Tennineee commited on
Commit
447dfce
·
verified ·
1 Parent(s): 0bbbc19

Upload PDFNet.py

Browse files
Files changed (1) hide show
  1. models/PDFNet.py +97 -27
models/PDFNet.py CHANGED
@@ -18,12 +18,12 @@ class PDF_depth_decoder(nn.Module):
18
 
19
  emb_dim = 128
20
  self.Decoder = nn.ModuleList()
21
- self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
22
- self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
23
- self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
24
- self.Decoder.append(nn.Sequential(make_crs(emb_dim*2,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
25
 
26
- self.shallow = nn.Sequential(nn.Conv2d(raw_ch, emb_dim, kernel_size=3, stride=1, padding=1))
27
  self.upsample1 = make_crs(emb_dim,emb_dim)
28
  self.upsample2 = make_crs(emb_dim,emb_dim)
29
 
@@ -34,7 +34,7 @@ class PDF_depth_decoder(nn.Module):
34
  self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
35
  self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
36
 
37
- def forward(self,img,img_feature):
38
 
39
  L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature
40
 
@@ -46,7 +46,7 @@ class PDF_depth_decoder(nn.Module):
46
 
47
  De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1))
48
 
49
- shallow = self.shallow(img)
50
  final_output = De_L1 + _upsample_like(shallow, De_L1)
51
  final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]))
52
  final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2])
@@ -65,7 +65,8 @@ class CoA(nn.Module):
65
  def __init__(self, emb_dim=128):
66
  super(CoA, self).__init__()
67
  self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
68
- self.Norm1 = RMSNorm(emb_dim,data_format='channels_last')
 
69
  self.drop1 = nn.Dropout(0.1)
70
  self.FFN = SwiGLU(emb_dim,emb_dim)
71
  self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
@@ -73,13 +74,31 @@ class CoA(nn.Module):
73
 
74
  def forward(self,q,kv):
75
  res = q
76
- KV_feature = self.Att(q, kv, kv)[0]
77
- KV_feature = self.Norm1(self.drop1(KV_feature)) + res
78
  res = KV_feature
79
- KV_feature = self.FFN(KV_feature)
80
- KV_feature = self.Norm2(self.drop2(KV_feature)) + res
81
  return KV_feature
82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
 
84
  class FSE(nn.Module):
85
  def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4):
@@ -149,11 +168,11 @@ class FSE(nn.Module):
149
 
150
  def BIS(self,pred):
151
  if pred.shape[-2]//8 % 2 == 0:
152
- boundary = 2*self.get_boundary(pred.sigmoid())
153
- return boundary, F.relu(pred.sigmoid()-5*boundary)
154
  else:
155
- boundary = 2*self.get_boundary(pred.sigmoid())
156
- return boundary, F.relu(pred.sigmoid()-5*boundary)
157
 
158
  def forward(self,img,depth,patch,last_pred):
159
  boundary,integrity = self.BIS(last_pred)
@@ -171,7 +190,7 @@ class FSE(nn.Module):
171
  #give depth the integrity prior
172
  integrity = _upsample_like(integrity,depth)
173
  last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid()
174
- enhance_depth = depth*(last_pred_sigmoid + integrity)
175
  depth_cs = self.D_channelswich(enhance_depth)
176
  pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
177
  pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
@@ -181,7 +200,8 @@ class FSE(nn.Module):
181
  patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
182
  boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
183
  boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
184
- select_patch = patch_batch * (1+5*boundary_score)
 
185
  select_patch = self.merge(select_patch,batch_size=B)
186
 
187
  patch_cs = self.P_channelswich(select_patch)
@@ -203,6 +223,52 @@ class FSE(nn.Module):
203
 
204
  return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  class PDF_decoder(nn.Module):
208
  def __init__(self, args,raw_ch=3,out_ch=1):
@@ -287,7 +353,7 @@ class PDFNet_process(nn.Module):
287
  emb = args.emb
288
  self.Glob = nn.Sequential(make_crs(emb,emb))
289
  self.decoder = decoder
290
- self.depth_decoder = depth_decoder
291
  self.decoder.patch_ratio = self.patch_ratio
292
  self.args=args
293
 
@@ -448,17 +514,18 @@ class PDFNet_process(nn.Module):
448
  [Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob],
449
  [patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob])
450
 
