import argparse import importlib import os os.environ["CUDA_VISIBLE_DEVICES"] = "7" os.environ["HDF5_USE_FILE_LOCKING"] = "0" import random from datetime import date from shutil import copyfile import cv2 as cv import numpy as np import torch import torch.backends.cudnn import admin.settings as ws_settings def run_training(train_module, train_name, seed, cudnn_benchmark=True): """Run a train scripts in train_settings. args: train_module: Name of module in the "train_settings/" folder. train_name: Name of the train settings file. cudnn_benchmark: Use cudnn benchmark or not (default is True). """ # This is needed to avoid strange crashes related to opencv cv.setNumThreads(0) torch.backends.cudnn.benchmark = cudnn_benchmark # dd/mm/YY today = date.today() d1 = today.strftime("%d/%m/%Y") print('Training: {} {}\nDate: {}'.format(train_module, train_name, d1)) settings = ws_settings.Settings() settings.module_name = train_module settings.script_name = train_name settings.project_path = 'train_settings/{}/{}'.format(train_module, train_name) settings.seed = seed # will save the checkpoints there save_dir = os.path.join(settings.env.workspace_dir, settings.project_path) if not os.path.exists(save_dir): os.makedirs(save_dir) copyfile(settings.project_path + '.py', os.path.join(save_dir, settings.script_name + '.py')) expr_module = importlib.import_module('train_settings.{}.{}'.format(train_module.replace('/', '.'), train_name.replace('/', '.'))) expr_func = getattr(expr_module, 'run') expr_func(settings) def main(): parser = argparse.ArgumentParser(description='Run a train scripts in train_settings.') parser.add_argument('--train_module', type=str, help='Name of module in the "train_settings/" folder.') parser.add_argument('--train_name', type=str, help='Name of the train settings file.') parser.add_argument('--cudnn_benchmark', type=bool, default=True, help='Set cudnn benchmark on (1) or off (0) (default is on).') parser.add_argument('--seed', type=int, default=1992, help='Pseudo-RNG seed') args = parser.parse_args() # args.seed = random.randint(0, 3000000) args.seed = torch.initial_seed() & (2 ** 32 - 1) print('Seed is {}'.format(args.seed)) random.seed(int(args.seed)) np.random.seed(args.seed) torch.manual_seed(args.seed) torch.cuda.manual_seed(args.seed) run_training(args.train_module, args.train_name, cudnn_benchmark=args.cudnn_benchmark, seed=args.seed) if __name__ == '__main__': main()