Spaces:
Runtime error
Runtime error
fix: cpu roma
Browse files
third_party/Roma/roma/models/encoders.py
CHANGED
|
@@ -24,7 +24,10 @@ class ResNet50(nn.Module):
|
|
| 24 |
self.freeze_bn = freeze_bn
|
| 25 |
self.early_exit = early_exit
|
| 26 |
self.amp = amp
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def forward(self, x, **kwargs):
|
| 30 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
@@ -60,7 +63,10 @@ class VGG19(nn.Module):
|
|
| 60 |
super().__init__()
|
| 61 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 62 |
self.amp = amp
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def forward(self, x, **kwargs):
|
| 66 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
@@ -94,7 +100,10 @@ class CNNandDinov2(nn.Module):
|
|
| 94 |
else:
|
| 95 |
self.cnn = VGG19(**cnn_kwargs)
|
| 96 |
self.amp = amp
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
| 98 |
if self.amp:
|
| 99 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 100 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
|
|
|
| 24 |
self.freeze_bn = freeze_bn
|
| 25 |
self.early_exit = early_exit
|
| 26 |
self.amp = amp
|
| 27 |
+
if not torch.cuda.is_available():
|
| 28 |
+
self.amp_dtype = torch.float32
|
| 29 |
+
else:
|
| 30 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 31 |
|
| 32 |
def forward(self, x, **kwargs):
|
| 33 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
|
| 63 |
super().__init__()
|
| 64 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 65 |
self.amp = amp
|
| 66 |
+
if not torch.cuda.is_available():
|
| 67 |
+
self.amp_dtype = torch.float32
|
| 68 |
+
else:
|
| 69 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 70 |
|
| 71 |
def forward(self, x, **kwargs):
|
| 72 |
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
|
|
|
| 100 |
else:
|
| 101 |
self.cnn = VGG19(**cnn_kwargs)
|
| 102 |
self.amp = amp
|
| 103 |
+
if not torch.cuda.is_available():
|
| 104 |
+
self.amp_dtype = torch.float32
|
| 105 |
+
else:
|
| 106 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 107 |
if self.amp:
|
| 108 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 109 |
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
third_party/Roma/roma/models/matcher.py
CHANGED
|
@@ -71,8 +71,12 @@ class ConvRefiner(nn.Module):
|
|
| 71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
| 72 |
self.is_classifier = is_classifier
|
| 73 |
self.sample_mode = sample_mode
|
| 74 |
-
self.
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
def create_block(
|
| 77 |
self,
|
| 78 |
in_dim,
|
|
@@ -109,8 +113,8 @@ class ConvRefiner(nn.Module):
|
|
| 109 |
if self.has_displacement_emb:
|
| 110 |
im_A_coords = torch.meshgrid(
|
| 111 |
(
|
| 112 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
| 113 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
| 114 |
)
|
| 115 |
)
|
| 116 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
@@ -296,8 +300,11 @@ class Decoder(nn.Module):
|
|
| 296 |
self.displacement_dropout_p = displacement_dropout_p
|
| 297 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
| 298 |
self.flow_upsample_mode = flow_upsample_mode
|
| 299 |
-
|
| 300 |
-
|
|
|
|
|
|
|
|
|
|
| 301 |
def get_placeholder_flow(self, b, h, w, device):
|
| 302 |
coarse_coords = torch.meshgrid(
|
| 303 |
(
|
|
@@ -615,8 +622,8 @@ class RegressionMatcher(nn.Module):
|
|
| 615 |
# Create im_A meshgrid
|
| 616 |
im_A_coords = torch.meshgrid(
|
| 617 |
(
|
| 618 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
| 619 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
| 620 |
)
|
| 621 |
)
|
| 622 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
|
| 71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
| 72 |
self.is_classifier = is_classifier
|
| 73 |
self.sample_mode = sample_mode
|
| 74 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 75 |
+
if not torch.cuda.is_available():
|
| 76 |
+
self.amp_dtype = torch.float32
|
| 77 |
+
else:
|
| 78 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 79 |
+
|
| 80 |
def create_block(
|
| 81 |
self,
|
| 82 |
in_dim,
|
|
|
|
| 113 |
if self.has_displacement_emb:
|
| 114 |
im_A_coords = torch.meshgrid(
|
| 115 |
(
|
| 116 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=self.device),
|
| 117 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=self.device),
|
| 118 |
)
|
| 119 |
)
|
| 120 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
|
| 300 |
self.displacement_dropout_p = displacement_dropout_p
|
| 301 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
| 302 |
self.flow_upsample_mode = flow_upsample_mode
|
| 303 |
+
if not torch.cuda.is_available():
|
| 304 |
+
self.amp_dtype = torch.float32
|
| 305 |
+
else:
|
| 306 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 307 |
+
|
| 308 |
def get_placeholder_flow(self, b, h, w, device):
|
| 309 |
coarse_coords = torch.meshgrid(
|
| 310 |
(
|
|
|
|
| 622 |
# Create im_A meshgrid
|
| 623 |
im_A_coords = torch.meshgrid(
|
| 624 |
(
|
| 625 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
| 626 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
| 627 |
)
|
| 628 |
)
|
| 629 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|