Оптимізація ML-моделі (дистилляція) для мобільного пристрою
Knowledge Distillation навчає маленьку модель ("студента") відтворювати поведінку великої ("учителя"). Не просто копіювання правильних відповідей, а засвоєння "м'яких" ймовірностей — розподілів по класам, що містять інформацію про подібність понять.
Дистилляція принципово відрізняється від pruning та квантизації: ви отримуєте нову, меншу архітектуру з власними вагами. Розмір студента ви вибираєте. Більш потужно, але вимагає ресурсів: набору даних, GPU, часу навчання.
Чому м'які мітки працюють краще
Нормальне навчання: правильний клас = 1.0, інші = 0.0. Hard labels.
Учитель на зображенні кота выдає: кіт 0.85, рись 0.08, тигр 0.04, собака 0.02. Ці м'які мітки несуть інформацію, що рись більше схожа на кота, ніж на літак. Студент, навчений на таких мітках, засвоює структуру простору ознак, не просто бінарну класифікацію.
import torch
import torch.nn.functional as F
def distillation_loss(student_logits, teacher_logits, true_labels, temperature=4.0, alpha=0.7):
"""
alpha — вага дистилляції vs hard label loss
temperature — сглажує розподіл учителя
"""
# Soft targets loss (KL divergence між студентом та учителем)
soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
soft_student = F.log_softmax(student_logits / temperature, dim=-1)
distill_loss = F.kl_div(soft_student, soft_teacher, reduction='batchmean') * (temperature ** 2)
# Hard label loss (звичайна крос-ентропія)
hard_loss = F.cross_entropy(student_logits, true_labels)
return alpha * distill_loss + (1 - alpha) * hard_loss
temperature ** 2 — нормалізуючий множник, який компенсує масштаб градієнтів при високій температурі. Без нього distill_loss та hard_loss мають різні масштаби.
Вибір архітектури студента
Студент повинен бути меншим за учителя, але не довільно. Гарні базові архітектури для мобіля:
- MobileNetV3-Small — 2.5 МБ, розроблена для мобіля з нуля, depthwise separable convolutions
- EfficientNet-Lite0/1 — хороший баланс точність/швидкість
- MobileViT-XXS — гібридна CNN+Transformer, 1.3 МБ
- DistilBERT (для NLP) — уже дистильована з BERT, 66 МБ vs 440 МБ
Для мобільної детекції об'єктів: студент на основі YOLOv8n (8 МБ) дистильована з YOLOv8l (87 МБ).
Процес дистилляції: приклад класифікації
# Припустимо: учитель — ResNet-50, студент — MobileNetV3-Small
teacher = torchvision.models.resnet50(pretrained=True).eval()
student = torchvision.models.mobilenet_v3_small(pretrained=False)
# Заморозьте учителя
for param in teacher.parameters():
param.requires_grad = False
optimizer = torch.optim.AdamW(student.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
for epoch in range(100):
student.train()
for images, labels in train_loader:
with torch.no_grad():
teacher_logits = teacher(images)
student_logits = student(images)
loss = distillation_loss(student_logits, teacher_logits, labels,
temperature=4.0, alpha=0.7)
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
val_acc = evaluate(student, val_loader)
print(f"Epoch {epoch}: student_acc={val_acc:.4f}")
Типові результати: MobileNetV3-Small, навчена нормально — 67–68% top-1 на ImageNet. Після дистилляції з ResNet-50 — 71–72%. 3–4% прирост від передачі знань.
Дистилляція проміжних шарів
Дистилляція тільки logits — базовий варіант. Сильніше: відповідність проміжних feature maps.
# FitNets / PKT: студент навчає feature maps учителя
class DistillationHook:
"""Хук для захоплення проміжних активацій"""
def __init__(self):
self.output = None
def __call__(self, module, input, output):
self.output = output
teacher_hook = DistillationHook()
student_hook = DistillationHook()
# Реєструйте на відповідних шарах
teacher.layer3.register_forward_hook(teacher_hook)
student.features[9].register_forward_hook(student_hook) # Аналогічний шар
# В циклі навчання додайте feature distillation loss
with torch.no_grad():
teacher(images)
teacher_features = teacher_hook.output
student(images) # з grad
student_features = student_hook.output
# Якщо розмірності відрізняються — потрібен adapter (1x1 Conv)
if teacher_features.shape[1] != student_features.shape[1]:
student_features = adapter_conv(student_features) # adapter навчається
feature_loss = F.mse_loss(student_features, teacher_features.detach())
Цей підхід вимагає вирівнювання розмірностей feature map через 1×1 conv adapter. Adapter додає кілька параметрів студента, залишається малим.
Data-Free Distillation
Іноді вихідний набір даних недоступний (IP обмеження, приватність). Data-free дистилляція — генеруйте синтетичні дані, що максимізують активації учителя:
# DAFL (Data-Free Learning): генератор створює зразки для дистилляції
generator = Generator(latent_dim=256, img_channels=3)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=1e-4)
for step in range(1000):
z = torch.randn(batch_size, 256)
fake_images = generator(z)
# Втрати: максимізуйте впевненість учителя + мінімізуйте невідповідність BatchNorm статистики
teacher_out = teacher(fake_images)
activation_loss = -teacher_out.max(dim=1)[0].mean() # учитель повинен бути впевнений
# BN статистика відповідності
bn_loss = compute_bn_statistics_loss(teacher, fake_images)
total_loss = activation_loss + 0.1 * bn_loss
optimizer_G.zero_grad()
total_loss.backward()
optimizer_G.step()
Якість data-free дистилляції нижче, ніж у повноданного варіанту, але іноді це єдиний варіант.
Дистилляція для мобільних NLP задач
Для мобільного NLP (класифікація відгуків, визначення intent, суммаризація): дистильовуйте від GPT-4 / Claude API відповідей у маленький BERT/DistilBERT.
# Зберіть м'які мітки від учителя (GPT-4 API)
# Для кожного навчального прикладу запитайте ймовірності класів
# Збережіть як мітки навчання для студента
# Студент — DistilBERT fine-tuned на цих м'яких мітках
DistilBERT (66 МБ, ONNX int8 — 18 МБ) працює на пристрої за 30–80 мс на iOS/Android. GPT-4 в хмарі — сотні мс, кошти за запити.
Процес
Визначте архітектуру студента за бюджетом ресурсів → налаштуйте дистилляцію (температура, альфа, проміжні шари) → навчайте на GPU → перевірте точність vs учитель → конвертуйте в Core ML / TFLite → остаточні замери на пристроях.
Орієнтири за часом
Базова логіт-дистилляція для класифікаційної задачі — 2–4 тижні (час GPU плюс підбір гіперпараметрів). Повна дистилляція з проміжними шарами, нестандартними архітектурами, доповненням даних — 5–10 тижнів.







