Training моделей распознавания лиц: ArcFace, CosFace, AdaFace
Распознавание лиц технически — задача metric learning: обучить embedding-пространство, где лица одного человека близко, а разных — далеко. Softmax-классификатор для этого не подходит — он не обобщается на новые identity, не представленные в train. ArcFace решает это через angular margin в пространстве embedding.
ArcFace loss — математика и реализация
ArcFace добавляет аддитивный угловой margin m к углу между embedding и соответствующим центром класса:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ArcFaceLoss(nn.Module):
def __init__(
self,
embedding_size: int = 512,
num_classes: int = 10000,
margin: float = 0.5, # угловой margin в радианах (~28.6°)
scale: float = 64.0 # масштаб логитов
):
super().__init__()
self.margin = margin
self.scale = scale
# Обучаемые центры классов (нормализованные)
self.weight = nn.Parameter(
torch.FloatTensor(num_classes, embedding_size)
)
nn.init.xavier_uniform_(self.weight)
self.cos_m = math.cos(margin)
self.sin_m = math.sin(margin)
self.th = math.cos(math.pi - margin) # порог для численной стабильности
self.mm = math.sin(math.pi - margin) * margin
def forward(
self,
embeddings: torch.Tensor, # (B, embedding_size), L2-нормализованные
labels: torch.Tensor # (B,)
) -> torch.Tensor:
# L2-нормализация весов
W = F.normalize(self.weight, dim=1)
# cos(θ) = emb · W^T
cosine = F.linear(embeddings, W) # (B, num_classes)
sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1))
# cos(θ + m) = cos(θ)cos(m) - sin(θ)sin(m)
phi = cosine * self.cos_m - sine * self.sin_m
# Numerical stability: если θ > π - m, используем косинусный penalty
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
# One-hot target mask
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, labels.view(-1, 1), 1)
# Заменяем logit только для правильного класса
output = one_hot * phi + (1.0 - one_hot) * cosine
output *= self.scale
return F.cross_entropy(output, labels)
Backbone и embedding
InsightFace / ArcFace обычно использует ResNet-50/100 или IResNet. Для production на мобильных устройствах — MobileFaceNet:
import timm
def build_face_recognition_model(
backbone: str = 'resnet50', # 'resnet100', 'mobilenetv3_small'
embedding_size: int = 512,
pretrained: bool = True
) -> nn.Module:
class FaceEmbedder(nn.Module):
def __init__(self):
super().__init__()
self.backbone = timm.create_model(
backbone,
pretrained=pretrained,
num_classes=0, # убираем classifier head
global_pool='avg'
)
feat_dim = self.backbone.num_features
self.bn = nn.BatchNorm1d(feat_dim)
self.drop = nn.Dropout(p=0.4)
self.fc = nn.Linear(feat_dim, embedding_size, bias=False)
self.bn2 = nn.BatchNorm1d(embedding_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
feat = self.backbone(x)
feat = self.bn(feat)
feat = self.drop(feat)
emb = self.fc(feat)
emb = self.bn2(emb)
return F.normalize(emb, dim=1) # L2-нормализация
return FaceEmbedder()
Сборка идентичностей: open-set recognition
В production система встречает новых людей, не бывших в train. Используем cosine similarity threshold:
import numpy as np
from scipy.spatial.distance import cosine
class FaceRecognitionSystem:
def __init__(
self,
model: nn.Module,
threshold: float = 0.4 # cosine distance; подбирается по ROC
):
self.model = model.eval()
self.threshold = threshold
self.gallery: dict[str, np.ndarray] = {} # id → embedding
def enroll(self, person_id: str, face_image: torch.Tensor) -> None:
"""Регистрация нового лица в галерее"""
with torch.no_grad():
emb = self.model(face_image.unsqueeze(0))
self.gallery[person_id] = emb.cpu().numpy().squeeze()
def identify(
self,
face_image: torch.Tensor,
top_k: int = 1
) -> list[dict]:
"""Поиск по галерее — 1:N идентификация"""
with torch.no_grad():
query_emb = self.model(face_image.unsqueeze(0))
query_np = query_emb.cpu().numpy().squeeze()
distances = {
person_id: cosine(query_np, gallery_emb)
for person_id, gallery_emb in self.gallery.items()
}
sorted_matches = sorted(distances.items(), key=lambda x: x[1])
results = []
for person_id, dist in sorted_matches[:top_k]:
results.append({
'identity': person_id if dist < self.threshold else 'unknown',
'distance': float(dist),
'confidence': float(1 - dist)
})
return results
Evaluation: TAR@FAR
Метрика для face recognition — не accuracy, а TAR@FAR (True Accept Rate при заданном False Accept Rate):
| Метрика | Значение | Применение |
|---|---|---|
| TAR@FAR=0.1% | 98.5%+ | Телефонная разблокировка |
| TAR@FAR=0.01% | 95%+ | Физический доступ |
| TAR@FAR=0.001% | 90%+ | Криминалистика |
| 1:1 Verification AUC | > 0.998 | Верификация документов |
Сравнение методов loss
| Loss | LFW Acc | IJB-C TAR@FAR=0.1% | Сложность | Применение |
|---|---|---|---|---|
| Softmax | 98.8% | 91.3% | Низкая | Закрытое множество |
| CosFace | 99.3% | 94.1% | Низкая | Стандарт |
| ArcFace | 99.5% | 95.6% | Низкая | Стандарт |
| AdaFace | 99.6% | 96.8% | Средняя | Низкое качество фото |
| ElasticFace | 99.6% | 96.4% | Средняя | Общий случай |
Сроки
| Задача | Срок |
|---|---|
| Fine-tuning ArcFace на корпоративные данные | 3–5 недель |
| Полная система 1:N с галереей | 5–8 недель |
| Кастомный пайплайн (detection + alignment + recognition) | 8–14 недель |







