Claude 3 trained on 200K context windows but most of its pretraining data came from 2-8K token documents. The mismatch creates a generalization problem: models that see only short documents during training exhibit position bias at long contexts, retrieving information from the start and end of the window while ignoring the middle 80%. Fixing this requires long-context training data at scale — book-length documents, multi-document reasoning tasks, and synthetic needle-in-haystack examples that force the model to attend uniformly across 100K+ tokens.
This post covers the construction of training data for long-context models: document sourcing, synthetic long-context generation, multi-document QA, needle-in-haystack evaluation, and position-aware training strategies.
Long Document Sourcing
Finding and Preparing Book-Length Content
from dataclasses import dataclass, field
from enum import Enum
import re
class DocumentSource(Enum):
GUTENBERG = "project_gutenberg"
ARXIV = "arxiv_papers"
WIKIPEDIA = "wikipedia"
LEGAL = "legal_documents"
CODE_REPOS = "code_repositories"
SYNTHETIC = "synthetic"
GOV_REPORTS = "government_reports"
@dataclass
class LongDocument:
"""A document suitable for long-context training."""
source: DocumentSource
title: str
text: str
token_count: int
sections: list = field(default_factory=list)
metadata: dict = field(default_factory=dict)
class LongDocumentPipeline:
"""
Pipeline for sourcing and preparing long documents.
Challenges:
- Most web text is short (median ~500 tokens)
- Book-length text is copyrighted
- Legal/gov documents are long but domain-specific
- arXiv papers are 5-20K tokens (not long enough alone)
Strategy: combine natural long documents with
synthetic multi-document compositions.
"""
SOURCE_CONFIGS = {
DocumentSource.GUTENBERG: {
"avg_tokens": 80000,
"copyright": "public_domain",
"quality": "high",
"topics": "fiction, classic literature",
"available_count": 70000,
},
DocumentSource.ARXIV: {
"avg_tokens": 8000,
"copyright": "open_access",
"quality": "high",
"topics": "STEM research",
"available_count": 2500000,
},
DocumentSource.LEGAL: {
"avg_tokens": 50000,
"copyright": "public_domain",
"quality": "medium",
"topics": "law, regulation, court opinions",
"available_count": 500000,
},
DocumentSource.GOV_REPORTS: {
"avg_tokens": 30000,
"copyright": "public_domain",
"quality": "medium",
"topics": "policy, economics, science",
"available_count": 200000,
},
}
def __init__(self, tokenizer, target_lengths=None):
self.tokenizer = tokenizer
self.target_lengths = target_lengths or [
16384, 32768, 65536, 131072, 262144,
]
def prepare_natural_documents(self, raw_documents):
"""
Prepare naturally long documents for training.
Steps:
1. Clean and normalize text
2. Tokenize and measure length
3. Extract section structure
4. Filter by quality metrics
"""
prepared = []
for doc in raw_documents:
clean = self._clean_text(doc["text"])
tokens = self.tokenizer.encode(clean)
token_count = len(tokens)
# Skip documents shorter than our minimum target
if token_count < self.target_lengths[0]:
continue
sections = self._extract_sections(clean)
prepared.append(
LongDocument(
source=doc["source"],
title=doc.get("title", ""),
text=clean,
token_count=token_count,
sections=sections,
metadata={
"original_length": len(doc["text"]),
"clean_length": len(clean),
"compression_ratio": (
len(clean) / len(doc["text"])
),
},
)
)
return prepared
def _clean_text(self, text):
"""Clean raw document text."""
# Remove multiple newlines
text = re.sub(r"\n{3,}", "\n\n", text)
# Remove page headers/footers (common in PDFs)
text = re.sub(
r"^Page \d+ of \d+.*$", "", text, flags=re.MULTILINE
)
# Normalize whitespace
text = re.sub(r"[ \t]+", " ", text)
return text.strip()
def _extract_sections(self, text):
"""
Extract document section structure.
Identifies headings, chapters, and major breaks.
Used for structured QA generation.
"""
sections = []
patterns = [
r"^(Chapter \d+[:\.]?\s*.*)$",
r"^(SECTION \d+[:\.]?\s*.*)$",
r"^(#{1,3}\s+.*)$",
r"^(\d+\.\d*\s+[A-Z].*)$",
]
lines = text.split("\n")
current_section = {"title": "Introduction", "start": 0}
for i, line in enumerate(lines):
for pattern in patterns:
match = re.match(pattern, line.strip())
if match:
current_section["end"] = i
sections.append(current_section.copy())
current_section = {
"title": match.group(1).strip(),
"start": i,
}
break
current_section["end"] = len(lines)
sections.append(current_section)
return sections
Long Document Sources: Availability and Properties
| Source | Avg Length (tokens) | Available Docs | Copyright Status | Domain Coverage | Quality |
|---|---|---|---|---|---|
| Project Gutenberg | 80K | 70K | Public domain | Literature | High |
| arXiv (concatenated) | 8K per paper | 2.5M | Open access | STEM | High |
| US Court Opinions | 50K | 500K | Public domain | Legal | Medium |
| Congressional Reports | 30K | 200K | Public domain | Policy | Medium |
| GitHub repos (full) | 100K+ | 10M+ | Mixed licenses | Code | Variable |
| Synthetic compositions | Configurable | Unlimited | N/A | Any | Controlled |
Synthetic Long-Context Generation
Multi-Document Composition
import random
import numpy as np
class SyntheticLongContextGenerator:
"""
Generate synthetic long-context training data by
composing multiple shorter documents with
cross-references, shared entities, and
information dependencies.
This produces training data that specifically
exercises long-context skills:
- Information retrieval from arbitrary positions
- Cross-document reasoning
- Entity tracking across distance
- Temporal ordering from scattered evidence
"""
def __init__(self, document_pool, tokenizer, model):
self.document_pool = document_pool
self.tokenizer = tokenizer
self.model = model
def generate_multi_doc_composition(self, target_tokens,
n_documents=10):
"""
Compose N documents into a single long context
with cross-references and shared entities.
Steps:
1. Select N thematically related documents
2. Insert cross-reference markers
3. Generate QA pairs requiring multi-doc reasoning
4. Shuffle document order for position invariance
"""
# Select related documents
anchor = random.choice(self.document_pool)
related = self._find_related(
anchor, n_documents - 1
)
documents = [anchor] + related
# Trim to fit target length
tokens_per_doc = target_tokens // n_documents
trimmed = []
for doc in documents:
tokens = self.tokenizer.encode(doc["text"])
if len(tokens) > tokens_per_doc:
tokens = tokens[:tokens_per_doc]
trimmed.append(
self.tokenizer.decode(tokens)
)
# Insert document separators with metadata
composed = ""
doc_positions = []
for i, doc_text in enumerate(trimmed):
start_pos = len(self.tokenizer.encode(composed))
header = (
f"\n\n--- Document {i+1}: "
f"{documents[i].get('title', f'Doc {i+1}')} "
f"---\n\n"
)
composed += header + doc_text
end_pos = len(self.tokenizer.encode(composed))
doc_positions.append({
"doc_index": i,
"start_token": start_pos,
"end_token": end_pos,
"title": documents[i].get("title", ""),
})
return {
"text": composed,
"doc_positions": doc_positions,
"total_tokens": len(
self.tokenizer.encode(composed)
),
}
def generate_cross_reference_qa(self, composition):
"""
Generate QA pairs that require reading multiple
documents in the composition.
Question types:
1. Comparison: "How does Document A's view on X
differ from Document B's?"
2. Aggregation: "What are all the dates mentioned
across all documents?"
3. Reasoning: "Based on the evidence in Documents
A and C, what conclusion..."
4. Contradiction: "Document B claims X, but
Document D claims Y. Which is supported by..."
"""
qa_types = [
"comparison", "aggregation",
"reasoning", "contradiction",
]
qa_pairs = []
for qa_type in qa_types:
prompt = (
f"Given the following multi-document context, "
f"generate a {qa_type} question that requires "
f"reading at least 2 different documents to "
f"answer correctly.\n\n"
f"Context:\n{composition['text'][:50000]}\n\n"
f"Generate:\n"
f"1. A question requiring {qa_type} reasoning\n"
f"2. The correct answer\n"
f"3. Which document numbers are needed"
)
response = self.model.generate(
prompt, temperature=0.7, max_tokens=500
)
qa_pairs.append({
"type": qa_type,
"raw_response": response,
})
return qa_pairs
def _find_related(self, anchor, n):
"""Find n documents related to the anchor."""
# Simplified: random selection from same domain
same_domain = [
doc for doc in self.document_pool
if doc.get("domain") == anchor.get("domain")
and doc != anchor
]
if len(same_domain) >= n:
return random.sample(same_domain, n)
return random.sample(self.document_pool, min(n, len(self.document_pool)))
Needle-in-Haystack Evaluation Data
Constructing Rigorous Retrieval Tests
class NeedleInHaystackBuilder:
"""
Build needle-in-haystack (NIAH) evaluation data.
NIAH tests whether a model can retrieve a specific
piece of information from an arbitrary position in
a long context. The 'needle' is a fact planted at
a specific position. The 'haystack' is filler text.
The question asks about the needle.
Key design decisions:
- Needle must be unrelated to haystack (prevents
shortcut via topic matching)
- Position must be varied systematically
- Multiple needles test multi-fact retrieval
- Distractor needles test precision
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def build_single_needle(self, haystack_tokens,
needle_position_pct):
"""
Insert a single needle at a specified position.
needle_position_pct: 0.0 = beginning, 1.0 = end
"""
# Generate a unique needle fact
needle = self._generate_needle()
# Get haystack text of target length
haystack = self._get_haystack(haystack_tokens)
# Calculate insertion point
haystack_lines = haystack.split("\n")
insert_idx = int(
len(haystack_lines) * needle_position_pct
)
insert_idx = max(1, min(
insert_idx, len(haystack_lines) - 1
))
# Insert needle
haystack_lines.insert(insert_idx, needle["text"])
composed = "\n".join(haystack_lines)
return {
"context": composed,
"needle": needle,
"position_pct": needle_position_pct,
"question": needle["question"],
"answer": needle["answer"],
"total_tokens": len(
self.tokenizer.encode(composed)
),
}
def build_multi_needle(self, haystack_tokens,
n_needles=5):
"""
Insert multiple needles at different positions.
Tests the model's ability to retrieve multiple
facts from a single long context.
"""
needles = [
self._generate_needle() for _ in range(n_needles)
]
haystack = self._get_haystack(haystack_tokens)
haystack_lines = haystack.split("\n")
# Distribute needles evenly
positions = np.linspace(
0.1, 0.9, n_needles
)
needle_positions = []
for needle, pos in zip(needles, positions):
insert_idx = int(len(haystack_lines) * pos)
haystack_lines.insert(insert_idx, needle["text"])
needle_positions.append({
"needle": needle,
"position_pct": pos,
})
composed = "\n".join(haystack_lines)
# Generate a question that requires ALL needles
combined_question = (
"Answer the following questions based on "
"the context:\n"
)
combined_answer = ""
for i, np_info in enumerate(needle_positions):
combined_question += (
f"{i+1}. {np_info['needle']['question']}\n"
)
combined_answer += (
f"{i+1}. {np_info['needle']['answer']}\n"
)
return {
"context": composed,
"needles": needle_positions,
"question": combined_question,
"answer": combined_answer,
"total_tokens": len(
self.tokenizer.encode(composed)
),
}
def build_distractor_needles(self, haystack_tokens,
n_real=3, n_distractors=5):
"""
Insert real needles and distractor needles.
Distractor needles are facts from a similar domain
that are NOT asked about. Tests whether the model
retrieves the correct facts rather than any fact.
"""
real_needles = [
self._generate_needle(domain="science")
for _ in range(n_real)
]
distractor_needles = [
self._generate_needle(domain="science")
for _ in range(n_distractors)
]
all_needles = real_needles + distractor_needles
random.shuffle(all_needles)
haystack = self._get_haystack(haystack_tokens)
haystack_lines = haystack.split("\n")
positions = np.linspace(
0.05, 0.95, len(all_needles)
)
for needle, pos in zip(all_needles, positions):
insert_idx = int(len(haystack_lines) * pos)
haystack_lines.insert(insert_idx, needle["text"])
composed = "\n".join(haystack_lines)
# Question only asks about real needles
question = "Based on the context, answer:\n"
answer = ""
for i, needle in enumerate(real_needles):
question += f"{i+1}. {needle['question']}\n"
answer += f"{i+1}. {needle['answer']}\n"
return {
"context": composed,
"real_needles": real_needles,
"distractor_needles": distractor_needles,
"question": question,
"answer": answer,
"total_tokens": len(
self.tokenizer.encode(composed)
),
}
def build_evaluation_grid(self, context_lengths,
position_percentages):
"""
Build a full evaluation grid: every combination
of context length and needle position.
Standard grid:
- Lengths: [4K, 8K, 16K, 32K, 64K, 128K]
- Positions: [0%, 10%, 25%, 50%, 75%, 90%, 100%]
- Total tests: 6 * 7 = 42 per needle type
"""
grid = []
for length in context_lengths:
for position in position_percentages:
sample = self.build_single_needle(
haystack_tokens=length,
needle_position_pct=position,
)
sample["grid_length"] = length
sample["grid_position"] = position
grid.append(sample)
return grid
def _generate_needle(self, domain="general"):
"""Generate a unique needle fact."""
# Predefined needles for reproducibility
needles = [
{
"text": (
"The special magic number for this "
"experiment is 7392."
),
"question": (
"What is the special magic number "
"mentioned in the text?"
),
"answer": "7392",
},
{
"text": (
"The secret project codename is "
"Operation Sapphire Dawn."
),
"question": (
"What is the secret project codename?"
),
"answer": "Operation Sapphire Dawn",
},
{
"text": (
"The research facility is located at "
"coordinates 47.6N, 122.3W."
),
"question": (
"What are the coordinates of the "
"research facility?"
),
"answer": "47.6N, 122.3W",
},
]
return random.choice(needles)
def _get_haystack(self, target_tokens):
"""Get haystack text of approximately target length."""
# In practice, this draws from a corpus of
# topically neutral text (Paul Graham essays,
# Wikipedia articles, etc.)
return "Placeholder haystack text.\n" * (
target_tokens // 5
)
NIAH Accuracy by Context Length and Position
| Metric | 0% | 10% | 25% | 50% | 75% | 90% | 100% |
|---|---|---|---|---|---|---|---|
| 4K context | |||||||
| 32K context | |||||||
| 128K context | |||||||
| 1M context |
The U-shaped accuracy curve (high at start and end, low in the middle) is called the “lost in the middle” effect (Liu et al., 2023). Models attend preferentially to the beginning (primacy bias) and end (recency bias) of the context. Training data must counteract this by placing critical information uniformly across all positions, with additional emphasis on the middle positions where models perform worst.
Position-Aware Data Augmentation
Counteracting Position Bias
class PositionAwareAugmenter:
"""
Augment long-context training data to counteract
position bias (lost-in-the-middle effect).
Strategies:
1. Answer-position randomization: place the answer-
relevant information at different positions across
training samples
2. Middle-heavy sampling: over-represent samples where
key information is in positions 25-75%
3. Position-tagged training: prepend position markers
so the model learns position-independent retrieval
"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
def randomize_answer_position(self, qa_sample,
n_augmentations=5):
"""
Create multiple versions of a QA sample with
the answer-relevant passage at different positions.
"""
context = qa_sample["context"]
answer_passage = qa_sample.get("answer_passage", "")
if not answer_passage:
return [qa_sample]
# Remove answer passage from context
context_without = context.replace(answer_passage, "")
lines = context_without.split("\n")
lines = [l for l in lines if l.strip()]
augmented = []
positions = np.linspace(0.0, 1.0, n_augmentations)
for pos in positions:
insert_idx = int(len(lines) * pos)
insert_idx = max(0, min(
insert_idx, len(lines)
))
new_lines = lines.copy()
new_lines.insert(insert_idx, answer_passage)
augmented.append({
"context": "\n".join(new_lines),
"question": qa_sample["question"],
"answer": qa_sample["answer"],
"answer_position_pct": pos,
})
return augmented
def apply_middle_heavy_sampling(self, dataset,
middle_weight=3.0):
"""
Over-sample examples where the answer is in
the middle 50% of the context.
Standard uniform sampling gives equal weight to
all positions. This function applies a weight of
middle_weight to samples where the answer is
between 25% and 75% of the context.
"""
weighted_dataset = []
for sample in dataset:
pos = sample.get("answer_position_pct", 0.5)
weight = 1.0
if 0.25 <= pos <= 0.75:
weight = middle_weight
# Duplicate by weight (integer approximation)
n_copies = max(1, int(weight))
for _ in range(n_copies):
weighted_dataset.append(sample)
random.shuffle(weighted_dataset)
return weighted_dataset
def chunk_with_overlap(self, document, chunk_size,
overlap_tokens=512):
"""
Create overlapping chunks for sliding-window training.
For documents longer than the training context,
create chunks with overlap so that cross-chunk
dependencies are seen during training.
"""
tokens = self.tokenizer.encode(document)
chunks = []
stride = chunk_size - overlap_tokens
for start in range(0, len(tokens), stride):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunks.append({
"tokens": chunk_tokens,
"text": self.tokenizer.decode(chunk_tokens),
"start_token": start,
"end_token": end,
"is_first": start == 0,
"is_last": end == len(tokens),
})
if end == len(tokens):
break
return chunks
Multi-Document QA Datasets
Cross-Document Reasoning
class MultiDocQABuilder:
"""
Build QA datasets requiring reasoning across
multiple documents within a long context.
Task types:
1. Information synthesis: combine facts from
multiple documents
2. Contradiction detection: find conflicting
claims across documents
3. Timeline reconstruction: order events from
scattered mentions
4. Entity resolution: track entities across
documents with different naming
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def build_synthesis_qa(self, documents):
"""
Build QA requiring synthesis of information
from 2+ documents.
Example: Document A states the population of
city X. Document B states the GDP per capita.
Question: What is the total GDP of city X?
"""
# Select 2-3 documents with shared entities
entity_groups = self._find_shared_entities(documents)
qa_pairs = []
for entity, doc_indices in entity_groups.items():
if len(doc_indices) < 2:
continue
# Extract facts about entity from each document
facts = {}
for idx in doc_indices:
doc_facts = self._extract_facts(
documents[idx], entity
)
facts[idx] = doc_facts
# Generate synthesis question
if len(facts) >= 2:
qa = self._generate_synthesis_question(
entity, facts, documents
)
if qa:
qa_pairs.append(qa)
return qa_pairs
def build_contradiction_qa(self, documents):
"""
Build QA requiring detection of contradictions
across documents.
Two documents may state conflicting facts about
the same topic. The model must identify the
contradiction and potentially resolve it using
additional evidence.
"""
qa_pairs = []
for i in range(len(documents)):
for j in range(i + 1, len(documents)):
# Check for topic overlap
overlap = self._compute_topic_overlap(
documents[i], documents[j]
)
if overlap > 0.3:
contradictions = self._find_contradictions(
documents[i], documents[j]
)
for contradiction in contradictions:
qa_pairs.append({
"question": (
f"Document {i+1} and Document "
f"{j+1} make conflicting claims "
f"about {contradiction['topic']}. "
f"What are the two claims and "
f"which is better supported?"
),
"answer": contradiction["resolution"],
"required_docs": [i, j],
"type": "contradiction",
})
return qa_pairs
def build_timeline_qa(self, documents):
"""
Build QA requiring temporal ordering from
scattered mentions across documents.
"""
# Extract temporal events from all documents
all_events = []
for i, doc in enumerate(documents):
events = self._extract_temporal_events(doc)
for event in events:
event["source_doc"] = i
all_events.extend(events)
if len(all_events) < 3:
return []
# Sort by date
all_events.sort(key=lambda e: e.get("date", ""))
qa_pairs = []
# "What happened first" questions
if len(all_events) >= 2:
e1, e2 = all_events[0], all_events[-1]
qa_pairs.append({
"question": (
f"Based on the documents, which happened "
f"first: {e1['description']} or "
f"{e2['description']}?"
),
"answer": (
f"{e1['description']} happened first "
f"({e1.get('date', 'earlier')})."
),
"required_docs": [
e1["source_doc"], e2["source_doc"]
],
"type": "timeline",
})
return qa_pairs
def _find_shared_entities(self, documents):
"""Find entities mentioned in multiple documents."""
entity_docs = {}
for i, doc in enumerate(documents):
entities = self._extract_entities(doc)
for entity in entities:
if entity not in entity_docs:
entity_docs[entity] = []
entity_docs[entity].append(i)
return {
k: v for k, v in entity_docs.items()
if len(v) >= 2
}
def _extract_entities(self, document):
"""Extract named entities from document."""
return [] # Placeholder for NER
def _extract_facts(self, document, entity):
"""Extract facts about an entity from a document."""
return [] # Placeholder
def _generate_synthesis_question(self, entity, facts,
documents):
"""Generate a synthesis question from multiple facts."""
return None # Placeholder
def _compute_topic_overlap(self, doc_a, doc_b):
"""Compute topic overlap between two documents."""
return 0.0 # Placeholder
def _find_contradictions(self, doc_a, doc_b):
"""Find contradicting claims between documents."""
return [] # Placeholder
def _extract_temporal_events(self, document):
"""Extract dated events from a document."""
return [] # Placeholder
Multi-Document QA Task Difficulty
| Task Type | Documents Required | Avg Answer Length | Human Accuracy | GPT-4 128K Accuracy | Claude 200K Accuracy |
|---|---|---|---|---|---|
| Single-doc retrieval | 1 | Short (1-2 sentences) | 95% | 92% | 94% |
| Cross-doc synthesis | 2-3 | Medium (3-5 sentences) | 88% | 78% | 82% |
| Contradiction detection | 2 | Medium | 82% | 65% | 71% |
| Timeline reconstruction | 3-5 | Long (ordered list) | 75% | 58% | 64% |
| Entity resolution | 4+ | Medium | 70% | 52% | 59% |
Training Strategy for Long Context
Progressive Length Training
class LongContextTrainingSchedule:
"""
Training schedule for progressive context length extension.
Models are not trained at full context length from
the start. Instead, context length is gradually
increased during training:
Phase 1: 4K tokens (standard pretraining data)
Phase 2: 16K tokens (RoPE base frequency adjusted)
Phase 3: 64K tokens (with long-context data mix)
Phase 4: 128K+ tokens (final extension)
Each phase uses a different data mix optimized
for the target length.
"""
def __init__(self, base_model_context=4096):
self.base_context = base_model_context
def get_schedule(self, target_context=131072):
"""
Generate a training schedule for extending
context from base to target.
"""
phases = []
current = self.base_context
while current < target_context:
next_length = min(current * 4, target_context)
phase = {
"context_length": next_length,
"rope_base": self._compute_rope_base(
next_length
),
"learning_rate": self._compute_lr(
next_length
),
"data_mix": self._compute_data_mix(
next_length
),
"steps": self._compute_steps(
next_length
),
"batch_tokens": self._compute_batch_tokens(
next_length
),
}
phases.append(phase)
current = next_length
return phases
def _compute_rope_base(self, context_length):
"""
Compute RoPE base frequency for target context.
Standard: base = 10000
For 4x extension: base = 10000 * 4 = 40000
For 32x extension: base = 10000 * 32 = 320000
Alternatively, use NTK-aware scaling:
base_new = base * (scale ** (dim / (dim - 2)))
"""
scale = context_length / 4096
dim = 128 # Typical head dimension
base = 10000.0
# NTK-aware scaling
ntk_base = base * (scale ** (dim / (dim - 2)))
return ntk_base
def _compute_lr(self, context_length):
"""
Learning rate decreases with context length.
Longer contexts have more gradient signal per step,
so a lower learning rate prevents instability.
"""
base_lr = 2e-5
scale = context_length / 4096
return base_lr / (scale ** 0.5)
def _compute_data_mix(self, context_length):
"""
Data mix changes with context length.
Short context: mostly standard pretraining data
Long context: increasing proportion of long documents
and synthetic long-context tasks
"""
if context_length <= 8192:
return {
"standard": 0.9,
"long_natural": 0.05,
"long_synthetic": 0.05,
}
elif context_length <= 32768:
return {
"standard": 0.5,
"long_natural": 0.25,
"long_synthetic": 0.25,
}
else:
return {
"standard": 0.2,
"long_natural": 0.40,
"long_synthetic": 0.40,
}
def _compute_steps(self, context_length):
"""Training steps per phase."""
base_steps = 1000
scale = context_length / 4096
return int(base_steps * (scale ** 0.5))
def _compute_batch_tokens(self, context_length):
"""
Constant token budget per batch.
As context length increases, batch size (in sequences)
decreases proportionally to maintain constant memory.
"""
target_batch_tokens = 4 * 1024 * 1024 # 4M tokens
return target_batch_tokens
Progressive Length Training: Perplexity by Phase
| Metric | Base (4K) | Phase 1 (16K) | Phase 2 (64K) | Phase 3 (128K) | Phase 4 (256K) |
|---|---|---|---|---|---|
| Perplexity at 128K tokens | |||||
| Perplexity at 4K tokens (regression check) |
Key Takeaways
Long-context training data is the binding constraint for extending context windows beyond 128K tokens. Architecture changes (RoPE scaling, ring attention) are necessary but not sufficient — the model must see long-context examples during training.
The critical decisions:
-
Natural long documents are scarce and copyright-constrained: Project Gutenberg (public domain) provides 70K book-length texts. arXiv provides millions of 8K-token papers. Synthetic multi-document compositions fill the gap by combining shorter documents with cross-references.
-
Needle-in-haystack testing requires careful construction: Simple NIAH tests (one needle, one question) are near-solved for most models at 128K. Multi-needle, distractor-needle, and cross-reference NIAH tests better differentiate model capabilities. The lost-in-the-middle effect persists even in the best models.
-
Position-aware augmentation is essential: Without explicit counteraction, models develop primacy and recency biases. Randomizing answer position across training samples and over-sampling middle-position examples reduces the accuracy gap between edge and middle positions by 10-15 percentage points.
-
Progressive length training is more stable than direct extension: Training directly at 128K tokens from a 4K-token base model causes instability. Gradually increasing context length (4K, 16K, 64K, 128K) with adjusted RoPE frequencies and learning rates produces stable convergence with minimal regression on short-context tasks.
-
Multi-document QA is harder than single-document retrieval: Tasks requiring cross-document reasoning (contradiction detection, timeline reconstruction, entity resolution) see 20-30% accuracy drops compared to single-document retrieval. Training data must include these multi-hop reasoning tasks explicitly.