Implementation of systems based on graph neural networks (GNN)
Graph neural networks are a class of architectures for working with graph data. Unlike CNNs and RNNs, GNNs natively model relationships between objects: social connections, molecular structures, transaction networks, and roadmaps. Where a table loses context, a graph preserves it.
Theoretical basis and key architectures
The core idea of GNN is message passing: each node aggregates information from its neighbors. After K iterations, a node "sees" a K-hop neighborhood.
Aggregation formula (GraphSAGE):
h_v^(k) = σ(W · CONCAT(h_v^(k-1), AGG({h_u^(k-1), u ∈ N(v)})))
Key architectures:
| Architecture | Aggregation | Application | Features |
|---|---|---|---|
| GCN (Kipf 2017) | Spectral conv | Node classification | Transductive |
| GraphSAGE | Mean/LSTM/Max | Large Graphs | Inductive |
| GAT | Attention | Heterogeneous Graphs | Weighted Edges |
| GIN | Sum (most powerful) | Graph isomorphism | Maximum expressiveness |
| RGCN | Relation-specific | Knowledge graphs | Different types of edges |
GCN Implementation with PyTorch Geometric
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, SAGEConv, GATConv, global_mean_pool
from torch_geometric.data import Data, DataLoader
import numpy as np
import pandas as pd
class GraphConvNet(nn.Module):
"""
GCN для классификации/регрессии на графе.
Подходит для: fraud detection, рекомендаций, молекул.
"""
def __init__(self, node_features: int,
hidden_channels: int = 64,
output_dim: int = 1,
num_layers: int = 3,
dropout: float = 0.3):
super().__init__()
self.convs = nn.ModuleList()
self.bns = nn.ModuleList()
# Входной слой
self.convs.append(GCNConv(node_features, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# Скрытые слои
for _ in range(num_layers - 2):
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
# Выходной слой
self.convs.append(GCNConv(hidden_channels, hidden_channels))
self.bns.append(nn.BatchNorm1d(hidden_channels))
self.dropout = dropout
self.classifier = nn.Linear(hidden_channels, output_dim)
def forward(self, x: torch.Tensor,
edge_index: torch.Tensor,
batch: torch.Tensor = None) -> torch.Tensor:
"""
x: (N, node_features) — матрица признаков узлов
edge_index: (2, E) — список рёбер в COO формате
batch: (N,) — принадлежность узлов к графам (для батчинга)
"""
for conv, bn in zip(self.convs, self.bns):
x = conv(x, edge_index)
x = bn(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
# Graph-level readout (для задач на уровне графа)
if batch is not None:
x = global_mean_pool(x, batch)
return self.classifier(x)
class GraphSAGEEncoder(nn.Module):
"""
GraphSAGE для inductive learning (работает на новых узлах без переобучения).
Используется для больших графов: социальные сети, транзакции.
"""
def __init__(self, in_channels: int, hidden_channels: int, out_channels: int,
num_layers: int = 3, aggr: str = 'mean'):
super().__init__()
self.convs = nn.ModuleList()
self.convs.append(SAGEConv(in_channels, hidden_channels, aggr=aggr))
for _ in range(num_layers - 2):
self.convs.append(SAGEConv(hidden_channels, hidden_channels, aggr=aggr))
self.convs.append(SAGEConv(hidden_channels, out_channels, aggr=aggr))
def forward(self, x, edge_index):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index)
if i < len(self.convs) - 1:
x = F.relu(x)
x = F.dropout(x, p=0.2, training=self.training)
return x
def encode(self, x, edge_index):
"""L2-нормализованные эмбеддинги для downstream задач"""
out = self.forward(x, edge_index)
return F.normalize(out, p=2, dim=-1)
class GATNetwork(nn.Module):
"""
Graph Attention Network: взвешенная агрегация соседей.
Attention веса показывают «важность» каждого соседа.
"""
def __init__(self, in_channels: int, hidden_channels: int,
out_channels: int, num_heads: int = 8):
super().__init__()
self.conv1 = GATConv(in_channels, hidden_channels,
heads=num_heads, dropout=0.6)
self.conv2 = GATConv(hidden_channels * num_heads, out_channels,
heads=1, concat=False, dropout=0.6)
def forward(self, x, edge_index):
x = F.dropout(x, p=0.6, training=self.training)
x = F.elu(self.conv1(x, edge_index))
x = F.dropout(x, p=0.6, training=self.training)
return self.conv2(x, edge_index)
Building a graph from tabular data
class GraphBuilder:
"""Конвертация табличных данных в граф для GNN"""
def build_user_item_graph(self, interactions: pd.DataFrame,
user_features: pd.DataFrame,
item_features: pd.DataFrame) -> Data:
"""
Двудольный граф пользователь-товар для рекомендаций.
interactions: user_id, item_id, rating/count
"""
# Маппинг ID в индексы узлов
user_ids = interactions['user_id'].unique()
item_ids = interactions['item_id'].unique()
n_users = len(user_ids)
user_idx = {uid: i for i, uid in enumerate(user_ids)}
item_idx = {iid: i + n_users for i, iid in enumerate(item_ids)}
# Рёбра: пользователь → товар
src = interactions['user_id'].map(user_idx).values
dst = interactions['item_id'].map(item_idx).values
# Двунаправленный граф (типично для GNN)
edge_index = torch.tensor(
np.vstack([
np.concatenate([src, dst]),
np.concatenate([dst, src])
]),
dtype=torch.long
)
# Матрица признаков узлов
# Пользователи: embedding + поведенческие признаки
user_feat_matrix = user_features.set_index('user_id').reindex(user_ids).fillna(0).values
# Товары: embedding + характеристики
item_feat_matrix = item_features.set_index('item_id').reindex(item_ids).fillna(0).values
# Выравниваем размерности
max_dim = max(user_feat_matrix.shape[1], item_feat_matrix.shape[1])
user_feat_padded = np.pad(user_feat_matrix, ((0, 0), (0, max_dim - user_feat_matrix.shape[1])))
item_feat_padded = np.pad(item_feat_matrix, ((0, 0), (0, max_dim - item_feat_matrix.shape[1])))
x = torch.tensor(
np.vstack([user_feat_padded, item_feat_padded]),
dtype=torch.float
)
# Веса рёбер (например, рейтинг)
edge_attr = torch.tensor(
np.concatenate([
interactions['rating'].values,
interactions['rating'].values # Зеркальные рёбра
]),
dtype=torch.float
).unsqueeze(1)
return Data(
x=x,
edge_index=edge_index,
edge_attr=edge_attr,
n_users=n_users
)
def build_transaction_graph(self, transactions: pd.DataFrame) -> Data:
"""
Граф транзакций для fraud detection.
Узлы: аккаунты, карты, IP-адреса, мерчанты.
Рёбра: транзакции между ними.
"""
# Уникальные сущности
accounts = transactions['account_id'].unique()
merchants = transactions['merchant_id'].unique()
n_accounts = len(accounts)
acc_idx = {a: i for i, a in enumerate(accounts)}
mer_idx = {m: i + n_accounts for i, m in enumerate(merchants)}
src = transactions['account_id'].map(acc_idx).values
dst = transactions['merchant_id'].map(mer_idx).values
edge_index = torch.tensor([
np.concatenate([src, dst]),
np.concatenate([dst, src])
], dtype=torch.long)
# Признаки транзакций как атрибуты рёбер
edge_attr = torch.tensor(
transactions[['amount', 'hour_of_day', 'is_international']].values,
dtype=torch.float
)
edge_attr = torch.cat([edge_attr, edge_attr], dim=0) # Дублируем для зеркальных рёбер
# Метки: fraud = 1
if 'is_fraud' in transactions.columns:
y = torch.tensor(transactions['is_fraud'].values, dtype=torch.long)
else:
y = None
return Data(
x=torch.zeros(n_accounts + len(merchants), 16), # Placeholder features
edge_index=edge_index,
edge_attr=edge_attr,
y=y
)
GNN Training and Assessment
class GNNTrainer:
"""Pipeline обучения GNN"""
def __init__(self, model: nn.Module, device: str = 'cuda'):
self.model = model.to(device)
self.device = device
self.optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
def train_epoch(self, data: Data, mask: torch.Tensor = None) -> float:
"""Один эпох для node classification"""
self.model.train()
self.optimizer.zero_grad()
data = data.to(self.device)
out = self.model(data.x, data.edge_index)
if mask is not None:
loss = F.cross_entropy(out[mask], data.y[mask])
else:
loss = F.cross_entropy(out, data.y)
loss.backward()
self.optimizer.step()
return float(loss)
def evaluate(self, data: Data, mask: torch.Tensor) -> dict:
"""Оценка качества предсказаний"""
self.model.eval()
with torch.no_grad():
out = self.model(data.x.to(self.device), data.edge_index.to(self.device))
pred = out[mask].argmax(dim=-1).cpu()
true = data.y[mask].cpu()
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
probs = torch.softmax(out[mask], dim=-1)[:, 1].cpu().numpy()
return {
'accuracy': accuracy_score(true, pred),
'f1_macro': f1_score(true, pred, average='macro'),
'auc': roc_auc_score(true, probs) if len(np.unique(true)) > 1 else 0.5
}
def train(self, data: Data,
n_epochs: int = 200,
train_mask: torch.Tensor = None,
val_mask: torch.Tensor = None) -> dict:
"""Полный цикл обучения с early stopping"""
best_val_auc = 0
patience, patience_counter = 20, 0
history = {'train_loss': [], 'val_auc': []}
for epoch in range(n_epochs):
loss = self.train_epoch(data, train_mask)
history['train_loss'].append(loss)
if val_mask is not None and epoch % 5 == 0:
metrics = self.evaluate(data, val_mask)
history['val_auc'].append(metrics['auc'])
if metrics['auc'] > best_val_auc:
best_val_auc = metrics['auc']
patience_counter = 0
torch.save(self.model.state_dict(), 'best_gnn_model.pt')
else:
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
return {'best_val_auc': best_val_auc, 'history': history}
Scaling to large graphs
Standard GNNs don't scale to graphs with millions of nodes—the full adjacency matrix doesn't fit in memory. Solutions:
-
GraphSAGE with mini-batch: sampling K neighbors instead of all. PyG supports this via
NeighborLoaderwith thenum_neighbors=[25, 10]parameter. - Cluster-GCN: graph partitioning into clusters, intra-cluster learning
- GraphSAINT: Random sampling of subgraphs with importance sampling
from torch_geometric.loader import NeighborLoader
def create_scalable_dataloader(data: Data, batch_size: int = 1024) -> NeighborLoader:
"""Mini-batch загрузчик для больших графов"""
return NeighborLoader(
data,
num_neighbors=[25, 10, 5], # Соседи для 3 hop
batch_size=batch_size,
input_nodes=data.train_mask,
shuffle=True,
num_workers=4
)
Application area and benchmarks
| Task | Dataset | Architecture | AUC/Accuracy |
|---|---|---|---|
| Fraud detection | financial transactions | GraphSAGE | AUC 0.93-0.97 |
| Recommendations | Amazon | LightGCN | NDCG@20 0.045 |
| Social Spam | GAT | F1 0.89 | |
| Molecular Properties | ZINC | GIN | MAE 0.163 |
| Road Traffic | METR-LA | Diffusion GCN | RMSE 2.37 |
GNNs outperform traditional methods only when the graph structure is informative. If the relationships between objects are random, a standard GBT or MLP will yield comparable results with lower complexity.







