Spaces:
Sleeping
Sleeping
| import os | |
| from collections import deque | |
| from itertools import combinations | |
| from os.path import join | |
| import hydra | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.distributed as dist | |
| import torch.nn.functional as F | |
| from omegaconf import DictConfig, OmegaConf | |
| from peft import get_peft_model, LoraConfig | |
| from pytorch_lightning import Trainer | |
| from pytorch_lightning import seed_everything | |
| from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint | |
| from pytorch_lightning.loggers import TensorBoardLogger | |
| from pytorch_lightning.utilities import grad_norm | |
| from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, SequentialLR, LambdaLR | |
| from torchmetrics.functional.classification import binary_average_precision | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from DenseAV.denseav.aggregators import get_aggregator | |
| from DenseAV.denseav.aligners import get_aligner, ProgressiveGrowing | |
| from DenseAV.denseav.constants import * | |
| from DenseAV.denseav.data.AVDatasets import AVDataModule | |
| from DenseAV.denseav.shared import flatten_preds, GatherLayer, \ | |
| get_image_featurizer, get_audio_featurizer, RollingAvg, create_model_from_cfg | |
| torch.multiprocessing.set_sharing_strategy('file_system') | |
| def _imposter_indices_helper(true_indices: torch.Tensor, samples: torch.Tensor): | |
| mask = (true_indices == samples).to(torch.int64) | |
| n = mask.shape[0] | |
| if not mask.any(): | |
| return samples | |
| else: | |
| new_samples = torch.randint(0, n, size=(n,), device=true_indices.device) | |
| comb_samples = mask * new_samples + (1 - mask) * samples | |
| return _imposter_indices_helper(true_indices, comb_samples) | |
| def imposter_indices(n, device): | |
| return _imposter_indices_helper( | |
| torch.arange(0, n, device=device), | |
| torch.randint(0, n, size=(n,), device=device)) | |
| def get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type): | |
| max_t = audio_outputs.shape[-1] | |
| oh = F.one_hot(n_frames - 1, num_classes=max_t) | |
| audio_mask = 1 - torch.cumsum(oh, dim=1) | |
| audio_mask = F.pad(audio_mask, [1, 0], value=1)[:, :max_t].to(audio_outputs.dtype) | |
| full_sim = torch.einsum("bct,bchw->bthw", audio_outputs, image_outputs) | |
| expanded_am = audio_mask.unsqueeze(-1).unsqueeze(-1) | |
| if sim_type.endswith("mi"): | |
| offset = 10 * (full_sim.max() - full_sim.min()) | |
| full_sim = (full_sim - ((1 - expanded_am) * offset)).max(1, keepdim=True).values | |
| if sim_type.startswith("mi"): | |
| full_sim = full_sim.max(-1, keepdim=True).values.max(-2, keepdim=True).values | |
| if sim_type.endswith("sa"): | |
| full_sim = (full_sim * (expanded_am / expanded_am.sum(1, keepdim=True).clamp_min(1))).sum(1, keepdim=True) | |
| return full_sim.mean(dim=[1, 2, 3]) | |
| def sampled_margin_rank_loss(image_outputs, audio_outputs, n_frames, sim_type, margin=1.): | |
| """ | |
| Computes the triplet margin ranking loss for each anchor image/caption pair | |
| The impostor image/caption is randomly sampled from the minibatch | |
| """ | |
| assert (image_outputs.dim() == 4) | |
| assert (audio_outputs.dim() == 3) | |
| n = image_outputs.size(0) | |
| imp_ind_i = imposter_indices(n, image_outputs.device) | |
| imp_ind_a = imposter_indices(n, image_outputs.device) | |
| true_sim = get_sim_per_row(image_outputs, audio_outputs, n_frames, sim_type) | |
| imp_sim_i = get_sim_per_row(image_outputs[imp_ind_i], audio_outputs, n_frames, sim_type) | |
| imp_sim_a = get_sim_per_row(image_outputs, audio_outputs[imp_ind_a], n_frames[imp_ind_a], sim_type) | |
| a2i_loss = (margin + imp_sim_i - true_sim).clamp_min(0) | |
| i2a_loss = (margin + imp_sim_a - true_sim).clamp_min(0) | |
| return (a2i_loss + i2a_loss).mean() / 2 | |
| class SimilarityCalibrator(torch.nn.Module): | |
| def __init__(self, cal_init, max_w=100, min_w=.01, subtract_mean=True, use_bias=False): | |
| super().__init__() | |
| self.max_w = max_w | |
| self.min_w = min_w | |
| self.w = torch.nn.Parameter(torch.tensor([cal_init]).log()) | |
| self.use_bias = use_bias | |
| if self.use_bias: | |
| self.b = torch.nn.Parameter(torch.tensor([0.0])) | |
| self.subtract_mean = subtract_mean | |
| def get_w(self): | |
| return torch.exp(self.w).clamp_max(self.max_w).clamp_min(self.min_w) | |
| def forward(self, x): | |
| sims = self.get_w() * x | |
| if self.use_bias: | |
| sims = sims + self.b | |
| if self.subtract_mean: | |
| return sims - sims.mean() | |
| else: | |
| return sims | |
| class SpatialDropout(torch.nn.Module): | |
| def __init__(self, p, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.p = p | |
| def forward(self, x): | |
| b, c, h, w = x.shape | |
| dropout = torch.rand((b, 1, h, w), dtype=x.dtype, device=x.device) > self.p | |
| if self.training: | |
| return x * dropout | |
| else: | |
| return x | |
| class LitAVAligner(pl.LightningModule, PyTorchModelHubMixin, repo_url="https://github.com/mhamilton723/DenseAV", license="mit", tags=["denseav"]): | |
| def __init__(self, | |
| code_dim, | |
| image_model_type, | |
| image_model_token_type, | |
| image_aligner_type, | |
| image_pool_width, | |
| audio_model_type, | |
| audio_aligner_type, | |
| audio_pool_width, | |
| audio_lora, | |
| audio_lora_rank, | |
| image_lora, | |
| image_lora_rank, | |
| gradient_clipping, | |
| learn_audio_cls, | |
| silence_l1, | |
| silence_l2, | |
| tv_weight, | |
| nonneg_sim, | |
| nonneg_pressure, | |
| pretrain_lr, | |
| lr, | |
| lr_warmup, | |
| lr_schedule, | |
| lr_cycle_length, | |
| optimizer, | |
| gather_tensors, | |
| sim_agg_type, | |
| sim_agg_heads, | |
| sim_use_cls, | |
| disentangle_weight, | |
| norm_vectors, | |
| cal_init, | |
| cal_balance_weight, | |
| loss_type, | |
| loss_margin, | |
| mask_silence, | |
| finetune_image_model, | |
| finetune_audio_model, | |
| use_cached_embs, | |
| output_root, | |
| neg_audio, | |
| neg_audio_weight, | |
| head_agg, | |
| adaptive_clipping, | |
| specialization_weight, | |
| spatial_dropout, | |
| channel_dropout, | |
| mixup_weight, | |
| memory_buffer_size, | |
| loss_leak, | |
| ): | |
| super().__init__() | |
| self.code_dim = code_dim | |
| self.image_model_type = image_model_type | |
| self.image_model_token_type = image_model_token_type | |
| self.image_aligner_type = image_aligner_type | |
| self.image_pool_width = image_pool_width | |
| self.audio_model_type = audio_model_type | |
| self.audio_aligner_type = audio_aligner_type | |
| self.audio_pool_width = audio_pool_width | |
| self.gradient_clipping = gradient_clipping | |
| self.learn_audio_cls = learn_audio_cls | |
| self.silence_l1 = silence_l1 | |
| self.silence_l2 = silence_l2 | |
| self.tv_weight = tv_weight | |
| self.nonneg_sim = nonneg_sim | |
| self.nonneg_pressure = nonneg_pressure | |
| self.pretrain_lr = pretrain_lr | |
| self.lr = lr | |
| self.lr_warmup = lr_warmup | |
| self.lr_schedule = lr_schedule | |
| self.lr_cycle_length = lr_cycle_length | |
| self.optimizer = optimizer | |
| self.gather_tensors = gather_tensors | |
| self.sim_agg_type = sim_agg_type | |
| self.sim_agg_heads = sim_agg_heads | |
| self.sim_use_cls = sim_use_cls | |
| self.disentangle_weight = disentangle_weight | |
| self.norm_vectors = norm_vectors | |
| self.cal_init = cal_init | |
| self.cal_balance_weight = cal_balance_weight | |
| self.loss_type = loss_type | |
| self.loss_margin = loss_margin | |
| self.mask_silence = mask_silence | |
| self.finetune_image_model = finetune_image_model | |
| self.finetune_audio_model = finetune_audio_model | |
| self.use_cached_embs = use_cached_embs | |
| self.output_root = output_root | |
| self.audio_lora = audio_lora | |
| self.audio_lora_rank = audio_lora_rank | |
| self.image_lora = image_lora | |
| self.image_lora_rank = image_lora_rank | |
| self.neg_audio = neg_audio | |
| self.neg_audio_weight = neg_audio_weight | |
| self.head_agg = head_agg | |
| self.adaptive_clipping = adaptive_clipping | |
| self.specialization_weight = specialization_weight | |
| self.spatial_dropout = spatial_dropout | |
| self.channel_dropout = channel_dropout | |
| self.mixup_weight = mixup_weight | |
| self.memory_buffer_size = memory_buffer_size | |
| self.memory_buffer = deque(maxlen=self.memory_buffer_size) | |
| self.loss_leak = loss_leak | |
| self.full_train = False # Added by me | |
| if self.audio_model_type in {"audiomae", "audiomae-finetuned", "cavmae", "cavmae-mixed", "imagebind"}: | |
| self.audio_input = "spec" | |
| elif self.audio_model_type == "davenet": | |
| self.audio_input = "davenet_spec" | |
| elif self.audio_model_type == "fnac": | |
| self.audio_input = "fnac_spec" | |
| else: | |
| self.audio_input = "audio" | |
| extra_model_args = dict(output_root=output_root) | |
| self.image_model, _, self.image_feat_dim = get_image_featurizer( | |
| image_model_type, token_type=self.image_model_token_type, **extra_model_args) | |
| self.image_model.eval() | |
| if not self.finetune_image_model: | |
| for param in self.image_model.parameters(): | |
| param.requires_grad = False | |
| if image_model_type in {"cavmae", "cavmae-mixed", "imagebind", "fnac"}: | |
| extra_model_args["model"] = self.image_model.model | |
| if use_cached_embs: | |
| _, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args) | |
| else: | |
| self.audio_model, self.audio_feat_dim = get_audio_featurizer(audio_model_type, **extra_model_args) | |
| self.audio_model.eval() | |
| if not self.finetune_audio_model: | |
| for param in self.audio_model.parameters(): | |
| param.requires_grad = False | |
| if self.image_lora: | |
| if self.image_model_type in {"sam", "dino8", "dinov2", "cavmae", "cavmae-mixed"}: | |
| target_modules = ["qkv"] | |
| elif self.image_model_type == "clip": | |
| target_modules = ["out_proj"] | |
| elif self.image_model_type == "imagebind": | |
| target_modules = ["out_proj", "fc1", "fc2"] | |
| else: | |
| target_modules = ["q", "k", "v"] | |
| peft_config = LoraConfig( | |
| target_modules=target_modules, | |
| inference_mode=False, | |
| r=image_lora_rank, | |
| lora_alpha=32, | |
| lora_dropout=0.1 | |
| ) | |
| self.image_model = get_peft_model(self.image_model, peft_config) | |
| self.image_model.print_trainable_parameters() | |
| if self.audio_lora: | |
| if self.audio_model_type == "hubert": | |
| target_modules = ["q_proj", "k_proj", "v_proj"] | |
| else: | |
| target_modules = ["q", "k", "v"] | |
| peft_config = LoraConfig( | |
| inference_mode=False, | |
| target_modules=target_modules, | |
| r=audio_lora_rank, | |
| lora_alpha=32, | |
| lora_dropout=0.1 | |
| ) | |
| self.audio_model = get_peft_model(self.audio_model, peft_config) | |
| self.audio_model.print_trainable_parameters() | |
| shared_aligner_args = dict(out_dim=self.code_dim) | |
| self.audio_aligner = get_aligner( | |
| self.audio_aligner_type, self.audio_feat_dim, **shared_aligner_args) | |
| self.image_aligner = get_aligner( | |
| self.image_aligner_type, self.image_feat_dim, **shared_aligner_args) | |
| if self.loss_type == "nce": | |
| self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=True, use_bias=False) | |
| else: | |
| self.sim_cal = SimilarityCalibrator(self.cal_init, subtract_mean=False, use_bias=True) | |
| if self.learn_audio_cls: | |
| self.audio_cls = torch.nn.Parameter(torch.randn(self.audio_feat_dim)) | |
| if self.spatial_dropout > 0.0: | |
| self.spatial_dropout_layer = SpatialDropout(self.spatial_dropout) | |
| if self.channel_dropout > 0.0: | |
| self.channel_dropout_layer = torch.nn.Dropout2d(self.channel_dropout) | |
| self.sim_agg = get_aggregator( | |
| self.sim_agg_type, | |
| self.nonneg_sim, | |
| self.mask_silence, | |
| self.sim_agg_heads, | |
| self.head_agg, | |
| self.sim_use_cls, | |
| dim=self.image_feat_dim | |
| ) | |
| self.hparams_logged = False | |
| self.rolling_avg = RollingAvg(50) | |
| self.grad_avg = RollingAvg(50, nonzero=True) | |
| self.save_hyperparameters() | |
| def set_full_train(self, full_train): | |
| self.full_train = full_train | |
| def prep_feats(self, feats, is_audio): | |
| if not is_audio and self.training and self.image_pool_width > 1: | |
| feats = torch.nn.AvgPool2d(self.image_pool_width)(feats) | |
| if is_audio and self.training and self.audio_pool_width > 1: | |
| feats = torch.nn.AvgPool2d((1, self.audio_pool_width))(feats) | |
| if self.norm_vectors: | |
| feats = F.normalize(feats, dim=1) | |
| return feats | |
| def on_before_optimizer_step(self, optimizer, optimizer_idx): | |
| norms = grad_norm(self, norm_type=2) | |
| avg_grads = self.grad_avg.get_all() | |
| params = { | |
| f"grad_2.0_norm/{name}": p | |
| for name, p in self.named_parameters() | |
| if p.grad is not None | |
| } | |
| if self.adaptive_clipping: | |
| for k in norms.keys(): | |
| if k in params: | |
| avg_grad = max(avg_grads.get(k, norms[k]), 1e-5) | |
| if self.global_step > 10 and norms[k] > avg_grad * 5: | |
| print(f"Bad grad for {k}: {norms[k]} scaling to {avg_grad * 5}") | |
| torch.nn.utils.clip_grad_norm_(params[k], avg_grad * 5) | |
| norms[k] = avg_grad * 5 | |
| if norms[k] > self.gradient_clipping: | |
| # print(f"Bad grad for {k}: {norms[k]} scaling to {self.gradient_clipping}") | |
| torch.nn.utils.clip_grad_norm_(params[k], self.gradient_clipping) | |
| # self.grad_avg.add_all(norms) | |
| # self.log_dict(norms) | |
| def interpolate_mask(self, mask, target_length, discrete): | |
| b, t = mask.shape | |
| mask = F.interpolate(mask.reshape(b, 1, 1, t), (1, target_length), mode="bilinear") \ | |
| .reshape(b, target_length) | |
| if discrete: | |
| mask = mask > 0.01 | |
| sums = mask.sum(1) | |
| all_zeros = torch.where(sums == 0)[0] | |
| if len(all_zeros) > 0: | |
| print("Fixing a bad mask") | |
| for entry in all_zeros: | |
| mask[entry, torch.randint(0, target_length - 1, size=())] = True | |
| else: | |
| return mask | |
| return mask | |
| def forward_audio(self, batch): | |
| if self.use_cached_embs: | |
| audio_feats = batch["audio_emb"] | |
| if "audio_cls" in batch: | |
| audio_cls = batch["audio_cls"] | |
| else: | |
| audio_cls = None | |
| else: | |
| audio = batch[self.audio_input] | |
| if self.full_train: | |
| audio_feats, audio_cls = self.audio_model(audio, include_cls=True) | |
| else: | |
| with torch.no_grad(): | |
| audio_feats, audio_cls = self.audio_model(audio, include_cls=True) | |
| mask = batch[AUDIO_MASK] if AUDIO_MASK in batch else torch.ones_like(audio) | |
| pos_mask = batch[AUDIO_POS_MASK] if AUDIO_POS_MASK in batch else torch.ones_like(audio) | |
| if self.learn_audio_cls: | |
| assert audio_cls is None | |
| audio_cls = torch.broadcast_to(self.audio_cls.unsqueeze(0), (audio_feats.shape[0], audio_feats.shape[1])) | |
| aligned_audio_feats, aligned_audio_cls = self.audio_aligner(audio_feats, audio_cls) | |
| if self.channel_dropout > 0.0: | |
| aligned_audio_feats = self.channel_dropout_layer(aligned_audio_feats) | |
| aligned_audio_feats = self.prep_feats(aligned_audio_feats, is_audio=True) | |
| audio_mask = self.interpolate_mask(mask, aligned_audio_feats.shape[-1], True) | |
| audio_pos_mask = self.interpolate_mask(pos_mask, aligned_audio_feats.shape[-1], False) | |
| ret = { | |
| AUDIO_MASK: audio_mask, | |
| AUDIO_POS_MASK: audio_pos_mask, | |
| AUDIO_FEATS: aligned_audio_feats, | |
| } | |
| if aligned_audio_cls is not None: | |
| ret[AUDIO_CLS] = aligned_audio_cls | |
| return ret | |
| # @autocast(device_type="cuda", enabled=False) | |
| def forward_image(self, batch, max_batch_size=None): | |
| with torch.no_grad(): | |
| image = batch[IMAGE_INPUT] | |
| b, nf, c, h, w = image.shape | |
| image = image.reshape(b * nf, c, h, w) | |
| if max_batch_size is None: | |
| max_batch_size = image.shape[0] | |
| chunks = [image[i:i + max_batch_size] for i in range(0, image.shape[0], max_batch_size)] | |
| all_image_feats = [] | |
| all_image_cls = [] | |
| for chunk in chunks: | |
| if self.full_train: | |
| image_feats, image_cls = self.image_model(chunk, include_cls=True) | |
| else: | |
| with torch.no_grad(): | |
| image_feats, image_cls = self.image_model(chunk, include_cls=True) | |
| aligned_image_feats, aligned_image_cls = self.image_aligner(image_feats, image_cls) | |
| all_image_feats.append(aligned_image_feats) | |
| all_image_cls.append(aligned_image_cls) | |
| # Stitch the chunks back together | |
| aligned_image_feats = torch.cat(all_image_feats, dim=0) | |
| aligned_image_cls = torch.cat(all_image_cls, dim=0) | |
| if self.channel_dropout > 0.0: | |
| aligned_image_feats = self.channel_dropout_layer(aligned_image_feats) | |
| if self.spatial_dropout > 0.0: | |
| aligned_image_feats = self.spatial_dropout_layer(aligned_image_feats) | |
| aligned_image_feats = self.prep_feats(aligned_image_feats, is_audio=False) | |
| ret = {IMAGE_FEATS: aligned_image_feats} | |
| if IMAGE_MASK in batch: | |
| with torch.no_grad(): | |
| mask = batch[IMAGE_MASK] | |
| mask = mask.reshape(b * nf, 1, h, w) | |
| b, c, h, w = aligned_image_feats.shape | |
| mask = F.adaptive_avg_pool2d(mask.to(aligned_image_feats), output_size=(h, w)) | |
| ret[IMAGE_MASK] = mask | |
| if aligned_image_cls is not None: | |
| ret[IMAGE_CLS] = aligned_image_cls | |
| return ret | |
| def forward(self, batch): | |
| audio_feat_dict = self.forward_audio(batch) | |
| image_feat_dict = self.forward_image(batch) | |
| return {**image_feat_dict, **audio_feat_dict} | |
| def contrast_loss(self, sims): | |
| b = sims.shape[0] | |
| sims = sims - torch.eye(b, b, device=sims.device) * self.loss_margin | |
| sims_1 = sims | |
| sims_2 = sims.permute(1, 0) | |
| if self.loss_leak > 0.0: | |
| id = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype) | |
| label_mask = id * (1 - self.loss_leak) | |
| label_mask += (1 - id) * self.loss_leak / (sims_1.shape[0] - 1) | |
| label_mask /= label_mask.sum(dim=1, keepdim=True) | |
| else: | |
| label_mask = torch.eye(sims_1.shape[0], sims_1.shape[1], device=sims.device, dtype=sims.dtype) | |
| labels = torch.arange(0, sims.shape[0], device=sims.device) | |
| self.rolling_avg.add(f"acc/1", (sims.argmax(dim=1) == labels).to(sims).mean()) | |
| self.rolling_avg.add(f"acc/2", (sims.argmax(dim=0) == labels).to(sims).mean()) | |
| if self.loss_type == "margin": | |
| margin_loss_tensor = (sims - torch.diag(sims)).clamp_min(0) | |
| margin_loss = margin_loss_tensor.mean() | |
| self.rolling_avg.add(f"loss/frac_nonzero", (margin_loss_tensor > 0).to(sims).mean()) | |
| self.rolling_avg.add(f"loss/margin", margin_loss) | |
| return margin_loss | |
| elif self.loss_type == "ce": | |
| ce_loss = 1 / 2 * F.cross_entropy(sims_1, labels) + \ | |
| 1 / 2 * F.cross_entropy(sims_2, labels) | |
| self.rolling_avg.add(f"loss/ce", ce_loss) | |
| return ce_loss | |
| elif self.loss_type == "bce": | |
| bce_loss = F.binary_cross_entropy_with_logits(sims_1.flatten(), label_mask.flatten()) | |
| self.rolling_avg.add(f"loss/bce", bce_loss) | |
| return bce_loss | |
| elif self.loss_type == "nce": | |
| nce_loss = 1 / 2 * (-F.log_softmax(sims_1, dim=-1) * label_mask).sum(1).mean() + \ | |
| 1 / 2 * (-F.log_softmax(sims_2, dim=-1) * label_mask).sum(1).mean() | |
| self.rolling_avg.add(f"loss/nce", nce_loss) | |
| return nce_loss | |
| else: | |
| raise ValueError(f"Unknown loss type {self.loss_type}") | |
| def loss(self, preds): | |
| image_feats = preds[IMAGE_FEATS] | |
| audio_feats = preds[AUDIO_FEATS] | |
| audio_mask = preds[AUDIO_MASK] | |
| image_mask = preds[IMAGE_MASK] | |
| audio_pos_mask = preds[AUDIO_POS_MASK] | |
| if DATA_SOURCE in preds: | |
| source = preds[DATA_SOURCE].to(torch.int64) | |
| else: | |
| source = None | |
| uncal_sims = self.sim_agg(preds, agg_heads=True) | |
| sims = self.sim_cal(uncal_sims) | |
| _mask = 1 - torch.eye(sims.shape[0], device=sims.device) | |
| self.log(f"sim/pos", torch.diag(sims).mean()) | |
| self.log(f"sim/neg", (sims * _mask).sum() / (_mask.sum())) | |
| self.log(f"sim/uncal_pos", torch.diag(uncal_sims).mean()) | |
| self.log(f"sim/uncal_neg", (uncal_sims * _mask).sum() / (_mask.sum())) | |
| b, c, h, w = image_feats.shape | |
| b, c, f, t = audio_feats.shape | |
| n_samples = 250 | |
| nh = self.sim_agg_heads | |
| image_feats_by_head = image_feats.reshape(b, self.sim_agg_heads, c // nh, h, w) | |
| audio_feats_by_head = audio_feats.reshape(b, self.sim_agg_heads, c // nh, f, t) | |
| def maybe_clamp(t): | |
| return t.clamp_min(0) if self.nonneg_sim else t | |
| paired_sim_raw = self.sim_agg.get_pairwise_sims(preds, raw=True, agg_sim=False, agg_heads=False) | |
| paired_sim = maybe_clamp(paired_sim_raw) | |
| loss = 0.0 | |
| if self.nonneg_pressure: | |
| afb, afk, afc, aff, aft = audio_feats_by_head.shape | |
| ifb, ifk, ifc, ifh, ifw = image_feats_by_head.shape | |
| assert (afb == ifb) | |
| device = audio_feats_by_head.device | |
| random_b = torch.randint(0, afb, size=(n_samples,), device=device) | |
| random_t = torch.randint(0, aft, size=(n_samples,), device=device) | |
| random_f = torch.randint(0, aff, size=(n_samples,), device=device) | |
| random_h = torch.randint(0, ifh, size=(n_samples,), device=device) | |
| random_w = torch.randint(0, ifw, size=(n_samples,), device=device) | |
| random_audio_feats = audio_feats_by_head[random_b, :, :, random_f, random_t] | |
| random_image_feats = image_feats_by_head[random_b, :, :, random_h, random_w] | |
| random_sim_raw = torch.einsum("bkc,dkc->bdk", random_audio_feats, random_image_feats) | |
| nonneg_loss = random_sim_raw.clamp_max(0).square().mean() | |
| self.rolling_avg.add(f"loss/nonneg", nonneg_loss) | |
| loss += nonneg_loss * self.nonneg_pressure | |
| if self.silence_l1 > 0 or self.silence_l2 > 0: | |
| masked_b, masked_t = torch.where(~audio_mask) | |
| if len(masked_b) > n_samples: | |
| subset = torch.randperm(len(masked_b))[:n_samples] | |
| masked_b = masked_b[subset] | |
| masked_t = masked_t[subset] | |
| if len(masked_b) == n_samples: | |
| silent_audio_feats = audio_feats_by_head[masked_b, :, :, :, masked_t].mean(-1) # d k c | |
| silence_tensor = maybe_clamp( | |
| torch.einsum("bkchw,dkc->bkdhw", image_feats_by_head, silent_audio_feats)) | |
| silence_l1_loss = silence_tensor.abs().mean() | |
| self.rolling_avg.add(f"loss/silence_l1", silence_l1_loss) | |
| loss += silence_l1_loss * self.silence_l1 | |
| silence_l2_loss = silence_tensor.square().mean() | |
| self.rolling_avg.add(f"loss/silence_l2", silence_l2_loss) | |
| loss += silence_l2_loss * self.silence_l2 | |
| else: | |
| pass | |
| if self.neg_audio_weight > 0 and self.neg_audio: | |
| b, t = audio_pos_mask.shape | |
| negative_weight = ((1 - audio_pos_mask) * audio_mask.to(sims)).reshape(b, 1, 1, 1, 1, t) | |
| negative_weight = torch.broadcast_to(negative_weight, paired_sim.shape) | |
| if negative_weight.sum() > 0: | |
| neg_audio_loss = (paired_sim.square() * negative_weight).sum() \ | |
| / negative_weight.sum().clamp_min(0.1) | |
| self.rolling_avg.add(f"loss/neg_audio", neg_audio_loss) | |
| self.rolling_avg.add(f"loss/neg_weight_avg", negative_weight.mean()) | |
| loss += neg_audio_loss * self.neg_audio_weight | |
| else: | |
| print("WARNING: No negative samples found in batch") | |
| if self.tv_weight > 0: | |
| tv_loss = (paired_sim[:, :, :, :, :, 1:] - paired_sim[:, :, :, :, :, :-1]).square().mean() | |
| self.rolling_avg.add(f"loss/tv", tv_loss) | |
| loss += tv_loss * self.tv_weight | |
| self.log(f"cal/w", self.sim_cal.get_w()) | |
| if self.cal_balance_weight > 0.0: | |
| cal_balance = (np.log(self.cal_init) - torch.log(self.sim_cal.get_w().clamp_min(.00000001))) \ | |
| .clamp_min(0).square().mean() | |
| self.rolling_avg.add(f"loss/cal_balance", cal_balance) | |
| loss += cal_balance * self.cal_balance_weight | |
| if self.disentangle_weight > 0.0: | |
| assert source is not None | |
| assert self.sim_agg_heads % 2 == 0 | |
| dilation = self.sim_agg_heads // 2 | |
| sources_oh = F.one_hot(source, num_classes=2) | |
| b, h = sources_oh.shape | |
| sources_mask = 1 - torch.broadcast_to(sources_oh.unsqueeze(-1), (b, h, dilation)) \ | |
| .reshape(b, h * dilation).to(paired_sim) | |
| disentangle_loss = torch.einsum("bkhwft,bk->bhwft", paired_sim, sources_mask).square().mean() | |
| self.rolling_avg.add(f"loss/disentangle", disentangle_loss) | |
| loss += disentangle_loss * self.disentangle_weight | |
| if self.specialization_weight > 0.0 and self.sim_agg_heads > 1: | |
| total_specialization_loss = 0.0 | |
| combos = list(combinations(range(self.sim_agg_heads), 2)) | |
| for i, j in combos: | |
| specialization_loss_pair = (paired_sim[:, i].abs() * paired_sim[:, j].abs()).mean() | |
| total_specialization_loss += specialization_loss_pair | |
| avg_specialization_loss = total_specialization_loss / len(combos) | |
| self.rolling_avg.add(f"loss/specialize", avg_specialization_loss) | |
| loss += avg_specialization_loss * self.specialization_weight | |
| if self.mixup_weight > 0.0: | |
| b, _, h, w = image_mask.shape | |
| neg_img_mask = torch.broadcast_to( | |
| 1 - image_mask.to(paired_sim).reshape(b, 1, h, w, 1, 1), | |
| paired_sim.shape) | |
| image_mixup_loss = (paired_sim * neg_img_mask).square().sum() / neg_img_mask.sum().clamp_min(0.1) | |
| self.rolling_avg.add(f"loss/image_mixup", image_mixup_loss) | |
| loss += image_mixup_loss * self.mixup_weight | |
| sims = sims | |
| loss += self.contrast_loss(sims) | |
| self.rolling_avg.add(f"loss/total", loss) | |
| return loss | |
| def setup_hparams(self): | |
| recalls = ['A_r1', 'A_r5', 'A_r10', 'I_r1', 'I_r5', 'I_r10'] | |
| if self.trainer.datamodule.use_extra_val_sets: | |
| datasets = ["Places", "AudioSet"] | |
| else: | |
| datasets = ["Val"] | |
| heads = ["total"] | |
| metric_names = [ | |
| "hp/speech_basic_ap", "hp/speech_advanced_ap", "hp/sound_basic_ap", | |
| "hp/speech_basic_iou", "hp/speech_advanced_iou", "hp/sound_basic_iou", | |
| ] | |
| for dataset in datasets: | |
| for head in heads: | |
| for recall in recalls: | |
| metric_names.append(f"hp/{dataset}/{head}/{recall}") | |
| if self.sim_agg_heads == 2: | |
| metric_names.extend(["hp/ap_dis", "hp/act_dis"]) | |
| if hasattr(self.trainer, "datamodule"): | |
| all_hparams = {**self.hparams, **self.trainer.datamodule.hparams} | |
| else: | |
| all_hparams = self.hparams | |
| starting_values = {n: torch.nan for n in metric_names} | |
| self.logger.log_hyperparams(all_hparams, starting_values) | |
| def on_train_start(self): | |
| self.setup_hparams() | |
| self.hparams_logged = True | |
| def on_train_batch_start(self, batch, batch_idx): | |
| remake_optimizers = False | |
| if isinstance(self.image_aligner, ProgressiveGrowing): | |
| should_remake = self.image_aligner.maybe_change_phase(self.global_step) | |
| remake_optimizers = remake_optimizers or should_remake | |
| if isinstance(self.audio_aligner, ProgressiveGrowing): | |
| should_remake = self.audio_aligner.maybe_change_phase(self.global_step) | |
| remake_optimizers = remake_optimizers or should_remake | |
| if remake_optimizers: | |
| raise NotImplementedError() | |
| def _combine_preds(self, all_preds): | |
| temp = {} | |
| new_preds = {} | |
| # Collect tensors for each key into lists | |
| for d in all_preds: | |
| for key, value in d.items(): | |
| if isinstance(value, torch.Tensor): | |
| if key not in temp: | |
| temp[key] = [] | |
| temp[key].append(value) | |
| # Concatenate all tensors for each key using a single call to torch.cat | |
| for key, tensor_list in temp.items(): | |
| new_preds[key] = torch.cat(tensor_list) | |
| return new_preds | |
| def training_step(self, batch, batch_idx): | |
| assert batch[IMAGE_INPUT].shape[1] == 1 | |
| preds = self.forward(batch) | |
| if DATA_SOURCE in batch: | |
| preds[DATA_SOURCE] = batch[DATA_SOURCE] | |
| if self.trainer.world_size > 1 and self.gather_tensors: | |
| for k, v in preds.items(): | |
| new_v = v.contiguous() | |
| preds[k] = torch.cat(GatherLayer.apply(new_v), dim=0) | |
| if self.memory_buffer_size > 0: | |
| new_preds = self._combine_preds(list(self.memory_buffer) + [preds]) | |
| else: | |
| new_preds = preds | |
| loss = self.loss(new_preds) | |
| if self.memory_buffer_size > 0: | |
| self.memory_buffer.append(self._recursive_detach(preds, gather=False)) | |
| if self.trainer.is_global_zero and self.global_step % 50 == 1: | |
| writer = self.logger.experiment | |
| self.rolling_avg.logall(lambda k, v: writer.add_scalar(k, v, global_step=self.global_step)) | |
| if self.trainer.scaler is not None: | |
| self.log("loss_scale", self.trainer.scaler.get_scale()) | |
| if self.global_step % 10000 == 0 and self.global_step > 0: | |
| print("RESETTING TFEVENT FILE") | |
| self.logger.experiment.close() | |
| self.logger.experiment._get_file_writer() | |
| return loss | |
| def on_validation_start(self) -> None: | |
| if not self.hparams_logged: | |
| self.setup_hparams() | |
| self.hparams_logged = True | |
| def _auto_gather(self, t): | |
| if t.dtype == torch.bool: | |
| t = t.to(torch.float) | |
| if self.trainer.num_devices == 1: | |
| return t.cpu() | |
| t = torch.clone(t).contiguous() | |
| if self.trainer.is_global_zero: | |
| gather_list = [torch.zeros_like(t) for _ in range(dist.get_world_size())] | |
| dist.gather(t, gather_list) | |
| return torch.cat(gather_list, dim=0).cpu() | |
| else: | |
| dist.gather(t) | |
| def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
| with torch.no_grad(): | |
| preds = self.forward(batch) | |
| ret = {} | |
| for k in preds.keys(): | |
| if k in preds: | |
| ret[k] = self._auto_gather(preds[k]) | |
| batch_keys = [IMAGE_INPUT, "spec", "semseg", "num_pixels_per_class", 'total_length'] | |
| for k in batch_keys: | |
| if k in batch: | |
| ret[k] = self._auto_gather(batch[k]) | |
| if "metadata" in batch: | |
| if isinstance(batch["metadata"]["id"], torch.Tensor): | |
| ret["id"] = self._auto_gather(batch["metadata"]["id"]) | |
| ret["index"] = self._auto_gather(batch["metadata"]["index"]) | |
| return ret | |
| def _calc_recalls(self, sim): | |
| top_10_a = sim.topk(10, 0).indices == torch.arange(sim.shape[0]).unsqueeze(0) | |
| top_10_i = (sim.topk(10, 1).indices == torch.arange(sim.shape[0]).unsqueeze(1)).permute(1, 0) | |
| a_recall = lambda p: top_10_a[0:p].any(0).to(sim).mean() | |
| i_recall = lambda p: top_10_i[0:p].any(0).to(sim).mean() | |
| return {'A_r1': a_recall(1), | |
| 'A_r5': a_recall(5), | |
| 'A_r10': a_recall(10), | |
| 'I_r1': i_recall(1), | |
| 'I_r5': i_recall(5), | |
| 'I_r10': i_recall(10)} | |
| def calc_recalls(self, preds, dataset): | |
| sim = self.sim_agg.forward_batched( | |
| preds=preds, | |
| agg_heads=False, | |
| batch_size=4, | |
| ).cpu() | |
| all_metrics = dict() | |
| for k, v in self._calc_recalls(sim.sum(-1)).items(): | |
| all_metrics[f"hp/{dataset}/total/" + k] = v | |
| return all_metrics | |
| def retrieval_validation(self, outputs, dataset_name): | |
| if len(outputs) == 0: | |
| return | |
| if self.trainer.is_global_zero: | |
| results = flatten_preds(outputs) | |
| if not self.trainer.sanity_checking: | |
| print(results[IMAGE_FEATS].shape[0]) | |
| # assert (results[IMAGE_FEATS].shape[0] == 1000) | |
| results[IMAGE_FEATS] = results[IMAGE_FEATS].cpu() | |
| results[AUDIO_FEATS] = results[AUDIO_FEATS].cuda() | |
| if self.sim_use_cls: | |
| results[AUDIO_CLS] = results[AUDIO_CLS].cuda() | |
| results[AUDIO_CLS] = results[AUDIO_CLS].cuda() | |
| results[AUDIO_MASK] = results[AUDIO_MASK].cuda() | |
| recalls = self.calc_recalls(results, dataset_name) | |
| results[IMAGE_FEATS] = results[IMAGE_FEATS].cuda() | |
| writer = self.logger.experiment | |
| print("here") | |
| for name, v in recalls.items(): | |
| writer.add_scalar(f"{name}", v, self.global_step + 1) | |
| def semseg_validation(self, speech_preds, sound_preds): | |
| if self.trainer.is_global_zero: | |
| from eval_utils import get_paired_heatmaps | |
| def prep_preds(preds, loader): | |
| results = flatten_preds(preds) | |
| metadata = loader.dataset.metadata | |
| ordered_metadata = metadata.iloc[results["index"].numpy(), :].copy() | |
| ordered_metadata["order"] = range(len(ordered_metadata)) | |
| return results, ordered_metadata | |
| [_, _, speech_loader, sound_loader] = self.trainer.val_dataloaders | |
| speech_results, speech_metadata = prep_preds(speech_preds, speech_loader) | |
| sound_results, sound_metadata = prep_preds(sound_preds, sound_loader) | |
| self.sound_metrics, unique_sound_indices = get_paired_heatmaps( | |
| self, sound_results, sound_metadata["ade_class_id"], None) | |
| self.speech_metrics, unique_word_indices = get_paired_heatmaps( | |
| self, speech_results, speech_metadata["ade_class_id"], speech_metadata["timing"]) | |
| writer = self.logger.experiment | |
| all_metrics = { | |
| **{"sound_" + k: v for k, v in self.sound_metrics.items()}, | |
| **{"speech_" + k: v for k, v in self.speech_metrics.items()}, | |
| } | |
| for k, v in all_metrics.items(): | |
| writer.add_scalar(f"hp/{k}", torch.tensor(v).mean(), self.global_step + 1) | |
| def disentangle_validation(self, word_preds, sound_preds): | |
| if len(word_preds) == 0 or len(sound_preds) == 0: | |
| return | |
| if self.trainer.is_global_zero: | |
| word_preds = flatten_preds(word_preds) | |
| sound_preds = flatten_preds(sound_preds) | |
| word_scores = self.sim_agg.get_pairwise_sims( | |
| word_preds, | |
| raw=False, | |
| agg_sim=True, | |
| agg_heads=False, | |
| ) | |
| sound_scores = self.sim_agg.get_pairwise_sims( | |
| sound_preds, | |
| raw=False, | |
| agg_sim=True, | |
| agg_heads=False, | |
| ) | |
| all_scores = torch.cat([word_scores, sound_scores], dim=0) | |
| all_scores -= all_scores.min(dim=0, keepdim=True).values | |
| all_scores /= all_scores.max(dim=0, keepdim=True).values.clamp_min(.0001) | |
| is_words = torch.cat([ | |
| torch.ones(word_scores.shape[0]), | |
| torch.zeros(sound_scores.shape[0])], dim=0).to(torch.bool) | |
| assert all_scores.shape[1] == 2 | |
| ap_matrix = torch.zeros(2, 2) | |
| act_matrix = torch.zeros(2, 2) | |
| for head in range(2): | |
| # writer.add_histogram(f"h{head}_all_scores", all_scores[:, head]) | |
| for dataset_num in range(2): | |
| if dataset_num == 0: | |
| labels = is_words | |
| else: | |
| labels = ~is_words | |
| ap_matrix[head, dataset_num] = binary_average_precision( | |
| all_scores[:, head].cpu(), labels.to(torch.int64).cpu()) | |
| act_matrix[head, dataset_num] = 1 - (all_scores[:, head][labels]).mean() | |
| ap_dis = max(.5 * (ap_matrix[0, 0] + ap_matrix[1, 1]), | |
| .5 * (ap_matrix[0, 1] + ap_matrix[1, 0])) | |
| act_dis = max(.5 * (act_matrix[0, 0] + act_matrix[1, 1]), | |
| .5 * (act_matrix[0, 1] + act_matrix[1, 0])) | |
| print("AP", ap_matrix) | |
| print("AP dis", ap_dis) | |
| print("Act", act_matrix) | |
| print("Act dis", act_dis) | |
| writer = self.logger.experiment | |
| writer.add_scalar("hp/ap_dis", ap_dis, self.global_step + 1) | |
| writer.add_scalar("hp/act_dis", act_dis, self.global_step + 1) | |
| def validation_epoch_end(self, outputs) -> None: | |
| print("Val end") | |
| with torch.no_grad(): | |
| if self.trainer.datamodule.use_extra_val_sets: | |
| if self.sim_agg_heads == 2: | |
| self.disentangle_validation(outputs[0], outputs[1]) | |
| self.retrieval_validation(outputs[0], "Places") | |
| self.retrieval_validation(outputs[1], "AudioSet") | |
| self.semseg_validation(outputs[2], outputs[3]) | |
| else: | |
| print("HERE!") | |
| self.retrieval_validation(outputs, "Val") | |
| writer = self.logger.experiment | |
| writer.flush() | |
| def _recursive_detach(self, obj, gather=True): | |
| if isinstance(obj, torch.Tensor): | |
| if gather: | |
| return self._auto_gather(obj) | |
| else: | |
| obj.detach() | |
| elif isinstance(obj, dict): | |
| return {k: self._recursive_detach(v, gather) for k, v in obj.items()} | |
| elif isinstance(obj, list): | |
| return [self._recursive_detach(v, gather) for v in obj] | |
| else: | |
| return obj | |
| def predict_step(self, batch, batch_idx: int, dataloader_idx: int = 0): | |
| with torch.no_grad(): | |
| predictions = {} | |
| for k, v in batch.items(): | |
| predictions[k] = self._recursive_detach(v) | |
| for k, v in self.forward(batch).items(): | |
| predictions[k] = self._auto_gather(v) | |
| return predictions | |
| def _configure_optimizers(self, full_train, lr): | |
| params = [ | |
| *self.audio_aligner.parameters(), | |
| *self.image_aligner.parameters(), | |
| *self.sim_cal.parameters(), | |
| *self.sim_agg.parameters() | |
| ] | |
| if (self.finetune_image_model or self.image_lora) and full_train: | |
| params.extend(self.image_model.parameters()) | |
| if (self.finetune_audio_model or self.audio_lora) and full_train: | |
| params.extend(self.audio_model.parameters()) | |
| if self.learn_audio_cls: | |
| params.append(self.audio_cls) | |
| last_epoch = self.global_step - 1 | |
| if self.optimizer == "adam": | |
| opt = torch.optim.Adam(params, lr=lr, eps=1e-7) | |
| elif self.optimizer == "nadam": | |
| opt = torch.optim.NAdam(params, lr=lr, eps=1e-7) | |
| else: | |
| raise ValueError(f"Unknown optimizer {self.optimizer}") | |
| if self.lr_schedule == "sgdr": | |
| scheduler = CosineAnnealingWarmRestarts( | |
| opt, self.lr_cycle_length, 2, eta_min=lr * 2e-2, last_epoch=last_epoch) | |
| else: | |
| scheduler = LambdaLR(opt, lr_lambda=lambda step: 1.0, last_epoch=last_epoch) | |
| if self.lr_warmup > 0: | |
| warmup = LambdaLR( | |
| opt, | |
| lr_lambda=lambda step: min(max(float(step), 0.0) / self.lr_warmup, 1.0), | |
| last_epoch=last_epoch, | |
| ) | |
| scheduler = SequentialLR( | |
| opt, | |
| schedulers=[warmup, scheduler], | |
| milestones=[self.lr_warmup], | |
| last_epoch=last_epoch) | |
| scheduler = {"scheduler": scheduler, "interval": "step"} | |
| return [opt], [scheduler] | |
| def configure_optimizers(self): | |
| if self.full_train: | |
| return self._configure_optimizers(self.full_train, self.lr) | |
| else: | |
| return self._configure_optimizers(self.full_train, self.pretrain_lr) | |
| def my_app(cfg: DictConfig) -> None: | |
| print(OmegaConf.to_yaml(cfg)) | |
| seed_everything(cfg.seed, workers=True) | |
| exp_name = f"{cfg.resume_prefix}" | |
| if cfg.image_model_type == "dino8": | |
| patch_size = 8 * cfg.image_pool_width | |
| elif cfg.image_model_type == "cavmae": | |
| patch_size = 16 * cfg.image_pool_width | |
| elif cfg.image_model_type == "imagebind": | |
| patch_size = 16 * cfg.image_pool_width | |
| elif cfg.image_model_type == "clip": | |
| patch_size = 16 * cfg.image_pool_width | |
| elif cfg.image_model_type == "cavmae-mixed": | |
| patch_size = 16 * cfg.image_pool_width | |
| elif cfg.image_model_type == "dinov2": | |
| patch_size = 14 * cfg.image_pool_width | |
| else: | |
| raise ValueError(f"Unknown patch size for model {cfg.image_model_type}") | |
| datamodule = AVDataModule( | |
| dataset_name=cfg.dataset_name, | |
| load_size=cfg.load_size, | |
| image_aug=cfg.image_aug, | |
| audio_aug=cfg.audio_aug, | |
| extra_audio_masking=cfg.extra_audio_masking, | |
| audio_model_type=cfg.audio_model_type, | |
| pytorch_data_dir=cfg.pytorch_data_dir, | |
| use_cached_embs=cfg.use_cached_embs, | |
| batch_size=cfg.batch_size, | |
| num_workers=cfg.num_workers, | |
| audio_level=cfg.audio_level, | |
| neg_audio=cfg.neg_audio, | |
| use_original_val_set=not cfg.use_extra_val_sets, | |
| use_extra_val_sets=cfg.use_extra_val_sets, | |
| data_for_plotting=False, | |
| quad_mixup=cfg.quad_mixup, | |
| bg_mixup=cfg.bg_mixup, | |
| patch_mixup=cfg.patch_mixup, | |
| patch_size=patch_size | |
| ) | |
| datamodule.maybe_unpack(remove_source=cfg.submitting_to_aml) | |
| aligner = create_model_from_cfg(LitAVAligner, cfg, {}) | |
| if cfg.starting_weights is not None: | |
| loaded = torch.load(join(cfg.output_root, cfg.starting_weights), map_location='cpu') | |
| state = loaded["state_dict"] | |
| aligner.load_state_dict(state, strict=cfg.load_strict) | |
| del state | |
| del loaded | |
| if cfg.num_gpus > 1: | |
| # strategy = "ddp_sharded" # _find_unused_parameters_true" | |
| strategy = "ddp" # _find_unused_parameters_true" | |
| else: | |
| strategy = "auto" | |
| if cfg.dataset_name in {"places-audio", "mixed", "audio-set", "mixed-full"}: | |
| val_args = dict(check_val_every_n_epoch=2) | |
| elif cfg.dataset_name in {"dolphin"}: | |
| val_args = dict(check_val_every_n_epoch=5) | |
| else: | |
| val_args = dict(val_check_interval=10000) | |
| # val_args = dict(val_check_interval=1000) | |
| def maybe_get_ckpt(ckpt_dir): | |
| if cfg.auto_resume and os.path.exists(ckpt_dir): | |
| print(f"Attempting to resume from {ckpt_dir}") | |
| candidates = os.listdir(ckpt_dir) | |
| assert (len(candidates) == 1) | |
| return join(ckpt_dir, candidates[0]) | |
| elif cfg.auto_resume: | |
| print(f"Could not find checkpoint at {ckpt_dir}") | |
| return None | |
| else: | |
| return None | |
| log_dir = join(cfg.output_root, "logs", cfg.grouping_name, exp_name) | |
| ckpt_dir = join(cfg.output_root, "checkpoints", cfg.grouping_name, exp_name) | |
| import gc | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def run_exp(aligner, full_train): | |
| trainer_args = dict( | |
| accelerator='gpu', | |
| strategy=strategy, | |
| devices=cfg.num_gpus, | |
| num_sanity_val_steps=cfg.num_sanity_val_steps, | |
| log_every_n_steps=50, | |
| reload_dataloaders_every_n_epochs=10, | |
| precision="16", | |
| # profiler="simple", | |
| # precision="bf16", | |
| max_steps=cfg.max_steps, | |
| **val_args) | |
| aligner.set_full_train(full_train) | |
| if full_train: | |
| suffix = "train" | |
| else: | |
| suffix = "pretrain" | |
| trainer_args["max_steps"] = cfg.pretrain_steps | |
| print(f"Starting {suffix} phase") | |
| logger = TensorBoardLogger(join(log_dir, suffix), default_hp_metric=False) | |
| callbacks = [ | |
| ModelCheckpoint(join(ckpt_dir, suffix), every_n_epochs=1), | |
| LearningRateMonitor(logging_interval='step'), | |
| ] | |
| Trainer(logger=logger, | |
| callbacks=callbacks, | |
| **trainer_args).fit( | |
| aligner, | |
| datamodule=datamodule, | |
| ckpt_path=maybe_get_ckpt(join(ckpt_dir, suffix))) | |
| train_chkpt = maybe_get_ckpt(join(ckpt_dir, "train")) | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if cfg.pretrain_steps > 0 and train_chkpt is None: | |
| print("---"*10) | |
| print("Setup with full_train = False") | |
| run_exp(aligner, full_train=False) | |
| print("---"*10) | |
| else: | |
| print("---"*10) | |
| print("Setup with full_train = False") | |
| run_exp(aligner, full_train=True) | |
| print("---"*10) | |
| if __name__ == "__main__": | |
| my_app() | |