Spaces:
Paused
Paused
Commit
·
e2ce7e5
1
Parent(s):
327742a
Remove redundant part of our_ref in inference.
Browse files- app.py +1 -1
- models/baseline.py +24 -22
app.py
CHANGED
|
@@ -35,7 +35,7 @@ class ImagePreprocessor():
|
|
| 35 |
return image
|
| 36 |
|
| 37 |
|
| 38 |
-
model = BiRefNet().to(device)
|
| 39 |
state_dict = './BiRefNet_ep580.pth'
|
| 40 |
if os.path.exists(state_dict):
|
| 41 |
birefnet_dict = torch.load(state_dict, map_location=device)
|
|
|
|
| 35 |
return image
|
| 36 |
|
| 37 |
|
| 38 |
+
model = BiRefNet(bb_pretrained=False).to(device)
|
| 39 |
state_dict = './BiRefNet_ep580.pth'
|
| 40 |
if os.path.exists(state_dict):
|
| 41 |
birefnet_dict = torch.load(state_dict, map_location=device)
|
models/baseline.py
CHANGED
|
@@ -20,11 +20,11 @@ from models.refinement.stem_layer import StemLayer
|
|
| 20 |
|
| 21 |
|
| 22 |
class BiRefNet(nn.Module):
|
| 23 |
-
def __init__(self):
|
| 24 |
super(BiRefNet, self).__init__()
|
| 25 |
self.config = Config()
|
| 26 |
self.epoch = 1
|
| 27 |
-
self.bb = build_backbone(self.config.bb, pretrained=
|
| 28 |
|
| 29 |
channels = self.config.lateral_channels_in_collection
|
| 30 |
|
|
@@ -126,7 +126,7 @@ class BiRefNet(nn.Module):
|
|
| 126 |
x4 = self.squeeze_module(x4)
|
| 127 |
########## Decoder ##########
|
| 128 |
features = [x, x1, x2, x3, x4]
|
| 129 |
-
if self.config.out_ref:
|
| 130 |
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
|
| 131 |
scaled_preds = self.decoder(features)
|
| 132 |
return scaled_preds, class_preds
|
|
@@ -231,7 +231,7 @@ class Decoder(nn.Module):
|
|
| 231 |
return torch.cat(patches_batch, dim=0)
|
| 232 |
|
| 233 |
def forward(self, features):
|
| 234 |
-
if self.config.out_ref:
|
| 235 |
outs_gdt_pred = []
|
| 236 |
outs_gdt_label = []
|
| 237 |
x, x1, x2, x3, x4, gdt_gt = features
|
|
@@ -249,18 +249,19 @@ class Decoder(nn.Module):
|
|
| 249 |
p3 = self.decoder_block3(_p3)
|
| 250 |
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
|
| 251 |
if self.config.out_ref:
|
| 252 |
-
# >> GT:
|
| 253 |
-
# m3 --dilation--> m3_dia
|
| 254 |
-
# G_3^gt * m3_dia --> G_3^m, which is the label of gradient
|
| 255 |
-
m3_dia = m3
|
| 256 |
-
gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
| 257 |
-
outs_gdt_label.append(gdt_label_main_3)
|
| 258 |
-
# >> Pred:
|
| 259 |
-
# p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
|
| 260 |
-
# F_3^G --sigmoid--> A_3^G
|
| 261 |
p3_gdt = self.gdt_convs_3(p3)
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
| 265 |
# >> Finally:
|
| 266 |
# p3 = p3 * A_3^G
|
|
@@ -274,14 +275,15 @@ class Decoder(nn.Module):
|
|
| 274 |
p2 = self.decoder_block2(_p2)
|
| 275 |
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
|
| 276 |
if self.config.out_ref:
|
| 277 |
-
# >> GT:
|
| 278 |
-
m2_dia = m2
|
| 279 |
-
gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
| 280 |
-
outs_gdt_label.append(gdt_label_main_2)
|
| 281 |
-
# >> Pred:
|
| 282 |
p2_gdt = self.gdt_convs_2(p2)
|
| 283 |
-
|
| 284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
| 286 |
# >> Finally:
|
| 287 |
p2 = p2 * gdt_attn_2
|
|
|
|
| 20 |
|
| 21 |
|
| 22 |
class BiRefNet(nn.Module):
|
| 23 |
+
def __init__(self, bb_pretrained=True):
|
| 24 |
super(BiRefNet, self).__init__()
|
| 25 |
self.config = Config()
|
| 26 |
self.epoch = 1
|
| 27 |
+
self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
|
| 28 |
|
| 29 |
channels = self.config.lateral_channels_in_collection
|
| 30 |
|
|
|
|
| 126 |
x4 = self.squeeze_module(x4)
|
| 127 |
########## Decoder ##########
|
| 128 |
features = [x, x1, x2, x3, x4]
|
| 129 |
+
if self.training and self.config.out_ref:
|
| 130 |
features.append(laplacian(torch.mean(x, dim=1).unsqueeze(1), kernel_size=5))
|
| 131 |
scaled_preds = self.decoder(features)
|
| 132 |
return scaled_preds, class_preds
|
|
|
|
| 231 |
return torch.cat(patches_batch, dim=0)
|
| 232 |
|
| 233 |
def forward(self, features):
|
| 234 |
+
if self.training and self.config.out_ref:
|
| 235 |
outs_gdt_pred = []
|
| 236 |
outs_gdt_label = []
|
| 237 |
x, x1, x2, x3, x4, gdt_gt = features
|
|
|
|
| 249 |
p3 = self.decoder_block3(_p3)
|
| 250 |
m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
|
| 251 |
if self.config.out_ref:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
p3_gdt = self.gdt_convs_3(p3)
|
| 253 |
+
if self.training:
|
| 254 |
+
# >> GT:
|
| 255 |
+
# m3 --dilation--> m3_dia
|
| 256 |
+
# G_3^gt * m3_dia --> G_3^m, which is the label of gradient
|
| 257 |
+
m3_dia = m3
|
| 258 |
+
gdt_label_main_3 = gdt_gt * F.interpolate(m3_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
| 259 |
+
outs_gdt_label.append(gdt_label_main_3)
|
| 260 |
+
# >> Pred:
|
| 261 |
+
# p3 --conv--BN--> F_3^G, where F_3^G predicts the \hat{G_3} with xx
|
| 262 |
+
# F_3^G --sigmoid--> A_3^G
|
| 263 |
+
gdt_pred_3 = self.gdt_convs_pred_3(p3_gdt)
|
| 264 |
+
outs_gdt_pred.append(gdt_pred_3)
|
| 265 |
gdt_attn_3 = self.gdt_convs_attn_3(p3_gdt).sigmoid()
|
| 266 |
# >> Finally:
|
| 267 |
# p3 = p3 * A_3^G
|
|
|
|
| 275 |
p2 = self.decoder_block2(_p2)
|
| 276 |
m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
|
| 277 |
if self.config.out_ref:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
p2_gdt = self.gdt_convs_2(p2)
|
| 279 |
+
if self.training:
|
| 280 |
+
# >> GT:
|
| 281 |
+
m2_dia = m2
|
| 282 |
+
gdt_label_main_2 = gdt_gt * F.interpolate(m2_dia, size=gdt_gt.shape[2:], mode='bilinear', align_corners=True)
|
| 283 |
+
outs_gdt_label.append(gdt_label_main_2)
|
| 284 |
+
# >> Pred:
|
| 285 |
+
gdt_pred_2 = self.gdt_convs_pred_2(p2_gdt)
|
| 286 |
+
outs_gdt_pred.append(gdt_pred_2)
|
| 287 |
gdt_attn_2 = self.gdt_convs_attn_2(p2_gdt).sigmoid()
|
| 288 |
# >> Finally:
|
| 289 |
p2 = p2 * gdt_attn_2
|