| import glob | |
| import os | |
| from configs import global_config, paths_config, hyperparameters | |
| from scripts.latent_creators.sg2_plus_latent_creator import SG2PlusLatentCreator | |
| from scripts.latent_creators.e4e_latent_creator import E4ELatentCreator | |
| from scripts.run_pti import run_PTI | |
| import pickle | |
| import torch | |
| from utils.models_utils import toogle_grad, load_old_G | |
| class ExperimentRunner: | |
| def __init__(self, run_id=''): | |
| self.images_paths = glob.glob(f'{paths_config.input_data_path}/*') | |
| self.target_paths = glob.glob(f'{paths_config.input_data_path}/*') | |
| self.run_id = run_id | |
| self.sampled_ws = None | |
| self.old_G = load_old_G() | |
| toogle_grad(self.old_G, False) | |
| def run_experiment(self, run_pt, create_other_latents, use_multi_id_training, use_wandb=False): | |
| if run_pt: | |
| self.run_id = run_PTI(self.run_id, use_wandb=use_wandb, use_multi_id_training=use_multi_id_training) | |
| if create_other_latents: | |
| sg2_plus_latent_creator = SG2PlusLatentCreator(use_wandb=use_wandb) | |
| sg2_plus_latent_creator.create_latents() | |
| e4e_latent_creator = E4ELatentCreator(use_wandb=use_wandb) | |
| e4e_latent_creator.create_latents() | |
| torch.cuda.empty_cache() | |
| return self.run_id | |
| if __name__ == '__main__': | |
| os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' | |
| os.environ['CUDA_VISIBLE_DEVICES'] = global_config.cuda_visible_devices | |
| runner = ExperimentRunner() | |
| runner.run_experiment(True, False, False) | |