Training моделей классификации изображений: ResNet, ViT, EfficientNet
Классификация — базовая CV-задача, но «обучить ResNet на своих данных» на практике означает решение целой цепочки вопросов: какую архитектуру выбрать при имеющемся объёме данных, как избежать переобучения при малом датасете, как справиться с дисбалансом классов, как оптимизировать модель для инференса.
Selection архитектуры в зависимости от данных
| Датасет | Рекомендация | Почему |
|---|---|---|
| <500 примеров/класс | EfficientNet-B0/B2 (frozen → partial unfreeze) | Меньше параметров, меньше переобучение |
| 500–5000/класс | EfficientNet-B4, ConvNeXt-T, ResNet-50 | Баланс точности и скорости обучения |
| >5000/класс | ViT-B/16, ConvNeXt-S/B | Трансформеры раскрываются на больших данных |
| Медицина, мало данных | ResNet-50 с pretrain на MedNet | Доменный pretrain важнее архитектуры |
| Edge / мобайл | MobileNetV3-Large, EfficientNet-Lite | Latency < 10ms на телефоне |
Двухфазное обучение с ViT при малом датасете
ViT на 300 примерах без правильной стратегии даст accuracy 65% там, где EfficientNet даст 84%. Причина: attention heads на маленьком датасете переобучаются быстрее, чем convolutional inductive bias в ResNet. Решение — агрессивная заморозка + постепенное размораживание:
import timm
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
def train_vit_two_phase(
num_classes: int,
train_loader,
val_loader,
phase1_epochs: int = 20,
phase2_epochs: int = 60,
device: str = 'cuda'
) -> nn.Module:
model = timm.create_model(
'vit_base_patch16_224',
pretrained=True,
num_classes=num_classes,
drop_rate=0.1,
drop_path_rate=0.1
).to(device)
# ФАЗА 1: обучаем только head
for name, param in model.named_parameters():
param.requires_grad = 'head' in name
opt1 = AdamW(
filter(lambda p: p.requires_grad, model.parameters()),
lr=1e-3, weight_decay=0.05
)
sched1 = CosineAnnealingLR(opt1, T_max=phase1_epochs, eta_min=1e-5)
for epoch in range(phase1_epochs):
_train_epoch(model, train_loader, opt1, device)
sched1.step()
# ФАЗА 2: размораживаем последние 6 блоков (из 12)
for name, param in model.named_parameters():
if any(f'blocks.{i}.' in name for i in range(6, 12)):
param.requires_grad = True
opt2 = AdamW([
{'params': model.head.parameters(), 'lr': 5e-5},
{'params': [p for n, p in model.named_parameters()
if 'blocks' in n and p.requires_grad], 'lr': 5e-6}
], weight_decay=0.05)
sched2 = CosineAnnealingLR(opt2, T_max=phase2_epochs, eta_min=1e-7)
for epoch in range(phase2_epochs):
_train_epoch(model, train_loader, opt2, device)
_validate(model, val_loader, device)
sched2.step()
return model
def _train_epoch(model, loader, optimizer, device):
model.train()
criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
for images, labels in loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
loss = criterion(model(images), labels)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
Focal Loss при дисбалансе классов
Accuracy 94% при дисбалансе 1:200 — это обычно означает, что модель предсказывает только мажоритарный класс. Metrics для дисбаланса: macro F1, balanced accuracy, per-class recall.
import torch
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha: float = 1.0, gamma: float = 2.0,
class_weights: torch.Tensor = None):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.class_weights = class_weights
def forward(self, inputs: torch.Tensor,
targets: torch.Tensor) -> torch.Tensor:
ce_loss = F.cross_entropy(
inputs, targets, weight=self.class_weights, reduction='none'
)
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
# Веса классов обратно пропорционально частоте
class_counts = torch.tensor([10000, 500, 80], dtype=torch.float)
class_weights = (1.0 / class_counts).to(device)
criterion = FocalLoss(gamma=2.0, class_weights=class_weights)
Test Time Augmentation (TTA) для улучшения инференса
TTA — простой способ поднять accuracy на 1–3% без переобучения: прогоняем несколько аугментированных версий изображения и усредняем предсказания.
import torchvision.transforms as T
class TTAClassifier:
def __init__(self, model: nn.Module, n_augments: int = 5):
self.model = model.eval()
self.n_augments = n_augments
self.base_transform = T.Compose([
T.Resize(256), T.CenterCrop(224),
T.ToTensor(),
T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
self.tta_transforms = [
T.Compose([T.Resize(256), T.CenterCrop(224),
T.RandomHorizontalFlip(p=1.0),
T.ToTensor(),
T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
T.Compose([T.Resize(224), T.ToTensor(),
T.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])]),
]
@torch.no_grad()
def predict(self, image) -> torch.Tensor:
preds = [
torch.softmax(self.model(self.base_transform(image).unsqueeze(0)), dim=1)
]
for transform in self.tta_transforms:
preds.append(
torch.softmax(self.model(transform(image).unsqueeze(0)), dim=1)
)
return torch.stack(preds).mean(dim=0)
Экспорт для production
# ONNX export — для CPU/TensorRT деплоя
dummy = torch.randn(1, 3, 224, 224, device='cpu')
torch.onnx.export(
model.cpu().eval(),
dummy,
'classifier.onnx',
opset_version=17,
input_names=['image'],
output_names=['logits'],
dynamic_axes={'image': {0: 'batch'}, 'logits': {0: 'batch'}}
)
Сроки
| Задача | Срок |
|---|---|
| Fine-tuning готовой архитектуры (готовые данные) | 1–3 недели |
| Training с нуля + аугментации + оптимизация | 4–7 недель |
| Разработка кастомной архитектуры под домен | 8–14 недель |







