Part of Series The Dataset Frontier 27 of 27
1 Synthetic Data Pipelines: Magpie, Nemotron-4, and Generating Training Data at Scale 2 Data Curation at Scale: DCLM, FineWeb-Edu, and the Exact Heuristics That Filter the Web 3 Agent-Based Simulation: Using 10,000 AI Agents to Generate Synthetic Training Data 4 Code Dataset Curation: Deduplication, License Filtering, and Quality Scoring for LLM Training 5 Multilingual Data: Cross-Lingual Transfer, Low-Resource Languages, and Translation Quality 6 Instruction Tuning Data: ShareGPT, OpenAssistant, and Quality Metrics for Alignment 7 Preference Data: Building DPO/RLHF Datasets from Human and AI Feedback 8 Data Mixing: Optimal Proportions of Code, Math, Web, and Books for LLM Training 9 Evaluation Datasets: Building Benchmarks That Actually Measure LLM Capability 10 Data Contamination: Detecting and Preventing Benchmark Leakage in Training Data 11 The Data Scaling Law: How Much Data Is Enough, and What Happens When You Run Out 12 Training a Tokenizer from Scratch: BPE Merge Rules, Vocabulary Optimization, and Compression Ratio 13 Multimodal Training Data: Image-Text Pairs, Video Captioning, and Interleaved Document Formats 14 RLHF Data at Scale: Collecting Millions of Human Preferences with Minimal Cost 15 Building a Decontamination Pipeline: Removing Benchmark Data from Training Corpora 16 Safety Training Data: Red Teaming, Refusal Training, and Building Datasets for Harmless AI 17 Data Versioning and Reproducibility: Tracking What Changed Between Training Runs 18 Domain-Specific Data: Building Medical, Legal, and Financial Training Datasets 19 Data Attribution and Provenance: Tracing Model Outputs Back to Training Examples 20 The Data Flywheel: Using Production Logs to Continuously Improve Training Data 21 Reward Model Training Data: Building Datasets for Math Verification and Code Correctness 22 Long-Context Training Data: Book-Length Documents, Multi-Document QA, and Needle-in-Haystack 23 Agentic Interaction Data: Tool Use Traces, Multi-Step Planning Logs, and Environment Feedback 24 Data Labeling Platforms: Scale AI, Surge AI, and Building Your Own Annotation Pipeline 25 Data Legal Issues: Copyright, Fair Use, Opt-Out, and the Regulatory Landscape for Training Data 26 Data Pipeline at Scale: Spark, Ray, and Processing 15 Trillion Tokens Across 1000 Nodes 27 Building a Data Pipeline: From Raw HTML to Clean Training Tokens in 500 Lines

Deduplicating 15 trillion tokens with single-node Python takes 18 months. With Spark across 1,000 nodes, it takes 36 hours. The difference is infrastructure: single machines hit memory limits at 100M documents, while distributed systems can fingerprint 100B documents using MinHash LSH partitioned across petabytes of RAM. At frontier scale, data processing is not a Python script; it is a distributed systems problem where coordination overhead, stragglers, and fault tolerance determine whether your pipeline finishes in days or never completes.

This post covers the distributed data pipeline architecture for processing pretraining data at 15T-token scale: Spark for ETL, Ray Data for GPU operations, MinHash LSH for deduplication, tokenization throughput, fault tolerance, and cost optimization.

Pipeline Architecture

Overview

from dataclasses import dataclass, field
from enum import Enum
from typing import Optional

class PipelineStage(Enum):
    DOWNLOAD = "download"
    EXTRACT = "extract"
    LANGUAGE_FILTER = "language_filter"
    QUALITY_FILTER = "quality_filter"
    DEDUPLICATION = "deduplication"
    PII_SCRUB = "pii_scrub"
    TOKENIZE = "tokenize"
    SHUFFLE = "shuffle"
    PACKAGE = "package"

