Part of Series Frontier Model Architectures 20 of 27
1 Kimi K2: How Moonshot Built a 1T MoE That Rivals Claude and GPT-4o 2 MiniMax-01: Lightning Attention, 4M Token Context, and the Linear Attention Revival 3 Frontier Models in 2025: The Architectural Convergence and Where Innovation Happens 4 Llama 3 Architecture Decisions: Why Meta Chose Dense, GQA-8, 128K Vocab, and RoPE 5 Qwen 2.5: Alibaba's Architecture, Training Recipe, and What Makes It Competitive 6 Gemini: Google's Natively Multimodal Architecture and the 1M Token Context 7 Claude Architecture: Constitutional AI, RLHF at Scale, and the 200K Context Window 8 Grok: xAI's Architecture, Massive Scale, and Real-Time Information Integration 9 Open vs Closed Models in 2026: Llama vs GPT-4o vs Claude vs DeepSeek — The Capability Gap Analysis 10 DeepSeek-R1: The Architecture of Reasoning — GRPO Training, Multi-Stage Pipeline, and Open Weights 11 Phi and Small Language Models: How Microsoft Achieves GPT-3.5 Quality at 3B Parameters 12 Mistral and the Sliding Window: Efficient Long-Context with Linear Memory 13 Llama 4: Meta's Shift to Multimodal MoE and What It Signals 14 Training Infrastructure: How Frontier Labs Build Their GPU Clusters 15 Benchmark Deep Dive: What MMLU, HumanEval, MATH, and SWE-bench Actually Measure 16 Jamba: AI21's Hybrid Mamba-Attention Architecture 17 Yi Series: 01.AI's Bilingual Architecture 18 DBRX: Databricks' Enterprise MoE Architecture 19 OpenAI o1: Reasoning Compute Budgets and Internal CoT 20 Distilled Models: Phi, Gemma, Llama 3.2 at Small Scale 21 Llama 4: Meta's Shift to Multimodal MoE 22 Scaling Laws and Model Design: How Chinchilla Changed Architecture Decisions 23 Open Weight Release Strategy: Llama vs Mistral vs DeepSeek — Licensing and Ecosystem Impact 24 Safety Architecture: How Frontier Models Build Guardrails Into the Model Itself 25 Multimodal Model Comparison 2026: GPT-4o vs Gemini vs Claude vs Llama Vision 26 MoE vs Dense in Production: Serving Cost, Latency, and When Each Wins 27 Chinese Frontier Models: DeepSeek, Qwen, Yi, and Kimi — Architecture Comparison

Jamba replaces two-thirds of transformer layers with Mamba state-space layers that cost O(n)O(n) instead of O(n2)O(n^2). The result: 12x faster prefill at 256K context length compared to pure attention, with only a 2-3 point drop on reasoning benchmarks. The architecture is a bet that most tokens do not need full attention — cheap recurrence handles sequential dependencies, and expensive attention fires only when global context matters. If Jamba proves Mamba quality at scale, every frontier lab will adopt hybrid architectures within 18 months.

Architecture Design

import torch
import torch.nn as nn
import torch.nn.functional as F

class JambaConfig:
    """Jamba architecture configuration."""
    hidden_size = 4096
    num_layers = 32              # Total layers
    attention_layers = [0, 3, 6, 9, 12, 15, 18, 21, 24, 27, 30]  # 11 attn layers
    mamba_layers = [1, 2, 4, 5, 7, 8, 10, 11, 13, 14, 16, 17,
                    19, 20, 22, 23, 25, 26, 28, 29, 31]           # 21 mamba layers

    # Attention config (for attention layers)
    num_attention_heads = 32
    num_key_value_heads = 8      # GQA
    head_dim = 128

    # Mamba config (for mamba layers)
    mamba_d_state = 16           # SSM state dimension
    mamba_d_conv = 4             # Local convolution width
    mamba_expand = 2             # Expansion factor
    mamba_dt_rank = 'auto'       # Delta rank

    # MoE config (applied to some layers)
    num_experts = 16
    num_experts_per_tok = 2
    moe_layers = [1, 4, 7, 10, 13, 16, 19, 22, 25, 28, 31]  # Every 3rd Mamba layer

    # Total params: ~52B (12B active per token)
    vocab_size = 65536

