Pruning neural network model for optimization
Pruning is removal of insignificant parameters (weights, neurons, attention heads, layers) from a trained neural network. Goal is to reduce model size and speed up inference with minimal quality loss. For LLMs, pruning is often combined with quantization and distillation for maximum compression.
Types of pruning
Unstructured pruning: individual weights are zeroed throughout the matrix. High compression, but requires sparse computation — standard GPUs don't accelerate sparse operations out of the box.
Structured pruning: entire structural elements are removed — neurons, attention heads, layers. Result — actually smaller dense model that works faster on standard hardware.
Semi-structured pruning (N:M sparsity): N weights deleted from each M block. Format 2:4 is supported by NVIDIA Ampere and above at hardware level (up to 2× speedup).
LLM-Pruner: structured LLM pruning
# Example using LLM-Pruner
# pip install llm-pruner
from LLMPruner.pruner import LlamaStructuredPruner
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-7B")
pruner = LlamaStructuredPruner(
model=model,
tokenizer=tokenizer,
pruning_ratio=0.25, # Remove 25% of parameters
)
# Calculate parameter importance on calibration data
calibration_data = ["Text for weight importance analysis...", ...]
pruner.get_mask(calibration_data, method="taylor") # Taylor expansion importance
# Apply mask and pruning
pruned_model = pruner.prune()
SparseGPT: efficient unstructured pruning without retraining
SparseGPT is a method allowing pruning 50–60% of LLM weights in hours without retraining:
# sparsegpt — library from method authors
# Example conceptual code
from sparsegpt import SparseGPT
sparsegpt = SparseGPT(model)
sparsegpt.fasterprune(
sparsity=0.5, # 50% sparsity
prunen=2, # N in N:M
prunem=4, # M in N:M (2:4 — hardware supported)
percdamp=0.01,
blocksize=128,
)
At 2:4 sparsity (50%) on NVIDIA A100/H100, inference speedup on Tensor Core around 1.7–2×.
Wanda: simple and effective pruning
Wanda (Pruning by Weights and Activations) is one of the most effective methods, using |W| × ||X|| product to determine weight importance:
# Wanda simpler than SparseGPT, comparable quality
# Works in minutes on 7B model
def wanda_pruning(model, calibration_loader, sparsity=0.5):
"""Simplified Wanda implementation"""
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear):
# Accumulate activation statistics
activation_norms = get_activation_norms(module, calibration_loader)
# Importance score = |W| * ||X||
importance = module.weight.abs() * activation_norms
# Pruning by threshold
threshold = torch.quantile(importance, sparsity)
mask = importance > threshold
module.weight.data *= mask
return model
Depth pruning: layer removal
For LLMs, middle layers are often less critical than first and last:
def depth_prune_llm(model, layers_to_remove: list[int]):
"""Remove specified decoder layers"""
# For Llama architecture
remaining_layers = [
layer for i, layer in enumerate(model.model.layers)
if i not in layers_to_remove
]
model.model.layers = torch.nn.ModuleList(remaining_layers)
return model
# Example: remove 8 middle layers from 32 (25% depth reduction)
pruned_model = depth_prune_llm(model, layers_to_remove=list(range(12, 20)))
# Result: 24-layer model from 32-layer
Practical case study: edge deployment optimization
Task: fine-tuned Llama 3.1 8B for industrial controller (ARM server, 16GB RAM, no GPU). Requirement: inference < 2s per request.
Optimization strategy:
- GGUF Q4_K_M quantization: 8B → 4.1GB, 8 tok/s on CPU (insufficient)
- Depth pruning (remove 8 layers from 32): -25% latency, -3% quality
- Width pruning attention heads (remove 20% heads): -15% latency
- Re-quantization: GGUF Q4_K_M on pruned model
Final pruned+quantized model characteristics:
- Size: 3.1GB (vs 4.1GB)
- Throughput: 14 tok/s on ARM (vs 8 tok/s)
- Latency for 100-token answer: 7s → 1.8s (goal achieved)
- Quality loss (LLM-judge): 7%
Recovery Fine-Tuning after pruning
Pruning always causes degradation. Recovery training recovers some quality:
# After pruning — brief fine-tuning for recovery
from trl import SFTTrainer, SFTConfig
# Use same dataset as fine-tuning, but with lower LR
recovery_config = SFTConfig(
num_train_epochs=1, # 1 epoch for recovery
learning_rate=5e-5, # Lower than full fine-tuning
gradient_checkpointing=True,
bf16=True,
)
trainer = SFTTrainer(model=pruned_model, args=recovery_config, train_dataset=dataset)
trainer.train()
Recovery fine-tuning typically returns 50–70% of lost quality in 1 training epoch.
Timeline
- Choosing pruning strategy: 3–5 days
- Calibration and pruning: 4–24 hours (depends on method and size)
- Recovery fine-tuning: 2–8 hours
- Benchmarking and evaluation: 3–5 days
- Total: 2–4 weeks







