Spaces:
Sleeping
Sleeping
🚑️ [Fix] when creating onnx, trt force use cuda
Browse files
yolo/utils/deploy_utils.py
CHANGED
|
@@ -51,8 +51,8 @@ class FastModelLoader:
|
|
| 51 |
from onnxruntime import InferenceSession
|
| 52 |
from torch.onnx import export
|
| 53 |
|
| 54 |
-
model = create_model(self.cfg).eval()
|
| 55 |
-
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 56 |
export(
|
| 57 |
model,
|
| 58 |
dummy_input,
|
|
@@ -80,8 +80,8 @@ class FastModelLoader:
|
|
| 80 |
def _create_trt_weight(self):
|
| 81 |
from torch2trt import torch2trt
|
| 82 |
|
| 83 |
-
model = create_model(self.cfg).eval()
|
| 84 |
-
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 85 |
logger.info(f"♻️ Creating TensorRT model")
|
| 86 |
model_trt = torch2trt(model, [dummy_input])
|
| 87 |
torch.save(model_trt.state_dict(), self.weight)
|
|
|
|
| 51 |
from onnxruntime import InferenceSession
|
| 52 |
from torch.onnx import export
|
| 53 |
|
| 54 |
+
model = create_model(self.cfg).eval()
|
| 55 |
+
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 56 |
export(
|
| 57 |
model,
|
| 58 |
dummy_input,
|
|
|
|
| 80 |
def _create_trt_weight(self):
|
| 81 |
from torch2trt import torch2trt
|
| 82 |
|
| 83 |
+
model = create_model(self.cfg).eval()
|
| 84 |
+
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 85 |
logger.info(f"♻️ Creating TensorRT model")
|
| 86 |
model_trt = torch2trt(model, [dummy_input])
|
| 87 |
torch.save(model_trt.state_dict(), self.weight)
|