Part of Series Inference Optimization Timeline 50 of 60
1 Transformer Fundamentals for Systems Engineers: The 10-Minute Bridge from Architecture to Inference 2 LLM Inference Fundamentals: Prefill, Decode, and the Memory-Compute Divide 3 KV Cache: The Hidden Memory Giant in LLM Serving 4 Quantization for LLM Inference: From FP16 to INT4 — A Deep Dive into Precision, Performance, and Production Deployment 5 FlashAttention: Why Tiling Attention Through the Memory Hierarchy Changes Everything 6 PagedAttention: How vLLM Borrowed OS Virtual Memory to Fix LLM Serving 7 Continuous Batching: The Complete Guide to LLM Inference Scheduling 8 Speculative Decoding: Why Autoregressive LLMs Leave 99% of Your GPU Idle and How to Fix It 9 Prefix Caching: RadixAttention, Cache Hierarchies, and Reusing Computation Across Requests 10 LoRA and QLoRA for Serving: Multi-Adapter Inference, S-LoRA, and When to Merge 11 Disaggregated Prefill-Decode: Why Splitting LLM Inference Changes Everything 12 Constrained Generation: FSM-Based Decoding, Outlines, and Grammar-Guided LLM Output 13 Mamba and State Space Models: The O(n) Alternative to Attention 14 Inference-Time Compute Scaling: When More Thinking Helps (o1, DeepSeek-R1, and the Reasoning Frontier) 15 CPU and Edge Inference: llama.cpp Internals, GGUF Format, and When CPU Actually Wins 16 Inference Cost Economics: Tokens per Dollar, GPU-Hours, and the Real Math of LLM Serving 17 Model Loading and Cold Start: safetensors, mmap, and Startup Optimization 18 Batched GEMM: Why Matrix Multiply Throughput Determines Everything in LLM Inference 19 Kernel Autotuning: How TensorRT and torch.compile Find Optimal CUDA Kernels 20 Attention Kernel Comparison: FlashAttention vs FlashInfer vs xformers vs Triton 21 Token Generation Pipeline: Logit Processing, Sampling Strategies, and Stop Criteria 22 Dynamic Batching: Orca, Sarathi, and Iteration-Level Scheduling Algorithms 23 Memory Pool Management: Slab Allocators for GPU Inference 24 Prefill vs Decode Optimization: Different Bottlenecks, Different Solutions 25 Decode Optimization: CUDA Graphs, Persistent Batches, and Speculative Verification 26 Multi-Model Serving: GPU Sharing, Model Switching, and Adapter Pool Management 27 Structured Output Acceleration: Compressed FSMs, Speculative JSON, and Grammar Caching 28 Vision-Language Model Serving: ViT Encoding, Cross-Attention, and KV Cache Paging for Multimodal 29 Long-Context Serving: Ring Attention, KV Offloading, and Chunked Processing in Production 30 Inference Profiling: Nsight Systems, torch.profiler, and Finding Where Time Actually Goes 31 FP8 Inference: E4M3 Format, Per-Tensor Scaling, and the Hardware Support Matrix 32 Speculative Decoding v2: Medusa, EAGLE, Lookahead, and Token Tree Verification 33 Disaggregated Serving v2: Mooncake KV-Centric Architecture and LoongServe Elastic SP 34 Request Preemption and Priority Scheduling in Production LLM Serving 35 Autoscaling LLM Inference: Signals, Lag, Warm Pools, and Cost-Optimal Scaling 36 The Inference Stack in 2026: From HTTP Request to GPU Kernel and Back 37 Video and Audio LLM Serving: Temporal Encoding, Chunked Streaming, and Latency Budgets 38 KV Cache Compression and Eviction: H2O, Attention Sinks, Sliding Window, and Quantized KV 39 Distributed Inference: Tensor Parallelism vs Pipeline Parallelism for Serving 40 Serving Benchmark Methodology: How to Properly Measure LLM Inference Performance 41 Compute-Communication Overlap: Hiding Distributed Training Latency 42 DeepSpeed ZeRO: Memory Optimization for Distributed Training at Scale 43 Pipeline Parallelism: From GPipe to DualPipe -- Eliminating the Bubble 44 Gradient Compression for Distributed Training: Promise, Reality, and Where It Still Wins 45 The Definitive Guide to Distributed Parallelism: Data, Tensor, Pipeline, Expert, and Sequence Parallelism for Large-Scale Training 46 Decoding Performance: Beam Search vs Sampling — Latency, Throughput, Memory, and the Full Design Space 47 LLM Prefill Phase Optimization: Why Prompt Processing Is Compute-Bound and How to Fix It 48 LLM Serving Engines: vLLM vs SGLang vs TensorRT-LLM — A Systems Comparison 49 Request Routing for LLM Inference: From Naive Load Balancing to KV Cache-Aware Scheduling 50 Why Adam Is Expensive and What To Do About It: 8-bit Adam, Adafactor, CAME, and the Memory Math of Optimizers 51 How Large Models Actually Get Loaded: Safetensors, mmap, Tensor Parallelism, and Progressive Loading 52 Mixed Precision Training: The Complete Precision Landscape from FP32 to FP4 53 Model Compression: Pruning, Distillation, and Why Quantization Won 54 From NAS to Scaling Laws: How We Design LLM Architectures Now 55 NVIDIA NCCL Performance Tuning for Multi-GPU Training 56 ONNX Runtime in Practice: Graph Optimization, Execution Providers, Quantization, and When ORT Is the Right Choice 57 Optimizing GEMM for Neural Networks: BLAS vs Custom Kernels (Nov 2019) 58 Long Context: From Sparse Attention to Ring Attention 59 TensorRT-LLM: Graph Optimization for Maximum Inference Performance 60 Long Context LLMs: From 2K to 1M Tokens

