๐ [Add] deploy option, auto remove aux head
Browse files
examples/notebook_inference.ipynb
CHANGED
|
@@ -43,8 +43,8 @@
|
|
| 43 |
"outputs": [],
|
| 44 |
"source": [
|
| 45 |
"with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
|
| 46 |
-
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-
|
| 47 |
-
" model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH
|
| 48 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
| 49 |
" vec2box = Vec2Box(model, cfg.image_size, device)"
|
| 50 |
]
|
|
@@ -70,7 +70,7 @@
|
|
| 70 |
" predict = vec2box(predict[\"Main\"])\n",
|
| 71 |
"\n",
|
| 72 |
"predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
|
| 73 |
-
"draw_bboxes(image, predict_box,
|
| 74 |
]
|
| 75 |
},
|
| 76 |
{
|
|
@@ -81,13 +81,6 @@
|
|
| 81 |
"\n",
|
| 82 |
""
|
| 83 |
]
|
| 84 |
-
},
|
| 85 |
-
{
|
| 86 |
-
"cell_type": "code",
|
| 87 |
-
"execution_count": null,
|
| 88 |
-
"metadata": {},
|
| 89 |
-
"outputs": [],
|
| 90 |
-
"source": []
|
| 91 |
}
|
| 92 |
],
|
| 93 |
"metadata": {
|
|
|
|
| 43 |
"outputs": [],
|
| 44 |
"source": [
|
| 45 |
"with initialize(config_path=CONFIG_PATH, version_base=None, job_name=\"notebook_job\"):\n",
|
| 46 |
+
" cfg: Config = compose(config_name=CONFIG_NAME, overrides=[\"task=inference\", f\"task.data.source={IMAGE_PATH}\", \"model=v9-m\"])\n",
|
| 47 |
+
" model = create_model(cfg.model, class_num=CLASS_NUM, weight_path=WEIGHT_PATH, device = device)\n",
|
| 48 |
" transform = AugmentationComposer([], cfg.image_size)\n",
|
| 49 |
" vec2box = Vec2Box(model, cfg.image_size, device)"
|
| 50 |
]
|
|
|
|
| 70 |
" predict = vec2box(predict[\"Main\"])\n",
|
| 71 |
"\n",
|
| 72 |
"predict_box = bbox_nms(predict[0], predict[2], cfg.task.nms)\n",
|
| 73 |
+
"draw_bboxes(image, predict_box, idx2label=cfg.class_list)"
|
| 74 |
]
|
| 75 |
},
|
| 76 |
{
|
|
|
|
| 81 |
"\n",
|
| 82 |
""
|
| 83 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
}
|
| 85 |
],
|
| 86 |
"metadata": {
|
yolo/utils/deploy_utils.py
CHANGED
|
@@ -9,14 +9,15 @@ from yolo.model.yolo import create_model
|
|
| 9 |
|
| 10 |
|
| 11 |
class FastModelLoader:
|
| 12 |
-
def __init__(self, cfg: Config):
|
| 13 |
self.cfg = cfg
|
|
|
|
| 14 |
self.compiler = cfg.task.fast_inference
|
| 15 |
self._validate_compiler()
|
| 16 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
| 17 |
|
| 18 |
def _validate_compiler(self):
|
| 19 |
-
if self.compiler not in ["onnx", "trt"]:
|
| 20 |
logger.warning(f"โ ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
|
| 21 |
self.compiler = None
|
| 22 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
|
@@ -28,7 +29,11 @@ class FastModelLoader:
|
|
| 28 |
return self._load_onnx_model()
|
| 29 |
elif self.compiler == "trt":
|
| 30 |
return self._load_trt_model()
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
def _load_onnx_model(self):
|
| 34 |
from onnxruntime import InferenceSession
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class FastModelLoader:
|
| 12 |
+
def __init__(self, cfg: Config, device):
|
| 13 |
self.cfg = cfg
|
| 14 |
+
self.device = device
|
| 15 |
self.compiler = cfg.task.fast_inference
|
| 16 |
self._validate_compiler()
|
| 17 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
| 18 |
|
| 19 |
def _validate_compiler(self):
|
| 20 |
+
if self.compiler not in ["onnx", "trt", "deploy"]:
|
| 21 |
logger.warning(f"โ ๏ธ Compiler '{self.compiler}' is not supported. Using original model.")
|
| 22 |
self.compiler = None
|
| 23 |
if self.cfg.device == "mps" and self.compiler == "trt":
|
|
|
|
| 29 |
return self._load_onnx_model()
|
| 30 |
elif self.compiler == "trt":
|
| 31 |
return self._load_trt_model()
|
| 32 |
+
elif self.compiler == "deploy":
|
| 33 |
+
self.cfg.model.model.auxiliary = {}
|
| 34 |
+
return create_model(
|
| 35 |
+
self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight, device=self.device
|
| 36 |
+
)
|
| 37 |
|
| 38 |
def _load_onnx_model(self):
|
| 39 |
from onnxruntime import InferenceSession
|