Spaces:
Sleeping
Sleeping
✅ [Add] test, increase test coverage for dev mode
Browse files- tests/conftest.py +91 -0
- tests/test_tools/test_data_loader.py +71 -0
- tests/test_tools/test_dataset_preparation.py +29 -0
- tests/test_tools/test_drawer.py +29 -0
- tests/test_tools/test_solver.py +39 -83
- yolo/tools/data_loader.py +9 -8
- yolo/tools/dataset_preparation.py +1 -11
- yolo/tools/drawer.py +4 -1
- yolo/utils/logging_utils.py +2 -1
tests/conftest.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
import torch
|
| 6 |
+
from hydra import compose, initialize
|
| 7 |
+
|
| 8 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 9 |
+
sys.path.append(str(project_root))
|
| 10 |
+
|
| 11 |
+
from yolo import Config, Vec2Box, create_model
|
| 12 |
+
from yolo.model.yolo import YOLO
|
| 13 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
|
| 14 |
+
from yolo.utils.logging_utils import ProgressLogger, set_seed
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def pytest_configure(config):
|
| 18 |
+
config.addinivalue_line("markers", "requires_cuda: mark test to run only if CUDA is available")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_cfg(overrides=[]) -> Config:
|
| 22 |
+
config_path = "../yolo/config"
|
| 23 |
+
with initialize(config_path=config_path, version_base=None):
|
| 24 |
+
cfg: Config = compose(config_name="config", overrides=overrides)
|
| 25 |
+
set_seed(cfg.lucky_number)
|
| 26 |
+
return cfg
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@pytest.fixture(scope="session")
|
| 30 |
+
def train_cfg() -> Config:
|
| 31 |
+
return get_cfg(overrides=["task=train", "dataset=mock"])
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@pytest.fixture(scope="session")
|
| 35 |
+
def validation_cfg():
|
| 36 |
+
return get_cfg(overrides=["task=validation", "dataset=mock"])
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@pytest.fixture(scope="session")
|
| 40 |
+
def inference_cfg():
|
| 41 |
+
return get_cfg(overrides=["task=inference"])
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@pytest.fixture(scope="session")
|
| 45 |
+
def device():
|
| 46 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@pytest.fixture(scope="session")
|
| 50 |
+
def train_progress_logger(train_cfg: Config):
|
| 51 |
+
progress_logger = ProgressLogger(train_cfg, exp_name=train_cfg.name)
|
| 52 |
+
return progress_logger
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@pytest.fixture(scope="session")
|
| 56 |
+
def validation_progress_logger(validation_cfg: Config):
|
| 57 |
+
progress_logger = ProgressLogger(validation_cfg, exp_name=validation_cfg.name)
|
| 58 |
+
return progress_logger
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
@pytest.fixture(scope="session")
|
| 62 |
+
def model(train_cfg: Config, device) -> YOLO:
|
| 63 |
+
model = create_model(train_cfg.model)
|
| 64 |
+
return model.to(device)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@pytest.fixture(scope="session")
|
| 68 |
+
def vec2box(train_cfg: Config, model: YOLO, device) -> Vec2Box:
|
| 69 |
+
vec2box = Vec2Box(model, train_cfg.image_size, device)
|
| 70 |
+
return vec2box
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@pytest.fixture(scope="session")
|
| 74 |
+
def train_dataloader(train_cfg: Config):
|
| 75 |
+
return YoloDataLoader(train_cfg.task.data, train_cfg.dataset, train_cfg.task.task)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
@pytest.fixture(scope="session")
|
| 79 |
+
def validation_dataloader(validation_cfg: Config):
|
| 80 |
+
return YoloDataLoader(validation_cfg.task.data, validation_cfg.dataset, validation_cfg.task.task)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
@pytest.fixture(scope="session")
|
| 84 |
+
def file_stream_data_loader(inference_cfg: Config):
|
| 85 |
+
return StreamDataLoader(inference_cfg.task.data)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
@pytest.fixture(scope="session")
|
| 89 |
+
def directory_stream_data_loader(inference_cfg: Config):
|
| 90 |
+
inference_cfg.task.data.source = "tests/data/images/train"
|
| 91 |
+
return StreamDataLoader(inference_cfg.task.data)
|
tests/test_tools/test_data_loader.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import pytest
|
| 5 |
+
|
| 6 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
| 7 |
+
sys.path.append(str(project_root))
|
| 8 |
+
|
| 9 |
+
from yolo.config.config import Config, TrainConfig
|
| 10 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader, create_dataloader
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_create_dataloader_cache(train_cfg: Config):
|
| 14 |
+
train_cfg.task.data.shuffle = False
|
| 15 |
+
train_cfg.task.data.batch_size = 2
|
| 16 |
+
|
| 17 |
+
cache_file = Path("tests/data/train.cache")
|
| 18 |
+
cache_file.unlink(missing_ok=True)
|
| 19 |
+
|
| 20 |
+
make_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
|
| 21 |
+
load_cache_loader = create_dataloader(train_cfg.task.data, train_cfg.dataset)
|
| 22 |
+
m_batch_size, m_images, _, m_reverse_tensors, m_image_paths = next(iter(make_cache_loader))
|
| 23 |
+
l_batch_size, l_images, _, l_reverse_tensors, l_image_paths = next(iter(load_cache_loader))
|
| 24 |
+
assert m_batch_size == l_batch_size
|
| 25 |
+
assert m_images.shape == l_images.shape
|
| 26 |
+
assert m_reverse_tensors.shape == l_reverse_tensors.shape
|
| 27 |
+
assert m_image_paths == l_image_paths
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_training_data_loader_correctness(train_dataloader: YoloDataLoader):
|
| 31 |
+
"""Test that the training data loader produces correctly shaped data and metadata."""
|
| 32 |
+
batch_size, images, _, reverse_tensors, image_paths = next(iter(train_dataloader))
|
| 33 |
+
assert batch_size == 2
|
| 34 |
+
assert images.shape == (2, 3, 640, 640)
|
| 35 |
+
assert reverse_tensors.shape == (2, 5)
|
| 36 |
+
expected_paths = [
|
| 37 |
+
Path("tests/data/images/train/000000050725.jpg"),
|
| 38 |
+
Path("tests/data/images/train/000000167848.jpg"),
|
| 39 |
+
]
|
| 40 |
+
assert image_paths == expected_paths
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def test_validation_data_loader_correctness(validation_dataloader: YoloDataLoader):
|
| 44 |
+
batch_size, images, targets, reverse_tensors, image_paths = next(iter(validation_dataloader))
|
| 45 |
+
assert batch_size == 4
|
| 46 |
+
assert images.shape == (4, 3, 640, 640)
|
| 47 |
+
assert targets.shape == (4, 18, 5)
|
| 48 |
+
assert reverse_tensors.shape == (4, 5)
|
| 49 |
+
expected_paths = [
|
| 50 |
+
Path("tests/data/images/val/000000151480.jpg"),
|
| 51 |
+
Path("tests/data/images/val/000000284106.jpg"),
|
| 52 |
+
Path("tests/data/images/val/000000323571.jpg"),
|
| 53 |
+
Path("tests/data/images/val/000000570456.jpg"),
|
| 54 |
+
]
|
| 55 |
+
assert image_paths == expected_paths
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_file_stream_data_loader_frame(file_stream_data_loader: StreamDataLoader):
|
| 59 |
+
"""Test the frame output from the file stream data loader."""
|
| 60 |
+
frame, rev_tensor, origin_frame = next(iter(file_stream_data_loader))
|
| 61 |
+
assert frame.shape == (1, 3, 640, 640)
|
| 62 |
+
assert rev_tensor.shape == (1, 5)
|
| 63 |
+
assert origin_frame.size == (1024, 768)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def test_directory_stream_data_loader_frame(directory_stream_data_loader: StreamDataLoader):
|
| 67 |
+
"""Test the frame output from the directory stream data loader."""
|
| 68 |
+
frame, rev_tensor, origin_frame = next(iter(directory_stream_data_loader))
|
| 69 |
+
assert frame.shape == (1, 3, 640, 640)
|
| 70 |
+
assert rev_tensor.shape == (1, 5)
|
| 71 |
+
assert origin_frame.size == (480, 640)
|
tests/test_tools/test_dataset_preparation.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
import sys
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
| 7 |
+
sys.path.append(str(project_root))
|
| 8 |
+
|
| 9 |
+
from yolo.config.config import Config
|
| 10 |
+
from yolo.tools.dataset_preparation import prepare_dataset, prepare_weight
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_prepare_dataset(train_cfg: Config):
|
| 14 |
+
dataset_path = Path("tests/data")
|
| 15 |
+
if dataset_path.exists():
|
| 16 |
+
shutil.rmtree(dataset_path)
|
| 17 |
+
prepare_dataset(train_cfg.dataset, task="train")
|
| 18 |
+
prepare_dataset(train_cfg.dataset, task="val")
|
| 19 |
+
|
| 20 |
+
images_path = Path("tests/data/images")
|
| 21 |
+
for data_type in images_path.iterdir():
|
| 22 |
+
assert len(os.listdir(data_type)) == 5
|
| 23 |
+
|
| 24 |
+
annotations_path = Path("tests/data/annotations")
|
| 25 |
+
assert os.listdir(annotations_path) == ["instances_val.json", "instances_train.json"]
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_prepare_weight():
|
| 29 |
+
prepare_weight()
|
tests/test_tools/test_drawer.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from torch import tensor
|
| 6 |
+
|
| 7 |
+
project_root = Path(__file__).resolve().parent.parent.parent
|
| 8 |
+
sys.path.append(str(project_root))
|
| 9 |
+
|
| 10 |
+
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import YOLO
|
| 12 |
+
from yolo.tools.drawer import draw_bboxes, draw_model
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def test_draw_model_by_config(train_cfg: Config):
|
| 16 |
+
"""Test the drawing of a model based on a configuration."""
|
| 17 |
+
draw_model(model_cfg=train_cfg.model)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def test_draw_model_by_model(model: YOLO):
|
| 21 |
+
"""Test the drawing of a YOLO model."""
|
| 22 |
+
draw_model(model=model)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def test_draw_bboxes():
|
| 26 |
+
"""Test drawing bounding boxes on an image."""
|
| 27 |
+
predictions = tensor([[0, 60, 60, 160, 160, 0.5], [0, 40, 40, 120, 120, 0.5]])
|
| 28 |
+
pil_image = Image.open("tests/data/images/train/000000050725.jpg")
|
| 29 |
+
draw_bboxes(pil_image, [predictions])
|
tests/test_tools/test_solver.py
CHANGED
|
@@ -1,114 +1,70 @@
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
| 3 |
-
from unittest.mock import MagicMock, patch
|
| 4 |
|
| 5 |
import pytest
|
| 6 |
-
import
|
| 7 |
-
from hydra import compose, initialize
|
| 8 |
|
| 9 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 10 |
sys.path.append(str(project_root))
|
| 11 |
|
| 12 |
-
from yolo.config.config import
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
TrainConfig,
|
| 17 |
-
ValidationConfig,
|
| 18 |
-
)
|
| 19 |
-
from yolo.model.yolo import YOLO, create_model
|
| 20 |
-
from yolo.tools.data_loader import create_dataloader
|
| 21 |
-
from yolo.tools.loss_functions import create_loss_function
|
| 22 |
-
from yolo.tools.solver import ( # Adjust the import to your module
|
| 23 |
-
ModelTester,
|
| 24 |
-
ModelTrainer,
|
| 25 |
-
ModelValidator,
|
| 26 |
-
)
|
| 27 |
from yolo.utils.bounding_box_utils import Vec2Box
|
| 28 |
-
from yolo.utils.logging_utils import ProgressLogger
|
| 29 |
-
from yolo.utils.model_utils import (
|
| 30 |
-
ExponentialMovingAverage,
|
| 31 |
-
create_optimizer,
|
| 32 |
-
create_scheduler,
|
| 33 |
-
)
|
| 34 |
|
| 35 |
|
| 36 |
@pytest.fixture
|
| 37 |
-
def
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
return
|
| 42 |
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
cfg: Config = compose(config_name="config", overrides=["task=validation"])
|
| 48 |
-
cfg.weight = None
|
| 49 |
-
return cfg
|
| 50 |
|
| 51 |
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
|
| 59 |
|
| 60 |
@pytest.fixture
|
| 61 |
-
def
|
| 62 |
-
|
| 63 |
-
return
|
| 64 |
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
return model.to(device)
|
| 70 |
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
model = create_model(cfg.model, weight_path=None).to(device)
|
| 75 |
-
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 76 |
-
return vec2box
|
| 77 |
|
| 78 |
|
| 79 |
@pytest.fixture
|
| 80 |
-
def
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
# def test_model_trainer_initialization(cfg: Config, model: YOLO, vec2box: Vec2Box, progress_logger, device):
|
| 86 |
-
# trainer = ModelTrainer(cfg, model, vec2box, progress_logger, device, use_ddp=False)
|
| 87 |
-
# assert trainer.model == model
|
| 88 |
-
# assert trainer.device == device
|
| 89 |
-
# assert trainer.optimizer is not None
|
| 90 |
-
# assert trainer.scheduler is not None
|
| 91 |
-
# assert trainer.loss_fn is not None
|
| 92 |
-
# assert trainer.progress == progress_logger
|
| 93 |
-
|
| 94 |
|
| 95 |
-
# def test_model_trainer_train_one_batch(config, model, vec2box, progress_logger, device):
|
| 96 |
-
# trainer = ModelTrainer(config, model, vec2box, progress_logger, device, use_ddp=False)
|
| 97 |
-
# images = torch.rand(1, 3, 224, 224)
|
| 98 |
-
# targets = torch.rand(1, 5)
|
| 99 |
-
# loss_item = trainer.train_one_batch(images, targets)
|
| 100 |
-
# assert isinstance(loss_item, dict)
|
| 101 |
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
assert
|
| 106 |
-
assert
|
| 107 |
-
assert
|
| 108 |
|
| 109 |
|
| 110 |
-
def
|
| 111 |
-
|
| 112 |
-
assert tester.model == model
|
| 113 |
-
assert tester.device == device
|
| 114 |
-
assert tester.progress == progress_logger
|
|
|
|
| 1 |
import sys
|
| 2 |
from pathlib import Path
|
|
|
|
| 3 |
|
| 4 |
import pytest
|
| 5 |
+
from torch import allclose, tensor
|
|
|
|
| 6 |
|
| 7 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
+
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import YOLO
|
| 12 |
+
from yolo.tools.data_loader import StreamDataLoader, YoloDataLoader
|
| 13 |
+
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
@pytest.fixture
|
| 18 |
+
def model_validator(validation_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
|
| 19 |
+
validator = ModelValidator(
|
| 20 |
+
validation_cfg.task, validation_cfg.dataset, model, vec2box, validation_progress_logger, device
|
| 21 |
+
)
|
| 22 |
+
return validator
|
| 23 |
|
| 24 |
|
| 25 |
+
def test_model_validator_initialization(model_validator: ModelValidator):
|
| 26 |
+
assert isinstance(model_validator.model, YOLO)
|
| 27 |
+
assert hasattr(model_validator, "solve")
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
|
| 30 |
+
def test_model_validator_solve_mock_dataset(model_validator: ModelValidator, validation_dataloader: YoloDataLoader):
|
| 31 |
+
mAPs = model_validator.solve(validation_dataloader)
|
| 32 |
+
except_mAPs = {"mAP.5": tensor(0.6969), "mAP.5:.95": tensor(0.4195)}
|
| 33 |
+
assert allclose(mAPs["mAP.5"], except_mAPs["mAP.5"], rtol=1e-4)
|
| 34 |
+
print(mAPs)
|
| 35 |
+
assert allclose(mAPs["mAP.5:.95"], except_mAPs["mAP.5:.95"], rtol=1e-4)
|
| 36 |
|
| 37 |
|
| 38 |
@pytest.fixture
|
| 39 |
+
def model_tester(inference_cfg: Config, model: YOLO, vec2box: Vec2Box, validation_progress_logger, device):
|
| 40 |
+
tester = ModelTester(inference_cfg, model, vec2box, validation_progress_logger, device)
|
| 41 |
+
return tester
|
| 42 |
|
| 43 |
|
| 44 |
+
def test_model_tester_initialization(model_tester: ModelTester):
|
| 45 |
+
assert isinstance(model_tester.model, YOLO)
|
| 46 |
+
assert hasattr(model_tester, "solve")
|
|
|
|
| 47 |
|
| 48 |
|
| 49 |
+
def test_model_tester_solve_single_image(model_tester: ModelTester, file_stream_data_loader: StreamDataLoader):
|
| 50 |
+
model_tester.solve(file_stream_data_loader)
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
@pytest.fixture
|
| 54 |
+
def model_trainer(train_cfg: Config, model: YOLO, vec2box: Vec2Box, train_progress_logger, device):
|
| 55 |
+
train_cfg.task.epoch = 2
|
| 56 |
+
trainer = ModelTrainer(train_cfg, model, vec2box, train_progress_logger, device, use_ddp=False)
|
| 57 |
+
return trainer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
def test_model_trainer_initialization(model_trainer: ModelTrainer):
|
| 61 |
|
| 62 |
+
assert isinstance(model_trainer.model, YOLO)
|
| 63 |
+
assert hasattr(model_trainer, "solve")
|
| 64 |
+
assert model_trainer.optimizer is not None
|
| 65 |
+
assert model_trainer.scheduler is not None
|
| 66 |
+
assert model_trainer.loss_fn is not None
|
| 67 |
|
| 68 |
|
| 69 |
+
def test_model_trainer_solve_mock_dataset(model_trainer: ModelTrainer, train_dataloader: YoloDataLoader):
|
| 70 |
+
model_trainer.solve(train_dataloader)
|
|
|
|
|
|
|
|
|
yolo/tools/data_loader.py
CHANGED
|
@@ -111,7 +111,7 @@ class YoloDataset(Dataset):
|
|
| 111 |
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
| 112 |
return data
|
| 113 |
|
| 114 |
-
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[
|
| 115 |
"""
|
| 116 |
Loads and validates bounding box data is [0, 1] from a label file.
|
| 117 |
|
|
@@ -119,7 +119,7 @@ class YoloDataset(Dataset):
|
|
| 119 |
label_path (str): The filepath to the label file containing bounding box data.
|
| 120 |
|
| 121 |
Returns:
|
| 122 |
-
|
| 123 |
"""
|
| 124 |
bboxes = []
|
| 125 |
for seg_data in seg_data_one_img:
|
|
@@ -145,7 +145,7 @@ class YoloDataset(Dataset):
|
|
| 145 |
indices = torch.randint(0, len(self), (num,))
|
| 146 |
return [self.get_data(idx)[:2] for idx in indices]
|
| 147 |
|
| 148 |
-
def __getitem__(self, idx) ->
|
| 149 |
img, bboxes, img_path = self.get_data(idx)
|
| 150 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
| 151 |
return img, bboxes, rev_tensor, img_path
|
|
@@ -170,17 +170,17 @@ class YoloDataLoader(DataLoader):
|
|
| 170 |
collate_fn=self.collate_fn,
|
| 171 |
)
|
| 172 |
|
| 173 |
-
def collate_fn(self, batch: List[Tuple[
|
| 174 |
"""
|
| 175 |
A collate function to handle batching of images and their corresponding targets.
|
| 176 |
|
| 177 |
Args:
|
| 178 |
batch (list of tuples): Each tuple contains:
|
| 179 |
-
- image (
|
| 180 |
-
- labels (
|
| 181 |
|
| 182 |
Returns:
|
| 183 |
-
Tuple[
|
| 184 |
- A tensor of batched images.
|
| 185 |
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
| 186 |
"""
|
|
@@ -213,7 +213,7 @@ def create_dataloader(data_cfg: DataConfig, dataset_cfg: DatasetConfig, task: st
|
|
| 213 |
|
| 214 |
class StreamDataLoader:
|
| 215 |
def __init__(self, data_cfg: DataConfig):
|
| 216 |
-
self.source =
|
| 217 |
self.running = True
|
| 218 |
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
|
| 219 |
|
|
@@ -225,6 +225,7 @@ class StreamDataLoader:
|
|
| 225 |
|
| 226 |
self.cap = cv2.VideoCapture(self.source)
|
| 227 |
else:
|
|
|
|
| 228 |
self.queue = Queue()
|
| 229 |
self.thread = Thread(target=self.load_source)
|
| 230 |
self.thread.start()
|
|
|
|
| 111 |
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
| 112 |
return data
|
| 113 |
|
| 114 |
+
def load_valid_labels(self, label_path: str, seg_data_one_img: list) -> Union[Tensor, None]:
|
| 115 |
"""
|
| 116 |
Loads and validates bounding box data is [0, 1] from a label file.
|
| 117 |
|
|
|
|
| 119 |
label_path (str): The filepath to the label file containing bounding box data.
|
| 120 |
|
| 121 |
Returns:
|
| 122 |
+
Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
| 123 |
"""
|
| 124 |
bboxes = []
|
| 125 |
for seg_data in seg_data_one_img:
|
|
|
|
| 145 |
indices = torch.randint(0, len(self), (num,))
|
| 146 |
return [self.get_data(idx)[:2] for idx in indices]
|
| 147 |
|
| 148 |
+
def __getitem__(self, idx) -> Tuple[Image.Image, Tensor, Tensor, List[str]]:
|
| 149 |
img, bboxes, img_path = self.get_data(idx)
|
| 150 |
img, bboxes, rev_tensor = self.transform(img, bboxes)
|
| 151 |
return img, bboxes, rev_tensor, img_path
|
|
|
|
| 170 |
collate_fn=self.collate_fn,
|
| 171 |
)
|
| 172 |
|
| 173 |
+
def collate_fn(self, batch: List[Tuple[Tensor, Tensor]]) -> Tuple[Tensor, List[Tensor]]:
|
| 174 |
"""
|
| 175 |
A collate function to handle batching of images and their corresponding targets.
|
| 176 |
|
| 177 |
Args:
|
| 178 |
batch (list of tuples): Each tuple contains:
|
| 179 |
+
- image (Tensor): The image tensor.
|
| 180 |
+
- labels (Tensor): The tensor of labels for the image.
|
| 181 |
|
| 182 |
Returns:
|
| 183 |
+
Tuple[Tensor, List[Tensor]]: A tuple containing:
|
| 184 |
- A tensor of batched images.
|
| 185 |
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
| 186 |
"""
|
|
|
|
| 213 |
|
| 214 |
class StreamDataLoader:
|
| 215 |
def __init__(self, data_cfg: DataConfig):
|
| 216 |
+
self.source = data_cfg.source
|
| 217 |
self.running = True
|
| 218 |
self.is_stream = isinstance(self.source, int) or str(self.source).lower().startswith("rtmp://")
|
| 219 |
|
|
|
|
| 225 |
|
| 226 |
self.cap = cv2.VideoCapture(self.source)
|
| 227 |
else:
|
| 228 |
+
self.source = Path(self.source)
|
| 229 |
self.queue = Queue()
|
| 230 |
self.thread = Thread(target=self.load_source)
|
| 231 |
self.thread.start()
|
yolo/tools/dataset_preparation.py
CHANGED
|
@@ -82,7 +82,7 @@ def prepare_dataset(dataset_cfg: DatasetConfig, task: str):
|
|
| 82 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
| 83 |
|
| 84 |
|
| 85 |
-
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-c.pt"):
|
| 86 |
weight_name = weight_path.name
|
| 87 |
if download_link is None:
|
| 88 |
download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
|
@@ -97,13 +97,3 @@ def prepare_weight(download_link: Optional[str] = None, weight_path: Path = "v9-
|
|
| 97 |
download_file(weight_link, weight_path)
|
| 98 |
except requests.exceptions.RequestException as e:
|
| 99 |
logger.warning(f"Failed to download the weight file: {e}")
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
if __name__ == "__main__":
|
| 103 |
-
import sys
|
| 104 |
-
|
| 105 |
-
sys.path.append("./")
|
| 106 |
-
from utils.logging_utils import custom_logger
|
| 107 |
-
|
| 108 |
-
custom_logger()
|
| 109 |
-
prepare_weight()
|
|
|
|
| 82 |
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
| 83 |
|
| 84 |
|
| 85 |
+
def prepare_weight(download_link: Optional[str] = None, weight_path: Path = Path("v9-c.pt")):
|
| 86 |
weight_name = weight_path.name
|
| 87 |
if download_link is None:
|
| 88 |
download_link = "https://github.com/WongKinYiu/yolov9mit/releases/download/v1.0-alpha/"
|
|
|
|
| 97 |
download_file(weight_link, weight_path)
|
| 98 |
except requests.exceptions.RequestException as e:
|
| 99 |
logger.warning(f"Failed to download the weight file: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
yolo/tools/drawer.py
CHANGED
|
@@ -7,6 +7,9 @@ from loguru import logger
|
|
| 7 |
from PIL import Image, ImageDraw, ImageFont
|
| 8 |
from torchvision.transforms.functional import to_pil_image
|
| 9 |
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def draw_bboxes(
|
| 12 |
img: Union[Image.Image, torch.Tensor],
|
|
@@ -62,7 +65,7 @@ def draw_bboxes(
|
|
| 62 |
return img
|
| 63 |
|
| 64 |
|
| 65 |
-
def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
| 66 |
from graphviz import Digraph
|
| 67 |
|
| 68 |
if model_cfg:
|
|
|
|
| 7 |
from PIL import Image, ImageDraw, ImageFont
|
| 8 |
from torchvision.transforms.functional import to_pil_image
|
| 9 |
|
| 10 |
+
from yolo.config.config import ModelConfig
|
| 11 |
+
from yolo.model.yolo import YOLO
|
| 12 |
+
|
| 13 |
|
| 14 |
def draw_bboxes(
|
| 15 |
img: Union[Image.Image, torch.Tensor],
|
|
|
|
| 65 |
return img
|
| 66 |
|
| 67 |
|
| 68 |
+
def draw_model(*, model_cfg: ModelConfig = None, model: YOLO = None, v7_base=False):
|
| 69 |
from graphviz import Digraph
|
| 70 |
|
| 71 |
if model_cfg:
|
yolo/utils/logging_utils.py
CHANGED
|
@@ -138,7 +138,8 @@ class ProgressLogger(Progress):
|
|
| 138 |
def finish_train(self):
|
| 139 |
self.remove_task(self.task_epoch)
|
| 140 |
self.stop()
|
| 141 |
-
self.
|
|
|
|
| 142 |
|
| 143 |
|
| 144 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
|
|
|
| 138 |
def finish_train(self):
|
| 139 |
self.remove_task(self.task_epoch)
|
| 140 |
self.stop()
|
| 141 |
+
if self.use_wandb:
|
| 142 |
+
self.wandb.finish()
|
| 143 |
|
| 144 |
|
| 145 |
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|