The Hidden Cost of Adaptive Optimizers

Every discussion of large model training eventually hits the same wall: GPU memory. A 7B parameter model in FP16 requires 14 GB just for the weights. But the optimizer state — the internal bookkeeping that Adam maintains for every parameter — often requires more memory than the model itself.

This is not a minor implementation detail. For a 70B parameter model trained with standard Adam in mixed precision, the optimizer state alone consumes 840 GB. That is not a typo. Understanding why this happens, and what can be done about it, is essential for anyone training models at scale.

This article dissects the memory cost of optimizers, explains the most important memory-efficient alternatives (8-bit Adam via bitsandbytes, Adafactor, CAME), provides concrete memory savings analysis, and gives a practical decision framework for choosing the right optimizer.

Why Adam Requires 12 Bytes Per Parameter

The Adam Update Rule

Adam maintains two exponential moving averages per parameter: the first moment mm (mean of gradients) and the second moment vv (mean of squared gradients). The update at step tt is:

mt=β1mt1+(1β1)gtm_t = \beta_1 m_{t-1} + (1 - \beta_1) g_t

vt=β2vt1+(1β2)gt2v_t = \beta_2 v_{t-1} + (1 - \beta_2) g_t^2

m^t=mt1β1t\hat{m}_t = \frac{m_t}{1 - \beta_1^t}

v^t=vt1β2t\hat{v}_t = \frac{v_t}{1 - \beta_2^t}

θt=θt1αm^tv^t+ϵ\theta_t = \theta_{t-1} - \alpha \frac{\hat{m}_t}{\sqrt{\hat{v}_t} + \epsilon}

Each parameter θ\theta requires storing mm and vv, plus the parameter itself, plus the gradient. In mixed-precision training (the standard for modern LLMs), the memory breakdown is:

ComponentPrecisionBytes per Parameter
Parameters (for training)FP324 bytes
GradientsFP16 (or BF16)2 bytes
First moment mmFP324 bytes
Second moment vvFP324 bytes
FP16 parameters (for forward/backward)FP162 bytes
Total16 bytes

The optimizer states (mm and vv) must be stored in FP32 because they accumulate small updates over many steps. In FP16, the limited precision would cause these running averages to stagnate or diverge. The FP32 master copy of parameters is also required because weight updates are often too small to represent in FP16.

If we isolate just the optimizer’s contribution (what Adam adds beyond what SGD would need), it is the two FP32 state tensors: 4+4=84 + 4 = 8 bytes per parameter. Combined with the FP32 master weights (4 bytes), Adam requires 12 bytes per parameter for its state, compared to SGD’s 4 bytes (just the FP32 master weights).

def calculate_training_memory(num_params_billions, optimizer='adam', precision='mixed'):
    """
    Calculate memory requirements for training a model.
    Returns breakdown in GB.
    """
    num_params = num_params_billions * 1e9

    if precision == 'mixed':
        # Mixed precision: FP16 forward/backward, FP32 optimizer states
        param_memory = num_params * 2 / 1e9         # FP16 params: 2 bytes
        master_params = num_params * 4 / 1e9         # FP32 master: 4 bytes
        gradient_memory = num_params * 2 / 1e9       # FP16 grads: 2 bytes
    elif precision == 'fp32':
        param_memory = num_params * 4 / 1e9
        master_params = 0
        gradient_memory = num_params * 4 / 1e9
    else:
        raise ValueError(f"Unknown precision: {precision}")

    if optimizer == 'adam':
        # FP32 first and second moments
        optimizer_state = num_params * 8 / 1e9       # 4 + 4 bytes
    elif optimizer == 'sgd_momentum':
        optimizer_state = num_params * 4 / 1e9       # 4 bytes (momentum)
    elif optimizer == 'sgd':
        optimizer_state = 0
    elif optimizer == '8bit_adam':
        # 8-bit first and second moments + quantization overhead
        optimizer_state = num_params * 2 / 1e9       # 1 + 1 bytes
        optimizer_state += num_params / 2048 * 8 / 1e9  # block scales
    elif optimizer == 'adafactor':
        # Factored second moment (row + col factors instead of full matrix)
        # For simplicity, assume average 2 bytes per parameter
        optimizer_state = num_params * 2 / 1e9
    else:
        raise ValueError(f"Unknown optimizer: {optimizer}")

    total = param_memory + master_params + gradient_memory + optimizer_state

    return {
        'parameters_gb': param_memory,
        'master_params_gb': master_params,
        'gradients_gb': gradient_memory,
        'optimizer_state_gb': optimizer_state,
        'total_gb': total,
    }
