Оптимізація ML-моделі (pruning) для мобільного пристрою
Pruning видаляє ваги або нейрони з моделей. Логіка: в мережах, навчених на реальних даних, значна частка ваг близька до нуля та майже не впливає на вихід. Вони можуть бути обнулені або видалені без суттєвої втрати точності, але з виграшем у швидкості та розмірі.
Звучить привабливо. На практиці — pruning складніше за квантизацію, вимагає перенавчання після розрідження та не завжди дає очікуване прискорення на мобільних пристроях через особливості реалізації.
Два типи pruning
Unstructured pruning — обнулення окремих ваг (розріджені матриці). Матриця з 90% нулів — начебто 10-кратна економія. Але GPU/NPU працюють з щільними матрицями; розріджені обчислення там не прискорюються. Практична користь: зменшений розмір моделі після стиснення (нулі добре стискаються). Не швидкість інференсу на типових пристроях.
Structured pruning — видалення цілих фільтрів (каналів) у шарах згортки або голів у attention. Результат — фізично менший граф, справді швидший на будь-якому залізі. Це те, що потрібно мобілю.
Structured Pruning: практика на PyTorch
import torch
import torch.nn.utils.prune as prune
# L1-based structured pruning: видаліть 30% фільтрів з Conv2d шарів
# за критерієм мінімальної L1-норми (найменш важливі фільтри)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name='weight',
amount=0.3, # 30% каналів
n=1, # L1 норма
dim=0 # dim=0 — вихідні фільтри
)
# Після pruning — зробіть ваги постійними (видаліть маску)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.remove(module, 'weight')
Після цього модель містить нульові фільтри, але вони залишаються в графі. Наступний крок — фактичне видалення нульових каналів:
# Користувацька функція для видалення нульових фільтрів
def remove_zero_filters(conv_layer, next_layer=None):
"""Видаліть фільтри з нульовими вагами та синхронізуйте наступний шар"""
weight = conv_layer.weight.data
# Маска: фільтри з ненульовими вагами
nonzero_mask = weight.abs().sum(dim=(1,2,3)) > 1e-6
conv_layer.weight = nn.Parameter(weight[nonzero_mask])
if conv_layer.bias is not None:
conv_layer.bias = nn.Parameter(conv_layer.bias.data[nonzero_mask])
conv_layer.out_channels = nonzero_mask.sum().item()
# Синхронізуйте наступний шар (вхідні канали)
if next_layer is not None and isinstance(next_layer, nn.Conv2d):
next_layer.weight = nn.Parameter(next_layer.weight.data[:, nonzero_mask])
next_layer.in_channels = nonzero_mask.sum().item()
Робіть обережно — BatchNorm шари після Conv також містять параметри для кожного каналу, що вимагають синхронізації.
Fine-tuning після pruning
Після видалення 20–40% фільтрів модель втрачає точність. Fine-tuning обов'язковий. Правило: чим агресивніше pruning, тим довше fine-tuning.
# Fine-tuning після pruning — зазвичай 10-20% від оригінального числа епох
optimizer = torch.optim.Adam(pruned_model.parameters(), lr=1e-4) # нижчий LR
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=20)
for epoch in range(20):
train_one_epoch(pruned_model, train_loader, optimizer)
val_acc = evaluate(pruned_model, val_loader)
scheduler.step()
print(f"Epoch {epoch}: val_acc={val_acc:.4f}")
Iterative pruning — цикл pruning → fine-tuning → pruning — дає кращі результати ніж однократне видалення великої кількості фільтрів.
Lottery Ticket Hypothesis: глибше
Для задач, де точність критична, використовуйте Lottery Ticket підхід: навчіть повну мережу, знайдіть "виграшні квитки" — розріджені підмережі, які можна навчити до порівнянної точності з нуля. Реалізуйте через torch_pruning:
import torch_pruning as tp
# Аналізуйте залежності шарів
example_inputs = torch.zeros(1, 3, 224, 224)
DG = tp.DependencyGraph()
DG.build_dependency(model, example_inputs=example_inputs)
# Отримайте групи пов'язаних шарів (pruning одного вимагає pruning пов'язаних)
pruner = tp.pruner.MagnitudePruner(
model,
example_inputs,
importance=tp.importance.MagnitudeImportance(p=1),
pruning_ratio=0.5, # видаліть 50% каналів
global_pruning=False,
iterative_steps=5 # ітеративно протягом 5 кроків
)
Чому pruning не завжди прискорює
MobileNetV3 уже оптимізований: depthwise separable convolutions з малою кількістю каналів. Видаліть 30% фільтрів з 16-канального шару — отримайте 11 каналів. Різниця у швидкості — мінімальна, overhead від операцій з тензорами залишається.
Pruning добре працює на великих моделях: ResNet-50, EfficientNet-B4, BERT. На компактних MobileNet/EfficientNet-lite — нижчий ефект. Краще почати з легшої базової архітектури ніж прунінгу важкої.
Комбінація з квантизацією
Pruning + квантизація — стандартна двохкрокова оптимізація:
- Structured pruning 30–40% → fine-tune → зменшіть граф
- INT8 квантизація стиснутого графу → остаточна модель
Приклад результату: EfficientNet-B0 (20 МБ FP32, 80 мс Android) → 35% pruning + INT8 → 4 МБ, 18 мс. Top-1 точність впала з 77.1% до 75.8%.
Процес
Аналіз моделі на pruning-придатність → вибір критерію та ступеня розрідження → ітеративний pruning + fine-tuning → верифікація точності → опціонально: квантизація → замери на цільових пристроях.
Орієнтири за часом
Structured pruning з fine-tuning на готовому наборі даних — 2–4 тижні. Ітеративний pruning з повним експериментальним циклом — 4–8 тижнів.