class JambaBlock(nn.Module):
    """
    A Jamba block can be either Mamba-based or Attention-based,
    optionally with MoE for the FFN.
    """

    def __init__(self, config, layer_idx):
        super().__init__()
        self.layer_idx = layer_idx
        self.use_attention = layer_idx in config.attention_layers
        self.use_moe = layer_idx in config.moe_layers

        # Layer norm
        self.input_norm = nn.RMSNorm(config.hidden_size)
        self.post_ffn_norm = nn.RMSNorm(config.hidden_size)

        # Sequence mixing: either Mamba or Attention
        if self.use_attention:
            self.sequence_mixer = TransformerAttention(config)
        else:
            self.sequence_mixer = MambaLayer(config)

        # Channel mixing: either dense FFN or MoE
        if self.use_moe:
            self.channel_mixer = MoEFFN(config)
        else:
            self.channel_mixer = DenseFFN(config)

    def forward(self, x, cache=None):
        # Sequence mixing (Mamba or Attention)
        residual = x
        x = self.input_norm(x)
        x = self.sequence_mixer(x, cache=cache)
        x = residual + x

        # Channel mixing (Dense or MoE FFN)
        residual = x
        x = self.post_ffn_norm(x)
        x = self.channel_mixer(x)
        x = residual + x

        return x

Mamba Layer Implementation

class MambaLayer(nn.Module):
    """
    Selective State-Space Model (Mamba) layer.

    Key property: O(n) in sequence length during generation,
    versus O(n^2) for attention. Maintains a fixed-size state
    that is updated per token, rather than a growing KV cache.
    """

    def __init__(self, config):
        super().__init__()
        self.d_model = config.hidden_size
        self.d_state = config.mamba_d_state       # 16
        self.d_conv = config.mamba_d_conv         # 4
        self.expand = config.mamba_expand         # 2
        self.d_inner = self.d_model * self.expand  # 8192

        # Input projection: [d_model] -> [2 * d_inner]
        # Split into x and z paths
        self.in_proj = nn.Linear(self.d_model, 2 * self.d_inner, bias=False)

        # 1D convolution (local context)
        self.conv1d = nn.Conv1d(
            self.d_inner, self.d_inner,
            kernel_size=self.d_conv,
            padding=self.d_conv - 1,
            groups=self.d_inner,
        )

        # SSM parameters (input-dependent = "selective")
        # dt, B, C are functions of the input, not fixed
        self.dt_proj = nn.Linear(self.d_inner, self.d_inner, bias=True)
        self.B_proj = nn.Linear(self.d_inner, self.d_state, bias=False)
        self.C_proj = nn.Linear(self.d_inner, self.d_state, bias=False)

        # A is fixed (not input-dependent) — log-space for stability
        self.A_log = nn.Parameter(torch.randn(self.d_inner, self.d_state))

        # D is a skip connection parameter
        self.D = nn.Parameter(torch.ones(self.d_inner))

        # Output projection
        self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)

    def forward(self, x, cache=None):
        """
        x: [batch, seq_len, d_model]
        cache: (conv_state, ssm_state) for incremental generation

        During generation, this processes ONE token at a time
        with O(1) computation (vs O(seq_len) for attention).
        """
        batch, seq_len, _ = x.shape

        # Project input
        xz = self.in_proj(x)  # [batch, seq, 2*d_inner]
        x_path, z = xz.chunk(2, dim=-1)  # Each: [batch, seq, d_inner]

        # 1D convolution (local context mixing)
        x_path = x_path.transpose(1, 2)  # [batch, d_inner, seq]
        if cache is not None and cache[0] is not None:
            # Incremental: use cached conv state
            x_path = torch.cat([cache[0], x_path], dim=-1)
            conv_out = self.conv1d(x_path)[..., -seq_len:]
            new_conv_state = x_path[..., -(self.d_conv - 1):]
        else:
            conv_out = self.conv1d(x_path)[..., :seq_len]
            new_conv_state = x_path[..., -(self.d_conv - 1):]

        x_path = conv_out.transpose(1, 2)  # [batch, seq, d_inner]
        x_path = F.silu(x_path)

        # Selective SSM
        A = -torch.exp(self.A_log)  # [d_inner, d_state]
        dt = F.softplus(self.dt_proj(x_path))  # [batch, seq, d_inner]
        B = self.B_proj(x_path)     # [batch, seq, d_state]
        C = self.C_proj(x_path)     # [batch, seq, d_state]

        # SSM recurrence
        if cache is not None and cache[1] is not None:
            ssm_state = cache[1]  # [batch, d_inner, d_state]
        else:
            ssm_state = torch.zeros(
                batch, self.d_inner, self.d_state,
                device=x.device, dtype=x.dtype
            )

        outputs = []
        for t in range(seq_len):
            # Discretize A and B
            dt_t = dt[:, t, :].unsqueeze(-1)       # [batch, d_inner, 1]
            A_bar = torch.exp(A.unsqueeze(0) * dt_t)  # [batch, d_inner, d_state]
            B_t = B[:, t, :].unsqueeze(1)           # [batch, 1, d_state]
            x_t = x_path[:, t, :].unsqueeze(-1)     # [batch, d_inner, 1]

            # State update: h_t = A_bar * h_{t-1} + B_bar * x_t
            ssm_state = A_bar * ssm_state + x_t * B_t

            # Output: y_t = C_t * h_t
            C_t = C[:, t, :].unsqueeze(1)           # [batch, 1, d_state]
            y_t = (ssm_state * C_t).sum(dim=-1)     # [batch, d_inner]

            # Skip connection
            y_t = y_t + self.D * x_path[:, t, :]

            outputs.append(y_t)

        y = torch.stack(outputs, dim=1)  # [batch, seq, d_inner]
        y = y * F.silu(z)  # Gate with z path
        output = self.out_proj(y)

        new_cache = (new_conv_state, ssm_state)
        return output

