The Hidden Cost of Adaptive Optimizers
Every discussion of large model training eventually hits the same wall: GPU memory. A 7B parameter model in FP16 requires 14 GB just for the weights. But the optimizer state — the internal bookkeeping that Adam maintains for every parameter — often requires more memory than the model itself.
This is not a minor implementation detail. For a 70B parameter model trained with standard Adam in mixed precision, the optimizer state alone consumes 840 GB. That is not a typo. Understanding why this happens, and what can be done about it, is essential for anyone training models at scale.
This article dissects the memory cost of optimizers, explains the most important memory-efficient alternatives (8-bit Adam via bitsandbytes, Adafactor, CAME), provides concrete memory savings analysis, and gives a practical decision framework for choosing the right optimizer.
Why Adam Requires 12 Bytes Per Parameter
The Adam Update Rule
Adam maintains two exponential moving averages per parameter: the first moment (mean of gradients) and the second moment (mean of squared gradients). The update at step is:
Each parameter requires storing and , plus the parameter itself, plus the gradient. In mixed-precision training (the standard for modern LLMs), the memory breakdown is:
| Component | Precision | Bytes per Parameter |
|---|---|---|
| Parameters (for training) | FP32 | 4 bytes |
| Gradients | FP16 (or BF16) | 2 bytes |
| First moment | FP32 | 4 bytes |
| Second moment | FP32 | 4 bytes |
| FP16 parameters (for forward/backward) | FP16 | 2 bytes |
| Total | 16 bytes |
The optimizer states ( and ) must be stored in FP32 because they accumulate small updates over many steps. In FP16, the limited precision would cause these running averages to stagnate or diverge. The FP32 master copy of parameters is also required because weight updates are often too small to represent in FP16.
If we isolate just the optimizer’s contribution (what Adam adds beyond what SGD would need), it is the two FP32 state tensors: bytes per parameter. Combined with the FP32 master weights (4 bytes), Adam requires 12 bytes per parameter for its state, compared to SGD’s 4 bytes (just the FP32 master weights).
def calculate_training_memory(num_params_billions, optimizer='adam', precision='mixed'):
"""
Calculate memory requirements for training a model.
Returns breakdown in GB.
"""
num_params = num_params_billions * 1e9
if precision == 'mixed':
# Mixed precision: FP16 forward/backward, FP32 optimizer states
param_memory = num_params * 2 / 1e9 # FP16 params: 2 bytes
master_params = num_params * 4 / 1e9 # FP32 master: 4 bytes
gradient_memory = num_params * 2 / 1e9 # FP16 grads: 2 bytes
elif precision == 'fp32':
param_memory = num_params * 4 / 1e9
master_params = 0
gradient_memory = num_params * 4 / 1e9
else:
raise ValueError(f"Unknown precision: {precision}")
if optimizer == 'adam':
# FP32 first and second moments
optimizer_state = num_params * 8 / 1e9 # 4 + 4 bytes
elif optimizer == 'sgd_momentum':
optimizer_state = num_params * 4 / 1e9 # 4 bytes (momentum)
elif optimizer == 'sgd':
optimizer_state = 0
elif optimizer == '8bit_adam':
# 8-bit first and second moments + quantization overhead
optimizer_state = num_params * 2 / 1e9 # 1 + 1 bytes
optimizer_state += num_params / 2048 * 8 / 1e9 # block scales
elif optimizer == 'adafactor':
# Factored second moment (row + col factors instead of full matrix)
# For simplicity, assume average 2 bytes per parameter
optimizer_state = num_params * 2 / 1e9
else:
raise ValueError(f"Unknown optimizer: {optimizer}")
total = param_memory + master_params + gradient_memory + optimizer_state
return {
'parameters_gb': param_memory,
'master_params_gb': master_params,
'gradients_gb': gradient_memory,
'optimizer_state_gb': optimizer_state,
'total_gb': total,
}
Training Memory Breakdown by Optimizer (Mixed Precision)
| Model | Params (FP16) | Master (FP32) | Gradients (FP16) | Adam State (FP32) | Total |
|---|---|---|---|---|---|
| 1B params | 2 GB | 4 GB | 2 GB | 8 GB | 16 GB |
| 7B params | 14 GB | 28 GB | 14 GB | 56 GB | 112 GB |
| 13B params | 26 GB | 52 GB | 26 GB | 104 GB | 208 GB |
| 70B params | 140 GB | 280 GB | 140 GB | 560 GB | 1,120 GB |
| 175B params | 350 GB | 700 GB | 350 GB | 1,400 GB | 2,800 GB |
A useful rule of thumb: training a model with Adam in mixed precision requires roughly 16 bytes per parameter. A 7B model needs about 112 GB just for parameters, gradients, and optimizer state — before accounting for activations, which can add another 50-200 GB depending on batch size and sequence length.
Where the Memory Actually Goes
For a 7B parameter model trained with Adam in mixed precision, the 112 GB breaks down as:
- FP32 optimizer state (m + v): 56 GB (50%) — This is the dominant cost
- FP32 master weights: 28 GB (25%)
- FP16 parameters: 14 GB (12.5%)
- FP16 gradients: 14 GB (12.5%)
The optimizer state is the single largest memory consumer. This is why memory-efficient optimizers focus on reducing the size of and .
Memory Breakdown: 7B Parameter Model
(GB)8-bit Adam (bitsandbytes)
How It Works
Tim Dettmers’ bitsandbytes library introduced 8-bit Adam (published 2022, but the concepts were explored earlier), which quantizes the optimizer states and from FP32 (32 bits) to INT8 (8 bits). This cuts the optimizer state from 8 bytes per parameter to 2 bytes per parameter.
The key challenge is that naive INT8 quantization of optimizer states destroys training dynamics. The states span a wide range of values, and uniform 8-bit quantization cannot represent both the large values (common early in training) and small values (common late in training) accurately.
bitsandbytes solves this with two techniques:
1. Block-wise quantization: Instead of quantizing all values with a single scale factor, the optimizer states are divided into blocks of 2048 elements. Each block has its own scale factor, which allows different parts of the model to use different dynamic ranges.
2. Dynamic exponent data type: Instead of uniform INT8, bitsandbytes uses a non-uniform quantization that allocates more precision to values near zero (where most optimizer state values cluster). This is similar in spirit to the NormalFloat (NF4) quantization used in QLoRA.
import bitsandbytes as bnb
import torch
def setup_8bit_adam(model, lr=1e-4, weight_decay=0.01):
"""
Replace standard Adam with 8-bit Adam from bitsandbytes.
This is a drop-in replacement that saves ~75% of optimizer memory.
"""
optimizer = bnb.optim.Adam8bit(
model.parameters(),
lr=lr,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=weight_decay,
)
return optimizer
def compare_optimizer_memory(model):
"""
Compare memory usage between standard and 8-bit Adam.
"""
num_params = sum(p.numel() for p in model.parameters())
# Standard Adam: 2 FP32 states per parameter
standard_state_bytes = num_params * 4 * 2 # m and v, each FP32
standard_state_gb = standard_state_bytes / 1e9
# 8-bit Adam: 2 INT8 states + block scales
block_size = 2048
n_blocks = (num_params + block_size - 1) // block_size
eightbit_state_bytes = num_params * 1 * 2 # m and v, each INT8
eightbit_scales_bytes = n_blocks * 4 * 2 # FP32 scale per block
eightbit_total = eightbit_state_bytes + eightbit_scales_bytes
eightbit_gb = eightbit_total / 1e9
savings_pct = (1 - eightbit_gb / standard_state_gb) * 100
return {
'num_params': num_params,
'standard_adam_gb': standard_state_gb,
'8bit_adam_gb': eightbit_gb,
'savings_gb': standard_state_gb - eightbit_gb,
'savings_pct': savings_pct,
}
The Block-wise Quantization Algorithm
The quantization process for each optimizer step:
- After computing the Adam update for and in FP32, divide each state tensor into blocks of 2048 elements
- For each block, compute the absolute maximum value
- Quantize the block to INT8 using that maximum as the scale:
- Store the INT8 values and the FP32 scale factor
- At the next step, dequantize to FP32 for the update computation, then re-quantize
def blockwise_quantize_state(state_fp32, block_size=2048):
"""
Quantize an optimizer state tensor to INT8 using block-wise scaling.
Each block of 2048 elements gets its own scale factor,
preserving dynamic range across different parts of the tensor.
"""
n_elements = state_fp32.numel()
state_flat = state_fp32.flatten()
# Pad to block boundary
n_blocks = (n_elements + block_size - 1) // block_size
padded_size = n_blocks * block_size
if padded_size > n_elements:
state_flat = torch.nn.functional.pad(
state_flat, (0, padded_size - n_elements)
)
# Reshape into blocks
blocks = state_flat.reshape(n_blocks, block_size)
# Compute per-block absolute maximum
absmax = blocks.abs().amax(dim=1, keepdim=True) # (n_blocks, 1)
absmax = absmax.clamp(min=1e-10) # Avoid division by zero
# Quantize to INT8 range [-127, 127]
scale = 127.0 / absmax
quantized = (blocks * scale).round().clamp(-127, 127).to(torch.int8)
# Return quantized values and scale factors
return quantized.flatten()[:n_elements], absmax.flatten()
def blockwise_dequantize_state(quantized_int8, scales, block_size=2048):
"""
Dequantize INT8 optimizer state back to FP32 for update computation.
"""
n_elements = quantized_int8.numel()
n_blocks = scales.numel()
# Pad if necessary
padded_size = n_blocks * block_size
q_flat = quantized_int8.float()
if padded_size > n_elements:
q_flat = torch.nn.functional.pad(q_flat, (0, padded_size - n_elements))
# Reshape and apply scales
blocks = q_flat.reshape(n_blocks, block_size)
dequantized = blocks / 127.0 * scales.unsqueeze(1)
return dequantized.flatten()[:n_elements]
Training Quality With 8-bit Adam
The critical question: does 8-bit Adam converge to the same accuracy as standard Adam? The empirical evidence is strongly positive:
8-bit Adam vs Standard Adam: Training Quality
| Model / Task | Standard Adam (Loss / Accuracy) | 8-bit Adam (Loss / Accuracy) | Difference | Memory Saved |
|---|---|---|---|---|
| BERT-Base (MNLI) | 84.6% accuracy | 84.5% accuracy | -0.1% | 5.4 GB |
| GPT-2 (117M, WikiText) | 21.1 perplexity | 21.2 perplexity | +0.1 ppl | 0.9 GB |
| RoBERTa-Large (SQuAD) | 88.9 F1 | 88.8 F1 | -0.1 F1 | 2.7 GB |
| T5-3B (SuperGLUE) | 89.2% avg | 89.1% avg | -0.1% | 22.4 GB |
| LLaMA-7B (fine-tuning) | 5.68 perplexity | 5.70 perplexity | +0.02 ppl | 42 GB |
| ViT-Large (ImageNet) | 85.2% top-1 | 85.1% top-1 | -0.1% | 2.4 GB |
The accuracy difference is consistently within noise. The bitsandbytes paper (Dettmers et al., 2022) showed that 8-bit Adam matches standard Adam across a wide range of models and tasks, with no cases of significant degradation.
In nearly all practical scenarios, 8-bit Adam is strictly better than standard Adam: same convergence, 75% less optimizer memory, negligible compute overhead. There is almost no reason to use standard Adam when bitsandbytes is available. The only exception is if you are running into numerical stability issues in unusual training setups.
Adafactor: Factored Second Moments
The Factorization Idea
Adafactor (Shazeer and Stern, 2018) takes a more radical approach to memory reduction. Instead of storing the full second-moment matrix (which has the same shape as the parameter tensor), Adafactor stores only row and column factors whose outer product approximates .
For a weight matrix , standard Adam stores (the full second moment). Adafactor stores:
(row factor) and (column factor)
The approximation is:
This reduces memory from to , which for large matrices is a dramatic saving. For a attention weight matrix, this reduces the second moment from 16M values (64 MB in FP32) to 8192 values (32 KB in FP32) — a 2000x reduction for that tensor.
Adafactor also optionally eliminates the first moment entirely by using a sliding window momentum estimation, further reducing memory.
class SimplifiedAdafactor:
"""
Simplified Adafactor implementation illustrating the key ideas.
For production use, use the implementation in transformers or fairseq.
"""
def __init__(self, params, lr=None, eps=(1e-30, 1e-3),
clip_threshold=1.0, decay_rate=-0.8,
beta1=None, weight_decay=0.0,
scale_parameter=True, relative_step=True):
self.params = list(params)
self.lr = lr
self.eps = eps
self.clip_threshold = clip_threshold
self.decay_rate = decay_rate
self.beta1 = beta1 # None = no first moment (saves memory)
self.weight_decay = weight_decay
self.scale_parameter = scale_parameter
self.relative_step = relative_step
self.step_count = 0
# Initialize states
self.state = {}
for p in self.params:
state = {}
if len(p.shape) >= 2:
# 2D+ tensors: use factored second moment
state['v_row'] = torch.zeros(p.shape[:-1], device=p.device)
state['v_col'] = torch.zeros(
p.shape[:-2] + (p.shape[-1],), device=p.device
)
else:
# 1D tensors (biases, norms): use full second moment
state['v'] = torch.zeros_like(p)
if beta1 is not None:
state['m'] = torch.zeros_like(p)
self.state[id(p)] = state
def _rms(self, tensor):
"""Root mean square of tensor elements."""
return tensor.norm(2) / (tensor.numel() ** 0.5)
def _get_lr(self, step):
"""Compute learning rate (relative step schedule)."""
if self.relative_step:
return max(1e-6, min(1e-2, 1.0 / (step ** 0.5)))
return self.lr
def _get_rho(self, step):
"""Compute second moment decay rate."""
return min(1 - step ** self.decay_rate, 0.999)
def step(self):
self.step_count += 1
lr = self._get_lr(self.step_count)
rho = self._get_rho(self.step_count)
for p in self.params:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[id(p)]
if len(p.shape) >= 2:
# Factored second moment update
# Update row factor
state['v_row'].mul_(rho).add_(
(grad ** 2).mean(dim=-1), alpha=1 - rho
)
# Update column factor
state['v_col'].mul_(rho).add_(
(grad ** 2).mean(dim=-2), alpha=1 - rho
)
# Reconstruct approximate v from factors
row_mean = state['v_row'].mean(dim=-1, keepdim=True)
v_approx = (
state['v_row'].unsqueeze(-1)
* state['v_col'].unsqueeze(-2)
/ row_mean.unsqueeze(-1).clamp(min=self.eps[0])
)
else:
# Full second moment for 1D parameters
state['v'].mul_(rho).add_(grad ** 2, alpha=1 - rho)
v_approx = state['v']
# Compute update
update = grad / (v_approx.sqrt() + self.eps[0])
# RMS clipping
update_rms = self._rms(update)
if update_rms > self.clip_threshold:
update.mul_(self.clip_threshold / update_rms)
# Optional first moment
if self.beta1 is not None:
state['m'].mul_(self.beta1).add_(update, alpha=1 - self.beta1)
update = state['m']
# Scale by parameter RMS if enabled
if self.scale_parameter:
param_rms = self._rms(p.data).clamp(min=self.eps[1])
lr_scaled = lr * param_rms
else:
lr_scaled = lr
# Weight decay
if self.weight_decay > 0:
p.data.add_(p.data, alpha=-self.weight_decay * lr_scaled)
# Apply update
p.data.add_(update, alpha=-lr_scaled)
Adafactor Memory Savings
Memory Savings: Adam vs 8-bit Adam vs Adafactor
| Model | Adam State | 8-bit Adam State | Adafactor State | Adafactor Savings vs Adam |
|---|---|---|---|---|
| BERT-Base (110M) | 0.88 GB | 0.22 GB | 0.14 GB | 84% |
| GPT-2 (1.5B) | 12 GB | 3 GB | 1.8 GB | 85% |
| T5-3B | 24 GB | 6 GB | 3.6 GB | 85% |
| LLaMA-7B | 56 GB | 14 GB | 8.4 GB | 85% |
| LLaMA-70B | 560 GB | 140 GB | 84 GB | 85% |
Adafactor Trade-offs
Adafactor is not a drop-in replacement for Adam. It has meaningful trade-offs:
Convergence differences: Adafactor often converges to slightly different (sometimes worse) solutions than Adam. The factored approximation of the second moment is not exact, and this affects the effective learning rate per-parameter. T5 was specifically designed and tuned for Adafactor, so it works well there. For models designed for Adam, switching to Adafactor may require hyperparameter re-tuning.
The first moment question: Adafactor can run without a first moment (), which saves even more memory. But removing the momentum term often slows convergence, especially for tasks with noisy gradients.
Relative step size: Adafactor’s default learning rate schedule () is different from the typical warmup + cosine schedule used with Adam. Using Adafactor with a fixed learning rate is possible but requires different hyperparameters.
Unlike 8-bit Adam, which is a true drop-in replacement for Adam, Adafactor changes the optimization dynamics. If you switch from Adam to Adafactor, expect to re-tune learning rate, warmup steps, and potentially other hyperparameters. The memory savings are larger than 8-bit Adam, but the engineering cost is also higher.
CAME: Confidence-Guided Adaptive Memory-Efficient Optimization
Beyond Adafactor
CAME (Luo et al., 2023) builds on Adafactor’s factored second moment idea but adds a confidence-guided mechanism that improves convergence. The key insight is that Adafactor’s factored approximation introduces error, and this error can be partially corrected by weighting the update based on how well the factored approximation matches the true second moment.
CAME maintains the same row and column factors as Adafactor but adds a “confidence” term that measures the quality of the approximation:
In practice, since the true second moment is not available (that is the whole point of factoring), CAME uses the current gradient squared as a proxy.
class SimplifiedCAME:
"""
Simplified CAME optimizer illustrating the confidence mechanism.
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999, 0.9999),
eps=(1e-30, 1e-6), weight_decay=0.0):
self.params = list(params)
self.lr = lr
self.betas = betas # beta1 for m, beta2 for v, beta3 for confidence
self.eps = eps
self.weight_decay = weight_decay
self.step_count = 0
self.state = {}
for p in self.params:
state = {
'm': torch.zeros_like(p),
}
if len(p.shape) >= 2:
state['v_row'] = torch.zeros(p.shape[:-1], device=p.device)
state['v_col'] = torch.zeros(
p.shape[:-2] + (p.shape[-1],), device=p.device
)
# CAME addition: confidence accumulator
state['confidence'] = torch.ones_like(p)
else:
state['v'] = torch.zeros_like(p)
self.state[id(p)] = state
def step(self):
self.step_count += 1
beta1, beta2, beta3 = self.betas
for p in self.params:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[id(p)]
# First moment update (same as Adam)
state['m'].mul_(beta1).add_(grad, alpha=1 - beta1)
if len(p.shape) >= 2:
# Factored second moment (same as Adafactor)
grad_sq = grad ** 2
state['v_row'].mul_(beta2).add_(
grad_sq.mean(dim=-1), alpha=1 - beta2
)
state['v_col'].mul_(beta2).add_(
grad_sq.mean(dim=-2), alpha=1 - beta2
)
# Reconstruct factored approximation
row_mean = state['v_row'].mean(dim=-1, keepdim=True)
v_factored = (
state['v_row'].unsqueeze(-1)
* state['v_col'].unsqueeze(-2)
/ row_mean.unsqueeze(-1).clamp(min=self.eps[0])
)
# CAME: compute confidence from approximation quality
approx_error = (grad_sq - v_factored).abs()
confidence_update = v_factored / (
v_factored + approx_error + self.eps[0]
)
state['confidence'].mul_(beta3).add_(
confidence_update, alpha=1 - beta3
)
# Apply confidence-weighted update
v_effective = v_factored / (
state['confidence'].clamp(min=0.1)
)
else:
state['v'].mul_(beta2).add_(grad ** 2, alpha=1 - beta2)
v_effective = state['v']
# Bias correction
m_hat = state['m'] / (1 - beta1 ** self.step_count)
# Update
denom = v_effective.sqrt() + self.eps[1]
update = m_hat / denom
if self.weight_decay > 0:
p.data.add_(p.data, alpha=-self.weight_decay * self.lr)
p.data.add_(update, alpha=-self.lr)
CAME vs Adafactor vs Adam: Training Results
| Model / Task | Adam (Loss) | Adafactor (Loss) | CAME (Loss) | CAME Memory vs Adam |
|---|---|---|---|---|
| T5-Base (C4 pre-training) | 2.08 | 2.12 | 2.09 | ~50% savings |
| T5-Large (C4 pre-training) | 1.85 | 1.89 | 1.86 | ~50% savings |
| GPT-2 Medium (OpenWebText) | 3.18 | 3.25 | 3.19 | ~50% savings |
| ViT-Base (ImageNet) | 21.8% error | 22.4% error | 21.9% error | ~50% savings |
CAME consistently closes the gap between Adafactor and Adam, achieving near-Adam quality with Adafactor-level memory usage. However, it adds the confidence tensor, which partially offsets the memory savings for 2D parameters.
When To Use SGD vs Adam vs 8-bit Adam vs Adafactor
The Decision Framework
The choice of optimizer depends on three factors: memory budget, training quality requirements, and engineering effort budget.
Optimizer Selection Guide
| Scenario | Recommended Optimizer | Rationale | Optimizer State (per param) |
|---|---|---|---|
| Small model, plenty of memory | AdamW | Best convergence, simplest | 8 bytes |
| Medium model, memory tight | 8-bit AdamW | Same convergence, 75% less state | ~2 bytes |
| Large model pre-training | 8-bit AdamW or Adafactor | Depends on model architecture | 2-3 bytes |
| T5 / mT5 training | Adafactor | Model designed for it | ~1.5 bytes |
| Fine-tuning with QLoRA | 8-bit AdamW (paged) | Only trainable params have state | 2 bytes |
| Computer vision (ResNet, ViT) | SGD + momentum or AdamW | SGD often competitive for CV | 4 bytes (SGD+m) |
| Maximum memory savings needed | Adafactor (no first moment) | Minimum possible state | ~1 byte |
| Stable Diffusion fine-tuning | 8-bit AdamW | Proven to work, easy setup | ~2 bytes |
SGD Is Not Dead
For computer vision tasks, SGD with momentum remains competitive with Adam:
- ImageNet training: The standard recipe for ResNet-50 uses SGD with momentum. Adam offers marginal improvement at 3x the memory cost.
- Object detection: YOLO, Faster R-CNN, and DETR all train well with SGD.
- Fine-tuning large vision models: When fine-tuning ViT or similar models, SGD with a well-tuned learning rate can match Adam.
The key advantage of SGD is its 4-byte state (just momentum) vs. Adam’s 8-byte state. For a ViT-Large (307M params), this saves 2.4 GB — not huge on a single GPU, but meaningful when training on consumer hardware or running multiple experiments.
def choose_optimizer(model, task_type, memory_budget_gb, num_params_billions):
"""
Recommend an optimizer based on constraints.
"""
adam_state_gb = num_params_billions * 8 # 8 bytes per param
sgd_state_gb = num_params_billions * 4 # 4 bytes per param
eightbit_state_gb = num_params_billions * 2 # ~2 bytes per param
adafactor_state_gb = num_params_billions * 1.5 # ~1.5 bytes per param
if task_type in ('image_classification', 'object_detection'):
if sgd_state_gb <= memory_budget_gb:
return "SGD + momentum (cosine LR schedule)"
else:
return "SGD + momentum (gradient accumulation)"
if task_type in ('language_modeling', 'seq2seq', 'fine_tuning'):
if adam_state_gb <= memory_budget_gb:
return "AdamW (standard)"
elif eightbit_state_gb <= memory_budget_gb:
return "8-bit AdamW (bitsandbytes)"
elif adafactor_state_gb <= memory_budget_gb:
return "Adafactor (may need HP tuning)"
else:
return "8-bit AdamW + gradient accumulation + activation checkpointing"
return "8-bit AdamW (safe default)"
Complete Memory Savings Analysis
Here is a concrete comparison for a LLaMA-7B fine-tuning scenario:
LLaMA-7B Fine-tuning Memory: Full Model
| Component | AdamW | 8-bit AdamW | Adafactor | SGD+Momentum |
|---|---|---|---|---|
| FP16 Parameters | 14 GB | 14 GB | 14 GB | 14 GB |
| FP32 Master Weights | 28 GB | 28 GB | 28 GB | 28 GB |
| FP16 Gradients | 14 GB | 14 GB | 14 GB | 14 GB |
| Optimizer State | 56 GB | 14 GB | 10.5 GB | 28 GB |
| Total (no activations) | 112 GB | 70 GB | 66.5 GB | 84 GB |
| Min GPUs needed (80GB each) | 2 | 1 | 1 | 2 |
For fine-tuning, QLoRA (4-bit base model + LoRA adapters + 8-bit Adam on adapters only) reduces total memory even further. Only the LoRA adapter parameters (typically 0.1-1% of the model) have optimizer states. A 7B model can be fine-tuned on a single 24 GB GPU with QLoRA.
Paged Optimizers for GPU Memory Management
bitsandbytes also provides “paged” optimizer variants that use NVIDIA unified memory to automatically page optimizer states between GPU and CPU memory:
import bitsandbytes as bnb
def setup_paged_optimizer(model, lr=1e-4):
"""
Paged 8-bit Adam: optimizer states are automatically
paged to CPU when GPU memory is full.
This prevents OOM errors at the cost of some speed
when paging occurs.
"""
optimizer = bnb.optim.PagedAdam8bit(
model.parameters(),
lr=lr,
betas=(0.9, 0.999),
weight_decay=0.01,
)
return optimizer
Paged optimizers are particularly useful for:
- Fine-tuning with variable-length sequences (memory usage fluctuates)
- Training with large batch sizes that temporarily exceed GPU memory
- Multi-task training where different tasks have different memory profiles
Advanced: Lion, Sophia, and Beyond
The optimizer landscape continues to evolve. Notable recent developments:
Lion (Chen et al., 2023): Uses only the sign of the momentum, eliminating the second moment entirely. Memory cost is identical to SGD with momentum (4 bytes per parameter). Has shown competitive results with Adam on language and vision tasks, though it requires careful learning rate tuning.
Sophia (Liu et al., 2023): Uses a diagonal Hessian estimate instead of the second moment. Claims 2x faster convergence than Adam for LLM pre-training. The Hessian estimate adds some compute overhead but uses the same memory as standard Adam.
LOMO (Lv et al., 2023): Fuses the gradient computation and parameter update, eliminating the need to store gradients separately. This saves the 2 bytes per parameter for gradient storage. Most useful for extremely memory-constrained fine-tuning.
Optimizer Comparison Summary
| Optimizer | State Bytes/Param | Convergence Quality | Drop-in Replacement? | Best Use Case |
|---|---|---|---|---|
| SGD | 0 | Good for CV | No (different HP) | Vision, when memory is critical |
| SGD + Momentum | 4 | Good for CV | No | Standard vision training |
| Adam / AdamW | 8 | Excellent | Baseline | Default choice when memory allows |
| 8-bit Adam | ~2 | Excellent | Yes | Default when memory is tight |
| Adafactor | ~1.5 | Good (needs tuning) | No | T5-family, extreme memory savings |
| CAME | ~2 | Very Good | No | When Adafactor accuracy is not enough |
| Lion | 4 | Good (needs tuning) | No | Experimental, some LLM training |
| Sophia | 8 | Excellent (faster) | No | LLM pre-training |
Optimizer Memory Cost per 1B Parameters
(GB of Optimizer State)Practical Implementation: Putting It All Together
Here is a complete training setup that uses 8-bit Adam with mixed precision:
import torch
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb
def create_training_setup(model, lr=2e-5, max_grad_norm=1.0):
"""
Production training setup with 8-bit Adam and mixed precision.
"""
# 8-bit AdamW optimizer
optimizer = bnb.optim.AdamW8bit(
model.parameters(),
lr=lr,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0.01,
)
# Mixed precision scaler
scaler = GradScaler()
return optimizer, scaler
def training_step(model, batch, optimizer, scaler, max_grad_norm=1.0):
"""
Single training step with 8-bit Adam and mixed precision.
"""
optimizer.zero_grad()
# Forward pass in mixed precision
with autocast():
outputs = model(**batch)
loss = outputs.loss
# Backward pass with gradient scaling
scaler.scale(loss).backward()
# Gradient clipping (unscale first)
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
# Optimizer step
scaler.step(optimizer)
scaler.update()
return loss.item()
def print_memory_usage(model, optimizer_name='8-bit AdamW'):
"""Print detailed memory usage breakdown."""
num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(
p.numel() for p in model.parameters() if p.requires_grad
)
print(f"Model parameters: {num_params/1e6:.1f}M "
f"({num_trainable/1e6:.1f}M trainable)")
print(f"Parameter memory: {num_params * 2 / 1e9:.2f} GB (FP16)")
print(f"Gradient memory: {num_trainable * 2 / 1e9:.2f} GB (FP16)")
if optimizer_name == 'AdamW':
state_gb = num_trainable * 8 / 1e9
elif optimizer_name == '8-bit AdamW':
state_gb = num_trainable * 2 / 1e9
elif optimizer_name == 'Adafactor':
state_gb = num_trainable * 1.5 / 1e9
else:
state_gb = 0
print(f"Optimizer state: {state_gb:.2f} GB ({optimizer_name})")
print(f"Total (no activations): "
f"{num_params * 2 / 1e9 + num_trainable * 2 / 1e9 + state_gb:.2f} GB")
if torch.cuda.is_available():
print(f"\nGPU memory allocated: "
f"{torch.cuda.memory_allocated() / 1e9:.2f} GB")
print(f"GPU memory reserved: "
f"{torch.cuda.memory_reserved() / 1e9:.2f} GB")
Conclusion
The memory cost of optimizers is one of the most impactful factors in large model training. Adam’s requirement for 8 bytes of FP32 state per parameter means that optimizer state alone consumes 560 GB for a 70B model — more than all the GPUs in a typical 8-GPU node can hold.
The solutions, ranked by practicality:
8-bit Adam (bitsandbytes) is the clear winner for most use cases. It is a true drop-in replacement that reduces optimizer state by 75% with no measurable impact on convergence. If you are training with Adam today, switch to 8-bit Adam. There is no downside.
Adafactor provides even greater savings (85% reduction) but is not a drop-in replacement. It requires hyperparameter adjustment and may converge differently. It is the right choice for T5-family models (which were designed for it) and for scenarios where even 8-bit Adam’s state is too large.
CAME bridges the gap between Adafactor’s memory efficiency and Adam’s convergence quality. It is a good choice when Adafactor’s approximation error causes problems but you cannot afford full Adam.
SGD with momentum remains relevant for computer vision, where it matches Adam at a fraction of the memory cost. Do not dismiss it for CV tasks.
The broader lesson: optimizer memory is a tax on every parameter in your model, and this tax compounds as models scale. Choosing the right optimizer is not just about convergence — it determines how large a model you can train on your hardware.