Part of Series Inference Optimization Timeline 28 of 23
1 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 2 KV Cache: The Hidden Memory Giant in LLM Serving 3 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 4 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 5 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 6 Continuous Batching: The Complete Guide to LLM Inference Scheduling 7 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 8 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 9 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 10 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 11 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 12 Mamba and State Space Models: The O(n) Alternative to Attention 13 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 14 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 15 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 16 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 17 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 18 Memory Pool Management: Slab Allocators for GPU Inference 19 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 20 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 21 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 22 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 23 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification

Serving a 70B-parameter model at 128K context length is not a configuration knob you turn. It is a memory engineering problem. The KV cache alone for Llama 2 70B at 128K context with 80 attention heads, 128-dimensional head size, 80 layers, and FP16 precision occupies:

KV bytes=2×nlayers×nheads×dhead×seq_len×2=2×80×80×128×131072×242 GB\text{KV bytes} = 2 \times n_{\text{layers}} \times n_{\text{heads}} \times d_{\text{head}} \times \text{seq\_len} \times 2 = 2 \times 80 \times 80 \times 128 \times 131072 \times 2 \approx 42 \text{ GB}

The factor of 2 at the front accounts for K and V projections. The trailing factor of 2 is bytes per FP16 element. An H100 has 80 GB of HBM3. The model weights for Llama 70B in FP16 occupy ~140 GB (requiring at minimum 2-GPU tensor parallelism), and in INT4 occupy ~35 GB. Even with INT4 weights on a single GPU, model weights (35 GB) plus KV cache (42 GB) equals 77 GB — leaving 3 GB for activations, the CUDA context, and the framework overhead. That is for a single request. Multi-tenant serving with batch sizes of 8-64 is impossible without intervention.

This post covers three techniques that make long-context serving tractable: KV cache offloading (move cold KV blocks to CPU DRAM), Ring Attention (distribute the sequence across GPUs), and chunked prefill (process long prompts in manageable pieces). Each technique addresses a different dimension of the problem, and production systems combine all three.


1. The Memory Budget at Scale

Before diving into solutions, let us establish the numbers precisely. The KV cache size depends on the model architecture and the serving configuration.

📊

KV Cache Memory by Model and Context Length

ModelContextKV HeadsHead DimLayersKV Cache (FP16)KV Cache (FP8)
Llama 2 70B 4K 80 128 80 1.31 GB 0.66 GB
Llama 2 70B 32K 80 128 80 10.49 GB 5.24 GB
Llama 2 70B 128K 80 128 80 41.94 GB 20.97 GB
Llama 3 70B (GQA) 128K 8 128 80 5.24 GB 2.62 GB
Llama 3 405B (GQA) 128K 8 128 126 8.26 GB 4.13 GB
Mistral 7B (GQA) 128K 8 128 32 1.07 GB 0.54 GB
Note: GQA (Grouped Query Attention) reduces KV heads from n_heads to n_kv_heads. Llama 3 70B uses 8 KV heads instead of 80, reducing KV cache by 10x. KV Cache formula: 2 * layers * kv_heads * head_dim * seq_len * bytes_per_element.

The critical observation: GQA is the single most impactful architectural change for long-context serving. Llama 3 70B at 128K uses 5.24 GB of KV cache in FP16 versus 42 GB for Llama 2 70B with full MHA. But even 5.24 GB per request becomes 42 GB at batch size 8, and 168 GB at batch size 32. The memory problem does not disappear with GQA — it just shifts to higher batch sizes.

The Memory Equation for Multi-Tenant Serving

For a serving system handling BB concurrent requests at maximum context length SS:

Mtotal=Mweights+B×MKV(S)+Mactivations(B)+MoverheadM_{\text{total}} = M_{\text{weights}} + B \times M_{\text{KV}}(S) + M_{\text{activations}}(B) + M_{\text{overhead}}

where MKV(S)=2×L×nkv×d×S×dtype_sizeM_{\text{KV}}(S) = 2 \times L \times n_{\text{kv}} \times d \times S \times \text{dtype\_size} and MactivationsM_{\text{activations}} scales with batch size and the peak intermediate tensor size (typically the attention score matrix during prefill, which is O(S2)O(S^2) per head per layer but computed blockwise with FlashAttention).

For Llama 3 70B in INT4 weights, FP16 KV, batch size 16, 128K context on 2x H100 (160 GB total):

M=35 GB+16×5.24 GB+4 GB+2 GB=35+83.8+6=124.8 GBM = 35 \text{ GB} + 16 \times 5.24 \text{ GB} + \sim 4 \text{ GB} + \sim 2 \text{ GB} = 35 + 83.8 + 6 = 124.8 \text{ GB}

That fits in 160 GB — barely. At batch size 24 the KV alone is 125.8 GB. At batch size 32 it is 167.7 GB, exceeding our 2-GPU budget. And this assumes all 16 requests simultaneously fill 128K tokens. In practice, most requests use far less, which is where paged attention and dynamic memory allocation help. But the peak case must be handled, and that is where offloading enters.


2. KV Cache Offloading: GPU to CPU DRAM

The core idea is simple: not all KV cache entries are equally “hot.” During autoregressive decoding, the attention mechanism accesses the full KV cache for each new token — but the attention scores concentrate on recent tokens and a sparse set of “important” tokens from the distant past (attention sinks, topic-relevant tokens). Older KV blocks that receive low attention scores can be moved to CPU DRAM and restored on demand.

Block-Level Offloading Architecture

KV cache offloading operates at the block level, borrowing from vLLM’s PagedAttention abstraction. The KV cache is divided into blocks of BSB_S tokens (typically 16-256 tokens per block). Each block stores the K and V tensors for all layers:

block_size=2×L×nkv×d×BS×dtype_size\text{block\_size} = 2 \times L \times n_{\text{kv}} \times d \times B_S \times \text{dtype\_size}

For Llama 3 70B with GQA (8 KV heads, 128 dim, 80 layers) at block size 256, FP16:

block_size=2×80×8×128×256×2=83,886,080 bytes=80 MB\text{block\_size} = 2 \times 80 \times 8 \times 128 \times 256 \times 2 = 83,886,080 \text{ bytes} = 80 \text{ MB}

At block size 32:

