Spaces:
Runtime error
Runtime error
| # Copyright 2019-present NAVER Corp. | |
| # CC BY-NC-SA 3.0 | |
| # Available only for non-commercial use | |
| import os, pdb | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from tools import common | |
| from tools.dataloader import norm_RGB | |
| from nets.patchnet import * | |
| def load_network(model_fn): | |
| checkpoint = torch.load(model_fn) | |
| print("\n>> Creating net = " + checkpoint["net"]) | |
| net = eval(checkpoint["net"]) | |
| nb_of_weights = common.model_size(net) | |
| print(f" ( Model size: {nb_of_weights/1000:.0f}K parameters )") | |
| # initialization | |
| weights = checkpoint["state_dict"] | |
| net.load_state_dict({k.replace("module.", ""): v for k, v in weights.items()}) | |
| return net.eval() | |
| class NonMaxSuppression(torch.nn.Module): | |
| def __init__(self, rel_thr=0.7, rep_thr=0.7): | |
| nn.Module.__init__(self) | |
| self.max_filter = torch.nn.MaxPool2d(kernel_size=3, stride=1, padding=1) | |
| self.rel_thr = rel_thr | |
| self.rep_thr = rep_thr | |
| def forward(self, reliability, repeatability, **kw): | |
| assert len(reliability) == len(repeatability) == 1 | |
| reliability, repeatability = reliability[0], repeatability[0] | |
| # local maxima | |
| maxima = repeatability == self.max_filter(repeatability) | |
| # remove low peaks | |
| maxima *= repeatability >= self.rep_thr | |
| maxima *= reliability >= self.rel_thr | |
| return maxima.nonzero().t()[2:4] | |
| def extract_multiscale( | |
| net, | |
| img, | |
| detector, | |
| scale_f=2**0.25, | |
| min_scale=0.0, | |
| max_scale=1, | |
| min_size=256, | |
| max_size=1024, | |
| verbose=False, | |
| ): | |
| old_bm = torch.backends.cudnn.benchmark | |
| torch.backends.cudnn.benchmark = False # speedup | |
| # extract keypoints at multiple scales | |
| B, three, H, W = img.shape | |
| assert B == 1 and three == 3, "should be a batch with a single RGB image" | |
| assert max_scale <= 1 | |
| s = 1.0 # current scale factor | |
| X, Y, S, C, Q, D = [], [], [], [], [], [] | |
| while s + 0.001 >= max(min_scale, min_size / max(H, W)): | |
| if s - 0.001 <= min(max_scale, max_size / max(H, W)): | |
| nh, nw = img.shape[2:] | |
| if verbose: | |
| print(f"extracting at scale x{s:.02f} = {nw:4d}x{nh:3d}") | |
| # extract descriptors | |
| with torch.no_grad(): | |
| res = net(imgs=[img]) | |
| # get output and reliability map | |
| descriptors = res["descriptors"][0] | |
| reliability = res["reliability"][0] | |
| repeatability = res["repeatability"][0] | |
| # normalize the reliability for nms | |
| # extract maxima and descs | |
| y, x = detector(**res) # nms | |
| c = reliability[0, 0, y, x] | |
| q = repeatability[0, 0, y, x] | |
| d = descriptors[0, :, y, x].t() | |
| n = d.shape[0] | |
| # accumulate multiple scales | |
| X.append(x.float() * W / nw) | |
| Y.append(y.float() * H / nh) | |
| S.append((32 / s) * torch.ones(n, dtype=torch.float32, device=d.device)) | |
| C.append(c) | |
| Q.append(q) | |
| D.append(d) | |
| s /= scale_f | |
| # down-scale the image for next iteration | |
| nh, nw = round(H * s), round(W * s) | |
| img = F.interpolate(img, (nh, nw), mode="bilinear", align_corners=False) | |
| # restore value | |
| torch.backends.cudnn.benchmark = old_bm | |
| Y = torch.cat(Y) | |
| X = torch.cat(X) | |
| S = torch.cat(S) # scale | |
| scores = torch.cat(C) * torch.cat(Q) # scores = reliability * repeatability | |
| XYS = torch.stack([X, Y, S], dim=-1) | |
| D = torch.cat(D) | |
| return XYS, D, scores | |
| def extract_keypoints(args): | |
| iscuda = common.torch_set_gpu(args.gpu) | |
| # load the network... | |
| net = load_network(args.model) | |
| if iscuda: | |
| net = net.cuda() | |
| # create the non-maxima detector | |
| detector = NonMaxSuppression( | |
| rel_thr=args.reliability_thr, rep_thr=args.repeatability_thr | |
| ) | |
| while args.images: | |
| img_path = args.images.pop(0) | |
| if img_path.endswith(".txt"): | |
| args.images = open(img_path).read().splitlines() + args.images | |
| continue | |
| print(f"\nExtracting features for {img_path}") | |
| img = Image.open(img_path).convert("RGB") | |
| W, H = img.size | |
| img = norm_RGB(img)[None] | |
| if iscuda: | |
| img = img.cuda() | |
| # extract keypoints/descriptors for a single image | |
| xys, desc, scores = extract_multiscale( | |
| net, | |
| img, | |
| detector, | |
| scale_f=args.scale_f, | |
| min_scale=args.min_scale, | |
| max_scale=args.max_scale, | |
| min_size=args.min_size, | |
| max_size=args.max_size, | |
| verbose=True, | |
| ) | |
| xys = xys.cpu().numpy() | |
| desc = desc.cpu().numpy() | |
| scores = scores.cpu().numpy() | |
| idxs = scores.argsort()[-args.top_k or None :] | |
| outpath = img_path + "." + args.tag | |
| print(f"Saving {len(idxs)} keypoints to {outpath}") | |
| np.savez( | |
| open(outpath, "wb"), | |
| imsize=(W, H), | |
| keypoints=xys[idxs], | |
| descriptors=desc[idxs], | |
| scores=scores[idxs], | |
| ) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser("Extract keypoints for a given image") | |
| parser.add_argument("--model", type=str, required=True, help="model path") | |
| parser.add_argument( | |
| "--images", type=str, required=True, nargs="+", help="images / list" | |
| ) | |
| parser.add_argument("--tag", type=str, default="r2d2", help="output file tag") | |
| parser.add_argument("--top-k", type=int, default=5000, help="number of keypoints") | |
| parser.add_argument("--scale-f", type=float, default=2**0.25) | |
| parser.add_argument("--min-size", type=int, default=256) | |
| parser.add_argument("--max-size", type=int, default=1024) | |
| parser.add_argument("--min-scale", type=float, default=0) | |
| parser.add_argument("--max-scale", type=float, default=1) | |
| parser.add_argument("--reliability-thr", type=float, default=0.7) | |
| parser.add_argument("--repeatability-thr", type=float, default=0.7) | |
| parser.add_argument( | |
| "--gpu", type=int, nargs="+", default=[0], help="use -1 for CPU" | |
| ) | |
| args = parser.parse_args() | |
| extract_keypoints(args) | |