Налаштування FSDP (Fully Sharded Data Parallel) для навчання
FSDP - нативна реалізація fully sharded data parallelism у PyTorch (з'явилася у версії 1.11). На відміну від DeepSpeed ZeRO, FSDP є частиною PyTorch core та не потребує додаткових залежностей. Шардує параметри, градієнти та стан оптимізатора між GPU аналогічно до DeepSpeed ZeRO Stage 3.
Принцип роботи
При forward pass: параметри кожного sharded layer збираються (all-gather) з усіх GPU перед обчисленням. Після forward - негайно звільняються, якщо включений reshard_after_forward. При backward pass: параметри знову збираються, градієнти обчислюються, потім reduce-scatter розподіляє шарди градієнтів GPU.
Це усуває ситуацію, коли кожен GPU зберігає повну копію моделі, як у звичайному DDP.
Базове налаштування
import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
enable_wrap,
wrap,
)
import functools
def setup_fsdp(rank, world_size):
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def wrap_model_with_fsdp(model, rank):
# Политика автоматического оборачивания: шардировать слои > 100M параметров
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
min_num_params=100_000_000
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
cpu_offload=CPUOffload(offload_params=False),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD, # Аналог ZeRO-3
mixed_precision=MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.bfloat16,
),
)
return model
Стратегії шардування
from torch.distributed.fsdp import ShardingStrategy
# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD
# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP
# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD
# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD
Wrap policy для Transformer моделей
Для трансформерів важливо обертати кожен Transformer block окремо:
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={LlamaDecoderLayer},
)
model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)
Збереження та завантаження checkpoint
З FSDP checkpoint вимагає спеціальної обробки, оскільки параметри шардовані:
from torch.distributed.fsdp import FullStateDictConfig, StateDictType
# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state = model.state_dict()
if rank == 0:
torch.save(cpu_state, "checkpoint.pt")
# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
state_dict = torch.load("checkpoint.pt")
else:
state_dict = {}
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
model.load_state_dict(state_dict)
FSDP vs DeepSpeed ZeRO: порівняння
| Критерій | FSDP | DeepSpeed ZeRO-3 |
|---|---|---|
| Інтеграція з PyTorch | Нативна | Зовнішня бібліотека |
| CPU/NVMe offload | Обмежений | Просунутий (ZeRO-Infinity) |
| Підтримка Hugging Face | Через Accelerate | Нативна |
| Продуктивність | Порівнянно | Незначно швидше для великих моделей |
| Складність налаштування | Нижче | Вище |
Інтеграція з Accelerate
from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig
fsdp_plugin = FullyShardedDataParallelPlugin(
state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)
accelerator = Accelerator(fsdp_plugin=fsdp_plugin)
FSDP – правильний вибір для команд, які працюють в екосистемі PyTorch без бажання додавати DeepSpeed як залежність. Для LLaMA-2 70B на 8x A100 80GB FSDP FULL_SHARD забезпечує ~800-900 tokens/s у BF16 навчанні.