block_size=2×80×8×128×32×2=10,485,760 bytes10 MB\text{block\_size} = 2 \times 80 \times 8 \times 128 \times 32 \times 2 = 10,485,760 \text{ bytes} \approx 10 \text{ MB}

Smaller blocks give finer-grained eviction control but increase the bookkeeping overhead (more blocks to track, more individual transfers).

import torch
from dataclasses import dataclass, field
from collections import OrderedDict

@dataclass
class KVBlock:
    """A block of KV cache entries for all layers."""
    block_id: int
    seq_id: int
    start_pos: int       # Token position within the sequence
    num_tokens: int       # Actual tokens stored (may be less than block_size)
    gpu_tensor: torch.Tensor = None  # [2, layers, kv_heads, num_tokens, head_dim]
    cpu_tensor: torch.Tensor = None  # Pinned CPU copy
    on_gpu: bool = True
    last_access_step: int = 0
    access_count: int = 0
    importance_score: float = 0.0

class KVOffloadManager:
    """Manages KV cache blocks across GPU HBM and CPU DRAM.

    Keeps a sliding window of W recent blocks on GPU.
    Older blocks are offloaded to pinned CPU memory.
    Blocks are restored to GPU on demand via async CUDA streams.
    """

    def __init__(
        self,
        num_layers: int,
        num_kv_heads: int,
        head_dim: int,
        block_size: int = 32,
        gpu_budget_blocks: int = 256,    # Max blocks on GPU
        dtype: torch.dtype = torch.float16,
        device: str = "cuda:0",
    ):
        self.num_layers = num_layers
        self.num_kv_heads = num_kv_heads
        self.head_dim = head_dim
        self.block_size = block_size
        self.gpu_budget = gpu_budget_blocks
        self.dtype = dtype
        self.device = device

        # Block shape: [2, layers, kv_heads, block_size, head_dim]
        #   2 = K and V
        self.block_shape = (2, num_layers, num_kv_heads, block_size, head_dim)
        self.block_bytes = 2 * num_layers * num_kv_heads * block_size * head_dim
        if dtype == torch.float16:
            self.block_bytes *= 2
        elif dtype == torch.float8_e4m3fn:
            self.block_bytes *= 1

        # GPU block pool (pre-allocated)
        self.gpu_pool = torch.zeros(
            gpu_budget_blocks, *self.block_shape,
            dtype=dtype, device=device
        )
        self.gpu_free_slots = list(range(gpu_budget_blocks))
        self.gpu_slot_map = {}  # block_id -> gpu_slot_index

        # CPU pinned memory pool
        self.cpu_pool = {}  # block_id -> pinned tensor

        # Block registry
        self.blocks = {}  # block_id -> KVBlock
        self.next_block_id = 0

        # Async transfer streams
        self.h2d_stream = torch.cuda.Stream(device=device)  # CPU->GPU
        self.d2h_stream = torch.cuda.Stream(device=device)  # GPU->CPU

        # LRU tracking for eviction
        self.gpu_lru = OrderedDict()  # block_id -> access_step
        self.current_step = 0

    def allocate_block(self, seq_id, start_pos, num_tokens):
        """Allocate a new KV block, evicting from GPU if necessary."""
        block_id = self.next_block_id
        self.next_block_id += 1

        # Get a GPU slot (evict if needed)
        gpu_slot = self._get_gpu_slot()

        block = KVBlock(
            block_id=block_id,
            seq_id=seq_id,
            start_pos=start_pos,
            num_tokens=num_tokens,
            on_gpu=True,
            last_access_step=self.current_step,
        )
        block.gpu_tensor = self.gpu_pool[gpu_slot]

        self.blocks[block_id] = block
        self.gpu_slot_map[block_id] = gpu_slot
        self.gpu_lru[block_id] = self.current_step

        return block

    def _get_gpu_slot(self):
        """Get a free GPU slot, evicting the LRU block if needed."""
        if self.gpu_free_slots:
            return self.gpu_free_slots.pop()

        # Evict LRU block to CPU
        evict_id, _ = next(iter(self.gpu_lru))
        self._offload_to_cpu(evict_id)
        return self.gpu_free_slots.pop()

    def _offload_to_cpu(self, block_id):
        """Move a block from GPU to CPU pinned memory (async)."""
        block = self.blocks[block_id]
        gpu_slot = self.gpu_slot_map[block_id]

        # Allocate pinned CPU tensor if not exists
        if block_id not in self.cpu_pool:
            self.cpu_pool[block_id] = torch.empty(
                self.block_shape, dtype=self.dtype,
                pin_memory=True
            )

        # Async copy GPU -> CPU
        with torch.cuda.stream(self.d2h_stream):
            self.cpu_pool[block_id].copy_(self.gpu_pool[gpu_slot], non_blocking=True)

        # Record event for synchronization
        self.d2h_stream.synchronize()

        # Free GPU slot
        block.on_gpu = False
        block.gpu_tensor = None
        del self.gpu_slot_map[block_id]
        del self.gpu_lru[block_id]
        self.gpu_free_slots.append(gpu_slot)

    def restore_to_gpu(self, block_id):
        """Restore a block from CPU to GPU (async). Returns the GPU tensor."""
        block = self.blocks[block_id]
        if block.on_gpu:
            # Already on GPU, just update LRU
            self.gpu_lru.move_to_end(block_id)
            gpu_slot = self.gpu_slot_map[block_id]
            return self.gpu_pool[gpu_slot]

        # Need to restore from CPU
        gpu_slot = self._get_gpu_slot()

        with torch.cuda.stream(self.h2d_stream):
            self.gpu_pool[gpu_slot].copy_(self.cpu_pool[block_id], non_blocking=True)

        # Must synchronize before the compute stream uses this data
        event = self.h2d_stream.record_event()
        torch.cuda.current_stream().wait_event(event)

        block.on_gpu = True
        block.gpu_tensor = self.gpu_pool[gpu_slot]
        self.gpu_slot_map[block_id] = gpu_slot
        self.gpu_lru[block_id] = self.current_step

        return self.gpu_pool[gpu_slot]

    def get_kv_for_attention(self, seq_id, required_positions):
        """Gather all KV blocks for a sequence, restoring from CPU as needed.

        Returns a contiguous KV tensor on GPU: [2, layers, kv_heads, total_tokens, head_dim]
        """
        self.current_step += 1

        seq_blocks = sorted(
            [b for b in self.blocks.values() if b.seq_id == seq_id],
            key=lambda b: b.start_pos
        )

        gpu_tensors = []
        for block in seq_blocks:
            if block.start_pos + block.num_tokens <= required_positions[0]:
                continue  # Before our window
            if block.start_pos > required_positions[-1]:
                break

            tensor = self.restore_to_gpu(block.block_id)
            # Slice to actual token count
            gpu_tensors.append(tensor[:, :, :, :block.num_tokens, :])

            block.last_access_step = self.current_step
            block.access_count += 1
            if block.block_id in self.gpu_lru:
                self.gpu_lru.move_to_end(block.block_id)

        if not gpu_tensors:
            return None

        return torch.cat(gpu_tensors, dim=3)  # Concatenate along token dimension

