Speculative Decoding: Accelerating LLM Inference
When to Use This Skill
Use Speculative Decoding when you need to:
-
Speed up inference by 1.5-3.6× without quality loss
-
Reduce latency for real-time applications (chatbots, code generation)
-
Optimize throughput for high-volume serving
-
Deploy efficiently on limited hardware
-
Generate faster without changing model architecture
Key Techniques: Draft model speculative decoding, Medusa (multiple heads), Lookahead Decoding (Jacobi iteration)
Papers: Medusa (arXiv 2401.10774), Lookahead Decoding (ICML 2024), Speculative Decoding Survey (ACL 2024)
Installation
Standard speculative decoding (transformers)
pip install transformers accelerate
Medusa (multiple decoding heads)
git clone https://github.com/FasterDecoding/Medusa cd Medusa pip install -e .
Lookahead Decoding
git clone https://github.com/hao-ai-lab/LookaheadDecoding cd LookaheadDecoding pip install -e .
Optional: vLLM with speculative decoding
pip install vllm
Quick Start
Basic Speculative Decoding (Draft Model)
from transformers import AutoModelForCausalLM, AutoTokenizer
Load target model (large, slow)
target_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-70b-hf", device_map="auto", torch_dtype=torch.float16 )
Load draft model (small, fast)
draft_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16 )
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-hf")
Generate with speculative decoding
prompt = "Explain quantum computing in simple terms:" inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
Transformers 4.36+ supports assisted generation
outputs = target_model.generate( **inputs, assistant_model=draft_model, # Enable speculative decoding max_new_tokens=256, do_sample=True, temperature=0.7, )
response = tokenizer.decode(outputs[0], skip_special_tokens=True) print(response)
Medusa (Multiple Decoding Heads)
from medusa.model.medusa_model import MedusaModel
Load Medusa-enhanced model
model = MedusaModel.from_pretrained( "FasterDecoding/medusa-vicuna-7b-v1.3", # Pre-trained with Medusa heads torch_dtype=torch.float16, device_map="auto" )
tokenizer = AutoTokenizer.from_pretrained("FasterDecoding/medusa-vicuna-7b-v1.3")
Generate with Medusa (2-3× speedup)
prompt = "Write a Python function to calculate fibonacci numbers:" inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
outputs = model.medusa_generate( **inputs, max_new_tokens=256, temperature=0.7, posterior_threshold=0.09, # Acceptance threshold posterior_alpha=0.3, # Tree construction parameter )
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
Lookahead Decoding (Jacobi Iteration)
from lookahead.lookahead_decoding import LookaheadDecoding
Load model
model = AutoModelForCausalLM.from_pretrained( "meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, device_map="auto" ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
Initialize lookahead decoding
lookahead = LookaheadDecoding( model=model, tokenizer=tokenizer, window_size=15, # Lookahead window (W) ngram_size=5, # N-gram size (N) guess_size=5 # Number of parallel guesses )
Generate (1.5-2.3× speedup)
prompt = "Implement quicksort in Python:" output = lookahead.generate(prompt, max_new_tokens=256) print(output)
Core Concepts
- Speculative Decoding (Draft Model)
Idea: Use small draft model to generate candidates, large target model to verify in parallel.
Algorithm:
-
Draft model generates K tokens speculatively
-
Target model evaluates all K tokens in parallel (single forward pass)
-
Accept tokens where draft and target agree
-
Reject first disagreement, continue from there
def speculative_decode(target_model, draft_model, prompt, K=4): """Speculative decoding algorithm.""" # 1. Generate K draft tokens draft_tokens = draft_model.generate(prompt, max_new_tokens=K)
# 2. Target model evaluates all K tokens in one forward pass
target_logits = target_model(draft_tokens) # Parallel!
# 3. Accept/reject based on probability match
accepted = []
for i in range(K):
p_draft = softmax(draft_model.logits[i])
p_target = softmax(target_logits[i])
# Acceptance probability
if random.random() < min(1, p_target[draft_tokens[i]] / p_draft[draft_tokens[i]]):
accepted.append(draft_tokens[i])
else:
break # Reject, resample from target
return accepted
Performance:
-
Speedup: 1.5-2× with good draft model
-
Zero quality loss (mathematically equivalent to target model)
-
Best when draft model is 5-10× smaller than target
- Medusa (Multiple Decoding Heads)
Source: arXiv 2401.10774 (2024)
Innovation: Add multiple prediction heads to existing model, predict future tokens without separate draft model.
Architecture:
Input → Base LLM (frozen) → Hidden State ├→ Head 1 (predicts token t+1) ├→ Head 2 (predicts token t+2) ├→ Head 3 (predicts token t+3) └→ Head 4 (predicts token t+4)
Training:
-
Medusa-1: Freeze base LLM, train only heads
-
2.2× speedup, lossless
-
Medusa-2: Fine-tune base LLM + heads together
-
2.3-3.6× speedup, better quality
Tree-based Attention:
Medusa constructs tree of candidates
Example: Predict 2 steps ahead with top-2 per step
Root
/ \
T1a T1b (Step 1: 2 candidates)
/ \ / \
T2a T2b T2c T2d (Step 2: 4 candidates total)
Single forward pass evaluates entire tree!
Advantages:
-
No separate draft model needed
-
Minimal training (only heads)
-
Compatible with any LLM
- Lookahead Decoding (Jacobi Iteration)
Source: ICML 2024
Core idea: Reformulate autoregressive decoding as solving system of equations, solve in parallel using Jacobi iteration.
Mathematical formulation:
Traditional: y_t = f(x, y_1, ..., y_{t-1}) (sequential) Jacobi: y_t^{(k+1)} = f(x, y_1^{(k)}, ..., y_{t-1}^{(k)}) (parallel)
Two branches:
Lookahead Branch: Generate n-grams in parallel
-
Window size W: How many steps to look ahead
-
N-gram size N: How many past tokens to use
Verification Branch: Verify promising n-grams
-
Match n-grams with generated tokens
-
Accept if first token matches
class LookaheadDecoding: def init(self, model, window_size=15, ngram_size=5): self.model = model self.W = window_size # Lookahead window self.N = ngram_size # N-gram size
def generate_step(self, tokens):
# Lookahead branch: Generate W × N candidates
candidates = {}
for w in range(1, self.W + 1):
for n in range(1, self.N + 1):
# Generate n-gram starting at position w
ngram = self.generate_ngram(tokens, start=w, length=n)
candidates[(w, n)] = ngram
# Verification branch: Find matching n-grams
verified = []
for ngram in candidates.values():
if ngram[0] == tokens[-1]: # First token matches last input
if self.verify(tokens, ngram):
verified.append(ngram)
# Accept longest verified n-gram
return max(verified, key=len) if verified else [self.model.generate_next(tokens)]
Performance:
-
Speedup: 1.5-2.3× (up to 3.6× for code generation)
-
No draft model or training needed
-
Works out-of-the-box with any model
Method Comparison
Method Speedup Training Needed Draft Model Quality Loss
Draft Model Speculative 1.5-2× No Yes (external) None
Medusa 2-3.6× Minimal (heads only) No (built-in heads) None
Lookahead 1.5-2.3× None No None
Naive Batching 1.2-1.5× No No None
Advanced Patterns
Training Medusa Heads
from medusa.model.medusa_model import MedusaModel from medusa.model.kv_cache import initialize_past_key_values import torch.nn as nn
1. Load base model
base_model = AutoModelForCausalLM.from_pretrained( "lmsys/vicuna-7b-v1.3", torch_dtype=torch.float16 )
2. Add Medusa heads
num_heads = 4 medusa_heads = nn.ModuleList([ nn.Linear(base_model.config.hidden_size, base_model.config.vocab_size, bias=False) for _ in range(num_heads) ])
3. Training loop (freeze base model for Medusa-1)
for param in base_model.parameters(): param.requires_grad = False # Freeze base
optimizer = torch.optim.Adam(medusa_heads.parameters(), lr=1e-3)
for batch in dataloader: # Forward pass hidden_states = base_model(**batch, output_hidden_states=True).hidden_states[-1]
# Predict future tokens with each head
loss = 0
for i, head in enumerate(medusa_heads):
logits = head(hidden_states)
# Target: tokens shifted by (i+1) positions
target = batch['input_ids'][:, i+1:]
loss += F.cross_entropy(logits[:, :-i-1], target)
# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()
Hybrid: Speculative + Medusa
Use Medusa as draft model for speculative decoding
draft_medusa = MedusaModel.from_pretrained("medusa-vicuna-7b") target_model = AutoModelForCausalLM.from_pretrained("vicuna-33b")
Draft generates multiple candidates with Medusa
draft_tokens = draft_medusa.medusa_generate(prompt, max_new_tokens=5)
Target verifies in single forward pass
outputs = target_model.generate( prompt, assistant_model=draft_medusa, # Use Medusa as draft max_new_tokens=256 )
Combines benefits: Medusa speed + large model quality
Optimal Draft Model Selection
def select_draft_model(target_model_size, target): """Select optimal draft model for speculative decoding.""" # Rule: Draft should be 5-10× smaller if target_model_size == "70B": return "7B" # 10× smaller elif target_model_size == "33B": return "7B" # 5× smaller elif target_model_size == "13B": return "1B" # 13× smaller else: return None # Target too small, use Medusa/Lookahead instead
Example
draft = select_draft_model("70B", target_model)
Returns "7B" → Use Llama-2-7b as draft for Llama-2-70b
Best Practices
- Choose the Right Method
New deployment → Medusa (best overall speedup, no draft model)
if deploying_new_model: use_method = "Medusa"
Existing deployment with small model available → Draft speculative
elif have_small_version_of_model: use_method = "Draft Model Speculative"
Want zero training/setup → Lookahead
elif want_plug_and_play: use_method = "Lookahead Decoding"
- Hyperparameter Tuning
Draft Model Speculative:
K = number of speculative tokens
K = 4 # Good default K = 2 # Conservative (higher acceptance) K = 8 # Aggressive (lower acceptance, but more when accepted)
Rule: Larger K → more speedup IF draft model is good
Medusa:
Posterior threshold (acceptance confidence)
posterior_threshold = 0.09 # Standard (from paper) posterior_threshold = 0.05 # More conservative (slower, higher quality) posterior_threshold = 0.15 # More aggressive (faster, may degrade quality)
Tree depth (how many steps ahead)
medusa_choices = [[0], [0, 0], [0, 1], [0, 0, 0]] # Depth 3 (standard)
Lookahead:
Window size W (lookahead distance)
N-gram size N (context for generation)
7B model (more resources)
W, N = 15, 5
13B model (moderate)
W, N = 10, 5
33B+ model (limited resources)
W, N = 7, 5
- Production Deployment
vLLM with speculative decoding
from vllm import LLM, SamplingParams
Initialize with draft model
llm = LLM( model="meta-llama/Llama-2-70b-hf", speculative_model="meta-llama/Llama-2-7b-hf", # Draft model num_speculative_tokens=5, use_v2_block_manager=True, )
Generate
prompts = ["Tell me about AI:", "Explain quantum physics:"] sampling_params = SamplingParams(temperature=0.7, max_tokens=256)
outputs = llm.generate(prompts, sampling_params) for output in outputs: print(output.outputs[0].text)
Resources
-
Medusa Paper: https://arxiv.org/abs/2401.10774
-
Medusa GitHub: https://github.com/FasterDecoding/Medusa
-
Lookahead Decoding (ICML 2024): https://lmsys.org/blog/2023-11-21-lookahead-decoding/
-
Lookahead GitHub: https://github.com/hao-ai-lab/LookaheadDecoding
-
Speculative Decoding Survey (ACL 2024): https://aclanthology.org/2024.findings-acl.456.pdf
-
Comprehensive Survey: https://arxiv.org/abs/2401.07851
See Also
-
references/draft_model.md
-
Draft model selection and training
-
references/medusa.md
-
Medusa architecture and training
-
references/lookahead.md
-
Lookahead decoding implementation details