Neural Architecture Search (NAS)
Вручну спроектовані архітектури – результат досвіду та інтуїції. NAS - алгоритмічний перебір архітектурного простору з оптимізацією під конкретне завдання, датасет та hardware target. Чи не заміна архітектурного мислення, а спосіб знайти зміни, до яких людина просто не дійде за розумний час.
Чому наївний NAS вбиває GPU-бюджет
Класичний NAS у виконанні NASNet (Google, 2017) – 500 GPU-днів на A100-еквіваленті. Проблема полягає в тому, що кожна кандидатна архітектура навчалася з нуля до збіжності. При просторі пошуку 10^10 змін повний перебір неможливий у принципі.
Сучасні підходи вирішують це через три принципово різні ідеї:
One-shot NAS / Weight Sharing. Супермережа (supernet) включає всі можливі підграфи. Кожен кандидат — «шлях» через цю супермережу, яка використовує вже навчені ваги. DARTS, SNAS, Single-Path NAS – всі вони будуються на цій ідеї. Час пошуку падає із сотень GPU-днів до 1–4 днів.
Predictor-based NAS. Навчається surrogate-модель, яка передбачає прозорість архітектури без її повного навчання. BANANAS, NASBOWL, NAO використовують цей підхід. Вибірка із простору пошуку + 100–200 реальних оцінок → предиктор точності для наступних мільйонів кандидатів.
Hardware-aware NAS. Оптимізація не тільки по accuracy, але по latency на конкретному пристрої. MNasNet, FBNet, Once-for-All – шукають Pareto-front у просторі (accuracy, latency/MACs). Критично для edge deployment.
Глибокий розбір: DARTS та його проблеми у production
DARTS (Differentiable Architecture Search) - найбільш використовуваний one-shot метод. Ідея: замість дискретного вибору операції (3×3 conv vs 5×5 conv vs skip) використовуємо безперервні ваги для кожної операції, оптимізовані через gradient descent.
import torch
import torch.nn as nn
from torch.nn import functional as F
class MixedOp(nn.Module):
"""
DARTS mixed operation: взвешенная сумма всех кандидатных операций.
Веса alpha оптимизируются через архитектурный градиент.
"""
def __init__(self, C: int, stride: int):
super().__init__()
self._ops = nn.ModuleList()
for primitive in PRIMITIVES: # ['none', 'skip_connect', 'sep_conv_3x3', ...]
op = OPS[primitive](C, stride, affine=False)
self._ops.append(op)
def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
# weights = softmax(alpha) — архитектурные веса
return sum(w * op(x) for w, op in zip(weights, self._ops))
class DARTSCell(nn.Module):
def __init__(self, steps: int, multiplier: int, C_prev_prev: int,
C_prev: int, C: int, reduction: bool, reduction_prev: bool):
super().__init__()
self._steps = steps # число промежуточных узлов (обычно 4)
self._multiplier = multiplier # сколько узлов конкатенируется на выходе
# ... инициализация preprocess и mixed ops
def forward(self, s0: torch.Tensor, s1: torch.Tensor,
weights: torch.Tensor) -> torch.Tensor:
states = [s0, s1]
offset = 0
for i in range(self._steps):
s = sum(
self._ops[offset + j](h, weights[offset + j])
for j, h in enumerate(states)
)
offset += len(states)
states.append(s)
return torch.cat(states[-self._multiplier:], dim=1)
Дворівнева оптимізація DARTS – головна інженерна складність. Мережеві ваги w та архітектурні ваги α оптимізуються поперемінно:
def train_darts_step(model, architect, optimizer_w, optimizer_alpha,
train_queue, valid_queue, lr_w: float):
"""
DARTS: чередование шагов оптимизации весов сети и архитектурных весов.
"""
for step, (input_train, target_train) in enumerate(train_queue):
# 1. Архитектурный шаг: обновляем alpha по валидационной потере
input_valid, target_valid = next(iter(valid_queue))
architect.step(
input_train, target_train,
input_valid, target_valid,
lr=lr_w, optimizer=optimizer_w,
unrolled=False # True = second-order DARTS, в 2x дороже
)
# 2. Шаг весов: обновляем w по тренировочной потере
optimizer_w.zero_grad()
logits = model(input_train)
loss = F.cross_entropy(logits, target_train)
loss.backward()
nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
optimizer_w.step()
Проблема колапсу операцій. У чистому DARTS skip-connection операції майже завжди "перемагають" - у них нульові параметри, вони добре навчаються на ранніх етапах, і архітектурні ваги α[skip] стійко ростуть. Результат: знайдена архітектура вироджується в майже skip-only мережу з поганим узагальненням. Рішення:
- DARTS+: відсікання skip-connections з найбільшим α на фінальному етапі
- P-DARTS: прогресивне збільшення глибини мережі під час пошуку
- GDAS: Gumbel-softmax замість softmax для α – розріджений вибір операцій
Hardware-aware NAS на практиці
Для мобільного деплою (Android, CoreML) accuracy – не єдина метрика. Latency на цільовому залозі важливіший за FLOP-підрахунок, тому що різні операції виконуються по-різному на реальному залозі.
Once-for-All (MIT) — навчається одна супермережа, з якої без донавчання витягуються підмережі під будь-який hardware constraint:
from ofa.model_zoo import ofa_net
# Загружаем предобученную OFA суперсеть
ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)
# Специализируем под конкретный device с latency constraint
from ofa.nas.efficiency_predictor import Latency_MBV3_MeasuredNet
efficiency_predictor = Latency_MBV3_MeasuredNet(
'note10', # Samsung Note10 — реальные замеры латентности
ofa_network
)
# Evolutionary search: ищем подсеть с latency < 25ms и max accuracy
from ofa.nas.search_algorithm.evolution_finder import EvolutionFinder
finder = EvolutionFinder(
efficiency_constraint=25, # ms
efficiency_predictor=efficiency_predictor,
accuracy_predictor=accuracy_predictor,
population_size=100,
max_time_budget=500 # эволюционных шагов
)
best_valids, best_info = finder.run_evolution_search()
У реальному проекті: NAS під MobileNetV3-space для завдання класифікації виробничого шлюбу (640 480, 12 класів). Цільова платформа - NVIDIA Jetson Nano (4GB RAM, 128 CUDA cores). Обмеження: latency < 30ms при batch=1. Ручна архітектура MobileNetV3-Large давала 28.4ms та accuracy 91.3%. OFA-пошук за 6 годин знайшов підмережа: 22.1ms, accuracy 92.7%. Без жодної ручної зміни архітектури.
Практичний стек і коли NAS виправданий
| Сценарій | Підхід | Час пошуку | Інструмент |
|---|---|---|---|
| Image classification, стандартний | DARTS / PC-DARTS | 1–2 дні (4× A100) | nni (Microsoft) або automl (torchvision) |
| Edge deployment (мобайл, MCU) | OFA / MNasNet-style | 6–24 години | Once-for-All, TuNAS |
| NLP / Transformer architecture | NAS-BERT, AutoFormer | 2-5 днів | Hugging Face NAS toolkit |
| Кастомні операції, custom hardware | Predictor-based NAS | 1-3 дні + 100 eval | BANANAS, NASBOWL |
Microsoft NNI — найбільш зрілий open-source фреймворк для NAS. Підтримує DARTS, ENAS, Random NAS, SPOS із коробки. Інтеграція з PyTorch та TensorFlow.
import nni
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback
trainer = DartsTrainer(
model=model,
loss=nn.CrossEntropyLoss(),
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optimizer,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=64,
log_frequency=10,
callbacks=[LRSchedulerCallback(scheduler)]
)
trainer.fit()
# Получаем финальную архитектуру
export_result = trainer.export()
Коли NAS не потрібен. Якщо завдання стандартне і даних < 50k прикладів - візьміть передбачувану ResNet-50 або EfficientNet-B0 та fine-tune. NAS виправданий за: кастомних hardware-обмеження, нетипових вхідних даних (гіперспектральні знімки, специфічні модальності), необхідності кардинально зменшити модель без втрати якості.
Процес роботи
- Визначення search space - критично важливий крок: задаємо блоки, операції, діапазони каналів, максимальну глибину. Неправильний search space = поганий результат незалежно від алгоритму
- Вибір стратегії пошуку - DARTS для GPU-rich оточення, evolutionary для hardware-aware, predictor-based при обмеженому бюджеті оцінок
- Профілювання цільового заліза — реальні виміри latency/throughput для операцій із search space на production hardware
- Пошук та оцінка кандидатів — Weights & Biases для трекінгу, MLflow для зберігання знайдених архітектур
- Full training знайденої архітектури з нуля — ваги з фази пошуку не використовуються
- Validation на holdout set, профіль на production-залізі
Терміни: визначення search space та setup – 1 тиждень. Сам пошук – 1–5 днів обчислень. Full training кандидата + валідація - 1-2 тижні. Разом: 3-6 тижнів на повний NAS-цикл.







