FSDP (Fully Sharded Data Parallel) Training Setup

We design and deploy artificial intelligence systems: from prototype to production-ready solutions. Our team combines expertise in machine learning, data engineering and MLOps to make AI work not in the lab, but in real business.
Showing 1 of 1 servicesAll 1566 services
FSDP (Fully Sharded Data Parallel) Training Setup
Complex
~3-5 business days
FAQ
AI Development Areas
AI Solution Development Stages
Latest works
  • image_website-b2b-advance_0.png
    B2B ADVANCE company website development
    1212
  • image_web-applications_feedme_466_0.webp
    Development of a web application for FEEDME
    1161
  • image_websites_belfingroup_462_0.webp
    Website development for BELFINGROUP
    852
  • image_ecommerce_furnoro_435_0.webp
    Development of an online store for the company FURNORO
    1041
  • image_logo-advance_0.png
    B2B Advance company logo design
    561
  • image_crm_enviok_479_0.webp
    Development of a web application for Enviok
    822

Setting up FSDP (Fully Sharded Data Parallel) for training

FSDP is a native implementation of fully sharded data parallelism in PyTorch (introduced in version 1.11). Unlike DeepSpeed ZeRO, FSDP is part of the PyTorch core and requires no additional dependencies. It shards parameters, gradients, and optimizer state across GPUs, similar to DeepSpeed ZeRO Stage 3.

Operating principle

Forward pass: the parameters of each sharded layer are all-gathered from all GPUs before computation. After forward pass, they are immediately freed if rehard_after_forward is enabled. Backward pass: the parameters are collected again, the gradients are computed, and then reduce-scatter distributes the gradient shards across the GPUs.

This eliminates the situation where each GPU stores a full copy of the model, as in regular DDP.

Basic setup

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import functools

def setup_fsdp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def wrap_model_with_fsdp(model, rank):
    # Политика автоматического оборачивания: шардировать слои > 100M параметров
    auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy,
        min_num_params=100_000_000
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        cpu_offload=CPUOffload(offload_params=False),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # Аналог ZeRO-3
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.bfloat16,
        ),
    )
    return model

Sharding strategies

from torch.distributed.fsdp import ShardingStrategy

# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD

# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP

# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD

# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD

Wrap policy for Transformer models

For transformers, it is important to wrap each Transformer block separately:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)

Saving and loading checkpoints

With FSDP, the checkpoint requires special handling because the parameters are sharded:

from torch.distributed.fsdp import FullStateDictConfig, StateDictType

# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    cpu_state = model.state_dict()
    if rank == 0:
        torch.save(cpu_state, "checkpoint.pt")

# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
    state_dict = torch.load("checkpoint.pt")
else:
    state_dict = {}

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    model.load_state_dict(state_dict)

FSDP vs. DeepSpeed ZeRO: A Comparison

Criterion FSDP DeepSpeed ZeRO-3
Integration with PyTorch Native External library
CPU/NVMe offload Limited Advanced (ZeRO-Infinity)
Hugging Face Support Via Accelerate Native
Performance Comparable Slightly faster for very large models
Difficulty of setup Below Higher

Integration with Accelerate

from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

FSDP is the right choice for teams working in the PyTorch ecosystem without the need to add DeepSpeed as a dependency. For LLaMA-2 70B on 8x A100 80GB, FSDP FULL_SHARD delivers ~800-900 tokens/s during BF16 training.