Serving a 70B-parameter model at 128K context length is not a configuration knob you turn. It is a memory engineering problem. The KV cache alone for Llama 2 70B at 128K context with 80 attention heads, 128-dimensional head size, 80 layers, and FP16 precision occupies:
The factor of 2 at the front accounts for K and V projections. The trailing factor of 2 is bytes per FP16 element. An H100 has 80 GB of HBM3. The model weights for Llama 70B in FP16 occupy ~140 GB (requiring at minimum 2-GPU tensor parallelism), and in INT4 occupy ~35 GB. Even with INT4 weights on a single GPU, model weights (35 GB) plus KV cache (42 GB) equals 77 GB — leaving 3 GB for activations, the CUDA context, and the framework overhead. That is for a single request. Multi-tenant serving with batch sizes of 8-64 is impossible without intervention.
This post covers three techniques that make long-context serving tractable: KV cache offloading (move cold KV blocks to CPU DRAM), Ring Attention (distribute the sequence across GPUs), and chunked prefill (process long prompts in manageable pieces). Each technique addresses a different dimension of the problem, and production systems combine all three.
1. The Memory Budget at Scale
Before diving into solutions, let us establish the numbers precisely. The KV cache size depends on the model architecture and the serving configuration.
KV Cache Memory by Model and Context Length
| Model | Context | KV Heads | Head Dim | Layers | KV Cache (FP16) | KV Cache (FP8) |
|---|---|---|---|---|---|---|
| Llama 2 70B | 4K | 80 | 128 | 80 | 1.31 GB | 0.66 GB |
| Llama 2 70B | 32K | 80 | 128 | 80 | 10.49 GB | 5.24 GB |
| Llama 2 70B | 128K | 80 | 128 | 80 | 41.94 GB | 20.97 GB |
| Llama 3 70B (GQA) | 128K | 8 | 128 | 80 | 5.24 GB | 2.62 GB |
| Llama 3 405B (GQA) | 128K | 8 | 128 | 126 | 8.26 GB | 4.13 GB |
| Mistral 7B (GQA) | 128K | 8 | 128 | 32 | 1.07 GB | 0.54 GB |
The critical observation: GQA is the single most impactful architectural change for long-context serving. Llama 3 70B at 128K uses 5.24 GB of KV cache in FP16 versus 42 GB for Llama 2 70B with full MHA. But even 5.24 GB per request becomes 42 GB at batch size 8, and 168 GB at batch size 32. The memory problem does not disappear with GQA — it just shifts to higher batch sizes.
The Memory Equation for Multi-Tenant Serving
For a serving system handling concurrent requests at maximum context length :
where and scales with batch size and the peak intermediate tensor size (typically the attention score matrix during prefill, which is per head per layer but computed blockwise with FlashAttention).
For Llama 3 70B in INT4 weights, FP16 KV, batch size 16, 128K context on 2x H100 (160 GB total):
That fits in 160 GB — barely. At batch size 24 the KV alone is 125.8 GB. At batch size 32 it is 167.7 GB, exceeding our 2-GPU budget. And this assumes all 16 requests simultaneously fill 128K tokens. In practice, most requests use far less, which is where paged attention and dynamic memory allocation help. But the peak case must be handled, and that is where offloading enters.
2. KV Cache Offloading: GPU to CPU DRAM
The core idea is simple: not all KV cache entries are equally “hot.” During autoregressive decoding, the attention mechanism accesses the full KV cache for each new token — but the attention scores concentrate on recent tokens and a sparse set of “important” tokens from the distant past (attention sinks, topic-relevant tokens). Older KV blocks that receive low attention scores can be moved to CPU DRAM and restored on demand.
Block-Level Offloading Architecture
KV cache offloading operates at the block level, borrowing from vLLM’s PagedAttention abstraction. The KV cache is divided into blocks of tokens (typically 16-256 tokens per block). Each block stores the K and V tensors for all layers:
For Llama 3 70B with GQA (8 KV heads, 128 dim, 80 layers) at block size 256, FP16:
At block size 32:
Smaller blocks give finer-grained eviction control but increase the bookkeeping overhead (more blocks to track, more individual transfers).
import torch
from dataclasses import dataclass, field
from collections import OrderedDict
@dataclass
class KVBlock:
"""A block of KV cache entries for all layers."""
block_id: int
seq_id: int
start_pos: int # Token position within the sequence
num_tokens: int # Actual tokens stored (may be less than block_size)
gpu_tensor: torch.Tensor = None # [2, layers, kv_heads, num_tokens, head_dim]
cpu_tensor: torch.Tensor = None # Pinned CPU copy
on_gpu: bool = True
last_access_step: int = 0
access_count: int = 0
importance_score: float = 0.0
class KVOffloadManager:
"""Manages KV cache blocks across GPU HBM and CPU DRAM.
Keeps a sliding window of W recent blocks on GPU.
Older blocks are offloaded to pinned CPU memory.
Blocks are restored to GPU on demand via async CUDA streams.
"""
def __init__(
self,
num_layers: int,
num_kv_heads: int,
head_dim: int,
block_size: int = 32,
gpu_budget_blocks: int = 256, # Max blocks on GPU
dtype: torch.dtype = torch.float16,
device: str = "cuda:0",
):
self.num_layers = num_layers
self.num_kv_heads = num_kv_heads
self.head_dim = head_dim
self.block_size = block_size
self.gpu_budget = gpu_budget_blocks
self.dtype = dtype
self.device = device
# Block shape: [2, layers, kv_heads, block_size, head_dim]
# 2 = K and V
self.block_shape = (2, num_layers, num_kv_heads, block_size, head_dim)
self.block_bytes = 2 * num_layers * num_kv_heads * block_size * head_dim
if dtype == torch.float16:
self.block_bytes *= 2
elif dtype == torch.float8_e4m3fn:
self.block_bytes *= 1
# GPU block pool (pre-allocated)
self.gpu_pool = torch.zeros(
gpu_budget_blocks, *self.block_shape,
dtype=dtype, device=device
)
self.gpu_free_slots = list(range(gpu_budget_blocks))
self.gpu_slot_map = {} # block_id -> gpu_slot_index
# CPU pinned memory pool
self.cpu_pool = {} # block_id -> pinned tensor
# Block registry
self.blocks = {} # block_id -> KVBlock
self.next_block_id = 0
# Async transfer streams
self.h2d_stream = torch.cuda.Stream(device=device) # CPU->GPU
self.d2h_stream = torch.cuda.Stream(device=device) # GPU->CPU
# LRU tracking for eviction
self.gpu_lru = OrderedDict() # block_id -> access_step
self.current_step = 0
def allocate_block(self, seq_id, start_pos, num_tokens):
"""Allocate a new KV block, evicting from GPU if necessary."""
block_id = self.next_block_id
self.next_block_id += 1
# Get a GPU slot (evict if needed)
gpu_slot = self._get_gpu_slot()
block = KVBlock(
block_id=block_id,
seq_id=seq_id,
start_pos=start_pos,
num_tokens=num_tokens,
on_gpu=True,
last_access_step=self.current_step,
)
block.gpu_tensor = self.gpu_pool[gpu_slot]
self.blocks[block_id] = block
self.gpu_slot_map[block_id] = gpu_slot
self.gpu_lru[block_id] = self.current_step
return block
def _get_gpu_slot(self):
"""Get a free GPU slot, evicting the LRU block if needed."""
if self.gpu_free_slots:
return self.gpu_free_slots.pop()
# Evict LRU block to CPU
evict_id, _ = next(iter(self.gpu_lru))
self._offload_to_cpu(evict_id)
return self.gpu_free_slots.pop()
def _offload_to_cpu(self, block_id):
"""Move a block from GPU to CPU pinned memory (async)."""
block = self.blocks[block_id]
gpu_slot = self.gpu_slot_map[block_id]
# Allocate pinned CPU tensor if not exists
if block_id not in self.cpu_pool:
self.cpu_pool[block_id] = torch.empty(
self.block_shape, dtype=self.dtype,
pin_memory=True
)
# Async copy GPU -> CPU
with torch.cuda.stream(self.d2h_stream):
self.cpu_pool[block_id].copy_(self.gpu_pool[gpu_slot], non_blocking=True)
# Record event for synchronization
self.d2h_stream.synchronize()
# Free GPU slot
block.on_gpu = False
block.gpu_tensor = None
del self.gpu_slot_map[block_id]
del self.gpu_lru[block_id]
self.gpu_free_slots.append(gpu_slot)
def restore_to_gpu(self, block_id):
"""Restore a block from CPU to GPU (async). Returns the GPU tensor."""
block = self.blocks[block_id]
if block.on_gpu:
# Already on GPU, just update LRU
self.gpu_lru.move_to_end(block_id)
gpu_slot = self.gpu_slot_map[block_id]
return self.gpu_pool[gpu_slot]
# Need to restore from CPU
gpu_slot = self._get_gpu_slot()
with torch.cuda.stream(self.h2d_stream):
self.gpu_pool[gpu_slot].copy_(self.cpu_pool[block_id], non_blocking=True)
# Must synchronize before the compute stream uses this data
event = self.h2d_stream.record_event()
torch.cuda.current_stream().wait_event(event)
block.on_gpu = True
block.gpu_tensor = self.gpu_pool[gpu_slot]
self.gpu_slot_map[block_id] = gpu_slot
self.gpu_lru[block_id] = self.current_step
return self.gpu_pool[gpu_slot]
def get_kv_for_attention(self, seq_id, required_positions):
"""Gather all KV blocks for a sequence, restoring from CPU as needed.
Returns a contiguous KV tensor on GPU: [2, layers, kv_heads, total_tokens, head_dim]
"""
self.current_step += 1
seq_blocks = sorted(
[b for b in self.blocks.values() if b.seq_id == seq_id],
key=lambda b: b.start_pos
)
gpu_tensors = []
for block in seq_blocks:
if block.start_pos + block.num_tokens <= required_positions[0]:
continue # Before our window
if block.start_pos > required_positions[-1]:
break
tensor = self.restore_to_gpu(block.block_id)
# Slice to actual token count
gpu_tensors.append(tensor[:, :, :, :block.num_tokens, :])
block.last_access_step = self.current_step
block.access_count += 1
if block.block_id in self.gpu_lru:
self.gpu_lru.move_to_end(block.block_id)
if not gpu_tensors:
return None
return torch.cat(gpu_tensors, dim=3) # Concatenate along token dimension
Transfer Latency Analysis
The critical performance question: how fast can we restore a KV block from CPU to GPU? The transfer uses PCIe or NVLink CPU-GPU interconnect. On a standard H100 SXM system:
- PCIe Gen5 x16: 64 GB/s bidirectional (32 GB/s per direction effective)
- Block size (Llama 3 70B, GQA, 32 tokens): 10 MB
- Transfer time:
With block size 256 tokens (80 MB per block):
- Transfer time:
KV Block Transfer Latency (PCIe Gen5)
| Block Size (tokens) | Block Size (MB) | Transfer Time | Tokens Decoded During Transfer | Overhead per Token |
|---|---|---|---|---|
| 16 | 5.24 | 164 us | 0.002 (at 10 tok/s) | 0.2% |
| 32 | 10.49 | 328 us | 0.003 | 0.3% |
| 64 | 20.97 | 655 us | 0.007 | 0.7% |
| 128 | 41.94 | 1.31 ms | 0.013 | 1.3% |
| 256 | 83.89 | 2.62 ms | 0.026 | 2.6% |
The latency is manageable for two reasons. First, the transfer is asynchronous: the GPU can compute attention over already-resident blocks while the CPU-to-GPU copy of the next block proceeds on a separate CUDA stream. Second, during decoding, the model generates one token at a time (typically 50-200ms for a 70B model), and the block restore (0.2-2.6ms) is a small fraction of that.
Eviction Policies
LRU (Least Recently Used) is the baseline eviction policy, but it is not optimal for attention patterns. Better policies consider the attention score distribution:
class AttentionAwareEvictionPolicy:
"""Eviction policy that considers attention patterns.
Keeps blocks with high cumulative attention scores on GPU,
even if they have not been accessed recently.
This handles 'attention sinks' -- early tokens that receive
disproportionate attention throughout generation.
"""
def __init__(self, window_size, sink_tokens=4):
self.window_size = window_size # Always keep last W blocks
self.sink_tokens = sink_tokens # Always keep first S tokens
self.attention_accumulator = {} # block_id -> cumulative score
self.decay = 0.95 # Exponential decay for old scores
def update_scores(self, block_ids, attention_scores):
"""Update cumulative attention scores after each decode step.
attention_scores: average attention weight each block received
from the new query token, averaged across
all heads and layers.
"""
# Decay existing scores
for bid in self.attention_accumulator:
self.attention_accumulator[bid] *= self.decay
# Add new scores
for bid, score in zip(block_ids, attention_scores):
if bid not in self.attention_accumulator:
self.attention_accumulator[bid] = 0.0
self.attention_accumulator[bid] += score
def select_eviction_candidates(self, all_blocks, num_to_evict):
"""Select blocks to evict from GPU.
Never evicts:
- Blocks in the sliding window (most recent W blocks)
- Blocks containing attention sink tokens (first S tokens)
Among the rest, evicts blocks with lowest cumulative
attention score.
"""
# Sort by position
sorted_blocks = sorted(all_blocks, key=lambda b: b.start_pos)
# Protected: sink blocks and window blocks
sink_blocks = set()
window_blocks = set()
# First N tokens are sinks
pos = 0
for block in sorted_blocks:
if pos + block.num_tokens <= self.sink_tokens:
sink_blocks.add(block.block_id)
elif pos < self.sink_tokens:
sink_blocks.add(block.block_id)
pos += block.num_tokens
# Last W blocks are window
for block in sorted_blocks[-self.window_size:]:
window_blocks.add(block.block_id)
protected = sink_blocks | window_blocks
# Eviction candidates: everything else, sorted by attention score
candidates = [
b for b in all_blocks
if b.block_id not in protected and b.on_gpu
]
candidates.sort(
key=lambda b: self.attention_accumulator.get(b.block_id, 0.0)
)
return [c.block_id for c in candidates[:num_to_evict]]
Xiao et al. (2023) showed that the first 1-4 tokens in a sequence receive disproportionately high attention scores regardless of their content. These “attention sink” tokens act as a bias term in the softmax normalization. Evicting their KV cache entries causes severe quality degradation. Any eviction policy must protect these initial tokens unconditionally.
Prefetching Strategy
Naive on-demand restoration creates a bubble: the GPU stalls waiting for the block transfer. Prefetching eliminates this by predicting which blocks will be needed and starting the transfer before the attention kernel needs them.
class KVPrefetcher:
"""Prefetches KV blocks from CPU to GPU based on predicted access patterns.
For autoregressive decoding, the access pattern is deterministic:
every decode step accesses all blocks for the active sequences.
The prefetcher restores blocks in order of expected attention
score (high-attention blocks first) to minimize quality impact
if a transfer is still in flight when the kernel launches.
"""
def __init__(self, offload_manager, lookahead_steps=2):
self.manager = offload_manager
self.lookahead = lookahead_steps
self.pending_transfers = {} # block_id -> cuda_event
def prefetch_for_step(self, seq_id, step):
"""Issue prefetch commands for blocks needed in upcoming steps.
Strategy: For each active sequence, ensure all blocks are
either on GPU or have an in-flight transfer. Prioritize
blocks with high historical attention scores.
"""
seq_blocks = [
b for b in self.manager.blocks.values()
if b.seq_id == seq_id and not b.on_gpu
]
# Sort by importance (high attention score first)
seq_blocks.sort(
key=lambda b: b.importance_score, reverse=True
)
for block in seq_blocks:
if block.block_id in self.pending_transfers:
continue # Already in flight
if not self.manager.gpu_free_slots:
break # No GPU space
# Start async transfer
gpu_slot = self.manager._get_gpu_slot()
with torch.cuda.stream(self.manager.h2d_stream):
self.manager.gpu_pool[gpu_slot].copy_(
self.manager.cpu_pool[block.block_id],
non_blocking=True
)
event = self.manager.h2d_stream.record_event()
self.pending_transfers[block.block_id] = (gpu_slot, event)
def wait_and_finalize(self, block_id):
"""Wait for a prefetched block and register it as GPU-resident."""
if block_id not in self.pending_transfers:
return
gpu_slot, event = self.pending_transfers.pop(block_id)
event.synchronize()
block = self.manager.blocks[block_id]
block.on_gpu = True
block.gpu_tensor = self.manager.gpu_pool[gpu_slot]
self.manager.gpu_slot_map[block_id] = gpu_slot
self.manager.gpu_lru[block_id] = self.manager.current_step
3. Ring Attention for Distributed Long Context
KV offloading handles the case where a single GPU cannot fit the entire KV cache. Ring Attention (Liu et al., 2023) handles the case where the sequence itself is too long to process on a single GPU during prefill. Rather than replicating the full sequence on each device, Ring Attention partitions the sequence across GPUs and implements attention via a ring communication pattern.
The Ring Attention Algorithm
Given GPUs and a sequence of tokens, each GPU holds a contiguous chunk of tokens: queries , keys , and values .
To compute attention for , GPU needs KV from all chunks. Ring Attention circulates the KV blocks around the ring:
Step 0: GPU_i computes attention(Q_i, K_i, V_i) -- local KV
Step 1: GPU_i receives K_{i-1}, V_{i-1} from left neighbor
GPU_i computes attention(Q_i, K_{i-1}, V_{i-1})
Step 2: GPU_i receives K_{i-2}, V_{i-2}
GPU_i computes attention(Q_i, K_{i-2}, V_{i-2})
...
Step P-1: GPU_i has accumulated attention from all chunks
At each step, every GPU simultaneously sends its current KV buffer to the right neighbor and receives a new KV buffer from the left neighbor. The communication overlaps with the attention computation on the previous KV buffer.
import torch
import torch.distributed as dist
def ring_attention_forward(
Q, K, V,
rank, world_size,
process_group,
causal=True,
block_size=4096
):
"""Ring Attention: distributed attention across GPUs.
Each GPU holds a chunk of the sequence.
Q: [batch, local_seq_len, num_heads, head_dim] -- queries for this chunk
K: [batch, local_seq_len, num_kv_heads, head_dim] -- keys for this chunk
V: [batch, local_seq_len, num_kv_heads, head_dim] -- values for this chunk
Returns: attention output for this GPU's query chunk.
"""
B, S_local, H, D = Q.shape
_, _, H_kv, _ = K.shape
# Initialize accumulators for online softmax
# O_i = sum of softmax(scores) * V, unnormalized
# l_i = sum of exp(scores - max_score), for normalization
# m_i = running max of scores, for numerical stability
O = torch.zeros_like(Q) # [B, S_local, H, D]
l = torch.zeros(B, S_local, H, 1, device=Q.device, dtype=Q.dtype)
m = torch.full((B, S_local, H, 1), float('-inf'), device=Q.device, dtype=Q.dtype)
# Double-buffered KV for overlapping communication and compute
kv_send = torch.cat([K, V], dim=2) # [B, S_local, 2*H_kv, D]
kv_recv = torch.empty_like(kv_send)
for step in range(world_size):
# Source of the KV at this step
src_rank = (rank - step) % world_size
# Start async ring communication (send right, recv from left)
if step < world_size - 1:
send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
send_op = dist.isend(kv_send, dst=send_rank, group=process_group)
recv_op = dist.irecv(kv_recv, src=recv_rank, group=process_group)
# Extract K, V from the current buffer
K_step = kv_send[:, :, :H_kv, :]
V_step = kv_send[:, :, H_kv:, :]
# Causal masking: GPU i's queries at positions [i*S_local, (i+1)*S_local)
# can only attend to KV at positions up to their own position.
# KV from src_rank covers positions [src_rank*S_local, (src_rank+1)*S_local)
if causal and src_rank > rank:
# This KV chunk is entirely in the future -- skip
pass
else:
# Compute blockwise attention with online softmax update
# Score: [B, S_local, H, S_local_kv]
scores = torch.einsum('bshd,bthd->bsht', Q, K_step.repeat(1, 1, H // H_kv, 1))
scores = scores / (D ** 0.5)
# Apply causal mask for the diagonal block
if causal and src_rank == rank:
# Only this block needs a causal mask
q_positions = torch.arange(rank * S_local, (rank + 1) * S_local, device=Q.device)
kv_positions = torch.arange(src_rank * S_local, (src_rank + 1) * S_local, device=Q.device)
mask = q_positions.unsqueeze(1) < kv_positions.unsqueeze(0)
scores[:, :, :, :].masked_fill_(mask.unsqueeze(0).unsqueeze(2), float('-inf'))
# Online softmax accumulation (FlashAttention-2 style)
m_step = scores.max(dim=-1, keepdim=True).values # [B, S, H, 1]
m_new = torch.maximum(m, m_step)
# Rescale old accumulator
exp_diff_old = torch.exp(m - m_new)
# New exponentials
exp_scores = torch.exp(scores - m_new)
# Update running sum of exponentials
l_new = l * exp_diff_old + exp_scores.sum(dim=-1, keepdim=True)
# Update output accumulator
O = O * exp_diff_old + torch.einsum('bsht,bthd->bshd', exp_scores, V_step.repeat(1, 1, H // H_kv, 1))
l = l_new
m = m_new
# Wait for communication to finish, then swap buffers
if step < world_size - 1:
send_op.wait()
recv_op.wait()
kv_send, kv_recv = kv_recv, kv_send
# Final normalization
O = O / l
return O
Ring Attention requires the online softmax trick (Milakov and Gimelshein, 2018; Dao et al., 2022). Without it, you would need to compute the global softmax denominator across all chunks before computing outputs — requiring a global reduction and preventing overlap. With online softmax, each chunk contributes incrementally: update the running max , rescale the previous accumulator, and add the new contribution. The final normalization at the end produces numerically identical results to standard attention.
Communication-Computation Overlap
The ring pattern’s key property is that communication and computation overlap. At each step, while GPU computes attention(Q_i, KV_current), it simultaneously sends KV_current to GPU and receives KV_next from GPU $i-1`. The overlap is effective as long as:
For a local chunk of tokens:
Ring Attention: Compute vs Communication Time
(ms)With NVLink (900 GB/s), communication time is 5-15x smaller than compute time for typical configurations. The overlap is nearly perfect. On PCIe-only systems (64 GB/s), communication time grows by ~14x, and overlap breaks down beyond 4 GPUs.
Handling Causal Masking Efficiently
A subtle problem with Ring Attention: causal masking creates workload imbalance. GPU 0 (holding the earliest tokens) has queries that can only attend to their own chunk. GPU (holding the latest tokens) has queries that attend to all chunks. In a naive implementation, GPU 0 finishes in 1 ring step while GPU takes all steps.
The Striped Attention optimization (Brandon et al., 2023) interleaves token positions across GPUs instead of assigning contiguous chunks:
Contiguous: GPU0=[0,1,2,3] GPU1=[4,5,6,7] GPU2=[8,9,10,11] GPU3=[12,13,14,15]
Striped: GPU0=[0,4,8,12] GPU1=[1,5,9,13] GPU2=[2,6,10,14] GPU3=[3,7,11,15]
With striped assignment, every GPU has a mix of early and late tokens, and the causal mask blocks are roughly equal across GPUs. The workload imbalance drops from to .
def stripe_sequence(tokens, world_size):
"""Distribute tokens across GPUs in a striped pattern.
Instead of contiguous chunks [0..N/P], [N/P..2N/P], ...,
assign token i to GPU (i % world_size).
This balances causal mask workload across GPUs.
"""
N = len(tokens)
local_indices = {}
for rank in range(world_size):
local_indices[rank] = list(range(rank, N, world_size))
return local_indices
def unstripe_output(local_outputs, local_indices, total_length):
"""Reassemble the full output from striped local results."""
full_output = torch.empty(total_length, *local_outputs[0].shape[1:],
device=local_outputs[0].device,
dtype=local_outputs[0].dtype)
for rank, indices in local_indices.items():
full_output[indices] = local_outputs[rank]
return full_output
4. Chunked Prefill Processing
The third technique addresses a different bottleneck: processing the initial prompt (prefill phase) for very long contexts. A 128K-token prompt processed in a single forward pass creates an intermediate attention matrix of size — which, even with FlashAttention’s tiling, requires massive compute. More critically, the prefill blocks the GPU for the entire duration, preventing any decode-phase tokens from being generated for other requests.
Chunked prefill splits the long prompt into chunks of tokens (typically 4K-8K) and processes them sequentially. Each chunk attends to all previous chunks via the already-computed KV cache, plus its own tokens via causal self-attention.
The Chunked Prefill Algorithm
class ChunkedPrefillEngine:
"""Process long prompts in chunks to bound memory and enable scheduling.
Key insight: attention is decomposable. Processing tokens [0, C) first,
then [C, 2C) with the KV cache from [0, C), produces identical results
to processing [0, 2C) at once. This is because:
- For positions in [C, 2C): they attend to [0, C) via cached KV and to
[C, 2C) via causal self-attention within the chunk. The attention
scores and outputs are numerically identical to full-sequence attention.
- For positions in [0, C): they were already fully processed in chunk 0.
Their KV entries are correct and final.
"""
def __init__(self, model, chunk_size=4096, kv_cache=None):
self.model = model
self.chunk_size = chunk_size
self.kv_cache = kv_cache # Shared KV cache (e.g., vLLM PagedAttention)
def prefill(self, input_ids, seq_id):
"""Process a long prompt in chunks.
input_ids: [1, total_seq_len] -- the full prompt
Returns: logits for the last token (for generation)
"""
total_len = input_ids.shape[1]
num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size
all_logits = None
for chunk_idx in range(num_chunks):
start = chunk_idx * self.chunk_size
end = min(start + self.chunk_size, total_len)
chunk_ids = input_ids[:, start:end]
# The model attends to:
# 1. All previous KV cache entries (positions 0..start-1)
# 2. Current chunk tokens (positions start..end-1) with causal mask
# Position offsets so the model knows where these tokens sit
# in the full sequence
position_ids = torch.arange(start, end, device=chunk_ids.device).unsqueeze(0)
outputs = self.model(
input_ids=chunk_ids,
position_ids=position_ids,
past_key_values=self.kv_cache.get_cache(seq_id),
use_cache=True,
)
# The model appends this chunk's KV to the cache automatically
# Now the cache covers positions 0..end-1
all_logits = outputs.logits
# Return logits for the last token only (used for generation)
return all_logits[:, -1:, :]
def prefill_with_scheduling(self, input_ids, seq_id, decode_queue):
"""Chunked prefill interleaved with decode steps for other requests.
After each prefill chunk, check if any decode-phase requests
are waiting. If so, batch their decode steps together and
process them before continuing with the next prefill chunk.
This prevents a single long-context prefill from starving
all decode-phase requests (which are latency-sensitive).
"""
total_len = input_ids.shape[1]
num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size
for chunk_idx in range(num_chunks):
start = chunk_idx * self.chunk_size
end = min(start + self.chunk_size, total_len)
chunk_ids = input_ids[:, start:end]
position_ids = torch.arange(start, end, device=chunk_ids.device).unsqueeze(0)
# Process this prefill chunk
self.model(
input_ids=chunk_ids,
position_ids=position_ids,
past_key_values=self.kv_cache.get_cache(seq_id),
use_cache=True,
)
# Yield to decode requests between chunks
if decode_queue.has_pending():
decode_batch = decode_queue.get_batch()
self._run_decode_step(decode_batch)
def _run_decode_step(self, decode_batch):
"""Run one decode step for a batch of requests."""
input_ids = torch.stack([r.last_token for r in decode_batch])
position_ids = torch.stack([r.current_pos for r in decode_batch])
outputs = self.model(
input_ids=input_ids,
position_ids=position_ids,
past_key_values=[self.kv_cache.get_cache(r.seq_id) for r in decode_batch],
use_cache=True,
)
# Sample next tokens and update decode state
for i, req in enumerate(decode_batch):
logits = outputs.logits[i, -1, :]
next_token = torch.argmax(logits)
req.append_token(next_token)
Chunk Size Selection
The chunk size controls the tradeoff between prefill throughput and decode latency:
Chunk Size Impact on Prefill and Decode
| Chunk Size | Prefill Throughput | Decode Interruption | Peak Activation Memory | Use Case |
|---|---|---|---|---|
| 1024 | 65% of max | Every 8ms | 0.5 GB | Latency-critical (chatbots) |
| 4096 | 88% of max | Every 32ms | 2 GB | Balanced (default) |
| 8192 | 94% of max | Every 65ms | 4 GB | Throughput-focused |
| 16384 | 97% of max | Every 130ms | 8 GB | Batch processing |
| Full sequence | 100% | Never (blocked) | O(S) GB | Offline only |
The sweet spot for interactive serving is 4096-8192. At 4K chunks, the overhead versus unchunked processing is ~12% (due to repeated KV cache lookups and kernel launch overhead per chunk), but decode requests can be serviced every 32ms, keeping time-to-first-token (TTFT) for concurrent requests reasonable.
Chunked prefill produces numerically identical results to full-sequence prefill. This is not an approximation. Each chunk computes exact causal attention over its tokens plus all previously cached KV. The decomposition works because attention is a function of (Q, K, V) pairs and the causal mask only depends on position indices, both of which are correctly supplied per-chunk via position_ids and the growing KV cache.
5. Combining Techniques: Production Architecture
A production long-context serving system combines all three techniques. Here is the architecture used by systems like vLLM, SGLang, and TensorRT-LLM for 128K+ serving:
Layer 1: Prefix Caching
Many long-context requests share a common prefix (system prompt, document context). Prefix caching stores the KV cache for shared prefixes and reuses it across requests.
class PrefixCacheManager:
"""Cache KV for common prefixes across requests.
Hash-based lookup: hash the token sequence to find cached KV.
Supports hierarchical matching -- if the full prefix is not
cached, find the longest cached prefix and compute the rest.
"""
def __init__(self, block_size=256):
self.block_size = block_size
self.cache = {} # hash -> list of KV block IDs
self.hash_to_tokens = {} # hash -> token sequence (for verification)
def _hash_block(self, token_ids):
"""Hash a block of tokens for cache lookup."""
return hash(tuple(token_ids.tolist()))
def lookup(self, input_ids):
"""Find the longest cached prefix for this input.
Returns: (num_cached_tokens, kv_block_ids)
"""
total_len = input_ids.shape[1]
cached_blocks = []
cached_tokens = 0
for start in range(0, total_len, self.block_size):
end = min(start + self.block_size, total_len)
block_tokens = input_ids[0, start:end]
block_hash = self._hash_block(block_tokens)
if block_hash in self.cache:
# Verify (hash collision check)
if torch.equal(self.hash_to_tokens[block_hash], block_tokens):
cached_blocks.extend(self.cache[block_hash])
cached_tokens = end
else:
break # Hash collision, stop matching
else:
break # No more cached prefix
return cached_tokens, cached_blocks
def store(self, input_ids, kv_block_ids):
"""Store KV blocks for this prefix in the cache."""
total_len = input_ids.shape[1]
block_idx = 0
for start in range(0, total_len, self.block_size):
end = min(start + self.block_size, total_len)
block_tokens = input_ids[0, start:end]
block_hash = self._hash_block(block_tokens)
if block_hash not in self.cache:
self.cache[block_hash] = [kv_block_ids[block_idx]]
self.hash_to_tokens[block_hash] = block_tokens.clone()
block_idx += 1
Layer 2: Integrated Serving Pipeline
class LongContextServingPipeline:
"""Production pipeline combining prefix caching, chunked prefill,
KV offloading, and (optionally) Ring Attention.
Request flow:
1. Check prefix cache for shared prefix hit
2. Chunked prefill for remaining tokens
3. During decode, offload cold KV blocks to CPU
4. For sequences exceeding single-GPU KV budget, use Ring Attention
"""
def __init__(
self,
model,
chunk_size=4096,
gpu_kv_budget_gb=40,
enable_offloading=True,
enable_ring_attention=False,
world_size=1,
):
self.model = model
self.chunk_size = chunk_size
self.prefix_cache = PrefixCacheManager(block_size=256)
self.chunked_prefill = ChunkedPrefillEngine(model, chunk_size)
if enable_offloading:
# Calculate max blocks that fit in GPU budget
block_bytes = model.config.kv_block_bytes(block_size=32)
max_gpu_blocks = int(gpu_kv_budget_gb * 1e9 / block_bytes)
self.offloader = KVOffloadManager(
num_layers=model.config.num_hidden_layers,
num_kv_heads=model.config.num_key_value_heads,
head_dim=model.config.hidden_size // model.config.num_attention_heads,
gpu_budget_blocks=max_gpu_blocks,
)
else:
self.offloader = None
self.ring_attention = enable_ring_attention
self.world_size = world_size
def serve_request(self, input_ids, max_new_tokens=4096):
"""Full serving pipeline for a single request."""
seq_id = self._allocate_sequence()
# Step 1: Prefix cache lookup
cached_tokens, cached_blocks = self.prefix_cache.lookup(input_ids)
if cached_tokens > 0:
# Reuse cached KV blocks
self._load_cached_kv(seq_id, cached_blocks)
remaining_ids = input_ids[:, cached_tokens:]
else:
remaining_ids = input_ids
# Step 2: Chunked prefill for uncached tokens
if remaining_ids.shape[1] > 0:
logits = self.chunked_prefill.prefill(remaining_ids, seq_id)
# Step 3: Store prefix in cache for future requests
kv_blocks = self._get_kv_block_ids(seq_id)
self.prefix_cache.store(input_ids, kv_blocks)
# Step 4: Autoregressive decode with KV offloading
generated_tokens = []
for step in range(max_new_tokens):
# Offload cold blocks if needed
if self.offloader:
self.offloader.current_step = step
self._manage_kv_residency(seq_id)
# Decode one token
logits = self._decode_step(seq_id)
next_token = torch.argmax(logits, dim=-1)
generated_tokens.append(next_token.item())
if next_token.item() == self.model.config.eos_token_id:
break
return generated_tokens
def _manage_kv_residency(self, seq_id):
"""Decide which KV blocks stay on GPU and which get offloaded."""
total_blocks = self._count_kv_blocks(seq_id)
gpu_blocks = self._count_gpu_resident_blocks(seq_id)
if gpu_blocks > self.offloader.gpu_budget * 0.9:
# Approaching GPU budget -- evict old blocks
num_evict = gpu_blocks - int(self.offloader.gpu_budget * 0.75)
candidates = self.offloader.select_eviction_candidates(
self._get_blocks(seq_id), num_evict
)
for block_id in candidates:
self.offloader._offload_to_cpu(block_id)
def _allocate_sequence(self):
pass # Implementation depends on the KV cache backend
def _load_cached_kv(self, seq_id, block_ids):
pass
def _get_kv_block_ids(self, seq_id):
pass
def _decode_step(self, seq_id):
pass
def _count_kv_blocks(self, seq_id):
pass
def _count_gpu_resident_blocks(self, seq_id):
pass
def _get_blocks(self, seq_id):
pass
End-to-End Performance
128K Context Serving: Throughput by Configuration
(tokens/s)Production Configuration for 128K Serving (Llama 3 70B)
| Component | Setting | Rationale |
|---|---|---|
| Weight precision | INT4 (AWQ) | 35 GB weights, fits single GPU |
| KV cache precision | FP8 E4M3 | Halves KV memory vs FP16, negligible quality loss |
| Chunk size | 8192 tokens | 94% prefill efficiency, 65ms decode gaps |
| KV offload policy | Attention-aware, 4 sink tokens | Protects attention sinks, evicts cold middle context |
| GPU KV budget | 75% of free HBM | 25% headroom for activation spikes |
| Block size | 32 tokens | Fine-grained eviction, 10 MB per block |
| Prefix cache | 256-token blocks, LRU eviction | Shared system prompts across requests |
| Tensor parallelism | 2 GPUs (TP=2) | Model weights split, KV per-GPU |
6. Measuring Long-Context Quality Under Offloading
KV offloading with eviction is, in theory, lossy: if a block is evicted and not restored before the attention kernel runs, the model produces different outputs. In practice, the eviction policy ensures all needed blocks are restored before compute, making it lossless. But there are failure modes:
class LongContextQualityBenchmark:
"""Benchmark to verify that KV offloading does not degrade quality.
Tests:
1. Needle-in-a-haystack: plant a fact at position P in a long context,
ask about it at the end. Measure recall across positions.
2. Passkey retrieval: hide a 5-digit passkey at various depths.
3. Multi-hop reasoning: facts spread across the context, answer
requires combining 2-3 of them.
"""
def __init__(self, model_runner, offload_manager):
self.runner = model_runner
self.offloader = offload_manager
def needle_in_haystack(self, context_length, needle_position, num_trials=20):
"""Test recall of a specific fact placed at needle_position."""
results = []
for trial in range(num_trials):
# Generate haystack (irrelevant text)
haystack = self._generate_haystack(context_length)
# Insert needle
needle_fact = f"The special code is {trial * 1000 + 42}."
needle_query = "What is the special code?"
# Place needle at specified position
token_pos = int(needle_position * context_length)
context = haystack[:token_pos] + [needle_fact] + haystack[token_pos:]
# Run with offloading
self.offloader.reset()
response = self.runner.generate(context + [needle_query])
# Check if the response contains the correct code
expected = str(trial * 1000 + 42)
recall = expected in response
results.append(recall)
return sum(results) / len(results)
def sweep_positions(self, context_length, num_positions=20):
"""Sweep needle position across the full context length."""
positions = [i / num_positions for i in range(num_positions)]
recall_map = {}
for pos in positions:
recall = self.needle_in_haystack(context_length, pos)
recall_map[pos] = recall
return recall_map
def _generate_haystack(self, num_tokens):
pass # Generate filler text tokens
The expected result: with a correctly implemented offloading system (all blocks restored before compute), needle-in-a-haystack recall should be identical to the non-offloaded baseline. Any degradation indicates a bug in the eviction or restore logic, not an inherent limitation of offloading.
A common implementation bug: evicting blocks that are still needed by in-flight prefetches on a different CUDA stream. The eviction sees the block as “not recently accessed” and offloads it, but a concurrent prefetch was about to use it. Fix: use CUDA events to track block dependencies across streams. Never evict a block that has an unrealized event dependency.
7. Scaling Limits and Future Directions
Each technique has limits:
KV offloading is bounded by PCIe bandwidth. If the model accesses more blocks per step than the interconnect can restore, the GPU stalls. The critical ratio is:
For a 70B decode step of 100ms and PCIe Gen5 at 32 GB/s: blocks. At 32 tokens per block, that is 10,240 tokens worth of KV restored per step — more than enough for sparse attention patterns but a ceiling for models that attend uniformly over 128K tokens.
Ring Attention is bounded by the number of ring steps: communications for GPUs. With NVLink, this is fine up to 8 GPUs. Beyond that (multi-node), the inter-node bandwidth (400 Gbps InfiniBand = ~40 GB/s, roughly 20x slower than NVLink) makes the communication-computation overlap break down unless chunks are very large.
Chunked prefill adds fixed overhead per chunk: kernel launch costs (~10us per kernel, hundreds of kernels per chunk = ~1-2ms total), KV cache lookup overhead, and reduced arithmetic intensity for small chunks. Below chunk size 1024, the overhead exceeds 20% and the approach becomes inefficient.
Scaling Limits Summary
| Technique | Scaling Limit | Bottleneck | Mitigation |
|---|---|---|---|
| KV offloading | ~10K tokens restored/step | PCIe Gen5 bandwidth | CXL memory (future), FP8 KV |
| Ring Attention | 8 GPUs (single node) | Inter-node bandwidth | Striped attention, async prefetch |
| Chunked prefill | Chunk size 1K minimum | Kernel launch overhead | Fused kernels, CUDA graphs per chunk |
| Prefix caching | Cache capacity | GPU/CPU memory | Hierarchical caching (GPU/CPU/SSD) |
Looking ahead, CXL (Compute Express Link) memory will blur the GPU/CPU memory boundary, providing 128+ GB of device-attached DRAM at bandwidth between CPU DRAM and GPU HBM. This collapses the KV offloading problem: instead of explicitly managing block transfers, the KV cache lives in CXL memory and is accessed transparently by the GPU at ~100 GB/s, roughly 3x slower than HBM but 3x faster than PCIe DMA transfers.
Reviewer Agent Validation
Challenge: Implement a function that computes the KV cache memory in bytes for an arbitrary transformer model and determines, given a GPU memory budget and model weights size, the maximum batch size achievable at a target context length. The function should also compute how many KV blocks need to be offloaded to CPU for a given batch size that exceeds the GPU budget.
Expected:
def compute_kv_budget(
num_layers, num_kv_heads, head_dim, dtype_bytes,
context_length, block_size,
gpu_total_bytes, model_weight_bytes, activation_overhead_bytes
):
kv_per_token = 2 * num_layers * num_kv_heads * head_dim * dtype_bytes
kv_per_request = kv_per_token * context_length
free_for_kv = gpu_total_bytes - model_weight_bytes - activation_overhead_bytes
max_batch = int(free_for_kv / kv_per_request)
tokens_per_block = block_size
blocks_per_request = (context_length + tokens_per_block - 1) // tokens_per_block
bytes_per_block = kv_per_token * tokens_per_block
max_gpu_blocks = int(free_for_kv / bytes_per_block)
return {
"kv_per_request_gb": kv_per_request / 1e9,
"max_batch_on_gpu": max_batch,
"blocks_per_request": blocks_per_request,
"max_gpu_blocks": max_gpu_blocks,
}
def offload_plan(total_requests, blocks_per_request, max_gpu_blocks):
total_blocks = total_requests * blocks_per_request
offloaded = max(0, total_blocks - max_gpu_blocks)
return {"total_blocks": total_blocks, "gpu_blocks": min(total_blocks, max_gpu_blocks), "cpu_blocks": offloaded}