Implementing Federated Learning to Train Models Without Data Transfer
Federated Learning is a paradigm for training ML models where data remains on client devices (smartphones, hospital servers, banking systems), and only model weight updates are transmitted to the central server. This allows models to be trained on sensitive data without centralizing it.
When to use FL
- Medical data: Several hospitals train diagnostic model without sharing patient data
- Finance: Competing banks jointly train a fraud model without disclosing transactions
- Mobile devices: personalized models based on user data without uploading it
- IoT: models based on industrial equipment data that cannot be transmitted for security reasons
FedAvg - base algorithm
Federated Averaging (McMahan et al., 2017) is a standard FL algorithm:
- The server initializes the global model $w_0$
- At each round t: the server selects a subset of clients, broadcasts the current weights
- Each client: trains the model on local data (several epochs), returns $\Delta w_i$
- The server aggregates: $w_{t+1} = \sum_i \frac{n_i}{n} w_i^t$, where $n_i$ is the size of the client i's dataset
Implementation with PySyft/Flower
Flower (flwr) is the most mature open-source FL framework:
import flwr as fl
import torch
from typing import Dict, List, Tuple, Optional
# Клиентская часть
class MedicalModelClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader):
self.model = model
self.train_loader = train_loader
self.val_loader = val_loader
def get_parameters(self, config) -> List[np.ndarray]:
return [param.data.numpy() for param in self.model.parameters()]
def set_parameters(self, parameters: List[np.ndarray]):
for param, new_param in zip(self.model.parameters(), parameters):
param.data = torch.tensor(new_param)
def fit(self, parameters, config) -> Tuple[List[np.ndarray], int, Dict]:
self.set_parameters(parameters)
# Локальное обучение
optimizer = torch.optim.SGD(self.model.parameters(),
lr=config.get("lr", 0.01))
local_epochs = config.get("local_epochs", 3)
self.model.train()
for epoch in range(local_epochs):
for batch in self.train_loader:
optimizer.zero_grad()
loss = self.model(batch)
loss.backward()
optimizer.step()
return self.get_parameters(config), len(self.train_loader.dataset), {}
def evaluate(self, parameters, config) -> Tuple[float, int, Dict]:
self.set_parameters(parameters)
loss, accuracy = test(self.model, self.val_loader)
return float(loss), len(self.val_loader.dataset), {"accuracy": float(accuracy)}
# Серверная часть
class FedAvgWithDP(fl.server.strategy.FedAvg):
"""FedAvg с Differential Privacy"""
def aggregate_fit(self, server_round, results, failures):
aggregated_params, aggregated_metrics = super().aggregate_fit(
server_round, results, failures
)
if aggregated_params is not None:
# Добавление Gaussian noise для DP
noise_multiplier = 0.1
for param in fl.common.parameters_to_ndarrays(aggregated_params):
noise = np.random.normal(0, noise_multiplier, param.shape)
param += noise
return aggregated_params, aggregated_metrics
strategy = FedAvgWithDP(
min_fit_clients=5,
min_evaluate_clients=3,
min_available_clients=10,
fraction_fit=0.5, # 50% клиентов на каждый раунд
)
fl.server.start_server(
server_address="0.0.0.0:8080",
strategy=strategy,
config=fl.server.ServerConfig(num_rounds=50)
)
Differential Privacy in FL
DP ensures that individual client participation cannot be detected by the global model:
from opacus import PrivacyEngine
privacy_engine = PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=local_epochs,
target_epsilon=5.0, # ε-DP параметр (меньше = приватнее)
target_delta=1e-5,
max_grad_norm=1.0, # Gradient clipping
)
Problems and solutions
Non-IID data — data has different distributions on different clients. Solutions: FedProx (adds proximal term), SCAFFOLD, FedNova.
Communication overhead – transmitting model weights across thousands of clients. Solutions: gradient compression (Top-k sparsification), quantization (8-bit weights).
Stragglers — slow clients delay the round. Solutions: asynchronous FL (FedAsync), client participation timeout.
Backdoor attacks – a malicious client poisons the global model. Defenses: Byzantine-robust aggregation (Krum, Median), anomaly detection on updates.
FL system evaluation metrics
- Communication efficiency: number of rounds to achieve target accuracy
- Accuracy gap: the difference between centralized training and FL (usually 1-5%)
- Privacy budget: $(\epsilon, \delta)$-DP achieved upon completion of training
- Participation rate: % of clients who successfully completed each round
A typical project: a medical consortium of 10 hospitals trains a model for detecting cancer on X-ray images. FL achieves an AUC of 0.94—versus 0.87 for the best single hospital—without sharing any patient data.







