Обучение модели классификации текста (BERT, RoBERTa, DeBERTa)
Fine-tuning предобученных трансформеров — стандартный путь к высококачественному классификатору текста. BERT, RoBERTa, DeBERTa — три поколения, каждое лучше предыдущего по ряду параметров.
Выбор базовой модели
BERT (bert-base-uncased, DeepPavlov/rubert-base-cased): классика, хорошо изучена, много туториалов. Для большинства задач достаточно.
RoBERTa (roberta-base, ai-forever/ruRoBERTa-large): улучшенное обучение без Next Sentence Prediction, на большем корпусе. Обычно на 1–3% лучше BERT.
DeBERTa (microsoft/deberta-v3-base): диcентанглированное внимание — лучшее качество на benchmarks. Рекомендуется если нужна максимальная точность и есть GPU-ресурсы.
Для русского языка: ai-forever/ruBert-base, DeepPavlov/rubert-base-cased, ai-forever/ruRoBERTa-large, ai-forever/sber-roberta-large.
Pipeline обучения
from transformers import (
AutoTokenizer, AutoModelForSequenceClassification,
TrainingArguments, Trainer
)
from datasets import Dataset
import evaluate
import numpy as np
# Подготовка данных
tokenizer = AutoTokenizer.from_pretrained("DeepPavlov/rubert-base-cased")
def tokenize_function(examples):
return tokenizer(
examples["text"],
padding="max_length",
truncation=True,
max_length=256 # 256 достаточно для большинства задач
)
dataset = Dataset.from_pandas(df)
tokenized = dataset.map(tokenize_function, batched=True)
tokenized = tokenized.train_test_split(test_size=0.2)
# Инициализация модели
model = AutoModelForSequenceClassification.from_pretrained(
"DeepPavlov/rubert-base-cased",
num_labels=num_classes,
id2label=id2label,
label2id=label2id
)
# Метрики
accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return {
"accuracy": accuracy.compute(predictions=predictions, references=labels)["accuracy"],
"f1_macro": f1.compute(predictions=predictions, references=labels, average="macro")["f1"],
}
# Параметры обучения
training_args = TrainingArguments(
output_dir="./results",
num_train_epochs=5,
per_device_train_batch_size=16,
per_device_eval_batch_size=32,
learning_rate=2e-5,
weight_decay=0.01,
warmup_ratio=0.1,
evaluation_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="f1_macro",
fp16=True, # mixed precision для GPU
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized["train"],
eval_dataset=tokenized["test"],
compute_metrics=compute_metrics,
)
trainer.train()
Гиперпараметры и их влияние
| Параметр | Рекомендуемый диапазон | Влияние |
|---|---|---|
| learning_rate | 1e-5 – 5e-5 | Самый критичный. 2e-5 — хороший старт |
| num_epochs | 3–10 | Переобучение при > 10 |
| batch_size | 8–32 | Больше = стабильнее, но нужно больше VRAM |
| max_length | 64–512 | Зависит от длины текстов |
| warmup_ratio | 0.06–0.1 | Предотвращает нестабильное начало |
Обработка несбалансированных классов
from torch import nn
import torch
# Вычисляем веса классов
class_weights = compute_class_weight("balanced", classes=np.unique(labels), y=labels)
weights_tensor = torch.FloatTensor(class_weights).to(device)
class WeightedTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False):
labels = inputs.pop("labels")
outputs = model(**inputs)
logits = outputs.get("logits")
loss_fn = nn.CrossEntropyLoss(weight=weights_tensor)
loss = loss_fn(logits, labels)
return (loss, outputs) if return_outputs else loss
Оценка и анализ ошибок
После обучения обязательно:
- Confusion matrix по всем классам
- Примеры ошибок для каждой пары (истинный класс, предсказанный класс)
- Calibration plot: насколько достоверны вероятности модели
- Error analysis: есть ли паттерн в ошибках? (определённые слова, длина текста, авторский стиль)
Оптимизация для продакшена
После fine-tuning экспортировать в ONNX:
from optimum.onnxruntime import ORTModelForSequenceClassification
ort_model = ORTModelForSequenceClassification.from_pretrained("./results", export=True)
ort_model.save_pretrained("./onnx_model")
Benchmark: ruBERT fine-tuned → ONNX INT8: 120ms → 18ms на CPU при точности -0.3%.
Типичные результаты
На задачах классификации новостей: 92–96% F1 macro. Классификация обращений клиентов: 88–94%. Мультиметочная классификация: 78–86% Micro F1.