@dataclass
class PipelineConfig:
    """Configuration for the data processing pipeline."""
    input_path: str
    output_path: str
    checkpoint_path: str
    n_workers: int = 1000
    n_gpus: int = 0
    target_tokens: int = 15_000_000_000_000
    spark_memory_per_executor: str = "32g"
    ray_object_store_memory: int = 50_000_000_000
    partitions: int = 100_000
    stages: list = field(
        default_factory=lambda: [
            PipelineStage.DOWNLOAD,
            PipelineStage.EXTRACT,
            PipelineStage.LANGUAGE_FILTER,
            PipelineStage.QUALITY_FILTER,
            PipelineStage.DEDUPLICATION,
            PipelineStage.PII_SCRUB,
            PipelineStage.TOKENIZE,
            PipelineStage.SHUFFLE,
            PipelineStage.PACKAGE,
        ]
    )

@dataclass
class StageMetrics:
    """Metrics for a single pipeline stage."""
    stage: PipelineStage
    input_records: int
    output_records: int
    bytes_read: int
    bytes_written: int
    wall_time_s: float
    cpu_hours: float
    gpu_hours: float = 0.0
    cost_usd: float = 0.0

class PipelineOrchestrator:
    """
    Orchestrate the full data processing pipeline.

    Design principles:
    1. Idempotent stages: each stage can be re-run safely
    2. Checkpoint after every stage: resume from failures
    3. Data-parallel: each stage processes partitions
       independently
    4. Framework-appropriate: Spark for CPU stages,
       Ray for GPU stages
    """

    def __init__(self, config):
        self.config = config
        self.metrics = {}

    def run(self):
        """Run all stages sequentially."""
        current_path = self.config.input_path

        for stage in self.config.stages:
            # Check if stage already completed (checkpoint)
            checkpoint = self._load_checkpoint(stage)
            if checkpoint:
                current_path = checkpoint["output_path"]
                self.metrics[stage] = checkpoint["metrics"]
                continue

            # Run stage
            output_path = f"{self.config.output_path}/{stage.value}"
            metrics = self._run_stage(
                stage, current_path, output_path
            )

            # Save checkpoint
            self._save_checkpoint(stage, output_path, metrics)

            self.metrics[stage] = metrics
            current_path = output_path

        return self.metrics

    def _run_stage(self, stage, input_path, output_path):
        """
        Run a single pipeline stage.

        Routes to Spark or Ray based on stage requirements.
        """
        if stage in (
            PipelineStage.DOWNLOAD,
            PipelineStage.EXTRACT,
            PipelineStage.LANGUAGE_FILTER,
            PipelineStage.PII_SCRUB,
            PipelineStage.SHUFFLE,
        ):
            return self._run_spark_stage(
                stage, input_path, output_path
            )
        elif stage in (
            PipelineStage.QUALITY_FILTER,
            PipelineStage.DEDUPLICATION,
            PipelineStage.TOKENIZE,
        ):
            return self._run_ray_stage(
                stage, input_path, output_path
            )
        else:
            return self._run_spark_stage(
                stage, input_path, output_path
            )

    def _run_spark_stage(self, stage, input_path,
                          output_path):
        """Run a Spark-based stage."""
        return StageMetrics(
            stage=stage,
            input_records=0,
            output_records=0,
            bytes_read=0,
            bytes_written=0,
            wall_time_s=0.0,
            cpu_hours=0.0,
        )

    def _run_ray_stage(self, stage, input_path,
                        output_path):
        """Run a Ray-based stage."""
        return StageMetrics(
            stage=stage,
            input_records=0,
            output_records=0,
            bytes_read=0,
            bytes_written=0,
            wall_time_s=0.0,
            cpu_hours=0.0,
        )

    def _load_checkpoint(self, stage):
        """Load checkpoint for a stage if it exists."""
        return None

    def _save_checkpoint(self, stage, output_path, metrics):
        """Save checkpoint after stage completion."""
        pass
📊

