🔨 [Add] weight initailize for Detection
Browse files- yolo/model/module.py +4 -1
- yolo/tools/trainer.py +1 -1
yolo/model/module.py
CHANGED
|
@@ -25,7 +25,7 @@ class Conv(nn.Module):
|
|
| 25 |
super().__init__()
|
| 26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
| 27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
| 28 |
-
self.bn = nn.BatchNorm2d(out_channels)
|
| 29 |
self.act = get_activation(activation)
|
| 30 |
|
| 31 |
def forward(self, x: Tensor) -> Tensor:
|
|
@@ -69,6 +69,9 @@ class Detection(nn.Module):
|
|
| 69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
| 70 |
)
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
| 73 |
anchor_x = self.anchor_conv(x)
|
| 74 |
class_x = self.class_conv(x)
|
|
|
|
| 25 |
super().__init__()
|
| 26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
| 27 |
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
| 28 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=1e-3, momentum=3e-2)
|
| 29 |
self.act = get_activation(activation)
|
| 30 |
|
| 31 |
def forward(self, x: Tensor) -> Tensor:
|
|
|
|
| 69 |
Conv(in_channels, class_neck, 3), Conv(class_neck, class_neck, 3), nn.Conv2d(class_neck, num_classes, 1)
|
| 70 |
)
|
| 71 |
|
| 72 |
+
self.anchor_conv[-1].bias.data.fill_(1.0)
|
| 73 |
+
self.class_conv[-1].bias.data.fill_(-10)
|
| 74 |
+
|
| 75 |
def forward(self, x: List[Tensor]) -> List[Tensor]:
|
| 76 |
anchor_x = self.anchor_conv(x)
|
| 77 |
class_x = self.class_conv(x)
|
yolo/tools/trainer.py
CHANGED
|
@@ -79,7 +79,7 @@ class Trainer:
|
|
| 79 |
self.progress.start_train(num_epochs)
|
| 80 |
for epoch in range(num_epochs):
|
| 81 |
|
| 82 |
-
epoch_loss = self.train_one_epoch(dataloader
|
| 83 |
self.progress.one_epoch()
|
| 84 |
|
| 85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
|
|
|
| 79 |
self.progress.start_train(num_epochs)
|
| 80 |
for epoch in range(num_epochs):
|
| 81 |
|
| 82 |
+
epoch_loss = self.train_one_epoch(dataloader)
|
| 83 |
self.progress.one_epoch()
|
| 84 |
|
| 85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|