Update app.py
Browse files
app.py
CHANGED
|
@@ -89,7 +89,7 @@ def predict(image):
|
|
| 89 |
depth = torch.nn.functional.interpolate(torch.from_numpy(depth)[None,None,...],size=[1024,1024],mode='bilinear',align_corners=True)
|
| 90 |
image = torch.divide(image,255.0)
|
| 91 |
depth = torch.divide(depth,255.0)
|
| 92 |
-
depth = transforms(depth[0].
|
| 93 |
image = transforms(image).unsqueeze(0)
|
| 94 |
DIS_map = model.inference(image.to(device),depth.to(device))[0][0][0].cpu()
|
| 95 |
DIS_map = cv2.resize(np.array(DIS_map), (W,H))
|
|
|
|
| 89 |
depth = torch.nn.functional.interpolate(torch.from_numpy(depth)[None,None,...],size=[1024,1024],mode='bilinear',align_corners=True)
|
| 90 |
image = torch.divide(image,255.0)
|
| 91 |
depth = torch.divide(depth,255.0)
|
| 92 |
+
depth = transforms(depth[0].repeat(3,1,1)).unsqueeze(0)
|
| 93 |
image = transforms(image).unsqueeze(0)
|
| 94 |
DIS_map = model.inference(image.to(device),depth.to(device))[0][0][0].cpu()
|
| 95 |
DIS_map = cv2.resize(np.array(DIS_map), (W,H))
|