Pipeline Stage Resource Requirements (15T Token Target)

StageFrameworkInput SizeOutput SizeWall Time (1000 nodes)CPU HoursGPU Hours
Download (Common Crawl) wget/aria2 300 TB (compressed) 300 TB 8h 8,000 0
HTML extraction Spark 300 TB 90 TB 12h 120,000 0
Language filter Spark 90 TB 65 TB 4h 40,000 0
Quality filter Ray + GPU 65 TB 25 TB 8h 20,000 8,000
Deduplication (MinHash) Spark + Ray 25 TB 18 TB 16h 160,000 2,000
PII scrubbing Spark 18 TB 17.5 TB 6h 60,000 0
Tokenization Ray + GPU 17.5 TB 60 TB (tokens) 4h 10,000 4,000
Shuffle + package Spark 60 TB 60 TB 3h 30,000 0

Spark for CPU-Bound ETL

HTML Extraction and Language Filtering

class SparkETLPipeline:
    """
    Spark-based ETL for the CPU-bound stages.

    Spark configuration for 1000-node cluster:
    - 1000 executors, 32 cores each = 32,000 total cores
    - 32 GB memory per executor = 32 TB total memory
    - 100,000 partitions for 300 TB input
    - Each partition: ~3 GB raw, processes in ~5 minutes
    """

    def __init__(self, spark_session, config):
        self.spark = spark_session
        self.config = config

    def extract_html(self, input_path, output_path):
        """
        Extract clean text from WARC files (Common Crawl format).

        Uses trafilatura or resiliparse for content extraction.
        Removes boilerplate, navigation, ads, and other
        non-content elements.
        """
        # Read WARC files
        warc_rdd = (
            self.spark.sparkContext
            .binaryFiles(f"{input_path}/*.warc.gz")
        )

        # Extract text from each WARC record
        extracted = warc_rdd.flatMap(
            self._extract_warc_records
        )

        # Write as parquet for downstream processing
        extracted_df = extracted.toDF([
            "url", "text", "language_hint",
            "content_length", "crawl_date",
        ])

        extracted_df.repartition(
            self.config.partitions
        ).write.parquet(
            output_path, mode="overwrite"
        )

        return extracted_df.count()

    def _extract_warc_records(self, warc_pair):
        """
        Extract text from a single WARC file.

        Uses trafilatura for high-quality extraction.
        Falls back to simple tag stripping for speed.
        """
        filename, content = warc_pair
        records = []

        # Parse WARC records (simplified)
        # In production, use warcio library
        try:
            import trafilatura
            # Process each HTML document in the WARC
            text = trafilatura.extract(
                content.decode("utf-8", errors="ignore"),
                include_comments=False,
                include_tables=True,
                no_fallback=False,
            )
            if text and len(text) > 100:
                records.append((
                    filename,
                    text,
                    "",  # language hint
                    len(text),
                    "",  # crawl date
                ))
        except Exception:
            pass

        return records

    def language_filter(self, input_path, output_path,
                         target_languages=None):
        """
        Filter documents by language.

        Uses fasttext language identification model
        (lid.176.bin) which supports 176 languages.
        """
        if target_languages is None:
            target_languages = {"en", "zh", "de", "fr", "es", "ja", "ko"}

        df = self.spark.read.parquet(input_path)

        # UDF for language detection
        from pyspark.sql.functions import udf
        from pyspark.sql.types import StringType

        def detect_language(text):
            """Detect language using fasttext."""
            if not text or len(text) < 20:
                return "unknown"
            # In production: fasttext.load_model() called once
            # per executor using a broadcast variable
            return "en"  # Placeholder

        lang_udf = udf(detect_language, StringType())

        filtered = (
            df.withColumn("detected_language", lang_udf(df.text))
            .filter(
                df.text.isNotNull()
                & (df.content_length > 100)
            )
        )

        # Filter to target languages
        from pyspark.sql.functions import col
        filtered = filtered.filter(
            col("detected_language").isin(target_languages)
        )

        filtered.repartition(
            self.config.partitions
        ).write.parquet(
            output_path, mode="overwrite"
        )

        return filtered.count()
