Реалізація Federated Learning для навчання моделей без передачі даних
Federated Learning — парадигма навчання ML-моделей, коли дані залишаються на пристроях клієнтів (смартфони, лікарняні сервери, банківські системи), а центральний сервер передаються лише оновлення ваг моделі. Це дозволяє навчати моделі на чутливих даних без їхньої централізації.
Коли застосовувати FL
- Медичні дані: декілька лікарень навчають модель діагностики без обміну даними пацієнтів
- Фінанси: банки-конкуренти спільно навчають модель фрода без розкриття транзакцій
- Мобільні пристрої: персоналізовані моделі на даних користувачів без їх upload
- IoT: моделі на даних промислового обладнання, які не можна передавати з security причин
FedAvg - базовий алгоритм
Federated Averaging (McMahan et al., 2017) - стандартний алгоритм FL:
- Сервер ініціалізує глобальну модель $w_0$
- На кожному раунді t: сервер вибирає підмножину клієнтів, розсилає поточні ваги
- Кожен клієнт: навчає модель локальних даних (кілька епох), повертає $\Delta w_i$
- Сервер агрегує $w_{t+1} = \sum_i \frac{n_i}{n} w_i^t$, де $n_i$ — розмір датасету клієнта i
Реалізація з PySyft / Flower
Flower (flwr) — найбільш зрілий open-source FL фреймворк:
import flwr as fl
import torch
from typing import Dict, List, Tuple, Optional
# Клиентская часть
class MedicalModelClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
def get_parameters(self, config) -> List[np.ndarray]:
return [param.data.numpy() for param in self.model.parameters()]
def set_parameters(self, parameters: List[np.ndarray]):
for param, new_param in zip(self.model.parameters(), parameters):
param.data = torch.tensor(new_param)
def fit(self, parameters, config) -> Tuple[List[np.ndarray], int, Dict]:
self.set_parameters(parameters)
# Локальное обучение
optimizer = torch.optim.SGD(self.model.parameters(),
lr=config.get("lr", 0.01))
local_epochs = config.get("local_epochs", 3)
self.model.train()
for epoch in range(local_epochs):
for batch in self.train_loader:
optimizer.zero_grad()
loss = self.model(batch)
loss.backward()
optimizer.step()
return self.get_parameters(config), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config) -> Tuple[float, int, Dict]:
self.set_parameters(parameters)
loss, accuracy = test(self.model, self.val_loader)
return float(loss), len(self.val_loader.dataset), {"accuracy": float(accuracy)}
# Серверная часть
class FedAvgWithDP(fl.server.strategy.FedAvg):
"""FedAvg с Differential Privacy"""
def aggregate_fit(self, server_round, results, failures):
aggregated_params, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_params is not None:
# Добавление Gaussian noise для DP
noise_multiplier = 0.1
for param in fl.common.parameters_to_ndarrays(aggregated_params):
noise = np.random.normal(0, noise_multiplier, param.shape)
param += noise
return aggregated_params, aggregated_metrics
strategy = FedAvgWithDP(
min_fit_clients=5,
min_evaluate_clients=3,
min_available_clients=10,
fraction_fit=0.5, # 50% клиентов на каждый раунд
)
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=50)
)
Differential Privacy у FL
DP гарантує, що участь окремого клієнта не може бути виявлена за глобальною моделлю:
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=local_epochs,
target_epsilon=5.0, # ε-DP параметр (меньше = приватнее)
target_delta=1e-5,
max_grad_norm=1.0, # Gradient clipping
)
Проблеми та рішення
Non-IID дані - дані на різних клієнтах мають різні розподіли. Рішення: FedProx (додає proximal term), SCAFFOLD, FedNova.
Комунікаційні накладні витрати — передача ваги моделі при тисячах клієнтів. Рішення: gradient compression (Top-k sparsification), quantization (8-bit weights).
Stragglers — повільні клієнти затримують раунд. Рішення: асинхронний FL (FedAsync), тайм на участь клієнта.
Backdoor атаки - шкідливий клієнт отруює глобальну модель. Захисту: Byzantine-robust aggregation (Krum, Median), аномально виявити на оновленнях.
Метрики оцінки FL системи
- Communication efficiency: кількість раундів до досягнення target accuracy
- Accuracy gap: різниця між centralised training та FL (зазвичай 1-5%)
- Privacy budget: $(\epsilon, \delta)$-DP досягнутий за підсумками навчання
- Participation rate: % клієнтів, які успішно завершили кожен раунд
Типовий проект: медичний консорціум із 10 лікарень навчає модель детекції раку на рентгенограмах. FL дозволяє досягти AUC 0.94 – проти 0.87 у кращої окремої лікарні – без жодної передачі даних пацієнтів.







