Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import typing as tp | |
| import flashy | |
| import julius | |
| import omegaconf | |
| import torch | |
| import torch.nn.functional as F | |
| from . import builders | |
| from . import base | |
| from .. import models | |
| from ..modules.diffusion_schedule import NoiseSchedule | |
| from ..metrics import RelativeVolumeMel | |
| from ..models.builders import get_processor | |
| from ..utils.samples.manager import SampleManager | |
| from ..solvers.compression import CompressionSolver | |
| class PerStageMetrics: | |
| """Handle prompting the metrics per stage. | |
| It outputs the metrics per range of diffusion states. | |
| e.g. avg loss when t in [250, 500] | |
| """ | |
| def __init__(self, num_steps: int, num_stages: int = 4): | |
| self.num_steps = num_steps | |
| self.num_stages = num_stages | |
| def __call__(self, losses: dict, step: tp.Union[int, torch.Tensor]): | |
| if type(step) is int: | |
| stage = int((step / self.num_steps) * self.num_stages) | |
| return {f"{name}_{stage}": loss for name, loss in losses.items()} | |
| elif type(step) is torch.Tensor: | |
| stage_tensor = ((step / self.num_steps) * self.num_stages).long() | |
| out: tp.Dict[str, float] = {} | |
| for stage_idx in range(self.num_stages): | |
| mask = (stage_tensor == stage_idx) | |
| N = mask.sum() | |
| stage_out = {} | |
| if N > 0: # pass if no elements in the stage | |
| for name, loss in losses.items(): | |
| stage_loss = (mask * loss).sum() / N | |
| stage_out[f"{name}_{stage_idx}"] = stage_loss | |
| out = {**out, **stage_out} | |
| return out | |
| class DataProcess: | |
| """Apply filtering or resampling. | |
| Args: | |
| initial_sr (int): Initial sample rate. | |
| target_sr (int): Target sample rate. | |
| use_resampling: Whether to use resampling or not. | |
| use_filter (bool): | |
| n_bands (int): Number of bands to consider. | |
| idx_band (int): | |
| device (torch.device or str): | |
| cutoffs (): | |
| boost (bool): | |
| """ | |
| def __init__(self, initial_sr: int = 24000, target_sr: int = 16000, use_resampling: bool = False, | |
| use_filter: bool = False, n_bands: int = 4, | |
| idx_band: int = 0, device: torch.device = torch.device('cpu'), cutoffs=None, boost=False): | |
| """Apply filtering or resampling | |
| Args: | |
| initial_sr (int): sample rate of the dataset | |
| target_sr (int): sample rate after resampling | |
| use_resampling (bool): whether or not performs resampling | |
| use_filter (bool): when True filter the data to keep only one frequency band | |
| n_bands (int): Number of bands used | |
| cuts (none or list): The cutoff frequencies of the band filtering | |
| if None then we use mel scale bands. | |
| idx_band (int): index of the frequency band. 0 are lows ... (n_bands - 1) highs | |
| boost (bool): make the data scale match our music dataset. | |
| """ | |
| assert idx_band < n_bands | |
| self.idx_band = idx_band | |
| if use_filter: | |
| if cutoffs is not None: | |
| self.filter = julius.SplitBands(sample_rate=initial_sr, cutoffs=cutoffs).to(device) | |
| else: | |
| self.filter = julius.SplitBands(sample_rate=initial_sr, n_bands=n_bands).to(device) | |
| self.use_filter = use_filter | |
| self.use_resampling = use_resampling | |
| self.target_sr = target_sr | |
| self.initial_sr = initial_sr | |
| self.boost = boost | |
| def process_data(self, x, metric=False): | |
| if x is None: | |
| return None | |
| if self.boost: | |
| x /= torch.clamp(x.std(dim=(1, 2), keepdim=True), min=1e-4) | |
| x * 0.22 | |
| if self.use_filter and not metric: | |
| x = self.filter(x)[self.idx_band] | |
| if self.use_resampling: | |
| x = julius.resample_frac(x, old_sr=self.initial_sr, new_sr=self.target_sr) | |
| return x | |
| def inverse_process(self, x): | |
| """Upsampling only.""" | |
| if self.use_resampling: | |
| x = julius.resample_frac(x, old_sr=self.target_sr, new_sr=self.target_sr) | |
| return x | |
| class DiffusionSolver(base.StandardSolver): | |
| """Solver for compression task. | |
| The diffusion task allows for MultiBand diffusion model training. | |
| Args: | |
| cfg (DictConfig): Configuration. | |
| """ | |
| def __init__(self, cfg: omegaconf.DictConfig): | |
| super().__init__(cfg) | |
| self.cfg = cfg | |
| self.device = cfg.device | |
| self.sample_rate: int = self.cfg.sample_rate | |
| self.codec_model = CompressionSolver.model_from_checkpoint( | |
| cfg.compression_model_checkpoint, device=self.device) | |
| self.codec_model.set_num_codebooks(cfg.n_q) | |
| assert self.codec_model.sample_rate == self.cfg.sample_rate, ( | |
| f"Codec model sample rate is {self.codec_model.sample_rate} but " | |
| f"Solver sample rate is {self.cfg.sample_rate}." | |
| ) | |
| assert self.codec_model.sample_rate == self.sample_rate, \ | |
| f"Sample rate of solver {self.sample_rate} and codec {self.codec_model.sample_rate} " \ | |
| "don't match." | |
| self.sample_processor = get_processor(cfg.processor, sample_rate=self.sample_rate) | |
| self.register_stateful('sample_processor') | |
| self.sample_processor.to(self.device) | |
| self.schedule = NoiseSchedule( | |
| **cfg.schedule, device=self.device, sample_processor=self.sample_processor) | |
| self.eval_metric: tp.Optional[torch.nn.Module] = None | |
| self.rvm = RelativeVolumeMel() | |
| self.data_processor = DataProcess(initial_sr=self.sample_rate, target_sr=cfg.resampling.target_sr, | |
| use_resampling=cfg.resampling.use, cutoffs=cfg.filter.cutoffs, | |
| use_filter=cfg.filter.use, n_bands=cfg.filter.n_bands, | |
| idx_band=cfg.filter.idx_band, device=self.device) | |
| def best_metric_name(self) -> tp.Optional[str]: | |
| if self._current_stage == "evaluate": | |
| return 'rvm' | |
| else: | |
| return 'loss' | |
| def get_condition(self, wav: torch.Tensor) -> torch.Tensor: | |
| codes, scale = self.codec_model.encode(wav) | |
| assert scale is None, "Scaled compression models not supported." | |
| emb = self.codec_model.decode_latent(codes) | |
| return emb | |
| def build_model(self): | |
| """Build model and optimizer as well as optional Exponential Moving Average of the model. | |
| """ | |
| # Model and optimizer | |
| self.model = models.builders.get_diffusion_model(self.cfg).to(self.device) | |
| self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim) | |
| self.register_stateful('model', 'optimizer') | |
| self.register_best_state('model') | |
| self.register_ema('model') | |
| def build_dataloaders(self): | |
| """Build audio dataloaders for each stage.""" | |
| self.dataloaders = builders.get_audio_datasets(self.cfg) | |
| def show(self): | |
| # TODO | |
| raise NotImplementedError() | |
| def run_step(self, idx: int, batch: torch.Tensor, metrics: dict): | |
| """Perform one training or valid step on a given batch.""" | |
| x = batch.to(self.device) | |
| loss_fun = F.mse_loss if self.cfg.loss.kind == 'mse' else F.l1_loss | |
| condition = self.get_condition(x) # [bs, 128, T/hop, n_emb] | |
| sample = self.data_processor.process_data(x) | |
| input_, target, step = self.schedule.get_training_item(sample, | |
| tensor_step=self.cfg.schedule.variable_step_batch) | |
| out = self.model(input_, step, condition=condition).sample | |
| base_loss = loss_fun(out, target, reduction='none').mean(dim=(1, 2)) | |
| reference_loss = loss_fun(input_, target, reduction='none').mean(dim=(1, 2)) | |
| loss = base_loss / reference_loss ** self.cfg.loss.norm_power | |
| if self.is_training: | |
| loss.mean().backward() | |
| flashy.distrib.sync_model(self.model) | |
| self.optimizer.step() | |
| self.optimizer.zero_grad() | |
| metrics = { | |
| 'loss': loss.mean(), 'normed_loss': (base_loss / reference_loss).mean(), | |
| } | |
| metrics.update(self.per_stage({'loss': loss, 'normed_loss': base_loss / reference_loss}, step)) | |
| metrics.update({ | |
| 'std_in': input_.std(), 'std_out': out.std()}) | |
| return metrics | |
| def run_epoch(self): | |
| # reset random seed at the beginning of the epoch | |
| self.rng = torch.Generator() | |
| self.rng.manual_seed(1234 + self.epoch) | |
| self.per_stage = PerStageMetrics(self.schedule.num_steps, self.cfg.metrics.num_stage) | |
| # run epoch | |
| super().run_epoch() | |
| def evaluate(self): | |
| """Evaluate stage. | |
| Runs audio reconstruction evaluation. | |
| """ | |
| self.model.eval() | |
| evaluate_stage_name = f'{self.current_stage}' | |
| loader = self.dataloaders['evaluate'] | |
| updates = len(loader) | |
| lp = self.log_progress(f'{evaluate_stage_name} estimate', loader, total=updates, updates=self.log_updates) | |
| metrics = {} | |
| n = 1 | |
| for idx, batch in enumerate(lp): | |
| x = batch.to(self.device) | |
| with torch.no_grad(): | |
| y_pred = self.regenerate(x) | |
| y_pred = y_pred.cpu() | |
| y = batch.cpu() # should already be on CPU but just in case | |
| rvm = self.rvm(y_pred, y) | |
| lp.update(**rvm) | |
| if len(metrics) == 0: | |
| metrics = rvm | |
| else: | |
| for key in rvm.keys(): | |
| metrics[key] = (metrics[key] * n + rvm[key]) / (n + 1) | |
| metrics = flashy.distrib.average_metrics(metrics) | |
| return metrics | |
| def regenerate(self, wav: torch.Tensor, step_list: tp.Optional[list] = None): | |
| """Regenerate the given waveform.""" | |
| condition = self.get_condition(wav) | |
| initial = self.schedule.get_initial_noise(self.data_processor.process_data(wav)) # sampling rate changes. | |
| result = self.schedule.generate_subsampled(self.model, initial=initial, condition=condition, | |
| step_list=step_list) | |
| result = self.data_processor.inverse_process(result) | |
| return result | |
| def generate(self): | |
| """Generate stage.""" | |
| sample_manager = SampleManager(self.xp) | |
| self.model.eval() | |
| generate_stage_name = f'{self.current_stage}' | |
| loader = self.dataloaders['generate'] | |
| updates = len(loader) | |
| lp = self.log_progress(generate_stage_name, loader, total=updates, updates=self.log_updates) | |
| for batch in lp: | |
| reference, _ = batch | |
| reference = reference.to(self.device) | |
| estimate = self.regenerate(reference) | |
| reference = reference.cpu() | |
| estimate = estimate.cpu() | |
| sample_manager.add_samples(estimate, self.epoch, ground_truth_wavs=reference) | |
| flashy.distrib.barrier() | |