TransformerLens: Mechanistic Interpretability for Transformers
TransformerLens is the de facto standard library for mechanistic interpretability research on GPT-style language models. Created by Neel Nanda and maintained by Bryce Meyer, it provides clean interfaces to inspect and manipulate model internals via HookPoints on every activation.
GitHub: TransformerLensOrg/TransformerLens (2,900+ stars)
When to Use TransformerLens
Use TransformerLens when you need to:
-
Reverse-engineer algorithms learned during training
-
Perform activation patching / causal tracing experiments
-
Study attention patterns and information flow
-
Analyze circuits (e.g., induction heads, IOI circuit)
-
Cache and inspect intermediate activations
-
Apply direct logit attribution
Consider alternatives when:
-
You need to work with non-transformer architectures → Use nnsight or pyvene
-
You want to train/analyze Sparse Autoencoders → Use SAELens
-
You need remote execution on massive models → Use nnsight with NDIF
-
You want higher-level causal intervention abstractions → Use pyvene
Installation
pip install transformer-lens
For development version:
pip install git+https://github.com/TransformerLensOrg/TransformerLens
Core Concepts
HookedTransformer
The main class that wraps transformer models with HookPoints on every activation:
from transformer_lens import HookedTransformer
Load a model
model = HookedTransformer.from_pretrained("gpt2-small")
For gated models (LLaMA, Mistral)
import os os.environ["HF_TOKEN"] = "your_token" model = HookedTransformer.from_pretrained("meta-llama/Llama-2-7b-hf")
Supported Models (50+)
Family Models
GPT-2 gpt2, gpt2-medium, gpt2-large, gpt2-xl
LLaMA llama-7b, llama-13b, llama-2-7b, llama-2-13b
EleutherAI pythia-70m to pythia-12b, gpt-neo, gpt-j-6b
Mistral mistral-7b, mixtral-8x7b
Others phi, qwen, opt, gemma
Activation Caching
Run the model and cache all intermediate activations:
Get all activations
tokens = model.to_tokens("The Eiffel Tower is in") logits, cache = model.run_with_cache(tokens)
Access specific activations
residual = cache["resid_post", 5] # Layer 5 residual stream attn_pattern = cache["pattern", 3] # Layer 3 attention pattern mlp_out = cache["mlp_out", 7] # Layer 7 MLP output
Filter which activations to cache (saves memory)
logits, cache = model.run_with_cache( tokens, names_filter=lambda name: "resid_post" in name )
ActivationCache Keys
Key Pattern Shape Description
resid_pre, layer
[batch, pos, d_model] Residual before attention
resid_mid, layer
[batch, pos, d_model] Residual after attention
resid_post, layer
[batch, pos, d_model] Residual after MLP
attn_out, layer
[batch, pos, d_model] Attention output
mlp_out, layer
[batch, pos, d_model] MLP output
pattern, layer
[batch, head, q_pos, k_pos] Attention pattern (post-softmax)
q, layer
[batch, pos, head, d_head] Query vectors
k, layer
[batch, pos, head, d_head] Key vectors
v, layer
[batch, pos, head, d_head] Value vectors
Workflow 1: Activation Patching (Causal Tracing)
Identify which activations causally affect model output by patching clean activations into corrupted runs.
Step-by-Step
from transformer_lens import HookedTransformer, patching import torch
model = HookedTransformer.from_pretrained("gpt2-small")
1. Define clean and corrupted prompts
clean_prompt = "The Eiffel Tower is in the city of" corrupted_prompt = "The Colosseum is in the city of"
clean_tokens = model.to_tokens(clean_prompt) corrupted_tokens = model.to_tokens(corrupted_prompt)
2. Get clean activations
_, clean_cache = model.run_with_cache(clean_tokens)
3. Define metric (e.g., logit difference)
paris_token = model.to_single_token(" Paris") rome_token = model.to_single_token(" Rome")
def metric(logits): return logits[0, -1, paris_token] - logits[0, -1, rome_token]
4. Patch each position and layer
results = torch.zeros(model.cfg.n_layers, clean_tokens.shape[1])
for layer in range(model.cfg.n_layers): for pos in range(clean_tokens.shape[1]): def patch_hook(activation, hook): activation[0, pos] = clean_cache[hook.name][0, pos] return activation
patched_logits = model.run_with_hooks(
corrupted_tokens,
fwd_hooks=[(f"blocks.{layer}.hook_resid_post", patch_hook)]
)
results[layer, pos] = metric(patched_logits)
5. Visualize results (layer x position heatmap)
Checklist
-
Define clean and corrupted inputs that differ minimally
-
Choose metric that captures behavior difference
-
Cache clean activations
-
Systematically patch each (layer, position) combination
-
Visualize results as heatmap
-
Identify causal hotspots
Workflow 2: Circuit Analysis (Indirect Object Identification)
Replicate the IOI circuit discovery from "Interpretability in the Wild".
Step-by-Step
from transformer_lens import HookedTransformer import torch
model = HookedTransformer.from_pretrained("gpt2-small")
IOI task: "When John and Mary went to the store, Mary gave a bottle to"
Model should predict "John" (indirect object)
prompt = "When John and Mary went to the store, Mary gave a bottle to" tokens = model.to_tokens(prompt)
1. Get baseline logits
logits, cache = model.run_with_cache(tokens)
john_token = model.to_single_token(" John") mary_token = model.to_single_token(" Mary")
2. Compute logit difference (IO - S)
logit_diff = logits[0, -1, john_token] - logits[0, -1, mary_token] print(f"Logit difference: {logit_diff.item():.3f}")
3. Direct logit attribution by head
def get_head_contribution(layer, head): # Project head output to logits head_out = cache["z", layer][0, :, head, :] # [pos, d_head] W_O = model.W_O[layer, head] # [d_head, d_model] W_U = model.W_U # [d_model, vocab]
# Head contribution to logits at final position
contribution = head_out[-1] @ W_O @ W_U
return contribution[john_token] - contribution[mary_token]
4. Map all heads
head_contributions = torch.zeros(model.cfg.n_layers, model.cfg.n_heads) for layer in range(model.cfg.n_layers): for head in range(model.cfg.n_heads): head_contributions[layer, head] = get_head_contribution(layer, head)
5. Identify top contributing heads (name movers, backup name movers)
Checklist
-
Set up task with clear IO/S tokens
-
Compute baseline logit difference
-
Decompose by attention head contributions
-
Identify key circuit components (name movers, S-inhibition, induction)
-
Validate with ablation experiments
Workflow 3: Induction Head Detection
Find induction heads that implement [A][B]...[A] → [B] pattern.
from transformer_lens import HookedTransformer import torch
model = HookedTransformer.from_pretrained("gpt2-small")
Create repeated sequence: [A][B][A] should predict [B]
repeated_tokens = torch.tensor([[1000, 2000, 1000]]) # Arbitrary tokens
_, cache = model.run_with_cache(repeated_tokens)
Induction heads attend from final [A] back to first [B]
Check attention from position 2 to position 1
induction_scores = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
for layer in range(model.cfg.n_layers): pattern = cache["pattern", layer][0] # [head, q_pos, k_pos] # Attention from pos 2 to pos 1 induction_scores[layer] = pattern[:, 2, 1]
Heads with high scores are induction heads
top_heads = torch.topk(induction_scores.flatten(), k=5)
Common Issues & Solutions
Issue: Hooks persist after debugging
WRONG: Old hooks remain active
model.run_with_hooks(tokens, fwd_hooks=[...]) # Debug, add new hooks model.run_with_hooks(tokens, fwd_hooks=[...]) # Old hooks still there!
RIGHT: Always reset hooks
model.reset_hooks() model.run_with_hooks(tokens, fwd_hooks=[...])
Issue: Tokenization gotchas
WRONG: Assuming consistent tokenization
model.to_tokens("Tim") # Single token model.to_tokens("Neel") # Becomes "Ne" + "el" (two tokens!)
RIGHT: Check tokenization explicitly
tokens = model.to_tokens("Neel", prepend_bos=False) print(model.to_str_tokens(tokens)) # ['Ne', 'el']
Issue: LayerNorm ignored in analysis
WRONG: Ignoring LayerNorm
pre_activation = residual @ model.W_in[layer]
RIGHT: Include LayerNorm
ln_scale = model.blocks[layer].ln2.w ln_out = model.blocks[layer].ln2(residual) pre_activation = ln_out @ model.W_in[layer]
Issue: Memory explosion with large models
Use selective caching
logits, cache = model.run_with_cache( tokens, names_filter=lambda n: "resid_post" in n or "pattern" in n, device="cpu" # Cache on CPU )
Key Classes Reference
Class Purpose
HookedTransformer
Main model wrapper with hooks
ActivationCache
Dictionary-like cache of activations
HookedTransformerConfig
Model configuration
FactoredMatrix
Efficient factored matrix operations
Integration with SAELens
TransformerLens integrates with SAELens for Sparse Autoencoder analysis:
from transformer_lens import HookedTransformer from sae_lens import SAE
model = HookedTransformer.from_pretrained("gpt2-small") sae = SAE.from_pretrained("gpt2-small-res-jb", "blocks.8.hook_resid_pre")
Run with SAE
tokens = model.to_tokens("Hello world") _, cache = model.run_with_cache(tokens) sae_acts = sae.encode(cache["resid_pre", 8])
Reference Documentation
For detailed API documentation, tutorials, and advanced usage, see the references/ folder:
File Contents
references/README.md Overview and quick start guide
references/api.md Complete API reference for HookedTransformer, ActivationCache, HookPoints
references/tutorials.md Step-by-step tutorials for activation patching, circuit analysis, logit lens
External Resources
Tutorials
-
Main Demo Notebook
-
Activation Patching Demo
-
ARENA Mech Interp Course - 200+ hours of tutorials
Papers
-
A Mathematical Framework for Transformer Circuits
-
In-context Learning and Induction Heads
-
Interpretability in the Wild (IOI)
Official Documentation
-
Official Docs
-
Model Properties Table
-
Neel Nanda's Glossary
Version Notes
-
v2.0: Removed HookedSAE (moved to SAELens)
-
v3.0 (alpha): TransformerBridge for loading any nn.Module