Оптимизация ML-модели (дистилляция) для мобильного устройства
Knowledge Distillation — обучение маленькой модели («студент») воспроизводить поведение большой («учитель»). Не просто копировать правильные ответы, а перенять «мягкие» вероятности учителя — распределение по всем классам, которое содержит информацию о похожести понятий между собой.
Дистилляция принципиально отличается от pruning и квантизации: вы получаете новую, меньшую архитектуру с её собственными весами. Размер «студента» вы выбираете сами. Это мощнее, но требует больше ресурсов: нужен датасет, GPU, время на обучение.
Почему мягкие метки работают лучше
Обычное обучение: правильный класс = 1.0, остальные = 0.0. Hard labels.
Учитель на изображении кошки выдаёт: кошка 0.85, рысь 0.08, тигр 0.04, собака 0.02 ... Эти «мягкие» метки несут информацию, что рысь похожа на кошку больше, чем самолёт. Студент, обученный на таких метках, усваивает структуру пространства признаков, а не просто решение бинарного классификатора.
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, true_labels, temperature=4.0, alpha=0.7):
"""
alpha — вес дистилляции vs hard label loss
temperature — сглаживает распределение учителя
"""
# Soft targets loss (KL divergence между студентом и учителем)
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
# Hard label loss (обычная кросс-энтропия)
hard_loss = F.cross_entropy(student_logits, true_labels)
return alpha * distill_loss + (1 - alpha) * hard_loss
temperature ** 2 — нормализующий множитель, компенсирующий масштаб градиентов при высокой температуре. Без него distill_loss и hard_loss находятся в разных масштабах.
Выбор архитектуры студента
Студент должен быть меньше учителя, но не произвольно. Хорошие базовые архитектуры для мобиля:
- MobileNetV3-Small — 2.5 МБ, проектировался для мобиля с нуля, depthwise separable convolutions
- EfficientNet-Lite0/1 — хороший баланс точность/скорость
- MobileViT-XXS — hybrid CNN+Transformer, 1.3 МБ
- DistilBERT (для NLP) — уже дистиллированный из BERT, 66 МБ vs 440 МБ
Для задач детекции объектов на мобиле: студент на базе YOLOv8n (8 МБ) дистиллируется из YOLOv8l (87 МБ).
Процесс дистилляции: пример для классификации
# Предположим: учитель — ResNet-50, студент — MobileNetV3-Small
teacher = torchvision.models.resnet50(pretrained=True).eval()
student = torchvision.models.mobilenet_v3_small(pretrained=False)
# Замораживаем учителя
for param in teacher.parameters():
param.requires_grad = False
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
student.train()
for images, labels in train_loader:
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss = distillation_loss(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.7)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
val_acc = evaluate(student, val_loader)
print(f"Epoch {epoch}: student_acc={val_acc:.4f}")
Типичные результаты: MobileNetV3-Small обученный обычно — 67–68% top-1 на ImageNet. После дистилляции из ResNet-50 — 71–72%. Прирост 3–4% за счёт knowledge transfer.
Intermediate layer distillation
Дистилляция только по выходам (logits) — базовый вариант. Более сильный: добавляем соответствие промежуточных feature maps.
# FitNets / PKT: студент учит feature maps учителя
class DistillationHook:
"""Хук для захвата промежуточных активаций"""
def __init__(self):
self.output = None
def __call__(self, module, input, output):
self.output = output
teacher_hook = DistillationHook()
student_hook = DistillationHook()
# Регистрируем на соответствующих слоях
teacher.layer3.register_forward_hook(teacher_hook)
student.features[9].register_forward_hook(student_hook) # Analogous layer
# В цикле обучения добавляем feature distillation loss
with torch.no_grad():
teacher(images)
teacher_features = teacher_hook.output
student(images) # с grad
student_features = student_hook.output
# Если размерности отличаются — нужен adapter (1x1 Conv)
if teacher_features.shape[1] != student_features.shape[1]:
student_features = adapter_conv(student_features) # adapter обучается вместе
feature_loss = F.mse_loss(student_features, teacher_features.detach())
Такой подход требует выравнивания размерностей feature maps между учителем и студентом — через adapter 1×1 свёрток. Адаптер добавляет немного параметров студенту, но остаётся маленьким.
Data-Free Distillation
Иногда исходный датасет недоступен (IP restrictions, privacy). Data-free distillation — генерируем синтетические данные, которые максимизируют активации учителя:
# DAFL (Data-Free Learning): генератор создаёт «образцы» для дистилляции
generator = Generator(latent_dim=256, img_channels=3)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4)
for step in range(1000):
z = torch.randn(batch_size, 256)
fake_images = generator(z)
# Потери: максимизируем уверенность учителя + минимизируем BatchNorm statistics mismatch
teacher_out = teacher(fake_images)
activation_loss = -teacher_out.max(dim=1)[0].mean() # учитель должен быть уверен
# BN statistics matching
bn_loss = compute_bn_statistics_loss(teacher, fake_images)
total_loss = activation_loss + 0.1 * bn_loss
optimizer_G.zero_grad()
total_loss.backward()
optimizer_G.step()
Качество data-free дистилляции ниже полноданных варианта, но иногда это единственный вариант.
Дистилляция для NLP задач на мобиле
Для мобильных приложений с NLP (классификация отзывов, определение intent, суммаризация): дистиллируем из GPT-4 / Claude API ответов в маленький BERT/DistilBERT.
# Собираем soft labels от учителя (GPT-4 API)
# Для каждого обучающего примера запрашиваем вероятности классов
# Сохраняем как обучающие метки для студента
# Студент — DistilBERT fine-tuned на этих мягких метках
DistilBERT (66 МБ, ONNX int8 — 18 МБ) работает на устройстве за 30–80 мс на iOS/Android. GPT-4 в облаке — сотни мс, деньги за запросы.
Процесс
Определение архитектуры студента под ресурсный бюджет → настройка дистилляции (temperature, alpha, intermediate layers) → обучение на GPU → проверка точности vs учитель → конвертация в Core ML / TFLite → финальные замеры на устройствах.
Ориентиры по срокам
Базовая логит-дистилляция для классификационной задачи — 2–4 недели (GPU-время плюс подбор гиперпараметров). Полная дистилляция с промежуточными слоями, нестандартными архитектурами, data augmentation — 5–10 недель.







