Skill: Use PyTorch FSDP2 (fully_shard ) correctly in a training script
This skill teaches a coding agent how to add PyTorch FSDP2 to a training loop with correct initialization, sharding, mixed precision/offload configuration, and checkpointing.
FSDP2 in PyTorch is exposed primarily via torch.distributed.fsdp.fully_shard and the FSDPModule methods it adds in-place to modules. See: references/pytorch_fully_shard_api.md , references/pytorch_fsdp2_tutorial.md .
When to use this skill
Use FSDP2 when:
-
Your model doesn’t fit on one GPU (parameters + gradients + optimizer state).
-
You want an eager-mode sharding approach that is DTensor-based per-parameter sharding (more inspectable, simpler sharded state dicts) than FSDP1.
-
You may later compose DP with Tensor Parallel using DeviceMesh.
Avoid (or be careful) if:
-
You need strict backwards-compatible checkpoints across PyTorch versions (DCP warns against this).
-
You’re forced onto older PyTorch versions without the FSDP2 stack.
Alternatives (when FSDP2 is not the best fit)
-
DistributedDataParallel (DDP): Use the standard data-parallel wrapper when you want classic distributed data parallel training.
-
FullyShardedDataParallel (FSDP1): Use the original FSDP wrapper for parameter sharding across data-parallel workers.
Reference: references/pytorch_ddp_notes.md , references/pytorch_fsdp1_api.md .
Contract the agent must follow
-
Launch with torchrun and set the CUDA device per process (usually via LOCAL_RANK ).
-
Apply fully_shard() bottom-up, i.e., shard submodules (e.g., Transformer blocks) before the root module.
-
Call model(input) , not model.forward(input) , so the FSDP2 hooks run (unless you explicitly unshard() or register the forward method).
-
Create the optimizer after sharding and make sure it is built on the DTensor parameters (post-fully_shard ).
-
Checkpoint using Distributed Checkpoint (DCP) or the distributed-state-dict helpers, not naïve torch.save(model.state_dict()) unless you deliberately gather to full tensors.
(Each of these rules is directly described in the official API docs/tutorial; see references.)
Step-by-step procedure
- Version & environment sanity
-
Prefer a recent stable PyTorch where the docs show FSDP2 and DCP updated recently.
-
Use torchrun --nproc_per_node <gpus_per_node> ... and ensure RANK , WORLD_SIZE , LOCAL_RANK are visible.
Reference: references/pytorch_fsdp2_tutorial.md (launch commands and setup), references/pytorch_fully_shard_api.md (user contract).
- Initialize distributed and set device
Minimal, correct pattern:
-
dist.init_process_group(backend="nccl")
-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
-
Optionally create a DeviceMesh to describe the data-parallel group(s)
Reference: references/pytorch_device_mesh_tutorial.md (why DeviceMesh exists & how it manages process groups).
- Build model on meta device (recommended for very large models)
For big models, initialize on meta , apply sharding, then materialize weights on GPU:
-
with torch.device("meta"): model = ...
-
apply fully_shard(...) on submodules, then fully_shard(model)
-
model.to_empty(device="cuda")
-
model.reset_parameters() (or your init routine)
Reference: references/pytorch_fsdp2_tutorial.md (migration guide shows this flow explicitly).
- Apply fully_shard() bottom-up (wrapping policy = “apply where needed”)
Do not only call fully_shard on the topmost module.
Recommended sharding pattern for transformer-like models:
-
iterate modules, if isinstance(m, TransformerBlock): fully_shard(m, ...)
-
then fully_shard(model, ...)
Why:
- fully_shard forms “parameter groups” for collective efficiency and excludes params already grouped by earlier calls. Bottom-up gives better overlap and lower peak memory.
Reference: references/pytorch_fully_shard_api.md (bottom-up requirement and why).
- Configure reshard_after_forward for memory/perf trade-offs
Default behavior:
- None means True for non-root modules and False for root modules (good default).
Heuristics:
-
If you’re memory-bound: keep defaults or force True on many blocks.
-
If you’re throughput-bound and can afford memory: consider keeping unsharded params longer (root often False ).
-
Advanced: use an int to reshard to a smaller mesh after forward (e.g., intra-node) if it’s a meaningful divisor.
Reference: references/pytorch_fully_shard_api.md (full semantics).
- Mixed precision & offload (optional but common)
FSDP2 uses:
-
mp_policy=MixedPrecisionPolicy(param_dtype=..., reduce_dtype=..., output_dtype=..., cast_forward_inputs=...)
-
offload_policy=CPUOffloadPolicy() if you want CPU offload
Rules of thumb:
-
Start with BF16 parameters/reductions on H100/A100-class GPUs (if numerically stable for your model).
-
Keep reduce_dtype aligned with your gradient reduction expectations.
-
If you use CPU offload, budget for PCIe/NVLink traffic and runtime overhead.
Reference: references/pytorch_fully_shard_api.md (MixedPrecisionPolicy / OffloadPolicy classes).
- Optimizer, gradient clipping, accumulation
-
Create the optimizer after sharding so it holds DTensor params.
-
If you need gradient accumulation / no_sync:
-
use the FSDP2 mechanism (set_requires_gradient_sync ) instead of FSDP1’s no_sync() .
Gradient clipping:
- Use the approach shown in the FSDP2 tutorial (“Gradient Clipping and Optimizer with DTensor”), because parameters/gradients are DTensors.
Reference: references/pytorch_fsdp2_tutorial.md .
- Checkpointing: prefer DCP or distributed state dict helpers
Two recommended approaches:
A) Distributed Checkpoint (DCP) — best default
-
DCP saves/loads from multiple ranks in parallel and supports load-time resharding.
-
DCP produces multiple files (often at least one per rank) and operates “in place”.
B) Distributed state dict helpers
-
get_model_state_dict / set_model_state_dict with StateDictOptions(full_state_dict=True, cpu_offload=True, broadcast_from_rank0=True, ...)
-
For optimizer: get_optimizer_state_dict / set_optimizer_state_dict
Avoid:
- Saving DTensor state dicts with plain torch.save unless you intentionally convert with DTensor.full_tensor() and manage memory carefully.
References:
-
references/pytorch_dcp_overview.md (DCP behavior and caveats)
-
references/pytorch_dcp_recipe.md and references/pytorch_dcp_async_recipe.md (end-to-end usage)
-
references/pytorch_fsdp2_tutorial.md (DTensor vs DCP state-dict flows)
-
references/pytorch_examples_fsdp2.md (working checkpoint scripts)
Workflow checklists (copy-paste friendly)
Workflow A: Retrofit FSDP2 into an existing training script
-
Launch with torchrun and initialize the process group.
-
Set the CUDA device from LOCAL_RANK ; create a DeviceMesh if you need multi-dim parallelism.
-
Build the model (use meta if needed), apply fully_shard bottom-up, then fully_shard(model) .
-
Create the optimizer after sharding so it captures DTensor parameters.
-
Use model(inputs) so hooks run; use set_requires_gradient_sync for accumulation.
-
Add DCP save/load via torch.distributed.checkpoint helpers.
Reference: references/pytorch_fsdp2_tutorial.md , references/pytorch_fully_shard_api.md , references/pytorch_device_mesh_tutorial.md , references/pytorch_dcp_recipe.md .
Workflow B: Add DCP save/load (minimal pattern)
-
Wrap state in Stateful or assemble state via get_state_dict .
-
Call dcp.save(...) from all ranks to a shared path.
-
Call dcp.load(...) and restore with set_state_dict .
-
Validate any resharding assumptions when loading into a different mesh.
Reference: references/pytorch_dcp_recipe.md .
Debug checklist (what the agent should check first)
- All ranks on distinct GPUs?
If not, verify torch.cuda.set_device(LOCAL_RANK) and your torchrun flags.
- Did you accidentally call forward() directly?
Use model(input) or explicitly unshard() / register forward.
- Is fully_shard() applied bottom-up?
If only root is sharded, expect worse memory/perf and possible confusion.
- Optimizer created at the right time?
Must be built on DTensor parameters after sharding.
-
Checkpointing path consistent?
-
If using DCP, don’t mix with ad-hoc torch.save unless you understand conversions.
-
Be mindful of PyTorch-version compatibility warnings for DCP.
Common issues and fixes
-
Forward hooks not running → Call model(inputs) (or unshard() explicitly) instead of model.forward(...) .
-
Optimizer sees non-DTensor params → Create optimizer after all fully_shard calls.
-
Only root module sharded → Apply fully_shard bottom-up on submodules before the root.
-
Memory spikes after forward → Set reshard_after_forward=True for more modules.
-
Gradient accumulation desync → Use set_requires_gradient_sync instead of FSDP1’s no_sync() .
Reference: references/pytorch_fully_shard_api.md , references/pytorch_fsdp2_tutorial.md .
Minimal reference implementation outline (agent-friendly)
The coding agent should implement a script with these labeled blocks:
-
init_distributed() : init process group, set device
-
build_model_meta() : model on meta, apply fully_shard , materialize weights
-
build_optimizer() : optimizer created after sharding
-
train_step() : forward/backward/step with model(inputs) and DTensor-aware patterns
-
checkpoint_save/load() : DCP or distributed state dict helpers
Concrete examples live in references/pytorch_examples_fsdp2.md and the official tutorial reference.
References
-
references/pytorch_fsdp2_tutorial.md
-
references/pytorch_fully_shard_api.md
-
references/pytorch_ddp_notes.md
-
references/pytorch_fsdp1_api.md
-
references/pytorch_device_mesh_tutorial.md
-
references/pytorch_tp_tutorial.md
-
references/pytorch_dcp_overview.md
-
references/pytorch_dcp_recipe.md
-
references/pytorch_dcp_async_recipe.md
-
references/pytorch_examples_fsdp2.md
-
references/torchtitan_fsdp_notes.md (optional, production notes)
-
references/ray_train_fsdp2_example.md (optional, integration example)