Transfer Latency Analysis

The critical performance question: how fast can we restore a KV block from CPU to GPU? The transfer uses PCIe or NVLink CPU-GPU interconnect. On a standard H100 SXM system:

  • PCIe Gen5 x16: 64 GB/s bidirectional (32 GB/s per direction effective)
  • Block size (Llama 3 70B, GQA, 32 tokens): 10 MB
  • Transfer time: 10 MB/32 GB/s=312.5 us10 \text{ MB} / 32 \text{ GB/s} = 312.5 \text{ us}

With block size 256 tokens (80 MB per block):

  • Transfer time: 80 MB/32 GB/s=2.5 ms80 \text{ MB} / 32 \text{ GB/s} = 2.5 \text{ ms}
📊

KV Block Transfer Latency (PCIe Gen5)

Block Size (tokens)Block Size (MB)Transfer TimeTokens Decoded During TransferOverhead per Token
16 5.24 164 us 0.002 (at 10 tok/s) 0.2%
32 10.49 328 us 0.003 0.3%
64 20.97 655 us 0.007 0.7%
128 41.94 1.31 ms 0.013 1.3%
256 83.89 2.62 ms 0.026 2.6%
Note: Transfer time assumes 32 GB/s effective PCIe Gen5 bandwidth. Tokens decoded during transfer assumes 100ms per-token latency for 70B model. Overhead assumes one block restore per generated token.

The latency is manageable for two reasons. First, the transfer is asynchronous: the GPU can compute attention over already-resident blocks while the CPU-to-GPU copy of the next block proceeds on a separate CUDA stream. Second, during decoding, the model generates one token at a time (typically 50-200ms for a 70B model), and the block restore (0.2-2.6ms) is a small fraction of that.

Eviction Policies

LRU (Least Recently Used) is the baseline eviction policy, but it is not optimal for attention patterns. Better policies consider the attention score distribution:

class AttentionAwareEvictionPolicy:
    """Eviction policy that considers attention patterns.

    Keeps blocks with high cumulative attention scores on GPU,
    even if they have not been accessed recently.
    This handles 'attention sinks' -- early tokens that receive
    disproportionate attention throughout generation.
    """

    def __init__(self, window_size, sink_tokens=4):
        self.window_size = window_size  # Always keep last W blocks
        self.sink_tokens = sink_tokens  # Always keep first S tokens
        self.attention_accumulator = {}  # block_id -> cumulative score
        self.decay = 0.95  # Exponential decay for old scores

    def update_scores(self, block_ids, attention_scores):
        """Update cumulative attention scores after each decode step.

        attention_scores: average attention weight each block received
                          from the new query token, averaged across
                          all heads and layers.
        """
        # Decay existing scores
        for bid in self.attention_accumulator:
            self.attention_accumulator[bid] *= self.decay

        # Add new scores
        for bid, score in zip(block_ids, attention_scores):
            if bid not in self.attention_accumulator:
                self.attention_accumulator[bid] = 0.0
            self.attention_accumulator[bid] += score

    def select_eviction_candidates(self, all_blocks, num_to_evict):
        """Select blocks to evict from GPU.

        Never evicts:
        - Blocks in the sliding window (most recent W blocks)
        - Blocks containing attention sink tokens (first S tokens)

        Among the rest, evicts blocks with lowest cumulative
        attention score.
        """
        # Sort by position
        sorted_blocks = sorted(all_blocks, key=lambda b: b.start_pos)

        # Protected: sink blocks and window blocks
        sink_blocks = set()
        window_blocks = set()

        # First N tokens are sinks
        pos = 0
        for block in sorted_blocks:
            if pos + block.num_tokens <= self.sink_tokens:
                sink_blocks.add(block.block_id)
            elif pos < self.sink_tokens:
                sink_blocks.add(block.block_id)
            pos += block.num_tokens

        # Last W blocks are window
        for block in sorted_blocks[-self.window_size:]:
            window_blocks.add(block.block_id)

        protected = sink_blocks | window_blocks

        # Eviction candidates: everything else, sorted by attention score
        candidates = [
            b for b in all_blocks
            if b.block_id not in protected and b.on_gpu
        ]
        candidates.sort(
            key=lambda b: self.attention_accumulator.get(b.block_id, 0.0)
        )

        return [c.block_id for c in candidates[:num_to_evict]]
Attention Sinks

Xiao et al. (2023) showed that the first 1-4 tokens in a sequence receive disproportionately high attention scores regardless of their content. These “attention sink” tokens act as a bias term in the softmax normalization. Evicting their KV cache entries causes severe quality degradation. Any eviction policy must protect these initial tokens unconditionally.

Prefetching Strategy

Naive on-demand restoration creates a bubble: the GPU stalls waiting for the block transfer. Prefetching eliminates this by predicting which blocks will be needed and starting the transfer before the attention kernel needs them.

