Налаштування FSDP (Fully Sharded Data Parallel) для навчання

Проектуємо та впроваджуємо системи штучного інтелекту: від прототипу до production-ready рішення. Наша команда поєднує експертизу в машинному навчанні, дата-інжинірингу та MLOps, щоб AI працював не в лабораторії, а в реальному бізнесі.
Показано 1 з 1Усі 1566 послуг
Налаштування FSDP (Fully Sharded Data Parallel) для навчання
Складний
~3-5 днів
Часті запитання

Напрямки AI-розробки

Етапи розробки AI-рішення

Останні роботи

  • image_website-b2b-advance_0.webp
    Розробка сайту компанії B2B ADVANCE
    1284
  • image_web-applications_feedme_466_0.webp
    Розробка веб-додатків для компанії FEEDME
    1196
  • image_websites_belfingroup_462_0.webp
    Розробка веб-сайту для компанії БЕЛФІНГРУП
    901
  • image_ecommerce_furnoro_435_0.webp
    Розробка інтернет магазину для компанії FURNORO
    1119
  • image_logo-advance_0.webp
    Розробка логотипу компанії B2B Advance
    586
  • image_crm_enviok_479_0.webp
    Розробка веб-додатків для компанії Enviok
    853

Налаштування FSDP (Fully Sharded Data Parallel) для навчання

FSDP - нативна реалізація fully sharded data parallelism у PyTorch (з'явилася у версії 1.11). На відміну від DeepSpeed ZeRO, FSDP є частиною PyTorch core та не потребує додаткових залежностей. Шардує параметри, градієнти та стан оптимізатора між GPU аналогічно до DeepSpeed ZeRO Stage 3.

Принцип роботи

При forward pass: параметри кожного sharded layer збираються (all-gather) з усіх GPU перед обчисленням. Після forward - негайно звільняються, якщо включений reshard_after_forward. При backward pass: параметри знову збираються, градієнти обчислюються, потім reduce-scatter розподіляє шарди градієнтів GPU.

Це усуває ситуацію, коли кожен GPU зберігає повну копію моделі, як у звичайному DDP.

Базове налаштування

import torch
import torch.distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    enable_wrap,
    wrap,
)
import functools

def setup_fsdp(rank, world_size):
    dist.init_process_group("nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)

def wrap_model_with_fsdp(model, rank):
    # Политика автоматического оборачивания: шардировать слои > 100M параметров
    auto_wrap_policy = functools.partial(
        size_based_auto_wrap_policy,
        min_num_params=100_000_000
    )

    model = FSDP(
        model,
        auto_wrap_policy=auto_wrap_policy,
        cpu_offload=CPUOffload(offload_params=False),
        backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
        device_id=torch.cuda.current_device(),
        sharding_strategy=ShardingStrategy.FULL_SHARD,  # Аналог ZeRO-3
        mixed_precision=MixedPrecision(
            param_dtype=torch.bfloat16,
            reduce_dtype=torch.float32,
            buffer_dtype=torch.bfloat16,
        ),
    )
    return model

Стратегії шардування

from torch.distributed.fsdp import ShardingStrategy

# FULL_SHARD — полное шардирование (аналог ZeRO-3)
# Максимальная экономия памяти, максимальный overhead на коммуникацию
strategy = ShardingStrategy.FULL_SHARD

# SHARD_GRAD_OP — шардирование только градиентов и оптимизатора (ZeRO-2)
# Баланс между памятью и скоростью
strategy = ShardingStrategy.SHARD_GRAD_OP

# NO_SHARD — обычный DDP без шардирования
strategy = ShardingStrategy.NO_SHARD

# HYBRID_SHARD — FULL_SHARD внутри узла, репликация между узлами
# Оптимален для multi-node с быстрым NVLink внутри узла
strategy = ShardingStrategy.HYBRID_SHARD

Wrap policy для Transformer моделей

Для трансформерів важливо обертати кожен Transformer block окремо:

from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
from transformers.models.llama.modeling_llama import LlamaDecoderLayer

# Каждый LlamaDecoderLayer будет отдельным FSDP unit
llama_auto_wrap_policy = functools.partial(
    transformer_auto_wrap_policy,
    transformer_layer_cls={LlamaDecoderLayer},
)

model = FSDP(model, auto_wrap_policy=llama_auto_wrap_policy)

Збереження та завантаження checkpoint

З FSDP checkpoint вимагає спеціальної обробки, оскільки параметри шардовані:

from torch.distributed.fsdp import FullStateDictConfig, StateDictType

# Сохранение — собираем полный state dict на rank 0
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    cpu_state = model.state_dict()
    if rank == 0:
        torch.save(cpu_state, "checkpoint.pt")

# Загрузка — загружаем на CPU, затем распределяем
if rank == 0:
    state_dict = torch.load("checkpoint.pt")
else:
    state_dict = {}

with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
    model.load_state_dict(state_dict)

FSDP vs DeepSpeed ZeRO: порівняння

Критерій FSDP DeepSpeed ZeRO-3
Інтеграція з PyTorch Нативна Зовнішня бібліотека
CPU/NVMe offload Обмежений Просунутий (ZeRO-Infinity)
Підтримка Hugging Face Через Accelerate Нативна
Продуктивність Порівнянно Незначно швидше для великих моделей
Складність налаштування Нижче Вище

Інтеграція з Accelerate

from accelerate import Accelerator
from accelerate.utils import FullyShardedDataParallelPlugin
from torch.distributed.fsdp.fully_sharded_data_parallel import FullOptimStateDictConfig, FullStateDictConfig

fsdp_plugin = FullyShardedDataParallelPlugin(
    state_dict_config=FullStateDictConfig(offload_to_cpu=True, rank0_only=False),
    optim_state_dict_config=FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=False),
)

accelerator = Accelerator(fsdp_plugin=fsdp_plugin)

FSDP – правильний вибір для команд, які працюють в екосистемі PyTorch без бажання додавати DeepSpeed як залежність. Для LLaMA-2 70B на 8x A100 80GB FSDP FULL_SHARD забезпечує ~800-900 tokens/s у BF16 навчанні.