📊

Training Memory Breakdown by Optimizer (Mixed Precision)

ModelParams (FP16)Master (FP32)Gradients (FP16)Adam State (FP32)Total
1B params 2 GB 4 GB 2 GB 8 GB 16 GB
7B params 14 GB 28 GB 14 GB 56 GB 112 GB
13B params 26 GB 52 GB 26 GB 104 GB 208 GB
70B params 140 GB 280 GB 140 GB 560 GB 1,120 GB
175B params 350 GB 700 GB 350 GB 1,400 GB 2,800 GB
⚠️ The 16x Rule

A useful rule of thumb: training a model with Adam in mixed precision requires roughly 16 bytes per parameter. A 7B model needs about 112 GB just for parameters, gradients, and optimizer state — before accounting for activations, which can add another 50-200 GB depending on batch size and sequence length.

Where the Memory Actually Goes

For a 7B parameter model trained with Adam in mixed precision, the 112 GB breaks down as:

  • FP32 optimizer state (m + v): 56 GB (50%) — This is the dominant cost
  • FP32 master weights: 28 GB (25%)
  • FP16 parameters: 14 GB (12.5%)
  • FP16 gradients: 14 GB (12.5%)

The optimizer state is the single largest memory consumer. This is why memory-efficient optimizers focus on reducing the size of mm and vv.

Memory Breakdown: 7B Parameter Model

(GB)
📊 bar chart (GB)

8-bit Adam (bitsandbytes)

How It Works

Tim Dettmers’ bitsandbytes library introduced 8-bit Adam (published 2022, but the concepts were explored earlier), which quantizes the optimizer states mm and vv from FP32 (32 bits) to INT8 (8 bits). This cuts the optimizer state from 8 bytes per parameter to 2 bytes per parameter.

The key challenge is that naive INT8 quantization of optimizer states destroys training dynamics. The states span a wide range of values, and uniform 8-bit quantization cannot represent both the large values (common early in training) and small values (common late in training) accurately.

bitsandbytes solves this with two techniques:

1. Block-wise quantization: Instead of quantizing all values with a single scale factor, the optimizer states are divided into blocks of 2048 elements. Each block has its own scale factor, which allows different parts of the model to use different dynamic ranges.

2. Dynamic exponent data type: Instead of uniform INT8, bitsandbytes uses a non-uniform quantization that allocates more precision to values near zero (where most optimizer state values cluster). This is similar in spirit to the NormalFloat (NF4) quantization used in QLoRA.

import bitsandbytes as bnb
import torch

def setup_8bit_adam(model, lr=1e-4, weight_decay=0.01):
    """
    Replace standard Adam with 8-bit Adam from bitsandbytes.
    This is a drop-in replacement that saves ~75% of optimizer memory.
    """
    optimizer = bnb.optim.Adam8bit(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=weight_decay,
    )
    return optimizer

def compare_optimizer_memory(model):
    """
    Compare memory usage between standard and 8-bit Adam.
    """
    num_params = sum(p.numel() for p in model.parameters())

    # Standard Adam: 2 FP32 states per parameter
    standard_state_bytes = num_params * 4 * 2  # m and v, each FP32
    standard_state_gb = standard_state_bytes / 1e9

    # 8-bit Adam: 2 INT8 states + block scales
    block_size = 2048
    n_blocks = (num_params + block_size - 1) // block_size
    eightbit_state_bytes = num_params * 1 * 2  # m and v, each INT8
    eightbit_scales_bytes = n_blocks * 4 * 2   # FP32 scale per block
    eightbit_total = eightbit_state_bytes + eightbit_scales_bytes
    eightbit_gb = eightbit_total / 1e9

    savings_pct = (1 - eightbit_gb / standard_state_gb) * 100

    return {
        'num_params': num_params,
        'standard_adam_gb': standard_state_gb,
        '8bit_adam_gb': eightbit_gb,
        'savings_gb': standard_state_gb - eightbit_gb,
        'savings_pct': savings_pct,
    }

The Block-wise Quantization Algorithm

The quantization process for each optimizer step:

  1. After computing the Adam update for mm and vv in FP32, divide each state tensor into blocks of 2048 elements
  2. For each block, compute the absolute maximum value
  3. Quantize the block to INT8 using that maximum as the scale: q=round(x/absmax×127)q = \text{round}(x / \text{absmax} \times 127)
  4. Store the INT8 values and the FP32 scale factor
  5. At the next step, dequantize to FP32 for the update computation, then re-quantize