ℹ️ Note

Spark’s lazy evaluation means the full pipeline (read, filter, write) is optimized as a single DAG. No intermediate data is materialized to disk unless explicitly requested. For a 300 TB input, this avoids writing 200+ TB of intermediate results. The tradeoff: if the job fails mid-execution, you restart from the beginning unless you add explicit checkpointing between stages.

MinHash Deduplication at Scale

Near-Duplicate Detection Across 15T Tokens

import hashlib
import struct
import numpy as np

class MinHashDeduplicator:
    """
    MinHash LSH deduplication for near-duplicate detection.

    At 15T token scale (~5 billion documents), pairwise
    comparison is O(n^2) = infeasible. MinHash LSH reduces
    this to approximately O(n) by hashing documents into
    fixed-size signatures and comparing only documents
    that share hash buckets.

    Parameters:
    - n_hashes: number of hash functions (128-256)
    - n_bands: number of LSH bands
    - rows_per_band: n_hashes / n_bands
    - Threshold: documents with Jaccard similarity above
      (1/n_bands)^(1/rows_per_band) are likely candidates

    For n_hashes=128, n_bands=16, rows_per_band=8:
    threshold ~ (1/16)^(1/8) ~ 0.58

    For n_hashes=256, n_bands=32, rows_per_band=8:
    threshold ~ (1/32)^(1/8) ~ 0.52
    """

    MAX_HASH = 2**32 - 1
    MERSENNE_PRIME = (1 << 61) - 1

    def __init__(self, n_hashes=128, n_bands=16, n_grams=5):
        self.n_hashes = n_hashes
        self.n_bands = n_bands
        self.rows_per_band = n_hashes // n_bands
        self.n_grams = n_grams

        # Generate random hash function parameters
        # h(x) = (ax + b) mod p mod MAX_HASH
        np.random.seed(42)
        self.hash_a = np.random.randint(
            1, self.MERSENNE_PRIME, size=n_hashes,
            dtype=np.int64,
        )
        self.hash_b = np.random.randint(
            0, self.MERSENNE_PRIME, size=n_hashes,
            dtype=np.int64,
        )

    def compute_signature(self, text):
        """
        Compute MinHash signature for a document.

        Steps:
        1. Tokenize into n-grams (word-level)
        2. Hash each n-gram
        3. For each hash function, take the minimum
           hash value across all n-grams
        """
        # Generate n-grams
        words = text.lower().split()
        if len(words) < self.n_grams:
            return None

        ngrams = set()
        for i in range(len(words) - self.n_grams + 1):
            ngram = " ".join(words[i:i + self.n_grams])
            ngrams.add(ngram)

        if not ngrams:
            return None

        # Hash each n-gram to an integer
        ngram_hashes = np.array([
            struct.unpack(
                "<I",
                hashlib.md5(
                    ng.encode()
                ).digest()[:4],
            )[0]
            for ng in ngrams
        ], dtype=np.int64)

        # Compute MinHash signature
        # For each hash function h_i, signature[i] =
        #   min over all n-grams of h_i(ngram)
        signature = np.full(
            self.n_hashes, self.MAX_HASH, dtype=np.int64
        )

        for ng_hash in ngram_hashes:
            hashed = (
                (self.hash_a * ng_hash + self.hash_b)
                % self.MERSENNE_PRIME
            ) % self.MAX_HASH
            signature = np.minimum(signature, hashed)

        return signature

    def compute_lsh_buckets(self, signature):
        """
        Compute LSH bucket keys from a MinHash signature.

        The signature is divided into n_bands bands.
        Each band is hashed to produce a bucket key.
        Two documents that share a bucket in any band
        are candidate duplicates.
        """
        buckets = []

        for band_idx in range(self.n_bands):
            start = band_idx * self.rows_per_band
            end = start + self.rows_per_band
            band = signature[start:end]

            # Hash the band to a bucket key
            band_bytes = band.tobytes()
            bucket_key = hashlib.md5(
                band_bytes
            ).hexdigest()[:16]

            buckets.append((band_idx, bucket_key))

        return buckets

    def estimate_jaccard(self, sig_a, sig_b):
        """
        Estimate Jaccard similarity from MinHash signatures.

        J(A, B) ~ (number of positions where sig_a == sig_b)
                   / n_hashes
        """
        if sig_a is None or sig_b is None:
            return 0.0
        matches = np.sum(sig_a == sig_b)
        return float(matches) / self.n_hashes

