Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import torch | |
| from argparse import ArgumentParser | |
| from torch import nn | |
| from torch.utils.data import ConcatDataset | |
| import torch.distributed as dist | |
| from torch.nn.parallel import DistributedDataParallel as DDP | |
| import json | |
| import wandb | |
| from tqdm import tqdm | |
| from romatch.benchmarks import MegadepthDenseBenchmark | |
| from romatch.datasets.megadepth import MegadepthBuilder | |
| from romatch.datasets.scannet import ScanNetBuilder | |
| from romatch.losses.robust_loss import RobustLosses | |
| from romatch.benchmarks import MegadepthDenseBenchmark, ScanNetBenchmark | |
| from romatch.train.train import train_k_steps | |
| from romatch.models.matcher import * | |
| from romatch.models.transformer import Block, TransformerDecoder, MemEffAttention | |
| from romatch.models.encoders import * | |
| from romatch.checkpointing import CheckPoint | |
| resolutions = {"low":(448, 448), "medium":(14*8*5, 14*8*5), "high":(14*8*6, 14*8*6)} | |
| def get_model(pretrained_backbone=True, resolution = "medium", **kwargs): | |
| gp_dim = 512 | |
| feat_dim = 512 | |
| decoder_dim = gp_dim + feat_dim | |
| cls_to_coord_res = 64 | |
| coordinate_decoder = TransformerDecoder( | |
| nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), | |
| decoder_dim, | |
| cls_to_coord_res**2 + 1, | |
| is_classifier=True, | |
| amp = True, | |
| pos_enc = False,) | |
| dw = True | |
| hidden_blocks = 8 | |
| kernel_size = 5 | |
| displacement_emb = "linear" | |
| disable_local_corr_grad = True | |
| conv_refiner = nn.ModuleDict( | |
| { | |
| "16": ConvRefiner( | |
| 2 * 512+128+(2*7+1)**2, | |
| 2 * 512+128+(2*7+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=128, | |
| local_corr_radius = 7, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "8": ConvRefiner( | |
| 2 * 512+64+(2*3+1)**2, | |
| 2 * 512+64+(2*3+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=64, | |
| local_corr_radius = 3, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "4": ConvRefiner( | |
| 2 * 256+32+(2*2+1)**2, | |
| 2 * 256+32+(2*2+1)**2, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=32, | |
| local_corr_radius = 2, | |
| corr_in_other = True, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "2": ConvRefiner( | |
| 2 * 64+16, | |
| 128+16, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks=hidden_blocks, | |
| displacement_emb=displacement_emb, | |
| displacement_emb_dim=16, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| "1": ConvRefiner( | |
| 2 * 9 + 6, | |
| 24, | |
| 2 + 1, | |
| kernel_size=kernel_size, | |
| dw=dw, | |
| hidden_blocks = hidden_blocks, | |
| displacement_emb = displacement_emb, | |
| displacement_emb_dim = 6, | |
| amp = True, | |
| disable_local_corr_grad = disable_local_corr_grad, | |
| bn_momentum = 0.01, | |
| ), | |
| } | |
| ) | |
| kernel_temperature = 0.2 | |
| learn_temperature = False | |
| no_cov = True | |
| kernel = CosKernel | |
| only_attention = False | |
| basis = "fourier" | |
| gp16 = GP( | |
| kernel, | |
| T=kernel_temperature, | |
| learn_temperature=learn_temperature, | |
| only_attention=only_attention, | |
| gp_dim=gp_dim, | |
| basis=basis, | |
| no_cov=no_cov, | |
| ) | |
| gps = nn.ModuleDict({"16": gp16}) | |
| proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512)) | |
| proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512)) | |
| proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256)) | |
| proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64)) | |
| proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9)) | |
| proj = nn.ModuleDict({ | |
| "16": proj16, | |
| "8": proj8, | |
| "4": proj4, | |
| "2": proj2, | |
| "1": proj1, | |
| }) | |
| displacement_dropout_p = 0.0 | |
| gm_warp_dropout_p = 0.0 | |
| decoder = Decoder(coordinate_decoder, | |
| gps, | |
| proj, | |
| conv_refiner, | |
| detach=True, | |
| scales=["16", "8", "4", "2", "1"], | |
| displacement_dropout_p = displacement_dropout_p, | |
| gm_warp_dropout_p = gm_warp_dropout_p) | |
| h,w = resolutions[resolution] | |
| encoder = CNNandDinov2( | |
| cnn_kwargs = dict( | |
| pretrained=pretrained_backbone, | |
| amp = True), | |
| amp = True, | |
| use_vgg = True, | |
| ) | |
| matcher = RegressionMatcher(encoder, decoder, h=h, w=w, alpha=1, beta=0,**kwargs) | |
| return matcher | |
| def train(args): | |
| dist.init_process_group('nccl') | |
| #torch._dynamo.config.verbose=True | |
| gpus = int(os.environ['WORLD_SIZE']) | |
| # create model and move it to GPU with id rank | |
| rank = dist.get_rank() | |
| print(f"Start running DDP on rank {rank}") | |
| device_id = rank % torch.cuda.device_count() | |
| romatch.LOCAL_RANK = device_id | |
| torch.cuda.set_device(device_id) | |
| resolution = args.train_resolution | |
| wandb_log = not args.dont_log_wandb | |
| experiment_name = os.path.splitext(os.path.basename(__file__))[0] | |
| wandb_mode = "online" if wandb_log and rank == 0 and False else "disabled" | |
| wandb.init(project="romatch", entity=args.wandb_entity, name=experiment_name, reinit=False, mode = wandb_mode) | |
| checkpoint_dir = "workspace/checkpoints/" | |
| h,w = resolutions[resolution] | |
| model = get_model(pretrained_backbone=True, resolution=resolution, attenuate_cert = False).to(device_id) | |
| # Num steps | |
| global_step = 0 | |
| batch_size = args.gpu_batch_size | |
| step_size = gpus*batch_size | |
| romatch.STEP_SIZE = step_size | |
| N = (32 * 250000) # 250k steps of batch size 32 | |
| # checkpoint every | |
| k = 25000 // romatch.STEP_SIZE | |
| # Data | |
| mega = MegadepthBuilder(data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) | |
| use_horizontal_flip_aug = True | |
| rot_prob = 0 | |
| depth_interpolation_mode = "bilinear" | |
| megadepth_train1 = mega.build_scenes( | |
| split="train_loftr", min_overlap=0.01, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
| ht=h,wt=w, | |
| ) | |
| megadepth_train2 = mega.build_scenes( | |
| split="train_loftr", min_overlap=0.35, shake_t=32, use_horizontal_flip_aug = use_horizontal_flip_aug, rot_prob = rot_prob, | |
| ht=h,wt=w, | |
| ) | |
| megadepth_train = ConcatDataset(megadepth_train1 + megadepth_train2) | |
| mega_ws = mega.weight_scenes(megadepth_train, alpha=0.75) | |
| scannet = ScanNetBuilder(data_root="data/scannet") | |
| scannet_train = scannet.build_scenes(split="train", ht=h, wt=w, use_horizontal_flip_aug = use_horizontal_flip_aug) | |
| scannet_train = ConcatDataset(scannet_train) | |
| scannet_ws = scannet.weight_scenes(scannet_train, alpha=0.75) | |
| # Loss and optimizer | |
| depth_loss_scannet = RobustLosses( | |
| ce_weight=0.0, | |
| local_dist={1:4, 2:4, 4:8, 8:8}, | |
| local_largest_scale=8, | |
| depth_interpolation_mode=depth_interpolation_mode, | |
| alpha = 0.5, | |
| c = 1e-4,) | |
| # Loss and optimizer | |
| depth_loss_mega = RobustLosses( | |
| ce_weight=0.01, | |
| local_dist={1:4, 2:4, 4:8, 8:8}, | |
| local_largest_scale=8, | |
| depth_interpolation_mode=depth_interpolation_mode, | |
| alpha = 0.5, | |
| c = 1e-4,) | |
| parameters = [ | |
| {"params": model.encoder.parameters(), "lr": romatch.STEP_SIZE * 5e-6 / 8}, | |
| {"params": model.decoder.parameters(), "lr": romatch.STEP_SIZE * 1e-4 / 8}, | |
| ] | |
| optimizer = torch.optim.AdamW(parameters, weight_decay=0.01) | |
| lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( | |
| optimizer, milestones=[(9*N/romatch.STEP_SIZE)//10]) | |
| megadense_benchmark = MegadepthDenseBenchmark("data/megadepth", num_samples = 1000, h=h,w=w) | |
| checkpointer = CheckPoint(checkpoint_dir, experiment_name) | |
| model, optimizer, lr_scheduler, global_step = checkpointer.load(model, optimizer, lr_scheduler, global_step) | |
| romatch.GLOBAL_STEP = global_step | |
| ddp_model = DDP(model, device_ids=[device_id], find_unused_parameters = False, gradient_as_bucket_view=True) | |
| grad_scaler = torch.cuda.amp.GradScaler(growth_interval=1_000_000) | |
| grad_clip_norm = 0.01 | |
| for n in range(romatch.GLOBAL_STEP, N, k * romatch.STEP_SIZE): | |
| mega_sampler = torch.utils.data.WeightedRandomSampler( | |
| mega_ws, num_samples = batch_size * k, replacement=False | |
| ) | |
| mega_dataloader = iter( | |
| torch.utils.data.DataLoader( | |
| megadepth_train, | |
| batch_size = batch_size, | |
| sampler = mega_sampler, | |
| num_workers = 8, | |
| ) | |
| ) | |
| scannet_ws_sampler = torch.utils.data.WeightedRandomSampler( | |
| scannet_ws, num_samples=batch_size * k, replacement=False | |
| ) | |
| scannet_dataloader = iter( | |
| torch.utils.data.DataLoader( | |
| scannet_train, | |
| batch_size=batch_size, | |
| sampler=scannet_ws_sampler, | |
| num_workers=gpus * 8, | |
| ) | |
| ) | |
| for n_k in tqdm(range(n, n + 2 * k, 2),disable = romatch.RANK > 0): | |
| train_k_steps( | |
| n_k, 1, mega_dataloader, ddp_model, depth_loss_mega, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False | |
| ) | |
| train_k_steps( | |
| n_k + 1, 1, scannet_dataloader, ddp_model, depth_loss_scannet, optimizer, lr_scheduler, grad_scaler, grad_clip_norm = grad_clip_norm, progress_bar=False | |
| ) | |
| checkpointer.save(model, optimizer, lr_scheduler, romatch.GLOBAL_STEP) | |
| wandb.log(megadense_benchmark.benchmark(model), step = romatch.GLOBAL_STEP) | |
| def test_scannet(model, name, resolution, sample_mode): | |
| scannet_benchmark = ScanNetBenchmark("data/scannet") | |
| scannet_results = scannet_benchmark.benchmark(model) | |
| json.dump(scannet_results, open(f"results/scannet_{name}.json", "w")) | |
| if __name__ == "__main__": | |
| import warnings | |
| warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') | |
| warnings.filterwarnings('ignore')#, category=UserWarning)#, message='WARNING batched routines are designed for small sizes.') | |
| os.environ["TORCH_CUDNN_V8_API_ENABLED"] = "1" # For BF16 computations | |
| os.environ["OMP_NUM_THREADS"] = "16" | |
| import romatch | |
| parser = ArgumentParser() | |
| parser.add_argument("--test", action='store_true') | |
| parser.add_argument("--debug_mode", action='store_true') | |
| parser.add_argument("--dont_log_wandb", action='store_true') | |
| parser.add_argument("--train_resolution", default='medium') | |
| parser.add_argument("--gpu_batch_size", default=4, type=int) | |
| parser.add_argument("--wandb_entity", required = False) | |
| args, _ = parser.parse_known_args() | |
| romatch.DEBUG_MODE = args.debug_mode | |
| if not args.test: | |
| train(args) | |
| experiment_name = os.path.splitext(os.path.basename(__file__))[0] | |
| checkpoint_dir = "workspace/" | |
| checkpoint_name = checkpoint_dir + experiment_name + ".pth" | |
| test_resolution = "medium" | |
| sample_mode = "threshold_balanced" | |
| symmetric = True | |
| upsample_preds = False | |
| attenuate_cert = True | |
| model = get_model(pretrained_backbone=False, resolution = test_resolution, sample_mode = sample_mode, upsample_preds = upsample_preds, symmetric=symmetric, name=experiment_name, attenuate_cert = attenuate_cert) | |
| model = model.cuda() | |
| states = torch.load(checkpoint_name) | |
| model.load_state_dict(states["model"]) | |
| test_scannet(model, experiment_name, resolution = test_resolution, sample_mode = sample_mode) | |