def blockwise_quantize_state(state_fp32, block_size=2048):
    """
    Quantize an optimizer state tensor to INT8 using block-wise scaling.

    Each block of 2048 elements gets its own scale factor,
    preserving dynamic range across different parts of the tensor.
    """
    n_elements = state_fp32.numel()
    state_flat = state_fp32.flatten()

    # Pad to block boundary
    n_blocks = (n_elements + block_size - 1) // block_size
    padded_size = n_blocks * block_size
    if padded_size > n_elements:
        state_flat = torch.nn.functional.pad(
            state_flat, (0, padded_size - n_elements)
        )

    # Reshape into blocks
    blocks = state_flat.reshape(n_blocks, block_size)

    # Compute per-block absolute maximum
    absmax = blocks.abs().amax(dim=1, keepdim=True)  # (n_blocks, 1)
    absmax = absmax.clamp(min=1e-10)  # Avoid division by zero

    # Quantize to INT8 range [-127, 127]
    scale = 127.0 / absmax
    quantized = (blocks * scale).round().clamp(-127, 127).to(torch.int8)

    # Return quantized values and scale factors
    return quantized.flatten()[:n_elements], absmax.flatten()

def blockwise_dequantize_state(quantized_int8, scales, block_size=2048):
    """
    Dequantize INT8 optimizer state back to FP32 for update computation.
    """
    n_elements = quantized_int8.numel()
    n_blocks = scales.numel()

    # Pad if necessary
    padded_size = n_blocks * block_size
    q_flat = quantized_int8.float()
    if padded_size > n_elements:
        q_flat = torch.nn.functional.pad(q_flat, (0, padded_size - n_elements))

    # Reshape and apply scales
    blocks = q_flat.reshape(n_blocks, block_size)
    dequantized = blocks / 127.0 * scales.unsqueeze(1)

    return dequantized.flatten()[:n_elements]

Training Quality With 8-bit Adam

The critical question: does 8-bit Adam converge to the same accuracy as standard Adam? The empirical evidence is strongly positive:

📊

8-bit Adam vs Standard Adam: Training Quality

Model / TaskStandard Adam (Loss / Accuracy)8-bit Adam (Loss / Accuracy)DifferenceMemory Saved
BERT-Base (MNLI) 84.6% accuracy 84.5% accuracy -0.1% 5.4 GB
GPT-2 (117M, WikiText) 21.1 perplexity 21.2 perplexity +0.1 ppl 0.9 GB
RoBERTa-Large (SQuAD) 88.9 F1 88.8 F1 -0.1 F1 2.7 GB
T5-3B (SuperGLUE) 89.2% avg 89.1% avg -0.1% 22.4 GB
LLaMA-7B (fine-tuning) 5.68 perplexity 5.70 perplexity +0.02 ppl 42 GB
ViT-Large (ImageNet) 85.2% top-1 85.1% top-1 -0.1% 2.4 GB

The accuracy difference is consistently within noise. The bitsandbytes paper (Dettmers et al., 2022) showed that 8-bit Adam matches standard Adam across a wide range of models and tasks, with no cases of significant degradation.

💡 8-bit Adam Is a Free Lunch

In nearly all practical scenarios, 8-bit Adam is strictly better than standard Adam: same convergence, 75% less optimizer memory, negligible compute overhead. There is almost no reason to use standard Adam when bitsandbytes is available. The only exception is if you are running into numerical stability issues in unusual training setups.

Adafactor: Factored Second Moments

The Factorization Idea

Adafactor (Shazeer and Stern, 2018) takes a more radical approach to memory reduction. Instead of storing the full second-moment matrix vv (which has the same shape as the parameter tensor), Adafactor stores only row and column factors whose outer product approximates vv.

For a weight matrix WRm×nW \in \mathbb{R}^{m \times n}, standard Adam stores vRm×nv \in \mathbb{R}^{m \times n} (the full second moment). Adafactor stores:

rRmr \in \mathbb{R}^{m} (row factor) and cRnc \in \mathbb{R}^{n} (column factor)

The approximation is: v^ijricjmean(r)\hat{v}_{ij} \approx \frac{r_i \cdot c_j}{\text{mean}(r)}

This reduces memory from O(mn)O(mn) to O(m+n)O(m + n), which for large matrices is a dramatic saving. For a 4096×40964096 \times 4096 attention weight matrix, this reduces the second moment from 16M values (64 MB in FP32) to 8192 values (32 KB in FP32) — a 2000x reduction for that tensor.

Adafactor also optionally eliminates the first moment mm entirely by using a sliding window momentum estimation, further reducing memory.

