Spaces:
Sleeping
Sleeping
✅ [Pass] Train, Model, Loss Test
Browse files- tests/test_utils/test_loss.py +2 -3
- yolo/utils/loss.py +1 -1
tests/test_utils/test_loss.py
CHANGED
|
@@ -27,14 +27,13 @@ def loss_function(cfg) -> YOLOLoss:
|
|
| 27 |
def data():
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
targets = torch.zeros(1, 20, 5, device=device)
|
| 30 |
-
predicts = [
|
| 31 |
return predicts, targets
|
| 32 |
|
| 33 |
|
| 34 |
def test_yolo_loss(loss_function, data):
|
| 35 |
predicts, targets = data
|
| 36 |
-
|
| 37 |
-
assert torch.isnan(loss)
|
| 38 |
assert torch.isnan(loss_iou)
|
| 39 |
assert torch.isnan(loss_dfl)
|
| 40 |
assert torch.isinf(loss_cls)
|
|
|
|
| 27 |
def data():
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
targets = torch.zeros(1, 20, 5, device=device)
|
| 30 |
+
predicts = [torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]]
|
| 31 |
return predicts, targets
|
| 32 |
|
| 33 |
|
| 34 |
def test_yolo_loss(loss_function, data):
|
| 35 |
predicts, targets = data
|
| 36 |
+
loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
|
|
|
|
| 37 |
assert torch.isnan(loss_iou)
|
| 38 |
assert torch.isnan(loss_dfl)
|
| 39 |
assert torch.isinf(loss_cls)
|
yolo/utils/loss.py
CHANGED
|
@@ -80,7 +80,7 @@ class YOLOLoss:
|
|
| 80 |
self.strides = cfg.model.anchor.strides
|
| 81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
|
| 83 |
-
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.
|
| 84 |
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
| 85 |
|
| 86 |
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|
|
|
|
| 80 |
self.strides = cfg.model.anchor.strides
|
| 81 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 82 |
|
| 83 |
+
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
|
| 84 |
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
| 85 |
|
| 86 |
self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
|