451
- pred_depth = self.depth_decoder(RIMG,[latent_I1+Depth_latent_I1+_upsample_like(patch_latent_I1,latent_I1),
452
- latent_I2+Depth_latent_I2+_upsample_like(patch_latent_I2,latent_I2),
453
- latent_I3+Depth_latent_I3+_upsample_like(patch_latent_I3,latent_I3),
454
- latent_I4+Depth_latent_I4+_upsample_like(patch_latent_I4,latent_I4),
455
- x_glob+Depth_x_glob+_upsample_like(patch_x_glob,x_glob)])
456
 
457
  loss, target_loss = self.loss_compute(pred_m,RGT)
458
  integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT)
459
  depth_loss,_ = self.depth_loss(pred_depth,depth_gt)
460
 
461
  loss = loss + integrity_loss/2 + depth_loss/10
 
462
 
463
  if self.args.DEBUG:
464
  print(pred_m[0].shape)
@@ -467,11 +534,13 @@ class PDFNet_process(nn.Module):
467
  RDEPTH.reshape([-1,H,W])[:1].cpu().detach(),
468
  RGT.reshape([-1,H,W])[:1].cpu().detach(),
469
  pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
470
- _upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),],dim=0)
 
471
  show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray')
472
  return [i.sigmoid() for i in pred_m], loss, target_loss
473
 
474
  @torch.no_grad()
 
475
  def inference(self,img,depth):
476
  depth = (depth-depth.min())/(depth.max()-depth.min())
477
  B,C,H,W = img.size()
@@ -506,6 +575,7 @@ class PDFNet_process(nn.Module):
506
 
507
  def build_model(args):
508
  if args.back_bone == 'PDFNet_swinB':
509
- return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=False),
510
  decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args),
511
- device=args.device, args=args),args.model
 
 
18
 
19
  emb_dim = 128
20
  self.Decoder = nn.ModuleList()
21
+ self.Decoder.append(nn.Sequential(make_crs(emb_dim*2*3,emb_dim*2),make_crs(emb_dim*2,emb_dim)))
22
+ self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
23
+ self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
24
+ self.Decoder.append(nn.Sequential(make_crs(emb_dim*(1+3),emb_dim*2),make_crs(emb_dim*2,emb_dim)))
25
 
26
+ self.shallow = nn.Sequential(nn.Conv2d(raw_ch*2, emb_dim, kernel_size=3, stride=1, padding=1))
27
  self.upsample1 = make_crs(emb_dim,emb_dim)
28
  self.upsample2 = make_crs(emb_dim,emb_dim)
29
 
 
34
  self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
35
  self.Bside.append(nn.Conv2d(emb_dim,out_ch,3,padding=1))
36
 
37
+ def forward(self,img,depth,img_feature):
38
 
39
  L1_feature,L2_feature,L3_feature,L4_feature,global_feature = img_feature
40
 
 
46
 
47
  De_L1 = self.Decoder[3](torch.cat([_upsample_like(De_L2,L1_feature),L1_feature],dim=1))
48
 
49
+ shallow = self.shallow(torch.cat([img,depth],dim=1))
50
  final_output = De_L1 + _upsample_like(shallow, De_L1)
51
  final_output = self.upsample1(_upsample_(final_output,[final_output.shape[-2]*2,final_output.shape[-1]*2]))
52
  final_output = _upsample_(final_output + _upsample_like(shallow, final_output),[final_output.shape[-2]*2,final_output.shape[-1]*2])
 
65
  def __init__(self, emb_dim=128):
66
  super(CoA, self).__init__()
67
  self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
68
+ self.Normq = RMSNorm(emb_dim,data_format='channels_last')
69
+ self.Normkv = RMSNorm(emb_dim,data_format='channels_last')
70
  self.drop1 = nn.Dropout(0.1)
71
  self.FFN = SwiGLU(emb_dim,emb_dim)
72
  self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
 
74
 
75
  def forward(self,q,kv):
76
  res = q
77
+ KV_feature = self.Att(self.Normq(q), self.Normkv(kv), self.Normkv(kv))[0]
78
+ KV_feature = self.drop1(KV_feature) + res
79
  res = KV_feature
80
+ KV_feature = self.FFN(self.Norm2(KV_feature))
81
+ KV_feature = self.drop2(KV_feature) + res
82
  return KV_feature
83
 
