Upload PDFNet.py
Browse files- models/PDFNet.py +97 -27
models/PDFNet.py
CHANGED
|
@@ -18,12 +18,12 @@ class PDF_depth_decoder(nn.Module):
|
|
| 18 |
|
| 19 |
emb_dim = 128
|
| 20 |
self.Decoder = nn.ModuleList()
|
| 21 |
-
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
|
| 22 |
-
self.Decoder.append(nn.Sequential(make_crs(emb_dim*
|
| 23 |
-
self.Decoder.append(nn.Sequential(make_crs(emb_dim*
|
| 24 |
-
self.Decoder.append(nn.Sequential(make_crs(emb_dim*
|
| 25 |
|
| 26 |
-
self.shallow = nn.Sequential(nn.Conv2d(raw_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
| 27 |
self.upsample1 = make_crs(emb_dim,emb_dim)
|
| 28 |
self.upsample2 = make_crs(emb_dim,emb_dim)
|
| 29 |
|
|
@@ -34,7 +34,7 @@ class PDF_depth_decoder(nn.Module):
|
|
| 34 |
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
|
| 35 |
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
|
| 36 |
|
| 37 |
-
def forward(self,img,img_feature):
|
| 38 |
|
| 39 |
L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature
|
| 40 |
|
|
@@ -46,7 +46,7 @@ class PDF_depth_decoder(nn.Module):
|
|
| 46 |
|
| 47 |
De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1))
|
| 48 |
|
| 49 |
-
shallow = self.shallow(img)
|
| 50 |
final_output = De_L1 + _upsample_like(shallow, De_L1)
|
| 51 |
final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]))
|
| 52 |
final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2])
|
|
@@ -65,7 +65,8 @@ class CoA(nn.Module):
|
|
| 65 |
def __init__(self, emb_dim=128):
|
| 66 |
super(CoA, self).__init__()
|
| 67 |
self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
|
| 68 |
-
self.
|
|
|
|
| 69 |
self.drop1 = nn.Dropout(0.1)
|
| 70 |
self.FFN = SwiGLU(emb_dim,emb_dim)
|
| 71 |
self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
|
|
@@ -73,13 +74,31 @@ class CoA(nn.Module):
|
|
| 73 |
|
| 74 |
def forward(self,q,kv):
|
| 75 |
res = q
|
| 76 |
-
KV_feature = self.Att(q, kv, kv)[0]
|
| 77 |
-
KV_feature = self.
|
| 78 |
res = KV_feature
|
| 79 |
-
KV_feature = self.FFN(KV_feature)
|
| 80 |
-
KV_feature = self.
|
| 81 |
return KV_feature
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
class FSE(nn.Module):
|
| 85 |
def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4):
|
|
@@ -149,11 +168,11 @@ class FSE(nn.Module):
|
|
| 149 |
|
| 150 |
def BIS(self,pred):
|
| 151 |
if pred.shape[-2]//8 % 2 == 0:
|
| 152 |
-
boundary =
|
| 153 |
-
return boundary, F.relu(pred.sigmoid()-
|
| 154 |
else:
|
| 155 |
-
boundary =
|
| 156 |
-
return boundary, F.relu(pred.sigmoid()-
|
| 157 |
|
| 158 |
def forward(self,img,depth,patch,last_pred):
|
| 159 |
boundary,integrity = self.BIS(last_pred)
|
|
@@ -171,7 +190,7 @@ class FSE(nn.Module):
|
|
| 171 |
#give depth the integrity prior
|
| 172 |
integrity = _upsample_like(integrity,depth)
|
| 173 |
last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid()
|
| 174 |
-
enhance_depth = depth*(last_pred_sigmoid
|
| 175 |
depth_cs = self.D_channelswich(enhance_depth)
|
| 176 |
pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
|
| 177 |
pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
|
|
@@ -181,7 +200,8 @@ class FSE(nn.Module):
|
|
| 181 |
patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
|
| 182 |
boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
|
| 183 |
boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
|
| 184 |
-
select_patch = patch_batch * (1+
|
|
|
|
| 185 |
select_patch = self.merge(select_patch,batch_size=B)
|
| 186 |
|
| 187 |
patch_cs = self.P_channelswich(select_patch)
|
|
@@ -203,6 +223,52 @@ class FSE(nn.Module):
|
|
| 203 |
|
| 204 |
return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs
|
| 205 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
|
| 207 |
class PDF_decoder(nn.Module):
|
| 208 |
def __init__(self, args,raw_ch=3,out_ch=1):
|
|
@@ -287,7 +353,7 @@ class PDFNet_process(nn.Module):
|
|
| 287 |
emb = args.emb
|
| 288 |
self.Glob = nn.Sequential(make_crs(emb,emb))
|
| 289 |
self.decoder = decoder
|
| 290 |
-
self.depth_decoder = depth_decoder
|
| 291 |
self.decoder.patch_ratio = self.patch_ratio
|
| 292 |
self.args=args
|
| 293 |
|
|
@@ -448,17 +514,18 @@ class PDFNet_process(nn.Module):
|
|
| 448 |
[Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob],
|
| 449 |
[patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob])
|
| 450 |
|
| 451 |
-
pred_depth = self.depth_decoder(RIMG,[latent_I1
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
|
| 455 |
-
|
| 456 |
|
| 457 |
loss, target_loss = self.loss_compute(pred_m,RGT)
|
| 458 |
integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT)
|
| 459 |
depth_loss,_ = self.depth_loss(pred_depth,depth_gt)
|
| 460 |
|
| 461 |
loss = loss + integrity_loss/2 + depth_loss/10
|
|
|
|
| 462 |
|
| 463 |
if self.args.DEBUG:
|
| 464 |
print(pred_m[0].shape)
|
|
@@ -467,11 +534,13 @@ class PDFNet_process(nn.Module):
|
|
| 467 |
RDEPTH.reshape([-1,H,W])[:1].cpu().detach(),
|
| 468 |
RGT.reshape([-1,H,W])[:1].cpu().detach(),
|
| 469 |
pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
|
| 470 |
-
_upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
|
|
|
|
| 471 |
show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray')
|
| 472 |
return [i.sigmoid() for i in pred_m], loss, target_loss
|
| 473 |
|
| 474 |
@torch.no_grad()
|
|
|
|
| 475 |
def inference(self,img,depth):
|
| 476 |
depth = (depth-depth.min())/(depth.max()-depth.min())
|
| 477 |
B,C,H,W = img.size()
|
|
@@ -506,6 +575,7 @@ class PDFNet_process(nn.Module):
|
|
| 506 |
|
| 507 |
def build_model(args):
|
| 508 |
if args.back_bone == 'PDFNet_swinB':
|
| 509 |
-
return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=
|
| 510 |
decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args),
|
| 511 |
-
device=args.device, args=args),args.model
|
|
|
|
|
|
| 18 |
|
| 19 |
emb_dim = 128
|
| 20 |
self.Decoder = nn.ModuleList()
|
| 21 |
+
self.Decoder.append(nn.Sequential(make_crs(emb_dim*2*3,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
|
| 22 |
+
self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
|
| 23 |
+
self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
|
| 24 |
+
self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
|
| 25 |
|
| 26 |
+
self.shallow = nn.Sequential(nn.Conv2d(raw_ch*2, emb_dim, kernel_size=3, stride=1, padding=1))
|
| 27 |
self.upsample1 = make_crs(emb_dim,emb_dim)
|
| 28 |
self.upsample2 = make_crs(emb_dim,emb_dim)
|
| 29 |
|
|
|
|
| 34 |
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
|
| 35 |
self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
|
| 36 |
|
| 37 |
+
def forward(self,img,depth,img_feature):
|
| 38 |
|
| 39 |
L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature
|
| 40 |
|
|
|
|
| 46 |
|
| 47 |
De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1))
|
| 48 |
|
| 49 |
+
shallow = self.shallow(torch.cat([img,depth],dim=1))
|
| 50 |
final_output = De_L1 + _upsample_like(shallow, De_L1)
|
| 51 |
final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]))
|
| 52 |
final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2])
|
|
|
|
| 65 |
def __init__(self, emb_dim=128):
|
| 66 |
super(CoA, self).__init__()
|
| 67 |
self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
|
| 68 |
+
self.Normq = RMSNorm(emb_dim,data_format='channels_last')
|
| 69 |
+
self.Normkv = RMSNorm(emb_dim,data_format='channels_last')
|
| 70 |
self.drop1 = nn.Dropout(0.1)
|
| 71 |
self.FFN = SwiGLU(emb_dim,emb_dim)
|
| 72 |
self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
|
|
|
|
| 74 |
|
| 75 |
def forward(self,q,kv):
|
| 76 |
res = q
|
| 77 |
+
KV_feature = self.Att(self.Normq(q), self.Normkv(kv), self.Normkv(kv))[0]
|
| 78 |
+
KV_feature = self.drop1(KV_feature) + res
|
| 79 |
res = KV_feature
|
| 80 |
+
KV_feature = self.FFN(self.Norm2(KV_feature))
|
| 81 |
+
KV_feature = self.drop2(KV_feature) + res
|
| 82 |
return KV_feature
|
| 83 |
|
| 84 |
+
class CoA(nn.Module):
|
| 85 |
+
def __init__(self, emb_dim=128):
|
| 86 |
+
super(CoA, self).__init__()
|
| 87 |
+
self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
|
| 88 |
+
self.Norm1 = RMSNorm(emb_dim,data_format='channels_last')
|
| 89 |
+
self.drop1 = nn.Dropout(0.1)
|
| 90 |
+
self.FFN = SwiGLU(emb_dim,emb_dim)
|
| 91 |
+
self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
|
| 92 |
+
self.drop2 = nn.Dropout(0.1)
|
| 93 |
+
|
| 94 |
+
def forward(self,q,kv):
|
| 95 |
+
res = q
|
| 96 |
+
KV_feature = self.Norm1(self.Att(q,kv,kv)[0])
|
| 97 |
+
KV_feature = self.drop1(KV_feature) + res
|
| 98 |
+
res = KV_feature
|
| 99 |
+
KV_feature = self.Norm2(self.FFN(KV_feature))
|
| 100 |
+
KV_feature = self.drop2(KV_feature) + res
|
| 101 |
+
return KV_feature
|
| 102 |
|
| 103 |
class FSE(nn.Module):
|
| 104 |
def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4):
|
|
|
|
| 168 |
|
| 169 |
def BIS(self,pred):
|
| 170 |
if pred.shape[-2]//8 % 2 == 0:
|
| 171 |
+
boundary = (self.get_boundary(pred.sigmoid())>0.1).float()
|
| 172 |
+
return boundary, F.relu(pred.sigmoid()-boundary)
|
| 173 |
else:
|
| 174 |
+
boundary = self.get_boundary(pred.sigmoid())
|
| 175 |
+
return boundary, F.relu(pred.sigmoid()-boundary)
|
| 176 |
|
| 177 |
def forward(self,img,depth,patch,last_pred):
|
| 178 |
boundary,integrity = self.BIS(last_pred)
|
|
|
|
| 190 |
#give depth the integrity prior
|
| 191 |
integrity = _upsample_like(integrity,depth)
|
| 192 |
last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid()
|
| 193 |
+
enhance_depth = depth*(last_pred_sigmoid+integrity)
|
| 194 |
depth_cs = self.D_channelswich(enhance_depth)
|
| 195 |
pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
|
| 196 |
pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
|
|
|
|
| 200 |
patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
|
| 201 |
boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
|
| 202 |
boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
|
| 203 |
+
select_patch = patch_batch * (1+(boundary_score>0).float())
|
| 204 |
+
# select_patch = patch_batch*0
|
| 205 |
select_patch = self.merge(select_patch,batch_size=B)
|
| 206 |
|
| 207 |
patch_cs = self.P_channelswich(select_patch)
|
|
|
|
| 223 |
|
| 224 |
return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs
|
| 225 |
|
| 226 |
+
# def forward(self,img,depth,patch,last_pred):
|
| 227 |
+
# boundary,integrity = self.BIS(last_pred)
|
| 228 |
+
# # img = img * _upsample_like(last_pred.sigmoid(),img)
|
| 229 |
+
# # depth = depth * _upsample_like(last_pred.sigmoid(),depth)
|
| 230 |
+
# # patch = patch * _upsample_like(last_pred.sigmoid(),patch)
|
| 231 |
+
# pi,pd,pp = self.pool_ratio
|
| 232 |
+
# B,C,img_H,img_W = img.size()
|
| 233 |
+
# img_cs = self.I_channelswich(img* (1+_upsample_like(integrity,depth)))
|
| 234 |
+
# pool_img_cs = F.adaptive_avg_pool2d(img_cs,output_size=[img_H//pi,img_W//pi])
|
| 235 |
+
# # img_cs = rearrange(img_cs, 'b c h w -> b (h w) c')
|
| 236 |
+
# pool_img_cs = rearrange(pool_img_cs, 'b c h w -> b (h w) c')
|
| 237 |
+
# B,C,depth_H,depth_W = depth.size()
|
| 238 |
+
|
| 239 |
+
# #give depth the integrity prior
|
| 240 |
+
# enhance_depth = depth * _upsample_like(last_pred.sigmoid(),depth)
|
| 241 |
+
# depth_cs = self.D_channelswich(enhance_depth)
|
| 242 |
+
# pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
|
| 243 |
+
# depth_cs = rearrange(depth_cs, 'b c h w -> b (h w) c')
|
| 244 |
+
# pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
|
| 245 |
+
# B,C,patch_H,patch_W = patch.size()
|
| 246 |
+
|
| 247 |
+
# #select the boundary patches to select patches
|
| 248 |
+
# patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
|
| 249 |
+
# boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
|
| 250 |
+
# boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
|
| 251 |
+
# select_patch = patch_batch * (1+(boundary_score>0).float())
|
| 252 |
+
# select_patch = self.merge(select_patch,batch_size=B)
|
| 253 |
+
|
| 254 |
+
# patch_cs = self.P_channelswich(select_patch)
|
| 255 |
+
# pool_patch_cs = F.adaptive_avg_pool2d(patch_cs,output_size=[patch_H//pp,patch_W//pp])
|
| 256 |
+
# pool_patch_cs = rearrange(pool_patch_cs, 'b c h w -> b (h w) c')
|
| 257 |
+
|
| 258 |
+
# patch_feature = self.PI(pool_patch_cs, torch.cat([pool_img_cs,pool_depth_cs],dim=1))
|
| 259 |
+
# depth_feature = self.IP(depth_cs,patch_feature)
|
| 260 |
+
|
| 261 |
+
# img_feature = self.DI(pool_img_cs, torch.cat([pool_img_cs,pool_patch_cs],dim=1))
|
| 262 |
+
# depth_feature = self.ID(depth_feature,img_feature)
|
| 263 |
+
|
| 264 |
+
# patch_feature = rearrange(patch_feature, 'b (h w) c -> b c h w',h=patch_H//pp)
|
| 265 |
+
# depth_feature = rearrange(depth_feature, 'b (h w) c -> b c h w',h=depth_H)
|
| 266 |
+
# img_feature = rearrange(img_feature, 'b (h w) c -> b c h w',h=img_H//pi)
|
| 267 |
+
|
| 268 |
+
# img_feature = _upsample_like(img_feature,img)
|
| 269 |
+
# patch_feature = _upsample_like(patch_feature,patch)
|
| 270 |
+
|
| 271 |
+
# return img_feature + img_cs, depth_feature + rearrange(depth_cs, 'b (h w) c -> b c h w',h=depth_H), patch_feature + patch_cs
|
| 272 |
|
| 273 |
class PDF_decoder(nn.Module):
|
| 274 |
def __init__(self, args,raw_ch=3,out_ch=1):
|
|
|
|
| 353 |
emb = args.emb
|
| 354 |
self.Glob = nn.Sequential(make_crs(emb,emb))
|
| 355 |
self.decoder = decoder
|
| 356 |
+
# self.depth_decoder = depth_decoder
|
| 357 |
self.decoder.patch_ratio = self.patch_ratio
|
| 358 |
self.args=args
|
| 359 |
|
|
|
|
| 514 |
[Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob],
|
| 515 |
[patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob])
|
| 516 |
|
| 517 |
+
pred_depth = self.depth_decoder(RIMG,RDEPTH,[torch.cat([latent_I1,Depth_latent_I1,_upsample_like(patch_latent_I1,latent_I1)],dim=1),
|
| 518 |
+
torch.cat([latent_I2,Depth_latent_I2,_upsample_like(patch_latent_I2,latent_I2)],dim=1),
|
| 519 |
+
torch.cat([latent_I3,Depth_latent_I3,_upsample_like(patch_latent_I3,latent_I3)],dim=1),
|
| 520 |
+
torch.cat([latent_I4,Depth_latent_I4,_upsample_like(patch_latent_I4,latent_I4)],dim=1),
|
| 521 |
+
torch.cat([x_glob,Depth_x_glob,_upsample_like(patch_x_glob,x_glob)],dim=1)])
|
| 522 |
|
| 523 |
loss, target_loss = self.loss_compute(pred_m,RGT)
|
| 524 |
integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT)
|
| 525 |
depth_loss,_ = self.depth_loss(pred_depth,depth_gt)
|
| 526 |
|
| 527 |
loss = loss + integrity_loss/2 + depth_loss/10
|
| 528 |
+
# loss = loss + integrity_loss/2
|
| 529 |
|
| 530 |
if self.args.DEBUG:
|
| 531 |
print(pred_m[0].shape)
|
|
|
|
| 534 |
RDEPTH.reshape([-1,H,W])[:1].cpu().detach(),
|
| 535 |
RGT.reshape([-1,H,W])[:1].cpu().detach(),
|
| 536 |
pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
|
| 537 |
+
# _upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
|
| 538 |
+
],dim=0)
|
| 539 |
show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray')
|
| 540 |
return [i.sigmoid() for i in pred_m], loss, target_loss
|
| 541 |
|
| 542 |
@torch.no_grad()
|
| 543 |
+
|
| 544 |
def inference(self,img,depth):
|
| 545 |
depth = (depth-depth.min())/(depth.max()-depth.min())
|
| 546 |
B,C,H,W = img.size()
|
|
|
|
| 575 |
|
| 576 |
def build_model(args):
|
| 577 |
if args.back_bone == 'PDFNet_swinB':
|
| 578 |
+
return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=True),
|
| 579 |
decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args),
|
| 580 |
+
device=args.device, args=args),args.model
|
| 581 |
+
|