class KVPrefetcher:
    """Prefetches KV blocks from CPU to GPU based on predicted access patterns.

    For autoregressive decoding, the access pattern is deterministic:
    every decode step accesses all blocks for the active sequences.
    The prefetcher restores blocks in order of expected attention
    score (high-attention blocks first) to minimize quality impact
    if a transfer is still in flight when the kernel launches.
    """

    def __init__(self, offload_manager, lookahead_steps=2):
        self.manager = offload_manager
        self.lookahead = lookahead_steps
        self.pending_transfers = {}  # block_id -> cuda_event

    def prefetch_for_step(self, seq_id, step):
        """Issue prefetch commands for blocks needed in upcoming steps.

        Strategy: For each active sequence, ensure all blocks are
        either on GPU or have an in-flight transfer. Prioritize
        blocks with high historical attention scores.
        """
        seq_blocks = [
            b for b in self.manager.blocks.values()
            if b.seq_id == seq_id and not b.on_gpu
        ]

        # Sort by importance (high attention score first)
        seq_blocks.sort(
            key=lambda b: b.importance_score, reverse=True
        )

        for block in seq_blocks:
            if block.block_id in self.pending_transfers:
                continue  # Already in flight

            if not self.manager.gpu_free_slots:
                break  # No GPU space

            # Start async transfer
            gpu_slot = self.manager._get_gpu_slot()

            with torch.cuda.stream(self.manager.h2d_stream):
                self.manager.gpu_pool[gpu_slot].copy_(
                    self.manager.cpu_pool[block.block_id],
                    non_blocking=True
                )
                event = self.manager.h2d_stream.record_event()

            self.pending_transfers[block.block_id] = (gpu_slot, event)

    def wait_and_finalize(self, block_id):
        """Wait for a prefetched block and register it as GPU-resident."""
        if block_id not in self.pending_transfers:
            return

        gpu_slot, event = self.pending_transfers.pop(block_id)
        event.synchronize()

        block = self.manager.blocks[block_id]
        block.on_gpu = True
        block.gpu_tensor = self.manager.gpu_pool[gpu_slot]
        self.manager.gpu_slot_map[block_id] = gpu_slot
        self.manager.gpu_lru[block_id] = self.manager.current_step

3. Ring Attention for Distributed Long Context

KV offloading handles the case where a single GPU cannot fit the entire KV cache. Ring Attention (Liu et al., 2023) handles the case where the sequence itself is too long to process on a single GPU during prefill. Rather than replicating the full sequence on each device, Ring Attention partitions the sequence across GPUs and implements attention via a ring communication pattern.

The Ring Attention Algorithm

Given PP GPUs and a sequence of NN tokens, each GPU ii holds a contiguous chunk of N/PN/P tokens: queries QiQ_i, keys KiK_i, and values ViV_i.

To compute attention for QiQ_i, GPU ii needs KV from all chunks. Ring Attention circulates the KV blocks around the ring:

Step 0: GPU_i computes attention(Q_i, K_i, V_i)      -- local KV
Step 1: GPU_i receives K_{i-1}, V_{i-1} from left neighbor
        GPU_i computes attention(Q_i, K_{i-1}, V_{i-1})
Step 2: GPU_i receives K_{i-2}, V_{i-2}
        GPU_i computes attention(Q_i, K_{i-2}, V_{i-2})
...
Step P-1: GPU_i has accumulated attention from all chunks

At each step, every GPU simultaneously sends its current KV buffer to the right neighbor and receives a new KV buffer from the left neighbor. The communication overlaps with the attention computation on the previous KV buffer.

import torch
import torch.distributed as dist

