Torch Pipeline Parallelism
Overview
This skill provides guidance for implementing pipeline parallelism in PyTorch for distributed model training. Pipeline parallelism partitions a model across multiple devices/ranks, where each rank processes a subset of layers and communicates activations/gradients with neighboring ranks.
Key Concepts
Pipeline Parallelism Patterns
-
AFAB (All-Forward-All-Backward): Process all microbatch forwards first, cache activations, then process all backwards. This is the most common pattern for pipeline parallelism.
-
1F1B (One-Forward-One-Backward): Interleave forward and backward passes for better memory efficiency but more complex scheduling.
Critical Components
-
Model Partitioning: Divide model layers across ranks
-
Activation Communication: Send/receive hidden states between ranks
-
Gradient Communication: Send/receive gradients during backward pass
-
Activation Caching: Store activations for backward pass computation
Implementation Approach
Step 1: Understand Model Architecture First
Before implementing, thoroughly understand the model being parallelized:
-
Identify the layer structure (e.g., model.model.layers for LLaMA)
-
Understand embedding layers (input embeddings, position embeddings)
-
Identify the output head (e.g., lm_head for language models)
-
Note any shared parameters or tied weights
Step 2: Plan Tensor Shape Handling
Critical distinction between ranks:
-
Rank 0 (first stage): Receives integer token IDs with shape [batch, seq_len]
-
Intermediate ranks: Receive hidden states with shape [batch, seq_len, hidden_size]
-
Final rank: Must apply output head and compute loss
Create explicit shape handling logic for each case rather than assuming uniform input types.
Step 3: Design Communication Strategy
Use torch.distributed.P2POp for batched send/receive operations:
Preferred: Batched P2P operations
ops = [] if rank > 0: ops.append(dist.P2POp(dist.irecv, recv_tensor, rank - 1)) if rank < world_size - 1: ops.append(dist.P2POp(dist.isend, send_tensor, rank + 1)) reqs = dist.batch_isend_irecv(ops) for req in reqs: req.wait()
Avoid using bare dist.send/dist.recv as they are blocking and less efficient.
Step 4: Handle Gradient Flow Correctly
Critical pitfall: Using tensor.detach().requires_grad_(True) severs the computational graph.
Correct approach for maintaining gradient flow:
For caching inputs that need gradients
input_cache = []
During forward: cache the input tensor directly (not detached)
stage_input = received_tensor.requires_grad_(True) input_cache.append(stage_input)
During backward: use the cached tensor to compute gradients
The gradient flows through the original tensor
Verify gradient connectivity with small test cases before full implementation.
Step 5: Implement Shape Communication
Communicate tensor shapes before data when shapes vary:
Efficient: Single tensor for shape
shape_tensor = torch.tensor(list(tensor.shape), dtype=torch.long, device=device) dist.send(shape_tensor, dst_rank)
Then send the actual data
dist.send(tensor.contiguous(), dst_rank)
Avoid sending each dimension as a separate tensor.
Verification Strategies
- Gradient Flow Verification
Create a minimal test to verify gradients flow correctly:
def test_gradient_flow(): # Create simple model partition # Run forward/backward # Check that model.parameters() have non-None gradients for name, param in model.named_parameters(): assert param.grad is not None, f"No gradient for {name}" assert not torch.all(param.grad == 0), f"Zero gradient for {name}"
- Activation Shape Verification
Log shapes at each stage boundary:
Before send
print(f"Rank {rank} sending shape: {tensor.shape}")
After receive
print(f"Rank {rank} received shape: {tensor.shape}")
- Single-Rank Testing
Test with world_size=1 to ensure the implementation handles the degenerate case:
-
No communication should occur
-
Model should function as standard single-device training
-
All gradients should flow correctly
- End-to-End Loss Comparison
Compare loss values between:
-
Pipeline parallel implementation
-
Standard single-device training (ground truth)
Values should match within numerical precision.
Common Pitfalls
- Truncated Code Edits
When making large code changes:
-
Prefer smaller, targeted edits over large rewrites
-
Verify edit completeness by reading the file after each edit
-
Use full file writes for major structural changes
- Detach Breaking Gradient Flow
WRONG: Severs computational graph
cached = tensor.detach().requires_grad_(True)
RIGHT: Maintains graph connection for backward
cached = tensor.clone().requires_grad_(True)
Or: simply keep reference to original tensor
- Missing lm_head in Partitions
The output head (lm_head ) is often separate from the layer list. Ensure:
-
It's included in the final rank's computation
-
Its parameters receive gradients
-
It's not duplicated across ranks
- Position Embeddings Handling
Position embeddings (especially rotary embeddings) require care:
-
They may need explicit computation before the first layer
-
The API varies between model implementations
-
Test with the specific model architecture being used
- Empty Partitions
When world_size > num_layers , some ranks may have no layers:
-
Add explicit handling for empty partitions
-
These ranks still need to forward activations
-
Avoid division by zero in layer assignment
- Variable Sequence Lengths
If microbatches have different sequence lengths:
-
Communicate shapes before data
-
Don't cache and reuse shapes across microbatches
-
Consider padding strategies for efficiency
Code Organization
Structure the implementation clearly:
pipeline_parallel.py ├── partition_model() # Divide layers across ranks ├── get_stage_layers() # Get this rank's layer subset ├── forward_stage() # Single stage forward pass ├── backward_stage() # Single stage backward pass ├── send_activation() # Send tensor to next rank ├── recv_activation() # Receive tensor from prev rank ├── send_gradient() # Send gradient to prev rank ├── recv_gradient() # Receive gradient from next rank └── train_step_pipeline() # Main training step orchestrator
Testing Checklist
Before considering implementation complete:
-
Gradient flow verified for all model parameters
-
Shapes correct at each stage boundary
-
world_size=1 case works correctly
-
Loss matches non-parallel baseline
-
No communication deadlocks
-
Memory usage scales appropriately with world_size
-
Position embeddings handled correctly for the specific model
-
Output head (lm_head) included and receives gradients