Гибкое управление экспериментами по детекции волейбольного мяча с использованием YAML-конфигураций

Гибкое управление экспериментами по детекции волейбольного мяча с использованием YAML-конфигураций

2025/08/19

Введение

В задачах компьютерного зрения, таких как детекция волейбольного мяча на видео, важна возможность быстро экспериментировать с различными архитектурами моделей, гиперпараметрами и настройками. В этой статье описан подход к организации проекта для детекции мяча с использованием модели TrackNetV4 и её модификаций (включая версии с GRU, графовыми нейронными сетями (GNN) и механизмами внимания). Основной акцент сделан на гибкое управление экспериментами через YAML-конфигурации, что обеспечивает воспроизводимость, масштабируемость и удобство сравнения результатов.

Задача: модель принимает 9 grayscale-кадров размером 512x288, возвращает 9 heatmap-предсказаний для каждого кадра, должна работать на CPU с частотой 60+ FPS в формате ONNX, а также использовать карты внимания для улучшения детекции движущегося мяча.

Структура проекта

Для упрощения экспериментов и поддержки кода проект организован следующим образом:

ball-detection/
├── models/                     # Все модели (TrackNetV2, V4, GRU, GNN, Attention)
│   ├── __init__.py
│   ├── tracknet_v2.py
│   ├── tracknet_v4.py
│   ├── tracknet_v4_gru.py
│   ├── tracknet_v4_gnn.py
│   ├── tracknet_v4_attention.py
│   └── base_model.py           # Базовый класс для всех моделей
├── configs/                    # Конфигурации для экспериментов
│   ├── tracknet_v2.yaml
│   ├── tracknet_v4_gru.yaml
│   ├── tracknet_v4_attention.yaml
│   └── default.yaml
├── data/                       # Датасет и предобработка
│   ├── dataset.py
│   ├── transforms.py
│   └── loader.py
├── experiments/                # Результаты тренировок (логи, чекпоинты)
│   ├── exp_001_tracknet_v2/
│   ├── exp_002_tracknet_v4_gru/
│   └── exp_003_attention/
├── scripts/
│   ├── train.py                # Скрипт обучения
│   ├── eval.py                 # Оценка модели
│   ├── export_onnx.py          # Экспорт в ONNX
│   └── infer.py                # Инференс для тестирования FPS
├── utils/
│   ├── metrics.py              # Метрики (mAP, PCK, heatmap loss)
│   ├── visualization.py
│   ├── onnx_exporter.py
│   └── fps_benchmark.py        # Тестирование скорости на CPU
├── notebooks/                  # Анализ данных и визуализация
├── logs/                       # Логи (TensorBoard, WandB)
├── .gitignore
├── requirements.txt
└── README.md

Управление версиями моделей

Единый интерфейс для моделей

Для упрощения работы с разными архитектурами создан базовый класс BaseBallModel, от которого наследуются все модели:


# models/base_model.py
import torch.nn as nn

class BaseBallModel(nn.Module):
    def __init__(self, in_frames=9, img_h=288, img_w=512):
        super().__init__()
        self.in_frames = in_frames
        self.img_h = img_h
        self.img_w = img_w

    def forward(self, x):
        raise NotImplementedError

    def get_config(self):
        return {
            "model_type": self.__class__.__name__,
            "in_frames": self.in_frames,
            "img_h": self.img_h,
            "img_w": self.img_w
        }

Пример реализации модели с механизмом внимания:

 # models/base_model.py
import torch.nn as nn

