An LLM trained in 2024 does not know who won the 2025 elections, the current stock price of any company, or whether a recently-discovered drug interaction is dangerous. Retraining to update this knowledge costs millions of dollars and takes weeks. Fine-tuning on new facts risks catastrophic forgetting of existing knowledge. RAG (retrieval augmented generation) adds external knowledge at inference time but increases latency and does not truly update the model’s internal representations.
Knowledge editing offers a different approach: surgically modify the specific weights that encode a particular fact, changing what the model “knows” without retraining. ROME (Rank-One Model Editing) demonstrates that a single factual association (subject attribute) is stored in the MLP layers of specific transformer blocks and can be updated by a rank-one weight modification. MEMIT (Mass-Editing Memory In a Transformer) extends this to thousands of simultaneous edits.
This post covers the mathematics and implementation of knowledge editing: causal tracing to locate facts, ROME for single-fact editing, MEMIT for batch editing, and the side effects that limit practical deployment.
Where Facts Live in Transformers
Causal Tracing
import numpy as np
from dataclasses import dataclass
from typing import Optional
@dataclass
class CausalTraceResult:
"""Result of causal tracing for a fact."""
subject: str
relation: str
object_true: str
layer_contributions: list
critical_layer: int
critical_token: int
indirect_effect: float
class CausalTracer:
"""
Causal tracing: identify which layers and positions
are critical for recalling a specific fact.
Method (Meng et al., 2022):
1. Run the model on "The Eiffel Tower is in [Paris]"
2. Corrupt the subject tokens ("The ##### Tower is in")
with noise. The model can no longer predict "Paris".
3. For each (layer, position), restore the clean
activation at that single point.
4. Measure how much the probability of "Paris" recovers.
5. Points where restoration recovers the most probability
are causally critical for this fact.
Typical finding: MLP outputs at the last subject token
in mid-to-late layers (layers 15-25 for 32-layer models)
are most critical.
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def trace(self, prompt, subject, target_token):
"""
Run causal tracing for a specific fact.
prompt: "The Eiffel Tower is located in"
subject: "Eiffel Tower"
target_token: "Paris"
"""
# Step 1: Clean run
clean_logits = self._forward(prompt)
target_id = self.tokenizer.encode(target_token)[0]
clean_prob = self._softmax(clean_logits)[target_id]
# Step 2: Corrupted run (noise on subject tokens)
subject_tokens = self.tokenizer.encode(subject)
subject_positions = self._find_subject_positions(
prompt, subject
)
corrupted_logits = self._forward_corrupted(
prompt, subject_positions, noise_std=3.0
)
corrupted_prob = self._softmax(
corrupted_logits
)[target_id]
# Step 3: Restoration sweep
n_layers = self._get_n_layers()
n_tokens = len(self.tokenizer.encode(prompt))
layer_contributions = np.zeros(
(n_layers, n_tokens)
)
for layer in range(n_layers):
for pos in range(n_tokens):
restored_logits = self._forward_restored(
prompt, subject_positions,
restore_layer=layer,
restore_position=pos,
noise_std=3.0,
)
restored_prob = self._softmax(
restored_logits
)[target_id]
# Indirect effect: how much does restoring
# this single (layer, position) recover
# the probability?
indirect_effect = (
restored_prob - corrupted_prob
) / (clean_prob - corrupted_prob + 1e-10)
layer_contributions[layer, pos] = (
indirect_effect
)
# Find critical (layer, position)
critical_idx = np.unravel_index(
np.argmax(layer_contributions),
layer_contributions.shape,
)
return CausalTraceResult(
subject=subject,
relation=prompt.replace(subject, "[SUBJECT]"),
object_true=target_token,
layer_contributions=layer_contributions.tolist(),
critical_layer=int(critical_idx[0]),
critical_token=int(critical_idx[1]),
indirect_effect=float(
layer_contributions[critical_idx]
),
)
def _forward(self, prompt):
"""Standard forward pass."""
return np.zeros(50000) # Placeholder logits
def _forward_corrupted(self, prompt, positions,
noise_std):
"""Forward pass with noised subject embeddings."""
return np.zeros(50000)
def _forward_restored(self, prompt, positions,
restore_layer, restore_position,
noise_std):
"""
Forward pass: corrupt subject, restore one point.
"""
return np.zeros(50000)
def _softmax(self, logits):
"""Compute softmax."""
exp = np.exp(logits - np.max(logits))
return exp / exp.sum()
def _find_subject_positions(self, prompt, subject):
"""Find token positions of the subject in prompt."""
return [1, 2] # Placeholder
def _get_n_layers(self):
"""Get number of model layers."""
return 32 # Placeholder
Causal tracing consistently finds that factual associations are stored in the MLP layers (not the attention layers) at the last subject token position in middle-to-late transformer blocks. For GPT-J (28 layers), the critical layer is typically 17-22. For Llama 2 70B (80 layers), it is typically 50-65. This localization is what makes surgical editing possible.
ROME: Rank-One Model Editing
Single-Fact Editing
class ROME:
"""
Rank-One Model Editing (Meng et al., 2022).
Core idea: a factual association (subject -> attribute)
is stored as a key-value pair in the MLP of a critical
layer. The key is the subject representation. The value
is the attribute.
The MLP at layer L computes:
m = W_proj * sigma(W_fc * h)
where h is the hidden state, W_fc maps to the intermediate
dimension, and W_proj maps back. Factual recall happens
when the hidden state h at the last subject token activates
a specific "key" in W_fc, which produces the corresponding
"value" through W_proj.
To edit a fact, we modify W_proj with a rank-one update:
W_proj_new = W_proj + delta
where delta = (v_new - v_old) * k^T / (k^T * k)
This changes the value associated with the subject key
without affecting other key-value pairs.
"""
def __init__(self, model, tokenizer, config):
self.model = model
self.tokenizer = tokenizer
self.critical_layer = config.get(
"critical_layer", 17
)
self.v_lr = config.get("v_lr", 5e-1)
self.v_steps = config.get("v_steps", 20)
self.kl_factor = config.get("kl_factor", 0.0625)
def edit_fact(self, subject, target_new, prompts):
"""
Edit a single fact.
subject: "The Eiffel Tower"
target_new: "Rome" (changing location from Paris to Rome)
prompts: test prompts to verify the edit
e.g., ["The Eiffel Tower is located in",
"The city where the Eiffel Tower stands is"]
"""
# Step 1: Compute the key vector (subject representation)
key = self._compute_key(subject)
# Step 2: Compute the new value vector (target representation)
value_new = self._compute_target_value(
subject, target_new, prompts
)
# Step 3: Compute the old value vector
value_old = self._compute_current_value(
subject, key
)
# Step 4: Compute rank-one update
delta = self._compute_rank_one_update(
key, value_old, value_new
)
# Step 5: Apply update to model weights
self._apply_update(delta)
return {
"subject": subject,
"target_new": target_new,
"key_norm": float(np.linalg.norm(key)),
"delta_norm": float(np.linalg.norm(delta)),
}
def _compute_key(self, subject):
"""
Compute the key vector for the subject.
The key is the hidden state at the last subject
token, at the critical layer, passed through
the MLP's first linear layer.
k = W_fc * h_subject
"""
# Get hidden state at last subject token
prompt = f"{subject} is"
h = self._get_hidden_state(
prompt, layer=self.critical_layer,
position=-2, # Last subject token
)
# Pass through MLP first layer
# k = W_fc * h
k = self._mlp_first_layer(h, self.critical_layer)
return k
def _compute_target_value(self, subject, target_new,
prompts):
"""
Compute the value vector that would make the
model output the new target.
Optimize v such that inserting it at the critical
layer causes the model to output target_new for
all test prompts.
"""
# Initialize v from current value
v = self._get_current_value_at_layer(
subject, self.critical_layer
)
target_id = self.tokenizer.encode(target_new)[0]
# Optimize v to maximize P(target_new | prompt)
for step in range(self.v_steps):
grad = np.zeros_like(v)
for prompt in prompts:
# Forward pass with v inserted
logits = self._forward_with_value_override(
prompt, self.critical_layer, v
)
# Gradient: increase probability of target
probs = self._softmax(logits)
target_prob = probs[target_id]
# Cross-entropy gradient (simplified)
grad += (probs - np.eye(len(probs))[target_id]) / len(prompts)
# KL penalty to stay close to original
v_orig = self._get_current_value_at_layer(
subject, self.critical_layer
)
kl_grad = self.kl_factor * (v - v_orig)
v = v - self.v_lr * (grad[:len(v)] + kl_grad)
return v
def _compute_current_value(self, subject, key):
"""Compute current value for the subject key."""
return self._get_current_value_at_layer(
subject, self.critical_layer
)
def _compute_rank_one_update(self, key, value_old,
value_new):
"""
Compute the rank-one weight update.
delta = (v_new - v_old) * k^T / (k^T * C^{-1} * k)
where C is the uncentered covariance of keys
(estimated from a sample of inputs).
This ensures the update only affects the target
key direction, minimizing impact on other facts.
"""
delta_v = value_new - value_old
# Simplified: without covariance correction
k_norm_sq = np.dot(key, key)
if k_norm_sq < 1e-10:
return np.zeros(
(len(value_new), len(key))
)
# Outer product: delta = delta_v * k^T / ||k||^2
update = np.outer(delta_v, key) / k_norm_sq
return update
def _apply_update(self, delta):
"""Apply rank-one update to the critical layer's MLP."""
# In practice: model.layers[L].mlp.W_proj += delta
pass
def _get_hidden_state(self, prompt, layer, position):
"""Get hidden state at specific layer and position."""
return np.random.randn(4096)
def _mlp_first_layer(self, h, layer):
"""Apply MLP first layer."""
return np.random.randn(11008)
def _get_current_value_at_layer(self, subject, layer):
"""Get current MLP output for subject."""
return np.random.randn(4096)
def _forward_with_value_override(self, prompt, layer, v):
"""Forward pass with overridden value at layer."""
return np.zeros(50000)
ROME Editing Results on CounterFact Dataset
| Metric | Before Edit | After Edit | Ideal | Description |
|---|---|---|---|---|
| Efficacy (target probability) | 0.02 | 0.92 | 1.0 | P(new target | edited prompt) |
| Paraphrase (generalization) | 0.01 | 0.85 | 1.0 | P(new target | rephrased prompt) |
| Neighborhood (specificity) | 0.98 | 0.95 | 1.0 | P(original target | similar but different subject) |
| Overall edit score | -- | 0.91 | 1.0 | Harmonic mean of above |
| Perplexity change (general) | 15.2 | 15.3 | 15.2 | Perplexity on unrelated text |
MEMIT: Mass-Editing Memory
Editing Thousands of Facts
class MEMIT:
"""
Mass-Editing Memory In a Transformer (Meng et al., 2023).
ROME edits one fact at a time. Applying N ROME edits
sequentially degrades after ~100 edits because each
edit perturbs the weight space, and perturbations
accumulate.
MEMIT solves this by:
1. Distributing edits across multiple layers
(not just one critical layer)
2. Computing all updates jointly to minimize
interference between edits
3. Applying a single batch update rather than
N sequential updates
The update spreads the information across layers
L, L+1, ..., L+R, where each layer handles a
portion of the total edit.
"""
def __init__(self, model, tokenizer, config):
self.model = model
self.tokenizer = tokenizer
self.edit_layers = config.get(
"edit_layers", list(range(15, 25))
)
self.v_lr = config.get("v_lr", 0.5)
self.v_steps = config.get("v_steps", 25)
def batch_edit(self, edits):
"""
Apply multiple fact edits simultaneously.
edits: list of dicts with keys:
- subject: str
- target_new: str
- prompts: list of test prompts
"""
n_edits = len(edits)
n_layers = len(self.edit_layers)
# Step 1: Compute keys for all edits
keys = []
for edit in edits:
key = self._compute_key(edit["subject"])
keys.append(key)
# Step 2: Compute target values for all edits
target_values = []
for edit in edits:
v = self._compute_target_value(
edit["subject"],
edit["target_new"],
edit["prompts"],
)
target_values.append(v)
# Step 3: Compute current values
current_values = []
for edit in edits:
v_old = self._compute_current_value(
edit["subject"]
)
current_values.append(v_old)
# Step 4: Distribute delta across layers
layer_deltas = self._distribute_deltas(
keys, current_values, target_values
)
# Step 5: Apply all updates
for layer_idx, layer in enumerate(self.edit_layers):
if layer_deltas[layer_idx] is not None:
self._apply_update_to_layer(
layer, layer_deltas[layer_idx]
)
return {
"n_edits": n_edits,
"n_layers_modified": n_layers,
"avg_delta_norm": float(np.mean([
np.linalg.norm(d) for d in layer_deltas
if d is not None
])),
}
def _distribute_deltas(self, keys, current_values,
target_values):
"""
Distribute the total edit across multiple layers.
The residual delta for each edit is:
delta_i = v_new_i - v_old_i
This residual is distributed equally across
edit layers using a least-squares solution.
"""
n_layers = len(self.edit_layers)
n_edits = len(keys)
layer_deltas = [None] * n_layers
# Compute residuals
residuals = []
for i in range(n_edits):
delta = target_values[i] - current_values[i]
residuals.append(delta)
# Simple distribution: equal share per layer
for layer_idx in range(n_layers):
layer_delta_sum = np.zeros_like(residuals[0])
for i, (key, residual) in enumerate(
zip(keys, residuals)
):
share = residual / n_layers
k_norm_sq = np.dot(key, key)
if k_norm_sq > 1e-10:
update = np.outer(share, key) / k_norm_sq
if layer_deltas[layer_idx] is None:
layer_deltas[layer_idx] = update
else:
layer_deltas[layer_idx] += update
return layer_deltas
def _compute_key(self, subject):
"""Compute subject key vector."""
return np.random.randn(4096)
def _compute_target_value(self, subject, target_new,
prompts):
"""Compute optimized target value."""
return np.random.randn(4096)
def _compute_current_value(self, subject):
"""Compute current value for subject."""
return np.random.randn(4096)
def _apply_update_to_layer(self, layer, delta):
"""Apply weight update to a specific layer."""
pass
Edit Success Rate: ROME vs MEMIT by Number of Edits
| Metric | 1 | 10 | 100 | 1000 | 5000 | 10000 |
|---|---|---|---|---|---|---|
| ROME (sequential) | ||||||
| MEMIT (batch) | ||||||
| Fine-tuning (for reference) |
MEMIT maintains 80% success rate at 1,000 simultaneous edits where ROME drops to 45%. However, even MEMIT degrades at 10,000+ edits. The fundamental limitation is that transformer MLP weights have finite capacity for storing facts. Editing a fact does not create new capacity — it overwrites existing capacity. At large edit counts, the edits begin to interfere with each other and with the model’s general capabilities.
Side Effects and Limitations
Measuring Edit Quality
class EditQualityEvaluator:
"""
Evaluate the quality of knowledge edits.
Three dimensions:
1. Efficacy: does the edit work on the exact prompt?
2. Generalization: does it work on paraphrased prompts?
3. Specificity: does it NOT affect related but
different facts?
Additional concerns:
4. Consistency: are logical implications updated?
(If Eiffel Tower is in Rome, is it in Italy?)
5. General capability: does the edit degrade
performance on unrelated tasks?
"""
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
def evaluate_edit(self, edit, test_suite):
"""
Comprehensive evaluation of a single edit.
"""
results = {}
# Efficacy
efficacy_prompts = test_suite.get(
"efficacy_prompts", []
)
efficacy_scores = []
for prompt in efficacy_prompts:
prob = self._get_target_probability(
prompt, edit["target_new"]
)
efficacy_scores.append(prob)
results["efficacy"] = float(
np.mean(efficacy_scores)
) if efficacy_scores else 0.0
# Generalization (paraphrase)
para_prompts = test_suite.get(
"paraphrase_prompts", []
)
para_scores = []
for prompt in para_prompts:
prob = self._get_target_probability(
prompt, edit["target_new"]
)
para_scores.append(prob)
results["generalization"] = float(
np.mean(para_scores)
) if para_scores else 0.0
# Specificity (neighborhood)
neighbor_prompts = test_suite.get(
"neighborhood_prompts", []
)
neighbor_scores = []
for prompt, expected in neighbor_prompts:
prob = self._get_target_probability(
prompt, expected
)
neighbor_scores.append(prob)
results["specificity"] = float(
np.mean(neighbor_scores)
) if neighbor_scores else 0.0
# Consistency (logical implications)
consistency_prompts = test_suite.get(
"consistency_prompts", []
)
consistency_scores = []
for prompt, expected in consistency_prompts:
prob = self._get_target_probability(
prompt, expected
)
consistency_scores.append(prob)
results["consistency"] = float(
np.mean(consistency_scores)
) if consistency_scores else 0.0
# Overall score (harmonic mean)
scores = [
v for v in results.values()
if v > 0
]
results["overall"] = (
len(scores) / sum(1.0 / s for s in scores)
if scores and all(s > 0 for s in scores)
else 0.0
)
return results
def _get_target_probability(self, prompt, target):
"""Get probability of target token given prompt."""
return 0.5 # Placeholder
Knowledge Editing Limitations
| Limitation | Severity | ROME Impact | MEMIT Impact | Mitigation |
|---|---|---|---|---|
| Consistency (logical implications not updated) | High | Not addressed | Not addressed | Chain-of-edits for implications |
| Multi-hop reasoning breaks | High | Severe after 2+ hops | Moderate | Edit the full reasoning chain |
| General capability degradation | Medium | Visible at 100+ edits | Visible at 5000+ edits | Limit edit count |
| Reversibility (undoing edits) | Low | Supported (store delta) | Supported | Save original weights per layer |
| Cross-lingual transfer | Medium | Edits in English do not transfer | Partial transfer | Edit in multiple languages |
Key Takeaways
Knowledge editing enables surgical modification of specific facts in LLMs without retraining. ROME and MEMIT demonstrate that factual associations are localized in MLP layers and can be updated with targeted weight modifications.
The critical findings:
-
Facts are localized in MLP layers: Causal tracing shows that factual associations (subject attribute) are primarily stored in MLP layers at the last subject token position in mid-to-late transformer blocks. This localization makes surgical editing possible.
-
ROME achieves 92% single-edit success: A rank-one update to the MLP projection matrix at the critical layer changes the model’s output for the target fact with 85% generalization to paraphrased prompts and 95% specificity (does not affect neighboring facts).
-
MEMIT scales to thousands of edits: By distributing edits across multiple layers and computing updates jointly, MEMIT maintains 80% success at 1,000 simultaneous edits where sequential ROME drops to 45%.
-
Consistency is the unsolved problem: Editing “Eiffel Tower is in Rome” does not automatically update “The Eiffel Tower is in the country of [Italy]” or “The famous tower in Paris is […]”. Logical implications require additional edits, and the edit graph grows combinatorially.
-
Knowledge editing does not replace retraining for large updates: For fewer than 1,000 targeted fact updates, MEMIT is practical and cost-effective. For larger knowledge updates (new domain, comprehensive factual refresh), fine-tuning or retraining remains necessary. Knowledge editing is best viewed as a tool for quick, targeted corrections.