84
+ class CoA(nn.Module):
85
+ def __init__(self, emb_dim=128):
86
+ super(CoA, self).__init__()
87
+ self.Att = nn.MultiheadAttention(emb_dim,1,bias=False,batch_first=True,dropout=0.1)
88
+ self.Norm1 = RMSNorm(emb_dim,data_format='channels_last')
89
+ self.drop1 = nn.Dropout(0.1)
90
+ self.FFN = SwiGLU(emb_dim,emb_dim)
91
+ self.Norm2 = RMSNorm(emb_dim,data_format='channels_last')
92
+ self.drop2 = nn.Dropout(0.1)
93
+
94
+ def forward(self,q,kv):
95
+ res = q
96
+ KV_feature = self.Norm1(self.Att(q,kv,kv)[0])
97
+ KV_feature = self.drop1(KV_feature) + res
98
+ res = KV_feature
99
+ KV_feature = self.Norm2(self.FFN(KV_feature))
100
+ KV_feature = self.drop2(KV_feature) + res
101
+ return KV_feature
102
 
103
  class FSE(nn.Module):
104
  def __init__(self, img_dim=128, depth_dim=128, patch_dim=128, emb_dim=128, pool_ratio=[1,1,1], patch_ratio=4):
 
168
 
169
  def BIS(self,pred):
170
  if pred.shape[-2]//8 % 2 == 0:
171
+ boundary = (self.get_boundary(pred.sigmoid())>0.1).float()
172
+ return boundary, F.relu(pred.sigmoid()-boundary)
173
  else:
174
+ boundary = self.get_boundary(pred.sigmoid())
175
+ return boundary, F.relu(pred.sigmoid()-boundary)
176
 
177
  def forward(self,img,depth,patch,last_pred):
178
  boundary,integrity = self.BIS(last_pred)
 
190
  #give depth the integrity prior
191
  integrity = _upsample_like(integrity,depth)
192
  last_pred_sigmoid = _upsample_like(last_pred,depth).sigmoid()
193
+ enhance_depth = depth*(last_pred_sigmoid+integrity)
194
  depth_cs = self.D_channelswich(enhance_depth)
195
  pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
196
  pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
 
200
  patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
201
  boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
202
  boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
203
+ select_patch = patch_batch * (1+(boundary_score>0).float())
204
+ # select_patch = patch_batch*0
205
  select_patch = self.merge(select_patch,batch_size=B)
206
 
207
  patch_cs = self.P_channelswich(select_patch)
 
223
 
224
  return img_feature + rearrange(img_cs, 'b (h w) c -> b c h w',h=img_H), depth_feature + depth_cs, patch_feature + patch_cs
225
 
226
+ # def forward(self,img,depth,patch,last_pred):
227
+ # boundary,integrity = self.BIS(last_pred)
228
+ # # img = img * _upsample_like(last_pred.sigmoid(),img)
229
+ # # depth = depth * _upsample_like(last_pred.sigmoid(),depth)
230
+ # # patch = patch * _upsample_like(last_pred.sigmoid(),patch)
231
+ # pi,pd,pp = self.pool_ratio
232
+ # B,C,img_H,img_W = img.size()
233
+ # img_cs = self.I_channelswich(img* (1+_upsample_like(integrity,depth)))
234
+ # pool_img_cs = F.adaptive_avg_pool2d(img_cs,output_size=[img_H//pi,img_W//pi])
235
+ # # img_cs = rearrange(img_cs, 'b c h w -> b (h w) c')
236
+ # pool_img_cs = rearrange(pool_img_cs, 'b c h w -> b (h w) c')
237
+ # B,C,depth_H,depth_W = depth.size()
238
+
239
+ # #give depth the integrity prior
240
+ # enhance_depth = depth * _upsample_like(last_pred.sigmoid(),depth)
241
+ # depth_cs = self.D_channelswich(enhance_depth)
242
+ # pool_depth_cs = F.adaptive_avg_pool2d(depth_cs,output_size=[depth_H//pd,depth_W//pd])
243
+ # depth_cs = rearrange(depth_cs, 'b c h w -> b (h w) c')
244
+ # pool_depth_cs = rearrange(pool_depth_cs, 'b c h w -> b (h w) c')
245
+ # B,C,patch_H,patch_W = patch.size()
246
+
247
+ # #select the boundary patches to select patches
248
+ # patch_batch = self.split(patch,patch_ratio=self.patch_ratio)
249
+ # boundary_batch = self.split(boundary,patch_ratio=self.patch_ratio)
250
+ # boundary_score = boundary_batch.mean(dim=[2,3])[...,None,None]
251
+ # select_patch = patch_batch * (1+(boundary_score>0).float())
252
+ # select_patch = self.merge(select_patch,batch_size=B)
253
+
254
+ # patch_cs = self.P_channelswich(select_patch)
255
+ # pool_patch_cs = F.adaptive_avg_pool2d(patch_cs,output_size=[patch_H//pp,patch_W//pp])
256
+ # pool_patch_cs = rearrange(pool_patch_cs, 'b c h w -> b (h w) c')
257
+
258
+ # patch_feature = self.PI(pool_patch_cs, torch.cat([pool_img_cs,pool_depth_cs],dim=1))
259
+ # depth_feature = self.IP(depth_cs,patch_feature)
260
+
261
+ # img_feature = self.DI(pool_img_cs, torch.cat([pool_img_cs,pool_patch_cs],dim=1))
262
+ # depth_feature = self.ID(depth_feature,img_feature)
263
+
264
+ # patch_feature = rearrange(patch_feature, 'b (h w) c -> b c h w',h=patch_H//pp)
265
+ # depth_feature = rearrange(depth_feature, 'b (h w) c -> b c h w',h=depth_H)
266
+ # img_feature = rearrange(img_feature, 'b (h w) c -> b c h w',h=img_H//pi)
267
+
268
+ # img_feature = _upsample_like(img_feature,img)
269
+ # patch_feature = _upsample_like(patch_feature,patch)
270
+
271
+ # return img_feature + img_cs, depth_feature + rearrange(depth_cs, 'b (h w) c -> b c h w',h=depth_H), patch_feature + patch_cs
272
 
