Setting up dynamic batching for LLM
Dynamic batching combines multiple parallel queries into a single forward pass through the GPU. This is a key mechanism for high LLM throughput: the GPU is parallel and processes matrix multiplications more efficiently for large batches.
Why batching is critical for LLM
GPU A100 80GB at batch=1: ~30 tokens/sec for Llama-3-8B. At batch=16: ~300 tokens/sec (10x). At batch=64: ~900 tokens/sec (30x). There's no linear increase (overhead), but the gain is significant.
Without batching, with 100 concurrent users, each request is processed sequentially, resulting in hundreds of seconds of latency. With continuous batching, all requests are processed in parallel, resulting in seconds.
Continuous (In-flight) Batching in vLLM
vLLM implements continuous batching automatically:
# Ключевые параметры batching
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-8b-instruct \
--max-num-seqs 256 \ # максимальный concurrent batch
--max-num-batched-tokens 32768 \ # токены в одном forward pass
--scheduler-delay-factor 0.5 \ # задержка перед scheduling (для лучшей группировки)
--use-v2-block-manager \ # улучшенный менеджер памяти
--enable-chunked-prefill # чанкинг длинных prefill запросов
Chunked prefill: A long prefill (4K token system prompt) is split into chunks, which does not block the decoding of other requests:
--enable-chunked-prefill
--max-num-batched-tokens 8192 # max токенов в чанке
Setting up dynamic batching in TensorRT-LLM / Triton
# tensorrt_llm/config.pbtxt
parameters {
key: "max_tokens_in_paged_kv_cache"
value: { string_value: "40000" } # суммарный KV-кеш для всех sequences
}
parameters {
key: "batch_scheduler_policy"
value: { string_value: "guaranteed_no_evict" } # не вытесняем начатые запросы
}
parameters {
key: "executor_static_batch_size"
value: { string_value: "-1" } # -1 = dynamic batch
}
Manual implementation of batching
If you use your own inference server:
import asyncio
from dataclasses import dataclass
from collections import deque
import time
@dataclass
class PendingRequest:
id: str
prompt: str
max_tokens: int
future: asyncio.Future
enqueued_at: float
class DynamicBatchInferenceServer:
def __init__(
self,
model,
max_batch_size: int = 64,
max_wait_ms: float = 20.0, # ждём 20ms для набора батча
max_tokens_per_batch: int = 16384
):
self.model = model
self.max_batch_size = max_batch_size
self.max_wait_ms = max_wait_ms
self.max_tokens_per_batch = max_tokens_per_batch
self.queue: deque[PendingRequest] = deque()
self.lock = asyncio.Lock()
self._batch_worker_task = None
async def start(self):
self._batch_worker_task = asyncio.create_task(self._batch_worker())
async def predict(self, prompt: str, max_tokens: int = 512) -> str:
future = asyncio.get_event_loop().create_future()
request = PendingRequest(
id=str(time.time()),
prompt=prompt,
max_tokens=max_tokens,
future=future,
enqueued_at=time.time()
)
async with self.lock:
self.queue.append(request)
return await future
async def _batch_worker(self):
while True:
await asyncio.sleep(self.max_wait_ms / 1000)
async with self.lock:
if not self.queue:
continue
# Формируем батч
batch: list[PendingRequest] = []
total_tokens = 0
while (self.queue
and len(batch) < self.max_batch_size
and total_tokens + self.queue[0].max_tokens <= self.max_tokens_per_batch):
req = self.queue.popleft()
batch.append(req)
total_tokens += len(req.prompt.split()) + req.max_tokens
if not batch:
continue
# Инференс батча
prompts = [req.prompt for req in batch]
max_tokens_list = [req.max_tokens for req in batch]
try:
outputs = self.model.generate_batch(prompts, max(max_tokens_list))
for req, output in zip(batch, outputs):
if not req.future.done():
req.future.set_result(output)
except Exception as e:
for req in batch:
if not req.future.done():
req.future.set_exception(e)
Batching monitoring
vllm:num_requests_running # запросов в активном батче
vllm:num_requests_waiting # запросов в очереди
vllm:avg_prompt_throughput_toks_per_s # tokens/s для prefill
vllm:avg_generation_throughput_toks_per_s # tokens/s для decode
The optimal batch size for a specific GPU is determined through benchmarking: we run a load test with different numbers of concurrent users and measure the throughput vs. latency trade-off.