class BaseBallModel(nn.Module):
    def __init__(self, in_frames=9, img_h=288, img_w=512):
        super().__init__()
        self.in_frames = in_frames
        self.img_h = img_h
        self.img_w = img_w

    def forward(self, x):
        raise NotImplementedError

    def get_config(self):
        return {
            "model_type": self.__class__.__name__,
            "in_frames": self.in_frames,
            "img_h": self.img_h,
            "img_w": self.img_w
        }     `

Динамическая загрузка моделей

Для удобства переключения между моделями используется модуль get_model:

# models/__init__.py
from .tracknet_v2 import TrackNetV2
from .tracknet_v4_attention import TrackNetV4Attention
from .tracknet_v4_gru import TrackNetV4GRU
from .tracknet_v4_gnn import TrackNetV4GNN

def get_model(name, **kwargs):
    available_models = {
        "TrackNetV2": TrackNetV2,
        "TrackNetV4Attention": TrackNetV4Attention,
        "TrackNetV4GRU": TrackNetV4GRU,
        "TrackNetV4GNN": TrackNetV4GNN,
    }
    if name not in available_models:
        raise ValueError(f"Model {name} not found")
    return available_models[name](**kwargs)

YAML-конфигурации

Каждый эксперимент описывается в отдельном YAML-файле. Пример tracknetv4attention.yaml:

model:
  type: TrackNetV4Attention
  in_frames: 9
  img_h: 288
  img_w: 512

train:
  epochs: 100
  batch_size: 8
  lr: 1e-4
  optimizer: Adam
  scheduler: StepLR

data:
  dataset_path: /data/volleyball/
  augment: true

eval:
  iou_threshold: 0.5

onnx:
  opset: 13
  dynamic_axes: false

Кто использует YAML?

  1. train.py — читает параметры модели, оптимизатора, датасета и гиперпараметров для обучения.
  2. eval.py — использует конфигурацию для воспроизведения условий валидации (например, iou_threshold).
  3. export_onnx.py — загружает модель с параметрами из YAML перед экспортом в ONNX.
  4. infer.py — применяет размеры входных данных и настройки предобработки.
  5. fps_benchmark.py — использует размеры входа для тестирования производительности.

Пример запуска:

python scripts/train.py --config configs/tracknet_v4_attention.yaml
python scripts/eval.py --config configs/tracknet_v4_attention.yaml --ckpt experiments/exp003/best.pth
python scripts/export_onnx.py --config configs/tracknet_v4_attention.yaml --ckpt best.pth --output model.onnx

Обучение

Скрипт train.py обеспечивает единый пайплайн для всех моделей:

# scripts/train.py
import yaml
import torch
from models import get_model
from data import BallDataset
from utils import AverageMeter, save_checkpoint

def train(config_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)

    model_cls = get_model(config['model']['type'])
    model = model_cls(**config['model'])
    dataloader = ...
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config['train']['lr'])

    for epoch in range(config['train']['epochs']):
        model.train()
        for batch in dataloader:
            x, y = batch
            pred = model(x)
            loss = criterion(pred, y)
            loss.backward()
            optimizer.step()
        save_checkpoint(model, epoch, config)

Сохранение чекпоинтов включает конфигурацию для воспроизводимости:

torch.save({
    'state_dict': model.state_dict(),
    'config': config,
    'epoch': epoch,
    'loss': loss
}, 'checkpoint.pth')

Экспорт в ONNX и оптимизация

Для достижения 60+ FPS на CPU модель экспортируется в ONNX:

# scripts/export_onnx.py
import torch
from models import get_model

def export_model(config_path, checkpoint_path, output_path):
    with open(config_path) as f:
        config = yaml.safe_load(f)

    model = get_model(config['model']['type'])(**config['model'])
    state = torch.load(checkpoint_path, map_location="cpu")
    model.load_state_dict(state)
    model.eval()

    dummy_input = torch.randn(1, 9, 288, 512)
    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        opset_version=13,
        input_names=["input"],
        output_names=["heatmaps"],
        dynamic_axes=None
    )

Для оптимизации производительности:

  • Используйте ONNX Runtime с максимальным уровнем оптимизации (--opt_level 99).
  • Рассмотрите квантование (int8).
  • Экспериментируйте с лёгкими энкодерами (например, MobileNet вместо VGG).

Сравнение моделей

Результаты экспериментов сохраняются в results.csv:

exp_id model params(M) FPS(CPU) mAP@0.5 PCK@5px size(MB) notes
001 TrackNetV2 1.8 72 0.85 0.91 7.2 Базовая модель
002 V4 + GRU 3.1 48 0.89 0.93 12.4 Лучший трекинг
003 V4 + Attention 2.2 65 0.91 0.94 8.8 Лучший баланс

Тестирование FPS

# utils/fps_benchmark.py
import torch
import time

def benchmark_fps(model, device="cpu", seq_len=9, H=288, W=512, warmup=10, runs=100):
    model.to(device)
    model.eval()
    x = torch.randn(1, seq_len, H, W).to(device)

    for _ in range(warmup):
        with torch.no_grad():
            _ = model(x)

    start = time.time()
    for _ in range(runs):
        with torch.no_grad():
            _ = model(x)
    end = time.time()

    return runs / (end - start)

Рекомендации по архитектурам

  1. Attention: Лёгкие механизмы (Channel/Spatial Attention) встраиваются в skip-connections TrackNet, улучшая детекцию движущихся объектов.
  2. GRU: Полезен для обработки временных последовательностей, но увеличивает число параметров и снижает FPS.
  3. GNN: Подходит для сложных сценариев с перекрытиями, но менее эффективен для достижения 60 FPS на CPU.