Distributed Dedup with Spark

class DistributedDeduplication:
    """
    Distributed deduplication using Spark.

    Architecture:
    1. Compute MinHash signatures (map phase)
    2. Compute LSH buckets (map phase)
    3. Group by bucket (shuffle phase)
    4. Find connected components (graph phase)
    5. Keep one document per component (filter phase)

    The shuffle phase is the bottleneck: with 5 billion
    documents and 16 bands, we produce 80 billion
    (band, bucket, doc_id) tuples. At ~50 bytes each,
    that is ~4 TB of shuffle data.
    """

    def __init__(self, spark_session, minhash_config):
        self.spark = spark_session
        self.deduplicator = MinHashDeduplicator(
            n_hashes=minhash_config.get("n_hashes", 128),
            n_bands=minhash_config.get("n_bands", 16),
            n_grams=minhash_config.get("n_grams", 5),
        )

    def run_dedup(self, input_path, output_path):
        """
        Run full deduplication pipeline.
        """
        # Step 1: Compute signatures
        docs_rdd = self.spark.sparkContext.textFile(
            input_path
        ).zipWithIndex()

        signatures_rdd = docs_rdd.map(
            lambda x: (
                x[1],  # doc_id
                self.deduplicator.compute_signature(x[0]),
            )
        ).filter(lambda x: x[1] is not None)

        # Step 2: Compute LSH buckets
        bucket_rdd = signatures_rdd.flatMap(
            lambda x: [
                ((band_idx, bucket_key), x[0])
                for band_idx, bucket_key
                in self.deduplicator.compute_lsh_buckets(x[1])
            ]
        )

        # Step 3: Group by bucket
        candidate_pairs = (
            bucket_rdd
            .groupByKey()
            .flatMap(self._generate_pairs)
            .distinct()
        )

        # Step 4: Find connected components (Union-Find)
        # This identifies clusters of near-duplicates
        components = self._find_connected_components(
            candidate_pairs
        )

        # Step 5: Keep one document per component
        # (the one with highest quality score)
        docs_to_keep = components.map(
            lambda x: (x[1], x[0])  # (component_id, doc_id)
        ).reduceByKey(
            lambda a, b: min(a, b)  # Keep lowest doc_id
        ).map(lambda x: x[1])

        docs_to_keep_set = set(docs_to_keep.collect())

        # Filter original dataset
        deduped = docs_rdd.filter(
            lambda x: x[1] in docs_to_keep_set
        )

        return deduped.count()

    def _generate_pairs(self, bucket_group):
        """
        Generate candidate pairs from a bucket.

        Cap at 100 pairs per bucket to avoid quadratic
        explosion in large buckets.
        """
        bucket_key, doc_ids = bucket_group
        doc_ids = list(doc_ids)

        if len(doc_ids) > 100:
            # Large bucket: sample pairs
            import random
            pairs = set()
            for _ in range(100):
                i, j = random.sample(range(len(doc_ids)), 2)
                pairs.add(
                    (min(doc_ids[i], doc_ids[j]),
                     max(doc_ids[i], doc_ids[j]))
                )
            return list(pairs)

        # Small bucket: all pairs
        pairs = []
        for i in range(len(doc_ids)):
            for j in range(i + 1, len(doc_ids)):
                pairs.append(
                    (doc_ids[i], doc_ids[j])
                )
        return pairs

    def _find_connected_components(self, pairs_rdd):
        """
        Find connected components using iterative
        label propagation on Spark.
        """
        # Initialize: each node is its own component
        nodes = pairs_rdd.flatMap(
            lambda x: [x[0], x[1]]
        ).distinct().map(lambda x: (x, x))

        for iteration in range(20):
            # Propagate minimum label
            edges = pairs_rdd.flatMap(
                lambda x: [(x[0], x[1]), (x[1], x[0])]
            )

            new_labels = edges.join(nodes).map(
                lambda x: (x[1][0], x[1][1])
            ).reduceByKey(min)

            combined = nodes.union(new_labels).reduceByKey(min)

            # Check convergence
            changed = combined.join(nodes).filter(
                lambda x: x[1][0] != x[1][1]
            ).count()

            nodes = combined

            if changed == 0:
                break

        return nodes