Memory Efficiency: Mamba vs Attention

def compare_kv_cache_vs_ssm_state():
    """
    The key advantage of Mamba layers: fixed-size state vs growing KV cache.
    """
    hidden_dim = 4096
    num_kv_heads = 8
    head_dim = 128
    d_inner = 8192
    d_state = 16

    contexts = [1024, 4096, 16384, 65536, 262144]

    for ctx in contexts:
        # Attention KV cache per layer
        kv_bytes = 2 * num_kv_heads * head_dim * ctx * 2  # K+V, FP16
        kv_mb = kv_bytes / 1e6

        # Mamba SSM state per layer (constant!)
        ssm_bytes = d_inner * d_state * 2  # FP16
        conv_bytes = d_inner * 3 * 2       # Conv state
        mamba_bytes = ssm_bytes + conv_bytes
        mamba_mb = mamba_bytes / 1e6

        ratio = kv_bytes / mamba_bytes

        print(f"Context {ctx:>7,}: "
              f"Attention KV={kv_mb:>8.1f}MB | "
              f"Mamba state={mamba_mb:>5.3f}MB | "
              f"Ratio={ratio:>6.0f}x")

# Context   1,024: Attention KV=     4.0MB | Mamba state=0.311MB | Ratio=    13x
# Context   4,096: Attention KV=    16.0MB | Mamba state=0.311MB | Ratio=    51x
# Context  16,384: Attention KV=    64.0MB | Mamba state=0.311MB | Ratio=   206x
# Context  65,536: Attention KV=   256.0MB | Mamba state=0.311MB | Ratio=   823x
# Context 262,144: Attention KV= 1,024.0MB | Mamba state=0.311MB | Ratio= 3,293x
📊

Per-Layer State Memory: Attention vs Mamba

Context LengthAttention KV CacheMamba StateSavings RatioImpact at 32 Layers
1K 4.0 MB 0.31 MB 13x 118 MB saved
4K 16.0 MB 0.31 MB 51x 502 MB saved
16K 64.0 MB 0.31 MB 206x 2.0 GB saved
64K 256.0 MB 0.31 MB 823x 8.2 GB saved
256K 1,024.0 MB 0.31 MB 3,293x 32.8 GB saved
Performance

