Реалізація Neural Architecture Search (NAS) для проектування оптимальної архітектури моделі

Проектуємо та впроваджуємо системи штучного інтелекту: від прототипу до production-ready рішення. Наша команда поєднує експертизу в машинному навчанні, дата-інжинірингу та MLOps, щоб AI працював не в лабораторії, а в реальному бізнесі.
Показано 1 з 1Усі 1566 послуг
Реалізація Neural Architecture Search (NAS) для проектування оптимальної архітектури моделі
Складний
від 1 тижня до 3 місяців
Часті запитання

Напрямки AI-розробки

Етапи розробки AI-рішення

Останні роботи

  • image_website-b2b-advance_0.webp
    Розробка сайту компанії B2B ADVANCE
    1284
  • image_web-applications_feedme_466_0.webp
    Розробка веб-додатків для компанії FEEDME
    1196
  • image_websites_belfingroup_462_0.webp
    Розробка веб-сайту для компанії БЕЛФІНГРУП
    901
  • image_ecommerce_furnoro_435_0.webp
    Розробка інтернет магазину для компанії FURNORO
    1119
  • image_logo-advance_0.webp
    Розробка логотипу компанії B2B Advance
    586
  • image_crm_enviok_479_0.webp
    Розробка веб-додатків для компанії Enviok
    853

Neural Architecture Search (NAS)

Вручну спроектовані архітектури – результат досвіду та інтуїції. NAS - алгоритмічний перебір архітектурного простору з оптимізацією під конкретне завдання, датасет та hardware target. Чи не заміна архітектурного мислення, а спосіб знайти зміни, до яких людина просто не дійде за розумний час.

Чому наївний NAS вбиває GPU-бюджет

Класичний NAS у виконанні NASNet (Google, 2017) – 500 GPU-днів на A100-еквіваленті. Проблема полягає в тому, що кожна кандидатна архітектура навчалася з нуля до збіжності. При просторі пошуку 10^10 змін повний перебір неможливий у принципі.

Сучасні підходи вирішують це через три принципово різні ідеї:

One-shot NAS / Weight Sharing. Супермережа (supernet) включає всі можливі підграфи. Кожен кандидат — «шлях» через цю супермережу, яка використовує вже навчені ваги. DARTS, SNAS, Single-Path NAS – всі вони будуються на цій ідеї. Час пошуку падає із сотень GPU-днів до 1–4 днів.

Predictor-based NAS. Навчається surrogate-модель, яка передбачає прозорість архітектури без її повного навчання. BANANAS, NASBOWL, NAO використовують цей підхід. Вибірка із простору пошуку + 100–200 реальних оцінок → предиктор точності для наступних мільйонів кандидатів.

Hardware-aware NAS. Оптимізація не тільки по accuracy, але по latency на конкретному пристрої. MNasNet, FBNet, Once-for-All – шукають Pareto-front у просторі (accuracy, latency/MACs). Критично для edge deployment.

Глибокий розбір: DARTS та його проблеми у production

DARTS (Differentiable Architecture Search) - найбільш використовуваний one-shot метод. Ідея: замість дискретного вибору операції (3×3 conv vs 5×5 conv vs skip) використовуємо безперервні ваги для кожної операції, оптимізовані через gradient descent.

import torch
import torch.nn as nn
from torch.nn import functional as F

class MixedOp(nn.Module):
    """
    DARTS mixed operation: взвешенная сумма всех кандидатных операций.
    Веса alpha оптимизируются через архитектурный градиент.
    """
    def __init__(self, C: int, stride: int):
        super().__init__()
        self._ops = nn.ModuleList()
        for primitive in PRIMITIVES:  # ['none', 'skip_connect', 'sep_conv_3x3', ...]
            op = OPS[primitive](C, stride, affine=False)
            self._ops.append(op)

    def forward(self, x: torch.Tensor, weights: torch.Tensor) -> torch.Tensor:
        # weights = softmax(alpha) — архитектурные веса
        return sum(w * op(x) for w, op in zip(weights, self._ops))


class DARTSCell(nn.Module):
    def __init__(self, steps: int, multiplier: int, C_prev_prev: int,
                 C_prev: int, C: int, reduction: bool, reduction_prev: bool):
        super().__init__()
        self._steps = steps       # число промежуточных узлов (обычно 4)
        self._multiplier = multiplier  # сколько узлов конкатенируется на выходе
        # ... инициализация preprocess и mixed ops

    def forward(self, s0: torch.Tensor, s1: torch.Tensor,
                weights: torch.Tensor) -> torch.Tensor:
        states = [s0, s1]
        offset = 0
        for i in range(self._steps):
            s = sum(
                self._ops[offset + j](h, weights[offset + j])
                for j, h in enumerate(states)
            )
            offset += len(states)
            states.append(s)
        return torch.cat(states[-self._multiplier:], dim=1)

Дворівнева оптимізація DARTS – головна інженерна складність. Мережеві ваги w та архітектурні ваги α оптимізуються поперемінно:

