Гибкое управление экспериментами по детекции волейбольного мяча с использованием YAML-конфигураций
2025/08/19
Введение
В задачах компьютерного зрения, таких как детекция волейбольного мяча на видео, важна возможность быстро экспериментировать с различными архитектурами моделей, гиперпараметрами и настройками. В этой статье описан подход к организации проекта для детекции мяча с использованием модели TrackNetV4 и её модификаций (включая версии с GRU, графовыми нейронными сетями (GNN) и механизмами внимания). Основной акцент сделан на гибкое управление экспериментами через YAML-конфигурации, что обеспечивает воспроизводимость, масштабируемость и удобство сравнения результатов.
Задача: модель принимает 9 grayscale-кадров размером 512x288, возвращает 9 heatmap-предсказаний для каждого кадра, должна работать на CPU с частотой 60+ FPS в формате ONNX, а также использовать карты внимания для улучшения детекции движущегося мяча.
Структура проекта
Для упрощения экспериментов и поддержки кода проект организован следующим образом:
ball-detection/
├── models/ # Все модели (TrackNetV2, V4, GRU, GNN, Attention)
│ ├── __init__.py
│ ├── tracknet_v2.py
│ ├── tracknet_v4.py
│ ├── tracknet_v4_gru.py
│ ├── tracknet_v4_gnn.py
│ ├── tracknet_v4_attention.py
│ └── base_model.py # Базовый класс для всех моделей
├── configs/ # Конфигурации для экспериментов
│ ├── tracknet_v2.yaml
│ ├── tracknet_v4_gru.yaml
│ ├── tracknet_v4_attention.yaml
│ └── default.yaml
├── data/ # Датасет и предобработка
│ ├── dataset.py
│ ├── transforms.py
│ └── loader.py
├── experiments/ # Результаты тренировок (логи, чекпоинты)
│ ├── exp_001_tracknet_v2/
│ ├── exp_002_tracknet_v4_gru/
│ └── exp_003_attention/
├── scripts/
│ ├── train.py # Скрипт обучения
│ ├── eval.py # Оценка модели
│ ├── export_onnx.py # Экспорт в ONNX
│ └── infer.py # Инференс для тестирования FPS
├── utils/
│ ├── metrics.py # Метрики (mAP, PCK, heatmap loss)
│ ├── visualization.py
│ ├── onnx_exporter.py
│ └── fps_benchmark.py # Тестирование скорости на CPU
├── notebooks/ # Анализ данных и визуализация
├── logs/ # Логи (TensorBoard, WandB)
├── .gitignore
├── requirements.txt
└── README.md
Управление версиями моделей
Единый интерфейс для моделей
Для упрощения работы с разными архитектурами создан базовый класс BaseBallModel, от которого наследуются все модели:
# models/base_model.py
import torch.nn as nn
class BaseBallModel(nn.Module):
def __init__(self, in_frames=9, img_h=288, img_w=512):
super().__init__()
self.in_frames = in_frames
self.img_h = img_h
self.img_w = img_w
def forward(self, x):
raise NotImplementedError
def get_config(self):
return {
"model_type": self.__class__.__name__,
"in_frames": self.in_frames,
"img_h": self.img_h,
"img_w": self.img_w
}
Пример реализации модели с механизмом внимания:
# models/base_model.py
import torch.nn as nn
class BaseBallModel(nn.Module):
def __init__(self, in_frames=9, img_h=288, img_w=512):
super().__init__()
self.in_frames = in_frames
self.img_h = img_h
self.img_w = img_w
def forward(self, x):
raise NotImplementedError
def get_config(self):
return {
"model_type": self.__class__.__name__,
"in_frames": self.in_frames,
"img_h": self.img_h,
"img_w": self.img_w
} `
Динамическая загрузка моделей
Для удобства переключения между моделями используется модуль get_model:
# models/__init__.py
from .tracknet_v2 import TrackNetV2
from .tracknet_v4_attention import TrackNetV4Attention
from .tracknet_v4_gru import TrackNetV4GRU
from .tracknet_v4_gnn import TrackNetV4GNN
def get_model(name, **kwargs):
available_models = {
"TrackNetV2": TrackNetV2,
"TrackNetV4Attention": TrackNetV4Attention,
"TrackNetV4GRU": TrackNetV4GRU,
"TrackNetV4GNN": TrackNetV4GNN,
}
if name not in available_models:
raise ValueError(f"Model {name} not found")
return available_models[name](**kwargs)
YAML-конфигурации
Каждый эксперимент описывается в отдельном YAML-файле. Пример tracknetv4attention.yaml:
model:
type: TrackNetV4Attention
in_frames: 9
img_h: 288
img_w: 512
train:
epochs: 100
batch_size: 8
lr: 1e-4
optimizer: Adam
scheduler: StepLR
data:
dataset_path: /data/volleyball/
augment: true
eval:
iou_threshold: 0.5
onnx:
opset: 13
dynamic_axes: false
Кто использует YAML?
- train.py — читает параметры модели, оптимизатора, датасета и гиперпараметров для обучения.
- eval.py — использует конфигурацию для воспроизведения условий валидации (например, iou_threshold).
- export_onnx.py — загружает модель с параметрами из YAML перед экспортом в ONNX.
- infer.py — применяет размеры входных данных и настройки предобработки.
- fps_benchmark.py — использует размеры входа для тестирования производительности.
Пример запуска:
python scripts/train.py --config configs/tracknet_v4_attention.yaml
python scripts/eval.py --config configs/tracknet_v4_attention.yaml --ckpt experiments/exp003/best.pth
python scripts/export_onnx.py --config configs/tracknet_v4_attention.yaml --ckpt best.pth --output model.onnx
Обучение
Скрипт train.py обеспечивает единый пайплайн для всех моделей:
# scripts/train.py
import yaml
import torch
from models import get_model
from data import BallDataset
from utils import AverageMeter, save_checkpoint
def train(config_path):
with open(config_path) as f:
config = yaml.safe_load(f)
model_cls = get_model(config['model']['type'])
model = model_cls(**config['model'])
dataloader = ...
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])
for epoch in range(config['train']['epochs']):
model.train()
for batch in dataloader:
x, y = batch
pred = model(x)
loss = criterion(pred, y)
loss.backward()
optimizer.step()
save_checkpoint(model, epoch, config)
Сохранение чекпоинтов включает конфигурацию для воспроизводимости:
torch.save({
'state_dict': model.state_dict(),
'config': config,
'epoch': epoch,
'loss': loss
}, 'checkpoint.pth')
Экспорт в ONNX и оптимизация
Для достижения 60+ FPS на CPU модель экспортируется в ONNX:
# scripts/export_onnx.py
import torch
from models import get_model
def export_model(config_path, checkpoint_path, output_path):
with open(config_path) as f:
config = yaml.safe_load(f)
model = get_model(config['model']['type'])(**config['model'])
state = torch.load(checkpoint_path, map_location="cpu")
model.load_state_dict(state)
model.eval()
dummy_input = torch.randn(1, 9, 288, 512)
torch.onnx.export(
model,
dummy_input,
output_path,
opset_version=13,
input_names=["input"],
output_names=["heatmaps"],
dynamic_axes=None
)
Для оптимизации производительности:
- Используйте ONNX Runtime с максимальным уровнем оптимизации (--opt_level 99).
- Рассмотрите квантование (int8).
- Экспериментируйте с лёгкими энкодерами (например, MobileNet вместо VGG).
Сравнение моделей
Результаты экспериментов сохраняются в results.csv:
| exp_id | model | params(M) | FPS(CPU) | mAP@0.5 | PCK@5px | size(MB) | notes |
|---|---|---|---|---|---|---|---|
| 001 | TrackNetV2 | 1.8 | 72 | 0.85 | 0.91 | 7.2 | Базовая модель |
| 002 | V4 + GRU | 3.1 | 48 | 0.89 | 0.93 | 12.4 | Лучший трекинг |
| 003 | V4 + Attention | 2.2 | 65 | 0.91 | 0.94 | 8.8 | Лучший баланс |
Тестирование FPS
# utils/fps_benchmark.py
import torch
import time
def benchmark_fps(model, device="cpu", seq_len=9, H=288, W=512, warmup=10, runs=100):
model.to(device)
model.eval()
x = torch.randn(1, seq_len, H, W).to(device)
for _ in range(warmup):
with torch.no_grad():
_ = model(x)
start = time.time()
for _ in range(runs):
with torch.no_grad():
_ = model(x)
end = time.time()
return runs / (end - start)
Рекомендации по архитектурам
- Attention: Лёгкие механизмы (Channel/Spatial Attention) встраиваются в skip-connections TrackNet, улучшая детекцию движущихся объектов.
- GRU: Полезен для обработки временных последовательностей, но увеличивает число параметров и снижает FPS.
- GNN: Подходит для сложных сценариев с перекрытиями, но менее эффективен для достижения 60 FPS на CPU.