Spaces:
Sleeping
Sleeping
π [Merge] branch 'SETUP' into MODEL
Browse files- LICENSE +21 -0
- README.md +13 -13
- examples/example_train.py +35 -0
- {config β yolo/config}/README.md +0 -0
- yolo/config/config.py +91 -0
- yolo/config/config.yaml +11 -0
- yolo/config/data/augmentation.yaml +3 -0
- {config β yolo/config}/data/coco.yaml +0 -0
- yolo/config/data/download.yaml +21 -0
- yolo/config/hyper/default.yaml +19 -0
- {config β yolo/config}/model/v7-base.yaml +0 -0
- {model β yolo/model}/README.md +0 -0
- {model β yolo/model}/module.py +0 -0
- {model β yolo/model}/yolo.py +1 -1
- yolo/tools/__init__.py +0 -0
- yolo/tools/dataset_helper.py +103 -0
- {tools β yolo/tools}/layer_helper.py +1 -1
- {tools β yolo/tools}/log_helper.py +0 -0
- {tools β yolo/tools}/model_helper.py +1 -1
- {tools β yolo/tools}/trainer.py +4 -4
- {utils β yolo/utils}/README.md +0 -0
- yolo/utils/converter_json2txt.py +90 -0
- yolo/utils/data_augment.py +125 -0
- yolo/utils/dataloader.py +206 -0
- yolo/utils/drawer.py +41 -0
- yolo/utils/get_dataset.py +84 -0
- yolo/utils/loss.py +2 -0
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Kin-Yiu, Wong
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -20,31 +20,31 @@ If you are interested in contributing, please keep an eye on project updates or
|
|
| 20 |
## To-Do Lists
|
| 21 |
- [ ] Project Setup
|
| 22 |
- [X] requirements
|
| 23 |
-
- [
|
| 24 |
- [ ] README
|
| 25 |
-
- [
|
| 26 |
- [ ] setup.py/pip install
|
| 27 |
-
- [
|
| 28 |
- [ ] hugging face
|
| 29 |
- [ ] Data proccess
|
| 30 |
- [ ] Dataset
|
| 31 |
-
- [
|
| 32 |
- [ ] Auto Download
|
| 33 |
- [ ] xywh, xxyy, xcyc
|
| 34 |
-
- [
|
| 35 |
-
- [
|
| 36 |
- [ ] Model
|
| 37 |
- [ ] load model
|
| 38 |
- [ ] from yaml
|
| 39 |
- [ ] from github
|
| 40 |
-
- [
|
| 41 |
-
- [
|
| 42 |
-
- [
|
| 43 |
-
- [ ] DDP
|
|
|
|
|
|
|
| 44 |
- [ ] Run
|
| 45 |
- [ ] train
|
| 46 |
- [ ] test
|
| 47 |
- [ ] demo
|
| 48 |
-
- [
|
| 49 |
-
- [ ] hyperparams: dataclass
|
| 50 |
-
- [ ] model cfg: yaml
|
|
|
|
| 20 |
## To-Do Lists
|
| 21 |
- [ ] Project Setup
|
| 22 |
- [X] requirements
|
| 23 |
+
- [x] LICENSE
|
| 24 |
- [ ] README
|
| 25 |
+
- [x] pytests
|
| 26 |
- [ ] setup.py/pip install
|
| 27 |
+
- [x] log format
|
| 28 |
- [ ] hugging face
|
| 29 |
- [ ] Data proccess
|
| 30 |
- [ ] Dataset
|
| 31 |
+
- [x] Download script
|
| 32 |
- [ ] Auto Download
|
| 33 |
- [ ] xywh, xxyy, xcyc
|
| 34 |
+
- [x] Dataloder
|
| 35 |
+
- [x] Data arugment
|
| 36 |
- [ ] Model
|
| 37 |
- [ ] load model
|
| 38 |
- [ ] from yaml
|
| 39 |
- [ ] from github
|
| 40 |
+
- [x] trainer
|
| 41 |
+
- [x] train_one_iter
|
| 42 |
+
- [x] train_one_epoch
|
| 43 |
+
- [ ] DDP
|
| 44 |
+
- [x] EMA, OTA
|
| 45 |
+
- [ ] Loss
|
| 46 |
- [ ] Run
|
| 47 |
- [ ] train
|
| 48 |
- [ ] test
|
| 49 |
- [ ] demo
|
| 50 |
+
- [x] Configuration
|
|
|
|
|
|
examples/example_train.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
import hydra
|
| 5 |
+
import torch
|
| 6 |
+
from loguru import logger
|
| 7 |
+
|
| 8 |
+
project_root = Path(__file__).resolve().parent.parent
|
| 9 |
+
sys.path.append(str(project_root))
|
| 10 |
+
|
| 11 |
+
from yolo.config.config import Config
|
| 12 |
+
from yolo.model.yolo import get_model
|
| 13 |
+
from yolo.tools.log_helper import custom_logger
|
| 14 |
+
from yolo.tools.trainer import Trainer
|
| 15 |
+
from yolo.utils.dataloader import get_dataloader
|
| 16 |
+
from yolo.utils.get_dataset import prepare_dataset
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 20 |
+
def main(cfg: Config):
|
| 21 |
+
if cfg.download.auto:
|
| 22 |
+
prepare_dataset(cfg.download)
|
| 23 |
+
|
| 24 |
+
dataloader = get_dataloader(cfg)
|
| 25 |
+
model = get_model(cfg.model)
|
| 26 |
+
# TODO: get_device or rank, for DDP mode
|
| 27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
|
| 29 |
+
trainer = Trainer(model, cfg.hyper.train, device)
|
| 30 |
+
trainer.train(dataloader, 10)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
if __name__ == "__main__":
|
| 34 |
+
custom_logger()
|
| 35 |
+
main()
|
{config β yolo/config}/README.md
RENAMED
|
File without changes
|
yolo/config/config.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, List, Union
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
@dataclass
|
| 6 |
+
class Model:
|
| 7 |
+
anchor: List[List[int]]
|
| 8 |
+
model: Dict[str, List[Dict[str, Union[Dict, List, int]]]]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class Download:
|
| 13 |
+
auto: bool
|
| 14 |
+
path: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class DataLoaderConfig:
|
| 19 |
+
batch_size: int
|
| 20 |
+
shuffle: bool
|
| 21 |
+
num_workers: int
|
| 22 |
+
pin_memory: bool
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class OptimizerArgs:
|
| 27 |
+
lr: float
|
| 28 |
+
weight_decay: float
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class OptimizerConfig:
|
| 33 |
+
type: str
|
| 34 |
+
args: OptimizerArgs
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class SchedulerArgs:
|
| 39 |
+
step_size: int
|
| 40 |
+
gamma: float
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class SchedulerConfig:
|
| 45 |
+
type: str
|
| 46 |
+
args: SchedulerArgs
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@dataclass
|
| 50 |
+
class EMAConfig:
|
| 51 |
+
enabled: bool
|
| 52 |
+
decay: float
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class TrainConfig:
|
| 57 |
+
optimizer: OptimizerConfig
|
| 58 |
+
scheduler: SchedulerConfig
|
| 59 |
+
ema: EMAConfig
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class HyperConfig:
|
| 64 |
+
data: DataLoaderConfig
|
| 65 |
+
train: TrainConfig
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
@dataclass
|
| 69 |
+
class Dataset:
|
| 70 |
+
file_name: str
|
| 71 |
+
num_files: int
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
@dataclass
|
| 75 |
+
class Datasets:
|
| 76 |
+
base_url: str
|
| 77 |
+
images: Dict[str, Dataset]
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@dataclass
|
| 81 |
+
class Download:
|
| 82 |
+
auto: bool
|
| 83 |
+
save_path: str
|
| 84 |
+
datasets: Datasets
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@dataclass
|
| 88 |
+
class Config:
|
| 89 |
+
model: Model
|
| 90 |
+
download: Download
|
| 91 |
+
hyper: HyperConfig
|
yolo/config/config.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ./runs
|
| 4 |
+
|
| 5 |
+
defaults:
|
| 6 |
+
- data: coco
|
| 7 |
+
- download: ../data/download
|
| 8 |
+
- augmentation: ../data/augmentation
|
| 9 |
+
- model: v7-base
|
| 10 |
+
- hyper: default
|
| 11 |
+
- _self_
|
yolo/config/data/augmentation.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Mosaic: 1
|
| 2 |
+
# MixUp: 1
|
| 3 |
+
HorizontalFlip: 0.5
|
{config β yolo/config}/data/coco.yaml
RENAMED
|
File without changes
|
yolo/config/data/download.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
auto: True
|
| 2 |
+
save_path: data/coco
|
| 3 |
+
datasets:
|
| 4 |
+
images:
|
| 5 |
+
base_url: http://images.cocodataset.org/zips/
|
| 6 |
+
train2017:
|
| 7 |
+
file_name: train2017
|
| 8 |
+
file_num: 118287
|
| 9 |
+
val2017:
|
| 10 |
+
file_name: val2017
|
| 11 |
+
file_num: 5000
|
| 12 |
+
test2017:
|
| 13 |
+
file_name: test2017
|
| 14 |
+
file_num: 40670
|
| 15 |
+
annotations:
|
| 16 |
+
base_url: http://images.cocodataset.org/annotations/
|
| 17 |
+
annotations:
|
| 18 |
+
file_name: annotations_trainval2017
|
| 19 |
+
hydra:
|
| 20 |
+
run:
|
| 21 |
+
dir: ./runs
|
yolo/config/hyper/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
data:
|
| 2 |
+
batch_size: 4
|
| 3 |
+
shuffle: True
|
| 4 |
+
num_workers: 4
|
| 5 |
+
pin_memory: True
|
| 6 |
+
train:
|
| 7 |
+
optimizer:
|
| 8 |
+
type: Adam
|
| 9 |
+
args:
|
| 10 |
+
lr: 0.001
|
| 11 |
+
weight_decay: 0.0001
|
| 12 |
+
scheduler:
|
| 13 |
+
type: StepLR
|
| 14 |
+
args:
|
| 15 |
+
step_size: 10
|
| 16 |
+
gamma: 0.1
|
| 17 |
+
ema:
|
| 18 |
+
enabled: true
|
| 19 |
+
decay: 0.995
|
{config β yolo/config}/model/v7-base.yaml
RENAMED
|
File without changes
|
{model β yolo/model}/README.md
RENAMED
|
File without changes
|
{model β yolo/model}/module.py
RENAMED
|
File without changes
|
{model β yolo/model}/yolo.py
RENAMED
|
@@ -5,7 +5,7 @@ import torch.nn as nn
|
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
|
| 8 |
-
from tools.layer_helper import get_layer_map
|
| 9 |
|
| 10 |
|
| 11 |
class YOLO(nn.Module):
|
|
|
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import OmegaConf
|
| 7 |
|
| 8 |
+
from yolo.tools.layer_helper import get_layer_map
|
| 9 |
|
| 10 |
|
| 11 |
class YOLO(nn.Module):
|
yolo/tools/__init__.py
ADDED
|
File without changes
|
yolo/tools/dataset_helper.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from itertools import chain
|
| 4 |
+
from os import path
|
| 5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def find_labels_path(dataset_path: str, phase_name: str):
|
| 11 |
+
"""
|
| 12 |
+
Find the path to label files for a specified dataset and phase(e.g. training).
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
dataset_path (str): The path to the root directory of the dataset.
|
| 16 |
+
phase_name (str): The name of the phase for which labels are being searched (e.g., "train", "val", "test").
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
Tuple[str, str]: A tuple containing the path to the labels file and the file format ("json" or "txt").
|
| 20 |
+
"""
|
| 21 |
+
json_labels_path = path.join(dataset_path, "annotations", f"instances_{phase_name}.json")
|
| 22 |
+
|
| 23 |
+
txt_labels_path = path.join(dataset_path, "label", phase_name)
|
| 24 |
+
|
| 25 |
+
if path.isfile(json_labels_path):
|
| 26 |
+
return json_labels_path, "json"
|
| 27 |
+
|
| 28 |
+
elif path.isdir(txt_labels_path):
|
| 29 |
+
txt_files = [f for f in os.listdir(txt_labels_path) if f.endswith(".txt")]
|
| 30 |
+
if txt_files:
|
| 31 |
+
return txt_labels_path, "txt"
|
| 32 |
+
|
| 33 |
+
raise FileNotFoundError("No labels found in the specified dataset path and phase name.")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def create_image_info_dict(labels_path: str) -> Tuple[Dict[str, List], Dict[str, Dict]]:
|
| 37 |
+
"""
|
| 38 |
+
Create a dictionary containing image information and annotations indexed by image ID.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
labels_path (str): The path to the annotation json file.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
- annotations_index: A dictionary where keys are image IDs and values are lists of annotations.
|
| 45 |
+
- image_info_dict: A dictionary where keys are image file names without extension and values are image information dictionaries.
|
| 46 |
+
"""
|
| 47 |
+
with open(labels_path, "r") as file:
|
| 48 |
+
labels_data = json.load(file)
|
| 49 |
+
annotations_index = index_annotations_by_image(labels_data) # check lookup is a good name?
|
| 50 |
+
image_info_dict = {path.splitext(img["file_name"])[0]: img for img in labels_data["images"]}
|
| 51 |
+
return annotations_index, image_info_dict
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def index_annotations_by_image(data: Dict[str, Any]):
|
| 55 |
+
"""
|
| 56 |
+
Use image index to lookup every annotations
|
| 57 |
+
Args:
|
| 58 |
+
data (Dict[str, Any]): A dictionary containing annotation data.
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Dict[int, List[Dict[str, Any]]]: A dictionary where keys are image IDs and values are lists of annotations.
|
| 62 |
+
Annotations with "iscrowd" set to True are excluded from the index.
|
| 63 |
+
|
| 64 |
+
"""
|
| 65 |
+
annotation_lookup = {}
|
| 66 |
+
for anno in data["annotations"]:
|
| 67 |
+
if anno["iscrowd"]:
|
| 68 |
+
continue
|
| 69 |
+
image_id = anno["image_id"]
|
| 70 |
+
if image_id not in annotation_lookup:
|
| 71 |
+
annotation_lookup[image_id] = []
|
| 72 |
+
annotation_lookup[image_id].append(anno)
|
| 73 |
+
return annotation_lookup
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_scaled_segmentation(
|
| 77 |
+
annotations: List[Dict[str, Any]], image_dimensions: Dict[str, int]
|
| 78 |
+
) -> Optional[List[List[float]]]:
|
| 79 |
+
"""
|
| 80 |
+
Scale the segmentation data based on image dimensions and return a list of scaled segmentation data.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
annotations (List[Dict[str, Any]]): A list of annotation dictionaries.
|
| 84 |
+
image_dimensions (Dict[str, int]): A dictionary containing image dimensions (height and width).
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Optional[List[List[float]]]: A list of scaled segmentation data, where each sublist contains category_id followed by scaled (x, y) coordinates.
|
| 88 |
+
"""
|
| 89 |
+
if annotations is None:
|
| 90 |
+
return None
|
| 91 |
+
|
| 92 |
+
seg_array_with_cat = []
|
| 93 |
+
h, w = image_dimensions["height"], image_dimensions["width"]
|
| 94 |
+
for anno in annotations:
|
| 95 |
+
category_id = anno["category_id"]
|
| 96 |
+
seg_list = [item for sublist in anno["segmentation"] for item in sublist]
|
| 97 |
+
scaled_seg_data = (
|
| 98 |
+
np.array(seg_list).reshape(-1, 2) / [w, h]
|
| 99 |
+
).tolist() # make the list group in x, y pairs and scaled with image width, height
|
| 100 |
+
scaled_flat_seg_data = [category_id] + list(chain(*scaled_seg_data)) # flatten the scaled_seg_data list
|
| 101 |
+
seg_array_with_cat.append(scaled_flat_seg_data)
|
| 102 |
+
|
| 103 |
+
return seg_array_with_cat
|
{tools β yolo/tools}/layer_helper.py
RENAMED
|
@@ -2,7 +2,7 @@ import inspect
|
|
| 2 |
|
| 3 |
import torch.nn as nn
|
| 4 |
|
| 5 |
-
from model import module
|
| 6 |
|
| 7 |
|
| 8 |
def auto_pad():
|
|
|
|
| 2 |
|
| 3 |
import torch.nn as nn
|
| 4 |
|
| 5 |
+
from yolo.model import module
|
| 6 |
|
| 7 |
|
| 8 |
def auto_pad():
|
{tools β yolo/tools}/log_helper.py
RENAMED
|
File without changes
|
{tools β yolo/tools}/model_helper.py
RENAMED
|
@@ -4,7 +4,7 @@ import torch
|
|
| 4 |
from torch.optim import Optimizer
|
| 5 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 6 |
|
| 7 |
-
from config.config import OptimizerConfig, SchedulerConfig
|
| 8 |
|
| 9 |
|
| 10 |
class EMA:
|
|
|
|
| 4 |
from torch.optim import Optimizer
|
| 5 |
from torch.optim.lr_scheduler import _LRScheduler
|
| 6 |
|
| 7 |
+
from yolo.config.config import OptimizerConfig, SchedulerConfig
|
| 8 |
|
| 9 |
|
| 10 |
class EMA:
|
{tools β yolo/tools}/trainer.py
RENAMED
|
@@ -2,10 +2,10 @@ import torch
|
|
| 2 |
from loguru import logger
|
| 3 |
from tqdm import tqdm
|
| 4 |
|
| 5 |
-
from config.config import TrainConfig
|
| 6 |
-
from model.yolo import YOLO
|
| 7 |
-
from tools.model_helper import EMA, get_optimizer, get_scheduler
|
| 8 |
-
from utils.loss import get_loss_function
|
| 9 |
|
| 10 |
|
| 11 |
class Trainer:
|
|
|
|
| 2 |
from loguru import logger
|
| 3 |
from tqdm import tqdm
|
| 4 |
|
| 5 |
+
from yolo.config.config import TrainConfig
|
| 6 |
+
from yolo.model.yolo import YOLO
|
| 7 |
+
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
| 8 |
+
from yolo.utils.loss import get_loss_function
|
| 9 |
|
| 10 |
|
| 11 |
class Trainer:
|
{utils β yolo/utils}/README.md
RENAMED
|
File without changes
|
yolo/utils/converter_json2txt.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from typing import Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def discretize_categories(categories: List[Dict[str, int]]) -> Dict[int, int]:
|
| 9 |
+
"""
|
| 10 |
+
Maps each unique 'id' in the list of category dictionaries to a sequential integer index.
|
| 11 |
+
Indices are assigned based on the sorted 'id' values.
|
| 12 |
+
"""
|
| 13 |
+
sorted_categories = sorted(categories, key=lambda category: category["id"])
|
| 14 |
+
return {category["id"]: index for index, category in enumerate(sorted_categories)}
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def process_annotations(
|
| 18 |
+
image_annotations: Dict[int, List[Dict]],
|
| 19 |
+
image_info_dict: Dict[int, tuple],
|
| 20 |
+
output_dir: str,
|
| 21 |
+
id_to_idx: Optional[Dict[int, int]] = None,
|
| 22 |
+
) -> None:
|
| 23 |
+
"""
|
| 24 |
+
Process and save annotations to files, with option to remap category IDs.
|
| 25 |
+
"""
|
| 26 |
+
for image_id, annotations in tqdm(image_annotations.items(), desc="Processing annotations"):
|
| 27 |
+
file_path = os.path.join(output_dir, f"{image_id:0>12}.txt")
|
| 28 |
+
if not annotations:
|
| 29 |
+
continue
|
| 30 |
+
with open(file_path, "w") as file:
|
| 31 |
+
for annotation in annotations:
|
| 32 |
+
process_annotation(annotation, image_info_dict[image_id], id_to_idx, file)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def process_annotation(annotation: Dict, image_dims: tuple, id_to_idx: Optional[Dict[int, int]], file) -> None:
|
| 36 |
+
"""
|
| 37 |
+
Convert a single annotation's segmentation and write it to the open file handle.
|
| 38 |
+
"""
|
| 39 |
+
category_id = annotation["category_id"]
|
| 40 |
+
segmentation = (
|
| 41 |
+
annotation["segmentation"][0]
|
| 42 |
+
if annotation["segmentation"] and isinstance(annotation["segmentation"][0], list)
|
| 43 |
+
else None
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if segmentation is None:
|
| 47 |
+
return
|
| 48 |
+
|
| 49 |
+
img_width, img_height = image_dims
|
| 50 |
+
normalized_segmentation = normalize_segmentation(segmentation, img_width, img_height)
|
| 51 |
+
|
| 52 |
+
if id_to_idx:
|
| 53 |
+
category_id = id_to_idx.get(category_id, category_id)
|
| 54 |
+
|
| 55 |
+
file.write(f"{category_id} {' '.join(normalized_segmentation)}\n")
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def normalize_segmentation(segmentation: List[float], img_width: int, img_height: int) -> List[str]:
|
| 59 |
+
"""
|
| 60 |
+
Normalize and format segmentation coordinates.
|
| 61 |
+
"""
|
| 62 |
+
normalized = [
|
| 63 |
+
f"{coord / img_width:.6f}" if index % 2 == 0 else f"{coord / img_height:.6f}"
|
| 64 |
+
for index, coord in enumerate(segmentation)
|
| 65 |
+
]
|
| 66 |
+
return normalized
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def convert_annotations(json_file: str, output_dir: str) -> None:
|
| 70 |
+
"""
|
| 71 |
+
Load annotation data from a JSON file and process all annotations.
|
| 72 |
+
"""
|
| 73 |
+
with open(json_file) as file:
|
| 74 |
+
data = json.load(file)
|
| 75 |
+
|
| 76 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 77 |
+
|
| 78 |
+
image_info_dict = {img["id"]: (img["width"], img["height"]) for img in data.get("images", [])}
|
| 79 |
+
id_to_idx = discretize_categories(data.get("categories", [])) if "categories" in data else None
|
| 80 |
+
image_annotations = {img_id: [] for img_id in image_info_dict}
|
| 81 |
+
|
| 82 |
+
for annotation in data.get("annotations", []):
|
| 83 |
+
if not annotation.get("iscrowd", False):
|
| 84 |
+
image_annotations[annotation["image_id"]].append(annotation)
|
| 85 |
+
|
| 86 |
+
process_annotations(image_annotations, image_info_dict, output_dir, id_to_idx)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
convert_annotations("./data/coco/annotations/instances_train2017.json", "./data/coco/labels/train2017/")
|
| 90 |
+
convert_annotations("./data/coco/annotations/instances_val2017.json", "./data/coco/labels/val2017/")
|
yolo/utils/data_augment.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
from torchvision.transforms import functional as TF
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class Compose:
|
| 8 |
+
"""Composes several transforms together."""
|
| 9 |
+
|
| 10 |
+
def __init__(self, transforms, image_size: int = 640):
|
| 11 |
+
self.transforms = transforms
|
| 12 |
+
self.image_size = image_size
|
| 13 |
+
|
| 14 |
+
for transform in self.transforms:
|
| 15 |
+
if hasattr(transform, "set_parent"):
|
| 16 |
+
transform.set_parent(self)
|
| 17 |
+
|
| 18 |
+
def __call__(self, image, boxes):
|
| 19 |
+
for transform in self.transforms:
|
| 20 |
+
image, boxes = transform(image, boxes)
|
| 21 |
+
return image, boxes
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class HorizontalFlip:
|
| 25 |
+
"""Randomly horizontally flips the image along with the bounding boxes."""
|
| 26 |
+
|
| 27 |
+
def __init__(self, prob=0.5):
|
| 28 |
+
self.prob = prob
|
| 29 |
+
|
| 30 |
+
def __call__(self, image, boxes):
|
| 31 |
+
if torch.rand(1) < self.prob:
|
| 32 |
+
image = TF.hflip(image)
|
| 33 |
+
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]]
|
| 34 |
+
return image, boxes
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class VerticalFlip:
|
| 38 |
+
"""Randomly vertically flips the image along with the bounding boxes."""
|
| 39 |
+
|
| 40 |
+
def __init__(self, prob=0.5):
|
| 41 |
+
self.prob = prob
|
| 42 |
+
|
| 43 |
+
def __call__(self, image, boxes):
|
| 44 |
+
if torch.rand(1) < self.prob:
|
| 45 |
+
image = TF.vflip(image)
|
| 46 |
+
boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]]
|
| 47 |
+
return image, boxes
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class Mosaic:
|
| 51 |
+
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes."""
|
| 52 |
+
|
| 53 |
+
def __init__(self, prob=0.5):
|
| 54 |
+
self.prob = prob
|
| 55 |
+
self.parent = None
|
| 56 |
+
|
| 57 |
+
def set_parent(self, parent):
|
| 58 |
+
self.parent = parent
|
| 59 |
+
|
| 60 |
+
def __call__(self, image, boxes):
|
| 61 |
+
if torch.rand(1) >= self.prob:
|
| 62 |
+
return image, boxes
|
| 63 |
+
|
| 64 |
+
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size."
|
| 65 |
+
|
| 66 |
+
img_sz = self.parent.image_size # Assuming `image_size` is defined in parent
|
| 67 |
+
more_data = self.parent.get_more_data(3) # get 3 more images randomly
|
| 68 |
+
|
| 69 |
+
data = [(image, boxes)] + more_data
|
| 70 |
+
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz))
|
| 71 |
+
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)])
|
| 72 |
+
center = np.array([img_sz, img_sz])
|
| 73 |
+
all_labels = []
|
| 74 |
+
|
| 75 |
+
for (image, boxes), vector in zip(data, vectors):
|
| 76 |
+
this_w, this_h = image.size
|
| 77 |
+
coord = tuple(center + vector * np.array([this_w, this_h]))
|
| 78 |
+
|
| 79 |
+
mosaic_image.paste(image, coord)
|
| 80 |
+
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4]
|
| 81 |
+
xmin = (xmin * this_w + coord[0]) / (2 * img_sz)
|
| 82 |
+
xmax = (xmax * this_w + coord[0]) / (2 * img_sz)
|
| 83 |
+
ymin = (ymin * this_h + coord[1]) / (2 * img_sz)
|
| 84 |
+
ymax = (ymax * this_h + coord[1]) / (2 * img_sz)
|
| 85 |
+
|
| 86 |
+
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1)
|
| 87 |
+
all_labels.append(adjusted_boxes)
|
| 88 |
+
|
| 89 |
+
all_labels = torch.cat(all_labels, dim=0)
|
| 90 |
+
mosaic_image = mosaic_image.resize((img_sz, img_sz))
|
| 91 |
+
return mosaic_image, all_labels
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class MixUp:
|
| 95 |
+
"""Applies the MixUp augmentation to a pair of images and their corresponding boxes."""
|
| 96 |
+
|
| 97 |
+
def __init__(self, prob=0.5, alpha=1.0):
|
| 98 |
+
self.alpha = alpha
|
| 99 |
+
self.prob = prob
|
| 100 |
+
self.parent = None
|
| 101 |
+
|
| 102 |
+
def set_parent(self, parent):
|
| 103 |
+
"""Set the parent dataset object for accessing dataset methods."""
|
| 104 |
+
self.parent = parent
|
| 105 |
+
|
| 106 |
+
def __call__(self, image, boxes):
|
| 107 |
+
if torch.rand(1) >= self.prob:
|
| 108 |
+
return image, boxes
|
| 109 |
+
|
| 110 |
+
assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data."
|
| 111 |
+
|
| 112 |
+
# Retrieve another image and its boxes randomly from the dataset
|
| 113 |
+
image2, boxes2 = self.parent.get_more_data()[0]
|
| 114 |
+
|
| 115 |
+
# Calculate the mixup lambda parameter
|
| 116 |
+
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5
|
| 117 |
+
|
| 118 |
+
# Mix images
|
| 119 |
+
image1, image2 = TF.to_tensor(image), TF.to_tensor(image2)
|
| 120 |
+
mixed_image = lam * image1 + (1 - lam) * image2
|
| 121 |
+
|
| 122 |
+
# Mix bounding boxes
|
| 123 |
+
mixed_boxes = torch.cat([lam * boxes, (1 - lam) * boxes2])
|
| 124 |
+
|
| 125 |
+
return TF.to_pil_image(mixed_image), mixed_boxes
|
yolo/utils/dataloader.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from os import path
|
| 3 |
+
from typing import List, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import diskcache as dc
|
| 6 |
+
import hydra
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from torchvision.transforms import functional as TF
|
| 13 |
+
from tqdm.rich import tqdm
|
| 14 |
+
|
| 15 |
+
from yolo.tools.dataset_helper import (
|
| 16 |
+
create_image_info_dict,
|
| 17 |
+
find_labels_path,
|
| 18 |
+
get_scaled_segmentation,
|
| 19 |
+
)
|
| 20 |
+
from yolo.utils.data_augment import Compose, HorizontalFlip, MixUp, Mosaic, VerticalFlip
|
| 21 |
+
from yolo.utils.drawer import draw_bboxes
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class YoloDataset(Dataset):
|
| 25 |
+
def __init__(self, config: dict, phase: str = "train2017", image_size: int = 640):
|
| 26 |
+
dataset_cfg = config.data
|
| 27 |
+
augment_cfg = config.augmentation
|
| 28 |
+
phase_name = dataset_cfg.get(phase, phase)
|
| 29 |
+
self.image_size = image_size
|
| 30 |
+
|
| 31 |
+
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
| 32 |
+
self.transform = Compose(transforms, self.image_size)
|
| 33 |
+
self.transform.get_more_data = self.get_more_data
|
| 34 |
+
self.data = self.load_data(dataset_cfg.path, phase_name)
|
| 35 |
+
|
| 36 |
+
def load_data(self, dataset_path, phase_name):
|
| 37 |
+
"""
|
| 38 |
+
Loads data from a cache or generates a new cache for a specific dataset phase.
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
dataset_path (str): The root path to the dataset directory.
|
| 42 |
+
phase_name (str): The specific phase of the dataset (e.g., 'train', 'test') to load or generate data for.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
dict: The loaded data from the cache for the specified phase.
|
| 46 |
+
"""
|
| 47 |
+
cache_path = path.join(dataset_path, ".cache")
|
| 48 |
+
cache = dc.Cache(cache_path)
|
| 49 |
+
data = cache.get(phase_name)
|
| 50 |
+
|
| 51 |
+
if data is None:
|
| 52 |
+
logger.info("Generating {} cache", phase_name)
|
| 53 |
+
data = self.filter_data(dataset_path, phase_name)
|
| 54 |
+
cache[phase_name] = data
|
| 55 |
+
|
| 56 |
+
cache.close()
|
| 57 |
+
logger.info("π¦ Loaded {} cache", phase_name)
|
| 58 |
+
data = cache[phase_name]
|
| 59 |
+
return data
|
| 60 |
+
|
| 61 |
+
def filter_data(self, dataset_path: str, phase_name: str) -> list:
|
| 62 |
+
"""
|
| 63 |
+
Filters and collects dataset information by pairing images with their corresponding labels.
|
| 64 |
+
|
| 65 |
+
Parameters:
|
| 66 |
+
images_path (str): Path to the directory containing image files.
|
| 67 |
+
labels_path (str): Path to the directory containing label files.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
list: A list of tuples, each containing the path to an image file and its associated segmentation as a tensor.
|
| 71 |
+
"""
|
| 72 |
+
images_path = path.join(dataset_path, "images", phase_name)
|
| 73 |
+
labels_path, data_type = find_labels_path(dataset_path, phase_name)
|
| 74 |
+
images_list = sorted(os.listdir(images_path))
|
| 75 |
+
if data_type == "json":
|
| 76 |
+
annotations_index, image_info_dict = create_image_info_dict(labels_path)
|
| 77 |
+
|
| 78 |
+
data = []
|
| 79 |
+
valid_inputs = 0
|
| 80 |
+
for image_name in tqdm(images_list, desc="Filtering data"):
|
| 81 |
+
if not image_name.lower().endswith((".jpg", ".jpeg", ".png")):
|
| 82 |
+
continue
|
| 83 |
+
image_id, _ = path.splitext(image_name)
|
| 84 |
+
|
| 85 |
+
if data_type == "json":
|
| 86 |
+
image_info = image_info_dict.get(image_id, None)
|
| 87 |
+
if image_info is None:
|
| 88 |
+
continue
|
| 89 |
+
annotations = annotations_index.get(image_info["id"], [])
|
| 90 |
+
image_seg_annotations = get_scaled_segmentation(annotations, image_info)
|
| 91 |
+
if not image_seg_annotations:
|
| 92 |
+
continue
|
| 93 |
+
|
| 94 |
+
elif data_type == "txt":
|
| 95 |
+
label_path = path.join(labels_path, f"{image_id}.txt")
|
| 96 |
+
if not path.isfile(label_path):
|
| 97 |
+
continue
|
| 98 |
+
with open(label_path, "r") as file:
|
| 99 |
+
image_seg_annotations = [list(map(float, line.strip().split())) for line in file]
|
| 100 |
+
|
| 101 |
+
labels = self.load_valid_labels(image_id, image_seg_annotations)
|
| 102 |
+
if labels is not None:
|
| 103 |
+
img_path = path.join(images_path, image_name)
|
| 104 |
+
data.append((img_path, labels))
|
| 105 |
+
valid_inputs += 1
|
| 106 |
+
|
| 107 |
+
logger.info("Recorded {}/{} valid inputs", valid_inputs, len(images_list))
|
| 108 |
+
return data
|
| 109 |
+
|
| 110 |
+
def load_valid_labels(self, label_path, seg_data_one_img) -> Union[torch.Tensor, None]:
|
| 111 |
+
"""
|
| 112 |
+
Loads and validates bounding box data is [0, 1] from a label file.
|
| 113 |
+
|
| 114 |
+
Parameters:
|
| 115 |
+
label_path (str): The filepath to the label file containing bounding box data.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
torch.Tensor or None: A tensor of all valid bounding boxes if any are found; otherwise, None.
|
| 119 |
+
"""
|
| 120 |
+
bboxes = []
|
| 121 |
+
for seg_data in seg_data_one_img:
|
| 122 |
+
cls = seg_data[0]
|
| 123 |
+
points = np.array(seg_data[1:]).reshape(-1, 2)
|
| 124 |
+
valid_points = points[(points >= 0) & (points <= 1)].reshape(-1, 2)
|
| 125 |
+
if valid_points.size > 1:
|
| 126 |
+
bbox = torch.tensor([cls, *valid_points.min(axis=0), *valid_points.max(axis=0)])
|
| 127 |
+
bboxes.append(bbox)
|
| 128 |
+
|
| 129 |
+
if bboxes:
|
| 130 |
+
return torch.stack(bboxes)
|
| 131 |
+
else:
|
| 132 |
+
logger.warning("No valid BBox in {}", label_path)
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
def get_data(self, idx):
|
| 136 |
+
img_path, bboxes = self.data[idx]
|
| 137 |
+
img = Image.open(img_path).convert("RGB")
|
| 138 |
+
return img, bboxes
|
| 139 |
+
|
| 140 |
+
def get_more_data(self, num: int = 1):
|
| 141 |
+
indices = torch.randint(0, len(self), (num,))
|
| 142 |
+
return [self.get_data(idx) for idx in indices]
|
| 143 |
+
|
| 144 |
+
def __getitem__(self, idx) -> Union[Image.Image, torch.Tensor]:
|
| 145 |
+
img, bboxes = self.get_data(idx)
|
| 146 |
+
if self.transform:
|
| 147 |
+
img, bboxes = self.transform(img, bboxes)
|
| 148 |
+
img = TF.to_tensor(img)
|
| 149 |
+
return img, bboxes
|
| 150 |
+
|
| 151 |
+
def __len__(self) -> int:
|
| 152 |
+
return len(self.data)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class YoloDataLoader(DataLoader):
|
| 156 |
+
def __init__(self, config: dict):
|
| 157 |
+
"""Initializes the YoloDataLoader with hydra-config files."""
|
| 158 |
+
hyper = config.hyper.data
|
| 159 |
+
dataset = YoloDataset(config)
|
| 160 |
+
|
| 161 |
+
super().__init__(
|
| 162 |
+
dataset,
|
| 163 |
+
batch_size=hyper.batch_size,
|
| 164 |
+
shuffle=hyper.shuffle,
|
| 165 |
+
num_workers=hyper.num_workers,
|
| 166 |
+
pin_memory=hyper.pin_memory,
|
| 167 |
+
collate_fn=self.collate_fn,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
def collate_fn(self, batch: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 171 |
+
"""
|
| 172 |
+
A collate function to handle batching of images and their corresponding targets.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
batch (list of tuples): Each tuple contains:
|
| 176 |
+
- image (torch.Tensor): The image tensor.
|
| 177 |
+
- labels (torch.Tensor): The tensor of labels for the image.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
Tuple[torch.Tensor, List[torch.Tensor]]: A tuple containing:
|
| 181 |
+
- A tensor of batched images.
|
| 182 |
+
- A list of tensors, each corresponding to bboxes for each image in the batch.
|
| 183 |
+
"""
|
| 184 |
+
images = torch.stack([item[0] for item in batch])
|
| 185 |
+
targets = [item[1] for item in batch]
|
| 186 |
+
return images, targets
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def get_dataloader(config):
|
| 190 |
+
return YoloDataLoader(config)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
@hydra.main(config_path="../config", config_name="config", version_base=None)
|
| 194 |
+
def main(cfg):
|
| 195 |
+
dataloader = get_dataloader(cfg)
|
| 196 |
+
draw_bboxes(*next(iter(dataloader)))
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
import sys
|
| 201 |
+
|
| 202 |
+
sys.path.append("./")
|
| 203 |
+
from tools.log_helper import custom_logger
|
| 204 |
+
|
| 205 |
+
custom_logger()
|
| 206 |
+
main()
|
yolo/utils/drawer.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from loguru import logger
|
| 5 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 6 |
+
from torchvision.transforms.functional import to_pil_image
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def draw_bboxes(img: Union[Image.Image, torch.Tensor], bboxes: List[List[Union[int, float]]]):
|
| 10 |
+
"""
|
| 11 |
+
Draw bounding boxes on an image.
|
| 12 |
+
|
| 13 |
+
Args:
|
| 14 |
+
- img (PIL Image or torch.Tensor): Image on which to draw the bounding boxes.
|
| 15 |
+
- bboxes (List of Lists/Tensors): Bounding boxes with [class_id, x_min, y_min, x_max, y_max],
|
| 16 |
+
where coordinates are normalized [0, 1].
|
| 17 |
+
"""
|
| 18 |
+
# Convert tensor image to PIL Image if necessary
|
| 19 |
+
if isinstance(img, torch.Tensor):
|
| 20 |
+
if img.dim() > 3:
|
| 21 |
+
logger.info("Multi-frame tensor detected, using the first image.")
|
| 22 |
+
img = img[0]
|
| 23 |
+
bboxes = bboxes[0]
|
| 24 |
+
img = to_pil_image(img)
|
| 25 |
+
|
| 26 |
+
draw = ImageDraw.Draw(img)
|
| 27 |
+
width, height = img.size
|
| 28 |
+
font = ImageFont.load_default(30)
|
| 29 |
+
|
| 30 |
+
for bbox in bboxes:
|
| 31 |
+
class_id, x_min, y_min, x_max, y_max = bbox
|
| 32 |
+
x_min = x_min * width
|
| 33 |
+
x_max = x_max * width
|
| 34 |
+
y_min = y_min * height
|
| 35 |
+
y_max = y_max * height
|
| 36 |
+
shape = [(x_min, y_min), (x_max, y_max)]
|
| 37 |
+
draw.rectangle(shape, outline="red", width=3)
|
| 38 |
+
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
| 39 |
+
|
| 40 |
+
img.save("visualize.jpg") # Save the image with annotations
|
| 41 |
+
logger.info("Saved visualize image at visualize.png")
|
yolo/utils/get_dataset.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import zipfile
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
from hydra import main
|
| 6 |
+
from loguru import logger
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def download_file(url, destination):
|
| 11 |
+
"""
|
| 12 |
+
Downloads a file from the specified URL to the destination path with progress logging.
|
| 13 |
+
"""
|
| 14 |
+
logger.info(f"Downloading {os.path.basename(destination)}...")
|
| 15 |
+
with requests.get(url, stream=True) as response:
|
| 16 |
+
response.raise_for_status()
|
| 17 |
+
total_size = int(response.headers.get("content-length", 0))
|
| 18 |
+
progress = tqdm(total=total_size, unit="iB", unit_scale=True, desc=os.path.basename(destination), leave=True)
|
| 19 |
+
|
| 20 |
+
with open(destination, "wb") as file:
|
| 21 |
+
for data in response.iter_content(chunk_size=1024 * 1024): # 1 MB chunks
|
| 22 |
+
file.write(data)
|
| 23 |
+
progress.update(len(data))
|
| 24 |
+
progress.close()
|
| 25 |
+
logger.info("Download completed.")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def unzip_file(source, destination):
|
| 29 |
+
"""
|
| 30 |
+
Extracts a ZIP file to the specified directory and removes the ZIP file after extraction.
|
| 31 |
+
"""
|
| 32 |
+
logger.info(f"Unzipping {os.path.basename(source)}...")
|
| 33 |
+
with zipfile.ZipFile(source, "r") as zip_ref:
|
| 34 |
+
zip_ref.extractall(destination)
|
| 35 |
+
os.remove(source)
|
| 36 |
+
logger.info(f"Removed {source}.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def check_files(directory, expected_count=None):
|
| 40 |
+
"""
|
| 41 |
+
Returns True if the number of files in the directory matches expected_count, False otherwise.
|
| 42 |
+
"""
|
| 43 |
+
files = [f for f in os.listdir(directory) if os.path.isfile(os.path.join(directory, f))]
|
| 44 |
+
return len(files) == expected_count if expected_count is not None else bool(files)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@main(config_path="../config/data", config_name="download", version_base=None)
|
| 48 |
+
def prepare_dataset(cfg):
|
| 49 |
+
"""
|
| 50 |
+
Prepares dataset by downloading and unzipping if necessary.
|
| 51 |
+
"""
|
| 52 |
+
data_dir = cfg.save_path
|
| 53 |
+
for data_type, settings in cfg.datasets.items():
|
| 54 |
+
base_url = settings["base_url"]
|
| 55 |
+
for dataset_type, dataset_args in settings.items():
|
| 56 |
+
if dataset_type == "base_url":
|
| 57 |
+
continue # Skip the base_url entry
|
| 58 |
+
file_name = f"{dataset_args.get('file_name', dataset_type)}.zip"
|
| 59 |
+
url = f"{base_url}{file_name}"
|
| 60 |
+
local_zip_path = os.path.join(data_dir, file_name)
|
| 61 |
+
extract_to = os.path.join(data_dir, data_type) if data_type != "annotations" else data_dir
|
| 62 |
+
final_place = os.path.join(extract_to, dataset_type)
|
| 63 |
+
|
| 64 |
+
os.makedirs(extract_to, exist_ok=True)
|
| 65 |
+
if check_files(final_place, dataset_args.get("file_num")):
|
| 66 |
+
logger.info(f"Dataset {dataset_type} already verified.")
|
| 67 |
+
continue
|
| 68 |
+
|
| 69 |
+
if not os.path.exists(local_zip_path):
|
| 70 |
+
download_file(url, local_zip_path)
|
| 71 |
+
unzip_file(local_zip_path, extract_to)
|
| 72 |
+
|
| 73 |
+
if not check_files(final_place, dataset_args.get("file_num")):
|
| 74 |
+
logger.error(f"Error verifying the {dataset_type} dataset after extraction.")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
import sys
|
| 79 |
+
|
| 80 |
+
sys.path.append("./")
|
| 81 |
+
from tools.log_helper import custom_logger
|
| 82 |
+
|
| 83 |
+
custom_logger()
|
| 84 |
+
prepare_dataset()
|
yolo/utils/loss.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def get_loss_function(*args, **kwargs):
|
| 2 |
+
raise NotImplementedError
|