A transformer with 70 billion parameters processes every input through 80 layers of 8192-dimensional hidden states. The hidden states encode everything the model “knows” about the input, but they are dense, high-dimensional vectors with no human-interpretable structure. Mechanistic interpretability aims to reverse-engineer what these hidden states represent — to decompose them into understandable features and trace how those features interact to produce the model’s output.
The key breakthrough: sparse autoencoders (SAEs) can decompose a model’s dense hidden states into a large set of sparse, interpretable features. Each feature corresponds to a recognizable concept (a language, a topic, a syntactic pattern, a factual association). The activation of these features can be traced through the model’s computation graph, revealing “circuits” — small subnetworks that implement specific behaviors.
This post covers the technical details of sparse autoencoders, feature identification and interpretation, circuit discovery, the logit lens technique, and the practical implications for model debugging and safety.
The Superposition Problem
Why Hidden States Are Uninterpretable
A transformer’s hidden state with dimensions could naively represent at most 8192 independent features (one per dimension). But the model clearly knows far more than 8192 things. The resolution: the model uses superposition — it encodes many more than features by using non-orthogonal directions in the -dimensional space.
where features are directions in the hidden space, and are their activations. The features are not orthogonal to each other (they cannot be, since there are more features than dimensions). This means any single neuron’s activation is a mixture of multiple features, making individual neurons uninterpretable.
import torch
import torch.nn as nn
import numpy as np
def demonstrate_superposition():
"""Show how superposition encodes more features than dimensions."""
d = 128 # Hidden dimension
M = 1024 # Number of features (much larger than d)
# Random feature directions (non-orthogonal)
features = torch.randn(M, d)
features = features / features.norm(dim=1, keepdim=True)
# Average dot product between random feature pairs
# (measures how much features "interfere" with each other)
sample_pairs = 10000
i = torch.randint(0, M, (sample_pairs,))
j = torch.randint(0, M, (sample_pairs,))
dots = (features[i] * features[j]).sum(dim=1)
print(f"Dimension: {d}")
print(f"Features: {M}")
print(f"Feature-to-dimension ratio: {M/d:.1f}x")
print(f"Average interference (|dot product|): {dots.abs().mean():.4f}")
# For d=128: interference ~ 1/sqrt(128) ~ 0.088
# Low enough for the model to distinguish features, but not zero
The Key Insight
Superposition works because most features are sparse — they are active for only a small fraction of inputs. If feature activates on 1% of inputs and feature activates on 1% of inputs, the probability that both are active simultaneously is . When features rarely co-activate, their mutual interference is manageable.
This is exactly the setting where sparse autoencoders excel: they are designed to find sparse decompositions of dense signals.
Sparse Autoencoders (SAEs)
Architecture
A sparse autoencoder learns to decompose a hidden state into a sparse set of features and reconstruct from those features:
where is the encoder matrix, is the decoder matrix, and is the sparse feature activation vector.
class SparseAutoencoder(nn.Module):
"""
Sparse autoencoder for decomposing transformer hidden states.
Based on Anthropic's approach from Bricken et al. (2023).
"""
def __init__(self, d_model, n_features, l1_coefficient=1e-3):
super().__init__()
self.d_model = d_model
self.n_features = n_features
self.l1_coeff = l1_coefficient
# Encoder: d_model -> n_features
self.encoder = nn.Linear(d_model, n_features, bias=True)
# Decoder: n_features -> d_model
self.decoder = nn.Linear(n_features, d_model, bias=True)
# Initialize decoder columns to unit norm
with torch.no_grad():
self.decoder.weight.data = self.decoder.weight.data / \
self.decoder.weight.data.norm(dim=0, keepdim=True)
def forward(self, h):
"""
Args:
h: [batch_size, d_model] hidden state from transformer
Returns:
h_reconstructed: [batch_size, d_model]
z: [batch_size, n_features] sparse feature activations
loss: reconstruction loss + sparsity penalty
"""
# Encode to sparse features
z = torch.relu(self.encoder(h)) # [B, M]
# Decode back to hidden space
h_hat = self.decoder(z) # [B, d]
# Loss: reconstruction + L1 sparsity
reconstruction_loss = (h - h_hat).pow(2).mean()
sparsity_loss = z.abs().mean()
loss = reconstruction_loss + self.l1_coeff * sparsity_loss
return h_hat, z, loss
def get_feature_activations(self, h):
"""Get feature activations for a hidden state."""
with torch.no_grad():
z = torch.relu(self.encoder(h))
return z
def get_feature_direction(self, feature_idx):
"""Get the direction in hidden space for a feature."""
return self.decoder.weight[:, feature_idx]
@property
def sparsity(self):
"""Average number of active features per input."""
# Will be computed during forward pass monitoring
pass
Training the SAE
class SAETrainer:
"""Train a sparse autoencoder on transformer activations."""
def __init__(self, sae, learning_rate=1e-4):
self.sae = sae
self.optimizer = torch.optim.Adam(sae.parameters(), lr=learning_rate)
def train_on_activations(self, activation_dataset, epochs=10, batch_size=4096):
"""
Train SAE on pre-collected hidden state activations.
activation_dataset: list of tensors, each [d_model]
"""
self.sae.train()
losses = []
for epoch in range(epochs):
epoch_loss = 0
n_batches = 0
# Shuffle
indices = torch.randperm(len(activation_dataset))
for i in range(0, len(indices), batch_size):
batch_idx = indices[i:i + batch_size]
batch = torch.stack([activation_dataset[j] for j in batch_idx]).to("cuda")
self.optimizer.zero_grad()
h_hat, z, loss = self.sae(batch)
loss.backward()
# Clip gradients
torch.nn.utils.clip_grad_norm_(self.sae.parameters(), 1.0)
self.optimizer.step()
# Re-normalize decoder columns to unit norm
with torch.no_grad():
self.sae.decoder.weight.data = self.sae.decoder.weight.data / \
self.sae.decoder.weight.data.norm(dim=0, keepdim=True)
epoch_loss += loss.item()
n_batches += 1
avg_loss = epoch_loss / n_batches
losses.append(avg_loss)
# Compute sparsity stats
with torch.no_grad():
sample = torch.stack(activation_dataset[:1000]).to("cuda")
_, z, _ = self.sae(sample)
active = (z > 0).float().mean(dim=0) # Per-feature activation frequency
avg_active = (z > 0).float().sum(dim=1).mean() # Avg features active per input
print(
f"Epoch {epoch+1}: loss={avg_loss:.6f}, "
f"avg_active_features={avg_active:.1f}/{self.sae.n_features}, "
f"dead_features={int((active == 0).sum())}"
)
return losses
Collecting Activations
class ActivationCollector:
"""Collect hidden state activations from a transformer model."""
def __init__(self, model, tokenizer, layer_idx):
self.model = model
self.tokenizer = tokenizer
self.layer_idx = layer_idx
self.activations = []
def collect(self, texts, max_tokens=512):
"""Collect activations from a list of texts."""
hook_handle = None
def hook_fn(module, input, output):
# output is the hidden state at this layer
if isinstance(output, tuple):
hidden = output[0]
else:
hidden = output
# Store all token positions
self.activations.append(hidden.detach().cpu())
# Register hook on the target layer
layer = self.model.model.layers[self.layer_idx]
hook_handle = layer.register_forward_hook(hook_fn)
for text in texts:
input_ids = self.tokenizer.encode(
text, return_tensors="pt", max_length=max_tokens, truncation=True
).to("cuda")
with torch.no_grad():
self.model(input_ids)
hook_handle.remove()
# Flatten: list of [1, seq_len, d_model] -> list of [d_model]
all_activations = []
for act in self.activations:
for pos in range(act.shape[1]):
all_activations.append(act[0, pos])
return all_activations
Feature Interpretation
Identifying What Features Represent
After training the SAE, each feature is a direction in hidden space. To interpret what a feature represents, find the inputs that maximally activate it:
class FeatureInterpreter:
"""Interpret SAE features by finding maximally activating inputs."""
def __init__(self, sae, model, tokenizer, layer_idx):
self.sae = sae
self.model = model
self.tokenizer = tokenizer
self.layer_idx = layer_idx
def find_top_activations(self, feature_idx, texts, top_k=20):
"""Find the inputs that maximally activate a specific feature."""
activations = []
for text_idx, text in enumerate(texts):
input_ids = self.tokenizer.encode(
text, return_tensors="pt", max_length=512, truncation=True
).to("cuda")
# Get hidden states at the target layer
with torch.no_grad():
outputs = self.model(input_ids, output_hidden_states=True)
hidden = outputs.hidden_states[self.layer_idx] # [1, seq_len, d]
# Get SAE feature activations
z = self.sae.get_feature_activations(hidden[0]) # [seq_len, M]
# Find the activation of the target feature at each position
feature_acts = z[:, feature_idx] # [seq_len]
for pos in range(len(feature_acts)):
if feature_acts[pos] > 0:
# Get the token and its context
tokens = input_ids[0]
start = max(0, pos - 5)
end = min(len(tokens), pos + 6)
context = self.tokenizer.decode(tokens[start:end])
target_token = self.tokenizer.decode(tokens[pos:pos+1])
activations.append({
"text_idx": text_idx,
"position": pos,
"activation": feature_acts[pos].item(),
"token": target_token,
"context": context,
})
# Sort by activation strength
activations.sort(key=lambda x: x["activation"], reverse=True)
return activations[:top_k]
def auto_label_feature(self, feature_idx, texts, top_k=50):
"""Automatically generate a description of what a feature represents."""
top_acts = self.find_top_activations(feature_idx, texts, top_k)
if not top_acts:
return "Dead feature (never activates)"
# Collect the tokens and contexts
tokens = [a["token"].strip() for a in top_acts]
contexts = [a["context"] for a in top_acts]
# Simple heuristic labeling
from collections import Counter
token_counts = Counter(tokens)
most_common = token_counts.most_common(5)
# Check for patterns
label_parts = []
# Check if it is a language feature
# Check if it is a topic feature (common words)
# Check if it is a syntax feature (punctuation, structure)
if all(t.isdigit() or t == '.' for t, _ in most_common):
label_parts.append("number/digit")
elif any(t in [',', '.', '!', '?', ';'] for t, _ in most_common):
label_parts.append("punctuation")
else:
label_parts.append(f"tokens: {', '.join(t for t, _ in most_common[:3])}")
return {
"feature_idx": feature_idx,
"auto_label": " | ".join(label_parts),
"top_tokens": most_common[:10],
"sample_contexts": contexts[:5],
"avg_activation": np.mean([a["activation"] for a in top_acts]),
}
Example SAE Features Found in Llama 7B (Layer 15)
| Feature ID | Description | Activation Frequency | Top Tokens |
|---|---|---|---|
| F-2341 | Python code keywords | 3.2% | def, class, import, return |
| F-891 | French language | 1.8% | le, de, les, est, que |
| F-15432 | Dates and years | 2.1% | 2023, January, 1990, century |
| F-7823 | Negative sentiment | 4.5% | bad, terrible, wrong, fail |
| F-3021 | Mathematical notation | 1.2% | =, +, sum, equation, proof |
| F-12001 | Legal language | 0.8% | court, law, statute, defendant |
Feature Circuits
Tracing Features Through Layers
A “circuit” is a subnetwork of features across layers that implements a specific behavior. For example, the circuit for “answer a question about the capital of France” might involve:
Layer 5: Feature "France" activates
Layer 10: Feature "capital_of" activates
Layer 15: Features "France" + "capital_of" compose
Layer 20: Feature "Paris" gets boosted in the output direction
class CircuitTracer:
"""Trace feature activations through layers to identify circuits."""
def __init__(self, saes, model, tokenizer):
"""
Args:
saes: dict of {layer_idx: SparseAutoencoder}
"""
self.saes = saes
self.model = model
self.tokenizer = tokenizer
def trace_input(self, text, target_token_position=-1):
"""Trace feature activations through all layers for an input."""
input_ids = self.tokenizer.encode(text, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = self.model(input_ids, output_hidden_states=True)
layer_features = {}
for layer_idx, sae in self.saes.items():
hidden = outputs.hidden_states[layer_idx][0, target_token_position]
z = sae.get_feature_activations(hidden.unsqueeze(0)).squeeze(0)
# Get top active features
active_mask = z > 0
active_indices = active_mask.nonzero(as_tuple=True)[0]
active_values = z[active_mask]
# Sort by activation strength
sorted_idx = active_values.argsort(descending=True)
top_features = []
for idx in sorted_idx[:20]: # Top 20 per layer
feature_id = active_indices[idx].item()
activation = active_values[idx].item()
top_features.append({
"feature_id": feature_id,
"activation": activation,
})
layer_features[layer_idx] = top_features
return layer_features
def find_circuit(self, text, target_output_token):
"""
Find the circuit responsible for predicting a specific output token.
Uses attribution: which features most influence the target logit.
"""
input_ids = self.tokenizer.encode(text, return_tensors="pt").to("cuda")
target_id = self.tokenizer.encode(target_output_token, add_special_tokens=False)[0]
circuit = []
# For each layer, find features that increase the target logit
for layer_idx, sae in sorted(self.saes.items()):
hidden = self._get_hidden_state(input_ids, layer_idx)
z = sae.get_feature_activations(hidden)
# For each active feature, compute its contribution to the target logit
active_features = (z > 0).squeeze().nonzero(as_tuple=True)[0]
for feat_idx in active_features:
feat_direction = sae.get_feature_direction(feat_idx)
feat_activation = z[0, feat_idx].item()
# Project feature direction through remaining layers to output
# (simplified: use the output head directly)
output_weight = self.model.lm_head.weight[target_id]
contribution = (feat_direction * output_weight).sum().item() * feat_activation
if abs(contribution) > 0.1:
circuit.append({
"layer": layer_idx,
"feature_id": feat_idx.item(),
"activation": feat_activation,
"contribution_to_target": contribution,
})
# Sort by contribution magnitude
circuit.sort(key=lambda x: abs(x["contribution_to_target"]), reverse=True)
return circuit
The Logit Lens
Projecting Intermediate Layers to Output Space
The logit lens (nostalgebraist, 2020) projects hidden states from intermediate layers through the final output head to see what the model is “thinking” at each layer:
class LogitLens:
"""
Project intermediate hidden states through the output head.
Reveals how the model's prediction evolves layer by layer.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def analyze(self, text, position=-1, top_k=5):
"""
Show the model's top predictions at each layer.
"""
input_ids = self.tokenizer.encode(text, return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = self.model(input_ids, output_hidden_states=True)
results = []
for layer_idx, hidden_state in enumerate(outputs.hidden_states):
# Apply final layer norm
normed = self.model.model.norm(hidden_state[0, position])
# Project through the output head
logits = self.model.lm_head(normed.unsqueeze(0)).squeeze(0)
# Get top-k predictions
probs = torch.softmax(logits, dim=-1)
top_probs, top_indices = torch.topk(probs, top_k)
layer_predictions = []
for prob, idx in zip(top_probs, top_indices):
token = self.tokenizer.decode([idx.item()])
layer_predictions.append({
"token": token,
"probability": prob.item(),
})
results.append({
"layer": layer_idx,
"top_predictions": layer_predictions,
})
return results
def visualize_evolution(self, text, target_token, position=-1):
"""Track how the probability of a specific token evolves across layers."""
input_ids = self.tokenizer.encode(text, return_tensors="pt").to("cuda")
target_id = self.tokenizer.encode(target_token, add_special_tokens=False)[0]
with torch.no_grad():
outputs = self.model(input_ids, output_hidden_states=True)
layer_probs = []
for layer_idx, hidden_state in enumerate(outputs.hidden_states):
normed = self.model.model.norm(hidden_state[0, position])
logits = self.model.lm_head(normed.unsqueeze(0)).squeeze(0)
probs = torch.softmax(logits, dim=-1)
target_prob = probs[target_id].item()
layer_probs.append({
"layer": layer_idx,
"target_probability": target_prob,
})
return layer_probs
The logit lens reveals that early layers often predict generic, high-frequency tokens (articles, prepositions), and the model’s specific prediction emerges only in the later layers. For factual recall (e.g., “The capital of France is ___”), the correct answer “Paris” typically appears in the top predictions around layer 60-70% of the way through the network.
Applications: Debugging and Safety
Debugging Model Behavior
class ModelDebugger:
"""Use interpretability tools to debug model behavior."""
def __init__(self, saes, model, tokenizer):
self.circuit_tracer = CircuitTracer(saes, model, tokenizer)
self.logit_lens = LogitLens(model, tokenizer)
self.interpreter = FeatureInterpreter(
list(saes.values())[0], model, tokenizer, list(saes.keys())[0]
)
def debug_wrong_answer(self, prompt, wrong_answer, correct_answer):
"""Investigate why the model produces a wrong answer."""
# Step 1: Find the circuit for the wrong answer
wrong_circuit = self.circuit_tracer.find_circuit(prompt, wrong_answer)
# Step 2: Find the circuit for the correct answer
correct_circuit = self.circuit_tracer.find_circuit(prompt, correct_answer)
# Step 3: Identify features that boost the wrong answer
wrong_boosters = [
f for f in wrong_circuit
if f["contribution_to_target"] > 0.1
]
# Step 4: Identify features that should boost the correct answer
correct_boosters = [
f for f in correct_circuit
if f["contribution_to_target"] > 0.1
]
# Step 5: Compare
wrong_features = set(f["feature_id"] for f in wrong_boosters)
correct_features = set(f["feature_id"] for f in correct_boosters)
return {
"wrong_answer": wrong_answer,
"correct_answer": correct_answer,
"wrong_circuit_size": len(wrong_boosters),
"correct_circuit_size": len(correct_boosters),
"features_unique_to_wrong": wrong_features - correct_features,
"features_unique_to_correct": correct_features - wrong_features,
"shared_features": wrong_features & correct_features,
}
def detect_bias_features(self, texts_group_a, texts_group_b, feature_threshold=0.1):
"""Find features that activate differently for two groups of texts."""
# Collect feature activations for each group
acts_a = self._collect_feature_acts(texts_group_a)
acts_b = self._collect_feature_acts(texts_group_b)
# Find features with significantly different activation rates
biased_features = []
for feat_idx in range(acts_a.shape[1]):
rate_a = (acts_a[:, feat_idx] > 0).float().mean().item()
rate_b = (acts_b[:, feat_idx] > 0).float().mean().item()
diff = abs(rate_a - rate_b)
if diff > feature_threshold:
biased_features.append({
"feature_id": feat_idx,
"rate_group_a": rate_a,
"rate_group_b": rate_b,
"difference": diff,
})
biased_features.sort(key=lambda x: x["difference"], reverse=True)
return biased_features
Feature Steering
Once you identify a feature, you can artificially activate or deactivate it to control the model’s behavior:
class FeatureSteering:
"""Steer model behavior by intervening on SAE features."""
def __init__(self, sae, model, tokenizer, layer_idx):
self.sae = sae
self.model = model
self.tokenizer = tokenizer
self.layer_idx = layer_idx
def generate_with_feature(
self, prompt, feature_idx, activation_strength=5.0, max_tokens=200
):
"""Generate text with a specific feature artificially activated."""
feature_direction = self.sae.get_feature_direction(feature_idx)
# Hook to inject the feature at the target layer
def hook_fn(module, input, output):
if isinstance(output, tuple):
hidden = output[0]
else:
hidden = output
# Add the feature direction scaled by activation strength
modified = hidden + activation_strength * feature_direction.to(hidden.device)
if isinstance(output, tuple):
return (modified,) + output[1:]
return modified
layer = self.model.model.layers[self.layer_idx]
handle = layer.register_forward_hook(hook_fn)
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to("cuda")
with torch.no_grad():
output = self.model.generate(
input_ids, max_new_tokens=max_tokens, temperature=0.7,
)
handle.remove()
return self.tokenizer.decode(output[0][input_ids.shape[1]:], skip_special_tokens=True)
Feature steering enables precise behavioral control without retraining. If you discover a feature associated with generating harmful content, you can suppress it (set its activation to zero) during inference. This is more targeted than output filtering because it operates at the representation level, preventing the harmful behavior from forming rather than catching it after the fact.
Practical Considerations
SAE Scale
SAE Training Costs and Characteristics
| Feature Count | Training Data | Training Time (A100) | Reconstruction Loss | Dead Features |
|---|---|---|---|---|
| 4K features | 100M tokens | 2 hours | 0.015 | 5% |
| 16K features | 500M tokens | 12 hours | 0.008 | 8% |
| 65K features | 1B tokens | 48 hours | 0.004 | 12% |
| 262K features | 4B tokens | 200 hours | 0.002 | 15% |