Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| from dkm import * | |
| from .local_corr import LocalCorr | |
| from .corr_channels import NormedCorr | |
| from torchvision.models import resnet as tv_resnet | |
| dkm_pretrained_urls = { | |
| "DKM": { | |
| "mega_synthetic": "https://github.com/Parskatt/storage/releases/download/dkm_mega_synthetic/dkm_mega_synthetic.pth", | |
| "mega": "https://github.com/Parskatt/storage/releases/download/dkm_mega/dkm_mega.pth", | |
| }, | |
| "DKMv2": { | |
| "outdoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_outdoor.pth", | |
| "indoor": "https://github.com/Parskatt/storage/releases/download/dkmv2/dkm_v2_indoor.pth", | |
| }, | |
| } | |
| def DKM(pretrained=True, version="mega_synthetic", device=None): | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| gp_dim = 256 | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256, | |
| 512, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64, | |
| 128, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "fourier" | |
| gp32 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| h, w = 384, 512 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained), | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) | |
| if pretrained: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["DKM"][version] | |
| ) | |
| matcher.load_state_dict(weights) | |
| return matcher | |
| def DKMv2(pretrained=True, version="outdoor", resolution="low", **kwargs): | |
| gp_dim = 256 | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| displacement_emb = "linear" | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512 + 128, | |
| 1024 + 128, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=128, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512 + 64, | |
| 1024 + 64, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=64, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256 + 32, | |
| 512 + 32, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=32, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64 + 16, | |
| 128 + 16, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=16, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3 + 6, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=6, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "fourier" | |
| gp32 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| if resolution == "low": | |
| h, w = 384, 512 | |
| elif resolution == "high": | |
| h, w = 480, 640 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained), | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w, **kwargs).to(device) | |
| if pretrained: | |
| try: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["DKMv2"][version] | |
| ) | |
| except: | |
| weights = torch.load(dkm_pretrained_urls["DKMv2"][version]) | |
| matcher.load_state_dict(weights) | |
| return matcher | |
| def local_corr(pretrained=True, version="mega_synthetic"): | |
| gp_dim = 256 | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": LocalCorr( | |
| 81, | |
| 81 * 12, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "8": LocalCorr( | |
| 81, | |
| 81 * 12, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "4": LocalCorr( | |
| 81, | |
| 81 * 6, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "2": LocalCorr( | |
| 81, | |
| 81, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "fourier" | |
| gp32 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| h, w = 384, 512 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained) | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) | |
| if pretrained: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["local_corr"][version] | |
| ) | |
| matcher.load_state_dict(weights) | |
| return matcher | |
| def corr_channels(pretrained=True, version="mega_synthetic"): | |
| h, w = 384, 512 | |
| gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim[0] + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim[1] + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256, | |
| 512, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64, | |
| 128, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| } | |
| ) | |
| gp32 = NormedCorr() | |
| gp16 = NormedCorr() | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| h, w = 384, 512 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained) | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) | |
| if pretrained: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["corr_channels"][version] | |
| ) | |
| matcher.load_state_dict(weights) | |
| return matcher | |
| def baseline(pretrained=True, version="mega_synthetic"): | |
| h, w = 384, 512 | |
| gp_dim = (h // 32) * (w // 32), (h // 16) * (w // 16) | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim[0] + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim[1] + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": LocalCorr( | |
| 81, | |
| 81 * 12, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "8": LocalCorr( | |
| 81, | |
| 81 * 12, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "4": LocalCorr( | |
| 81, | |
| 81 * 6, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "2": LocalCorr( | |
| 81, | |
| 81, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| } | |
| ) | |
| gp32 = NormedCorr() | |
| gp16 = NormedCorr() | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| h, w = 384, 512 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained) | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) | |
| if pretrained: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["baseline"][version] | |
| ) | |
| matcher.load_state_dict(weights) | |
| return matcher | |
| def linear(pretrained=True, version="mega_synthetic"): | |
| gp_dim = 256 | |
| dfn_dim = 384 | |
| feat_dim = 256 | |
| coordinate_decoder = DFN( | |
| internal_dim=dfn_dim, | |
| feat_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(512, feat_dim, 1, 1), | |
| "16": nn.Conv2d(512, feat_dim, 1, 1), | |
| } | |
| ), | |
| pred_input_modules=nn.ModuleDict( | |
| { | |
| "32": nn.Identity(), | |
| "16": nn.Identity(), | |
| } | |
| ), | |
| rrb_d_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(gp_dim + feat_dim, dfn_dim), | |
| "16": RRB(gp_dim + feat_dim, dfn_dim), | |
| } | |
| ), | |
| cab_dict=nn.ModuleDict( | |
| { | |
| "32": CAB(2 * dfn_dim, dfn_dim), | |
| "16": CAB(2 * dfn_dim, dfn_dim), | |
| } | |
| ), | |
| rrb_u_dict=nn.ModuleDict( | |
| { | |
| "32": RRB(dfn_dim, dfn_dim), | |
| "16": RRB(dfn_dim, dfn_dim), | |
| } | |
| ), | |
| terminal_module=nn.ModuleDict( | |
| { | |
| "32": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| "16": nn.Conv2d(dfn_dim, 3, 1, 1, 0), | |
| } | |
| ), | |
| ) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512, | |
| 1024, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256, | |
| 512, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64, | |
| 128, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 3, | |
| 24, | |
| 3, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "linear" | |
| gp32 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"32": gp32, "16": gp16}) | |
| proj = nn.ModuleDict( | |
| {"16": nn.Conv2d(1024, 512, 1, 1), "32": nn.Conv2d(2048, 512, 1, 1)} | |
| ) | |
| decoder = Decoder(coordinate_decoder, gps, proj, conv_refiner, detach=True) | |
| h, w = 384, 512 | |
| encoder = Encoder( | |
| tv_resnet.resnet50(pretrained=not pretrained) | |
| ) # only load pretrained weights if not loading a pretrained matcher ;) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w).to(device) | |
| if pretrained: | |
| weights = torch.hub.load_state_dict_from_url( | |
| dkm_pretrained_urls["linear"][version] | |
| ) | |
| matcher.load_state_dict(weights) | |
| return matcher | |