Spaces:
Sleeping
Sleeping
๐ [Rename] get_model to create, rename examples
Browse files- examples/lazy.py +0 -37
- examples/notebook_colab.ipynb +0 -0
- examples/notebook_inference.ipynb +25 -0
- examples/notebook_train.ipynb +25 -0
- examples/{example_inference.py โ sample_inference.py} +3 -7
- examples/{example_train.py โ sample_train.py} +2 -2
- tests/test_model/test_yolo.py +4 -4
- yolo/lazy.py +2 -2
- yolo/model/yolo.py +1 -1
- yolo/tools/drawer.py +4 -3
examples/lazy.py
DELETED
|
@@ -1,37 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
from pathlib import Path
|
| 3 |
-
|
| 4 |
-
import hydra
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
project_root = Path(__file__).resolve().parent.parent
|
| 8 |
-
sys.path.append(str(project_root))
|
| 9 |
-
|
| 10 |
-
from yolo.config.config import Config
|
| 11 |
-
from yolo.model.yolo import get_model
|
| 12 |
-
from yolo.tools.data_loader import create_dataloader
|
| 13 |
-
from yolo.tools.solver import ModelTester, ModelTrainer
|
| 14 |
-
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 18 |
-
def main(cfg: Config):
|
| 19 |
-
custom_logger()
|
| 20 |
-
|
| 21 |
-
custom_logger()
|
| 22 |
-
save_path = validate_log_directory(cfg, cfg.name)
|
| 23 |
-
dataloader = create_dataloader(cfg)
|
| 24 |
-
device = torch.device(cfg.device)
|
| 25 |
-
model = get_model(cfg).to(device)
|
| 26 |
-
|
| 27 |
-
if cfg.task.task == "train":
|
| 28 |
-
trainer = ModelTrainer(cfg, model, save_path, device)
|
| 29 |
-
trainer.solve(dataloader)
|
| 30 |
-
|
| 31 |
-
if cfg.task.task == "inference":
|
| 32 |
-
tester = ModelTester(cfg, model, save_path, device)
|
| 33 |
-
tester.solve(dataloader)
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
if __name__ == "__main__":
|
| 37 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
examples/notebook_colab.ipynb
ADDED
|
File without changes
|
examples/notebook_inference.ipynb
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": []
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": []
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"metadata": {
|
| 19 |
+
"language_info": {
|
| 20 |
+
"name": "python"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"nbformat": 4,
|
| 24 |
+
"nbformat_minor": 2
|
| 25 |
+
}
|
examples/notebook_train.ipynb
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": null,
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"outputs": [],
|
| 8 |
+
"source": []
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"cell_type": "code",
|
| 12 |
+
"execution_count": null,
|
| 13 |
+
"metadata": {},
|
| 14 |
+
"outputs": [],
|
| 15 |
+
"source": []
|
| 16 |
+
}
|
| 17 |
+
],
|
| 18 |
+
"metadata": {
|
| 19 |
+
"language_info": {
|
| 20 |
+
"name": "python"
|
| 21 |
+
}
|
| 22 |
+
},
|
| 23 |
+
"nbformat": 4,
|
| 24 |
+
"nbformat_minor": 2
|
| 25 |
+
}
|
examples/{example_inference.py โ sample_inference.py}
RENAMED
|
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
-
from yolo.model.yolo import
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTester
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
@@ -17,15 +17,11 @@ from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
| 17 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 18 |
def main(cfg: Config):
|
| 19 |
custom_logger()
|
| 20 |
-
save_path = validate_log_directory(cfg, cfg.name)
|
| 21 |
-
|
| 22 |
-
device = torch.device(cfg.device)
|
| 23 |
-
model = get_model(cfg).to(device)
|
| 24 |
-
|
| 25 |
save_path = validate_log_directory(cfg, cfg.name)
|
| 26 |
dataloader = create_dataloader(cfg)
|
|
|
|
| 27 |
device = torch.device(cfg.device)
|
| 28 |
-
model =
|
| 29 |
|
| 30 |
tester = ModelTester(cfg, model, save_path, device)
|
| 31 |
tester.solve(dataloader)
|
|
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import create_model
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTester
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
|
|
| 17 |
@hydra.main(config_path="../yolo/config", config_name="config", version_base=None)
|
| 18 |
def main(cfg: Config):
|
| 19 |
custom_logger()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
save_path = validate_log_directory(cfg, cfg.name)
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
+
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
+
model = create_model(cfg).to(device)
|
| 25 |
|
| 26 |
tester = ModelTester(cfg, model, save_path, device)
|
| 27 |
tester.solve(dataloader)
|
examples/{example_train.py โ sample_train.py}
RENAMED
|
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
-
from yolo.model.yolo import
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTrainer
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
@@ -21,7 +21,7 @@ def main(cfg: Config):
|
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
# TODO: get_device or rank, for DDP mode
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
-
model =
|
| 25 |
|
| 26 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
| 27 |
trainer.solve(dataloader, cfg.task.epoch)
|
|
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import create_model
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTrainer
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
|
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
# TODO: get_device or rank, for DDP mode
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
+
model = create_model(cfg).to(device)
|
| 25 |
|
| 26 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
| 27 |
trainer.solve(dataloader, cfg.task.epoch)
|
tests/test_model/test_yolo.py
CHANGED
|
@@ -8,7 +8,7 @@ from omegaconf import OmegaConf
|
|
| 8 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 9 |
sys.path.append(str(project_root))
|
| 10 |
|
| 11 |
-
from yolo.model.yolo import YOLO,
|
| 12 |
|
| 13 |
config_path = "../../yolo/config"
|
| 14 |
config_name = "config"
|
|
@@ -24,18 +24,18 @@ def test_build_model():
|
|
| 24 |
assert len(model.model) == 38
|
| 25 |
|
| 26 |
|
| 27 |
-
def
|
| 28 |
with initialize(config_path=config_path, version_base=None):
|
| 29 |
cfg = compose(config_name=config_name)
|
| 30 |
cfg.weight = None
|
| 31 |
-
model =
|
| 32 |
assert isinstance(model, YOLO)
|
| 33 |
|
| 34 |
|
| 35 |
def test_yolo_forward_output_shape():
|
| 36 |
with initialize(config_path=config_path, version_base=None):
|
| 37 |
cfg = compose(config_name=config_name)
|
| 38 |
-
model =
|
| 39 |
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
| 40 |
dummy_input = torch.rand(2, 3, 640, 640)
|
| 41 |
|
|
|
|
| 8 |
project_root = Path(__file__).resolve().parent.parent.parent
|
| 9 |
sys.path.append(str(project_root))
|
| 10 |
|
| 11 |
+
from yolo.model.yolo import YOLO, create_model
|
| 12 |
|
| 13 |
config_path = "../../yolo/config"
|
| 14 |
config_name = "config"
|
|
|
|
| 24 |
assert len(model.model) == 38
|
| 25 |
|
| 26 |
|
| 27 |
+
def test_create_model():
|
| 28 |
with initialize(config_path=config_path, version_base=None):
|
| 29 |
cfg = compose(config_name=config_name)
|
| 30 |
cfg.weight = None
|
| 31 |
+
model = create_model(cfg)
|
| 32 |
assert isinstance(model, YOLO)
|
| 33 |
|
| 34 |
|
| 35 |
def test_yolo_forward_output_shape():
|
| 36 |
with initialize(config_path=config_path, version_base=None):
|
| 37 |
cfg = compose(config_name=config_name)
|
| 38 |
+
model = create_model(cfg)
|
| 39 |
# 2 - batch size, 3 - number of channels, 640x640 - image dimensions
|
| 40 |
dummy_input = torch.rand(2, 3, 640, 640)
|
| 41 |
|
yolo/lazy.py
CHANGED
|
@@ -8,7 +8,7 @@ project_root = Path(__file__).resolve().parent.parent
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
-
from yolo.model.yolo import
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
@@ -20,7 +20,7 @@ def main(cfg: Config):
|
|
| 20 |
save_path = validate_log_directory(cfg, cfg.name)
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
device = torch.device(cfg.device)
|
| 23 |
-
model =
|
| 24 |
|
| 25 |
if cfg.task.task == "train":
|
| 26 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
|
|
|
| 8 |
sys.path.append(str(project_root))
|
| 9 |
|
| 10 |
from yolo.config.config import Config
|
| 11 |
+
from yolo.model.yolo import create_model
|
| 12 |
from yolo.tools.data_loader import create_dataloader
|
| 13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
| 14 |
from yolo.utils.logging_utils import custom_logger, validate_log_directory
|
|
|
|
| 20 |
save_path = validate_log_directory(cfg, cfg.name)
|
| 21 |
dataloader = create_dataloader(cfg)
|
| 22 |
device = torch.device(cfg.device)
|
| 23 |
+
model = create_model(cfg).to(device)
|
| 24 |
|
| 25 |
if cfg.task.task == "train":
|
| 26 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
yolo/model/yolo.py
CHANGED
|
@@ -116,7 +116,7 @@ class YOLO(nn.Module):
|
|
| 116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 117 |
|
| 118 |
|
| 119 |
-
def
|
| 120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 121 |
|
| 122 |
Args:
|
|
|
|
| 116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 117 |
|
| 118 |
|
| 119 |
+
def create_model(cfg: Config) -> YOLO:
|
| 120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 121 |
|
| 122 |
Args:
|
yolo/tools/drawer.py
CHANGED
|
@@ -14,6 +14,7 @@ def draw_bboxes(
|
|
| 14 |
*,
|
| 15 |
scaled_bbox: bool = True,
|
| 16 |
save_path: str = "",
|
|
|
|
| 17 |
):
|
| 18 |
"""
|
| 19 |
Draw bounding boxes on an image.
|
|
@@ -46,7 +47,7 @@ def draw_bboxes(
|
|
| 46 |
draw.rectangle(shape, outline="red", width=3)
|
| 47 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
| 48 |
|
| 49 |
-
save_image_path = os.path.join(save_path,
|
| 50 |
img.save(save_image_path) # Save the image with annotations
|
| 51 |
logger.info(f"๐พ Saved visualize image at {save_image_path}")
|
| 52 |
return img
|
|
@@ -56,9 +57,9 @@ def draw_model(*, model_cfg=None, model=None, v7_base=False):
|
|
| 56 |
from graphviz import Digraph
|
| 57 |
|
| 58 |
if model_cfg:
|
| 59 |
-
from yolo.model.yolo import
|
| 60 |
|
| 61 |
-
model =
|
| 62 |
elif model is None:
|
| 63 |
raise ValueError("Drawing Object is None")
|
| 64 |
|
|
|
|
| 14 |
*,
|
| 15 |
scaled_bbox: bool = True,
|
| 16 |
save_path: str = "",
|
| 17 |
+
save_name: str = "visualize.png",
|
| 18 |
):
|
| 19 |
"""
|
| 20 |
Draw bounding boxes on an image.
|
|
|
|
| 47 |
draw.rectangle(shape, outline="red", width=3)
|
| 48 |
draw.text((x_min, y_min), str(int(class_id)), font=font, fill="blue")
|
| 49 |
|
| 50 |
+
save_image_path = os.path.join(save_path, save_name)
|
| 51 |
img.save(save_image_path) # Save the image with annotations
|
| 52 |
logger.info(f"๐พ Saved visualize image at {save_image_path}")
|
| 53 |
return img
|
|
|
|
| 57 |
from graphviz import Digraph
|
| 58 |
|
| 59 |
if model_cfg:
|
| 60 |
+
from yolo.model.yolo import create_model
|
| 61 |
|
| 62 |
+
model = create_model(model_cfg)
|
| 63 |
elif model is None:
|
| 64 |
raise ValueError("Drawing Object is None")
|
| 65 |
|