AI-система анализа медицинских изображений
Медицинский CV — область с высокими требованиями: ошибка модели имеет клинические последствия. Это означает не просто высокую accuracy, но и калиброванную уверенность (модель должна знать, когда она не уверена), интерпретируемость (Grad-CAM, SHAP), соответствие регуляторным требованиям (MDR, FDA 510(k)), и обязательный human-in-the-loop для всех высоко-рисковых решений.
Architecture медицинской CV-системы
Медицинские снимки — разные модальности с разными пайплайнами:
- Рентген (X-ray): PNG/DICOM, 2D, патологии как изменения плотности
- КТ: DICOM 3D volume, HU-единицы, требует 3D-модели или 2.5D slice ensemble
- МРТ: несколько последовательностей (T1, T2, FLAIR), мультиканальный вход
- Гистология: WSI (Whole Slide Image), гигапиксельные изображения
- Дерматоскопия: обычные RGB, но специфические паттерны (ABCDE)
Предобработка DICOM
import pydicom
import numpy as np
import cv2
def dicom_to_array(
dcm_path: str,
target_modality: str = 'xray', # 'xray' | 'ct' | 'mri'
window_center: float = None,
window_width: float = None
) -> np.ndarray:
"""
Нормализация DICOM в диапазон [0, 255] uint8.
Для КТ обязательно windowing по HU.
"""
dcm = pydicom.dcmread(dcm_path)
array = dcm.pixel_array.astype(np.float32)
# Применяем линейное преобразование пикселей
slope = float(getattr(dcm, 'RescaleSlope', 1))
intercept = float(getattr(dcm, 'RescaleIntercept', 0))
array = array * slope + intercept
if target_modality == 'ct':
# Windowing по HU: лёгкие [-1000, 200], кости [-500, 1000]
wc = window_center or float(getattr(dcm, 'WindowCenter', -600))
ww = window_width or float(getattr(dcm, 'WindowWidth', 1500))
lower = wc - ww / 2
upper = wc + ww / 2
array = np.clip(array, lower, upper)
elif target_modality == 'xray':
# Для рентгена — нормализация по перцентилям (убираем артефакты)
p1, p99 = np.percentile(array, [1, 99])
array = np.clip(array, p1, p99)
# Нормализация в [0, 255]
arr_min, arr_max = array.min(), array.max()
if arr_max > arr_min:
array = (array - arr_min) / (arr_max - arr_min) * 255
return array.astype(np.uint8)
Детекция патологий на рентгене: CheXNet-подход
import torch
import torch.nn as nn
import timm
from torch.cuda.amp import autocast
# 14 классов патологий рентгена груди (CheXNet / CheXpert)
PATHOLOGY_CLASSES = [
'Atelectasis', 'Cardiomegaly', 'Consolidation', 'Edema',
'Enlarged Cardiomediastinum', 'Fracture', 'Lung Lesion',
'Lung Opacity', 'No Finding', 'Pleural Effusion',
'Pleural Other', 'Pneumonia', 'Pneumothorax', 'Support Devices'
]
class ChestXRayClassifier(nn.Module):
def __init__(
self,
backbone: str = 'densenet121',
num_classes: int = 14,
pretrained: bool = True
):
super().__init__()
self.backbone = timm.create_model(
backbone,
pretrained=pretrained,
num_classes=0, # убираем голову
global_pool='avg'
)
feat_dim = self.backbone.num_features
self.classifier = nn.Sequential(
nn.Linear(feat_dim, 512),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(512, num_classes)
# Sigmoid применяем отдельно — multi-label задача
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
features = self.backbone(x)
return self.classifier(features)
# Loss для multi-label classification с дисбалансом
class WeightedBCEWithLogitsLoss(nn.Module):
def __init__(self, pos_weights: torch.Tensor):
"""
pos_weights[i] = n_neg[i] / n_pos[i] для класса i.
CheXpert: типичный дисбаланс 15:1 - 100:1 для редких патологий.
"""
super().__init__()
self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weights)
def forward(self, logits, targets):
return self.loss_fn(logits, targets)
Grad-CAM для объяснимости
В медицине интерпретируемость обязательна — врач должен видеть, на каком регионе модель основывала решение:
import torch
import numpy as np
import cv2
class GradCAM:
def __init__(self, model: nn.Module, target_layer: nn.Module):
self.model = model
self.gradients = None
self.activations = None
target_layer.register_forward_hook(
lambda m, i, o: setattr(self, 'activations', o)
)
target_layer.register_backward_hook(
lambda m, gi, go: setattr(self, 'gradients', go[0])
)
def generate(
self,
image_tensor: torch.Tensor, # (1, C, H, W)
target_class: int,
original_size: tuple
) -> np.ndarray:
self.model.eval()
output = self.model(image_tensor)
self.model.zero_grad()
output[0, target_class].backward()
# Взвешивание активаций по градиентам
weights = self.gradients.mean(dim=[2, 3], keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = torch.relu(cam).squeeze().cpu().numpy()
# Нормализация и resize
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
cam = cv2.resize(cam, (original_size[1], original_size[0]))
return cam
Metrics для медицинской классификации
| Метрика | Использование | Почему не accuracy |
|---|---|---|
| AUC-ROC | Основная метрика | Устойчива к дисбалансу |
| Sensitivity (Recall) | Критична для скрининга | Пропустить болезнь — хуже |
| Specificity | Баланс с sensitivity | Ложные тревоги — нагрузка |
| F1 (micro/macro) | Multi-label задачи | Баланс P/R |
| Calibration (ECE) | Уверенность модели | Для клинического доверия |
Сроки
| Задача | Срок |
|---|---|
| Классификатор патологий рентгена (fine-tuning) | 4–6 недель |
| Детекция/сегментация на КТ/МРТ | 8–14 недель |
| Медицинская система с CE/FDA-документацией | 20–40 недель |







