Spaces:
Runtime error
Runtime error
fix: cache inference
Browse files- .gitignore +2 -0
- common/utils.py +1 -1
- common/viz.py +2 -2
- hloc/extractors/superpoint.py +1 -1
- third_party/SuperGluePretrainedNetwork/models/superpoint.py +16 -6
.gitignore
CHANGED
|
@@ -18,3 +18,5 @@ gradio_cached_examples
|
|
| 18 |
hloc/matchers/quadtree.py
|
| 19 |
third_party/QuadTreeAttention
|
| 20 |
desktop.ini
|
|
|
|
|
|
|
|
|
| 18 |
hloc/matchers/quadtree.py
|
| 19 |
third_party/QuadTreeAttention
|
| 20 |
desktop.ini
|
| 21 |
+
experiments*
|
| 22 |
+
datasets/wxbs_benchmark
|
common/utils.py
CHANGED
|
@@ -518,7 +518,7 @@ def run_matching(
|
|
| 518 |
gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
|
| 519 |
logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
|
| 520 |
t1 = time.time()
|
| 521 |
-
# plot images with keypoints
|
| 522 |
titles = [
|
| 523 |
"Image 0 - Keypoints",
|
| 524 |
"Image 1 - Keypoints",
|
|
|
|
| 518 |
gr.Info(f"Matching images done using: {time.time()-t1:.3f}s")
|
| 519 |
logger.info(f"Matching images done using: {time.time()-t1:.3f}s")
|
| 520 |
t1 = time.time()
|
| 521 |
+
# plot images with keypoints\
|
| 522 |
titles = [
|
| 523 |
"Image 0 - Keypoints",
|
| 524 |
"Image 1 - Keypoints",
|
common/viz.py
CHANGED
|
@@ -293,7 +293,7 @@ def draw_matches_core(
|
|
| 293 |
mkpts1,
|
| 294 |
color,
|
| 295 |
titles=titles,
|
| 296 |
-
|
| 297 |
path=path,
|
| 298 |
dpi=dpi,
|
| 299 |
pad=pad,
|
|
@@ -308,7 +308,7 @@ def draw_matches_core(
|
|
| 308 |
mkpts1,
|
| 309 |
color,
|
| 310 |
titles=titles,
|
| 311 |
-
|
| 312 |
pad=pad,
|
| 313 |
dpi=dpi,
|
| 314 |
)
|
|
|
|
| 293 |
mkpts1,
|
| 294 |
color,
|
| 295 |
titles=titles,
|
| 296 |
+
text=text,
|
| 297 |
path=path,
|
| 298 |
dpi=dpi,
|
| 299 |
pad=pad,
|
|
|
|
| 308 |
mkpts1,
|
| 309 |
color,
|
| 310 |
titles=titles,
|
| 311 |
+
text=text,
|
| 312 |
pad=pad,
|
| 313 |
dpi=dpi,
|
| 314 |
)
|
hloc/extractors/superpoint.py
CHANGED
|
@@ -44,4 +44,4 @@ class SuperPoint(BaseModel):
|
|
| 44 |
self.net = superpoint.SuperPoint(conf)
|
| 45 |
|
| 46 |
def _forward(self, data):
|
| 47 |
-
return self.net(data)
|
|
|
|
| 44 |
self.net = superpoint.SuperPoint(conf)
|
| 45 |
|
| 46 |
def _forward(self, data):
|
| 47 |
+
return self.net(data, self.conf)
|
third_party/SuperGluePretrainedNetwork/models/superpoint.py
CHANGED
|
@@ -83,9 +83,9 @@ def sample_descriptors(keypoints, descriptors, s: int = 8):
|
|
| 83 |
"""Interpolate descriptors at keypoint locations"""
|
| 84 |
b, c, h, w = descriptors.shape
|
| 85 |
keypoints = keypoints - s / 2 + 0.5
|
| 86 |
-
keypoints /= torch.tensor(
|
| 87 |
-
|
| 88 |
-
)[None]
|
| 89 |
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
| 90 |
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
| 91 |
descriptors = torch.nn.functional.grid_sample(
|
|
@@ -136,7 +136,11 @@ class SuperPoint(nn.Module):
|
|
| 136 |
|
| 137 |
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
| 138 |
self.convDb = nn.Conv2d(
|
| 139 |
-
c5,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
)
|
| 141 |
|
| 142 |
path = Path(__file__).parent / "weights/superpoint_v1.pth"
|
|
@@ -148,8 +152,12 @@ class SuperPoint(nn.Module):
|
|
| 148 |
|
| 149 |
print("Loaded SuperPoint model")
|
| 150 |
|
| 151 |
-
def forward(self, data):
|
| 152 |
"""Compute keypoints, scores, descriptors for image"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
# Shared Encoder
|
| 154 |
x = self.relu(self.conv1a(data["image"]))
|
| 155 |
x = self.relu(self.conv1b(x))
|
|
@@ -182,7 +190,9 @@ class SuperPoint(nn.Module):
|
|
| 182 |
keypoints, scores = list(
|
| 183 |
zip(
|
| 184 |
*[
|
| 185 |
-
remove_borders(
|
|
|
|
|
|
|
| 186 |
for k, s in zip(keypoints, scores)
|
| 187 |
]
|
| 188 |
)
|
|
|
|
| 83 |
"""Interpolate descriptors at keypoint locations"""
|
| 84 |
b, c, h, w = descriptors.shape
|
| 85 |
keypoints = keypoints - s / 2 + 0.5
|
| 86 |
+
keypoints /= torch.tensor(
|
| 87 |
+
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
| 88 |
+
).to(keypoints)[None]
|
| 89 |
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
| 90 |
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
| 91 |
descriptors = torch.nn.functional.grid_sample(
|
|
|
|
| 136 |
|
| 137 |
self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
| 138 |
self.convDb = nn.Conv2d(
|
| 139 |
+
c5,
|
| 140 |
+
self.config["descriptor_dim"],
|
| 141 |
+
kernel_size=1,
|
| 142 |
+
stride=1,
|
| 143 |
+
padding=0,
|
| 144 |
)
|
| 145 |
|
| 146 |
path = Path(__file__).parent / "weights/superpoint_v1.pth"
|
|
|
|
| 152 |
|
| 153 |
print("Loaded SuperPoint model")
|
| 154 |
|
| 155 |
+
def forward(self, data, cfg={}):
|
| 156 |
"""Compute keypoints, scores, descriptors for image"""
|
| 157 |
+
self.config = {
|
| 158 |
+
**self.config,
|
| 159 |
+
**cfg,
|
| 160 |
+
}
|
| 161 |
# Shared Encoder
|
| 162 |
x = self.relu(self.conv1a(data["image"]))
|
| 163 |
x = self.relu(self.conv1b(x))
|
|
|
|
| 190 |
keypoints, scores = list(
|
| 191 |
zip(
|
| 192 |
*[
|
| 193 |
+
remove_borders(
|
| 194 |
+
k, s, self.config["remove_borders"], h * 8, w * 8
|
| 195 |
+
)
|
| 196 |
for k, s in zip(keypoints, scores)
|
| 197 |
]
|
| 198 |
)
|