Spaces:
Runtime error
Runtime error
Update registry.py
Browse files- registry.py +33 -43
registry.py
CHANGED
|
@@ -1,43 +1,33 @@
|
|
| 1 |
-
import logging
|
| 2 |
-
|
| 3 |
-
# Configure Logger
|
| 4 |
-
logger = logging.getLogger(__name__)
|
| 5 |
-
|
| 6 |
-
def get_model(task: str, model_key: str, device="cpu"):
|
| 7 |
-
"""
|
| 8 |
-
Dynamically retrieves the model instance based on the task and model_key.
|
| 9 |
-
|
| 10 |
-
Args:
|
| 11 |
-
task (str): One of "detection", "segmentation", or "depth".
|
| 12 |
-
model_key (str): Model identifier or variant.
|
| 13 |
-
device (str): Device to run inference on ("cpu" or "cuda").
|
| 14 |
-
|
| 15 |
-
Returns:
|
| 16 |
-
object:
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
""
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
return DepthEstimator(model_key=model_key, device=device)
|
| 35 |
-
|
| 36 |
-
else:
|
| 37 |
-
error_msg = f"Unsupported task '{task}'. Valid options are: 'detection', 'segmentation', 'depth'."
|
| 38 |
-
logger.error(error_msg)
|
| 39 |
-
raise ValueError(error_msg)
|
| 40 |
-
|
| 41 |
-
except Exception as e:
|
| 42 |
-
logger.error(f"Error while loading model '{model_key}' for task '{task}': {e}")
|
| 43 |
-
raise
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
# Configure Logger
|
| 4 |
+
logger = logging.getLogger(__name__)
|
| 5 |
+
|
| 6 |
+
def get_model(task: str, model_key: str, device="cpu"):
|
| 7 |
+
"""
|
| 8 |
+
Dynamically retrieves the model instance based on the task and model_key.
|
| 9 |
+
|
| 10 |
+
Args:
|
| 11 |
+
task (str): One of "detection", "segmentation", or "depth".
|
| 12 |
+
model_key (str): Model identifier or variant.
|
| 13 |
+
device (str): Device to run inference on ("cpu" or "cuda").
|
| 14 |
+
|
| 15 |
+
Returns:
|
| 16 |
+
object: Uninitialized model instance.
|
| 17 |
+
"""
|
| 18 |
+
logger.info(f"Preparing model wrapper '{model_key}' for task '{task}' on device '{device}'")
|
| 19 |
+
|
| 20 |
+
if task == "detection":
|
| 21 |
+
from models.detection.detector import ObjectDetector
|
| 22 |
+
return ObjectDetector(model_key=model_key, device=device)
|
| 23 |
+
|
| 24 |
+
elif task == "segmentation":
|
| 25 |
+
from models.segmentation.segmenter import Segmenter
|
| 26 |
+
return Segmenter(model_key=model_key, device=device)
|
| 27 |
+
|
| 28 |
+
elif task == "depth":
|
| 29 |
+
from models.depth.depth_estimator import DepthEstimator
|
| 30 |
+
return DepthEstimator(model_key=model_key, device=device)
|
| 31 |
+
|
| 32 |
+
else:
|
| 33 |
+
raise ValueError(f"Unsupported task '{task}'")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|