class SimplifiedAdafactor:
    """
    Simplified Adafactor implementation illustrating the key ideas.
    For production use, use the implementation in transformers or fairseq.
    """

    def __init__(self, params, lr=None, eps=(1e-30, 1e-3),
                 clip_threshold=1.0, decay_rate=-0.8,
                 beta1=None, weight_decay=0.0,
                 scale_parameter=True, relative_step=True):
        self.params = list(params)
        self.lr = lr
        self.eps = eps
        self.clip_threshold = clip_threshold
        self.decay_rate = decay_rate
        self.beta1 = beta1  # None = no first moment (saves memory)
        self.weight_decay = weight_decay
        self.scale_parameter = scale_parameter
        self.relative_step = relative_step
        self.step_count = 0

        # Initialize states
        self.state = {}
        for p in self.params:
            state = {}
            if len(p.shape) >= 2:
                # 2D+ tensors: use factored second moment
                state['v_row'] = torch.zeros(p.shape[:-1], device=p.device)
                state['v_col'] = torch.zeros(
                    p.shape[:-2] + (p.shape[-1],), device=p.device
                )
            else:
                # 1D tensors (biases, norms): use full second moment
                state['v'] = torch.zeros_like(p)

            if beta1 is not None:
                state['m'] = torch.zeros_like(p)

            self.state[id(p)] = state

    def _rms(self, tensor):
        """Root mean square of tensor elements."""
        return tensor.norm(2) / (tensor.numel() ** 0.5)

    def _get_lr(self, step):
        """Compute learning rate (relative step schedule)."""
        if self.relative_step:
            return max(1e-6, min(1e-2, 1.0 / (step ** 0.5)))
        return self.lr

    def _get_rho(self, step):
        """Compute second moment decay rate."""
        return min(1 - step ** self.decay_rate, 0.999)

    def step(self):
        self.step_count += 1
        lr = self._get_lr(self.step_count)
        rho = self._get_rho(self.step_count)

        for p in self.params:
            if p.grad is None:
                continue

            grad = p.grad.data
            state = self.state[id(p)]

            if len(p.shape) >= 2:
                # Factored second moment update
                # Update row factor
                state['v_row'].mul_(rho).add_(
                    (grad ** 2).mean(dim=-1), alpha=1 - rho
                )
                # Update column factor
                state['v_col'].mul_(rho).add_(
                    (grad ** 2).mean(dim=-2), alpha=1 - rho
                )

                # Reconstruct approximate v from factors
                row_mean = state['v_row'].mean(dim=-1, keepdim=True)
                v_approx = (
                    state['v_row'].unsqueeze(-1)
                    * state['v_col'].unsqueeze(-2)
                    / row_mean.unsqueeze(-1).clamp(min=self.eps[0])
                )
            else:
                # Full second moment for 1D parameters
                state['v'].mul_(rho).add_(grad ** 2, alpha=1 - rho)
                v_approx = state['v']

            # Compute update
            update = grad / (v_approx.sqrt() + self.eps[0])

            # RMS clipping
            update_rms = self._rms(update)
            if update_rms > self.clip_threshold:
                update.mul_(self.clip_threshold / update_rms)

            # Optional first moment
            if self.beta1 is not None:
                state['m'].mul_(self.beta1).add_(update, alpha=1 - self.beta1)
                update = state['m']

            # Scale by parameter RMS if enabled
            if self.scale_parameter:
                param_rms = self._rms(p.data).clamp(min=self.eps[1])
                lr_scaled = lr * param_rms
            else:
                lr_scaled = lr

            # Weight decay
            if self.weight_decay > 0:
                p.data.add_(p.data, alpha=-self.weight_decay * lr_scaled)

            # Apply update
            p.data.add_(update, alpha=-lr_scaled)

Adafactor Memory Savings

📊

Memory Savings: Adam vs 8-bit Adam vs Adafactor

ModelAdam State8-bit Adam StateAdafactor StateAdafactor Savings vs Adam
BERT-Base (110M) 0.88 GB 0.22 GB 0.14 GB 84%
GPT-2 (1.5B) 12 GB 3 GB 1.8 GB 85%
T5-3B 24 GB 6 GB 3.6 GB 85%
LLaMA-7B 56 GB 14 GB 8.4 GB 85%
LLaMA-70B 560 GB 140 GB 84 GB 85%

Adafactor Trade-offs

Adafactor is not a drop-in replacement for Adam. It has meaningful trade-offs:

Convergence differences: Adafactor often converges to slightly different (sometimes worse) solutions than Adam. The factored approximation of the second moment is not exact, and this affects the effective learning rate per-parameter. T5 was specifically designed and tuned for Adafactor, so it works well there. For models designed for Adam, switching to Adafactor may require hyperparameter re-tuning.

The first moment question: Adafactor can run without a first moment (β1=None\beta_1 = \text{None}), which saves even more memory. But removing the momentum term often slows convergence, especially for tasks with noisy gradients.

Relative step size: Adafactor’s default learning rate schedule (lr1/t\text{lr} \propto 1/\sqrt{t}) is different from the typical warmup + cosine schedule used with Adam. Using Adafactor with a fixed learning rate is possible but requires different hyperparameters.

⚠️ Adafactor Is Not a Drop-in Replacement

