Налаштування динамічного батчингу для LLM
Dynamic batching - об'єднання кількох паралельних запитів в один forward pass через GPU. Ключовий механізм для високого через LLM: GPU паралельний і обробляє матричні множення ефективніше для великих батчів.
Чому батчинг критичний для LLM
GPU A100 80GB у batch=1: ~30 tokens/sec для Llama-3-8B. При batch=16: ~300 tokens/sec (10x). При batch=64: ~900 tokens/sec (30x). Лінійного зростання немає (накладні витрати), але приріст значний.
Без batching у 100 concurrent користувачів: кожен запит обробляється послідовно → сотні секунд очікування. З continuous batching: всі запити обробляються паралельно → секунди.
Continuous (In-flight) Batching в vLLM
vLLM реалізує continuous batching автоматично:
# Ключевые параметры batching
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-8b-instruct \
--max-num-seqs 256 \ # максимальный concurrent batch
--max-num-batched-tokens 32768 \ # токены в одном forward pass
--scheduler-delay-factor 0.5 \ # задержка перед scheduling (для лучшей группировки)
--use-v2-block-manager \ # улучшенный менеджер памяти
--enable-chunked-prefill # чанкинг длинных prefill запросов
Chunked prefill: довгий prefill (системний промпт 4K токенів) розбивається на чанки, що не блокує decode інших запитів:
--enable-chunked-prefill
--max-num-batched-tokens 8192 # max токенов в чанке
Налаштування динамічного батчингу в TensorRT-LLM / Triton
# tensorrt_llm/config.pbtxt
parameters {
key: "max_tokens_in_paged_kv_cache"
value: { string_value: "40000" } # суммарный KV-кеш для всех sequences
}
parameters {
key: "batch_scheduler_policy"
value: { string_value: "guaranteed_no_evict" } # не вытесняем начатые запросы
}
parameters {
key: "executor_static_batch_size"
value: { string_value: "-1" } # -1 = dynamic batch
}
Ручна реалізація батчингу
Якщо використовується власний inference server:
import asyncio
from dataclasses import dataclass
from collections import deque
import time
@dataclass
class PendingRequest:
id: str
prompt: str
max_tokens: int
future: asyncio.Future
enqueued_at: float
class DynamicBatchInferenceServer:
def __init__(
self,
model,
max_batch_size: int = 64,
max_wait_ms: float = 20.0, # ждём 20ms для набора батча
max_tokens_per_batch: int = 16384
):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.max_tokens_per_batch = max_tokens_per_batch
self.queue: deque[PendingRequest] = deque()
self.lock = asyncio.Lock()
self._batch_worker_task = None
async def start(self):
self._batch_worker_task = asyncio.create_task(self._batch_worker())
async def predict(self, prompt: str, max_tokens: int = 512) -> str:
future = asyncio.get_event_loop().create_future()
request = PendingRequest(
id=str(time.time()),
prompt=prompt,
max_tokens=max_tokens,
future=future,
enqueued_at=time.time()
)
async with self.lock:
self.queue.append(request)
return await future
async def _batch_worker(self):
while True:
await asyncio.sleep(self.max_wait_ms / 1000)
async with self.lock:
if not self.queue:
continue
# Формируем батч
batch: list[PendingRequest] = []
total_tokens = 0
while (self.queue
and len(batch) < self.max_batch_size
and total_tokens + self.queue[0].max_tokens <= self.max_tokens_per_batch):
req = self.queue.popleft()
batch.append(req)
total_tokens += len(req.prompt.split()) + req.max_tokens
if not batch:
continue
# Инференс батча
prompts = [req.prompt for req in batch]
max_tokens_list = [req.max_tokens for req in batch]
try:
outputs = self.model.generate_batch(prompts, max(max_tokens_list))
for req, output in zip(batch, outputs):
if not req.future.done():
req.future.set_result(output)
except Exception as e:
for req in batch:
if not req.future.done():
req.future.set_exception(e)
Моніторинг батчингу
vllm:num_requests_running # запросов в активном батче
vllm:num_requests_waiting # запросов в очереди
vllm:avg_prompt_throughput_toks_per_s # tokens/s для prefill
vllm:avg_generation_throughput_toks_per_s # tokens/s для decode
Оптимальний batch size для конкретного GPU визначається через бенчмаркінг: запускаємо тест навантаження з різним числом concurrent користувачів і вимірюємо черезput vs latency trade-off.







