Spaces:
Running
on
Zero
Running
on
Zero
| import yaml | |
| import json | |
| import torch | |
| import wandb | |
| import torchvision | |
| import numpy as np | |
| from torch import nn | |
| from tqdm import tqdm | |
| from abc import abstractmethod | |
| from fractions import Fraction | |
| import matplotlib.pyplot as plt | |
| from dataclasses import dataclass | |
| from torch.distributed import barrier | |
| from torch.utils.data import DataLoader | |
| from gdf import GDF | |
| from gdf import AdaptiveLossWeight | |
| from core import WarpCore | |
| from core.data import setup_webdataset_path, MultiGetter, MultiFilter, Bucketeer | |
| from core.utils import EXPECTED, EXPECTED_TRAIN, update_weights_ema, create_folder_if_necessary | |
| import webdataset as wds | |
| from webdataset.handlers import warn_and_continue | |
| import transformers | |
| transformers.utils.logging.set_verbosity_error() | |
| class DataCore(WarpCore): | |
| class Config(WarpCore.Config): | |
| image_size: int = EXPECTED_TRAIN | |
| webdataset_path: str = EXPECTED_TRAIN | |
| grad_accum_steps: int = EXPECTED_TRAIN | |
| batch_size: int = EXPECTED_TRAIN | |
| multi_aspect_ratio: list = None | |
| captions_getter: list = None | |
| dataset_filters: list = None | |
| bucketeer_random_ratio: float = 0.05 | |
| class Extras(WarpCore.Extras): | |
| transforms: torchvision.transforms.Compose = EXPECTED | |
| clip_preprocess: torchvision.transforms.Compose = EXPECTED | |
| class Models(WarpCore.Models): | |
| tokenizer: nn.Module = EXPECTED | |
| text_model: nn.Module = EXPECTED | |
| image_model: nn.Module = None | |
| config: Config | |
| def webdataset_path(self): | |
| if isinstance(self.config.webdataset_path, str) and (self.config.webdataset_path.strip().startswith( | |
| 'pipe:') or self.config.webdataset_path.strip().startswith('file:')): | |
| return self.config.webdataset_path | |
| else: | |
| dataset_path = self.config.webdataset_path | |
| if isinstance(self.config.webdataset_path, str) and self.config.webdataset_path.strip().endswith('.yml'): | |
| with open(self.config.webdataset_path, 'r', encoding='utf-8') as file: | |
| dataset_path = yaml.safe_load(file) | |
| return setup_webdataset_path(dataset_path, cache_path=f"{self.config.experiment_id}_webdataset_cache.yml") | |
| def webdataset_preprocessors(self, extras: Extras): | |
| def identity(x): | |
| if isinstance(x, bytes): | |
| x = x.decode('utf-8') | |
| return x | |
| # CUSTOM CAPTIONS GETTER ----- | |
| def get_caption(oc, c, p_og=0.05): # cog_contexual, cog_caption | |
| if p_og > 0 and np.random.rand() < p_og and len(oc) > 0: | |
| return identity(oc) | |
| else: | |
| return identity(c) | |
| captions_getter = MultiGetter(rules={ | |
| ('old_caption', 'caption'): lambda oc, c: get_caption(json.loads(oc)['og_caption'], c, p_og=0.05) | |
| }) | |
| return [ | |
| ('jpg;png', | |
| torchvision.transforms.ToTensor() if self.config.multi_aspect_ratio is not None else extras.transforms, | |
| 'images'), | |
| ('txt', identity, 'captions') if self.config.captions_getter is None else ( | |
| self.config.captions_getter[0], eval(self.config.captions_getter[1]), 'captions'), | |
| ] | |
| def setup_data(self, extras: Extras) -> WarpCore.Data: | |
| # SETUP DATASET | |
| dataset_path = self.webdataset_path() | |
| preprocessors = self.webdataset_preprocessors(extras) | |
| handler = warn_and_continue | |
| dataset = wds.WebDataset( | |
| dataset_path, resampled=True, handler=handler | |
| ).select( | |
| MultiFilter(rules={ | |
| f[0]: eval(f[1]) for f in self.config.dataset_filters | |
| }) if self.config.dataset_filters is not None else lambda _: True | |
| ).shuffle(690, handler=handler).decode( | |
| "pilrgb", handler=handler | |
| ).to_tuple( | |
| *[p[0] for p in preprocessors], handler=handler | |
| ).map_tuple( | |
| *[p[1] for p in preprocessors], handler=handler | |
| ).map(lambda x: {p[2]: x[i] for i, p in enumerate(preprocessors)}) | |
| def identity(x): | |
| return x | |
| # SETUP DATALOADER | |
| real_batch_size = self.config.batch_size // (self.world_size * self.config.grad_accum_steps) | |
| dataloader = DataLoader( | |
| dataset, batch_size=real_batch_size, num_workers=8, pin_memory=True, | |
| collate_fn=identity if self.config.multi_aspect_ratio is not None else None | |
| ) | |
| if self.is_main_node: | |
| print(f"Training with batch size {self.config.batch_size} ({real_batch_size}/GPU)") | |
| if self.config.multi_aspect_ratio is not None: | |
| aspect_ratios = [float(Fraction(f)) for f in self.config.multi_aspect_ratio] | |
| dataloader_iterator = Bucketeer(dataloader, density=self.config.image_size ** 2, factor=32, | |
| ratios=aspect_ratios, p_random_ratio=self.config.bucketeer_random_ratio, | |
| interpolate_nearest=False) # , use_smartcrop=True) | |
| else: | |
| dataloader_iterator = iter(dataloader) | |
| return self.Data(dataset=dataset, dataloader=dataloader, iterator=dataloader_iterator) | |
| def get_conditions(self, batch: dict, models: Models, extras: Extras, is_eval=False, is_unconditional=False, | |
| eval_image_embeds=False, return_fields=None): | |
| if return_fields is None: | |
| return_fields = ['clip_text', 'clip_text_pooled', 'clip_img'] | |
| captions = batch.get('captions', None) | |
| images = batch.get('images', None) | |
| batch_size = len(captions) | |
| text_embeddings = None | |
| text_pooled_embeddings = None | |
| if 'clip_text' in return_fields or 'clip_text_pooled' in return_fields: | |
| if is_eval: | |
| if is_unconditional: | |
| captions_unpooled = ["" for _ in range(batch_size)] | |
| else: | |
| captions_unpooled = captions | |
| else: | |
| rand_idx = np.random.rand(batch_size) > 0.05 | |
| captions_unpooled = [str(c) if keep else "" for c, keep in zip(captions, rand_idx)] | |
| clip_tokens_unpooled = models.tokenizer(captions_unpooled, truncation=True, padding="max_length", | |
| max_length=models.tokenizer.model_max_length, | |
| return_tensors="pt").to(self.device) | |
| text_encoder_output = models.text_model(**clip_tokens_unpooled, output_hidden_states=True) | |
| if 'clip_text' in return_fields: | |
| text_embeddings = text_encoder_output.hidden_states[-1] | |
| if 'clip_text_pooled' in return_fields: | |
| text_pooled_embeddings = text_encoder_output.text_embeds.unsqueeze(1) | |
| image_embeddings = None | |
| if 'clip_img' in return_fields: | |
| image_embeddings = torch.zeros(batch_size, 768, device=self.device) | |
| if images is not None: | |
| images = images.to(self.device) | |
| if is_eval: | |
| if not is_unconditional and eval_image_embeds: | |
| image_embeddings = models.image_model(extras.clip_preprocess(images)).image_embeds | |
| else: | |
| rand_idx = np.random.rand(batch_size) > 0.9 | |
| if any(rand_idx): | |
| image_embeddings[rand_idx] = models.image_model(extras.clip_preprocess(images[rand_idx])).image_embeds | |
| image_embeddings = image_embeddings.unsqueeze(1) | |
| return { | |
| 'clip_text': text_embeddings, | |
| 'clip_text_pooled': text_pooled_embeddings, | |
| 'clip_img': image_embeddings | |
| } | |
| class TrainingCore(DataCore, WarpCore): | |
| class Config(DataCore.Config, WarpCore.Config): | |
| updates: int = EXPECTED_TRAIN | |
| backup_every: int = EXPECTED_TRAIN | |
| save_every: int = EXPECTED_TRAIN | |
| # EMA UPDATE | |
| ema_start_iters: int = None | |
| ema_iters: int = None | |
| ema_beta: float = None | |
| use_fsdp: bool = None | |
| # not frozen, means that fields are mutable. Doesn't support EXPECTED | |
| class Info(WarpCore.Info): | |
| ema_loss: float = None | |
| adaptive_loss: dict = None | |
| class Models(WarpCore.Models): | |
| generator: nn.Module = EXPECTED | |
| generator_ema: nn.Module = None # optional | |
| class Optimizers(WarpCore.Optimizers): | |
| generator: any = EXPECTED | |
| class Extras(WarpCore.Extras): | |
| gdf: GDF = EXPECTED | |
| sampling_configs: dict = EXPECTED | |
| info: Info | |
| config: Config | |
| def forward_pass(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def backward_pass(self, update, loss, loss_adjusted, models: Models, optimizers: Optimizers, | |
| schedulers: WarpCore.Schedulers): | |
| raise NotImplementedError("This method needs to be overriden") | |
| def models_to_save(self) -> list: | |
| raise NotImplementedError("This method needs to be overriden") | |
| def encode_latents(self, batch: dict, models: Models, extras: Extras) -> torch.Tensor: | |
| raise NotImplementedError("This method needs to be overriden") | |
| def decode_latents(self, latents: torch.Tensor, batch: dict, models: Models, extras: Extras) -> torch.Tensor: | |
| raise NotImplementedError("This method needs to be overriden") | |
| def train(self, data: WarpCore.Data, extras: WarpCore.Extras, models: Models, optimizers: Optimizers, | |
| schedulers: WarpCore.Schedulers): | |
| start_iter = self.info.iter + 1 | |
| max_iters = self.config.updates * self.config.grad_accum_steps | |
| if self.is_main_node: | |
| print(f"STARTING AT STEP: {start_iter}/{max_iters}") | |
| pbar = tqdm(range(start_iter, max_iters + 1)) if self.is_main_node else range(start_iter, | |
| max_iters + 1) # <--- DDP | |
| if 'generator' in self.models_to_save(): | |
| models.generator.train() | |
| for i in pbar: | |
| # FORWARD PASS | |
| loss, loss_adjusted = self.forward_pass(data, extras, models) | |
| # # BACKWARD PASS | |
| grad_norm = self.backward_pass( | |
| i % self.config.grad_accum_steps == 0 or i == max_iters, loss, loss_adjusted, | |
| models, optimizers, schedulers | |
| ) | |
| self.info.iter = i | |
| # UPDATE EMA | |
| if models.generator_ema is not None and i % self.config.ema_iters == 0: | |
| update_weights_ema( | |
| models.generator_ema, models.generator, | |
| beta=(self.config.ema_beta if i > self.config.ema_start_iters else 0) | |
| ) | |
| # UPDATE LOSS METRICS | |
| self.info.ema_loss = loss.mean().item() if self.info.ema_loss is None else self.info.ema_loss * 0.99 + loss.mean().item() * 0.01 | |
| if self.is_main_node and self.config.wandb_project is not None and np.isnan(loss.mean().item()) or np.isnan( | |
| grad_norm.item()): | |
| wandb.alert( | |
| title=f"NaN value encountered in training run {self.info.wandb_run_id}", | |
| text=f"Loss {loss.mean().item()} - Grad Norm {grad_norm.item()}. Run {self.info.wandb_run_id}", | |
| wait_duration=60 * 30 | |
| ) | |
| if self.is_main_node: | |
| logs = { | |
| 'loss': self.info.ema_loss, | |
| 'raw_loss': loss.mean().item(), | |
| 'grad_norm': grad_norm.item(), | |
| 'lr': optimizers.generator.param_groups[0]['lr'] if optimizers.generator is not None else 0, | |
| 'total_steps': self.info.total_steps, | |
| } | |
| pbar.set_postfix(logs) | |
| if self.config.wandb_project is not None: | |
| wandb.log(logs) | |
| if i == 1 or i % (self.config.save_every * self.config.grad_accum_steps) == 0 or i == max_iters: | |
| # SAVE AND CHECKPOINT STUFF | |
| if np.isnan(loss.mean().item()): | |
| if self.is_main_node and self.config.wandb_project is not None: | |
| tqdm.write("Skipping sampling & checkpoint because the loss is NaN") | |
| wandb.alert(title=f"Skipping sampling & checkpoint for training run {self.config.wandb_run_id}", | |
| text=f"Skipping sampling & checkpoint at {self.info.total_steps} for training run {self.info.wandb_run_id} iters because loss is NaN") | |
| else: | |
| if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): | |
| self.info.adaptive_loss = { | |
| 'bucket_ranges': extras.gdf.loss_weight.bucket_ranges.tolist(), | |
| 'bucket_losses': extras.gdf.loss_weight.bucket_losses.tolist(), | |
| } | |
| self.save_checkpoints(models, optimizers) | |
| if self.is_main_node: | |
| create_folder_if_necessary(f'{self.config.output_path}/{self.config.experiment_id}/') | |
| self.sample(models, data, extras) | |
| def save_checkpoints(self, models: Models, optimizers: Optimizers, suffix=None): | |
| barrier() | |
| suffix = '' if suffix is None else suffix | |
| self.save_info(self.info, suffix=suffix) | |
| models_dict = models.to_dict() | |
| optimizers_dict = optimizers.to_dict() | |
| for key in self.models_to_save(): | |
| model = models_dict[key] | |
| if model is not None: | |
| self.save_model(model, f"{key}{suffix}", is_fsdp=self.config.use_fsdp) | |
| for key in optimizers_dict: | |
| optimizer = optimizers_dict[key] | |
| if optimizer is not None: | |
| self.save_optimizer(optimizer, f'{key}_optim{suffix}', | |
| fsdp_model=models_dict[key] if self.config.use_fsdp else None) | |
| if suffix == '' and self.info.total_steps > 1 and self.info.total_steps % self.config.backup_every == 0: | |
| self.save_checkpoints(models, optimizers, suffix=f"_{self.info.total_steps // 1000}k") | |
| torch.cuda.empty_cache() | |
| def sample(self, models: Models, data: WarpCore.Data, extras: Extras): | |
| if 'generator' in self.models_to_save(): | |
| models.generator.eval() | |
| with torch.no_grad(): | |
| batch = next(data.iterator) | |
| conditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False) | |
| unconditions = self.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False) | |
| latents = self.encode_latents(batch, models, extras) | |
| noised, _, _, logSNR, noise_cond, _ = extras.gdf.diffuse(latents, shift=1, loss_shift=1) | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| pred = models.generator(noised, noise_cond, **conditions) | |
| pred = extras.gdf.undiffuse(noised, logSNR, pred)[0] | |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16): | |
| *_, (sampled, _, _) = extras.gdf.sample( | |
| models.generator, conditions, | |
| latents.shape, unconditions, device=self.device, **extras.sampling_configs | |
| ) | |
| if models.generator_ema is not None: | |
| *_, (sampled_ema, _, _) = extras.gdf.sample( | |
| models.generator_ema, conditions, | |
| latents.shape, unconditions, device=self.device, **extras.sampling_configs | |
| ) | |
| else: | |
| sampled_ema = sampled | |
| if self.is_main_node: | |
| noised_images = torch.cat( | |
| [self.decode_latents(noised[i:i + 1], batch, models, extras) for i in range(len(noised))], dim=0) | |
| pred_images = torch.cat( | |
| [self.decode_latents(pred[i:i + 1], batch, models, extras) for i in range(len(pred))], dim=0) | |
| sampled_images = torch.cat( | |
| [self.decode_latents(sampled[i:i + 1], batch, models, extras) for i in range(len(sampled))], dim=0) | |
| sampled_images_ema = torch.cat( | |
| [self.decode_latents(sampled_ema[i:i + 1], batch, models, extras) for i in range(len(sampled_ema))], | |
| dim=0) | |
| images = batch['images'] | |
| if images.size(-1) != noised_images.size(-1) or images.size(-2) != noised_images.size(-2): | |
| images = nn.functional.interpolate(images, size=noised_images.shape[-2:], mode='bicubic') | |
| collage_img = torch.cat([ | |
| torch.cat([i for i in images.cpu()], dim=-1), | |
| torch.cat([i for i in noised_images.cpu()], dim=-1), | |
| torch.cat([i for i in pred_images.cpu()], dim=-1), | |
| torch.cat([i for i in sampled_images.cpu()], dim=-1), | |
| torch.cat([i for i in sampled_images_ema.cpu()], dim=-1), | |
| ], dim=-2) | |
| torchvision.utils.save_image(collage_img, f'{self.config.output_path}/{self.config.experiment_id}/{self.info.total_steps:06d}.jpg') | |
| torchvision.utils.save_image(collage_img, f'{self.config.experiment_id}_latest_output.jpg') | |
| captions = batch['captions'] | |
| if self.config.wandb_project is not None: | |
| log_data = [ | |
| [captions[i]] + [wandb.Image(sampled_images[i])] + [wandb.Image(sampled_images_ema[i])] + [ | |
| wandb.Image(images[i])] for i in range(len(images))] | |
| log_table = wandb.Table(data=log_data, columns=["Captions", "Sampled", "Sampled EMA", "Orig"]) | |
| wandb.log({"Log": log_table}) | |
| if isinstance(extras.gdf.loss_weight, AdaptiveLossWeight): | |
| plt.plot(extras.gdf.loss_weight.bucket_ranges, extras.gdf.loss_weight.bucket_losses[:-1]) | |
| plt.ylabel('Raw Loss') | |
| plt.ylabel('LogSNR') | |
| wandb.log({"Loss/LogSRN": plt}) | |
| if 'generator' in self.models_to_save(): | |
| models.generator.train() | |