The same GPU kernel that makes prefill fast will destroy your decode throughput. Prefill wants massive batch sizes and compute-bound GEMMs that saturate tensor cores. Decode wants minimal latency and memory-bandwidth optimization because you’re loading 140 GB of model weights to produce a single token. Treating them as one workload guarantees you’ll optimize for neither. Production serving systems recognize this split explicitly: chunked prefill to prevent decode stalls, CUDA graphs to eliminate decode overhead, and separate routing strategies because the bottleneck isn’t the same. This post covers the optimization techniques that acknowledge the arithmetic intensity divide.
The Arithmetic Intensity Divide
The key metric is arithmetic intensity: FLOPs per byte of memory accessed. High arithmetic intensity means compute-bound (GPU cores are the bottleneck). Low arithmetic intensity means bandwidth-bound (memory transfer is the bottleneck).
For a transformer layer with hidden dimension , the attention projection matrices are and the FFN matrices are and . Total weight bytes per layer (FP16): bytes.
Prefill with sequence length and batch size :
- The input activation matrix is
- Each weight matrix GEMM does FLOPs
- Weight loading: bytes (FP16)
- Arithmetic intensity: FLOPs/byte
Decode with batch size , generating 1 token per request:
- The input activation matrix is
- Each weight matrix GEMM does FLOPs
- Weight loading: bytes (same weights, regardless of batch size)
- Arithmetic intensity: FLOPs/byte
Arithmetic Intensity: Prefill vs Decode
| Phase | Batch Size | Seq Length | Arithmetic Intensity | Bottleneck (H100) |
|---|---|---|---|---|
| Prefill | 1 | 2048 | 2048 FLOPs/byte | Compute-bound |
| Prefill | 1 | 512 | 512 FLOPs/byte | Compute-bound |
| Prefill | 1 | 32 | 32 FLOPs/byte | Borderline |
| Decode | 1 | 1 | 1 FLOP/byte | Bandwidth-bound |
| Decode | 32 | 1 | 32 FLOPs/byte | Borderline |
| Decode | 256 | 1 | 256 FLOPs/byte | Compute-bound |
The H100 has 989 TFLOPS (FP16 tensor core) and 3.35 TB/s HBM bandwidth. The compute-bandwidth balance point is:
Prefill with any reasonable sequence length exceeds this. Decode at batch size 1 is 295x below this. This is why you cannot optimize both with the same approach.
Prefill Optimization: Maximizing Compute Utilization
Prefill processes the full prompt in one forward pass. The input is a matrix where is the prompt length. Every GEMM operates on this full matrix.
GEMM Tiling for Prefill
For prefill, the GEMMs are large enough to saturate GPU compute. The optimization goal is to maximize tensor core utilization through proper tiling:
import torch
import triton
import triton.language as tl
@triton.jit
def prefill_gemm_kernel(
A_ptr, B_ptr, C_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""GEMM kernel optimized for prefill: large M (seq_len), large K (hidden_dim).
Tile sizes chosen for H100 SM occupancy."""
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_M)
num_pid_n = tl.cdiv(N, BLOCK_N)
# 2D tile swizzling for L2 cache locality
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
# Compute tile offsets
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# Initialize accumulator in FP32
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# Main loop over K dimension
for k in range(0, K, BLOCK_K):
a = tl.load(
A_ptr + offs_m[:, None] * stride_am + (k + offs_k[None, :]) * stride_ak,
mask=(offs_m[:, None] < M) & ((k + offs_k[None, :]) < K),
other=0.0,
)
b = tl.load(
B_ptr + (k + offs_k[:, None]) * stride_bk + offs_n[None, :] * stride_bn,
mask=((k + offs_k[:, None]) < K) & (offs_n[None, :] < N),
other=0.0,
)
acc += tl.dot(a, b) # Tensor core HMMA
# Store result
c = acc.to(tl.float16)
tl.store(
C_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn,
c,
mask=(offs_m[:, None] < M) & (offs_n[None, :] < N),
)
For prefill on H100, optimal tile sizes are typically BLOCK_M=128, BLOCK_N=256, BLOCK_K=64 for FP16. This gives 128x256 = 32768 elements per output tile, requiring 128x64 + 64x256 = 24576 elements of input, fitting comfortably in shared memory (192 KB per SM on H100).
Attention During Prefill
Prefill attention computes the full attention matrix. FlashAttention tiles this computation to keep intermediates in SRAM:
def prefill_attention_analysis(seq_len, num_heads, head_dim, dtype_bytes=2):
"""Analyze prefill attention compute and memory."""
S = seq_len
d = head_dim
H = num_heads
# FLOPs: Q*K^T (S*S*d per head) + softmax (S*S) + attn*V (S*S*d)
qk_flops = 2 * H * S * S * d
av_flops = 2 * H * S * S * d
total_flops = qk_flops + av_flops
# Memory: Q,K,V tensors (no materialized S*S matrix with FlashAttention)
qkv_bytes = 3 * H * S * d * dtype_bytes
output_bytes = H * S * d * dtype_bytes
total_memory = qkv_bytes + output_bytes
# Without FlashAttention: must store S*S attention matrix
naive_memory = qkv_bytes + H * S * S * dtype_bytes + output_bytes
return {
"total_tflops": total_flops / 1e12,
"flash_memory_gb": total_memory / 1e9,
"naive_memory_gb": naive_memory / 1e9,
"arithmetic_intensity": total_flops / total_memory,
}
# Llama 70B: 80 layers, 64 heads, head_dim=128
# Prefill S=4096:
# Total: 80 * 2 * 64 * 4096^2 * 128 = 2.75 TFLOP
# This takes 2.75 / 989 = 2.8ms on H100 (just attention)
Decode Optimization: Maximizing Bandwidth Utilization
Decode generates one token per request. The input to each layer is a vector (batch=1) or a thin matrix (batched decode). The GEMM is now a matrix-vector product (GEMV) or a thin GEMM.
The problem: for Llama 70B, each layer’s weights are approximately 1.75 GB (FP16). At batch size 1, the GEMM does FLOPs but loads 134 MB of weights. On an H100 (3.35 TB/s bandwidth), loading takes . The compute takes . The GPU spends 99.6% of its time waiting for weight data.
def decode_roofline(hidden_dim, batch_size, num_layers,
bw_tbs=3.35, compute_tflops=989):
"""Roofline analysis for decode phase."""
d = hidden_dim
# Weight bytes per layer (Q,K,V,O projections + FFN gate/up/down)
# Assuming GQA with 8 KV heads and 64 query heads, head_dim=128
qkv_bytes = (64 + 8 + 8) * 128 * d * 2 # Q + K + V projections
o_bytes = 64 * 128 * d * 2 # O projection
ffn_bytes = 3 * d * (d * 8 // 3) * 2 # gate + up + down (SwiGLU)
weight_bytes_per_layer = qkv_bytes + o_bytes + ffn_bytes
# FLOPs per layer
flops_per_layer = 2 * batch_size * weight_bytes_per_layer // 2 # 2*M*N for each GEMM
# Time per layer
bandwidth_time = weight_bytes_per_layer / (bw_tbs * 1e12)
compute_time = flops_per_layer / (compute_tflops * 1e12)
# Decode is bandwidth-bound when bandwidth_time > compute_time
# This happens when batch_size < balance_point
balance_point = (bw_tbs * 1e12) / (compute_tflops * 1e12) * (weight_bytes_per_layer // 2) / (weight_bytes_per_layer // 2)
# Simplifies to: balance_point = bw / compute * weight_elements
total_time = num_layers * max(bandwidth_time, compute_time)
tokens_per_sec = batch_size / total_time
return {
"weight_bytes_per_layer_mb": weight_bytes_per_layer / 1e6,
"bandwidth_time_us": bandwidth_time * 1e6,
"compute_time_us": compute_time * 1e6,
"bottleneck": "bandwidth" if bandwidth_time > compute_time else "compute",
"tokens_per_sec": tokens_per_sec,
"time_per_token_ms": total_time * 1000 / batch_size,
}
Decode Throughput vs Batch Size (Llama 70B, H100)
line| Metric | 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128 | 256 | 512 |
|---|---|---|---|---|---|---|---|---|---|---|
| Tokens/sec (bandwidth model) | ||||||||||
| Tokens/sec (actual measured) | ||||||||||
| Tokens/sec (compute ceiling) |
At batch size 1, decode throughput is approximately 24 tokens/sec on an H100 for Llama 70B FP16. The theoretical compute ceiling is over 15,000 tokens/sec. The 600x gap is entirely due to memory bandwidth. Increasing batch size is the primary lever: at batch=256, actual throughput reaches 3,200 tokens/sec (21% of compute ceiling), because weights are loaded once and reused across 256 requests.
CUDA Graphs for Decode
Every decode step executes the exact same sequence of CUDA kernels with the same tensor shapes (batch_size x hidden_dim). CUDA graphs eliminate the CPU-side kernel launch overhead by recording the kernel sequence once and replaying it:
import torch
class CUDAGraphDecoder:
"""Capture and replay decode step as a CUDA graph."""
def __init__(self, model, max_batch_size, device="cuda:0"):
self.model = model
self.device = device
self.graphs = {} # batch_size -> captured graph
self.static_inputs = {} # batch_size -> static input tensors
def capture(self, batch_size):
"""Capture the decode forward pass for a specific batch size."""
# Allocate static tensors (CUDA graph requires fixed addresses)
static_input_ids = torch.zeros(
batch_size, 1, dtype=torch.long, device=self.device
)
static_position_ids = torch.zeros(
batch_size, 1, dtype=torch.long, device=self.device
)
static_output = torch.zeros(
batch_size, 1, self.model.config.vocab_size,
dtype=torch.float16, device=self.device
)
self.static_inputs[batch_size] = {
"input_ids": static_input_ids,
"position_ids": static_position_ids,
"output": static_output,
}
# Warmup (required before capture)
for _ in range(3):
with torch.no_grad():
_ = self.model(
input_ids=static_input_ids,
position_ids=static_position_ids,
)
# Capture
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
with torch.no_grad():
output = self.model(
input_ids=static_input_ids,
position_ids=static_position_ids,
)
self.static_inputs[batch_size]["output"].copy_(output.logits)
self.graphs[batch_size] = graph
def decode_step(self, input_ids, position_ids):
"""Execute one decode step using the captured graph."""
batch_size = input_ids.shape[0]
if batch_size not in self.graphs:
self.capture(batch_size)
static = self.static_inputs[batch_size]
# Copy dynamic data into static buffers
static["input_ids"].copy_(input_ids)
static["position_ids"].copy_(position_ids)
# Replay the captured graph (no CPU kernel launch overhead)
self.graphs[batch_size].replay()
return static["output"].clone()
Without CUDA graphs, each decode step involves:
- CPU Python overhead to call each module’s
forward(5-15 us per module) - CUDA kernel launch for each operation (3-5 us per launch)
- For 80 layers with ~20 kernels each: of launch overhead
With CUDA graphs: 1 graph replay launch = ~5 us total. For a model where the actual GPU compute takes 30ms, eliminating 8ms of launch overhead gives a 21% speedup.
CUDA Graph Impact on Decode Latency (Llama 70B, H100)
| Batch Size | Without Graph (ms) | With Graph (ms) | Speedup | Launch Overhead Eliminated |
|---|---|---|---|---|
| 1 | 42.1 | 33.8 | 1.25x | 8.3ms |
| 8 | 43.2 | 34.1 | 1.27x | 9.1ms |
| 32 | 48.5 | 39.8 | 1.22x | 8.7ms |
| 128 | 72.3 | 63.1 | 1.15x | 9.2ms |
| 256 | 118.4 | 109.8 | 1.08x | 8.6ms |
Chunked Prefill: Taming Long Prompts
A 32K token prompt processed in a single prefill pass creates several problems:
- Attention compute: even with FlashAttention, the 32K x 32K computation takes significant time
- KV cache allocation: must reserve KV cache for all 32K positions at once
- Decode starvation: while prefill runs, no decode steps happen for other requests
Chunked prefill splits the prompt into chunks of size and processes them sequentially:
class ChunkedPrefillScheduler:
"""Process prefill in chunks to avoid starving decode requests."""
def __init__(self, model, chunk_size=512, max_batch_tokens=8192):
self.model = model
self.chunk_size = chunk_size
self.max_batch_tokens = max_batch_tokens
def schedule_iteration(self, prefill_queue, decode_queue):
"""Schedule one iteration mixing prefill chunks and decode tokens."""
batch = []
token_budget = self.max_batch_tokens
# Priority 1: decode tokens (low latency requirement)
decode_requests = []
for req in decode_queue:
if token_budget >= 1: # Each decode request uses 1 token
decode_requests.append(req)
token_budget -= 1
batch.extend(decode_requests)
# Priority 2: prefill chunks (fill remaining budget)
prefill_chunks = []
for req in prefill_queue:
remaining = req.prompt_len - req.processed_tokens
chunk = min(remaining, self.chunk_size, token_budget)
if chunk > 0:
prefill_chunks.append((req, chunk))
token_budget -= chunk
batch.extend(prefill_chunks)
return batch
def execute_chunked_prefill(self, request, chunk_start, chunk_size):
"""Execute one chunk of prefill for a request."""
chunk_ids = request.prompt_ids[chunk_start:chunk_start + chunk_size]
position_ids = torch.arange(
chunk_start, chunk_start + chunk_size,
device=chunk_ids.device
)
# Forward pass for this chunk only
# KV cache is appended incrementally
with torch.no_grad():
outputs = self.model(
input_ids=chunk_ids.unsqueeze(0),
position_ids=position_ids.unsqueeze(0),
past_key_values=request.kv_cache,
use_cache=True,
)
# Update request state
request.kv_cache = outputs.past_key_values
request.processed_tokens += chunk_size
if request.processed_tokens >= request.prompt_len:
request.phase = "decode"
return outputs.logits[:, -1, :] # Ready to sample first token
return None # More chunks needed
Chunked Prefill Performance Analysis
The chunk size creates a direct tradeoff:
- Small (e.g., 128): minimal decode starvation, but prefill takes many iterations. Each chunk’s GEMM has lower arithmetic intensity ( instead of ).
- Large (e.g., 4096): efficient prefill GEMMs, but decode requests wait up to one chunk’s execution time between steps.
The inter-token latency (ITL) for decode requests during chunked prefill:
For a 70B model on H100 with batch=128 decode requests and one prefill chunk:
def chunked_prefill_latency(chunk_size, decode_batch, model_params):
"""Estimate per-iteration latency with chunked prefill."""
d = model_params["hidden_dim"]
num_layers = model_params["num_layers"]
bw = model_params["hbm_bw_tbs"] # TB/s
# Decode: bandwidth-bound, load all weights once
weight_bytes = num_layers * 32 * d * d # Approximate
decode_time = weight_bytes / (bw * 1e12)
# Prefill chunk: compute-bound if chunk_size is large enough
prefill_flops = num_layers * 2 * chunk_size * 12 * d * d # All linear layers
compute_tflops = model_params["compute_tflops"]
prefill_time = prefill_flops / (compute_tflops * 1e12)
# With chunked prefill, decode and prefill share the same forward pass
# The total token count determines whether we are compute or bandwidth bound
total_tokens = decode_batch + chunk_size
total_flops = num_layers * 2 * total_tokens * 12 * d * d
bw_time = weight_bytes / (bw * 1e12)
compute_time = total_flops / (compute_tflops * 1e12)
iteration_time = max(bw_time, compute_time)
return {
"iteration_ms": iteration_time * 1000,
"decode_itl_ms": iteration_time * 1000,
"bottleneck": "bandwidth" if bw_time > compute_time else "compute",
"total_tokens": total_tokens,
}
Decode ITL vs Prefill Chunk Size (Llama 70B, H100, 128 Decode Requests)
line| Metric | 0 (decode only) | 128 | 256 | 512 | 1024 | 2048 | 4096 |
|---|---|---|---|---|---|---|---|
| Inter-Token Latency (ms) | |||||||
| Prefill Throughput (tokens/s) |
Why You Cannot Optimize Both With One Approach
The fundamental tension is:
-
Prefill wants large matrices - high arithmetic intensity, saturate compute. Pack as many prompt tokens as possible into each forward pass.
-
Decode wants low latency - minimize the time between generated tokens. Any prefill work in the same batch increases decode latency.
-
Batching helps decode but not prefill - increasing decode batch size improves GPU utilization (weights loaded once, used times). But prefill already has high utilization; adding more prefill tokens just takes proportionally longer.
Approach 1: Shared Forward Pass (vLLM/SGLang)
Mix prefill chunks and decode tokens in the same forward pass:
def mixed_batch_forward(model, decode_requests, prefill_chunks):
"""Single forward pass with both decode and prefill tokens."""
# Concatenate all tokens into one batch
all_input_ids = []
all_position_ids = []
all_seq_lens = []
# Decode tokens: 1 token per request
for req in decode_requests:
all_input_ids.append(req.current_token_id)
all_position_ids.append(req.current_position)
all_seq_lens.append(req.total_seq_len)
# Prefill chunks: chunk_size tokens per request
for req, chunk_start, chunk_size in prefill_chunks:
all_input_ids.extend(
req.prompt_ids[chunk_start:chunk_start + chunk_size]
)
all_position_ids.extend(
range(chunk_start, chunk_start + chunk_size)
)
all_seq_lens.append(chunk_size)
# Single forward pass processes everything
# The attention kernel handles variable sequence lengths
# via the sequence length metadata
input_ids = torch.tensor(all_input_ids, device="cuda")
position_ids = torch.tensor(all_position_ids, device="cuda")
with torch.no_grad():
logits = model(input_ids, position_ids, ...)
return logits
Tradeoff: simple implementation, but decode latency increases with prefill chunk size. Decode tokens “pay” for the prefill compute.
Approach 2: Disaggregated Prefill/Decode (Splitwise, DistServe)
Run prefill and decode on separate GPU pools:
class DisaggregatedServing:
"""Separate GPU pools for prefill and decode."""
def __init__(self, prefill_workers, decode_workers):
self.prefill_pool = prefill_workers # Optimized for compute
self.decode_pool = decode_workers # Optimized for bandwidth
def handle_request(self, prompt_ids):
# Phase 1: Prefill on compute-optimized GPU
prefill_worker = self.prefill_pool.get_worker()
kv_cache = prefill_worker.prefill(prompt_ids)
first_token = prefill_worker.sample(kv_cache)
# Transfer KV cache to decode worker
decode_worker = self.decode_pool.get_worker()
decode_worker.receive_kv_cache(kv_cache) # RDMA transfer
# Phase 2: Decode on bandwidth-optimized GPU
tokens = [first_token]
while not is_eos(tokens[-1]):
next_token = decode_worker.decode_step(tokens[-1])
tokens.append(next_token)
return tokens
Tradeoff: optimal performance for each phase, but KV cache transfer between pools adds latency (). For a 4K context on Llama 70B: KV cache = , transfer at 400 Gbps = 26ms overhead.
Approach 3: Priority Scheduling
Keep both phases on the same GPU but give decode strict priority:
class PriorityScheduler:
"""Decode-priority scheduler that preempts prefill for decode."""
def __init__(self, model, decode_slo_ms=50):
self.model = model
self.decode_slo_ms = decode_slo_ms
self.decode_queue = []
self.prefill_queue = []
def schedule(self):
"""Always process all pending decode tokens first.
Use remaining capacity for prefill chunks."""
batch = []
token_budget = 8192 # Max tokens per iteration
# Decode tokens get absolute priority
for req in self.decode_queue:
batch.append(("decode", req, 1))
token_budget -= 1
# If batch is too small (low decode load), add prefill
if token_budget > 128: # Minimum chunk size
for req in self.prefill_queue:
remaining = req.prompt_len - req.processed
chunk = min(remaining, 512, token_budget)
if chunk >= 128:
batch.append(("prefill", req, chunk))
token_budget -= chunk
return batch
Optimization Strategy Comparison (Llama 70B, H100, 100 QPS)
| Strategy | TTFT P50 (ms) | TTFT P99 (ms) | ITL P50 (ms) | ITL P99 (ms) | Throughput (tok/s) |
|---|---|---|---|---|---|
| Shared (no chunking) | 85 | 320 | 68 | 180 | 3200 |
| Chunked prefill (C=512) | 210 | 450 | 36 | 52 | 3800 |
| Chunked prefill (C=2048) | 130 | 380 | 48 | 95 | 4100 |
| Disaggregated | 95 | 180 | 34 | 42 | 4500 |
| Priority scheduling | 250 | 600 | 34 | 45 | 3600 |
Profiling the Phase Boundary
The transition from prefill to decode is visible in profiling traces. Here is what to look for:
import torch.profiler
def profile_phase_transition(model, prompt_ids, num_decode_steps=10):
"""Profile the prefill-to-decode transition to see the bottleneck shift."""
with torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
record_shapes=True,
with_stack=True,
) as prof:
# Prefill phase
with torch.profiler.record_function("PREFILL"):
with torch.no_grad():
outputs = model(prompt_ids, use_cache=True)
kv_cache = outputs.past_key_values
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
# Decode phase
for step in range(num_decode_steps):
with torch.profiler.record_function(f"DECODE_STEP_{step}"):
with torch.no_grad():
outputs = model(
next_token,
past_key_values=kv_cache,
use_cache=True,
)
kv_cache = outputs.past_key_values
next_token = outputs.logits[:, -1, :].argmax(dim=-1, keepdim=True)
# Export for Nsight/Chrome trace viewer
prof.export_chrome_trace("phase_transition.json")
# Print kernel-level breakdown
print(prof.key_averages().table(
sort_by="cuda_time_total", row_limit=20
))
In the trace, you will see:
- Prefill: large GEMM kernels (e.g.,
sm80_xmma_gemm_f16f16_f16f32_f32_tn_n) dominating, high SM occupancy, low memory bandwidth utilization - Decode: small GEMV kernels, low SM occupancy, HBM bandwidth near saturation, many small kernel launches between GEMMs
Quantization Interacts Differently With Each Phase
Quantizing weights from FP16 to INT8 halves the weight bytes. This has different effects on each phase:
Prefill (compute-bound): halving weight bytes does not help because compute is the bottleneck. The GEMM still does the same number of FLOPs. INT8 tensor cores are 2x faster than FP16 on H100 (1979 vs 989 TOPS), so INT8 quantization gives ~2x prefill speedup through faster compute, not reduced memory.
Decode (bandwidth-bound): halving weight bytes directly halves the memory load time, giving ~2x decode speedup. The compute was not the bottleneck, so the reduced precision does not matter.
def quantization_impact(hidden_dim, batch_size, seq_len, phase,
bw_tbs=3.35, fp16_tflops=989, int8_tops=1979):
"""Estimate quantization impact on each phase."""
d = hidden_dim
weight_elements = 12 * d * d # Approximate per layer
if phase == "prefill":
tokens = seq_len
fp16_flops = 2 * tokens * weight_elements
fp16_bytes = weight_elements * 2
int8_bytes = weight_elements * 1
fp16_compute_time = fp16_flops / (fp16_tflops * 1e12)
fp16_bw_time = fp16_bytes / (bw_tbs * 1e12)
int8_compute_time = fp16_flops / (int8_tops * 1e12) # INT8 tensor cores
int8_bw_time = int8_bytes / (bw_tbs * 1e12)
return {
"fp16_time": max(fp16_compute_time, fp16_bw_time),
"int8_time": max(int8_compute_time, int8_bw_time),
"speedup": max(fp16_compute_time, fp16_bw_time) / max(int8_compute_time, int8_bw_time),
"fp16_bottleneck": "compute" if fp16_compute_time > fp16_bw_time else "bandwidth",
"int8_bottleneck": "compute" if int8_compute_time > int8_bw_time else "bandwidth",
}
else: # decode
tokens = batch_size
fp16_bytes = weight_elements * 2
int8_bytes = weight_elements * 1
fp16_flops = 2 * tokens * weight_elements
fp16_bw_time = fp16_bytes / (bw_tbs * 1e12)
int8_bw_time = int8_bytes / (bw_tbs * 1e12)
return {
"fp16_time": fp16_bw_time, # BW-bound at small batch
"int8_time": int8_bw_time,
"speedup": fp16_bw_time / int8_bw_time, # ~2x
"bottleneck": "bandwidth (both)",
}
INT8 Quantization Impact by Phase (Llama 70B, H100)
| Configuration | Prefill (S=2048) ms | Decode (B=1) ms | Decode (B=128) ms |
|---|---|---|---|
| FP16 | 28.4 | 42.1 | 72.3 |
| INT8 (W8A8) | 15.2 | 21.8 | 41.5 |
| Speedup | 1.87x | 1.93x | 1.74x |
| Bottleneck shift | Compute -> Compute | BW -> BW | BW -> borderline |
The fundamental insight is that prefill and decode live on opposite sides of the roofline. Any optimization that reduces memory traffic helps decode. Any optimization that increases compute throughput helps prefill. Chunked prefill is the practical compromise that lets both coexist on the same GPU, while disaggregated serving is the principled solution that gives each phase its own hardware optimized for its specific bottleneck.