Spaces:
Sleeping
Sleeping
🚸 [Add] converter function, auto choose conveter
Browse files- examples/notebook_inference.ipynb +10 -21
- yolo/__init__.py +2 -1
- yolo/lazy.py +6 -5
- yolo/utils/bounding_box_utils.py +9 -1
examples/notebook_inference.ipynb
CHANGED
|
@@ -16,8 +16,15 @@
|
|
| 16 |
"project_root = Path().resolve().parent\n",
|
| 17 |
"sys.path.append(str(project_root))\n",
|
| 18 |
"\n",
|
| 19 |
-
"from yolo import
|
| 20 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
]
|
| 22 |
},
|
| 23 |
{
|
|
@@ -48,8 +55,7 @@
|
|
| 48 |
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
|
| 49 |
" model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
|
| 50 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
| 51 |
-
" converter =
|
| 52 |
-
" # converter = Vec2Box(model, cfg.model.anchor, cfg.image_size, device)\n",
|
| 53 |
" post_proccess = PostProccess(converter, cfg.task.nms)"
|
| 54 |
]
|
| 55 |
},
|
|
@@ -86,23 +92,6 @@
|
|
| 86 |
"\n",
|
| 87 |
""
|
| 88 |
]
|
| 89 |
-
},
|
| 90 |
-
{
|
| 91 |
-
"cell_type": "code",
|
| 92 |
-
"execution_count": null,
|
| 93 |
-
"metadata": {},
|
| 94 |
-
"outputs": [],
|
| 95 |
-
"source": [
|
| 96 |
-
"%load_ext autoreload\n",
|
| 97 |
-
"%autoreload 2"
|
| 98 |
-
]
|
| 99 |
-
},
|
| 100 |
-
{
|
| 101 |
-
"cell_type": "code",
|
| 102 |
-
"execution_count": null,
|
| 103 |
-
"metadata": {},
|
| 104 |
-
"outputs": [],
|
| 105 |
-
"source": []
|
| 106 |
}
|
| 107 |
],
|
| 108 |
"metadata": {
|
|
|
|
| 16 |
"project_root = Path().resolve().parent\n",
|
| 17 |
"sys.path.append(str(project_root))\n",
|
| 18 |
"\n",
|
| 19 |
+
"from yolo import (\n",
|
| 20 |
+
" AugmentationComposer,\n",
|
| 21 |
+
" Config,\n",
|
| 22 |
+
" PostProccess,\n",
|
| 23 |
+
" create_converter,\n",
|
| 24 |
+
" create_model,\n",
|
| 25 |
+
" custom_logger,\n",
|
| 26 |
+
" draw_bboxes,\n",
|
| 27 |
+
")"
|
| 28 |
]
|
| 29 |
},
|
| 30 |
{
|
|
|
|
| 55 |
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", f\"model={MODEL}\"])\n",
|
| 56 |
" model = create_model(cfg.model, class_num=CLASS_NUM).to(device)\n",
|
| 57 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
| 58 |
+
" converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)\n",
|
|
|
|
| 59 |
" post_proccess = PostProccess(converter, cfg.task.nms)"
|
| 60 |
]
|
| 61 |
},
|
|
|
|
| 92 |
"\n",
|
| 93 |
""
|
| 94 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
}
|
| 96 |
],
|
| 97 |
"metadata": {
|
yolo/__init__.py
CHANGED
|
@@ -3,7 +3,7 @@ from yolo.model.yolo import create_model
|
|
| 3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
| 4 |
from yolo.tools.drawer import draw_bboxes
|
| 5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
| 6 |
-
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms
|
| 7 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 8 |
from yolo.utils.logging_utils import custom_logger
|
| 9 |
from yolo.utils.model_utils import PostProccess
|
|
@@ -18,6 +18,7 @@ all = [
|
|
| 18 |
"Vec2Box",
|
| 19 |
"Anc2Box",
|
| 20 |
"bbox_nms",
|
|
|
|
| 21 |
"AugmentationComposer",
|
| 22 |
"create_dataloader",
|
| 23 |
"FastModelLoader",
|
|
|
|
| 3 |
from yolo.tools.data_loader import AugmentationComposer, create_dataloader
|
| 4 |
from yolo.tools.drawer import draw_bboxes
|
| 5 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
| 6 |
+
from yolo.utils.bounding_box_utils import Anc2Box, Vec2Box, bbox_nms, create_converter
|
| 7 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 8 |
from yolo.utils.logging_utils import custom_logger
|
| 9 |
from yolo.utils.model_utils import PostProccess
|
|
|
|
| 18 |
"Vec2Box",
|
| 19 |
"Anc2Box",
|
| 20 |
"bbox_nms",
|
| 21 |
+
"create_converter",
|
| 22 |
"AugmentationComposer",
|
| 23 |
"create_dataloader",
|
| 24 |
"FastModelLoader",
|
yolo/lazy.py
CHANGED
|
@@ -10,7 +10,7 @@ from yolo.config.config import Config
|
|
| 10 |
from yolo.model.yolo import create_model
|
| 11 |
from yolo.tools.data_loader import create_dataloader
|
| 12 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
| 13 |
-
from yolo.utils.bounding_box_utils import
|
| 14 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 15 |
from yolo.utils.logging_utils import ProgressLogger
|
| 16 |
from yolo.utils.model_utils import get_device
|
|
@@ -27,13 +27,14 @@ def main(cfg: Config):
|
|
| 27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
| 28 |
model = model.to(device)
|
| 29 |
|
| 30 |
-
|
|
|
|
| 31 |
if cfg.task.task == "train":
|
| 32 |
-
solver = ModelTrainer(cfg, model,
|
| 33 |
if cfg.task.task == "validation":
|
| 34 |
-
solver = ModelValidator(cfg.task, cfg.dataset, model,
|
| 35 |
if cfg.task.task == "inference":
|
| 36 |
-
solver = ModelTester(cfg, model,
|
| 37 |
progress.start()
|
| 38 |
solver.solve(dataloader)
|
| 39 |
|
|
|
|
| 10 |
from yolo.model.yolo import create_model
|
| 11 |
from yolo.tools.data_loader import create_dataloader
|
| 12 |
from yolo.tools.solver import ModelTester, ModelTrainer, ModelValidator
|
| 13 |
+
from yolo.utils.bounding_box_utils import create_converter
|
| 14 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 15 |
from yolo.utils.logging_utils import ProgressLogger
|
| 16 |
from yolo.utils.model_utils import get_device
|
|
|
|
| 27 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
| 28 |
model = model.to(device)
|
| 29 |
|
| 30 |
+
converter = create_converter(cfg.model.name, model, cfg.model.anchor, cfg.image_size, device)
|
| 31 |
+
|
| 32 |
if cfg.task.task == "train":
|
| 33 |
+
solver = ModelTrainer(cfg, model, converter, progress, device, use_ddp)
|
| 34 |
if cfg.task.task == "validation":
|
| 35 |
+
solver = ModelValidator(cfg.task, cfg.dataset, model, converter, progress, device)
|
| 36 |
if cfg.task.task == "inference":
|
| 37 |
+
solver = ModelTester(cfg, model, converter, progress, device)
|
| 38 |
progress.start()
|
| 39 |
solver.solve(dataloader)
|
| 40 |
|
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -364,7 +364,15 @@ class Anc2Box:
|
|
| 364 |
return preds_cls, None, preds_box, preds_cnf.sigmoid()
|
| 365 |
|
| 366 |
|
| 367 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
|
| 369 |
|
| 370 |
# filter class by confidence
|
|
|
|
| 364 |
return preds_cls, None, preds_box, preds_cnf.sigmoid()
|
| 365 |
|
| 366 |
|
| 367 |
+
def create_converter(model_version: str = "v9-c", *args, **kwargs):
|
| 368 |
+
if "v7" in model_version: # check model if v7
|
| 369 |
+
converter = Anc2Box(*args, **kwargs)
|
| 370 |
+
else:
|
| 371 |
+
converter = Vec2Box(*args, **kwargs)
|
| 372 |
+
return converter
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def bbox_nms(cls_dist: Tensor, bbox: Tensor, nms_cfg: NMSConfig, confidence: Optional[Tensor] = None):
|
| 376 |
cls_dist = cls_dist.sigmoid() * (1 if confidence is None else confidence)
|
| 377 |
|
| 378 |
# filter class by confidence
|