Deduplicating 15 trillion tokens with single-node Python takes 18 months. With Spark across 1,000 nodes, it takes 36 hours. The difference is infrastructure: single machines hit memory limits at 100M documents, while distributed systems can fingerprint 100B documents using MinHash LSH partitioned across petabytes of RAM. At frontier scale, data processing is not a Python script; it is a distributed systems problem where coordination overhead, stragglers, and fault tolerance determine whether your pipeline finishes in days or never completes.
This post covers the distributed data pipeline architecture for processing pretraining data at 15T-token scale: Spark for ETL, Ray Data for GPU operations, MinHash LSH for deduplication, tokenization throughput, fault tolerance, and cost optimization.
Pipeline Architecture
Overview
from dataclasses import dataclass, field
from enum import Enum
from typing import Optional
class PipelineStage(Enum):
DOWNLOAD = "download"
EXTRACT = "extract"
LANGUAGE_FILTER = "language_filter"
QUALITY_FILTER = "quality_filter"
DEDUPLICATION = "deduplication"
PII_SCRUB = "pii_scrub"
TOKENIZE = "tokenize"
SHUFFLE = "shuffle"
PACKAGE = "package"
@dataclass
class PipelineConfig:
"""Configuration for the data processing pipeline."""
input_path: str
output_path: str
checkpoint_path: str
n_workers: int = 1000
n_gpus: int = 0
target_tokens: int = 15_000_000_000_000
spark_memory_per_executor: str = "32g"
ray_object_store_memory: int = 50_000_000_000
partitions: int = 100_000
stages: list = field(
default_factory=lambda: [
PipelineStage.DOWNLOAD,
PipelineStage.EXTRACT,
PipelineStage.LANGUAGE_FILTER,
PipelineStage.QUALITY_FILTER,
PipelineStage.DEDUPLICATION,
PipelineStage.PII_SCRUB,
PipelineStage.TOKENIZE,
PipelineStage.SHUFFLE,
PipelineStage.PACKAGE,
]
)
@dataclass
class StageMetrics:
"""Metrics for a single pipeline stage."""
stage: PipelineStage
input_records: int
output_records: int
bytes_read: int
bytes_written: int
wall_time_s: float
cpu_hours: float
gpu_hours: float = 0.0
cost_usd: float = 0.0
class PipelineOrchestrator:
"""
Orchestrate the full data processing pipeline.
Design principles:
1. Idempotent stages: each stage can be re-run safely
2. Checkpoint after every stage: resume from failures
3. Data-parallel: each stage processes partitions
independently
4. Framework-appropriate: Spark for CPU stages,
Ray for GPU stages
"""
def __init__(self, config):
self.config = config
self.metrics = {}
def run(self):
"""Run all stages sequentially."""
current_path = self.config.input_path
for stage in self.config.stages:
# Check if stage already completed (checkpoint)
checkpoint = self._load_checkpoint(stage)
if checkpoint:
current_path = checkpoint["output_path"]
self.metrics[stage] = checkpoint["metrics"]
continue
# Run stage
output_path = f"{self.config.output_path}/{stage.value}"
metrics = self._run_stage(
stage, current_path, output_path
)
# Save checkpoint
self._save_checkpoint(stage, output_path, metrics)
self.metrics[stage] = metrics
current_path = output_path
return self.metrics
def _run_stage(self, stage, input_path, output_path):
"""
Run a single pipeline stage.
Routes to Spark or Ray based on stage requirements.
"""
if stage in (
PipelineStage.DOWNLOAD,
PipelineStage.EXTRACT,
PipelineStage.LANGUAGE_FILTER,
PipelineStage.PII_SCRUB,
PipelineStage.SHUFFLE,
):
return self._run_spark_stage(
stage, input_path, output_path
)
elif stage in (
PipelineStage.QUALITY_FILTER,
PipelineStage.DEDUPLICATION,
PipelineStage.TOKENIZE,
):
return self._run_ray_stage(
stage, input_path, output_path
)
else:
return self._run_spark_stage(
stage, input_path, output_path
)
def _run_spark_stage(self, stage, input_path,
output_path):
"""Run a Spark-based stage."""
return StageMetrics(
stage=stage,
input_records=0,
output_records=0,
bytes_read=0,
bytes_written=0,
wall_time_s=0.0,
cpu_hours=0.0,
)
def _run_ray_stage(self, stage, input_path,
output_path):
"""Run a Ray-based stage."""
return StageMetrics(
stage=stage,
input_records=0,
output_records=0,
bytes_read=0,
bytes_written=0,
wall_time_s=0.0,
cpu_hours=0.0,
)
def _load_checkpoint(self, stage):
"""Load checkpoint for a stage if it exists."""
return None
def _save_checkpoint(self, stage, output_path, metrics):
"""Save checkpoint after stage completion."""
pass
Pipeline Stage Resource Requirements (15T Token Target)
| Stage | Framework | Input Size | Output Size | Wall Time (1000 nodes) | CPU Hours | GPU Hours |
|---|---|---|---|---|---|---|
| Download (Common Crawl) | wget/aria2 | 300 TB (compressed) | 300 TB | 8h | 8,000 | 0 |
| HTML extraction | Spark | 300 TB | 90 TB | 12h | 120,000 | 0 |
| Language filter | Spark | 90 TB | 65 TB | 4h | 40,000 | 0 |
| Quality filter | Ray + GPU | 65 TB | 25 TB | 8h | 20,000 | 8,000 |
| Deduplication (MinHash) | Spark + Ray | 25 TB | 18 TB | 16h | 160,000 | 2,000 |
| PII scrubbing | Spark | 18 TB | 17.5 TB | 6h | 60,000 | 0 |
| Tokenization | Ray + GPU | 17.5 TB | 60 TB (tokens) | 4h | 10,000 | 4,000 |
| Shuffle + package | Spark | 60 TB | 60 TB | 3h | 30,000 | 0 |
Spark for CPU-Bound ETL
HTML Extraction and Language Filtering
class SparkETLPipeline:
"""
Spark-based ETL for the CPU-bound stages.
Spark configuration for 1000-node cluster:
- 1000 executors, 32 cores each = 32,000 total cores
- 32 GB memory per executor = 32 TB total memory
- 100,000 partitions for 300 TB input
- Each partition: ~3 GB raw, processes in ~5 minutes
"""
def __init__(self, spark_session, config):
self.spark = spark_session
self.config = config
def extract_html(self, input_path, output_path):
"""
Extract clean text from WARC files (Common Crawl format).
Uses trafilatura or resiliparse for content extraction.
Removes boilerplate, navigation, ads, and other
non-content elements.
"""
# Read WARC files
warc_rdd = (
self.spark.sparkContext
.binaryFiles(f"{input_path}/*.warc.gz")
)
# Extract text from each WARC record
extracted = warc_rdd.flatMap(
self._extract_warc_records
)
# Write as parquet for downstream processing
extracted_df = extracted.toDF([
"url", "text", "language_hint",
"content_length", "crawl_date",
])
extracted_df.repartition(
self.config.partitions
).write.parquet(
output_path, mode="overwrite"
)
return extracted_df.count()
def _extract_warc_records(self, warc_pair):
"""
Extract text from a single WARC file.
Uses trafilatura for high-quality extraction.
Falls back to simple tag stripping for speed.
"""
filename, content = warc_pair
records = []
# Parse WARC records (simplified)
# In production, use warcio library
try:
import trafilatura
# Process each HTML document in the WARC
text = trafilatura.extract(
content.decode("utf-8", errors="ignore"),
include_comments=False,
include_tables=True,
no_fallback=False,
)
if text and len(text) > 100:
records.append((
filename,
text,
"", # language hint
len(text),
"", # crawl date
))
except Exception:
pass
return records
def language_filter(self, input_path, output_path,
target_languages=None):
"""
Filter documents by language.
Uses fasttext language identification model
(lid.176.bin) which supports 176 languages.
"""
if target_languages is None:
target_languages = {"en", "zh", "de", "fr", "es", "ja", "ko"}
df = self.spark.read.parquet(input_path)
# UDF for language detection
from pyspark.sql.functions import udf
from pyspark.sql.types import StringType
def detect_language(text):
"""Detect language using fasttext."""
if not text or len(text) < 20:
return "unknown"
# In production: fasttext.load_model() called once
# per executor using a broadcast variable
return "en" # Placeholder
lang_udf = udf(detect_language, StringType())
filtered = (
df.withColumn("detected_language", lang_udf(df.text))
.filter(
df.text.isNotNull()
& (df.content_length > 100)
)
)
# Filter to target languages
from pyspark.sql.functions import col
filtered = filtered.filter(
col("detected_language").isin(target_languages)
)
filtered.repartition(
self.config.partitions
).write.parquet(
output_path, mode="overwrite"
)
return filtered.count()
Spark’s lazy evaluation means the full pipeline (read, filter, write) is optimized as a single DAG. No intermediate data is materialized to disk unless explicitly requested. For a 300 TB input, this avoids writing 200+ TB of intermediate results. The tradeoff: if the job fails mid-execution, you restart from the beginning unless you add explicit checkpointing between stages.
MinHash Deduplication at Scale
Near-Duplicate Detection Across 15T Tokens
import hashlib
import struct
import numpy as np
class MinHashDeduplicator:
"""
MinHash LSH deduplication for near-duplicate detection.
At 15T token scale (~5 billion documents), pairwise
comparison is O(n^2) = infeasible. MinHash LSH reduces
this to approximately O(n) by hashing documents into
fixed-size signatures and comparing only documents
that share hash buckets.
Parameters:
- n_hashes: number of hash functions (128-256)
- n_bands: number of LSH bands
- rows_per_band: n_hashes / n_bands
- Threshold: documents with Jaccard similarity above
(1/n_bands)^(1/rows_per_band) are likely candidates
For n_hashes=128, n_bands=16, rows_per_band=8:
threshold ~ (1/16)^(1/8) ~ 0.58
For n_hashes=256, n_bands=32, rows_per_band=8:
threshold ~ (1/32)^(1/8) ~ 0.52
"""
MAX_HASH = 2**32 - 1
MERSENNE_PRIME = (1 << 61) - 1
def __init__(self, n_hashes=128, n_bands=16, n_grams=5):
self.n_hashes = n_hashes
self.n_bands = n_bands
self.rows_per_band = n_hashes // n_bands
self.n_grams = n_grams
# Generate random hash function parameters
# h(x) = (ax + b) mod p mod MAX_HASH
np.random.seed(42)
self.hash_a = np.random.randint(
1, self.MERSENNE_PRIME, size=n_hashes,
dtype=np.int64,
)
self.hash_b = np.random.randint(
0, self.MERSENNE_PRIME, size=n_hashes,
dtype=np.int64,
)
def compute_signature(self, text):
"""
Compute MinHash signature for a document.
Steps:
1. Tokenize into n-grams (word-level)
2. Hash each n-gram
3. For each hash function, take the minimum
hash value across all n-grams
"""
# Generate n-grams
words = text.lower().split()
if len(words) < self.n_grams:
return None
ngrams = set()
for i in range(len(words) - self.n_grams + 1):
ngram = " ".join(words[i:i + self.n_grams])
ngrams.add(ngram)
if not ngrams:
return None
# Hash each n-gram to an integer
ngram_hashes = np.array([
struct.unpack(
"<I",
hashlib.md5(
ng.encode()
).digest()[:4],
)[0]
for ng in ngrams
], dtype=np.int64)
# Compute MinHash signature
# For each hash function h_i, signature[i] =
# min over all n-grams of h_i(ngram)
signature = np.full(
self.n_hashes, self.MAX_HASH, dtype=np.int64
)
for ng_hash in ngram_hashes:
hashed = (
(self.hash_a * ng_hash + self.hash_b)
% self.MERSENNE_PRIME
) % self.MAX_HASH
signature = np.minimum(signature, hashed)
return signature
def compute_lsh_buckets(self, signature):
"""
Compute LSH bucket keys from a MinHash signature.
The signature is divided into n_bands bands.
Each band is hashed to produce a bucket key.
Two documents that share a bucket in any band
are candidate duplicates.
"""
buckets = []
for band_idx in range(self.n_bands):
start = band_idx * self.rows_per_band
end = start + self.rows_per_band
band = signature[start:end]
# Hash the band to a bucket key
band_bytes = band.tobytes()
bucket_key = hashlib.md5(
band_bytes
).hexdigest()[:16]
buckets.append((band_idx, bucket_key))
return buckets
def estimate_jaccard(self, sig_a, sig_b):
"""
Estimate Jaccard similarity from MinHash signatures.
J(A, B) ~ (number of positions where sig_a == sig_b)
/ n_hashes
"""
if sig_a is None or sig_b is None:
return 0.0
matches = np.sum(sig_a == sig_b)
return float(matches) / self.n_hashes
Distributed Dedup with Spark
class DistributedDeduplication:
"""
Distributed deduplication using Spark.
Architecture:
1. Compute MinHash signatures (map phase)
2. Compute LSH buckets (map phase)
3. Group by bucket (shuffle phase)
4. Find connected components (graph phase)
5. Keep one document per component (filter phase)
The shuffle phase is the bottleneck: with 5 billion
documents and 16 bands, we produce 80 billion
(band, bucket, doc_id) tuples. At ~50 bytes each,
that is ~4 TB of shuffle data.
"""
def __init__(self, spark_session, minhash_config):
self.spark = spark_session
self.deduplicator = MinHashDeduplicator(
n_hashes=minhash_config.get("n_hashes", 128),
n_bands=minhash_config.get("n_bands", 16),
n_grams=minhash_config.get("n_grams", 5),
)
def run_dedup(self, input_path, output_path):
"""
Run full deduplication pipeline.
"""
# Step 1: Compute signatures
docs_rdd = self.spark.sparkContext.textFile(
input_path
).zipWithIndex()
signatures_rdd = docs_rdd.map(
lambda x: (
x[1], # doc_id
self.deduplicator.compute_signature(x[0]),
)
).filter(lambda x: x[1] is not None)
# Step 2: Compute LSH buckets
bucket_rdd = signatures_rdd.flatMap(
lambda x: [
((band_idx, bucket_key), x[0])
for band_idx, bucket_key
in self.deduplicator.compute_lsh_buckets(x[1])
]
)
# Step 3: Group by bucket
candidate_pairs = (
bucket_rdd
.groupByKey()
.flatMap(self._generate_pairs)
.distinct()
)
# Step 4: Find connected components (Union-Find)
# This identifies clusters of near-duplicates
components = self._find_connected_components(
candidate_pairs
)
# Step 5: Keep one document per component
# (the one with highest quality score)
docs_to_keep = components.map(
lambda x: (x[1], x[0]) # (component_id, doc_id)
).reduceByKey(
lambda a, b: min(a, b) # Keep lowest doc_id
).map(lambda x: x[1])
docs_to_keep_set = set(docs_to_keep.collect())
# Filter original dataset
deduped = docs_rdd.filter(
lambda x: x[1] in docs_to_keep_set
)
return deduped.count()
def _generate_pairs(self, bucket_group):
"""
Generate candidate pairs from a bucket.
Cap at 100 pairs per bucket to avoid quadratic
explosion in large buckets.
"""
bucket_key, doc_ids = bucket_group
doc_ids = list(doc_ids)
if len(doc_ids) > 100:
# Large bucket: sample pairs
import random
pairs = set()
for _ in range(100):
i, j = random.sample(range(len(doc_ids)), 2)
pairs.add(
(min(doc_ids[i], doc_ids[j]),
max(doc_ids[i], doc_ids[j]))
)
return list(pairs)
# Small bucket: all pairs
pairs = []
for i in range(len(doc_ids)):
for j in range(i + 1, len(doc_ids)):
pairs.append(
(doc_ids[i], doc_ids[j])
)
return pairs
def _find_connected_components(self, pairs_rdd):
"""
Find connected components using iterative
label propagation on Spark.
"""
# Initialize: each node is its own component
nodes = pairs_rdd.flatMap(
lambda x: [x[0], x[1]]
).distinct().map(lambda x: (x, x))
for iteration in range(20):
# Propagate minimum label
edges = pairs_rdd.flatMap(
lambda x: [(x[0], x[1]), (x[1], x[0])]
)
new_labels = edges.join(nodes).map(
lambda x: (x[1][0], x[1][1])
).reduceByKey(min)
combined = nodes.union(new_labels).reduceByKey(min)
# Check convergence
changed = combined.join(nodes).filter(
lambda x: x[1][0] != x[1][1]
).count()
nodes = combined
if changed == 0:
break
return nodes
Deduplication Performance: Documents Processed per Hour
| Metric | 10 | 50 | 100 | 250 | 500 | 1000 |
|---|---|---|---|---|---|---|
| MinHash LSH (n=128, bands=16) | ||||||
| Exact dedup (hash-based) | ||||||
| Embedding-based SimHash |
The LSH shuffle phase is the memory and network bottleneck. With 5 billion documents and 16 bands, the shuffle writes approximately 4 TB of data across the cluster. Insufficient shuffle partitions cause out-of-memory errors on individual executors. Set spark.sql.shuffle.partitions to at least 200,000 for datasets above 1 billion documents. Also set spark.shuffle.compress=true and spark.shuffle.spill.compress=true.
Ray Data for GPU Operations
Quality Filtering with GPU Models
class RayQualityFilterPipeline:
"""
Ray Data pipeline for GPU-accelerated quality filtering.
Uses a quality classifier (fine-tuned BERT or similar)
to score each document. GPU inference is 10-50x faster
than CPU for the classifier forward pass.
Ray Data advantages over Spark for this stage:
- Native GPU scheduling
- Zero-copy data transfer to GPU
- Streaming execution (no full materialization)
"""
def __init__(self, ray_config):
self.model_path = ray_config["model_path"]
self.batch_size = ray_config.get("batch_size", 256)
self.quality_threshold = ray_config.get(
"threshold", 0.5
)
self.n_gpus = ray_config.get("n_gpus", 100)
def run_quality_filter(self, input_path, output_path):
"""
Filter documents by quality using GPU classifier.
"""
import ray
# Read data as Ray Dataset
ds = ray.data.read_parquet(input_path)
# Apply quality scoring using GPU
scored = ds.map_batches(
QualityScorer,
fn_constructor_kwargs={
"model_path": self.model_path,
"batch_size": self.batch_size,
},
num_gpus=1,
concurrency=self.n_gpus,
batch_size=self.batch_size,
)
# Filter by quality threshold
filtered = scored.filter(
lambda row: row["quality_score"]
> self.quality_threshold
)
# Write output
filtered.write_parquet(output_path)
return filtered.count()
class QualityScorer:
"""
Ray actor for GPU-based quality scoring.
Loaded once per GPU, reused across batches.
"""
def __init__(self, model_path, batch_size=256):
import torch
from transformers import AutoModelForSequenceClassification
from transformers import AutoTokenizer
self.device = torch.device("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(
model_path
)
self.model = (
AutoModelForSequenceClassification
.from_pretrained(model_path)
.to(self.device)
.eval()
)
self.batch_size = batch_size
def __call__(self, batch):
"""Score a batch of documents."""
import torch
texts = batch["text"].tolist()
scores = []
for i in range(0, len(texts), self.batch_size):
chunk = texts[i:i + self.batch_size]
# Truncate to model max length
inputs = self.tokenizer(
chunk,
max_length=512,
truncation=True,
padding=True,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.softmax(
outputs.logits, dim=-1
)
# Column 1 = "high quality" probability
quality_scores = (
probs[:, 1].cpu().numpy().tolist()
)
scores.extend(quality_scores)
batch["quality_score"] = scores
return batch
Tokenization at Scale
High-Throughput Tokenization
class DistributedTokenizer:
"""
Tokenize 18 TB of text into training-ready token sequences.
Tokenization throughput requirements:
- 18 TB of text at ~4 bytes per character = 4.5T characters
- BPE tokenization: ~1M tokens/second per CPU core
- 32,000 cores: ~32B tokens/second
- 15T tokens / 32B tokens/s = ~470 seconds = ~8 minutes
- In practice: 2-4 hours due to I/O, serialization,
and sequence packing overhead
Output format: packed token sequences of fixed length
(4096, 8192, or 32768 tokens) for efficient training.
"""
def __init__(self, tokenizer_path, seq_length=4096):
self.tokenizer_path = tokenizer_path
self.seq_length = seq_length
def tokenize_with_ray(self, input_path, output_path,
n_workers=1000):
"""
Tokenize using Ray for parallel processing.
"""
import ray
ds = ray.data.read_parquet(input_path)
tokenized = ds.map_batches(
TokenizeBatch,
fn_constructor_kwargs={
"tokenizer_path": self.tokenizer_path,
"seq_length": self.seq_length,
},
concurrency=n_workers,
batch_size=1000,
)
tokenized.write_parquet(output_path)
return tokenized.count()
class TokenizeBatch:
"""
Tokenize a batch of documents and pack into
fixed-length sequences.
Packing strategy:
- Concatenate documents with [EOS] separator
- Split into fixed-length sequences
- Pad the last sequence or discard if too short
- No document-level padding (wastes training compute)
"""
def __init__(self, tokenizer_path, seq_length=4096):
from tokenizers import Tokenizer
self.tokenizer = Tokenizer.from_file(tokenizer_path)
self.seq_length = seq_length
self.eos_token_id = self.tokenizer.token_to_id(
"</s>"
) or 2
self.buffer = []
def __call__(self, batch):
"""Tokenize and pack a batch."""
texts = batch["text"].tolist()
packed_sequences = []
for text in texts:
encoded = self.tokenizer.encode(text)
tokens = encoded.ids
# Add EOS token
tokens.append(self.eos_token_id)
self.buffer.extend(tokens)
# Pack into sequences
while len(self.buffer) >= self.seq_length:
seq = self.buffer[:self.seq_length]
packed_sequences.append(seq)
self.buffer = self.buffer[self.seq_length:]
return {
"input_ids": packed_sequences,
"length": [len(s) for s in packed_sequences],
}
Tokenization Throughput by Method
| Method | Tokens/sec (1 core) | Tokens/sec (1000 nodes) | Time for 15T tokens | Memory per Worker |
|---|---|---|---|---|
| HuggingFace tokenizers (Rust) | 2.5M | 80B | 3.1 hours | 200 MB |
| SentencePiece (C++) | 1.5M | 48B | 5.2 hours | 150 MB |
| tiktoken (Rust/Python) | 3.0M | 96B | 2.6 hours | 180 MB |
| Custom BPE (Python) | 0.3M | 9.6B | 26 hours | 100 MB |
Fault Tolerance and Cost
Handling Failures at Scale
class FaultTolerantPipeline:
"""
Fault tolerance for 1000-node processing clusters.
At 1000 nodes, node failures are not exceptions --
they are expected. Mean time between failures (MTBF)
for a 1000-node cluster is approximately:
MTBF_cluster = MTBF_node / n_nodes
If MTBF_node = 30 days: MTBF_cluster = 43 minutes
A 1000-node cluster will lose a node approximately
every 43 minutes. The pipeline must handle this.
"""
def __init__(self, config):
self.checkpoint_interval_s = config.get(
"checkpoint_interval_s", 1800
)
self.max_retries = config.get("max_retries", 3)
self.speculative_execution = config.get(
"speculative_execution", True
)
def run_with_checkpointing(self, stage_fn,
input_partitions,
output_path):
"""
Run a processing stage with partition-level
checkpointing and automatic retry.
"""
completed = self._load_completed_partitions(
output_path
)
remaining = [
p for p in input_partitions
if p not in completed
]
failed_partitions = []
for partition in remaining:
success = False
for attempt in range(self.max_retries):
try:
result = stage_fn(partition)
self._save_partition_result(
output_path, partition, result
)
success = True
break
except Exception as e:
if attempt == self.max_retries - 1:
failed_partitions.append(
(partition, str(e))
)
return {
"completed": (
len(input_partitions)
- len(failed_partitions)
),
"failed": len(failed_partitions),
"failure_rate": (
len(failed_partitions)
/ len(input_partitions)
),
}
def _load_completed_partitions(self, output_path):
"""Load set of already-completed partitions."""
return set()
def _save_partition_result(self, output_path,
partition, result):
"""Save a single partition result."""
pass
Pipeline Cost: Spark on AWS EMR vs Ray on GPU Instances
| Metric | Download | Extract | Lang Filter | Quality Filter | Dedup | PII Scrub | Tokenize | Shuffle |
|---|---|---|---|---|---|---|---|---|
| Spark on r5.8xlarge (CPU) | ||||||||
| Ray on p4d.24xlarge (GPU) |
Key Takeaways
Processing 15 trillion tokens requires a distributed pipeline that handles 300+ TB of raw data across 1000 nodes. The pipeline must be fault-tolerant (nodes fail every 43 minutes at this scale), checkpointed (resume from any failure), and cost-optimized (total cost of $40-70K per run).
The critical engineering decisions:
-
Use the right framework for each stage: Spark excels at CPU-bound ETL (extraction, language filtering, PII scrubbing). Ray Data excels at GPU-accelerated operations (quality classification, embedding computation). Using both in the same pipeline, with Parquet as the interchange format, captures the strengths of each.
-
MinHash LSH is the only viable dedup at this scale: Exact dedup (hash-based) catches identical documents but misses near-duplicates (same content, different formatting). Embedding-based dedup is too slow. MinHash LSH with 128 hashes and 16 bands catches documents above ~58% Jaccard similarity at O(n) cost.
-
The shuffle is the bottleneck in dedup: With 5 billion documents and 16 LSH bands, the shuffle phase produces 4 TB of intermediate data. Set shuffle partitions to 200,000+, enable compression, and use SSD-backed shuffle disks.
-
Tokenization is I/O-bound, not compute-bound: Rust-based tokenizers (HuggingFace tokenizers, tiktoken) process 2-3M tokens/second per core. At 1000 nodes with 32 cores each, tokenization completes in 2-5 hours. The bottleneck is reading input data and writing packed sequences, not the tokenization itself.
-
Checkpoint after every stage: At 1000 nodes, a full pipeline run takes 40-60 hours. Without checkpointing, a single failure in the final stage forces a restart from the beginning. Stage-level checkpointing with idempotent stages allows resume from the last successful stage.