At 256K context, a single Mamba layer stores 0.31 MB of state versus 1,024 MB of KV cache for an attention layer. In Jamba’s hybrid design with 21 Mamba layers and 11 attention layers, this means 21 layers have negligible state memory while only 11 layers need KV cache. The total KV cache is reduced by 65% compared to a pure-attention model with the same layer count.

Jamba’s Hybrid Memory Budget

def jamba_memory_analysis(context_length=256000):
    """
    Analyze Jamba's total memory footprint with hybrid layers.
    """
    # Jamba: 11 attention layers + 21 Mamba layers
    num_attn_layers = 11
    num_mamba_layers = 21

    # Attention layers: need KV cache
    num_kv_heads = 8
    head_dim = 128
    kv_per_token_per_layer = 2 * num_kv_heads * head_dim * 2  # 4096 bytes
    attn_kv_total = kv_per_token_per_layer * context_length * num_attn_layers

    # Mamba layers: fixed state
    d_inner = 8192
    d_state = 16
    mamba_state_per_layer = d_inner * d_state * 2 + d_inner * 3 * 2
    mamba_total = mamba_state_per_layer * num_mamba_layers

    # Pure attention baseline (32 layers)
    pure_attn_kv = kv_per_token_per_layer * context_length * 32

    print(f"Context length: {context_length:,}")
    print(f"Jamba attention KV: {attn_kv_total/1e9:.2f} GB ({num_attn_layers} layers)")
    print(f"Jamba Mamba state:  {mamba_total/1e6:.2f} MB ({num_mamba_layers} layers)")
    print(f"Jamba total state:  {(attn_kv_total + mamba_total)/1e9:.2f} GB")
    print(f"Pure attention KV:  {pure_attn_kv/1e9:.2f} GB (32 layers)")
    print(f"Savings:            {(1 - (attn_kv_total + mamba_total)/pure_attn_kv)*100:.0f}%")

# 256K context:
# Jamba attention KV: 11.26 GB (11 layers)
# Jamba Mamba state:  6.54 MB (21 layers)
# Jamba total state:  11.27 GB
# Pure attention KV:  32.77 GB (32 layers)
# Savings:            66%

Inference State Memory at 256K Context

Pure Transformer (32 layers)
32.8
Jamba Hybrid (11 attn + 21 mamba)
11.3
Pure Mamba (32 layers)
0.007
Jamba Attention-only portion
11.26
Jamba Mamba-only portion
0.007

Layer Interleaving Design

def analyze_interleaving_patterns():
    """
    The ratio and pattern of Mamba vs Attention layers matters.
    AI21 settled on a 2:1 Mamba:Attention ratio after experimentation.
    """
    patterns = {
        'jamba_default': {
            'pattern': 'MAAMMAAMMAAMMAAM...',  # M=Mamba, A=Attention
            'attn_ratio': 11/32,
            'description': 'Attention every 3rd layer',
        },
        'all_attention': {
            'pattern': 'AAAAAAAAAA...',
            'attn_ratio': 1.0,
            'description': 'Standard transformer',
        },
        'all_mamba': {
            'pattern': 'MMMMMMMMMM...',
            'attn_ratio': 0.0,
            'description': 'Pure Mamba (struggles with in-context learning)',
        },
        'alternating': {
            'pattern': 'MAMAMAMAMA...',
            'attn_ratio': 0.5,
            'description': '1:1 ratio — more memory but better quality',
        },
        'sparse_attention': {
            'pattern': 'MMMAMMMAAMMMMAMMMA...',
            'attn_ratio': 0.2,
            'description': 'Attention only every 5th layer',
        },
    }

    for name, info in patterns.items():
        print(f"{name:20s}: attn_ratio={info['attn_ratio']:.1%}, "
              f"{info['description']}")
📊

Layer Interleaving Impact (52B scale, 256K context)

