Jamba replaces two-thirds of transformer layers with Mamba state-space layers that cost instead of . The result: 12x faster prefill at 256K context length compared to pure attention, with only a 2-3 point drop on reasoning benchmarks. The architecture is a bet that most tokens do not need full attention — cheap recurrence handles sequential dependencies, and expensive attention fires only when global context matters. If Jamba proves Mamba quality at scale, every frontier lab will adopt hybrid architectures within 18 months.
Architecture Design
import torch
import torch.nn as nn
import torch.nn.functional as F
class JambaConfig:
"""Jamba architecture configuration."""
hidden_size = 4096
num_layers = 32 # Total layers
attention_layers = [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30] # 11 attn layers
mamba_layers = [1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17,
19, 20, 22, 23, 25, 26, 28, 29, 31] # 21 mamba layers
# Attention config (for attention layers)
num_attention_heads = 32
num_key_value_heads = 8 # GQA
head_dim = 128
# Mamba config (for mamba layers)
mamba_d_state = 16 # SSM state dimension
mamba_d_conv = 4 # Local convolution width
mamba_expand = 2 # Expansion factor
mamba_dt_rank = 'auto' # Delta rank
# MoE config (applied to some layers)
num_experts = 16
num_experts_per_tok = 2
moe_layers = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31] # Every 3rd Mamba layer
# Total params: ~52B (12B active per token)
vocab_size = 65536
class JambaBlock(nn.Module):
"""
A Jamba block can be either Mamba-based or Attention-based,
optionally with MoE for the FFN.
"""
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
self.use_attention = layer_idx in config.attention_layers
self.use_moe = layer_idx in config.moe_layers
# Layer norm
self.input_norm = nn.RMSNorm(config.hidden_size)
self.post_ffn_norm = nn.RMSNorm(config.hidden_size)
# Sequence mixing: either Mamba or Attention
if self.use_attention:
self.sequence_mixer = TransformerAttention(config)
else:
self.sequence_mixer = MambaLayer(config)
# Channel mixing: either dense FFN or MoE
if self.use_moe:
self.channel_mixer = MoEFFN(config)
else:
self.channel_mixer = DenseFFN(config)
def forward(self, x, cache=None):
# Sequence mixing (Mamba or Attention)
residual = x
x = self.input_norm(x)
x = self.sequence_mixer(x, cache=cache)
x = residual + x
# Channel mixing (Dense or MoE FFN)
residual = x
x = self.post_ffn_norm(x)
x = self.channel_mixer(x)
x = residual + x
return x
Mamba Layer Implementation
class MambaLayer(nn.Module):
"""
Selective State-Space Model (Mamba) layer.
Key property: O(n) in sequence length during generation,
versus O(n^2) for attention. Maintains a fixed-size state
that is updated per token, rather than a growing KV cache.
"""
def __init__(self, config):
super().__init__()
self.d_model = config.hidden_size
self.d_state = config.mamba_d_state # 16
self.d_conv = config.mamba_d_conv # 4
self.expand = config.mamba_expand # 2
self.d_inner = self.d_model * self.expand # 8192
# Input projection: [d_model] -> [2 * d_inner]
# Split into x and z paths
self.in_proj = nn.Linear(self.d_model, 2 * self.d_inner, bias=False)
# 1D convolution (local context)
self.conv1d = nn.Conv1d(
self.d_inner, self.d_inner,
kernel_size=self.d_conv,
padding=self.d_conv - 1,
groups=self.d_inner,
)
# SSM parameters (input-dependent = "selective")
# dt, B, C are functions of the input, not fixed
self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=True)
self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
# A is fixed (not input-dependent) — log-space for stability
self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))
# D is a skip connection parameter
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
def forward(self, x, cache=None):
"""
x: [batch, seq_len, d_model]
cache: (conv_state, ssm_state) for incremental generation
During generation, this processes ONE token at a time
with O(1) computation (vs O(seq_len) for attention).
"""
batch, seq_len, _ = x.shape
# Project input
xz = self.in_proj(x) # [batch, seq, 2*d_inner]
x_path, z = xz.chunk(2, dim=-1) # Each: [batch, seq, d_inner]
# 1D convolution (local context mixing)
x_path = x_path.transpose(1, 2) # [batch, d_inner, seq]
if cache is not None and cache[0] is not None:
# Incremental: use cached conv state
x_path = torch.cat([cache[0], x_path], dim=-1)
conv_out = self.conv1d(x_path)[..., -seq_len:]
new_conv_state = x_path[..., -(self.d_conv - 1):]
else:
conv_out = self.conv1d(x_path)[..., :seq_len]
new_conv_state = x_path[..., -(self.d_conv - 1):]
x_path = conv_out.transpose(1, 2) # [batch, seq, d_inner]
x_path = F.silu(x_path)
# Selective SSM
A = -torch.exp(self.A_log) # [d_inner, d_state]
dt = F.softplus(self.dt_proj(x_path)) # [batch, seq, d_inner]
B = self.B_proj(x_path) # [batch, seq, d_state]
C = self.C_proj(x_path) # [batch, seq, d_state]
# SSM recurrence
if cache is not None and cache[1] is not None:
ssm_state = cache[1] # [batch, d_inner, d_state]
else:
ssm_state = torch.zeros(
batch, self.d_inner, self.d_state,
device=x.device, dtype=x.dtype
)
outputs = []
for t in range(seq_len):
# Discretize A and B
dt_t = dt[:, t, :].unsqueeze(-1) # [batch, d_inner, 1]
A_bar = torch.exp(A.unsqueeze(0) * dt_t) # [batch, d_inner, d_state]
B_t = B[:, t, :].unsqueeze(1) # [batch, 1, d_state]
x_t = x_path[:, t, :].unsqueeze(-1) # [batch, d_inner, 1]
# State update: h_t = A_bar * h_{t-1} + B_bar * x_t
ssm_state = A_bar * ssm_state + x_t * B_t
# Output: y_t = C_t * h_t
C_t = C[:, t, :].unsqueeze(1) # [batch, 1, d_state]
y_t = (ssm_state * C_t).sum(dim=-1) # [batch, d_inner]
# Skip connection
y_t = y_t + self.D * x_path[:, t, :]
outputs.append(y_t)
y = torch.stack(outputs, dim=1) # [batch, seq, d_inner]
y = y * F.silu(z) # Gate with z path
output = self.out_proj(y)
new_cache = (new_conv_state, ssm_state)
return output
Memory Efficiency: Mamba vs Attention
def compare_kv_cache_vs_ssm_state():
"""
The key advantage of Mamba layers: fixed-size state vs growing KV cache.
"""
hidden_dim = 4096
num_kv_heads = 8
head_dim = 128
d_inner = 8192
d_state = 16
contexts = [1024, 4096, 16384, 65536, 262144]
for ctx in contexts:
# Attention KV cache per layer
kv_bytes = 2 * num_kv_heads * head_dim * ctx * 2 # K+V, FP16
kv_mb = kv_bytes / 1e6
# Mamba SSM state per layer (constant!)
ssm_bytes = d_inner * d_state * 2 # FP16
conv_bytes = d_inner * 3 * 2 # Conv state
mamba_bytes = ssm_bytes + conv_bytes
mamba_mb = mamba_bytes / 1e6
ratio = kv_bytes / mamba_bytes
print(f"Context {ctx:>7,}: "
f"Attention KV={kv_mb:>8.1f}MB | "
f"Mamba state={mamba_mb:>5.3f}MB | "
f"Ratio={ratio:>6.0f}x")
# Context 1,024: Attention KV= 4.0MB | Mamba state=0.311MB | Ratio= 13x
# Context 4,096: Attention KV= 16.0MB | Mamba state=0.311MB | Ratio= 51x
# Context 16,384: Attention KV= 64.0MB | Mamba state=0.311MB | Ratio= 206x
# Context 65,536: Attention KV= 256.0MB | Mamba state=0.311MB | Ratio= 823x
# Context 262,144: Attention KV= 1,024.0MB | Mamba state=0.311MB | Ratio= 3,293x
Per-Layer State Memory: Attention vs Mamba
| Context Length | Attention KV Cache | Mamba State | Savings Ratio | Impact at 32 Layers |
|---|---|---|---|---|
| 1K | 4.0 MB | 0.31 MB | 13x | 118 MB saved |
| 4K | 16.0 MB | 0.31 MB | 51x | 502 MB saved |
| 16K | 64.0 MB | 0.31 MB | 206x | 2.0 GB saved |
| 64K | 256.0 MB | 0.31 MB | 823x | 8.2 GB saved |
| 256K | 1,024.0 MB | 0.31 MB | 3,293x | 32.8 GB saved |
At 256K context, a single Mamba layer stores 0.31 MB of state versus 1,024 MB of KV cache for an attention layer. In Jamba’s hybrid design with 21 Mamba layers and 11 attention layers, this means 21 layers have negligible state memory while only 11 layers need KV cache. The total KV cache is reduced by 65% compared to a pure-attention model with the same layer count.
Jamba’s Hybrid Memory Budget
def jamba_memory_analysis(context_length=256000):
"""
Analyze Jamba's total memory footprint with hybrid layers.
"""
# Jamba: 11 attention layers + 21 Mamba layers
num_attn_layers = 11
num_mamba_layers = 21
# Attention layers: need KV cache
num_kv_heads = 8
head_dim = 128
kv_per_token_per_layer = 2 * num_kv_heads * head_dim * 2 # 4096 bytes
attn_kv_total = kv_per_token_per_layer * context_length * num_attn_layers
# Mamba layers: fixed state
d_inner = 8192
d_state = 16
mamba_state_per_layer = d_inner * d_state * 2 + d_inner * 3 * 2
mamba_total = mamba_state_per_layer * num_mamba_layers
# Pure attention baseline (32 layers)
pure_attn_kv = kv_per_token_per_layer * context_length * 32
print(f"Context length: {context_length:,}")
print(f"Jamba attention KV: {attn_kv_total/1e9:.2f} GB ({num_attn_layers} layers)")
print(f"Jamba Mamba state: {mamba_total/1e6:.2f} MB ({num_mamba_layers} layers)")
print(f"Jamba total state: {(attn_kv_total + mamba_total)/1e9:.2f} GB")
print(f"Pure attention KV: {pure_attn_kv/1e9:.2f} GB (32 layers)")
print(f"Savings: {(1 - (attn_kv_total + mamba_total)/pure_attn_kv)*100:.0f}%")
# 256K context:
# Jamba attention KV: 11.26 GB (11 layers)
# Jamba Mamba state: 6.54 MB (21 layers)
# Jamba total state: 11.27 GB
# Pure attention KV: 32.77 GB (32 layers)
# Savings: 66%
Inference State Memory at 256K Context
Layer Interleaving Design
def analyze_interleaving_patterns():
"""
The ratio and pattern of Mamba vs Attention layers matters.
AI21 settled on a 2:1 Mamba:Attention ratio after experimentation.
"""
patterns = {
'jamba_default': {
'pattern': 'MAAMMAAMMAAMMAAM...', # M=Mamba, A=Attention
'attn_ratio': 11/32,
'description': 'Attention every 3rd layer',
},
'all_attention': {
'pattern': 'AAAAAAAAAA...',
'attn_ratio': 1.0,
'description': 'Standard transformer',
},
'all_mamba': {
'pattern': 'MMMMMMMMMM...',
'attn_ratio': 0.0,
'description': 'Pure Mamba (struggles with in-context learning)',
},
'alternating': {
'pattern': 'MAMAMAMAMA...',
'attn_ratio': 0.5,
'description': '1:1 ratio — more memory but better quality',
},
'sparse_attention': {
'pattern': 'MMMAMMMAAMMMMAMMMA...',
'attn_ratio': 0.2,
'description': 'Attention only every 5th layer',
},
}
for name, info in patterns.items():
print(f"{name:20s}: attn_ratio={info['attn_ratio']:.1%}, "
f"{info['description']}")
Layer Interleaving Impact (52B scale, 256K context)
| Pattern | Attn Ratio | MMLU | Long-Context Recall | State Memory (256K) | Speed (tok/s) |
|---|---|---|---|---|---|
| Pure Transformer | 100% | 72.1 | 98% | 32.8 GB | 28 |
| 1:1 Alternating | 50% | 71.5 | 96% | 16.4 GB | 42 |
| Jamba (1:2) | 34% | 70.8 | 94% | 11.3 GB | 51 |
| Sparse (1:5) | 20% | 68.2 | 87% | 6.6 GB | 58 |
| Pure Mamba | 0% | 62.4 | 72% | 0.007 GB | 85 |
The 1:2 ratio (Jamba’s choice) represents a strong quality-efficiency tradeoff: only 1.3 MMLU points below a pure transformer but 2.9x less state memory and 1.8x higher throughput. Pure Mamba suffers significantly on quality, particularly on tasks requiring long-range exact recall (like needle-in-haystack), where attention excels.
MoE Integration
class JambaMoEConfig:
"""
Jamba applies MoE to a subset of the FFN layers.
Not every layer uses MoE — only every 3rd Mamba layer.
"""
# MoE layers: 11 out of 32 layers use MoE
# Dense FFN layers: 21 out of 32
# This means ~35% of FFN layers use MoE
# With MoE: 16 experts, top-2 routing
# Active params per token: ~12B (out of ~52B total)
moe_expert_count = 16
moe_top_k = 2
moe_ffn_dim = 14336 # Per expert
dense_ffn_dim = 14336 # Dense FFN (same size as single expert)
# The interplay: Mamba layers provide cheap sequence mixing,
# MoE layers provide cheap channel mixing.
# Both save compute compared to their dense alternatives.
Jamba vs Comparable Models
| Model | Total Params | Active Params | MMLU | HellaSwag | Context | tok/s (bs=1) |
|---|---|---|---|---|---|---|
| Jamba 52B | 52B | 12B | 70.8 | 87.1 | 256K | 51 |
| Mixtral 8x7B | 47B | 13B | 70.6 | 86.5 | 32K | 48 |
| Llama 2 70B | 70B | 70B | 69.8 | 85.3 | 4K | 18 |
| Mamba 2.8B | 2.8B | 2.8B | 46.2 | 72.4 | Unlimited* | 120 |
| Command R 35B | 35B | 35B | 68.4 | 84.7 | 128K | 32 |
Generation Performance
def jamba_generation_performance():
"""
Jamba's generation speed benefits come from:
1. Mamba layers: O(1) per token (no KV cache lookup)
2. Fewer attention layers: smaller KV cache to attend over
3. MoE: only 2 of 16 experts computed per token
"""
# Per-token compute breakdown
per_token = {
'mamba_layers': {
'count': 21,
'flops_per_layer': 2 * 4096 * 8192 * 2, # State update + output
'no_kv_lookup': True,
},
'attention_layers': {
'count': 11,
'flops_per_layer': {
'qkv_proj': 4096 * (4096 + 1024 + 1024) * 2,
'attention': 'varies with context', # O(seq_len)
'output_proj': 4096 * 4096 * 2,
},
},
'dense_ffn_layers': {
'count': 21,
'flops_per_layer': 3 * 4096 * 14336 * 2,
},
'moe_ffn_layers': {
'count': 11,
'active_experts': 2,
'flops_per_layer': 2 * 3 * 4096 * 14336 * 2, # 2 experts
},
}
return per_token
Generation Speed vs Context Length (batch=1, A100)
At 256K context, Jamba generates at 38 tokens/s versus 4 tokens/s for a comparable pure transformer. The 9.5x speedup comes from: (1) Mamba layers needing no KV cache lookup, and (2) only 11 attention layers scanning the full context versus 32 in the transformer. Jamba’s generation speed degrades gracefully with context length, while transformers degrade quadratically.
Jamba established the hybrid Mamba-Attention architecture as a practical alternative to pure transformers. The key design principle is straightforward: use Mamba for the bulk of sequence processing (cheap, O(1) per generation step) and insert attention layers at regular intervals to maintain quality on tasks requiring global token interaction. The addition of MoE to a subset of FFN layers further improves the parameter-to-compute ratio. The result is a model that matches Mixtral-class quality at significantly lower serving cost, particularly for long-context workloads.