Setting up FSDP (Fully Sharded Data Parallel) for training
FSDP is a native implementation of fully sharded data parallelism in PyTorch (introduced in version 1.11). Unlike DeepSpeed ZeRO, FSDP is part of the PyTorch core and requires no additional dependencies. It shards parameters, gradients, and optimizer state across GPUs, similar to DeepSpeed ZeRO Stage 3.
Operating principle
Forward pass: the parameters of each sharded layer are all-gathered from all GPUs before computation. After forward pass, they are immediately freed if rehard_after_forward is enabled. Backward pass: the parameters are collected again, the gradients are computed, and then reduce-scatter distributes the gradient shards across the GPUs.
This eliminates the situation where each GPU stores a full copy of the model, as in regular DDP.
Basic setup
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
def setup_fsdp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def wrap_model_with_fsdp(model, rank):
# Политика автоматического оборачивания: шардировать слои > 100M параметров
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=100_000_000
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=False),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # Аналог ZeRO-3
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
),
)
return model
Sharding strategies
from torch.distributed.fsdp import ShardingStrategy
# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD
# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP
# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD
# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD
Wrap policy for Transformer models
For transformers, it is important to wrap each Transformer block separately:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)
Saving and loading checkpoints
With FSDP, the checkpoint requires special handling because the parameters are sharded:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
torch.save(cpu_state, "checkpoint.pt")
# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
state_dict = torch.load("checkpoint.pt")
else:
state_dict = {}
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
model.load_state_dict(state_dict)
FSDP vs. DeepSpeed ZeRO: A Comparison
| Criterion | FSDP | DeepSpeed ZeRO-3 |
|---|---|---|
| Integration with PyTorch | Native | External library |
| CPU/NVMe offload | Limited | Advanced (ZeRO-Infinity) |
| Hugging Face Support | Via Accelerate | Native |
| Performance | Comparable | Slightly faster for very large models |
| Difficulty of setup | Below | Higher |
Integration with Accelerate
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
FSDP is the right choice for teams working in the PyTorch ecosystem without the need to add DeepSpeed as a dependency. For LLaMA-2 70B on 8x A100 80GB, FSDP FULL_SHARD delivers ~800-900 tokens/s during BF16 training.