Deduplication Performance: Documents Processed per Hour

Metric 10501002505001000
MinHash LSH (n=128, bands=16)
5
24
48
115
220
400
Exact dedup (hash-based)
15
70
140
340
650
1200
Embedding-based SimHash
2
9
18
42
80
150
⚠️ Warning

The LSH shuffle phase is the memory and network bottleneck. With 5 billion documents and 16 bands, the shuffle writes approximately 4 TB of data across the cluster. Insufficient shuffle partitions cause out-of-memory errors on individual executors. Set spark.sql.shuffle.partitions to at least 200,000 for datasets above 1 billion documents. Also set spark.shuffle.compress=true and spark.shuffle.spill.compress=true.

Ray Data for GPU Operations

Quality Filtering with GPU Models

class RayQualityFilterPipeline:
    """
    Ray Data pipeline for GPU-accelerated quality filtering.

    Uses a quality classifier (fine-tuned BERT or similar)
    to score each document. GPU inference is 10-50x faster
    than CPU for the classifier forward pass.

    Ray Data advantages over Spark for this stage:
    - Native GPU scheduling
    - Zero-copy data transfer to GPU
    - Streaming execution (no full materialization)
    """

    def __init__(self, ray_config):
        self.model_path = ray_config["model_path"]
        self.batch_size = ray_config.get("batch_size", 256)
        self.quality_threshold = ray_config.get(
            "threshold", 0.5
        )
        self.n_gpus = ray_config.get("n_gpus", 100)

    def run_quality_filter(self, input_path, output_path):
        """
        Filter documents by quality using GPU classifier.
        """
        import ray

        # Read data as Ray Dataset
        ds = ray.data.read_parquet(input_path)

        # Apply quality scoring using GPU
        scored = ds.map_batches(
            QualityScorer,
            fn_constructor_kwargs={
                "model_path": self.model_path,
                "batch_size": self.batch_size,
            },
            num_gpus=1,
            concurrency=self.n_gpus,
            batch_size=self.batch_size,
        )

        # Filter by quality threshold
        filtered = scored.filter(
            lambda row: row["quality_score"]
            > self.quality_threshold
        )

        # Write output
        filtered.write_parquet(output_path)

        return filtered.count()