Unlike 8-bit Adam, which is a true drop-in replacement for Adam, Adafactor changes the optimization dynamics. If you switch from Adam to Adafactor, expect to re-tune learning rate, warmup steps, and potentially other hyperparameters. The memory savings are larger than 8-bit Adam, but the engineering cost is also higher.

CAME: Confidence-Guided Adaptive Memory-Efficient Optimization

Beyond Adafactor

CAME (Luo et al., 2023) builds on Adafactor’s factored second moment idea but adds a confidence-guided mechanism that improves convergence. The key insight is that Adafactor’s factored approximation introduces error, and this error can be partially corrected by weighting the update based on how well the factored approximation matches the true second moment.

CAME maintains the same row and column factors as Adafactor but adds a “confidence” term that measures the quality of the approximation:

confidenceij=ricjricj+v^ijtruericj\text{confidence}_{ij} = \frac{r_i \cdot c_j}{r_i \cdot c_j + |\hat{v}_{ij}^{\text{true}} - r_i \cdot c_j|}

In practice, since the true second moment is not available (that is the whole point of factoring), CAME uses the current gradient squared as a proxy.

class SimplifiedCAME:
    """
    Simplified CAME optimizer illustrating the confidence mechanism.
    """

    def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999),
                 eps=(1e-30, 1e-6), weight_decay=0.0):
        self.params = list(params)
        self.lr = lr
        self.betas = betas  # beta1 for m, beta2 for v, beta3 for confidence
        self.eps = eps
        self.weight_decay = weight_decay
        self.step_count = 0
        self.state = {}

        for p in self.params:
            state = {
                'm': torch.zeros_like(p),
            }
            if len(p.shape) >= 2:
                state['v_row'] = torch.zeros(p.shape[:-1], device=p.device)
                state['v_col'] = torch.zeros(
                    p.shape[:-2] + (p.shape[-1],), device=p.device
                )
                # CAME addition: confidence accumulator
                state['confidence'] = torch.ones_like(p)
            else:
                state['v'] = torch.zeros_like(p)
            self.state[id(p)] = state

    def step(self):
        self.step_count += 1
        beta1, beta2, beta3 = self.betas

        for p in self.params:
            if p.grad is None:
                continue

            grad = p.grad.data
            state = self.state[id(p)]

            # First moment update (same as Adam)
            state['m'].mul_(beta1).add_(grad, alpha=1 - beta1)

            if len(p.shape) >= 2:
                # Factored second moment (same as Adafactor)
                grad_sq = grad ** 2
                state['v_row'].mul_(beta2).add_(
                    grad_sq.mean(dim=-1), alpha=1 - beta2
                )
                state['v_col'].mul_(beta2).add_(
                    grad_sq.mean(dim=-2), alpha=1 - beta2
                )

                # Reconstruct factored approximation
                row_mean = state['v_row'].mean(dim=-1, keepdim=True)
                v_factored = (
                    state['v_row'].unsqueeze(-1)
                    * state['v_col'].unsqueeze(-2)
                    / row_mean.unsqueeze(-1).clamp(min=self.eps[0])
                )

                # CAME: compute confidence from approximation quality
                approx_error = (grad_sq - v_factored).abs()
                confidence_update = v_factored / (
                    v_factored + approx_error + self.eps[0]
                )
                state['confidence'].mul_(beta3).add_(
                    confidence_update, alpha=1 - beta3
                )

                # Apply confidence-weighted update
                v_effective = v_factored / (
                    state['confidence'].clamp(min=0.1)
                )
            else:
                state['v'].mul_(beta2).add_(grad ** 2, alpha=1 - beta2)
                v_effective = state['v']

            # Bias correction
            m_hat = state['m'] / (1 - beta1 ** self.step_count)

            # Update
            denom = v_effective.sqrt() + self.eps[1]
            update = m_hat / denom

            if self.weight_decay > 0:
                p.data.add_(p.data, alpha=-self.weight_decay * self.lr)

            p.data.add_(update, alpha=-self.lr)
📊

CAME vs Adafactor vs Adam: Training Results

Model / TaskAdam (Loss)Adafactor (Loss)CAME (Loss)CAME Memory vs Adam
T5-Base (C4 pre-training) 2.08 2.12 2.09 ~50% savings
T5-Large (C4 pre-training) 1.85 1.89 1.86 ~50% savings
GPT-2 Medium (OpenWebText) 3.18 3.25 3.19 ~50% savings
ViT-Base (ImageNet) 21.8% error 22.4% error 21.9% error ~50% savings

CAME consistently closes the gap between Adafactor and Adam, achieving near-Adam quality with Adafactor-level memory usage. However, it adds the confidence tensor, which partially offsets the memory savings for 2D parameters.

When To Use SGD vs Adam vs 8-bit Adam vs Adafactor

The Decision Framework

The choice of optimizer depends on three factors: memory budget, training quality requirements, and engineering effort budget.

📊

Optimizer Selection Guide