def train_darts_step(model, architect, optimizer_w, optimizer_alpha,
                     train_queue, valid_queue, lr_w: float):
    """
    DARTS: чередование шагов оптимизации весов сети и архитектурных весов.
    """
    for step, (input_train, target_train) in enumerate(train_queue):
        # 1. Архитектурный шаг: обновляем alpha по валидационной потере
        input_valid, target_valid = next(iter(valid_queue))
        architect.step(
            input_train, target_train,
            input_valid, target_valid,
            lr=lr_w, optimizer=optimizer_w,
            unrolled=False  # True = second-order DARTS, в 2x дороже
        )

        # 2. Шаг весов: обновляем w по тренировочной потере
        optimizer_w.zero_grad()
        logits = model(input_train)
        loss = F.cross_entropy(logits, target_train)
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        optimizer_w.step()

Проблема колапсу операцій. У чистому DARTS skip-connection операції майже завжди "перемагають" - у них нульові параметри, вони добре навчаються на ранніх етапах, і архітектурні ваги α[skip] стійко ростуть. Результат: знайдена архітектура вироджується в майже skip-only мережу з поганим узагальненням. Рішення:

  • DARTS+: відсікання skip-connections з найбільшим α на фінальному етапі
  • P-DARTS: прогресивне збільшення глибини мережі під час пошуку
  • GDAS: Gumbel-softmax замість softmax для α – розріджений вибір операцій

Hardware-aware NAS на практиці

Для мобільного деплою (Android, CoreML) accuracy – не єдина метрика. Latency на цільовому залозі важливіший за FLOP-підрахунок, тому що різні операції виконуються по-різному на реальному залозі.

Once-for-All (MIT) — навчається одна супермережа, з якої без донавчання витягуються підмережі під будь-який hardware constraint:

from ofa.model_zoo import ofa_net

# Загружаем предобученную OFA суперсеть
ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.0', pretrained=True)

# Специализируем под конкретный device с latency constraint
from ofa.nas.efficiency_predictor import Latency_MBV3_MeasuredNet
efficiency_predictor = Latency_MBV3_MeasuredNet(
    'note10',   # Samsung Note10 — реальные замеры латентности
    ofa_network
)

# Evolutionary search: ищем подсеть с latency < 25ms и max accuracy
from ofa.nas.search_algorithm.evolution_finder import EvolutionFinder
finder = EvolutionFinder(
    efficiency_constraint=25,           # ms
    efficiency_predictor=efficiency_predictor,
    accuracy_predictor=accuracy_predictor,
    population_size=100,
    max_time_budget=500                  # эволюционных шагов
)
best_valids, best_info = finder.run_evolution_search()

У реальному проекті: NAS під MobileNetV3-space для завдання класифікації виробничого шлюбу (640 480, 12 класів). Цільова платформа - NVIDIA Jetson Nano (4GB RAM, 128 CUDA cores). Обмеження: latency < 30ms при batch=1. Ручна архітектура MobileNetV3-Large давала 28.4ms та accuracy 91.3%. OFA-пошук за 6 годин знайшов підмережа: 22.1ms, accuracy 92.7%. Без жодної ручної зміни архітектури.

Практичний стек і коли NAS виправданий

Сценарій Підхід Час пошуку Інструмент
Image classification, стандартний DARTS / PC-DARTS 1–2 дні (4× A100) nni (Microsoft) або automl (torchvision)
Edge deployment (мобайл, MCU) OFA / MNasNet-style 6–24 години Once-for-All, TuNAS
NLP / Transformer architecture NAS-BERT, AutoFormer 2-5 днів Hugging Face NAS toolkit
Кастомні операції, custom hardware Predictor-based NAS 1-3 дні + 100 eval BANANAS, NASBOWL

Microsoft NNI — найбільш зрілий open-source фреймворк для NAS. Підтримує DARTS, ENAS, Random NAS, SPOS із коробки. Інтеграція з PyTorch та TensorFlow.

import nni
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.callbacks import LRSchedulerCallback

trainer = DartsTrainer(
    model=model,
    loss=nn.CrossEntropyLoss(),
    metrics=lambda output, target: accuracy(output, target, topk=(1,)),
    optimizer=optimizer,
    num_epochs=50,
    dataset_train=dataset_train,
    dataset_valid=dataset_valid,
    batch_size=64,
    log_frequency=10,
    callbacks=[LRSchedulerCallback(scheduler)]
)
trainer.fit()
# Получаем финальную архитектуру
export_result = trainer.export()

Коли NAS не потрібен. Якщо завдання стандартне і даних < 50k прикладів - візьміть передбачувану ResNet-50 або EfficientNet-B0 та fine-tune. NAS виправданий за: кастомних hardware-обмеження, нетипових вхідних даних (гіперспектральні знімки, специфічні модальності), необхідності кардинально зменшити модель без втрати якості.

Процес роботи

  1. Визначення search space - критично важливий крок: задаємо блоки, операції, діапазони каналів, максимальну глибину. Неправильний search space = поганий результат незалежно від алгоритму
  2. Вибір стратегії пошуку - DARTS для GPU-rich оточення, evolutionary для hardware-aware, predictor-based при обмеженому бюджеті оцінок
  3. Профілювання цільового заліза — реальні виміри latency/throughput для операцій із search space на production hardware
  4. Пошук та оцінка кандидатів — Weights & Biases для трекінгу, MLflow для зберігання знайдених архітектур
  5. Full training знайденої архітектури з нуля — ваги з фази пошуку не використовуються
  6. Validation на holdout set, профіль на production-залізі

Терміни: визначення search space та setup – 1 тиждень. Сам пошук – 1–5 днів обчислень. Full training кандидата + валідація - 1-2 тижні. Разом: 3-6 тижнів на повний NAS-цикл.