class QualityScorer:
    """
    Ray actor for GPU-based quality scoring.

    Loaded once per GPU, reused across batches.
    """

    def __init__(self, model_path, batch_size=256):
        import torch
        from transformers import AutoModelForSequenceClassification
        from transformers import AutoTokenizer

        self.device = torch.device("cuda")
        self.tokenizer = AutoTokenizer.from_pretrained(
            model_path
        )
        self.model = (
            AutoModelForSequenceClassification
            .from_pretrained(model_path)
            .to(self.device)
            .eval()
        )
        self.batch_size = batch_size

    def __call__(self, batch):
        """Score a batch of documents."""
        import torch

        texts = batch["text"].tolist()
        scores = []

        for i in range(0, len(texts), self.batch_size):
            chunk = texts[i:i + self.batch_size]

            # Truncate to model max length
            inputs = self.tokenizer(
                chunk,
                max_length=512,
                truncation=True,
                padding=True,
                return_tensors="pt",
            ).to(self.device)

            with torch.no_grad():
                outputs = self.model(**inputs)
                probs = torch.softmax(
                    outputs.logits, dim=-1
                )
                # Column 1 = "high quality" probability
                quality_scores = (
                    probs[:, 1].cpu().numpy().tolist()
                )
                scores.extend(quality_scores)

        batch["quality_score"] = scores
        return batch

Tokenization at Scale

High-Throughput Tokenization

class DistributedTokenizer:
    """
    Tokenize 18 TB of text into training-ready token sequences.

    Tokenization throughput requirements:
    - 18 TB of text at ~4 bytes per character = 4.5T characters
    - BPE tokenization: ~1M tokens/second per CPU core
    - 32,000 cores: ~32B tokens/second
    - 15T tokens / 32B tokens/s = ~470 seconds = ~8 minutes
    - In practice: 2-4 hours due to I/O, serialization,
      and sequence packing overhead

    Output format: packed token sequences of fixed length
    (4096, 8192, or 32768 tokens) for efficient training.
    """

    def __init__(self, tokenizer_path, seq_length=4096):
        self.tokenizer_path = tokenizer_path
        self.seq_length = seq_length

    def tokenize_with_ray(self, input_path, output_path,
                           n_workers=1000):
        """
        Tokenize using Ray for parallel processing.
        """
        import ray

        ds = ray.data.read_parquet(input_path)

        tokenized = ds.map_batches(
            TokenizeBatch,
            fn_constructor_kwargs={
                "tokenizer_path": self.tokenizer_path,
                "seq_length": self.seq_length,
            },
            concurrency=n_workers,
            batch_size=1000,
        )

        tokenized.write_parquet(output_path)

        return tokenized.count()

class TokenizeBatch:
    """
    Tokenize a batch of documents and pack into
    fixed-length sequences.

    Packing strategy:
    - Concatenate documents with [EOS] separator
    - Split into fixed-length sequences
    - Pad the last sequence or discard if too short
    - No document-level padding (wastes training compute)
    """

    def __init__(self, tokenizer_path, seq_length=4096):
        from tokenizers import Tokenizer
        self.tokenizer = Tokenizer.from_file(tokenizer_path)
        self.seq_length = seq_length
        self.eos_token_id = self.tokenizer.token_to_id(
            "</s>"
        ) or 2
        self.buffer = []

    def __call__(self, batch):
        """Tokenize and pack a batch."""
        texts = batch["text"].tolist()
        packed_sequences = []

        for text in texts:
            encoded = self.tokenizer.encode(text)
            tokens = encoded.ids

            # Add EOS token
            tokens.append(self.eos_token_id)

            self.buffer.extend(tokens)

            # Pack into sequences
            while len(self.buffer) >= self.seq_length:
                seq = self.buffer[:self.seq_length]
                packed_sequences.append(seq)
                self.buffer = self.buffer[self.seq_length:]

        return {
            "input_ids": packed_sequences,
            "length": [len(s) for s in packed_sequences],
        }
📊

Tokenization Throughput by Method

MethodTokens/sec (1 core)Tokens/sec (1000 nodes)Time for 15T tokensMemory per Worker
HuggingFace tokenizers (Rust) 2.5M 80B 3.1 hours 200 MB
SentencePiece (C++) 1.5M 48B 5.2 hours 150 MB
tiktoken (Rust/Python) 3.0M 96B 2.6 hours 180 MB
Custom BPE (Python) 0.3M 9.6B 26 hours 100 MB

Fault Tolerance and Cost

Handling Failures at Scale

