Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| from argparse import ArgumentParser | |
| from torch import nn | |
| from torch.utils.data import ConcatDataset | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import json | |
| import wandb | |
| from romatch.benchmarks import MegadepthDenseBenchmark | |
| from romatch.datasets.megadepth import MegadepthBuilder | |
| from romatch.losses.robust_loss import RobustLosses | |
| from romatch.benchmarks import MegaDepthPoseEstimationBenchmark, MegadepthDenseBenchmark, HpatchesHomogBenchmark | |
| from romatch.train.train import train_k_steps | |
| from romatch.models.matcher import * | |
| from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention | |
| from romatch.models.encoders import * | |
| from romatch.checkpointing import CheckPoint | |
| resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)} | |
| def get_model(pretrained_backbone=True, resolution = "medium", **kwargs): | |
| import warnings | |
| warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
| gp_dim = 512 | |
| feat_dim = 512 | |
| decoder_dim = gp_dim + feat_dim | |
| cls_to_coord_res = 64 | |
| coordinate_decoder = TransformerDecoder( | |
| nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), | |
| decoder_dim, | |
| cls_to_coord_res**2 + 1, | |
| is_classifier=True, | |
| amp = True, | |
| pos_enc = False,) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| displacement_emb = "linear" | |
| disable_local_corr_grad = True | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512+128+(2*7+1)**2, | |
| 2 * 512+128+(2*7+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=128, | |
| local_corr_radius = 7, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512+64+(2*3+1)**2, | |
| 2 * 512+64+(2*3+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=64, | |
| local_corr_radius = 3, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256+32+(2*2+1)**2, | |
| 2 * 256+32+(2*2+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=32, | |
| local_corr_radius = 2, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64+16, | |
| 128+16, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=16, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 9 + 6, | |
| 24, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks = hidden_blocks, | |
| displacement_emb = displacement_emb, | |
| displacement_emb_dim = 6, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "fourier" | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"16": gp16}) | |
| proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) | |
| proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) | |
| proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) | |
| proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) | |
| proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) | |
| proj = nn.ModuleDict({ | |
| "16": proj16, | |
| "8": proj8, | |
| "4": proj4, | |
| "2": proj2, | |
| "1": proj1, | |
| }) | |
| displacement_dropout_p = 0.0 | |
| gm_warp_dropout_p = 0.0 | |
| decoder = Decoder(coordinate_decoder, | |
| gps, | |
| proj, | |
| conv_refiner, | |
| detach=True, | |
| scales=["16", "8", "4", "2", "1"], | |
| displacement_dropout_p = displacement_dropout_p, | |
| gm_warp_dropout_p = gm_warp_dropout_p) | |
| h,w = resolutions[resolution] | |
| encoder = CNNandDinov2( | |
| cnn_kwargs = dict( | |
| pretrained=pretrained_backbone, | |
| amp = True), | |
| amp = True, | |
| use_vgg = True, | |
| ) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w,**kwargs) | |
| return matcher | |
| def train(args): | |
| dist.init_process_group('nccl') | |
| #torch._dynamo.config.verbose=True | |
| gpus = int(os.environ['WORLD_SIZE']) | |
| # create model and move it to GPU with id rank | |
| rank = dist.get_rank() | |
| print(f"Start running DDP on rank {rank}") | |
| device_id = rank % torch.cuda.device_count() | |
| romatch.LOCAL_RANK = device_id | |
| torch.cuda.set_device(device_id) | |
| resolution = args.train_resolution | |
| wandb_log = not args.dont_log_wandb | |
| experiment_name = os.path.splitext(os.path.basename(__file__))[0] | |
| wandb_mode = "online" if wandb_log and rank == 0 else "disabled" | |
| wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) | |
| checkpoint_dir = "workspace/checkpoints/" | |
| h,w = resolutions[resolution] | |
| model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id) | |
| # Num steps | |
| global_step = 0 | |
| batch_size = args.gpu_batch_size | |
| step_size = gpus*batch_size | |
| romatch.STEP_SIZE = step_size | |
| N = (32 * 250000) # 250k steps of batch size 32 | |
| # checkpoint every | |
| k = 25000 // romatch.STEP_SIZE | |
| # Data | |
| mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) | |
| use_horizontal_flip_aug = True | |
| rot_prob = 0 | |
| depth_interpolation_mode = "bilinear" | |
| megadepth_train1 = mega.build_scenes( | |
| split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
| ht=h,wt=w, | |
| ) | |
| megadepth_train2 = mega.build_scenes( | |
| split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
| ht=h,wt=w, | |
| ) | |
| megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) | |
| mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) | |
| # Loss and optimizer | |
| depth_loss = RobustLosses( | |
| ce_weight=0.01, | |
| local_dist={1:4, 2:4, 4:8, 8:8}, | |
| local_largest_scale=8, | |
| depth_interpolation_mode=depth_interpolation_mode, | |
| alpha = 0.5, | |
| c = 1e-4,) | |
| parameters = [ | |
| {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8}, | |
| {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8}, | |
| ] | |
| optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10]) | |
| megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) | |
| checkpointer = CheckPoint(checkpoint_dir, experiment_name) | |
| model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) | |
| romatch.GLOBAL_STEP = global_step | |
| ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True) | |
| grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) | |
| grad_clip_norm = 0.01 | |
| for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE): | |
| mega_sampler = torch.utils.data.WeightedRandomSampler( | |
| mega_ws, num_samples = batch_size * k, replacement=False | |
| ) | |
| mega_dataloader = iter( | |
| torch.utils.data.DataLoader( | |
| megadepth_train, | |
| batch_size = batch_size, | |
| sampler = mega_sampler, | |
| num_workers = 8, | |
| ) | |
| ) | |
| train_k_steps( | |
| n, k, mega_dataloader, ddp_model, depth_loss, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, | |
| ) | |
| checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP) | |
| wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP) | |
| def test_mega_8_scenes(model, name): | |
| mega_8_scenes_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth", | |
| scene_names=['mega_8_scenes_0019_0.1_0.3.npz', | |
| 'mega_8_scenes_0025_0.1_0.3.npz', | |
| 'mega_8_scenes_0021_0.1_0.3.npz', | |
| 'mega_8_scenes_0008_0.1_0.3.npz', | |
| 'mega_8_scenes_0032_0.1_0.3.npz', | |
| 'mega_8_scenes_1589_0.1_0.3.npz', | |
| 'mega_8_scenes_0063_0.1_0.3.npz', | |
| 'mega_8_scenes_0024_0.1_0.3.npz', | |
| 'mega_8_scenes_0019_0.3_0.5.npz', | |
| 'mega_8_scenes_0025_0.3_0.5.npz', | |
| 'mega_8_scenes_0021_0.3_0.5.npz', | |
| 'mega_8_scenes_0008_0.3_0.5.npz', | |
| 'mega_8_scenes_0032_0.3_0.5.npz', | |
| 'mega_8_scenes_1589_0.3_0.5.npz', | |
| 'mega_8_scenes_0063_0.3_0.5.npz', | |
| 'mega_8_scenes_0024_0.3_0.5.npz']) | |
| mega_8_scenes_results = mega_8_scenes_benchmark.benchmark(model, model_name=name) | |
| print(mega_8_scenes_results) | |
| json.dump(mega_8_scenes_results, open(f"results/mega_8_scenes_{name}.json", "w")) | |
| def test_mega1500(model, name): | |
| mega1500_benchmark = MegaDepthPoseEstimationBenchmark("data/megadepth") | |
| mega1500_results = mega1500_benchmark.benchmark(model, model_name=name) | |
| json.dump(mega1500_results, open(f"results/mega1500_{name}.json", "w")) | |
| def test_mega_dense(model, name): | |
| megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000) | |
| megadense_results = megadense_benchmark.benchmark(model) | |
| json.dump(megadense_results, open(f"results/mega_dense_{name}.json", "w")) | |
| def test_hpatches(model, name): | |
| hpatches_benchmark = HpatchesHomogBenchmark("data/hpatches") | |
| hpatches_results = hpatches_benchmark.benchmark(model) | |
| json.dump(hpatches_results, open(f"results/hpatches_{name}.json", "w")) | |
| if __name__ == "__main__": | |
| os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations | |
| os.environ["OMP_NUM_THREADS"] = "16" | |
| torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn | |
| import romatch | |
| parser = ArgumentParser() | |
| parser.add_argument("--only_test", action='store_true') | |
| parser.add_argument("--debug_mode", action='store_true') | |
| parser.add_argument("--dont_log_wandb", action='store_true') | |
| parser.add_argument("--train_resolution", default='medium') | |
| parser.add_argument("--gpu_batch_size", default=8, type=int) | |
| parser.add_argument("--wandb_entity", required = False) | |
| args, _ = parser.parse_known_args() | |
| romatch.DEBUG_MODE = args.debug_mode | |
| if not args.only_test: | |
| train(args) | |