| import os | |
| import torch | |
| from tqdm import tqdm | |
| from configs import paths_config, hyperparameters, global_config | |
| from training.coaches.base_coach import BaseCoach | |
| from utils.log_utils import log_images_from_w | |
| class MultiIDCoach(BaseCoach): | |
| def __init__(self, data_loader, use_wandb): | |
| super().__init__(data_loader, use_wandb) | |
| def train(self): | |
| self.G.synthesis.train() | |
| self.G.mapping.train() | |
| w_path_dir = f'{paths_config.embedding_base_dir}/{paths_config.input_data_id}' | |
| os.makedirs(w_path_dir, exist_ok=True) | |
| os.makedirs(f'{w_path_dir}/{paths_config.pti_results_keyword}', exist_ok=True) | |
| use_ball_holder = True | |
| w_pivots = [] | |
| images = [] | |
| for fname, image in self.data_loader: | |
| if self.image_counter >= hyperparameters.max_images_to_invert: | |
| break | |
| image_name = fname[0] | |
| if hyperparameters.first_inv_type == 'w+': | |
| embedding_dir = f'{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}' | |
| else: | |
| embedding_dir = f'{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}' | |
| os.makedirs(embedding_dir, exist_ok=True) | |
| w_pivot = self.get_inversion(w_path_dir, image_name, image) | |
| w_pivots.append(w_pivot) | |
| images.append((image_name, image)) | |
| self.image_counter += 1 | |
| for i in tqdm(range(hyperparameters.max_pti_steps)): | |
| self.image_counter = 0 | |
| for data, w_pivot in zip(images, w_pivots): | |
| image_name, image = data | |
| if self.image_counter >= hyperparameters.max_images_to_invert: | |
| break | |
| real_images_batch = image.to(global_config.device) | |
| generated_images = self.forward(w_pivot) | |
| loss, l2_loss_val, loss_lpips = self.calc_loss(generated_images, real_images_batch, image_name, | |
| self.G, use_ball_holder, w_pivot) | |
| self.optimizer.zero_grad() | |
| loss.backward() | |
| self.optimizer.step() | |
| use_ball_holder = global_config.training_step % hyperparameters.locality_regularization_interval == 0 | |
| global_config.training_step += 1 | |
| self.image_counter += 1 | |
| if self.use_wandb: | |
| log_images_from_w(w_pivots, self.G, [image[0] for image in images]) | |
| torch.save(self.G, | |
| f'{paths_config.checkpoints_dir}/model_{global_config.run_name}_multi_id.pt') | |