class FaultTolerantPipeline:
    """
    Fault tolerance for 1000-node processing clusters.

    At 1000 nodes, node failures are not exceptions --
    they are expected. Mean time between failures (MTBF)
    for a 1000-node cluster is approximately:

    MTBF_cluster = MTBF_node / n_nodes
    If MTBF_node = 30 days: MTBF_cluster = 43 minutes

    A 1000-node cluster will lose a node approximately
    every 43 minutes. The pipeline must handle this.
    """

    def __init__(self, config):
        self.checkpoint_interval_s = config.get(
            "checkpoint_interval_s", 1800
        )
        self.max_retries = config.get("max_retries", 3)
        self.speculative_execution = config.get(
            "speculative_execution", True
        )

    def run_with_checkpointing(self, stage_fn,
                                 input_partitions,
                                 output_path):
        """
        Run a processing stage with partition-level
        checkpointing and automatic retry.
        """
        completed = self._load_completed_partitions(
            output_path
        )
        remaining = [
            p for p in input_partitions
            if p not in completed
        ]

        failed_partitions = []

        for partition in remaining:
            success = False
            for attempt in range(self.max_retries):
                try:
                    result = stage_fn(partition)
                    self._save_partition_result(
                        output_path, partition, result
                    )
                    success = True
                    break
                except Exception as e:
                    if attempt == self.max_retries - 1:
                        failed_partitions.append(
                            (partition, str(e))
                        )

        return {
            "completed": (
                len(input_partitions)
                - len(failed_partitions)
            ),
            "failed": len(failed_partitions),
            "failure_rate": (
                len(failed_partitions)
                / len(input_partitions)
            ),
        }

    def _load_completed_partitions(self, output_path):
        """Load set of already-completed partitions."""
        return set()

    def _save_partition_result(self, output_path,
                                partition, result):
        """Save a single partition result."""
        pass

Pipeline Cost: Spark on AWS EMR vs Ray on GPU Instances

Metric DownloadExtractLang FilterQuality FilterDedupPII ScrubTokenizeShuffle
Spark on r5.8xlarge (CPU)
800
12000
4000
20000
16000
6000
10000
3000
Ray on p4d.24xlarge (GPU)
800
12000
4000
5000
8000
6000
3000
3000

Key Takeaways

Processing 15 trillion tokens requires a distributed pipeline that handles 300+ TB of raw data across 1000 nodes. The pipeline must be fault-tolerant (nodes fail every 43 minutes at this scale), checkpointed (resume from any failure), and cost-optimized (total cost of $40-70K per run).

The critical engineering decisions:

  1. Use the right framework for each stage: Spark excels at CPU-bound ETL (extraction, language filtering, PII scrubbing). Ray Data excels at GPU-accelerated operations (quality classification, embedding computation). Using both in the same pipeline, with Parquet as the interchange format, captures the strengths of each.

  2. MinHash LSH is the only viable dedup at this scale: Exact dedup (hash-based) catches identical documents but misses near-duplicates (same content, different formatting). Embedding-based dedup is too slow. MinHash LSH with 128 hashes and 16 bands catches documents above ~58% Jaccard similarity at O(n) cost.

  3. The shuffle is the bottleneck in dedup: With 5 billion documents and 16 LSH bands, the shuffle phase produces 4 TB of intermediate data. Set shuffle partitions to 200,000+, enable compression, and use SSD-backed shuffle disks.

  4. Tokenization is I/O-bound, not compute-bound: Rust-based tokenizers (HuggingFace tokenizers, tiktoken) process 2-3M tokens/second per core. At 1000 nodes with 32 cores each, tokenization completes in 2-5 hours. The bottleneck is reading input data and writing packed sequences, not the tokenization itself.

  5. Checkpoint after every stage: At 1000 nodes, a full pipeline run takes 40-60 hours. Without checkpointing, a single failure in the final stage forces a restart from the beginning. Stage-level checkpointing with idempotent stages allows resume from the last successful stage.