ScenarioRecommended OptimizerRationaleOptimizer State (per param)
Small model, plenty of memory AdamW Best convergence, simplest 8 bytes
Medium model, memory tight 8-bit AdamW Same convergence, 75% less state ~2 bytes
Large model pre-training 8-bit AdamW or Adafactor Depends on model architecture 2-3 bytes
T5 / mT5 training Adafactor Model designed for it ~1.5 bytes
Fine-tuning with QLoRA 8-bit AdamW (paged) Only trainable params have state 2 bytes
Computer vision (ResNet, ViT) SGD + momentum or AdamW SGD often competitive for CV 4 bytes (SGD+m)
Maximum memory savings needed Adafactor (no first moment) Minimum possible state ~1 byte
Stable Diffusion fine-tuning 8-bit AdamW Proven to work, easy setup ~2 bytes

SGD Is Not Dead

For computer vision tasks, SGD with momentum remains competitive with Adam:

  • ImageNet training: The standard recipe for ResNet-50 uses SGD with momentum. Adam offers marginal improvement at 3x the memory cost.
  • Object detection: YOLO, Faster R-CNN, and DETR all train well with SGD.
  • Fine-tuning large vision models: When fine-tuning ViT or similar models, SGD with a well-tuned learning rate can match Adam.

The key advantage of SGD is its 4-byte state (just momentum) vs. Adam’s 8-byte state. For a ViT-Large (307M params), this saves 2.4 GB — not huge on a single GPU, but meaningful when training on consumer hardware or running multiple experiments.

def choose_optimizer(model, task_type, memory_budget_gb, num_params_billions):
    """
    Recommend an optimizer based on constraints.
    """
    adam_state_gb = num_params_billions * 8  # 8 bytes per param
    sgd_state_gb = num_params_billions * 4  # 4 bytes per param
    eightbit_state_gb = num_params_billions * 2  # ~2 bytes per param
    adafactor_state_gb = num_params_billions * 1.5  # ~1.5 bytes per param

    if task_type in ('image_classification', 'object_detection'):
        if sgd_state_gb <= memory_budget_gb:
            return "SGD + momentum (cosine LR schedule)"
        else:
            return "SGD + momentum (gradient accumulation)"

    if task_type in ('language_modeling', 'seq2seq', 'fine_tuning'):
        if adam_state_gb <= memory_budget_gb:
            return "AdamW (standard)"
        elif eightbit_state_gb <= memory_budget_gb:
            return "8-bit AdamW (bitsandbytes)"
        elif adafactor_state_gb <= memory_budget_gb:
            return "Adafactor (may need HP tuning)"
        else:
            return "8-bit AdamW + gradient accumulation + activation checkpointing"

    return "8-bit AdamW (safe default)"

Complete Memory Savings Analysis

Here is a concrete comparison for a LLaMA-7B fine-tuning scenario:

📊

LLaMA-7B Fine-tuning Memory: Full Model

ComponentAdamW8-bit AdamWAdafactorSGD+Momentum
FP16 Parameters 14 GB 14 GB 14 GB 14 GB
FP32 Master Weights 28 GB 28 GB 28 GB 28 GB
FP16 Gradients 14 GB 14 GB 14 GB 14 GB
Optimizer State 56 GB 14 GB 10.5 GB 28 GB
Total (no activations) 112 GB 70 GB 66.5 GB 84 GB
Min GPUs needed (80GB each) 2 1 1 2
💡 The QLoRA Approach

For fine-tuning, QLoRA (4-bit base model + LoRA adapters + 8-bit Adam on adapters only) reduces total memory even further. Only the LoRA adapter parameters (typically 0.1-1% of the model) have optimizer states. A 7B model can be fine-tuned on a single 24 GB GPU with QLoRA.

Paged Optimizers for GPU Memory Management

bitsandbytes also provides “paged” optimizer variants that use NVIDIA unified memory to automatically page optimizer states between GPU and CPU memory:

import bitsandbytes as bnb

def setup_paged_optimizer(model, lr=1e-4):
    """
    Paged 8-bit Adam: optimizer states are automatically
    paged to CPU when GPU memory is full.

    This prevents OOM errors at the cost of some speed
    when paging occurs.
    """
    optimizer = bnb.optim.PagedAdam8bit(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        weight_decay=0.01,
    )
    return optimizer

Paged optimizers are particularly useful for:

  • Fine-tuning with variable-length sequences (memory usage fluctuates)
  • Training with large batch sizes that temporarily exceed GPU memory
  • Multi-task training where different tasks have different memory profiles

Advanced: Lion, Sophia, and Beyond

The optimizer landscape continues to evolve. Notable recent developments:

Lion (Chen et al., 2023): Uses only the sign of the momentum, eliminating the second moment entirely. Memory cost is identical to SGD with momentum (4 bytes per parameter). Has shown competitive results with Adam on language and vision tasks, though it requires careful learning rate tuning.

Sophia (Liu et al., 2023): Uses a diagonal Hessian estimate instead of the second moment. Claims 2x faster convergence than Adam for LLM pre-training. The Hessian estimate adds some compute overhead but uses the same memory as standard Adam.