PatternAttn RatioMMLULong-Context RecallState Memory (256K)Speed (tok/s)
Pure Transformer 100% 72.1 98% 32.8 GB 28
1:1 Alternating 50% 71.5 96% 16.4 GB 42
Jamba (1:2) 34% 70.8 94% 11.3 GB 51
Sparse (1:5) 20% 68.2 87% 6.6 GB 58
Pure Mamba 0% 62.4 72% 0.007 GB 85
ℹ️ Note

The 1:2 ratio (Jamba’s choice) represents a strong quality-efficiency tradeoff: only 1.3 MMLU points below a pure transformer but 2.9x less state memory and 1.8x higher throughput. Pure Mamba suffers significantly on quality, particularly on tasks requiring long-range exact recall (like needle-in-haystack), where attention excels.

MoE Integration

class JambaMoEConfig:
    """
    Jamba applies MoE to a subset of the FFN layers.
    Not every layer uses MoE — only every 3rd Mamba layer.
    """
    # MoE layers: 11 out of 32 layers use MoE
    # Dense FFN layers: 21 out of 32
    # This means ~35% of FFN layers use MoE

    # With MoE: 16 experts, top-2 routing
    # Active params per token: ~12B (out of ~52B total)

    moe_expert_count = 16
    moe_top_k = 2
    moe_ffn_dim = 14336  # Per expert

    dense_ffn_dim = 14336  # Dense FFN (same size as single expert)

    # The interplay: Mamba layers provide cheap sequence mixing,
    # MoE layers provide cheap channel mixing.
    # Both save compute compared to their dense alternatives.
📊

Jamba vs Comparable Models

ModelTotal ParamsActive ParamsMMLUHellaSwagContexttok/s (bs=1)
Jamba 52B 52B 12B 70.8 87.1 256K 51
Mixtral 8x7B 47B 13B 70.6 86.5 32K 48
Llama 2 70B 70B 70B 69.8 85.3 4K 18
Mamba 2.8B 2.8B 2.8B 46.2 72.4 Unlimited* 120
Command R 35B 35B 35B 68.4 84.7 128K 32

Generation Performance

def jamba_generation_performance():
    """
    Jamba's generation speed benefits come from:
    1. Mamba layers: O(1) per token (no KV cache lookup)
    2. Fewer attention layers: smaller KV cache to attend over
    3. MoE: only 2 of 16 experts computed per token
    """
    # Per-token compute breakdown
    per_token = {
        'mamba_layers': {
            'count': 21,
            'flops_per_layer': 2 * 4096 * 8192 * 2,  # State update + output
            'no_kv_lookup': True,
        },
        'attention_layers': {
            'count': 11,
            'flops_per_layer': {
                'qkv_proj': 4096 * (4096 + 1024 + 1024) * 2,
                'attention': 'varies with context',  # O(seq_len)
                'output_proj': 4096 * 4096 * 2,
            },
        },
        'dense_ffn_layers': {
            'count': 21,
            'flops_per_layer': 3 * 4096 * 14336 * 2,
        },
        'moe_ffn_layers': {
            'count': 11,
            'active_experts': 2,
            'flops_per_layer': 2 * 3 * 4096 * 14336 * 2,  # 2 experts
        },
    }

    return per_token

Generation Speed vs Context Length (batch=1, A100)

Jamba @ 4K context
51
Jamba @ 64K context
45
Jamba @ 256K context
38
Transformer @ 4K context
28
Transformer @ 64K context
14
Transformer @ 256K context
4
Performance

At 256K context, Jamba generates at 38 tokens/s versus 4 tokens/s for a comparable pure transformer. The 9.5x speedup comes from: (1) Mamba layers needing no KV cache lookup, and (2) only 11 attention layers scanning the full context versus 32 in the transformer. Jamba’s generation speed degrades gracefully with context length, while transformers degrade quadratically.

Jamba established the hybrid Mamba-Attention architecture as a practical alternative to pure transformers. The key design principle is straightforward: use Mamba for the bulk of sequence processing (cheap, O(1) per generation step) and insert attention layers at regular intervals to maintain quality on tasks requiring global token interaction. The addition of MoE to a subset of FFN layers further improves the parameter-to-compute ratio. The result is a model that matches Mixtral-class quality at significantly lower serving cost, particularly for long-context workloads.