The Google TPU is the only production AI accelerator that was designed from the start as a domain-specific architecture for matrix multiplication — not a general-purpose GPU adapted for deep learning. The TPU’s core compute unit, the Matrix Multiply Unit (MXU), is a 128x128 systolic array that computes a complete matrix product in a single pass through the array. There are no CUDA cores, no warp schedulers, no register files in the GPU sense. The programming model is fundamentally different: you write JAX or TensorFlow code, the XLA compiler converts it to HLO (High-Level Operations), and the TPU runtime maps HLO operations to the systolic array and on-chip memory.
Google has deployed TPUs at massive scale — TPU v4 pods contain 4,096 chips connected by the ICI (Inter-Chip Interconnect) in a 3D torus topology with 1.1 Exaflops aggregate compute. The TPU v5e (the inference-optimized variant) powers Google Search, YouTube recommendations, and Gemini serving.
This post covers the TPU hardware architecture from v2 through v5e/v5p, the MXU systolic array design, the memory hierarchy, the ICI interconnect topology, the XLA compilation pipeline, and a quantitative analysis of when TPUs outperform GPUs.
TPU Architecture Overview
TPU Generational Comparison
| Specification | TPU v2 | TPU v3 | TPU v4 | TPU v5e | TPU v5p |
|---|---|---|---|---|---|
| Release year | 2017 | 2018 | 2021 | 2023 | 2023 |
| Process | 16nm | 16nm | 7nm | TBD | TBD |
| MXU count | 2 | 2 | 4 | 1 | 4 |
| MXU dimensions | 128×128 | 128×128 | 128×128 | 128×128 | 128×128 |
| BF16 Peak TFLOPS | 46 | 123 | 275 | 197 | 459 |
| INT8 Peak TOPS | 92 | 246 | 550 | 393 | 918 |
| HBM type | HBM2 | HBM2 | HBM2e | HBM2e | HBM2e |
| HBM capacity | 16 GB | 32 GB | 32 GB | 16 GB | 95 GB |
| HBM bandwidth | 600 GB/s | 900 GB/s | 1,200 GB/s | 820 GB/s | 2,765 GB/s |
| ICI bandwidth (per chip) | 496 Gbps | 656 Gbps | ~4.8 Tbps | ~1.6 Tbps | ~4.8 Tbps |
| Max pod size | 256 chips | 1024 chips | 4096 chips | 256 chips | 8960 chips |
| TDP per chip | ~250 W | ~250 W | ~175 W | ~100 W | ~250 W |
The Matrix Multiply Unit (MXU)
Systolic Array Architecture
The MXU is a 128x128 systolic array. Each cell in the array contains a multiply-accumulate (MAC) unit. Data flows through the array in two directions: one matrix enters from the left (row by row), the other enters from the top (column by column). Each cell multiplies the incoming values and adds the product to its running accumulation.
// Systolic array operation for C = A × B:
// A is 128×K, B is K×128, C is 128×128
//
// Data flow:
// A rows enter from the left, one row per cycle (shifted by 1 cycle per row)
// B columns enter from the top, one column per cycle (shifted by 1 cycle per column)
//
// At each cell (i, j):
// accumulator[i][j] += A[i, k] * B[k, j]
//
// After K+127 cycles, the full 128×128 output C is computed
// Peak throughput: 128 × 128 × 2 = 32,768 FLOPs per cycle (per MXU)
// With 4 MXUs (TPU v4):
// 4 × 32,768 = 131,072 BF16 FLOPs per cycle
// At ~1.05 GHz: 131,072 × 1.05 × 10^9 ≈ 137.6 TFLOPS BF16
// (Published: 275 TFLOPS includes clock speed and pipeline optimizations)
Why Systolic Arrays
Systolic arrays maximize data reuse and minimize memory bandwidth requirements:
// Comparison: GPU tensor core vs TPU MXU
// GPU tensor core (Hopper): 16×8×16 per warp
// Data loaded per operation: 16×16 + 16×8 = 384 BF16 values
// Output: 16×8 = 128 values
// FLOPs: 16 × 8 × 16 × 2 = 4,096
// Arithmetic intensity: 4,096 / (384 × 2 bytes) = 5.3 FLOPs/byte
// Multiple warps needed to build up to large matrix operations
// TPU MXU (128×128 systolic): 128×128×1 per cycle
// Data loaded per cycle: 128 (one A row) + 128 (one B column) = 256 BF16 values
// Output accumulates over K cycles
// FLOPs per cycle: 128 × 128 × 2 = 32,768
// Arithmetic intensity per cycle: 32,768 / (256 × 2 bytes) = 64 FLOPs/byte
// 12x higher arithmetic intensity than GPU tensor core per operation
// The MXU achieves high intensity because:
// Each value entering from the left is multiplied by ALL 128 values in its row
// Each value entering from the top is multiplied by ALL 128 values in its column
// One input value participates in 128 multiply-accumulate operations
Unlike GPU tensor cores (which software can invoke at various tile sizes via WMMA/WGMMA instructions), the MXU dimension is hardwired at 128x128. Operations smaller than 128x128 still pass through the full array — unused cells produce zeros. This means TPUs waste compute on problems that are not multiples of 128. For a 64x64 matrix multiply, the MXU operates at 25% efficiency (only 64x64 of the 128x128 cells produce useful results).
Precision Support
// TPU MXU precision modes:
// BF16 × BF16 → FP32 accumulation (primary mode for training)
// INT8 × INT8 → INT32 accumulation (inference quantization)
// FP8 × FP8 → FP32 accumulation (TPU v5 and later)
//
// Notable: NO native FP16 support (unlike GPUs)
// BF16 (Brain Floating Point 16):
// 1 sign, 8 exponent, 7 mantissa bits
// Same dynamic range as FP32 (8-bit exponent)
// Less precision than FP16 (7 vs 10 mantissa bits)
// But better for training because the exponent range avoids underflow/overflow
//
// Google's choice of BF16 over FP16 was deliberate:
// FP16 requires loss scaling (dynamic exponent management)
// BF16 does not — training is simpler and more robust
TPU Memory Hierarchy
On-Chip Memory: VMEM and CMEM
Each TPU chip has two types of on-chip memory:
// VMEM (Vector Memory): ~32 MB per chip (varies by generation)
// - Serves as the register file and scratchpad for the VPU (Vector Processing Unit)
// - 128 lanes, 32-bit each → processes 128 elements per cycle
// - Used for non-matrix operations: activations, normalization, softmax
// - Bandwidth: ~tens of TB/s (on-chip SRAM)
// CMEM (Common Memory / Shared Memory): variable
// - Feeds the MXU systolic array
// - Stores matrix tiles being fed into the MXU
// - Double-buffered: one buffer feeds MXU while another is loaded from HBM
// HBM: 32-95 GB per chip
// - Stores model weights, activations, optimizer state
// - 820-2,765 GB/s per chip (depending on generation)
The VPU (Vector Processing Unit)
The VPU handles all non-matrix operations:
// VPU: 128-lane SIMD unit
// Operations: elementwise add, multiply, exp, tanh, ReLU, etc.
// Used for: layer normalization, softmax, activation functions, reductions
//
// VPU is separate from MXU — they can execute concurrently:
// MXU: computing A × B (matrix multiply)
// VPU: computing softmax(previous_layer_output) (vector operation)
// HBM controller: loading next layer's weights (DMA)
// All three operate in parallel, overlapping compute and memory access
TPU v4 Memory Hierarchy
| Level | Size | Bandwidth | Latency | Used For |
|---|---|---|---|---|
| MXU registers (systolic) | ~16 KB (accumulator) | Matches MXU clock | 0 cycles | Matrix accumulation |
| VMEM (vector scratchpad) | ~32 MB | ~10+ TB/s | ~5-10 cycles | Activations, normalization |
| CMEM (MXU scratchpad) | ~16 MB | ~10+ TB/s | ~5-10 cycles | Matrix tiles (A, B operands) |
| HBM2e | 32 GB | 1,200 GB/s | ~100-200 cycles | Weights, full activations |
| ICI (inter-chip) | N/A | ~600 GB/s | ~1-10 us | Cross-chip data exchange |
Inter-Chip Interconnect (ICI)
Topology: 3D Torus
TPU pods use a 3D torus topology — each chip connects to 6 neighbors (2 per dimension):
// TPU v4 pod: 4,096 chips in a 3D torus
// Dimensions: 16 × 16 × 16 (approximately)
// Each chip has 6 ICI links (one in each direction: +X, -X, +Y, -Y, +Z, -Z)
//
// 3D torus properties:
// - Diameter: O(N^(1/3)) hops for N chips
// 4,096 chips → 16 hops maximum (corner to corner)
// - Bisection bandwidth: O(N^(2/3))
// 4,096 chips → cross-section bandwidth scales as 16^2 = 256 links
// Each link: ~600 Gbps → bisection BW: ~256 × 600 Gbps = ~19.2 TB/s
//
// vs NVSwitch all-to-all:
// NVSwitch: 1 hop between any pair, O(N) bisection bandwidth
// ICI torus: up to 16 hops, O(N^(2/3)) bisection bandwidth
// NVSwitch wins on per-pair bandwidth and latency
// ICI torus wins on scalability (4096+ chips without monster switches)
Allreduce on 3D Torus
// Allreduce on 3D torus uses dimension-wise reduction:
// 1. Reduce along X dimension (16 chips per row)
// Each chip sends to its X neighbor, accumulates, passes forward
// Latency: O(16) hops, bandwidth: 1 ICI link
//
// 2. Reduce along Y dimension (16 chips per column)
// Same process along Y
//
// 3. Reduce along Z dimension (16 chips per plane)
// Same process along Z
//
// Total latency: O(3 × 16) = O(48) hops
// vs ring allreduce on flat topology: O(4096) hops
// The 3D decomposition reduces allreduce latency by ~85x
// For 1 GB allreduce across 4,096 v4 chips:
// Step 1 (X): 1 GB / 16 = 64 MB per chip, 16 steps × ~600 Gbps ≈ ~1.7 ms
// Step 2 (Y): 64 MB per chip, 16 steps ≈ ~1.7 ms
// Step 3 (Z): 64 MB per chip, 16 steps ≈ ~1.7 ms
// Total: ~5.1 ms for 4,096-chip allreduce
Allreduce Time (1 GB payload) by System Scale
(ms (lower is better))At small scale (8-64 chips), NVSwitch beats ICI on bandwidth and latency. At large scale (1000+ chips), ICI’s torus topology provides more efficient allreduce than fat-tree InfiniBand networks because each dimension is reduced independently. The 3D torus also provides natural support for 3D parallelism (data × tensor × pipeline) — each parallelism dimension maps to one physical torus dimension.
XLA Compilation Pipeline
From JAX to TPU
The TPU has no user-facing ISA. You cannot write assembly or PTX for a TPU. All programming goes through the XLA (Accelerated Linear Algebra) compiler:
# Step 1: User writes JAX code
import jax
import jax.numpy as jnp
def attention(Q, K, V):
scores = jnp.matmul(Q, K.T) / jnp.sqrt(Q.shape[-1])
weights = jax.nn.softmax(scores, axis=-1)
return jnp.matmul(weights, V)
# Step 2: JAX traces the function to produce an HLO graph
# HLO (High-Level Operations): dot, reduce, broadcast, reshape, etc.
# attention → [dot(Q, K^T), div(_, sqrt(d)), softmax(_), dot(_, V)]
# Step 3: XLA optimizes the HLO graph
# - Operator fusion: merge elementwise ops into single kernels
# - Layout optimization: choose memory layouts (row-major, tiled)
# - Tiling: partition matrices for MXU-sized tiles (128x128)
# - Memory planning: schedule HBM ↔ VMEM/CMEM transfers
# - Parallelism: distribute across TPU cores via SPMD
# Step 4: XLA generates TPU machine code
# Low-level instructions for MXU, VPU, and DMA engine
# Explicit data movement between HBM and on-chip SRAMs
XLA Optimizations That Matter
# Operator fusion: XLA fuses chains of elementwise operations
# Before fusion:
# temp1 = matmul(Q, K) # HBM write
# temp2 = temp1 / sqrt(d) # HBM read, HBM write
# temp3 = softmax(temp2) # HBM read, HBM write
# out = matmul(temp3, V) # HBM read, HBM write
# 4 HBM reads + 4 HBM writes
# After fusion:
# temp1 = matmul(Q, K) # HBM write (MXU output)
# out = fused_softmax_matmul(temp1, V) # VPU + MXU, VMEM only
# 2 HBM reads + 2 HBM writes (2x reduction in HBM traffic)
# Layout optimization: choose tiled layout for MXU
# Default (row-major): A[i][j] at address i*N + j
# Tiled (128x128 tiles): A_tile[i//128][j//128][i%128][j%128]
# Tiled layout allows MXU to read contiguous data without gather
SPMD Partitioning
XLA automatically partitions computations across multiple TPU chips using SPMD (Single Program, Multiple Data) with explicit sharding annotations:
# Sharding annotation in JAX
from jax.sharding import NamedSharding, PartitionSpec, Mesh
# Create a mesh of 8 TPU chips
mesh = Mesh(jax.devices()[:8], axis_names=('data',))
# Shard input along batch dimension
sharding = NamedSharding(mesh, PartitionSpec('data', None))
x = jax.device_put(x, sharding)
# x is now distributed: each chip holds 1/8 of the batch
# JAX + XLA automatically:
# Partition the computation
# Insert ICI allreduce where gradients need to be synchronized
# Insert ICI all-gather where full tensors are needed
# Optimize the communication schedule to overlap with computation
When TPUs Win
Matrix-Heavy Workloads with Regular Shapes
The 128x128 MXU achieves peak throughput when matrix dimensions are multiples of 128:
// Workload: BERT-Large fine-tuning, batch size 512
// Matrix dimensions: all multiples of 128 (hidden_size=1024, FFN=4096)
// MXU utilization: >85%
//
// TPU v4 (275 TFLOPS BF16): 275 × 0.85 = 233.75 TFLOPS effective
// H100 (990 TFLOPS BF16): needs >23.6% MFU to match
// H100 typical MFU for BERT: 35-45% → 346-445 TFLOPS effective
// H100 wins on raw throughput, BUT:
// TPU v4 at $3.22/hr vs H100 at $11+/hr
// Cost-adjusted: TPU v4 wins for BERT training by ~2x
Pod-Scale Training
For training that requires thousands of accelerators, ICI provides an advantage:
// Scenario: training a 540B parameter PaLM model
// Google used 6,144 TPU v4 chips (two full pods)
// All-reduce across 6,144 chips via ICI: low latency, high bandwidth
//
// Equivalent GPU setup: 768 H100 GPUs (8-GPU DGX nodes × 96 nodes)
// Inter-node communication: InfiniBand 400G (50 GB/s per GPU)
// GPU NVLink: 900 GB/s intra-node, 50 GB/s inter-node → 18x gap
// TPU ICI: ~600 GB/s per chip, uniform across all 6,144 chips
// TPU has no intra-node / inter-node bandwidth cliff
Cost-Adjusted Training Throughput (normalized to TPU v4 = 100)
(relative throughput per dollar)Inference at Google’s Scale
TPU v5e is optimized for inference cost efficiency:
// TPU v5e for inference:
// 197 TFLOPS BF16, 16 GB HBM, ~100W TDP
// Cost: ~$1.20/hr per chip on Google Cloud
//
// For small model inference (BERT, T5):
// Single v5e chip handles thousands of queries/second
// The 16 GB HBM limits model size but is sufficient for encoder models
//
// For large model inference (Gemini, PaLM):
// Multiple v5e chips with model parallelism
// ICI enables efficient tensor parallelism across chips
//
// Key advantage: Google owns the entire stack
// Hardware (TPU) + software (JAX/XLA) + framework (Gemini) + serving (GKE)
// No mismatched abstractions, no driver compatibility issues
When GPUs Win
Irregular Workloads
// Workloads with non-128-multiple dimensions:
// NLP with vocab_size=50257 (GPT-2): 50257 / 128 = 392.6 → 393 tiles, last tile 49/128 = 38% utilized
// Image models with channel=3: 3/128 = 2.3% MXU utilization for first layer
//
// GPU tensor cores (16×8 tiles): much finer granularity
// 50257 / 16 = 3141 tiles, last tile 1/16 = 6.25% waste (vs 62% waste on TPU)
// GPUs handle irregular shapes with much less wasted compute
Software Ecosystem
// CUDA ecosystem advantages:
// - PyTorch native support (most researchers use PyTorch)
// - Custom CUDA kernels (FlashAttention, custom fused ops)
// - CUTLASS for custom GEMM configurations
// - Triton for rapid kernel development
// - Nsight Compute for detailed profiling
//
// TPU ecosystem:
// - JAX + XLA (functional programming model)
// - Limited custom kernel support (Pallas, still maturing)
// - No equivalent to Nsight Compute (profiling is less granular)
// - PyTorch support via torch_xla (works but adds overhead)
// - Model must be expressible as XLA-compatible operations
The single biggest TPU limitation is custom kernels. On GPUs, FlashAttention (a hand-tuned CUDA kernel) provides 2-4x speedup over XLA-compiled attention. On TPUs, you must rely on XLA’s auto-optimization of attention, which is good but not hand-tuned. Google’s Pallas project (a low-level TPU kernel language) is closing this gap, but as of 2025, the GPU ecosystem for custom kernels remains significantly ahead.
Memory Capacity for Large Models
// H100: 80 GB HBM per GPU (or H200: 141 GB)
// TPU v4: 32 GB HBM per chip
// TPU v5e: 16 GB HBM per chip
// TPU v5p: 95 GB HBM per chip
//
// For a 70B FP16 model (140 GB):
// H100: 2 GPUs (tensor parallel)
// H200: 1 GPU (fits entirely)
// TPU v4: 5 chips minimum (tensor parallel)
// TPU v5e: 9 chips minimum
// TPU v5p: 2 chips
//
// More chips = more ICI communication = more overhead
// GPUs with larger per-device memory need fewer devices
TPU vs GPU: Quantitative Summary
TPU v5p vs H100 vs B200: Head-to-Head
| Metric | TPU v5p | H100 SXM | B200 | Notes |
|---|---|---|---|---|
| BF16 TFLOPS | 459 | 990 | 2,250 | B200 leads |
| INT8 TOPS | 918 | 1,979 | 4,500 | B200 leads |
| HBM capacity | 95 GB | 80 GB | 192 GB | B200 leads |
| HBM bandwidth | 2,765 GB/s | 3,350 GB/s | 8,000 GB/s | B200 leads |
| Interconnect BW (per chip) | ~600 GB/s ICI | 900 GB/s NVLink | 1,800 GB/s NVLink | B200 leads |
| Max scale (single fabric) | 8,960 chips | 8 GPUs (NVSwitch) | 72 GPUs (NVL72) | TPU leads |
| Cloud cost ($/hr) | ~$4.50 | ~$11.00 | ~$15+ (est.) | TPU cheapest |
| Software ecosystem | JAX/XLA | CUDA/PyTorch | CUDA/PyTorch | GPU leads |
| Custom kernel support | Limited (Pallas) | Extensive (CUDA) | Extensive (CUDA) | GPU leads |
Summary
TPUs are purpose-built matrix multiplication engines. The 128x128 systolic array achieves higher arithmetic intensity per operation than GPU tensor cores, the ICI torus provides uniform bandwidth at pod scale without the intra/inter-node bandwidth cliff of GPU clusters, and Google’s vertical integration (hardware + XLA + JAX + training frameworks) eliminates interface overhead.
TPUs win on cost-adjusted throughput for large-scale training of models with regular shapes (multiples of 128) using the JAX/XLA stack. GPUs win on flexibility — irregular shapes, custom kernels, PyTorch ecosystem, profiling tools, and memory capacity per device.
The choice is not purely technical. If you are Google (or a Google Cloud customer running JAX workloads), TPUs provide the best cost/performance. If you need PyTorch, custom CUDA kernels, or workloads with irregular shapes, GPUs are the clear choice. Both are converging: Google is adding more flexibility to TPUs (Pallas, larger HBM), and NVIDIA is adding more domain-specific features to GPUs (Transformer Engine, decompression engines).