LOMO (Lv et al., 2023): Fuses the gradient computation and parameter update, eliminating the need to store gradients separately. This saves the 2 bytes per parameter for gradient storage. Most useful for extremely memory-constrained fine-tuning.

📊

Optimizer Comparison Summary

OptimizerState Bytes/ParamConvergence QualityDrop-in Replacement?Best Use Case
SGD 0 Good for CV No (different HP) Vision, when memory is critical
SGD + Momentum 4 Good for CV No Standard vision training
Adam / AdamW 8 Excellent Baseline Default choice when memory allows
8-bit Adam ~2 Excellent Yes Default when memory is tight
Adafactor ~1.5 Good (needs tuning) No T5-family, extreme memory savings
CAME ~2 Very Good No When Adafactor accuracy is not enough
Lion 4 Good (needs tuning) No Experimental, some LLM training
Sophia 8 Excellent (faster) No LLM pre-training

Optimizer Memory Cost per 1B Parameters

(GB of Optimizer State)
📊 bar chart (GB of Optimizer State)

Practical Implementation: Putting It All Together

Here is a complete training setup that uses 8-bit Adam with mixed precision:

import torch
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb

def create_training_setup(model, lr=2e-5, max_grad_norm=1.0):
    """
    Production training setup with 8-bit Adam and mixed precision.
    """
    # 8-bit AdamW optimizer
    optimizer = bnb.optim.AdamW8bit(
        model.parameters(),
        lr=lr,
        betas=(0.9, 0.999),
        eps=1e-8,
        weight_decay=0.01,
    )

    # Mixed precision scaler
    scaler = GradScaler()

    return optimizer, scaler

def training_step(model, batch, optimizer, scaler, max_grad_norm=1.0):
    """
    Single training step with 8-bit Adam and mixed precision.
    """
    optimizer.zero_grad()

    # Forward pass in mixed precision
    with autocast():
        outputs = model(**batch)
        loss = outputs.loss

    # Backward pass with gradient scaling
    scaler.scale(loss).backward()

    # Gradient clipping (unscale first)
    scaler.unscale_(optimizer)
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    # Optimizer step
    scaler.step(optimizer)
    scaler.update()

    return loss.item()

def print_memory_usage(model, optimizer_name='8-bit AdamW'):
    """Print detailed memory usage breakdown."""
    num_params = sum(p.numel() for p in model.parameters())
    num_trainable = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )

    print(f"Model parameters: {num_params/1e6:.1f}M "
          f"({num_trainable/1e6:.1f}M trainable)")
    print(f"Parameter memory: {num_params * 2 / 1e9:.2f} GB (FP16)")
    print(f"Gradient memory: {num_trainable * 2 / 1e9:.2f} GB (FP16)")

    if optimizer_name == 'AdamW':
        state_gb = num_trainable * 8 / 1e9
    elif optimizer_name == '8-bit AdamW':
        state_gb = num_trainable * 2 / 1e9
    elif optimizer_name == 'Adafactor':
        state_gb = num_trainable * 1.5 / 1e9
    else:
        state_gb = 0

    print(f"Optimizer state: {state_gb:.2f} GB ({optimizer_name})")
    print(f"Total (no activations): "
          f"{num_params * 2 / 1e9 + num_trainable * 2 / 1e9 + state_gb:.2f} GB")

    if torch.cuda.is_available():
        print(f"\nGPU memory allocated: "
              f"{torch.cuda.memory_allocated() / 1e9:.2f} GB")
        print(f"GPU memory reserved: "
              f"{torch.cuda.memory_reserved() / 1e9:.2f} GB")

Conclusion

The memory cost of optimizers is one of the most impactful factors in large model training. Adam’s requirement for 8 bytes of FP32 state per parameter means that optimizer state alone consumes 560 GB for a 70B model — more than all the GPUs in a typical 8-GPU node can hold.

The solutions, ranked by practicality:

8-bit Adam (bitsandbytes) is the clear winner for most use cases. It is a true drop-in replacement that reduces optimizer state by 75% with no measurable impact on convergence. If you are training with Adam today, switch to 8-bit Adam. There is no downside.

Adafactor provides even greater savings (85% reduction) but is not a drop-in replacement. It requires hyperparameter adjustment and may converge differently. It is the right choice for T5-family models (which were designed for it) and for scenarios where even 8-bit Adam’s state is too large.

CAME bridges the gap between Adafactor’s memory efficiency and Adam’s convergence quality. It is a good choice when Adafactor’s approximation error causes problems but you cannot afford full Adam.

SGD with momentum remains relevant for computer vision, where it matches Adam at a fraction of the memory cost. Do not dismiss it for CV tasks.

The broader lesson: optimizer memory is a tax on every parameter in your model, and this tax compounds as models scale. Choosing the right optimizer is not just about convergence — it determines how large a model you can train on your hardware.