Spaces:
Runtime error
Runtime error
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from random import randint | |
| from typing import Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from transformers import BartTokenizerFast | |
| class Preprocessor: | |
| encodec_base_path: Path | |
| clap_base_path: Path | |
| tokenizer: BartTokenizerFast = BartTokenizerFast.from_pretrained( | |
| "facebook/bart-base" | |
| ) | |
| max_length: int = 1024 | |
| mcm_masking_prob: float = 0.15 | |
| mcm_masking_span: int = 10 | |
| label_pad_token_id: int = -100 | |
| mask_token_id: int = 1024 | |
| num_eval_captions: int = 5 | |
| def __post_init__(self): | |
| if isinstance(self.encodec_base_path, str): | |
| self.encodec_base_path = Path(self.encodec_base_path) | |
| if isinstance(self.clap_base_path, str): | |
| self.clap_base_path = Path(self.clap_base_path) | |
| if isinstance(self.tokenizer, str): | |
| self.tokenizer = BartTokenizerFast.from_pretrained(self.tokenizer) | |
| def preprocess_train(self, example): | |
| path = example["file_path"] | |
| encodec = np.load(self.encodec_base_path / path) | |
| clap_embedding = np.load(self.clap_base_path / path) | |
| encodec_mask = np.array( | |
| [0, 0] + [1] * min(encodec.shape[0], self.max_length - 3) + [0] | |
| ) | |
| attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( | |
| np.int64 | |
| ) | |
| target_text = self.tokenizer(text_target=example["caption"]) | |
| if encodec.shape[0] + 3 > self.max_length: | |
| start = randint(0, encodec.shape[0] - self.max_length + 3) | |
| encodec = encodec[start : start + self.max_length - 3] | |
| mcm_labels = None | |
| if self.mcm_masking_prob > 0: | |
| num_rvq = encodec.shape[-1] | |
| mcm_mask, _ = _compute_mask_indices( | |
| encodec.T.shape, self.mcm_masking_prob, self.mcm_masking_span | |
| ) | |
| mcm_mask = mcm_mask.T | |
| mcm_labels = np.where(mcm_mask, encodec, self.label_pad_token_id) | |
| mcm_labels = np.concatenate( | |
| [ | |
| np.ones((2, num_rvq), dtype=np.int64) * self.label_pad_token_id, | |
| mcm_labels, | |
| np.ones((1, num_rvq), dtype=np.int64) * self.label_pad_token_id, | |
| ], | |
| axis=0, | |
| ) | |
| encodec[mcm_mask] = self.mask_token_id | |
| encodec = np.concatenate( | |
| [ | |
| np.ones((2, num_rvq), dtype=np.int64) * self.tokenizer.bos_token_id, | |
| encodec, | |
| np.ones((1, num_rvq), dtype=np.int64) * self.tokenizer.eos_token_id, | |
| ], | |
| axis=0, | |
| ) | |
| return { | |
| "input_ids": encodec, | |
| "clap_embedding": clap_embedding, | |
| "encodec_mask": encodec_mask, | |
| "attention_mask": attention_mask, | |
| "mcm_labels": mcm_labels, | |
| "labels": target_text["input_ids"], | |
| } | |
| def preprocess_eval(self, example): | |
| path = example["file_path"] | |
| encodec = np.load(self.encodec_base_path / path) | |
| clap_embedding = np.load(self.clap_base_path / path) | |
| attention_mask = np.ones(min(encodec.shape[0] + 3, self.max_length)).astype( | |
| np.int64 | |
| ) | |
| if encodec.shape[0] + 3 > self.max_length: | |
| encodec = encodec[: self.max_length - 3] | |
| captions = [] | |
| for i in range(self.num_eval_captions): | |
| captions.append(example[f"caption_{i+1}"]) | |
| return { | |
| "input_ids": encodec, | |
| "attention_mask": attention_mask, | |
| "clap": clap_embedding, | |
| "captions": captions, | |
| } | |
| def _compute_mask_indices( | |
| shape: Tuple[int, int], | |
| mask_prob: float, | |
| mask_length: int, | |
| attention_mask: Optional[torch.LongTensor] = None, | |
| min_masks: int = 0, | |
| ) -> np.ndarray: | |
| """ | |
| Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for | |
| ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on | |
| CPU as part of the preprocessing during training. | |
| Args: | |
| shape: The shape for which to compute masks. This should be of a tuple of size 2 where | |
| the first element is the batch size and the second element is the length of the axis to span. | |
| mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of | |
| independently generated mask spans of length `mask_length` is computed by | |
| `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the | |
| actual percentage will be smaller. | |
| mask_length: size of the mask | |
| min_masks: minimum number of masked spans | |
| attention_mask: A (right-padded) attention mask which independently shortens the feature axis of | |
| each batch dimension. | |
| """ | |
| batch_size, sequence_length = shape | |
| if mask_length < 1: | |
| raise ValueError("`mask_length` has to be bigger than 0.") | |
| if mask_length > sequence_length: | |
| raise ValueError( | |
| f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" | |
| f" and `sequence_length`: {sequence_length}`" | |
| ) | |
| # epsilon is used for probabilistic rounding | |
| epsilon = np.random.rand(1).item() | |
| def compute_num_masked_span(input_length): | |
| """Given input length, compute how many spans should be masked""" | |
| num_masked_span = int(mask_prob * input_length / mask_length + epsilon) | |
| num_masked_span = max(num_masked_span, min_masks) | |
| # make sure num masked span <= sequence_length | |
| if num_masked_span * mask_length > sequence_length: | |
| num_masked_span = sequence_length // mask_length | |
| # make sure num_masked span is also <= input_length - (mask_length - 1) | |
| if input_length - (mask_length - 1) < num_masked_span: | |
| num_masked_span = max(input_length - (mask_length - 1), 0) | |
| return num_masked_span | |
| # compute number of masked spans in batch | |
| input_lengths = ( | |
| attention_mask.sum(-1).detach().tolist() | |
| if attention_mask is not None | |
| else [sequence_length for _ in range(batch_size)] | |
| ) | |
| # SpecAugment mask to fill | |
| spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) | |
| spec_aug_mask_idxs = [] | |
| max_num_masked_span = compute_num_masked_span(sequence_length) | |
| if max_num_masked_span == 0: | |
| return spec_aug_mask | |
| for input_length in input_lengths: | |
| # compute num of masked spans for this input | |
| num_masked_span = compute_num_masked_span(input_length) | |
| # get random indices to mask | |
| spec_aug_mask_idx = np.random.choice( | |
| np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False | |
| ) | |
| # pick first sampled index that will serve as a dummy index to pad vector | |
| # to ensure same dimension for all batches due to probabilistic rounding | |
| # Picking first sample just pads those vectors twice. | |
| if len(spec_aug_mask_idx) == 0: | |
| # this case can only happen if `input_length` is strictly smaller then | |
| # `sequence_length` in which case the last token has to be a padding | |
| # token which we can use as a dummy mask id | |
| dummy_mask_idx = sequence_length - 1 | |
| else: | |
| dummy_mask_idx = spec_aug_mask_idx[0] | |
| spec_aug_mask_idx = np.concatenate( | |
| [ | |
| spec_aug_mask_idx, | |
| np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) | |
| * dummy_mask_idx, | |
| ] | |
| ) | |
| spec_aug_mask_idxs.append(spec_aug_mask_idx) | |
| spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) | |
| # expand masked indices to masked spans | |
| spec_aug_mask_idxs = np.broadcast_to( | |
| spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) | |
| ) | |
| spec_aug_mask_idxs = spec_aug_mask_idxs.reshape( | |
| batch_size, max_num_masked_span * mask_length | |
| ) | |
| # add offset to the starting indexes so that indexes now create a span | |
| offsets = np.arange(mask_length)[None, None, :] | |
| offsets = np.broadcast_to( | |
| offsets, (batch_size, max_num_masked_span, mask_length) | |
| ).reshape(batch_size, max_num_masked_span * mask_length) | |
| spec_aug_mask_idxs = spec_aug_mask_idxs + offsets | |
| # ensure that we cannot have indices larger than sequence_length | |
| if spec_aug_mask_idxs.max() > sequence_length - 1: | |
| spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = ( | |
| sequence_length - 1 | |
| ) | |
| # scatter indices to mask | |
| np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) | |
| return torch.from_numpy(spec_aug_mask), spec_aug_mask_idxs | |