Pruning (обрізка) нейронної мережі для оптимізації
Pruning — видалення малозначимих параметрів (ваг, нейронів, attention heads, шарів) з навченої нейронної мережі. Мета — зменшити розмір моделі та прискорити інференс при мінімальній втраті якості. Для LLM pruning часто комбінують з квантизацією та дистиляцією для максимального стиснення.
Види pruning
Unstructured pruning: обнуляються окремі ваги по всій матриці. Високе стиснення, але потребує sparse computation — стандартні GPU не прискорюють sparse операції «з коробки».
Structured pruning: видаляються цілі структурні елементи — нейрони, attention heads, шари. Результат — реально менша щільна модель, яка працює швидше на стандартному залізі.
Semi-structured pruning (N:M sparsity): видаляються N ваг з кожного блоку M. Формат 2:4 підтримується NVIDIA Ampere та вище на апаратному рівні (до 2× прискорення).
LLM-Pruner: структурований pruning LLM
# Приклад використання LLM-Pruner
# pip install llm-pruner
from LLMPruner.pruner import LlamaStructuredPruner
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
pruner = LlamaStructuredPruner(
model=model,
tokenizer=tokenizer,
pruning_ratio=0.25, # Видалити 25% параметрів
)
# Обчислення важливості параметрів на calibration data
calibration_data = ["Текст для аналізу важливості ваг...", ...]
pruner.get_mask(calibration_data, method="taylor") # Taylor expansion importance
# Застосування маски та pruning
pruned_model = pruner.prune()
SparseGPT: ефективний unstructured pruning без retraining
SparseGPT — метод, що дозволяє pruning 50–60% ваг LLM за кілька годин без повторного навчання:
# sparsegpt — бібліотека від авторів методу
# Приклад концептуального коду
from sparsegpt import SparseGPT
sparsegpt = SparseGPT(model)
sparsegpt.fasterprune(
sparsity=0.5, # 50% sparsity
prunen=2, # N в N:M
prunem=4, # M в N:M (2:4 — підтримується апаратно)
percdamp=0.01,
blocksize=128,
)
При 2:4 sparsity (50%) на NVIDIA A100/H100 прискорення inference на Tensor Core близько 1.7–2×.
Wanda: простий та ефективний pruning
Wanda (Pruning by Weights and Activations) — один з найефективніших методів, що використовує добуток |W| × ||X|| для визначення важливості ваг:
# Wanda простіше за SparseGPT, але порівнянна якість
# Працює за кілька хвилин на 7B моделі
def wanda_pruning(model, calibration_loader, sparsity=0.5):
"""Спрощена реалізація Wanda"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Накопичуємо статистику активацій
activation_norms = get_activation_norms(module, calibration_loader)
# Importance score = |W| * ||X||
importance = module.weight.abs() * activation_norms
# Pruning за threshold
threshold = torch.quantile(importance, sparsity)
mask = importance > threshold
module.weight.data *= mask
return model
Depth pruning: видалення шарів
Для LLM середні шари часто менш критичні, ніж перші та останні:
def depth_prune_llm(model, layers_to_remove: list[int]):
"""Видалення вказаних decoder layers"""
# Для Llama-архітектури
remaining_layers = [
layer for i, layer in enumerate(model.model.layers)
if i not in layers_to_remove
]
model.model.layers = torch.nn.ModuleList(remaining_layers)
return model
# Приклад: видалюємо 8 середніх шарів з 32 (25% depth reduction)
pruned_model = depth_prune_llm(model, layers_to_remove=list(range(12, 20)))
# Результат: 24-шарова модель з 32-шарової
Практичний case study: оптимізація edge-deploy
Завдання: дофіно-tuned Llama 3.1 8B для промислового контролера (ARM-сервер, 16GB RAM, немає GPU). Вимога: інференс < 2s на запит.
Стратегія оптимізації:
- GGUF Q4_K_M квантизація: 8B → 4.1GB, 8 tok/s на CPU (недостатньо)
- Depth pruning (видалення 8 шарів з 32): -25% latency, -3% якості
- Width pruning attention heads (видалення 20% голів): -15% latency
- Повторна квантизація: GGUF Q4_K_M на pruned моделі
Фінальні характеристики pruned+quantized моделі:
- Розмір: 3.1GB (vs 4.1GB)
- Throughput: 14 tok/s на ARM (vs 8 tok/s)
- Latency для 100-токенної відповіді: 7с → 1.8с (ціль досягнута)
- Втрата якості (LLM-judge): 7%
Recovery Fine-Tuning після pruning
Pruning завжди викликає деградацію. Recovery training відновлює частину якості:
# Після pruning — короткий fine-tuning для відновлення
from trl import SFTTrainer, SFTConfig
# Використовуємо той самий датасет, що для fine-tuning, але з нижчим LR
recovery_config = SFTConfig(
num_train_epochs=1, # 1 епоха для recovery
learning_rate=5e-5, # Нижче, ніж при full fine-tuning
gradient_checkpointing=True,
bf16=True,
)
trainer = SFTTrainer(model=pruned_model, args=recovery_config, train_dataset=dataset)
trainer.train()
Recovery fine-tuning типово повертає 50–70% втраченої якості при 1 епохі навчання.
Часові рамки
- Вибір стратегії pruning: 3–5 днів
- Calibration та pruning: 4–24 години (залежить від методу та розміру)
- Recovery fine-tuning: 2–8 годин
- Benchmarking та оцінка: 3–5 днів
- Всього: 2–4 тижні







