Spaces:
Runtime error
Runtime error
| import torch.utils.data | |
| from dataset import Offline_Dataset | |
| import yaml | |
| from sgmnet.match_model import matcher as SGM_Model | |
| from superglue.match_model import matcher as SG_Model | |
| import torch.distributed as dist | |
| import torch | |
| import os | |
| from collections import namedtuple | |
| from train import train | |
| from config import get_config, print_usage | |
| def main(config, model_config): | |
| """The main function.""" | |
| # Initialize network | |
| if config.model_name == "SGM": | |
| model = SGM_Model(model_config) | |
| elif config.model_name == "SG": | |
| model = SG_Model(model_config) | |
| else: | |
| raise NotImplementedError | |
| # initialize ddp | |
| torch.cuda.set_device(config.local_rank) | |
| device = torch.device(f"cuda:{config.local_rank}") | |
| model.to(device) | |
| dist.init_process_group(backend="nccl", init_method="env://") | |
| model = torch.nn.parallel.DistributedDataParallel( | |
| model, device_ids=[config.local_rank] | |
| ) | |
| if config.local_rank == 0: | |
| os.system("nvidia-smi") | |
| # initialize dataset | |
| train_dataset = Offline_Dataset(config, "train") | |
| train_sampler = torch.utils.data.distributed.DistributedSampler( | |
| train_dataset, shuffle=True | |
| ) | |
| train_loader = torch.utils.data.DataLoader( | |
| train_dataset, | |
| batch_size=config.train_batch_size // torch.distributed.get_world_size(), | |
| num_workers=8 // dist.get_world_size(), | |
| pin_memory=False, | |
| sampler=train_sampler, | |
| collate_fn=train_dataset.collate_fn, | |
| ) | |
| valid_dataset = Offline_Dataset(config, "valid") | |
| valid_sampler = torch.utils.data.distributed.DistributedSampler( | |
| valid_dataset, shuffle=False | |
| ) | |
| valid_loader = torch.utils.data.DataLoader( | |
| valid_dataset, | |
| batch_size=config.train_batch_size, | |
| num_workers=8 // dist.get_world_size(), | |
| pin_memory=False, | |
| collate_fn=valid_dataset.collate_fn, | |
| sampler=valid_sampler, | |
| ) | |
| if config.local_rank == 0: | |
| print("start training .....") | |
| train(model, train_loader, valid_loader, config, model_config) | |
| if __name__ == "__main__": | |
| # ---------------------------------------- | |
| # Parse configuration | |
| config, unparsed = get_config() | |
| with open(config.config_path, "r") as f: | |
| model_config = yaml.load(f) | |
| model_config = namedtuple("model_config", model_config.keys())( | |
| *model_config.values() | |
| ) | |
| # If we have unparsed arguments, print usage and exit | |
| if len(unparsed) > 0: | |
| print_usage() | |
| exit(1) | |
| main(config, model_config) | |