273
  class PDF_decoder(nn.Module):
274
  def __init__(self, args,raw_ch=3,out_ch=1):
 
353
  emb = args.emb
354
  self.Glob = nn.Sequential(make_crs(emb,emb))
355
  self.decoder = decoder
356
+ # self.depth_decoder = depth_decoder
357
  self.decoder.patch_ratio = self.patch_ratio
358
  self.args=args
359
 
 
514
  [Depth_latent_I1,Depth_latent_I2,Depth_latent_I3,Depth_latent_I4,Depth_x_glob],
515
  [patch_latent_I1,patch_latent_I2,patch_latent_I3,patch_latent_I4,patch_x_glob])
516
 
517
+ pred_depth = self.depth_decoder(RIMG,RDEPTH,[torch.cat([latent_I1,Depth_latent_I1,_upsample_like(patch_latent_I1,latent_I1)],dim=1),
518
+ torch.cat([latent_I2,Depth_latent_I2,_upsample_like(patch_latent_I2,latent_I2)],dim=1),
519
+ torch.cat([latent_I3,Depth_latent_I3,_upsample_like(patch_latent_I3,latent_I3)],dim=1),
520
+ torch.cat([latent_I4,Depth_latent_I4,_upsample_like(patch_latent_I4,latent_I4)],dim=1),
521
+ torch.cat([x_glob,Depth_x_glob,_upsample_like(patch_x_glob,x_glob)],dim=1)])
522
 
523
  loss, target_loss = self.loss_compute(pred_m,RGT)
524
  integrity_loss,_ = self.Integrity_Loss(pred_m,depth_gt,RGT)
525
  depth_loss,_ = self.depth_loss(pred_depth,depth_gt)
526
 
527
  loss = loss + integrity_loss/2 + depth_loss/10
528
+ # loss = loss + integrity_loss/2
529
 
530
  if self.args.DEBUG:
531
  print(pred_m[0].shape)
 
534
  RDEPTH.reshape([-1,H,W])[:1].cpu().detach(),
535
  RGT.reshape([-1,H,W])[:1].cpu().detach(),
536
  pred_m[0].sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
537
+ # _upsample_like(pred_depth[0],pred_m[0]).sigmoid().reshape([-1,H,W])[:1].cpu().detach(),
538
+ ],dim=0)
539
  show_gray_images(Show_X,m=RIMG.shape[0]*4,alpha=1.5,cmap='gray')
540
  return [i.sigmoid() for i in pred_m], loss, target_loss
541
 
542
  @torch.no_grad()
543
+
544
  def inference(self,img,depth):
545
  depth = (depth-depth.min())/(depth.max()-depth.min())
546
  B,C,H,W = img.size()
 
575
 
576
  def build_model(args):
577
  if args.back_bone == 'PDFNet_swinB':
578
+ return PDFNet_process(encoder=SwinB(args=args,in_chans=3,pretrained=True),
579
  decoder=PDF_decoder(args=args),depth_decoder=PDF_depth_decoder(args=args),
580
+ device=args.device, args=args),args.model
581
+