Spaces:
Runtime error
Runtime error
Vincentqyw
commited on
Commit
·
7b977a8
1
Parent(s):
cd1bd30
update: ui
Browse files- app.py +186 -162
- common/utils.py +323 -12
- common/visualize_util.py +0 -642
- common/{plotting.py → viz.py} +116 -21
- style.css +18 -0
app.py
CHANGED
|
@@ -1,59 +1,20 @@
|
|
| 1 |
import argparse
|
| 2 |
import gradio as gr
|
| 3 |
-
|
| 4 |
-
from hloc import extract_features
|
| 5 |
from common.utils import (
|
| 6 |
matcher_zoo,
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
get_feature_model,
|
| 12 |
-
display_matches,
|
| 13 |
)
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
|
| 17 |
-
match_threshold, extract_max_keypoints, keypoint_threshold, key, image0, image1
|
| 18 |
-
):
|
| 19 |
-
# image0 and image1 is RGB mode
|
| 20 |
-
if image0 is None or image1 is None:
|
| 21 |
-
raise gr.Error("Error: No images found! Please upload two images.")
|
| 22 |
-
|
| 23 |
-
model = matcher_zoo[key]
|
| 24 |
-
match_conf = model["config"]
|
| 25 |
-
# update match config
|
| 26 |
-
match_conf["model"]["match_threshold"] = match_threshold
|
| 27 |
-
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 28 |
|
| 29 |
-
|
| 30 |
-
if model["dense"]:
|
| 31 |
-
pred = match_dense.match_images(
|
| 32 |
-
matcher, image0, image1, match_conf["preprocessing"], device=device
|
| 33 |
-
)
|
| 34 |
-
del matcher
|
| 35 |
-
extract_conf = None
|
| 36 |
-
else:
|
| 37 |
-
extract_conf = model["config_feature"]
|
| 38 |
-
# update extract config
|
| 39 |
-
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 40 |
-
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 41 |
-
extractor = get_feature_model(extract_conf)
|
| 42 |
-
pred0 = extract_features.extract(
|
| 43 |
-
extractor, image0, extract_conf["preprocessing"]
|
| 44 |
-
)
|
| 45 |
-
pred1 = extract_features.extract(
|
| 46 |
-
extractor, image1, extract_conf["preprocessing"]
|
| 47 |
-
)
|
| 48 |
-
pred = match_features.match_images(matcher, pred0, pred1)
|
| 49 |
-
del extractor
|
| 50 |
-
fig, num_inliers = display_matches(pred)
|
| 51 |
-
del pred
|
| 52 |
-
return (
|
| 53 |
-
fig,
|
| 54 |
-
{"matches number": num_inliers},
|
| 55 |
-
{"match_conf": match_conf, "extractor_conf": extract_conf},
|
| 56 |
-
)
|
| 57 |
|
| 58 |
|
| 59 |
def ui_change_imagebox(choice):
|
|
@@ -61,7 +22,18 @@ def ui_change_imagebox(choice):
|
|
| 61 |
|
| 62 |
|
| 63 |
def ui_reset_state(
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
):
|
| 66 |
match_threshold = 0.2
|
| 67 |
extract_max_keypoints = 1000
|
|
@@ -69,31 +41,35 @@ def ui_reset_state(
|
|
| 69 |
key = list(matcher_zoo.keys())[0]
|
| 70 |
image0 = None
|
| 71 |
image1 = None
|
|
|
|
| 72 |
return (
|
|
|
|
|
|
|
| 73 |
match_threshold,
|
| 74 |
extract_max_keypoints,
|
| 75 |
keypoint_threshold,
|
| 76 |
key,
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
{"value": None, "source": "upload", "__type__": "update"},
|
| 80 |
-
{"value": None, "source": "upload", "__type__": "update"},
|
| 81 |
"upload",
|
| 82 |
None,
|
| 83 |
{},
|
| 84 |
{},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
)
|
| 86 |
|
| 87 |
|
|
|
|
| 88 |
def run(config):
|
| 89 |
-
with gr.Blocks(css="
|
| 90 |
-
gr.Markdown(
|
| 91 |
-
"""
|
| 92 |
-
<p align="center">
|
| 93 |
-
<h1 align="center">Image Matching WebUI</h1>
|
| 94 |
-
</p>
|
| 95 |
-
"""
|
| 96 |
-
)
|
| 97 |
|
| 98 |
with gr.Row(equal_height=False):
|
| 99 |
with gr.Column():
|
|
@@ -109,43 +85,6 @@ def run(config):
|
|
| 109 |
label="Image Source",
|
| 110 |
value="upload",
|
| 111 |
)
|
| 112 |
-
|
| 113 |
-
with gr.Row():
|
| 114 |
-
match_setting_threshold = gr.Slider(
|
| 115 |
-
minimum=0.0,
|
| 116 |
-
maximum=1,
|
| 117 |
-
step=0.001,
|
| 118 |
-
label="Match threshold",
|
| 119 |
-
value=0.1,
|
| 120 |
-
)
|
| 121 |
-
match_setting_max_features = gr.Slider(
|
| 122 |
-
minimum=10,
|
| 123 |
-
maximum=10000,
|
| 124 |
-
step=10,
|
| 125 |
-
label="Max number of features",
|
| 126 |
-
value=1000,
|
| 127 |
-
)
|
| 128 |
-
# TODO: add line settings
|
| 129 |
-
with gr.Row():
|
| 130 |
-
detect_keypoints_threshold = gr.Slider(
|
| 131 |
-
minimum=0,
|
| 132 |
-
maximum=1,
|
| 133 |
-
step=0.001,
|
| 134 |
-
label="Keypoint threshold",
|
| 135 |
-
value=0.015,
|
| 136 |
-
)
|
| 137 |
-
detect_line_threshold = gr.Slider(
|
| 138 |
-
minimum=0.1,
|
| 139 |
-
maximum=1,
|
| 140 |
-
step=0.01,
|
| 141 |
-
label="Line threshold",
|
| 142 |
-
value=0.2,
|
| 143 |
-
)
|
| 144 |
-
# matcher_lists = gr.Radio(
|
| 145 |
-
# ["NN-mutual", "Dual-Softmax"],
|
| 146 |
-
# label="Matcher mode",
|
| 147 |
-
# value="NN-mutual",
|
| 148 |
-
# )
|
| 149 |
with gr.Row():
|
| 150 |
input_image0 = gr.Image(
|
| 151 |
label="Image 0",
|
|
@@ -166,89 +105,147 @@ def run(config):
|
|
| 166 |
label="Run Match", value="Run Match", variant="primary"
|
| 167 |
)
|
| 168 |
|
| 169 |
-
with gr.Accordion("
|
| 170 |
-
gr.
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
|
|
|
|
| 177 |
# collect inputs
|
| 178 |
inputs = [
|
|
|
|
|
|
|
| 179 |
match_setting_threshold,
|
| 180 |
match_setting_max_features,
|
| 181 |
detect_keypoints_threshold,
|
| 182 |
matcher_list,
|
| 183 |
-
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
]
|
| 186 |
|
| 187 |
# Add some examples
|
| 188 |
with gr.Row():
|
| 189 |
-
examples = [
|
| 190 |
-
[
|
| 191 |
-
0.1,
|
| 192 |
-
2000,
|
| 193 |
-
0.015,
|
| 194 |
-
"disk+lightglue",
|
| 195 |
-
"datasets/sacre_coeur/mapping/71295362_4051449754.jpg",
|
| 196 |
-
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
|
| 197 |
-
],
|
| 198 |
-
[
|
| 199 |
-
0.1,
|
| 200 |
-
2000,
|
| 201 |
-
0.015,
|
| 202 |
-
"loftr",
|
| 203 |
-
"datasets/sacre_coeur/mapping/03903474_1471484089.jpg",
|
| 204 |
-
"datasets/sacre_coeur/mapping/02928139_3448003521.jpg",
|
| 205 |
-
],
|
| 206 |
-
[
|
| 207 |
-
0.1,
|
| 208 |
-
2000,
|
| 209 |
-
0.015,
|
| 210 |
-
"disk",
|
| 211 |
-
"datasets/sacre_coeur/mapping/10265353_3838484249.jpg",
|
| 212 |
-
"datasets/sacre_coeur/mapping/51091044_3486849416.jpg",
|
| 213 |
-
],
|
| 214 |
-
[
|
| 215 |
-
0.1,
|
| 216 |
-
2000,
|
| 217 |
-
0.015,
|
| 218 |
-
"topicfm",
|
| 219 |
-
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
|
| 220 |
-
"datasets/sacre_coeur/mapping/93341989_396310999.jpg",
|
| 221 |
-
],
|
| 222 |
-
[
|
| 223 |
-
0.1,
|
| 224 |
-
2000,
|
| 225 |
-
0.015,
|
| 226 |
-
"superpoint+superglue",
|
| 227 |
-
"datasets/sacre_coeur/mapping/17295357_9106075285.jpg",
|
| 228 |
-
"datasets/sacre_coeur/mapping/44120379_8371960244.jpg",
|
| 229 |
-
],
|
| 230 |
-
]
|
| 231 |
# Example inputs
|
| 232 |
gr.Examples(
|
| 233 |
-
examples=
|
| 234 |
inputs=inputs,
|
| 235 |
outputs=[],
|
| 236 |
fn=run_matching,
|
| 237 |
-
cache_examples=
|
| 238 |
-
label=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
with gr.Column():
|
| 242 |
-
output_mkpts = gr.Image(
|
| 243 |
-
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
# callbacks
|
| 247 |
match_image_src.change(
|
| 248 |
-
fn=ui_change_imagebox,
|
|
|
|
|
|
|
| 249 |
)
|
| 250 |
match_image_src.change(
|
| 251 |
-
fn=ui_change_imagebox,
|
|
|
|
|
|
|
| 252 |
)
|
| 253 |
|
| 254 |
# collect outputs
|
|
@@ -256,34 +253,61 @@ def run(config):
|
|
| 256 |
output_mkpts,
|
| 257 |
matches_result_info,
|
| 258 |
matcher_info,
|
|
|
|
|
|
|
| 259 |
]
|
| 260 |
# button callbacks
|
| 261 |
button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
|
| 262 |
|
| 263 |
# Reset images
|
| 264 |
reset_outputs = [
|
|
|
|
|
|
|
| 265 |
match_setting_threshold,
|
| 266 |
match_setting_max_features,
|
| 267 |
detect_keypoints_threshold,
|
| 268 |
matcher_list,
|
| 269 |
input_image0,
|
| 270 |
input_image1,
|
| 271 |
-
input_image0,
|
| 272 |
-
input_image1,
|
| 273 |
match_image_src,
|
| 274 |
output_mkpts,
|
| 275 |
matches_result_info,
|
| 276 |
matcher_info,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
]
|
| 278 |
-
button_reset.click(
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
app.launch(share=False)
|
| 281 |
|
| 282 |
|
| 283 |
if __name__ == "__main__":
|
| 284 |
parser = argparse.ArgumentParser()
|
| 285 |
parser.add_argument(
|
| 286 |
-
"--config_path",
|
|
|
|
|
|
|
|
|
|
| 287 |
)
|
| 288 |
args = parser.parse_args()
|
| 289 |
config = None
|
|
|
|
| 1 |
import argparse
|
| 2 |
import gradio as gr
|
|
|
|
|
|
|
| 3 |
from common.utils import (
|
| 4 |
matcher_zoo,
|
| 5 |
+
change_estimate_geom,
|
| 6 |
+
run_matching,
|
| 7 |
+
ransac_zoo,
|
| 8 |
+
gen_examples,
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
|
| 11 |
+
DESCRIPTION = """
|
| 12 |
+
# Image Matching WebUI
|
| 13 |
+
This Space demonstrates [Image Matching WebUI](https://github.com/Vincentqyw/image-matching-webui) by vincent qin. Feel free to play with it, or duplicate to run image matching without a queue!
|
| 14 |
|
| 15 |
+
🔎 For more details about supported local features and matchers, please refer to https://github.com/Vincentqyw/image-matching-webui
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
|
| 20 |
def ui_change_imagebox(choice):
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
def ui_reset_state(
|
| 25 |
+
image0,
|
| 26 |
+
image1,
|
| 27 |
+
match_threshold,
|
| 28 |
+
extract_max_keypoints,
|
| 29 |
+
keypoint_threshold,
|
| 30 |
+
key,
|
| 31 |
+
enable_ransac=False,
|
| 32 |
+
ransac_method="RANSAC",
|
| 33 |
+
ransac_reproj_threshold=8,
|
| 34 |
+
ransac_confidence=0.999,
|
| 35 |
+
ransac_max_iter=10000,
|
| 36 |
+
choice_estimate_geom="Homography",
|
| 37 |
):
|
| 38 |
match_threshold = 0.2
|
| 39 |
extract_max_keypoints = 1000
|
|
|
|
| 41 |
key = list(matcher_zoo.keys())[0]
|
| 42 |
image0 = None
|
| 43 |
image1 = None
|
| 44 |
+
enable_ransac = False
|
| 45 |
return (
|
| 46 |
+
image0,
|
| 47 |
+
image1,
|
| 48 |
match_threshold,
|
| 49 |
extract_max_keypoints,
|
| 50 |
keypoint_threshold,
|
| 51 |
key,
|
| 52 |
+
ui_change_imagebox("upload"),
|
| 53 |
+
ui_change_imagebox("upload"),
|
|
|
|
|
|
|
| 54 |
"upload",
|
| 55 |
None,
|
| 56 |
{},
|
| 57 |
{},
|
| 58 |
+
None,
|
| 59 |
+
{},
|
| 60 |
+
False,
|
| 61 |
+
"RANSAC",
|
| 62 |
+
8,
|
| 63 |
+
0.999,
|
| 64 |
+
10000,
|
| 65 |
+
"Homography",
|
| 66 |
)
|
| 67 |
|
| 68 |
|
| 69 |
+
# "footer {visibility: hidden}"
|
| 70 |
def run(config):
|
| 71 |
+
with gr.Blocks(css="style.css") as app:
|
| 72 |
+
gr.Markdown(DESCRIPTION)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
with gr.Row(equal_height=False):
|
| 75 |
with gr.Column():
|
|
|
|
| 85 |
label="Image Source",
|
| 86 |
value="upload",
|
| 87 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
with gr.Row():
|
| 89 |
input_image0 = gr.Image(
|
| 90 |
label="Image 0",
|
|
|
|
| 105 |
label="Run Match", value="Run Match", variant="primary"
|
| 106 |
)
|
| 107 |
|
| 108 |
+
with gr.Accordion("Advanced Setting", open=False):
|
| 109 |
+
with gr.Accordion("Matching Setting", open=True):
|
| 110 |
+
with gr.Row():
|
| 111 |
+
match_setting_threshold = gr.Slider(
|
| 112 |
+
minimum=0.0,
|
| 113 |
+
maximum=1,
|
| 114 |
+
step=0.001,
|
| 115 |
+
label="Match thres.",
|
| 116 |
+
value=0.1,
|
| 117 |
+
)
|
| 118 |
+
match_setting_max_features = gr.Slider(
|
| 119 |
+
minimum=10,
|
| 120 |
+
maximum=10000,
|
| 121 |
+
step=10,
|
| 122 |
+
label="Max features",
|
| 123 |
+
value=1000,
|
| 124 |
+
)
|
| 125 |
+
# TODO: add line settings
|
| 126 |
+
with gr.Row():
|
| 127 |
+
detect_keypoints_threshold = gr.Slider(
|
| 128 |
+
minimum=0,
|
| 129 |
+
maximum=1,
|
| 130 |
+
step=0.001,
|
| 131 |
+
label="Keypoint thres.",
|
| 132 |
+
value=0.015,
|
| 133 |
+
)
|
| 134 |
+
detect_line_threshold = gr.Slider(
|
| 135 |
+
minimum=0.1,
|
| 136 |
+
maximum=1,
|
| 137 |
+
step=0.01,
|
| 138 |
+
label="Line thres.",
|
| 139 |
+
value=0.2,
|
| 140 |
+
)
|
| 141 |
+
# matcher_lists = gr.Radio(
|
| 142 |
+
# ["NN-mutual", "Dual-Softmax"],
|
| 143 |
+
# label="Matcher mode",
|
| 144 |
+
# value="NN-mutual",
|
| 145 |
+
# )
|
| 146 |
+
with gr.Accordion("RANSAC Setting", open=False):
|
| 147 |
+
with gr.Row(equal_height=False):
|
| 148 |
+
enable_ransac = gr.Checkbox(label="Enable RANSAC")
|
| 149 |
+
ransac_method = gr.Dropdown(
|
| 150 |
+
choices=ransac_zoo.keys(),
|
| 151 |
+
value="RANSAC",
|
| 152 |
+
label="RANSAC Method",
|
| 153 |
+
interactive=True,
|
| 154 |
+
)
|
| 155 |
+
ransac_reproj_threshold = gr.Slider(
|
| 156 |
+
minimum=0.0,
|
| 157 |
+
maximum=12,
|
| 158 |
+
step=0.01,
|
| 159 |
+
label="Ransac Reproj threshold",
|
| 160 |
+
value=8.0,
|
| 161 |
+
)
|
| 162 |
+
ransac_confidence = gr.Slider(
|
| 163 |
+
minimum=0.0,
|
| 164 |
+
maximum=1,
|
| 165 |
+
step=0.00001,
|
| 166 |
+
label="Ransac Confidence",
|
| 167 |
+
value=0.99999,
|
| 168 |
+
)
|
| 169 |
+
ransac_max_iter = gr.Slider(
|
| 170 |
+
minimum=0.0,
|
| 171 |
+
maximum=100000,
|
| 172 |
+
step=100,
|
| 173 |
+
label="Ransac Iterations",
|
| 174 |
+
value=10000,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
with gr.Accordion("Geometry Setting", open=True):
|
| 178 |
+
with gr.Row(equal_height=False):
|
| 179 |
+
# show_geom = gr.Checkbox(label="Show Geometry")
|
| 180 |
+
choice_estimate_geom = gr.Radio(
|
| 181 |
+
["Fundamental", "Homography"],
|
| 182 |
+
label="Reconstruct Geometry",
|
| 183 |
+
value="Homography",
|
| 184 |
+
)
|
| 185 |
|
| 186 |
+
# with gr.Column():
|
| 187 |
# collect inputs
|
| 188 |
inputs = [
|
| 189 |
+
input_image0,
|
| 190 |
+
input_image1,
|
| 191 |
match_setting_threshold,
|
| 192 |
match_setting_max_features,
|
| 193 |
detect_keypoints_threshold,
|
| 194 |
matcher_list,
|
| 195 |
+
enable_ransac,
|
| 196 |
+
ransac_method,
|
| 197 |
+
ransac_reproj_threshold,
|
| 198 |
+
ransac_confidence,
|
| 199 |
+
ransac_max_iter,
|
| 200 |
+
choice_estimate_geom,
|
| 201 |
]
|
| 202 |
|
| 203 |
# Add some examples
|
| 204 |
with gr.Row():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
# Example inputs
|
| 206 |
gr.Examples(
|
| 207 |
+
examples=gen_examples(),
|
| 208 |
inputs=inputs,
|
| 209 |
outputs=[],
|
| 210 |
fn=run_matching,
|
| 211 |
+
cache_examples=False,
|
| 212 |
+
label=(
|
| 213 |
+
"Examples (click one of the images below to Run"
|
| 214 |
+
" Match)"
|
| 215 |
+
),
|
| 216 |
+
)
|
| 217 |
+
with gr.Accordion("Open for More!", open=False):
|
| 218 |
+
gr.Markdown(
|
| 219 |
+
f"""
|
| 220 |
+
<h3>Supported Algorithms</h3>
|
| 221 |
+
{", ".join(matcher_zoo.keys())}
|
| 222 |
+
"""
|
| 223 |
)
|
| 224 |
|
| 225 |
with gr.Column():
|
| 226 |
+
output_mkpts = gr.Image(
|
| 227 |
+
label="Keypoints Matching", type="numpy"
|
| 228 |
+
)
|
| 229 |
+
with gr.Accordion(
|
| 230 |
+
"Open for More: Matches Statistics", open=False
|
| 231 |
+
):
|
| 232 |
+
matches_result_info = gr.JSON(label="Matches Statistics")
|
| 233 |
+
matcher_info = gr.JSON(label="Match info")
|
| 234 |
+
|
| 235 |
+
output_wrapped = gr.Image(label="Wrapped Pair", type="numpy")
|
| 236 |
+
with gr.Accordion("Open for More: Geometry info", open=False):
|
| 237 |
+
geometry_result = gr.JSON(label="Reconstructed Geometry")
|
| 238 |
|
| 239 |
# callbacks
|
| 240 |
match_image_src.change(
|
| 241 |
+
fn=ui_change_imagebox,
|
| 242 |
+
inputs=match_image_src,
|
| 243 |
+
outputs=input_image0,
|
| 244 |
)
|
| 245 |
match_image_src.change(
|
| 246 |
+
fn=ui_change_imagebox,
|
| 247 |
+
inputs=match_image_src,
|
| 248 |
+
outputs=input_image1,
|
| 249 |
)
|
| 250 |
|
| 251 |
# collect outputs
|
|
|
|
| 253 |
output_mkpts,
|
| 254 |
matches_result_info,
|
| 255 |
matcher_info,
|
| 256 |
+
geometry_result,
|
| 257 |
+
output_wrapped,
|
| 258 |
]
|
| 259 |
# button callbacks
|
| 260 |
button_run.click(fn=run_matching, inputs=inputs, outputs=outputs)
|
| 261 |
|
| 262 |
# Reset images
|
| 263 |
reset_outputs = [
|
| 264 |
+
input_image0,
|
| 265 |
+
input_image1,
|
| 266 |
match_setting_threshold,
|
| 267 |
match_setting_max_features,
|
| 268 |
detect_keypoints_threshold,
|
| 269 |
matcher_list,
|
| 270 |
input_image0,
|
| 271 |
input_image1,
|
|
|
|
|
|
|
| 272 |
match_image_src,
|
| 273 |
output_mkpts,
|
| 274 |
matches_result_info,
|
| 275 |
matcher_info,
|
| 276 |
+
output_wrapped,
|
| 277 |
+
geometry_result,
|
| 278 |
+
enable_ransac,
|
| 279 |
+
ransac_method,
|
| 280 |
+
ransac_reproj_threshold,
|
| 281 |
+
ransac_confidence,
|
| 282 |
+
ransac_max_iter,
|
| 283 |
+
choice_estimate_geom,
|
| 284 |
]
|
| 285 |
+
button_reset.click(
|
| 286 |
+
fn=ui_reset_state, inputs=inputs, outputs=reset_outputs
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# estimate geo
|
| 290 |
+
choice_estimate_geom.change(
|
| 291 |
+
fn=change_estimate_geom,
|
| 292 |
+
inputs=[
|
| 293 |
+
input_image0,
|
| 294 |
+
input_image1,
|
| 295 |
+
geometry_result,
|
| 296 |
+
choice_estimate_geom,
|
| 297 |
+
],
|
| 298 |
+
outputs=[output_wrapped, geometry_result],
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
app.launch(share=False)
|
| 302 |
|
| 303 |
|
| 304 |
if __name__ == "__main__":
|
| 305 |
parser = argparse.ArgumentParser()
|
| 306 |
parser.add_argument(
|
| 307 |
+
"--config_path",
|
| 308 |
+
type=str,
|
| 309 |
+
default="config.yaml",
|
| 310 |
+
help="configuration file path",
|
| 311 |
)
|
| 312 |
args = parser.parse_args()
|
| 313 |
config = None
|
common/utils.py
CHANGED
|
@@ -1,11 +1,14 @@
|
|
| 1 |
-
import
|
|
|
|
| 2 |
import numpy as np
|
|
|
|
|
|
|
| 3 |
import cv2
|
|
|
|
| 4 |
from hloc import matchers, extractors
|
| 5 |
from hloc.utils.base_model import dynamic_load
|
| 6 |
from hloc import match_dense, match_features, extract_features
|
| 7 |
-
from .
|
| 8 |
-
from .visualize_util import plot_images, plot_color_line_matches
|
| 9 |
|
| 10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
|
|
@@ -22,6 +25,217 @@ def get_feature_model(conf):
|
|
| 22 |
return model
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
def display_matches(pred: dict):
|
| 26 |
img0 = pred["image0_orig"]
|
| 27 |
img1 = pred["image1_orig"]
|
|
@@ -42,7 +256,10 @@ def display_matches(pred: dict):
|
|
| 42 |
img1,
|
| 43 |
mconf,
|
| 44 |
dpi=300,
|
| 45 |
-
titles=[
|
|
|
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
fig = fig_mkpts
|
| 48 |
if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
|
|
@@ -69,13 +286,107 @@ def display_matches(pred: dict):
|
|
| 69 |
else:
|
| 70 |
mconf = np.ones(len(mkpts0))
|
| 71 |
fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
|
| 72 |
-
fig_lines = cv2.resize(
|
|
|
|
|
|
|
| 73 |
fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
|
| 74 |
else:
|
| 75 |
fig = fig_lines
|
| 76 |
return fig, num_inliers
|
| 77 |
|
| 78 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
# Matchers collections
|
| 80 |
matcher_zoo = {
|
| 81 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
|
@@ -147,11 +458,11 @@ matcher_zoo = {
|
|
| 147 |
"config_feature": extract_features.confs["d2net-ss"],
|
| 148 |
"dense": False,
|
| 149 |
},
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
"alike": {
|
| 156 |
"config": match_features.confs["NN-mutual"],
|
| 157 |
"config_feature": extract_features.confs["alike"],
|
|
@@ -177,6 +488,6 @@ matcher_zoo = {
|
|
| 177 |
"config_feature": extract_features.confs["sift"],
|
| 178 |
"dense": False,
|
| 179 |
},
|
| 180 |
-
|
| 181 |
-
|
| 182 |
}
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from itertools import combinations
|
| 6 |
import cv2
|
| 7 |
+
import gradio as gr
|
| 8 |
from hloc import matchers, extractors
|
| 9 |
from hloc.utils.base_model import dynamic_load
|
| 10 |
from hloc import match_dense, match_features, extract_features
|
| 11 |
+
from .viz import draw_matches, fig2im, plot_images, plot_color_line_matches
|
|
|
|
| 12 |
|
| 13 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 14 |
|
|
|
|
| 25 |
return model
|
| 26 |
|
| 27 |
|
| 28 |
+
def gen_examples():
|
| 29 |
+
random.seed(1)
|
| 30 |
+
example_matchers = [
|
| 31 |
+
"disk+lightglue",
|
| 32 |
+
"loftr",
|
| 33 |
+
"disk",
|
| 34 |
+
"d2net",
|
| 35 |
+
"topicfm",
|
| 36 |
+
"superpoint+superglue",
|
| 37 |
+
"disk+dualsoftmax",
|
| 38 |
+
"lanet",
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
def gen_images_pairs(path: str, count: int = 5):
|
| 42 |
+
imgs_list = [
|
| 43 |
+
os.path.join(path, file)
|
| 44 |
+
for file in os.listdir(path)
|
| 45 |
+
if file.lower().endswith((".jpg", ".jpeg", ".png"))
|
| 46 |
+
]
|
| 47 |
+
pairs = list(combinations(imgs_list, 2))
|
| 48 |
+
selected = random.sample(range(len(pairs)), count)
|
| 49 |
+
return [pairs[i] for i in selected]
|
| 50 |
+
# image pair path
|
| 51 |
+
path = "datasets/sacre_coeur/mapping"
|
| 52 |
+
pairs = gen_images_pairs(path, len(example_matchers))
|
| 53 |
+
match_setting_threshold = 0.1
|
| 54 |
+
match_setting_max_features = 2000
|
| 55 |
+
detect_keypoints_threshold = 0.01
|
| 56 |
+
enable_ransac = False
|
| 57 |
+
ransac_method = "RANSAC"
|
| 58 |
+
ransac_reproj_threshold = 8
|
| 59 |
+
ransac_confidence = 0.999
|
| 60 |
+
ransac_max_iter = 10000
|
| 61 |
+
input_lists = []
|
| 62 |
+
for pair, mt in zip(pairs, example_matchers):
|
| 63 |
+
input_lists.append(
|
| 64 |
+
[
|
| 65 |
+
pair[0],
|
| 66 |
+
pair[1],
|
| 67 |
+
match_setting_threshold,
|
| 68 |
+
match_setting_max_features,
|
| 69 |
+
detect_keypoints_threshold,
|
| 70 |
+
mt,
|
| 71 |
+
enable_ransac,
|
| 72 |
+
ransac_method,
|
| 73 |
+
ransac_reproj_threshold,
|
| 74 |
+
ransac_confidence,
|
| 75 |
+
ransac_max_iter,
|
| 76 |
+
]
|
| 77 |
+
)
|
| 78 |
+
return input_lists
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def filter_matches(
|
| 82 |
+
pred,
|
| 83 |
+
ransac_method="RANSAC",
|
| 84 |
+
ransac_reproj_threshold=8,
|
| 85 |
+
ransac_confidence=0.999,
|
| 86 |
+
ransac_max_iter=10000,
|
| 87 |
+
):
|
| 88 |
+
mkpts0 = None
|
| 89 |
+
mkpts1 = None
|
| 90 |
+
feature_type = None
|
| 91 |
+
if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
|
| 92 |
+
mkpts0 = pred["keypoints0_orig"]
|
| 93 |
+
mkpts1 = pred["keypoints1_orig"]
|
| 94 |
+
feature_type = "KEYPOINT"
|
| 95 |
+
elif (
|
| 96 |
+
"line_keypoints0_orig" in pred.keys()
|
| 97 |
+
and "line_keypoints1_orig" in pred.keys()
|
| 98 |
+
):
|
| 99 |
+
mkpts0 = pred["line_keypoints0_orig"]
|
| 100 |
+
mkpts1 = pred["line_keypoints1_orig"]
|
| 101 |
+
feature_type = "LINE"
|
| 102 |
+
else:
|
| 103 |
+
return pred
|
| 104 |
+
if mkpts0 is None or mkpts0 is None:
|
| 105 |
+
return pred
|
| 106 |
+
if ransac_method not in ransac_zoo.keys():
|
| 107 |
+
ransac_method = "RANSAC"
|
| 108 |
+
H, mask = cv2.findHomography(
|
| 109 |
+
mkpts0,
|
| 110 |
+
mkpts1,
|
| 111 |
+
method=ransac_zoo[ransac_method],
|
| 112 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
| 113 |
+
confidence=ransac_confidence,
|
| 114 |
+
maxIters=ransac_max_iter,
|
| 115 |
+
)
|
| 116 |
+
mask = np.array(mask.ravel().astype("bool"), dtype="bool")
|
| 117 |
+
if H is not None:
|
| 118 |
+
if feature_type == "KEYPOINT":
|
| 119 |
+
pred["keypoints0_orig"] = mkpts0[mask]
|
| 120 |
+
pred["keypoints1_orig"] = mkpts1[mask]
|
| 121 |
+
pred["mconf"] = pred["mconf"][mask]
|
| 122 |
+
elif feature_type == "LINE":
|
| 123 |
+
pred["line_keypoints0_orig"] = mkpts0[mask]
|
| 124 |
+
pred["line_keypoints1_orig"] = mkpts1[mask]
|
| 125 |
+
return pred
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def compute_geom(
|
| 129 |
+
pred,
|
| 130 |
+
ransac_method="RANSAC",
|
| 131 |
+
ransac_reproj_threshold=8,
|
| 132 |
+
ransac_confidence=0.999,
|
| 133 |
+
ransac_max_iter=10000,
|
| 134 |
+
) -> dict:
|
| 135 |
+
mkpts0 = None
|
| 136 |
+
mkpts1 = None
|
| 137 |
+
|
| 138 |
+
if "keypoints0_orig" in pred.keys() and "keypoints1_orig" in pred.keys():
|
| 139 |
+
mkpts0 = pred["keypoints0_orig"]
|
| 140 |
+
mkpts1 = pred["keypoints1_orig"]
|
| 141 |
+
|
| 142 |
+
if (
|
| 143 |
+
"line_keypoints0_orig" in pred.keys()
|
| 144 |
+
and "line_keypoints1_orig" in pred.keys()
|
| 145 |
+
):
|
| 146 |
+
mkpts0 = pred["line_keypoints0_orig"]
|
| 147 |
+
mkpts1 = pred["line_keypoints1_orig"]
|
| 148 |
+
|
| 149 |
+
if mkpts0 is not None and mkpts1 is not None:
|
| 150 |
+
if len(mkpts0) < 8:
|
| 151 |
+
return {}
|
| 152 |
+
h1, w1, _ = pred["image0_orig"].shape
|
| 153 |
+
geo_info = {}
|
| 154 |
+
F, inliers = cv2.findFundamentalMat(
|
| 155 |
+
mkpts0,
|
| 156 |
+
mkpts1,
|
| 157 |
+
method=ransac_zoo[ransac_method],
|
| 158 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
| 159 |
+
confidence=ransac_confidence,
|
| 160 |
+
maxIters=ransac_max_iter,
|
| 161 |
+
)
|
| 162 |
+
geo_info["Fundamental"] = F.tolist()
|
| 163 |
+
H, _ = cv2.findHomography(
|
| 164 |
+
mkpts1,
|
| 165 |
+
mkpts0,
|
| 166 |
+
method=ransac_zoo[ransac_method],
|
| 167 |
+
ransacReprojThreshold=ransac_reproj_threshold,
|
| 168 |
+
confidence=ransac_confidence,
|
| 169 |
+
maxIters=ransac_max_iter,
|
| 170 |
+
)
|
| 171 |
+
geo_info["Homography"] = H.tolist()
|
| 172 |
+
_, H1, H2 = cv2.stereoRectifyUncalibrated(
|
| 173 |
+
mkpts0.reshape(-1, 2), mkpts1.reshape(-1, 2), F, imgSize=(w1, h1)
|
| 174 |
+
)
|
| 175 |
+
geo_info["H1"] = H1.tolist()
|
| 176 |
+
geo_info["H2"] = H2.tolist()
|
| 177 |
+
return geo_info
|
| 178 |
+
else:
|
| 179 |
+
return {}
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def wrap_images(img0, img1, geo_info, geom_type):
|
| 183 |
+
h1, w1, _ = img0.shape
|
| 184 |
+
h2, w2, _ = img1.shape
|
| 185 |
+
result_matrix = None
|
| 186 |
+
if geo_info is not None and len(geo_info) != 0:
|
| 187 |
+
rectified_image0 = img0
|
| 188 |
+
rectified_image1 = None
|
| 189 |
+
H = np.array(geo_info["Homography"])
|
| 190 |
+
F = np.array(geo_info["Fundamental"])
|
| 191 |
+
title = []
|
| 192 |
+
if geom_type == "Homography":
|
| 193 |
+
rectified_image1 = cv2.warpPerspective(
|
| 194 |
+
img1, H, (img0.shape[1] + img1.shape[1], img0.shape[0])
|
| 195 |
+
)
|
| 196 |
+
result_matrix = H
|
| 197 |
+
title = ["Image 0", "Image 1 - warped"]
|
| 198 |
+
elif geom_type == "Fundamental":
|
| 199 |
+
H1, H2 = np.array(geo_info["H1"]), np.array(geo_info["H2"])
|
| 200 |
+
rectified_image0 = cv2.warpPerspective(img0, H1, (w1, h1))
|
| 201 |
+
rectified_image1 = cv2.warpPerspective(img1, H2, (w2, h2))
|
| 202 |
+
result_matrix = F
|
| 203 |
+
title = ["Image 0 - warped", "Image 1 - warped"]
|
| 204 |
+
else:
|
| 205 |
+
print("Error: Unknown geometry type")
|
| 206 |
+
fig = plot_images(
|
| 207 |
+
[rectified_image0.squeeze(), rectified_image1.squeeze()],
|
| 208 |
+
title,
|
| 209 |
+
dpi=300,
|
| 210 |
+
)
|
| 211 |
+
dictionary = {
|
| 212 |
+
"row1": result_matrix[0].tolist(),
|
| 213 |
+
"row2": result_matrix[1].tolist(),
|
| 214 |
+
"row3": result_matrix[2].tolist(),
|
| 215 |
+
}
|
| 216 |
+
return fig2im(fig), dictionary
|
| 217 |
+
else:
|
| 218 |
+
return None, None
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def change_estimate_geom(input_image0, input_image1, matches_info, choice):
|
| 222 |
+
if (
|
| 223 |
+
matches_info is None
|
| 224 |
+
or len(matches_info) < 1
|
| 225 |
+
or "geom_info" not in matches_info.keys()
|
| 226 |
+
):
|
| 227 |
+
return None, None
|
| 228 |
+
geom_info = matches_info["geom_info"]
|
| 229 |
+
wrapped_images = None
|
| 230 |
+
if choice != "No":
|
| 231 |
+
wrapped_images, _ = wrap_images(
|
| 232 |
+
input_image0, input_image1, geom_info, choice
|
| 233 |
+
)
|
| 234 |
+
return wrapped_images, matches_info
|
| 235 |
+
else:
|
| 236 |
+
return None, None
|
| 237 |
+
|
| 238 |
+
|
| 239 |
def display_matches(pred: dict):
|
| 240 |
img0 = pred["image0_orig"]
|
| 241 |
img1 = pred["image1_orig"]
|
|
|
|
| 256 |
img1,
|
| 257 |
mconf,
|
| 258 |
dpi=300,
|
| 259 |
+
titles=[
|
| 260 |
+
"Image 0 - matched keypoints",
|
| 261 |
+
"Image 1 - matched keypoints",
|
| 262 |
+
],
|
| 263 |
)
|
| 264 |
fig = fig_mkpts
|
| 265 |
if "line0_orig" in pred.keys() and "line1_orig" in pred.keys():
|
|
|
|
| 286 |
else:
|
| 287 |
mconf = np.ones(len(mkpts0))
|
| 288 |
fig_mkpts = draw_matches(mkpts0, mkpts1, img0, img1, mconf, dpi=300)
|
| 289 |
+
fig_lines = cv2.resize(
|
| 290 |
+
fig_lines, (fig_mkpts.shape[1], fig_mkpts.shape[0])
|
| 291 |
+
)
|
| 292 |
fig = np.concatenate([fig_mkpts, fig_lines], axis=0)
|
| 293 |
else:
|
| 294 |
fig = fig_lines
|
| 295 |
return fig, num_inliers
|
| 296 |
|
| 297 |
|
| 298 |
+
def run_matching(
|
| 299 |
+
image0,
|
| 300 |
+
image1,
|
| 301 |
+
match_threshold,
|
| 302 |
+
extract_max_keypoints,
|
| 303 |
+
keypoint_threshold,
|
| 304 |
+
key,
|
| 305 |
+
enable_ransac=False,
|
| 306 |
+
ransac_method="RANSAC",
|
| 307 |
+
ransac_reproj_threshold=8,
|
| 308 |
+
ransac_confidence=0.999,
|
| 309 |
+
ransac_max_iter=10000,
|
| 310 |
+
choice_estimate_geom="Homography",
|
| 311 |
+
):
|
| 312 |
+
# image0 and image1 is RGB mode
|
| 313 |
+
if image0 is None or image1 is None:
|
| 314 |
+
raise gr.Error("Error: No images found! Please upload two images.")
|
| 315 |
+
|
| 316 |
+
model = matcher_zoo[key]
|
| 317 |
+
match_conf = model["config"]
|
| 318 |
+
# update match config
|
| 319 |
+
match_conf["model"]["match_threshold"] = match_threshold
|
| 320 |
+
match_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 321 |
+
|
| 322 |
+
matcher = get_model(match_conf)
|
| 323 |
+
if model["dense"]:
|
| 324 |
+
pred = match_dense.match_images(
|
| 325 |
+
matcher, image0, image1, match_conf["preprocessing"], device=device
|
| 326 |
+
)
|
| 327 |
+
del matcher
|
| 328 |
+
extract_conf = None
|
| 329 |
+
else:
|
| 330 |
+
extract_conf = model["config_feature"]
|
| 331 |
+
# update extract config
|
| 332 |
+
extract_conf["model"]["max_keypoints"] = extract_max_keypoints
|
| 333 |
+
extract_conf["model"]["keypoint_threshold"] = keypoint_threshold
|
| 334 |
+
extractor = get_feature_model(extract_conf)
|
| 335 |
+
pred0 = extract_features.extract(
|
| 336 |
+
extractor, image0, extract_conf["preprocessing"]
|
| 337 |
+
)
|
| 338 |
+
pred1 = extract_features.extract(
|
| 339 |
+
extractor, image1, extract_conf["preprocessing"]
|
| 340 |
+
)
|
| 341 |
+
pred = match_features.match_images(matcher, pred0, pred1)
|
| 342 |
+
del extractor
|
| 343 |
+
|
| 344 |
+
if enable_ransac:
|
| 345 |
+
filter_matches(
|
| 346 |
+
pred,
|
| 347 |
+
ransac_method=ransac_method,
|
| 348 |
+
ransac_reproj_threshold=ransac_reproj_threshold,
|
| 349 |
+
ransac_confidence=ransac_confidence,
|
| 350 |
+
ransac_max_iter=ransac_max_iter,
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
fig, num_inliers = display_matches(pred)
|
| 354 |
+
geom_info = compute_geom(pred)
|
| 355 |
+
output_wrapped, _ = change_estimate_geom(
|
| 356 |
+
pred["image0_orig"],
|
| 357 |
+
pred["image1_orig"],
|
| 358 |
+
{"geom_info": geom_info},
|
| 359 |
+
choice_estimate_geom,
|
| 360 |
+
)
|
| 361 |
+
del pred
|
| 362 |
+
return (
|
| 363 |
+
fig,
|
| 364 |
+
{"matches number": num_inliers},
|
| 365 |
+
{
|
| 366 |
+
"match_conf": match_conf,
|
| 367 |
+
"extractor_conf": extract_conf,
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"geom_info": geom_info,
|
| 371 |
+
},
|
| 372 |
+
output_wrapped,
|
| 373 |
+
# geometry_result,
|
| 374 |
+
)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# @ref: https://docs.opencv.org/4.x/d0/d74/md__build_4_x-contrib_docs-lin64_opencv_doc_tutorials_calib3d_usac.html
|
| 378 |
+
# AND: https://opencv.org/blog/2021/06/09/evaluating-opencvs-new-ransacs
|
| 379 |
+
ransac_zoo = {
|
| 380 |
+
"RANSAC": cv2.RANSAC,
|
| 381 |
+
"USAC_MAGSAC": cv2.USAC_MAGSAC,
|
| 382 |
+
"USAC_DEFAULT": cv2.USAC_DEFAULT,
|
| 383 |
+
"USAC_FM_8PTS": cv2.USAC_FM_8PTS,
|
| 384 |
+
"USAC_PROSAC": cv2.USAC_PROSAC,
|
| 385 |
+
"USAC_FAST": cv2.USAC_FAST,
|
| 386 |
+
"USAC_ACCURATE": cv2.USAC_ACCURATE,
|
| 387 |
+
"USAC_PARALLEL": cv2.USAC_PARALLEL,
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
# Matchers collections
|
| 391 |
matcher_zoo = {
|
| 392 |
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
|
|
|
| 458 |
"config_feature": extract_features.confs["d2net-ss"],
|
| 459 |
"dense": False,
|
| 460 |
},
|
| 461 |
+
"d2net-ms": {
|
| 462 |
+
"config": match_features.confs["NN-mutual"],
|
| 463 |
+
"config_feature": extract_features.confs["d2net-ms"],
|
| 464 |
+
"dense": False,
|
| 465 |
+
},
|
| 466 |
"alike": {
|
| 467 |
"config": match_features.confs["NN-mutual"],
|
| 468 |
"config_feature": extract_features.confs["alike"],
|
|
|
|
| 488 |
"config_feature": extract_features.confs["sift"],
|
| 489 |
"dense": False,
|
| 490 |
},
|
| 491 |
+
"roma": {"config": match_dense.confs["roma"], "dense": True},
|
| 492 |
+
"DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
| 493 |
}
|
common/visualize_util.py
DELETED
|
@@ -1,642 +0,0 @@
|
|
| 1 |
-
""" Organize some frequently used visualization functions. """
|
| 2 |
-
import cv2
|
| 3 |
-
import numpy as np
|
| 4 |
-
import matplotlib
|
| 5 |
-
import matplotlib.pyplot as plt
|
| 6 |
-
import copy
|
| 7 |
-
import seaborn as sns
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# Plot junctions onto the image (return a separate copy)
|
| 11 |
-
def plot_junctions(input_image, junctions, junc_size=3, color=None):
|
| 12 |
-
"""
|
| 13 |
-
input_image: can be 0~1 float or 0~255 uint8.
|
| 14 |
-
junctions: Nx2 or 2xN np array.
|
| 15 |
-
junc_size: the size of the plotted circles.
|
| 16 |
-
"""
|
| 17 |
-
# Create image copy
|
| 18 |
-
image = copy.copy(input_image)
|
| 19 |
-
# Make sure the image is converted to 255 uint8
|
| 20 |
-
if image.dtype == np.uint8:
|
| 21 |
-
pass
|
| 22 |
-
# A float type image ranging from 0~1
|
| 23 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
| 24 |
-
image = (image * 255.0).astype(np.uint8)
|
| 25 |
-
# A float type image ranging from 0.~255.
|
| 26 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
| 27 |
-
image = image.astype(np.uint8)
|
| 28 |
-
else:
|
| 29 |
-
raise ValueError(
|
| 30 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
| 31 |
-
)
|
| 32 |
-
|
| 33 |
-
# Check whether the image is single channel
|
| 34 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
| 35 |
-
# Squeeze to H*W first
|
| 36 |
-
image = image.squeeze()
|
| 37 |
-
|
| 38 |
-
# Stack to channle 3
|
| 39 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
| 40 |
-
|
| 41 |
-
# Junction dimensions should be N*2
|
| 42 |
-
if not len(junctions.shape) == 2:
|
| 43 |
-
raise ValueError("[Error] junctions should be 2-dim array.")
|
| 44 |
-
|
| 45 |
-
# Always convert to N*2
|
| 46 |
-
if junctions.shape[-1] != 2:
|
| 47 |
-
if junctions.shape[0] == 2:
|
| 48 |
-
junctions = junctions.T
|
| 49 |
-
else:
|
| 50 |
-
raise ValueError("[Error] At least one of the two dims should be 2.")
|
| 51 |
-
|
| 52 |
-
# Round and convert junctions to int (and check the boundary)
|
| 53 |
-
H, W = image.shape[:2]
|
| 54 |
-
junctions = (np.round(junctions)).astype(np.int)
|
| 55 |
-
junctions[junctions < 0] = 0
|
| 56 |
-
junctions[junctions[:, 0] >= H, 0] = H - 1 # (first dim) max bounded by H-1
|
| 57 |
-
junctions[junctions[:, 1] >= W, 1] = W - 1 # (second dim) max bounded by W-1
|
| 58 |
-
|
| 59 |
-
# Iterate through all the junctions
|
| 60 |
-
num_junc = junctions.shape[0]
|
| 61 |
-
if color is None:
|
| 62 |
-
color = (0, 255.0, 0)
|
| 63 |
-
for idx in range(num_junc):
|
| 64 |
-
# Fetch one junction
|
| 65 |
-
junc = junctions[idx, :]
|
| 66 |
-
cv2.circle(
|
| 67 |
-
image, tuple(np.flip(junc)), radius=junc_size, color=color, thickness=3
|
| 68 |
-
)
|
| 69 |
-
|
| 70 |
-
return image
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
# Plot line segements given junctions and line adjecent map
|
| 74 |
-
def plot_line_segments(
|
| 75 |
-
input_image,
|
| 76 |
-
junctions,
|
| 77 |
-
line_map,
|
| 78 |
-
junc_size=3,
|
| 79 |
-
color=(0, 255.0, 0),
|
| 80 |
-
line_width=1,
|
| 81 |
-
plot_survived_junc=True,
|
| 82 |
-
):
|
| 83 |
-
"""
|
| 84 |
-
input_image: can be 0~1 float or 0~255 uint8.
|
| 85 |
-
junctions: Nx2 or 2xN np array.
|
| 86 |
-
line_map: NxN np array
|
| 87 |
-
junc_size: the size of the plotted circles.
|
| 88 |
-
color: color of the line segments (can be string "random")
|
| 89 |
-
line_width: width of the drawn segments.
|
| 90 |
-
plot_survived_junc: whether we only plot the survived junctions.
|
| 91 |
-
"""
|
| 92 |
-
# Create image copy
|
| 93 |
-
image = copy.copy(input_image)
|
| 94 |
-
# Make sure the image is converted to 255 uint8
|
| 95 |
-
if image.dtype == np.uint8:
|
| 96 |
-
pass
|
| 97 |
-
# A float type image ranging from 0~1
|
| 98 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
| 99 |
-
image = (image * 255.0).astype(np.uint8)
|
| 100 |
-
# A float type image ranging from 0.~255.
|
| 101 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
| 102 |
-
image = image.astype(np.uint8)
|
| 103 |
-
else:
|
| 104 |
-
raise ValueError(
|
| 105 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
| 106 |
-
)
|
| 107 |
-
|
| 108 |
-
# Check whether the image is single channel
|
| 109 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
| 110 |
-
# Squeeze to H*W first
|
| 111 |
-
image = image.squeeze()
|
| 112 |
-
|
| 113 |
-
# Stack to channle 3
|
| 114 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
| 115 |
-
|
| 116 |
-
# Junction dimensions should be 2
|
| 117 |
-
if not len(junctions.shape) == 2:
|
| 118 |
-
raise ValueError("[Error] junctions should be 2-dim array.")
|
| 119 |
-
|
| 120 |
-
# Always convert to N*2
|
| 121 |
-
if junctions.shape[-1] != 2:
|
| 122 |
-
if junctions.shape[0] == 2:
|
| 123 |
-
junctions = junctions.T
|
| 124 |
-
else:
|
| 125 |
-
raise ValueError("[Error] At least one of the two dims should be 2.")
|
| 126 |
-
|
| 127 |
-
# line_map dimension should be 2
|
| 128 |
-
if not len(line_map.shape) == 2:
|
| 129 |
-
raise ValueError("[Error] line_map should be 2-dim array.")
|
| 130 |
-
|
| 131 |
-
# Color should be "random" or a list or tuple with length 3
|
| 132 |
-
if color != "random":
|
| 133 |
-
if not (isinstance(color, tuple) or isinstance(color, list)):
|
| 134 |
-
raise ValueError("[Error] color should have type list or tuple.")
|
| 135 |
-
else:
|
| 136 |
-
if len(color) != 3:
|
| 137 |
-
raise ValueError(
|
| 138 |
-
"[Error] color should be a list or tuple with length 3."
|
| 139 |
-
)
|
| 140 |
-
|
| 141 |
-
# Make a copy of the line_map
|
| 142 |
-
line_map_tmp = copy.copy(line_map)
|
| 143 |
-
|
| 144 |
-
# Parse line_map back to segment pairs
|
| 145 |
-
segments = np.zeros([0, 4])
|
| 146 |
-
for idx in range(junctions.shape[0]):
|
| 147 |
-
# if no connectivity, just skip it
|
| 148 |
-
if line_map_tmp[idx, :].sum() == 0:
|
| 149 |
-
continue
|
| 150 |
-
# record the line segment
|
| 151 |
-
else:
|
| 152 |
-
for idx2 in np.where(line_map_tmp[idx, :] == 1)[0]:
|
| 153 |
-
p1 = np.flip(junctions[idx, :]) # Convert to xy format
|
| 154 |
-
p2 = np.flip(junctions[idx2, :]) # Convert to xy format
|
| 155 |
-
segments = np.concatenate(
|
| 156 |
-
(segments, np.array([p1[0], p1[1], p2[0], p2[1]])[None, ...]),
|
| 157 |
-
axis=0,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
# Update line_map
|
| 161 |
-
line_map_tmp[idx, idx2] = 0
|
| 162 |
-
line_map_tmp[idx2, idx] = 0
|
| 163 |
-
|
| 164 |
-
# Draw segment pairs
|
| 165 |
-
for idx in range(segments.shape[0]):
|
| 166 |
-
seg = np.round(segments[idx, :]).astype(np.int)
|
| 167 |
-
# Decide the color
|
| 168 |
-
if color != "random":
|
| 169 |
-
color = tuple(color)
|
| 170 |
-
else:
|
| 171 |
-
color = tuple(
|
| 172 |
-
np.random.rand(
|
| 173 |
-
3,
|
| 174 |
-
)
|
| 175 |
-
)
|
| 176 |
-
cv2.line(
|
| 177 |
-
image, tuple(seg[:2]), tuple(seg[2:]), color=color, thickness=line_width
|
| 178 |
-
)
|
| 179 |
-
|
| 180 |
-
# Also draw the junctions
|
| 181 |
-
if not plot_survived_junc:
|
| 182 |
-
num_junc = junctions.shape[0]
|
| 183 |
-
for idx in range(num_junc):
|
| 184 |
-
# Fetch one junction
|
| 185 |
-
junc = junctions[idx, :]
|
| 186 |
-
cv2.circle(
|
| 187 |
-
image,
|
| 188 |
-
tuple(np.flip(junc)),
|
| 189 |
-
radius=junc_size,
|
| 190 |
-
color=(0, 255.0, 0),
|
| 191 |
-
thickness=3,
|
| 192 |
-
)
|
| 193 |
-
# Only plot the junctions which are part of a line segment
|
| 194 |
-
else:
|
| 195 |
-
for idx in range(segments.shape[0]):
|
| 196 |
-
seg = np.round(segments[idx, :]).astype(np.int) # Already in HW format.
|
| 197 |
-
cv2.circle(
|
| 198 |
-
image,
|
| 199 |
-
tuple(seg[:2]),
|
| 200 |
-
radius=junc_size,
|
| 201 |
-
color=(0, 255.0, 0),
|
| 202 |
-
thickness=3,
|
| 203 |
-
)
|
| 204 |
-
cv2.circle(
|
| 205 |
-
image,
|
| 206 |
-
tuple(seg[2:]),
|
| 207 |
-
radius=junc_size,
|
| 208 |
-
color=(0, 255.0, 0),
|
| 209 |
-
thickness=3,
|
| 210 |
-
)
|
| 211 |
-
|
| 212 |
-
return image
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
# Plot line segments given Nx4 or Nx2x2 line segments
|
| 216 |
-
def plot_line_segments_from_segments(
|
| 217 |
-
input_image, line_segments, junc_size=3, color=(0, 255.0, 0), line_width=1
|
| 218 |
-
):
|
| 219 |
-
# Create image copy
|
| 220 |
-
image = copy.copy(input_image)
|
| 221 |
-
# Make sure the image is converted to 255 uint8
|
| 222 |
-
if image.dtype == np.uint8:
|
| 223 |
-
pass
|
| 224 |
-
# A float type image ranging from 0~1
|
| 225 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.max() <= 2.0:
|
| 226 |
-
image = (image * 255.0).astype(np.uint8)
|
| 227 |
-
# A float type image ranging from 0.~255.
|
| 228 |
-
elif image.dtype in [np.float32, np.float64, np.float] and image.mean() > 10.0:
|
| 229 |
-
image = image.astype(np.uint8)
|
| 230 |
-
else:
|
| 231 |
-
raise ValueError(
|
| 232 |
-
"[Error] Unknown image data type. Expect 0~1 float or 0~255 uint8."
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
# Check whether the image is single channel
|
| 236 |
-
if len(image.shape) == 2 or ((len(image.shape) == 3) and (image.shape[-1] == 1)):
|
| 237 |
-
# Squeeze to H*W first
|
| 238 |
-
image = image.squeeze()
|
| 239 |
-
|
| 240 |
-
# Stack to channle 3
|
| 241 |
-
image = np.concatenate([image[..., None] for _ in range(3)], axis=-1)
|
| 242 |
-
|
| 243 |
-
# Check the if line_segments are in (1) Nx4, or (2) Nx2x2.
|
| 244 |
-
H, W, _ = image.shape
|
| 245 |
-
# (1) Nx4 format
|
| 246 |
-
if len(line_segments.shape) == 2 and line_segments.shape[-1] == 4:
|
| 247 |
-
# Round to int32
|
| 248 |
-
line_segments = line_segments.astype(np.int32)
|
| 249 |
-
|
| 250 |
-
# Clip H dimension
|
| 251 |
-
line_segments[:, 0] = np.clip(line_segments[:, 0], a_min=0, a_max=H - 1)
|
| 252 |
-
line_segments[:, 2] = np.clip(line_segments[:, 2], a_min=0, a_max=H - 1)
|
| 253 |
-
|
| 254 |
-
# Clip W dimension
|
| 255 |
-
line_segments[:, 1] = np.clip(line_segments[:, 1], a_min=0, a_max=W - 1)
|
| 256 |
-
line_segments[:, 3] = np.clip(line_segments[:, 3], a_min=0, a_max=W - 1)
|
| 257 |
-
|
| 258 |
-
# Convert to Nx2x2 format
|
| 259 |
-
line_segments = np.concatenate(
|
| 260 |
-
[
|
| 261 |
-
np.expand_dims(line_segments[:, :2], axis=1),
|
| 262 |
-
np.expand_dims(line_segments[:, 2:], axis=1),
|
| 263 |
-
],
|
| 264 |
-
axis=1,
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
# (2) Nx2x2 format
|
| 268 |
-
elif len(line_segments.shape) == 3 and line_segments.shape[-1] == 2:
|
| 269 |
-
# Round to int32
|
| 270 |
-
line_segments = line_segments.astype(np.int32)
|
| 271 |
-
|
| 272 |
-
# Clip H dimension
|
| 273 |
-
line_segments[:, :, 0] = np.clip(line_segments[:, :, 0], a_min=0, a_max=H - 1)
|
| 274 |
-
line_segments[:, :, 1] = np.clip(line_segments[:, :, 1], a_min=0, a_max=W - 1)
|
| 275 |
-
|
| 276 |
-
else:
|
| 277 |
-
raise ValueError(
|
| 278 |
-
"[Error] line_segments should be either Nx4 or Nx2x2 in HW format."
|
| 279 |
-
)
|
| 280 |
-
|
| 281 |
-
# Draw segment pairs (all segments should be in HW format)
|
| 282 |
-
image = image.copy()
|
| 283 |
-
for idx in range(line_segments.shape[0]):
|
| 284 |
-
seg = np.round(line_segments[idx, :, :]).astype(np.int32)
|
| 285 |
-
# Decide the color
|
| 286 |
-
if color != "random":
|
| 287 |
-
color = tuple(color)
|
| 288 |
-
else:
|
| 289 |
-
color = tuple(
|
| 290 |
-
np.random.rand(
|
| 291 |
-
3,
|
| 292 |
-
)
|
| 293 |
-
)
|
| 294 |
-
cv2.line(
|
| 295 |
-
image,
|
| 296 |
-
tuple(np.flip(seg[0, :])),
|
| 297 |
-
tuple(np.flip(seg[1, :])),
|
| 298 |
-
color=color,
|
| 299 |
-
thickness=line_width,
|
| 300 |
-
)
|
| 301 |
-
|
| 302 |
-
# Also draw the junctions
|
| 303 |
-
cv2.circle(
|
| 304 |
-
image,
|
| 305 |
-
tuple(np.flip(seg[0, :])),
|
| 306 |
-
radius=junc_size,
|
| 307 |
-
color=(0, 255.0, 0),
|
| 308 |
-
thickness=3,
|
| 309 |
-
)
|
| 310 |
-
cv2.circle(
|
| 311 |
-
image,
|
| 312 |
-
tuple(np.flip(seg[1, :])),
|
| 313 |
-
radius=junc_size,
|
| 314 |
-
color=(0, 255.0, 0),
|
| 315 |
-
thickness=3,
|
| 316 |
-
)
|
| 317 |
-
|
| 318 |
-
return image
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
# Additional functions to visualize multiple images at the same time,
|
| 322 |
-
# e.g. for line matching
|
| 323 |
-
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
|
| 324 |
-
"""Plot a set of images horizontally.
|
| 325 |
-
Args:
|
| 326 |
-
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
| 327 |
-
titles: a list of strings, as titles for each image.
|
| 328 |
-
cmaps: colormaps for monochrome images.
|
| 329 |
-
"""
|
| 330 |
-
n = len(imgs)
|
| 331 |
-
if not isinstance(cmaps, (list, tuple)):
|
| 332 |
-
cmaps = [cmaps] * n
|
| 333 |
-
# figsize = (size*n, size*3/4) if size is not None else None
|
| 334 |
-
figsize = (size * n, size * 6 / 5) if size is not None else None
|
| 335 |
-
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
|
| 336 |
-
|
| 337 |
-
if n == 1:
|
| 338 |
-
ax = [ax]
|
| 339 |
-
for i in range(n):
|
| 340 |
-
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
| 341 |
-
ax[i].get_yaxis().set_ticks([])
|
| 342 |
-
ax[i].get_xaxis().set_ticks([])
|
| 343 |
-
ax[i].set_axis_off()
|
| 344 |
-
for spine in ax[i].spines.values(): # remove frame
|
| 345 |
-
spine.set_visible(False)
|
| 346 |
-
if titles:
|
| 347 |
-
ax[i].set_title(titles[i])
|
| 348 |
-
fig.tight_layout(pad=pad)
|
| 349 |
-
return fig
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
def plot_keypoints(kpts, colors="lime", ps=4):
|
| 353 |
-
"""Plot keypoints for existing images.
|
| 354 |
-
Args:
|
| 355 |
-
kpts: list of ndarrays of size (N, 2).
|
| 356 |
-
colors: string, or list of list of tuples (one for each keypoints).
|
| 357 |
-
ps: size of the keypoints as float.
|
| 358 |
-
"""
|
| 359 |
-
if not isinstance(colors, list):
|
| 360 |
-
colors = [colors] * len(kpts)
|
| 361 |
-
axes = plt.gcf().axes
|
| 362 |
-
for a, k, c in zip(axes, kpts, colors):
|
| 363 |
-
a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, indices=(0, 1), a=1.0):
|
| 367 |
-
"""Plot matches for a pair of existing images.
|
| 368 |
-
Args:
|
| 369 |
-
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
| 370 |
-
color: color of each match, string or RGB tuple. Random if not given.
|
| 371 |
-
lw: width of the lines.
|
| 372 |
-
ps: size of the end points (no endpoint if ps=0)
|
| 373 |
-
indices: indices of the images to draw the matches on.
|
| 374 |
-
a: alpha opacity of the match lines.
|
| 375 |
-
"""
|
| 376 |
-
fig = plt.gcf()
|
| 377 |
-
ax = fig.axes
|
| 378 |
-
assert len(ax) > max(indices)
|
| 379 |
-
ax0, ax1 = ax[indices[0]], ax[indices[1]]
|
| 380 |
-
fig.canvas.draw()
|
| 381 |
-
|
| 382 |
-
assert len(kpts0) == len(kpts1)
|
| 383 |
-
if color is None:
|
| 384 |
-
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
| 385 |
-
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
| 386 |
-
color = [color] * len(kpts0)
|
| 387 |
-
|
| 388 |
-
if lw > 0:
|
| 389 |
-
# transform the points into the figure coordinate system
|
| 390 |
-
transFigure = fig.transFigure.inverted()
|
| 391 |
-
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
|
| 392 |
-
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
|
| 393 |
-
fig.lines += [
|
| 394 |
-
matplotlib.lines.Line2D(
|
| 395 |
-
(fkpts0[i, 0], fkpts1[i, 0]),
|
| 396 |
-
(fkpts0[i, 1], fkpts1[i, 1]),
|
| 397 |
-
zorder=1,
|
| 398 |
-
transform=fig.transFigure,
|
| 399 |
-
c=color[i],
|
| 400 |
-
linewidth=lw,
|
| 401 |
-
alpha=a,
|
| 402 |
-
)
|
| 403 |
-
for i in range(len(kpts0))
|
| 404 |
-
]
|
| 405 |
-
|
| 406 |
-
# freeze the axes to prevent the transform to change
|
| 407 |
-
ax0.autoscale(enable=False)
|
| 408 |
-
ax1.autoscale(enable=False)
|
| 409 |
-
|
| 410 |
-
if ps > 0:
|
| 411 |
-
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps, zorder=2)
|
| 412 |
-
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps, zorder=2)
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
def plot_lines(
|
| 416 |
-
lines, line_colors="orange", point_colors="cyan", ps=4, lw=2, indices=(0, 1)
|
| 417 |
-
):
|
| 418 |
-
"""Plot lines and endpoints for existing images.
|
| 419 |
-
Args:
|
| 420 |
-
lines: list of ndarrays of size (N, 2, 2).
|
| 421 |
-
colors: string, or list of list of tuples (one for each keypoints).
|
| 422 |
-
ps: size of the keypoints as float pixels.
|
| 423 |
-
lw: line width as float pixels.
|
| 424 |
-
indices: indices of the images to draw the matches on.
|
| 425 |
-
"""
|
| 426 |
-
if not isinstance(line_colors, list):
|
| 427 |
-
line_colors = [line_colors] * len(lines)
|
| 428 |
-
if not isinstance(point_colors, list):
|
| 429 |
-
point_colors = [point_colors] * len(lines)
|
| 430 |
-
|
| 431 |
-
fig = plt.gcf()
|
| 432 |
-
ax = fig.axes
|
| 433 |
-
assert len(ax) > max(indices)
|
| 434 |
-
axes = [ax[i] for i in indices]
|
| 435 |
-
fig.canvas.draw()
|
| 436 |
-
|
| 437 |
-
# Plot the lines and junctions
|
| 438 |
-
for a, l, lc, pc in zip(axes, lines, line_colors, point_colors):
|
| 439 |
-
for i in range(len(l)):
|
| 440 |
-
line = matplotlib.lines.Line2D(
|
| 441 |
-
(l[i, 0, 0], l[i, 1, 0]),
|
| 442 |
-
(l[i, 0, 1], l[i, 1, 1]),
|
| 443 |
-
zorder=1,
|
| 444 |
-
c=lc,
|
| 445 |
-
linewidth=lw,
|
| 446 |
-
)
|
| 447 |
-
a.add_line(line)
|
| 448 |
-
pts = l.reshape(-1, 2)
|
| 449 |
-
a.scatter(pts[:, 0], pts[:, 1], c=pc, s=ps, linewidths=0, zorder=2)
|
| 450 |
-
|
| 451 |
-
return fig
|
| 452 |
-
|
| 453 |
-
|
| 454 |
-
def plot_line_matches(kpts0, kpts1, color=None, lw=1.5, indices=(0, 1), a=1.0):
|
| 455 |
-
"""Plot matches for a pair of existing images, parametrized by their middle point.
|
| 456 |
-
Args:
|
| 457 |
-
kpts0, kpts1: corresponding middle points of the lines of size (N, 2).
|
| 458 |
-
color: color of each match, string or RGB tuple. Random if not given.
|
| 459 |
-
lw: width of the lines.
|
| 460 |
-
indices: indices of the images to draw the matches on.
|
| 461 |
-
a: alpha opacity of the match lines.
|
| 462 |
-
"""
|
| 463 |
-
fig = plt.gcf()
|
| 464 |
-
ax = fig.axes
|
| 465 |
-
assert len(ax) > max(indices)
|
| 466 |
-
ax0, ax1 = ax[indices[0]], ax[indices[1]]
|
| 467 |
-
fig.canvas.draw()
|
| 468 |
-
|
| 469 |
-
assert len(kpts0) == len(kpts1)
|
| 470 |
-
if color is None:
|
| 471 |
-
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
| 472 |
-
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
| 473 |
-
color = [color] * len(kpts0)
|
| 474 |
-
|
| 475 |
-
if lw > 0:
|
| 476 |
-
# transform the points into the figure coordinate system
|
| 477 |
-
transFigure = fig.transFigure.inverted()
|
| 478 |
-
fkpts0 = transFigure.transform(ax0.transData.transform(kpts0))
|
| 479 |
-
fkpts1 = transFigure.transform(ax1.transData.transform(kpts1))
|
| 480 |
-
fig.lines += [
|
| 481 |
-
matplotlib.lines.Line2D(
|
| 482 |
-
(fkpts0[i, 0], fkpts1[i, 0]),
|
| 483 |
-
(fkpts0[i, 1], fkpts1[i, 1]),
|
| 484 |
-
zorder=1,
|
| 485 |
-
transform=fig.transFigure,
|
| 486 |
-
c=color[i],
|
| 487 |
-
linewidth=lw,
|
| 488 |
-
alpha=a,
|
| 489 |
-
)
|
| 490 |
-
for i in range(len(kpts0))
|
| 491 |
-
]
|
| 492 |
-
|
| 493 |
-
# freeze the axes to prevent the transform to change
|
| 494 |
-
ax0.autoscale(enable=False)
|
| 495 |
-
ax1.autoscale(enable=False)
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
|
| 499 |
-
"""Plot line matches for existing images with multiple colors.
|
| 500 |
-
Args:
|
| 501 |
-
lines: list of ndarrays of size (N, 2, 2).
|
| 502 |
-
correct_matches: bool array of size (N,) indicating correct matches.
|
| 503 |
-
lw: line width as float pixels.
|
| 504 |
-
indices: indices of the images to draw the matches on.
|
| 505 |
-
"""
|
| 506 |
-
n_lines = len(lines[0])
|
| 507 |
-
colors = sns.color_palette("husl", n_colors=n_lines)
|
| 508 |
-
np.random.shuffle(colors)
|
| 509 |
-
alphas = np.ones(n_lines)
|
| 510 |
-
# If correct_matches is not None, display wrong matches with a low alpha
|
| 511 |
-
if correct_matches is not None:
|
| 512 |
-
alphas[~np.array(correct_matches)] = 0.2
|
| 513 |
-
|
| 514 |
-
fig = plt.gcf()
|
| 515 |
-
ax = fig.axes
|
| 516 |
-
assert len(ax) > max(indices)
|
| 517 |
-
axes = [ax[i] for i in indices]
|
| 518 |
-
fig.canvas.draw()
|
| 519 |
-
|
| 520 |
-
# Plot the lines
|
| 521 |
-
for a, l in zip(axes, lines):
|
| 522 |
-
# Transform the points into the figure coordinate system
|
| 523 |
-
transFigure = fig.transFigure.inverted()
|
| 524 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
| 525 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
| 526 |
-
fig.lines += [
|
| 527 |
-
matplotlib.lines.Line2D(
|
| 528 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
| 529 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
| 530 |
-
zorder=1,
|
| 531 |
-
transform=fig.transFigure,
|
| 532 |
-
c=colors[i],
|
| 533 |
-
alpha=alphas[i],
|
| 534 |
-
linewidth=lw,
|
| 535 |
-
)
|
| 536 |
-
for i in range(n_lines)
|
| 537 |
-
]
|
| 538 |
-
|
| 539 |
-
return fig
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
def plot_color_lines(lines, correct_matches, wrong_matches, lw=2, indices=(0, 1)):
|
| 543 |
-
"""Plot line matches for existing images with multiple colors:
|
| 544 |
-
green for correct matches, red for wrong ones, and blue for the rest.
|
| 545 |
-
Args:
|
| 546 |
-
lines: list of ndarrays of size (N, 2, 2).
|
| 547 |
-
correct_matches: list of bool arrays of size N with correct matches.
|
| 548 |
-
wrong_matches: list of bool arrays of size (N,) with correct matches.
|
| 549 |
-
lw: line width as float pixels.
|
| 550 |
-
indices: indices of the images to draw the matches on.
|
| 551 |
-
"""
|
| 552 |
-
# palette = sns.color_palette()
|
| 553 |
-
palette = sns.color_palette("hls", 8)
|
| 554 |
-
blue = palette[5] # palette[0]
|
| 555 |
-
red = palette[0] # palette[3]
|
| 556 |
-
green = palette[2] # palette[2]
|
| 557 |
-
colors = [np.array([blue] * len(l)) for l in lines]
|
| 558 |
-
for i, c in enumerate(colors):
|
| 559 |
-
c[np.array(correct_matches[i])] = green
|
| 560 |
-
c[np.array(wrong_matches[i])] = red
|
| 561 |
-
|
| 562 |
-
fig = plt.gcf()
|
| 563 |
-
ax = fig.axes
|
| 564 |
-
assert len(ax) > max(indices)
|
| 565 |
-
axes = [ax[i] for i in indices]
|
| 566 |
-
fig.canvas.draw()
|
| 567 |
-
|
| 568 |
-
# Plot the lines
|
| 569 |
-
for a, l, c in zip(axes, lines, colors):
|
| 570 |
-
# Transform the points into the figure coordinate system
|
| 571 |
-
transFigure = fig.transFigure.inverted()
|
| 572 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
| 573 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
| 574 |
-
fig.lines += [
|
| 575 |
-
matplotlib.lines.Line2D(
|
| 576 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
| 577 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
| 578 |
-
zorder=1,
|
| 579 |
-
transform=fig.transFigure,
|
| 580 |
-
c=c[i],
|
| 581 |
-
linewidth=lw,
|
| 582 |
-
)
|
| 583 |
-
for i in range(len(l))
|
| 584 |
-
]
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
def plot_subsegment_matches(lines, subsegments, lw=2, indices=(0, 1)):
|
| 588 |
-
"""Plot line matches for existing images with multiple colors and
|
| 589 |
-
highlight the actually matched subsegments.
|
| 590 |
-
Args:
|
| 591 |
-
lines: list of ndarrays of size (N, 2, 2).
|
| 592 |
-
subsegments: list of ndarrays of size (N, 2, 2).
|
| 593 |
-
lw: line width as float pixels.
|
| 594 |
-
indices: indices of the images to draw the matches on.
|
| 595 |
-
"""
|
| 596 |
-
n_lines = len(lines[0])
|
| 597 |
-
colors = sns.cubehelix_palette(
|
| 598 |
-
start=2, rot=-0.2, dark=0.3, light=0.7, gamma=1.3, hue=1, n_colors=n_lines
|
| 599 |
-
)
|
| 600 |
-
|
| 601 |
-
fig = plt.gcf()
|
| 602 |
-
ax = fig.axes
|
| 603 |
-
assert len(ax) > max(indices)
|
| 604 |
-
axes = [ax[i] for i in indices]
|
| 605 |
-
fig.canvas.draw()
|
| 606 |
-
|
| 607 |
-
# Plot the lines
|
| 608 |
-
for a, l, ss in zip(axes, lines, subsegments):
|
| 609 |
-
# Transform the points into the figure coordinate system
|
| 610 |
-
transFigure = fig.transFigure.inverted()
|
| 611 |
-
|
| 612 |
-
# Draw full line
|
| 613 |
-
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
| 614 |
-
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
| 615 |
-
fig.lines += [
|
| 616 |
-
matplotlib.lines.Line2D(
|
| 617 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
| 618 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
| 619 |
-
zorder=1,
|
| 620 |
-
transform=fig.transFigure,
|
| 621 |
-
c="red",
|
| 622 |
-
alpha=0.7,
|
| 623 |
-
linewidth=lw,
|
| 624 |
-
)
|
| 625 |
-
for i in range(n_lines)
|
| 626 |
-
]
|
| 627 |
-
|
| 628 |
-
# Draw matched subsegment
|
| 629 |
-
endpoint0 = transFigure.transform(a.transData.transform(ss[:, 0]))
|
| 630 |
-
endpoint1 = transFigure.transform(a.transData.transform(ss[:, 1]))
|
| 631 |
-
fig.lines += [
|
| 632 |
-
matplotlib.lines.Line2D(
|
| 633 |
-
(endpoint0[i, 0], endpoint1[i, 0]),
|
| 634 |
-
(endpoint0[i, 1], endpoint1[i, 1]),
|
| 635 |
-
zorder=1,
|
| 636 |
-
transform=fig.transFigure,
|
| 637 |
-
c=colors[i],
|
| 638 |
-
alpha=1,
|
| 639 |
-
linewidth=lw,
|
| 640 |
-
)
|
| 641 |
-
for i in range(n_lines)
|
| 642 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
common/{plotting.py → viz.py}
RENAMED
|
@@ -6,6 +6,7 @@ import matplotlib.cm as cm
|
|
| 6 |
from PIL import Image
|
| 7 |
import torch.nn.functional as F
|
| 8 |
import torch
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
def _compute_conf_thresh(data):
|
|
@@ -19,7 +20,77 @@ def _compute_conf_thresh(data):
|
|
| 19 |
return thr
|
| 20 |
|
| 21 |
|
| 22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
def make_matching_figure(
|
|
@@ -57,7 +128,7 @@ def make_matching_figure(
|
|
| 57 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
| 58 |
|
| 59 |
# draw matches
|
| 60 |
-
if mkpts0.shape[0]
|
| 61 |
fig.canvas.draw()
|
| 62 |
transFigure = fig.transFigure.inverted()
|
| 63 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
|
@@ -105,8 +176,12 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
|
| 105 |
b_mask = data["m_bids"] == b_id
|
| 106 |
conf_thr = _compute_conf_thresh(data)
|
| 107 |
|
| 108 |
-
img0 = (
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
|
| 111 |
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
|
| 112 |
|
|
@@ -131,8 +206,10 @@ def _make_evaluation_figure(data, b_id, alpha="dynamic"):
|
|
| 131 |
|
| 132 |
text = [
|
| 133 |
f"#Matches {len(kpts0)}",
|
| 134 |
-
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):
|
| 135 |
-
f"
|
|
|
|
|
|
|
| 136 |
]
|
| 137 |
|
| 138 |
# make the figure
|
|
@@ -188,7 +265,9 @@ def error_colormap(err, thr, alpha=1.0):
|
|
| 188 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
| 189 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
| 190 |
return np.clip(
|
| 191 |
-
np.stack(
|
|
|
|
|
|
|
| 192 |
0,
|
| 193 |
1,
|
| 194 |
)
|
|
@@ -200,9 +279,13 @@ np.random.shuffle(color_map)
|
|
| 200 |
|
| 201 |
|
| 202 |
def draw_topics(
|
| 203 |
-
data,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
):
|
| 205 |
-
|
| 206 |
topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
|
| 207 |
hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
|
| 208 |
hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
|
|
@@ -237,7 +320,10 @@ def draw_topics(
|
|
| 237 |
dim=-1, keepdim=True
|
| 238 |
) # .float() / (n_topics - 1) #* 255 + 1
|
| 239 |
# topic1[~mask1_nonzero] = -1
|
| 240 |
-
label_img0, label_img1 =
|
|
|
|
|
|
|
|
|
|
| 241 |
for i, k in enumerate(top_topics):
|
| 242 |
label_img0[topic0 == k] = color_map[k]
|
| 243 |
label_img1[topic1 == k] = color_map[k]
|
|
@@ -312,24 +398,30 @@ def draw_topicfm_demo(
|
|
| 312 |
opencv_display=False,
|
| 313 |
opencv_title="",
|
| 314 |
):
|
| 315 |
-
topic_map0, topic_map1 = draw_topics(
|
| 316 |
-
|
| 317 |
-
mask_tm0, mask_tm1 = np.expand_dims(topic_map0 >= 0, axis=-1), np.expand_dims(
|
| 318 |
-
topic_map1 >= 0, axis=-1
|
| 319 |
)
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
|
| 322 |
-
topic_cm0 = cv2.cvtColor(
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
|
| 325 |
overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
|
| 326 |
|
| 327 |
cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
|
| 328 |
cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
|
| 329 |
|
| 330 |
-
overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
|
| 331 |
-
|
| 332 |
-
)
|
| 333 |
|
| 334 |
h0, w0 = img0.shape[:2]
|
| 335 |
h1, w1 = img1.shape[:2]
|
|
@@ -338,7 +430,9 @@ def draw_topicfm_demo(
|
|
| 338 |
out_fig[:h0, :w0] = overlay0
|
| 339 |
if h0 >= h1:
|
| 340 |
start = (h0 - h1) // 2
|
| 341 |
-
out_fig[
|
|
|
|
|
|
|
| 342 |
else:
|
| 343 |
start = (h1 - h0) // 2
|
| 344 |
out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
|
|
@@ -358,7 +452,8 @@ def draw_topicfm_demo(
|
|
| 358 |
img1[start : start + h0] * 255
|
| 359 |
).astype(np.uint8)
|
| 360 |
|
| 361 |
-
# draw matching lines, this is inspried from
|
|
|
|
| 362 |
mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
|
| 363 |
mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
|
| 364 |
|
|
|
|
| 6 |
from PIL import Image
|
| 7 |
import torch.nn.functional as F
|
| 8 |
import torch
|
| 9 |
+
import seaborn as sns
|
| 10 |
|
| 11 |
|
| 12 |
def _compute_conf_thresh(data):
|
|
|
|
| 20 |
return thr
|
| 21 |
|
| 22 |
|
| 23 |
+
def plot_images(imgs, titles=None, cmaps="gray", dpi=100, size=5, pad=0.5):
|
| 24 |
+
"""Plot a set of images horizontally.
|
| 25 |
+
Args:
|
| 26 |
+
imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
|
| 27 |
+
titles: a list of strings, as titles for each image.
|
| 28 |
+
cmaps: colormaps for monochrome images.
|
| 29 |
+
"""
|
| 30 |
+
n = len(imgs)
|
| 31 |
+
if not isinstance(cmaps, (list, tuple)):
|
| 32 |
+
cmaps = [cmaps] * n
|
| 33 |
+
# figsize = (size*n, size*3/4) if size is not None else None
|
| 34 |
+
figsize = (size * n, size * 6 / 5) if size is not None else None
|
| 35 |
+
fig, ax = plt.subplots(1, n, figsize=figsize, dpi=dpi)
|
| 36 |
+
|
| 37 |
+
if n == 1:
|
| 38 |
+
ax = [ax]
|
| 39 |
+
for i in range(n):
|
| 40 |
+
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
| 41 |
+
ax[i].get_yaxis().set_ticks([])
|
| 42 |
+
ax[i].get_xaxis().set_ticks([])
|
| 43 |
+
ax[i].set_axis_off()
|
| 44 |
+
for spine in ax[i].spines.values(): # remove frame
|
| 45 |
+
spine.set_visible(False)
|
| 46 |
+
if titles:
|
| 47 |
+
ax[i].set_title(titles[i])
|
| 48 |
+
fig.tight_layout(pad=pad)
|
| 49 |
+
return fig
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def plot_color_line_matches(lines, correct_matches=None, lw=2, indices=(0, 1)):
|
| 53 |
+
"""Plot line matches for existing images with multiple colors.
|
| 54 |
+
Args:
|
| 55 |
+
lines: list of ndarrays of size (N, 2, 2).
|
| 56 |
+
correct_matches: bool array of size (N,) indicating correct matches.
|
| 57 |
+
lw: line width as float pixels.
|
| 58 |
+
indices: indices of the images to draw the matches on.
|
| 59 |
+
"""
|
| 60 |
+
n_lines = len(lines[0])
|
| 61 |
+
colors = sns.color_palette("husl", n_colors=n_lines)
|
| 62 |
+
np.random.shuffle(colors)
|
| 63 |
+
alphas = np.ones(n_lines)
|
| 64 |
+
# If correct_matches is not None, display wrong matches with a low alpha
|
| 65 |
+
if correct_matches is not None:
|
| 66 |
+
alphas[~np.array(correct_matches)] = 0.2
|
| 67 |
+
|
| 68 |
+
fig = plt.gcf()
|
| 69 |
+
ax = fig.axes
|
| 70 |
+
assert len(ax) > max(indices)
|
| 71 |
+
axes = [ax[i] for i in indices]
|
| 72 |
+
fig.canvas.draw()
|
| 73 |
+
|
| 74 |
+
# Plot the lines
|
| 75 |
+
for a, l in zip(axes, lines):
|
| 76 |
+
# Transform the points into the figure coordinate system
|
| 77 |
+
transFigure = fig.transFigure.inverted()
|
| 78 |
+
endpoint0 = transFigure.transform(a.transData.transform(l[:, 0]))
|
| 79 |
+
endpoint1 = transFigure.transform(a.transData.transform(l[:, 1]))
|
| 80 |
+
fig.lines += [
|
| 81 |
+
matplotlib.lines.Line2D(
|
| 82 |
+
(endpoint0[i, 0], endpoint1[i, 0]),
|
| 83 |
+
(endpoint0[i, 1], endpoint1[i, 1]),
|
| 84 |
+
zorder=1,
|
| 85 |
+
transform=fig.transFigure,
|
| 86 |
+
c=colors[i],
|
| 87 |
+
alpha=alphas[i],
|
| 88 |
+
linewidth=lw,
|
| 89 |
+
)
|
| 90 |
+
for i in range(n_lines)
|
| 91 |
+
]
|
| 92 |
+
|
| 93 |
+
return fig
|
| 94 |
|
| 95 |
|
| 96 |
def make_matching_figure(
|
|
|
|
| 128 |
axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c="w", s=5)
|
| 129 |
|
| 130 |
# draw matches
|
| 131 |
+
if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
|
| 132 |
fig.canvas.draw()
|
| 133 |
transFigure = fig.transFigure.inverted()
|
| 134 |
fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
|
|
|
|
| 176 |
b_mask = data["m_bids"] == b_id
|
| 177 |
conf_thr = _compute_conf_thresh(data)
|
| 178 |
|
| 179 |
+
img0 = (
|
| 180 |
+
(data["image0"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
| 181 |
+
)
|
| 182 |
+
img1 = (
|
| 183 |
+
(data["image1"][b_id][0].cpu().numpy() * 255).round().astype(np.int32)
|
| 184 |
+
)
|
| 185 |
kpts0 = data["mkpts0_f"][b_mask].cpu().numpy()
|
| 186 |
kpts1 = data["mkpts1_f"][b_mask].cpu().numpy()
|
| 187 |
|
|
|
|
| 206 |
|
| 207 |
text = [
|
| 208 |
f"#Matches {len(kpts0)}",
|
| 209 |
+
f"Precision({conf_thr:.2e}) ({100 * precision:.1f}%):"
|
| 210 |
+
f" {n_correct}/{len(kpts0)}",
|
| 211 |
+
f"Recall({conf_thr:.2e}) ({100 * recall:.1f}%):"
|
| 212 |
+
f" {n_correct}/{n_gt_matches}",
|
| 213 |
]
|
| 214 |
|
| 215 |
# make the figure
|
|
|
|
| 265 |
assert alpha <= 1.0 and alpha > 0, f"Invaid alpha value: {alpha}"
|
| 266 |
x = 1 - np.clip(err / (thr * 2), 0, 1)
|
| 267 |
return np.clip(
|
| 268 |
+
np.stack(
|
| 269 |
+
[2 - x * 2, x * 2, np.zeros_like(x), np.ones_like(x) * alpha], -1
|
| 270 |
+
),
|
| 271 |
0,
|
| 272 |
1,
|
| 273 |
)
|
|
|
|
| 279 |
|
| 280 |
|
| 281 |
def draw_topics(
|
| 282 |
+
data,
|
| 283 |
+
img0,
|
| 284 |
+
img1,
|
| 285 |
+
saved_folder="viz_topics",
|
| 286 |
+
show_n_topics=8,
|
| 287 |
+
saved_name=None,
|
| 288 |
):
|
|
|
|
| 289 |
topic0, topic1 = data["topic_matrix"]["img0"], data["topic_matrix"]["img1"]
|
| 290 |
hw0_c, hw1_c = data["hw0_c"], data["hw1_c"]
|
| 291 |
hw0_i, hw1_i = data["hw0_i"], data["hw1_i"]
|
|
|
|
| 320 |
dim=-1, keepdim=True
|
| 321 |
) # .float() / (n_topics - 1) #* 255 + 1
|
| 322 |
# topic1[~mask1_nonzero] = -1
|
| 323 |
+
label_img0, label_img1 = (
|
| 324 |
+
torch.zeros_like(topic0) - 1,
|
| 325 |
+
torch.zeros_like(topic1) - 1,
|
| 326 |
+
)
|
| 327 |
for i, k in enumerate(top_topics):
|
| 328 |
label_img0[topic0 == k] = color_map[k]
|
| 329 |
label_img1[topic1 == k] = color_map[k]
|
|
|
|
| 398 |
opencv_display=False,
|
| 399 |
opencv_title="",
|
| 400 |
):
|
| 401 |
+
topic_map0, topic_map1 = draw_topics(
|
| 402 |
+
data, img0, img1, show_n_topics=show_n_topics
|
|
|
|
|
|
|
| 403 |
)
|
| 404 |
|
| 405 |
+
mask_tm0, mask_tm1 = np.expand_dims(
|
| 406 |
+
topic_map0 >= 0, axis=-1
|
| 407 |
+
), np.expand_dims(topic_map1 >= 0, axis=-1)
|
| 408 |
+
|
| 409 |
topic_cm0, topic_cm1 = cm.jet(topic_map0 / 99.0), cm.jet(topic_map1 / 99.0)
|
| 410 |
+
topic_cm0 = cv2.cvtColor(
|
| 411 |
+
topic_cm0[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
|
| 412 |
+
)
|
| 413 |
+
topic_cm1 = cv2.cvtColor(
|
| 414 |
+
topic_cm1[..., :3].astype(np.float32), cv2.COLOR_RGB2BGR
|
| 415 |
+
)
|
| 416 |
overlay0 = (mask_tm0 * topic_cm0 + (1 - mask_tm0) * img0).astype(np.float32)
|
| 417 |
overlay1 = (mask_tm1 * topic_cm1 + (1 - mask_tm1) * img1).astype(np.float32)
|
| 418 |
|
| 419 |
cv2.addWeighted(overlay0, topic_alpha, img0, 1 - topic_alpha, 0, overlay0)
|
| 420 |
cv2.addWeighted(overlay1, topic_alpha, img1, 1 - topic_alpha, 0, overlay1)
|
| 421 |
|
| 422 |
+
overlay0, overlay1 = (overlay0 * 255).astype(np.uint8), (
|
| 423 |
+
overlay1 * 255
|
| 424 |
+
).astype(np.uint8)
|
| 425 |
|
| 426 |
h0, w0 = img0.shape[:2]
|
| 427 |
h1, w1 = img1.shape[:2]
|
|
|
|
| 430 |
out_fig[:h0, :w0] = overlay0
|
| 431 |
if h0 >= h1:
|
| 432 |
start = (h0 - h1) // 2
|
| 433 |
+
out_fig[
|
| 434 |
+
start : (start + h1), (w0 + margin) : (w0 + margin + w1)
|
| 435 |
+
] = overlay1
|
| 436 |
else:
|
| 437 |
start = (h1 - h0) // 2
|
| 438 |
out_fig[:h0, (w0 + margin) : (w0 + margin + w1)] = overlay1[
|
|
|
|
| 452 |
img1[start : start + h0] * 255
|
| 453 |
).astype(np.uint8)
|
| 454 |
|
| 455 |
+
# draw matching lines, this is inspried from
|
| 456 |
+
# https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/master/models/utils.py
|
| 457 |
mkpts0, mkpts1 = np.round(mkpts0).astype(int), np.round(mkpts1).astype(int)
|
| 458 |
mcolor = (np.array(mcolor[:, [2, 1, 0]]) * 255).astype(int)
|
| 459 |
|
style.css
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
h1 {
|
| 2 |
+
text-align: center;
|
| 3 |
+
}
|
| 4 |
+
|
| 5 |
+
#duplicate-button {
|
| 6 |
+
margin: auto;
|
| 7 |
+
color: white;
|
| 8 |
+
background: #1565c0;
|
| 9 |
+
border-radius: 100vh;
|
| 10 |
+
}
|
| 11 |
+
|
| 12 |
+
#component-0 {
|
| 13 |
+
/* max-width: 900px; */
|
| 14 |
+
margin: auto;
|
| 15 |
+
padding-top: 1.5rem;
|
| 16 |
+
}
|
| 17 |
+
|
| 18 |
+
footer {visibility: hidden}
|