def ring_attention_forward(
    Q, K, V,
    rank, world_size,
    process_group,
    causal=True,
    block_size=4096
):
    """Ring Attention: distributed attention across GPUs.

    Each GPU holds a chunk of the sequence.
    Q: [batch, local_seq_len, num_heads, head_dim] -- queries for this chunk
    K: [batch, local_seq_len, num_kv_heads, head_dim] -- keys for this chunk
    V: [batch, local_seq_len, num_kv_heads, head_dim] -- values for this chunk

    Returns: attention output for this GPU's query chunk.
    """
    B, S_local, H, D = Q.shape
    _, _, H_kv, _ = K.shape

    # Initialize accumulators for online softmax
    # O_i = sum of softmax(scores) * V, unnormalized
    # l_i = sum of exp(scores - max_score), for normalization
    # m_i = running max of scores, for numerical stability
    O = torch.zeros_like(Q)          # [B, S_local, H, D]
    l = torch.zeros(B, S_local, H, 1, device=Q.device, dtype=Q.dtype)
    m = torch.full((B, S_local, H, 1), float('-inf'), device=Q.device, dtype=Q.dtype)

    # Double-buffered KV for overlapping communication and compute
    kv_send = torch.cat([K, V], dim=2)  # [B, S_local, 2*H_kv, D]
    kv_recv = torch.empty_like(kv_send)

    for step in range(world_size):
        # Source of the KV at this step
        src_rank = (rank - step) % world_size

        # Start async ring communication (send right, recv from left)
        if step < world_size - 1:
            send_rank = (rank + 1) % world_size
            recv_rank = (rank - 1) % world_size
            send_op = dist.isend(kv_send, dst=send_rank, group=process_group)
            recv_op = dist.irecv(kv_recv, src=recv_rank, group=process_group)

        # Extract K, V from the current buffer
        K_step = kv_send[:, :, :H_kv, :]
        V_step = kv_send[:, :, H_kv:, :]

        # Causal masking: GPU i's queries at positions [i*S_local, (i+1)*S_local)
        # can only attend to KV at positions up to their own position.
        # KV from src_rank covers positions [src_rank*S_local, (src_rank+1)*S_local)
        if causal and src_rank > rank:
            # This KV chunk is entirely in the future -- skip
            pass
        else:
            # Compute blockwise attention with online softmax update
            # Score: [B, S_local, H, S_local_kv]
            scores = torch.einsum('bshd,bthd->bsht', Q, K_step.repeat(1, 1, H // H_kv, 1))
            scores = scores / (D ** 0.5)

            # Apply causal mask for the diagonal block
            if causal and src_rank == rank:
                # Only this block needs a causal mask
                q_positions = torch.arange(rank * S_local, (rank + 1) * S_local, device=Q.device)
                kv_positions = torch.arange(src_rank * S_local, (src_rank + 1) * S_local, device=Q.device)
                mask = q_positions.unsqueeze(1) < kv_positions.unsqueeze(0)
                scores[:, :, :, :].masked_fill_(mask.unsqueeze(0).unsqueeze(2), float('-inf'))

            # Online softmax accumulation (FlashAttention-2 style)
            m_step = scores.max(dim=-1, keepdim=True).values  # [B, S, H, 1]
            m_new = torch.maximum(m, m_step)

            # Rescale old accumulator
            exp_diff_old = torch.exp(m - m_new)
            # New exponentials
            exp_scores = torch.exp(scores - m_new)

            # Update running sum of exponentials
            l_new = l * exp_diff_old + exp_scores.sum(dim=-1, keepdim=True)

            # Update output accumulator
            O = O * exp_diff_old + torch.einsum('bsht,bthd->bshd', exp_scores, V_step.repeat(1, 1, H // H_kv, 1))

            l = l_new
            m = m_new

        # Wait for communication to finish, then swap buffers
        if step < world_size - 1:
            send_op.wait()
            recv_op.wait()
            kv_send, kv_recv = kv_recv, kv_send

    # Final normalization
    O = O / l

    return O
ℹ️ Online Softmax Is Essential

Ring Attention requires the online softmax trick (Milakov and Gimelshein, 2018; Dao et al., 2022). Without it, you would need to compute the global softmax denominator jexp(qkj)\sum_j \exp(q \cdot k_j) across all chunks before computing outputs — requiring a global reduction and preventing overlap. With online softmax, each chunk contributes incrementally: update the running max mm, rescale the previous accumulator, and add the new contribution. The final normalization at the end produces numerically identical results to standard attention.

Communication-Computation Overlap

The ring pattern’s key property is that communication and computation overlap. At each step, while GPU ii computes attention(Q_i, KV_current), it simultaneously sends KV_current to GPU i+1i+1 and receives KV_next from GPU $i-1`. The overlap is effective as long as:

TcomputeTcommT_{\text{compute}} \geq T_{\text{comm}}

For a local chunk of S/PS/P tokens:

Tcompute2×B×H×(S/P)2×DFLOPSGPUT_{\text{compute}} \approx \frac{2 \times B \times H \times (S/P)^2 \times D}{\text{FLOPS}_{\text{GPU}}}

Tcomm2×B×(S/P)×Hkv×D×dtype_sizeBWlinkT_{\text{comm}} \approx \frac{2 \times B \times (S/P) \times H_{\text{kv}} \times D \times \text{dtype\_size}}{\text{BW}_{\text{link}}}

Ring Attention: Compute vs Communication Time

(ms)
Compute (128K/4 GPUs) 32K tokens/GPU, 8 heads
42 ms
NVLink comm (128K/4) 900 GB/s bidirectional
3.2 ms
Compute (512K/8 GPUs) 64K tokens/GPU
67 ms
+59.5%
NVLink comm (512K/8) 7 ring steps
6.4 ms
Compute (1M/16 GPUs) 64K tokens/GPU
65 ms
+54.8%
NVLink comm (1M/16) 15 ring steps
12.8 ms

With NVLink (900 GB/s), communication time is 5-15x smaller than compute time for typical configurations. The overlap is nearly perfect. On PCIe-only systems (64 GB/s), communication time grows by ~14x, and overlap breaks down beyond 4 GPUs.

Handling Causal Masking Efficiently

A subtle problem with Ring Attention: causal masking creates workload imbalance. GPU 0 (holding the earliest tokens) has queries that can only attend to their own chunk. GPU P1P-1 (holding the latest tokens) has queries that attend to all PP chunks. In a naive implementation, GPU 0 finishes in 1 ring step while GPU P1P-1 takes all PP steps.

The Striped Attention optimization (Brandon et al., 2023) interleaves token positions across GPUs instead of assigning contiguous chunks:

Contiguous:  GPU0=[0,1,2,3]  GPU1=[4,5,6,7]  GPU2=[8,9,10,11]  GPU3=[12,13,14,15]
Striped:     GPU0=[0,4,8,12] GPU1=[1,5,9,13]  GPU2=[2,6,10,14]  GPU3=[3,7,11,15]

With striped assignment, every GPU has a mix of early and late tokens, and the causal mask blocks are roughly equal across GPUs. The workload imbalance drops from O(P)O(P) to O(1)O(1).

def stripe_sequence(tokens, world_size):
    """Distribute tokens across GPUs in a striped pattern.

    Instead of contiguous chunks [0..N/P], [N/P..2N/P], ...,
    assign token i to GPU (i % world_size).
    This balances causal mask workload across GPUs.
    """
    N = len(tokens)
    local_indices = {}
    for rank in range(world_size):
        local_indices[rank] = list(range(rank, N, world_size))
    return local_indices

def unstripe_output(local_outputs, local_indices, total_length):
    """Reassemble the full output from striped local results."""
    full_output = torch.empty(total_length, *local_outputs[0].shape[1:],
                               device=local_outputs[0].device,
                               dtype=local_outputs[0].dtype)
    for rank, indices in local_indices.items():
        full_output[indices] = local_outputs[rank]
    return full_output

4. Chunked Prefill Processing

The third technique addresses a different bottleneck: processing the initial prompt (prefill phase) for very long contexts. A 128K-token prompt processed in a single forward pass creates an intermediate attention matrix of size 128K×128K×nheads128K \times 128K \times n_{\text{heads}} — which, even with FlashAttention’s tiling, requires massive compute. More critically, the prefill blocks the GPU for the entire duration, preventing any decode-phase tokens from being generated for other requests.

Chunked prefill splits the long prompt into chunks of CC tokens (typically 4K-8K) and processes them sequentially. Each chunk attends to all previous chunks via the already-computed KV cache, plus its own tokens via causal self-attention.

The Chunked Prefill Algorithm

class ChunkedPrefillEngine:
    """Process long prompts in chunks to bound memory and enable scheduling.

    Key insight: attention is decomposable. Processing tokens [0, C) first,
    then [C, 2C) with the KV cache from [0, C), produces identical results
    to processing [0, 2C) at once. This is because:

    - For positions in [C, 2C): they attend to [0, C) via cached KV and to
      [C, 2C) via causal self-attention within the chunk. The attention
      scores and outputs are numerically identical to full-sequence attention.

    - For positions in [0, C): they were already fully processed in chunk 0.
      Their KV entries are correct and final.
    """

    def __init__(self, model, chunk_size=4096, kv_cache=None):
        self.model = model
        self.chunk_size = chunk_size
        self.kv_cache = kv_cache  # Shared KV cache (e.g., vLLM PagedAttention)

    def prefill(self, input_ids, seq_id):
        """Process a long prompt in chunks.

        input_ids: [1, total_seq_len] -- the full prompt
        Returns: logits for the last token (for generation)
        """
        total_len = input_ids.shape[1]
        num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size

        all_logits = None

        for chunk_idx in range(num_chunks):
            start = chunk_idx * self.chunk_size
            end = min(start + self.chunk_size, total_len)
            chunk_ids = input_ids[:, start:end]

            # The model attends to:
            # 1. All previous KV cache entries (positions 0..start-1)
            # 2. Current chunk tokens (positions start..end-1) with causal mask

            # Position offsets so the model knows where these tokens sit
            # in the full sequence
            position_ids = torch.arange(start, end, device=chunk_ids.device).unsqueeze(0)

            outputs = self.model(
                input_ids=chunk_ids,
                position_ids=position_ids,
                past_key_values=self.kv_cache.get_cache(seq_id),
                use_cache=True,
            )

            # The model appends this chunk's KV to the cache automatically
            # Now the cache covers positions 0..end-1

            all_logits = outputs.logits

        # Return logits for the last token only (used for generation)
        return all_logits[:, -1:, :]

    def prefill_with_scheduling(self, input_ids, seq_id, decode_queue):
        """Chunked prefill interleaved with decode steps for other requests.

        After each prefill chunk, check if any decode-phase requests
        are waiting. If so, batch their decode steps together and
        process them before continuing with the next prefill chunk.

        This prevents a single long-context prefill from starving
        all decode-phase requests (which are latency-sensitive).
        """
        total_len = input_ids.shape[1]
        num_chunks = (total_len + self.chunk_size - 1) // self.chunk_size

        for chunk_idx in range(num_chunks):
            start = chunk_idx * self.chunk_size
            end = min(start + self.chunk_size, total_len)
            chunk_ids = input_ids[:, start:end]
            position_ids = torch.arange(start, end, device=chunk_ids.device).unsqueeze(0)

            # Process this prefill chunk
            self.model(
                input_ids=chunk_ids,
                position_ids=position_ids,
                past_key_values=self.kv_cache.get_cache(seq_id),
                use_cache=True,
            )

            # Yield to decode requests between chunks
            if decode_queue.has_pending():
                decode_batch = decode_queue.get_batch()
                self._run_decode_step(decode_batch)

    def _run_decode_step(self, decode_batch):
        """Run one decode step for a batch of requests."""
        input_ids = torch.stack([r.last_token for r in decode_batch])
        position_ids = torch.stack([r.current_pos for r in decode_batch])

        outputs = self.model(
            input_ids=input_ids,
            position_ids=position_ids,
            past_key_values=[self.kv_cache.get_cache(r.seq_id) for r in decode_batch],
            use_cache=True,
        )

        # Sample next tokens and update decode state
        for i, req in enumerate(decode_batch):
            logits = outputs.logits[i, -1, :]
            next_token = torch.argmax(logits)
            req.append_token(next_token)

Chunk Size Selection

The chunk size CC controls the tradeoff between prefill throughput and decode latency:

📊

Chunk Size Impact on Prefill and Decode

Chunk SizePrefill ThroughputDecode InterruptionPeak Activation MemoryUse Case
1024 65% of max Every 8ms 0.5 GB Latency-critical (chatbots)
4096 88% of max Every 32ms 2 GB Balanced (default)
8192 94% of max Every 65ms 4 GB Throughput-focused
16384 97% of max Every 130ms 8 GB Batch processing
Full sequence 100% Never (blocked) O(S) GB Offline only
Note: Measured on H100, Llama 3 70B. Prefill throughput as percentage of unchunked processing. Decode interruption = time between decode scheduling opportunities.

The sweet spot for interactive serving is 4096-8192. At 4K chunks, the overhead versus unchunked processing is ~12% (due to repeated KV cache lookups and kernel launch overhead per chunk), but decode requests can be serviced every 32ms, keeping time-to-first-token (TTFT) for concurrent requests reasonable.

⚠️ Correctness Guarantee

Chunked prefill produces numerically identical results to full-sequence prefill. This is not an approximation. Each chunk computes exact causal attention over its tokens plus all previously cached KV. The decomposition works because attention is a function of (Q, K, V) pairs and the causal mask only depends on position indices, both of which are correctly supplied per-chunk via position_ids and the growing KV cache.


5. Combining Techniques: Production Architecture

A production long-context serving system combines all three techniques. Here is the architecture used by systems like vLLM, SGLang, and TensorRT-LLM for 128K+ serving:

Layer 1: Prefix Caching

Many long-context requests share a common prefix (system prompt, document context). Prefix caching stores the KV cache for shared prefixes and reuses it across requests.

class PrefixCacheManager:
    """Cache KV for common prefixes across requests.

    Hash-based lookup: hash the token sequence to find cached KV.
    Supports hierarchical matching -- if the full prefix is not
    cached, find the longest cached prefix and compute the rest.
    """

    def __init__(self, block_size=256):
        self.block_size = block_size
        self.cache = {}  # hash -> list of KV block IDs
        self.hash_to_tokens = {}  # hash -> token sequence (for verification)

    def _hash_block(self, token_ids):
        """Hash a block of tokens for cache lookup."""
        return hash(tuple(token_ids.tolist()))

    def lookup(self, input_ids):
        """Find the longest cached prefix for this input.

        Returns: (num_cached_tokens, kv_block_ids)
        """
        total_len = input_ids.shape[1]
        cached_blocks = []
        cached_tokens = 0

        for start in range(0, total_len, self.block_size):
            end = min(start + self.block_size, total_len)
            block_tokens = input_ids[0, start:end]
            block_hash = self._hash_block(block_tokens)

            if block_hash in self.cache:
                # Verify (hash collision check)
                if torch.equal(self.hash_to_tokens[block_hash], block_tokens):
                    cached_blocks.extend(self.cache[block_hash])
                    cached_tokens = end
                else:
                    break  # Hash collision, stop matching
            else:
                break  # No more cached prefix

        return cached_tokens, cached_blocks

    def store(self, input_ids, kv_block_ids):
        """Store KV blocks for this prefix in the cache."""
        total_len = input_ids.shape[1]
        block_idx = 0

        for start in range(0, total_len, self.block_size):
            end = min(start + self.block_size, total_len)
            block_tokens = input_ids[0, start:end]
            block_hash = self._hash_block(block_tokens)

            if block_hash not in self.cache:
                self.cache[block_hash] = [kv_block_ids[block_idx]]
                self.hash_to_tokens[block_hash] = block_tokens.clone()

            block_idx += 1

Layer 2: Integrated Serving Pipeline

class LongContextServingPipeline:
    """Production pipeline combining prefix caching, chunked prefill,
    KV offloading, and (optionally) Ring Attention.

    Request flow:
    1. Check prefix cache for shared prefix hit
    2. Chunked prefill for remaining tokens
    3. During decode, offload cold KV blocks to CPU
    4. For sequences exceeding single-GPU KV budget, use Ring Attention
    """

    def __init__(
        self,
        model,
        chunk_size=4096,
        gpu_kv_budget_gb=40,
        enable_offloading=True,
        enable_ring_attention=False,
        world_size=1,
    ):
        self.model = model
        self.chunk_size = chunk_size
        self.prefix_cache = PrefixCacheManager(block_size=256)
        self.chunked_prefill = ChunkedPrefillEngine(model, chunk_size)

        if enable_offloading:
            # Calculate max blocks that fit in GPU budget
            block_bytes = model.config.kv_block_bytes(block_size=32)
            max_gpu_blocks = int(gpu_kv_budget_gb * 1e9 / block_bytes)
            self.offloader = KVOffloadManager(
                num_layers=model.config.num_hidden_layers,
                num_kv_heads=model.config.num_key_value_heads,
                head_dim=model.config.hidden_size // model.config.num_attention_heads,
                gpu_budget_blocks=max_gpu_blocks,
            )
        else:
            self.offloader = None

        self.ring_attention = enable_ring_attention
        self.world_size = world_size

    def serve_request(self, input_ids, max_new_tokens=4096):
        """Full serving pipeline for a single request."""
        seq_id = self._allocate_sequence()

        # Step 1: Prefix cache lookup
        cached_tokens, cached_blocks = self.prefix_cache.lookup(input_ids)

        if cached_tokens > 0:
            # Reuse cached KV blocks
            self._load_cached_kv(seq_id, cached_blocks)
            remaining_ids = input_ids[:, cached_tokens:]
        else:
            remaining_ids = input_ids

        # Step 2: Chunked prefill for uncached tokens
        if remaining_ids.shape[1] > 0:
            logits = self.chunked_prefill.prefill(remaining_ids, seq_id)

        # Step 3: Store prefix in cache for future requests
        kv_blocks = self._get_kv_block_ids(seq_id)
        self.prefix_cache.store(input_ids, kv_blocks)

        # Step 4: Autoregressive decode with KV offloading
        generated_tokens = []
        for step in range(max_new_tokens):
            # Offload cold blocks if needed
            if self.offloader:
                self.offloader.current_step = step
                self._manage_kv_residency(seq_id)

            # Decode one token
            logits = self._decode_step(seq_id)
            next_token = torch.argmax(logits, dim=-1)
            generated_tokens.append(next_token.item())

            if next_token.item() == self.model.config.eos_token_id:
                break

        return generated_tokens

    def _manage_kv_residency(self, seq_id):
        """Decide which KV blocks stay on GPU and which get offloaded."""
        total_blocks = self._count_kv_blocks(seq_id)
        gpu_blocks = self._count_gpu_resident_blocks(seq_id)

        if gpu_blocks > self.offloader.gpu_budget * 0.9:
            # Approaching GPU budget -- evict old blocks
            num_evict = gpu_blocks - int(self.offloader.gpu_budget * 0.75)
            candidates = self.offloader.select_eviction_candidates(
                self._get_blocks(seq_id), num_evict
            )
            for block_id in candidates:
                self.offloader._offload_to_cpu(block_id)

    def _allocate_sequence(self):
        pass  # Implementation depends on the KV cache backend

    def _load_cached_kv(self, seq_id, block_ids):
        pass

    def _get_kv_block_ids(self, seq_id):
        pass

    def _decode_step(self, seq_id):
        pass

    def _count_kv_blocks(self, seq_id):
        pass

    def _count_gpu_resident_blocks(self, seq_id):
        pass

    def _get_blocks(self, seq_id):
        pass

End-to-End Performance

128K Context Serving: Throughput by Configuration

(tokens/s)
Naive (no optimization) OOM at batch > 1
12 tokens/s
Chunked prefill only Batch 1, 4K chunks
45 tokens/s
+275.0%
+ KV offloading Batch 8, cold blocks to CPU
180 tokens/s
+1400.0%
+ Prefix caching 80% prefix hit rate
310 tokens/s
+2483.3%
+ Ring Attention (4 GPU) Sequence parallel
520 tokens/s
+4233.3%
📊

Production Configuration for 128K Serving (Llama 3 70B)

ComponentSettingRationale
Weight precision INT4 (AWQ) 35 GB weights, fits single GPU
KV cache precision FP8 E4M3 Halves KV memory vs FP16, negligible quality loss
Chunk size 8192 tokens 94% prefill efficiency, 65ms decode gaps
KV offload policy Attention-aware, 4 sink tokens Protects attention sinks, evicts cold middle context
GPU KV budget 75% of free HBM 25% headroom for activation spikes
Block size 32 tokens Fine-grained eviction, 10 MB per block
Prefix cache 256-token blocks, LRU eviction Shared system prompts across requests
Tensor parallelism 2 GPUs (TP=2) Model weights split, KV per-GPU
Note: Configuration for 2x H100 80GB SXM serving Llama 3 70B at 128K context. Supports batch size 8-12 depending on actual context lengths.

6. Measuring Long-Context Quality Under Offloading

KV offloading with eviction is, in theory, lossy: if a block is evicted and not restored before the attention kernel runs, the model produces different outputs. In practice, the eviction policy ensures all needed blocks are restored before compute, making it lossless. But there are failure modes:

class LongContextQualityBenchmark:
    """Benchmark to verify that KV offloading does not degrade quality.

    Tests:
    1. Needle-in-a-haystack: plant a fact at position P in a long context,
       ask about it at the end. Measure recall across positions.
    2. Passkey retrieval: hide a 5-digit passkey at various depths.
    3. Multi-hop reasoning: facts spread across the context, answer
       requires combining 2-3 of them.
    """

    def __init__(self, model_runner, offload_manager):
        self.runner = model_runner
        self.offloader = offload_manager

    def needle_in_haystack(self, context_length, needle_position, num_trials=20):
        """Test recall of a specific fact placed at needle_position."""
        results = []
        for trial in range(num_trials):
            # Generate haystack (irrelevant text)
            haystack = self._generate_haystack(context_length)

            # Insert needle
            needle_fact = f"The special code is {trial * 1000 + 42}."
            needle_query = "What is the special code?"

            # Place needle at specified position
            token_pos = int(needle_position * context_length)
            context = haystack[:token_pos] + [needle_fact] + haystack[token_pos:]

            # Run with offloading
            self.offloader.reset()
            response = self.runner.generate(context + [needle_query])

            # Check if the response contains the correct code
            expected = str(trial * 1000 + 42)
            recall = expected in response
            results.append(recall)

        return sum(results) / len(results)

    def sweep_positions(self, context_length, num_positions=20):
        """Sweep needle position across the full context length."""
        positions = [i / num_positions for i in range(num_positions)]
        recall_map = {}

        for pos in positions:
            recall = self.needle_in_haystack(context_length, pos)
            recall_map[pos] = recall

        return recall_map

    def _generate_haystack(self, num_tokens):
        pass  # Generate filler text tokens

The expected result: with a correctly implemented offloading system (all blocks restored before compute), needle-in-a-haystack recall should be identical to the non-offloaded baseline. Any degradation indicates a bug in the eviction or restore logic, not an inherent limitation of offloading.

🚨 Eviction Ordering Bug

A common implementation bug: evicting blocks that are still needed by in-flight prefetches on a different CUDA stream. The eviction sees the block as “not recently accessed” and offloads it, but a concurrent prefetch was about to use it. Fix: use CUDA events to track block dependencies across streams. Never evict a block that has an unrealized event dependency.


7. Scaling Limits and Future Directions

Each technique has limits:

KV offloading is bounded by PCIe bandwidth. If the model accesses more blocks per step than the interconnect can restore, the GPU stalls. The critical ratio is:

max_offloaded_restores_per_step=Tdecode_step×BWPCIeblock_size_bytes\text{max\_offloaded\_restores\_per\_step} = \frac{T_{\text{decode\_step}} \times \text{BW}_{\text{PCIe}}}{\text{block\_size\_bytes}}

For a 70B decode step of 100ms and PCIe Gen5 at 32 GB/s: 100ms×32 GB/s/10 MB=320100 \text{ms} \times 32 \text{ GB/s} / 10 \text{ MB} = 320 blocks. At 32 tokens per block, that is 10,240 tokens worth of KV restored per step — more than enough for sparse attention patterns but a ceiling for models that attend uniformly over 128K tokens.

Ring Attention is bounded by the number of ring steps: P1P - 1 communications for PP GPUs. With NVLink, this is fine up to 8 GPUs. Beyond that (multi-node), the inter-node bandwidth (400 Gbps InfiniBand = ~40 GB/s, roughly 20x slower than NVLink) makes the communication-computation overlap break down unless chunks are very large.

Chunked prefill adds fixed overhead per chunk: kernel launch costs (~10us per kernel, hundreds of kernels per chunk = ~1-2ms total), KV cache lookup overhead, and reduced arithmetic intensity for small chunks. Below chunk size 1024, the overhead exceeds 20% and the approach becomes inefficient.

📊

Scaling Limits Summary

TechniqueScaling LimitBottleneckMitigation
KV offloading ~10K tokens restored/step PCIe Gen5 bandwidth CXL memory (future), FP8 KV
Ring Attention 8 GPUs (single node) Inter-node bandwidth Striped attention, async prefetch
Chunked prefill Chunk size 1K minimum Kernel launch overhead Fused kernels, CUDA graphs per chunk
Prefix caching Cache capacity GPU/CPU memory Hierarchical caching (GPU/CPU/SSD)
Note: Limits are approximate and depend on model size, hardware generation, and serving configuration.

Looking ahead, CXL (Compute Express Link) memory will blur the GPU/CPU memory boundary, providing 128+ GB of device-attached DRAM at bandwidth between CPU DRAM and GPU HBM. This collapses the KV offloading problem: instead of explicitly managing block transfers, the KV cache lives in CXL memory and is accessed transparently by the GPU at ~100 GB/s, roughly 3x slower than HBM but 3x faster than PCIe DMA transfers.


Reviewer Agent Validation

Challenge: Implement a function that computes the KV cache memory in bytes for an arbitrary transformer model and determines, given a GPU memory budget and model weights size, the maximum batch size achievable at a target context length. The function should also compute how many KV blocks need to be offloaded to CPU for a given batch size that exceeds the GPU budget.

Expected:

def compute_kv_budget(
    num_layers, num_kv_heads, head_dim, dtype_bytes,
    context_length, block_size,
    gpu_total_bytes, model_weight_bytes, activation_overhead_bytes
):
    kv_per_token = 2 * num_layers * num_kv_heads * head_dim * dtype_bytes
    kv_per_request = kv_per_token * context_length

    free_for_kv = gpu_total_bytes - model_weight_bytes - activation_overhead_bytes
    max_batch = int(free_for_kv / kv_per_request)

    tokens_per_block = block_size
    blocks_per_request = (context_length + tokens_per_block - 1) // tokens_per_block
    bytes_per_block = kv_per_token * tokens_per_block
    max_gpu_blocks = int(free_for_kv / bytes_per_block)

    return {
        "kv_per_request_gb": kv_per_request / 1e9,
        "max_batch_on_gpu": max_batch,
        "blocks_per_request": blocks_per_request,
        "max_gpu_blocks": max_gpu_blocks,
    }

def offload_plan(total_requests, blocks_per_request, max_gpu_blocks):
    total_blocks = total_requests * blocks_per_request
    offloaded = max(0, total_blocks - max_gpu_blocks)
    return {"total_blocks": total_blocks, "gpu_blocks": min(total_blocks, max_gpu_blocks), "cpu_blocks": offloaded}