Spaces:
Paused
Paused
| import torch | |
| import torchaudio | |
| import torchaudio.functional | |
| from torchvision import transforms | |
| import torchvision.transforms.functional as F | |
| import torch.nn as nn | |
| from PIL import Image | |
| import numpy as np | |
| import math | |
| import random | |
| import soundfile | |
| import os | |
| import librosa | |
| import albumentations | |
| from torch_pitch_shift import * | |
| SR = 22050 | |
| class ResizeShortSide(object): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.size = size | |
| def __call__(self, x): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| w, h = x.size | |
| short_side = min(w, h) | |
| w_target = int((w / short_side) * self.size) | |
| h_target = int((h / short_side) * self.size) | |
| return x.resize((w_target, h_target)) | |
| class Crop(object): | |
| def __init__(self, cropped_shape=None, random_crop=False): | |
| self.cropped_shape = cropped_shape | |
| if cropped_shape is not None: | |
| mel_num, spec_len = cropped_shape | |
| if random_crop: | |
| self.cropper = albumentations.RandomCrop | |
| else: | |
| self.cropper = albumentations.CenterCrop | |
| self.preprocessor = albumentations.Compose([self.cropper(mel_num, spec_len)]) | |
| else: | |
| self.preprocessor = lambda **kwargs: kwargs | |
| def __call__(self, item): | |
| item['image'] = self.preprocessor(image=item['image'])['image'] | |
| if 'cond_image' in item.keys(): | |
| item['cond_image'] = self.preprocessor(image=item['cond_image'])['image'] | |
| return item | |
| class CropImage(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| class CropFeats(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| def __call__(self, item): | |
| item['feature'] = self.preprocessor(image=item['feature'])['image'] | |
| return item | |
| class CropCoords(Crop): | |
| def __init__(self, *crop_args): | |
| super().__init__(*crop_args) | |
| def __call__(self, item): | |
| item['coord'] = self.preprocessor(image=item['coord'])['image'] | |
| return item | |
| class RandomResizedCrop3D(nn.Module): | |
| """Crop the given series of images to random size and aspect ratio. | |
| The image can be a PIL Images or a Tensor, in which case it is expected | |
| to have [N, ..., H, W] shape, where ... means an arbitrary number of leading dimensions | |
| A crop of random size (default: of 0.08 to 1.0) of the original size and a random | |
| aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop | |
| is finally resized to given size. | |
| This is popularly used to train the Inception networks. | |
| Args: | |
| size (int or sequence): expected output size of each edge. If size is an | |
| int instead of sequence like (h, w), a square output size ``(size, size)`` is | |
| made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]). | |
| scale (tuple of float): range of size of the origin size cropped | |
| ratio (tuple of float): range of aspect ratio of the origin aspect ratio cropped. | |
| interpolation (int): Desired interpolation enum defined by `filters`_. | |
| Default is ``PIL.Image.BILINEAR``. If input is Tensor, only ``PIL.Image.NEAREST``, ``PIL.Image.BILINEAR`` | |
| and ``PIL.Image.BICUBIC`` are supported. | |
| """ | |
| def __init__(self, size, scale=(0.08, 1.0), ratio=(3. / 4., 4. / 3.), interpolation=transforms.InterpolationMode.BILINEAR): | |
| super().__init__() | |
| if isinstance(size, tuple) and len(size) == 2: | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.interpolation = interpolation | |
| self.scale = scale | |
| self.ratio = ratio | |
| def get_params(img, scale, ratio): | |
| """Get parameters for ``crop`` for a random sized crop. | |
| Args: | |
| img (PIL Image or Tensor): Input image. | |
| scale (list): range of scale of the origin size cropped | |
| ratio (list): range of aspect ratio of the origin aspect ratio cropped | |
| Returns: | |
| tuple: params (i, j, h, w) to be passed to ``crop`` for a random | |
| sized crop. | |
| """ | |
| width, height = img.size | |
| area = height * width | |
| for _ in range(10): | |
| target_area = area * \ | |
| torch.empty(1).uniform_(scale[0], scale[1]).item() | |
| log_ratio = torch.log(torch.tensor(ratio)) | |
| aspect_ratio = torch.exp( | |
| torch.empty(1).uniform_(log_ratio[0], log_ratio[1]) | |
| ).item() | |
| w = int(round(math.sqrt(target_area * aspect_ratio))) | |
| h = int(round(math.sqrt(target_area / aspect_ratio))) | |
| if 0 < w <= width and 0 < h <= height: | |
| i = torch.randint(0, height - h + 1, size=(1,)).item() | |
| j = torch.randint(0, width - w + 1, size=(1,)).item() | |
| return i, j, h, w | |
| # Fallback to central crop | |
| in_ratio = float(width) / float(height) | |
| if in_ratio < min(ratio): | |
| w = width | |
| h = int(round(w / min(ratio))) | |
| elif in_ratio > max(ratio): | |
| h = height | |
| w = int(round(h * max(ratio))) | |
| else: # whole image | |
| w = width | |
| h = height | |
| i = (height - h) // 2 | |
| j = (width - w) // 2 | |
| return i, j, h, w | |
| def forward(self, imgs): | |
| """ | |
| Args: | |
| img (PIL Image or Tensor): Image to be cropped and resized. | |
| Returns: | |
| PIL Image or Tensor: Randomly cropped and resized image. | |
| """ | |
| i, j, h, w = self.get_params(imgs[0], self.scale, self.ratio) | |
| return [F.resized_crop(img, i, j, h, w, self.size, self.interpolation) for img in imgs] | |
| class Resize3D(object): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.size = size | |
| def __call__(self, imgs): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| return [x.resize((self.size, self.size)) for x in imgs] | |
| class RandomHorizontalFlip3D(object): | |
| def __init__(self, p=0.5): | |
| super().__init__() | |
| self.p = p | |
| def __call__(self, imgs): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| if np.random.rand() < self.p: | |
| return [x.transpose(Image.FLIP_LEFT_RIGHT) for x in imgs] | |
| else: | |
| return imgs | |
| class ColorJitter3D(torch.nn.Module): | |
| """Randomly change the brightness, contrast and saturation of an image. | |
| Args: | |
| brightness (float or tuple of float (min, max)): How much to jitter brightness. | |
| brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] | |
| or the given [min, max]. Should be non negative numbers. | |
| contrast (float or tuple of float (min, max)): How much to jitter contrast. | |
| contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] | |
| or the given [min, max]. Should be non negative numbers. | |
| saturation (float or tuple of float (min, max)): How much to jitter saturation. | |
| saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] | |
| or the given [min, max]. Should be non negative numbers. | |
| hue (float or tuple of float (min, max)): How much to jitter hue. | |
| hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. | |
| Should have 0<= hue <= 0.5 or -0.5 <= min <= max <= 0.5. | |
| """ | |
| def __init__(self, brightness=0, contrast=0, saturation=0, hue=0): | |
| super().__init__() | |
| self.brightness = (1-brightness, 1+brightness) | |
| self.contrast = (1-contrast, 1+contrast) | |
| self.saturation = (1-saturation, 1+saturation) | |
| self.hue = (0-hue, 0+hue) | |
| def get_params(brightness, contrast, saturation, hue): | |
| """Get a randomized transform to be applied on image. | |
| Arguments are same as that of __init__. | |
| Returns: | |
| Transform which randomly adjusts brightness, contrast and | |
| saturation in a random order. | |
| """ | |
| tfs = [] | |
| if brightness is not None: | |
| brightness_factor = random.uniform(brightness[0], brightness[1]) | |
| tfs.append(transforms.Lambda( | |
| lambda img: F.adjust_brightness(img, brightness_factor))) | |
| if contrast is not None: | |
| contrast_factor = random.uniform(contrast[0], contrast[1]) | |
| tfs.append(transforms.Lambda( | |
| lambda img: F.adjust_contrast(img, contrast_factor))) | |
| if saturation is not None: | |
| saturation_factor = random.uniform(saturation[0], saturation[1]) | |
| tfs.append(transforms.Lambda( | |
| lambda img: F.adjust_saturation(img, saturation_factor))) | |
| if hue is not None: | |
| hue_factor = random.uniform(hue[0], hue[1]) | |
| tfs.append(transforms.Lambda( | |
| lambda img: F.adjust_hue(img, hue_factor))) | |
| random.shuffle(tfs) | |
| transform = transforms.Compose(tfs) | |
| return transform | |
| def forward(self, imgs): | |
| """ | |
| Args: | |
| img (PIL Image or Tensor): Input image. | |
| Returns: | |
| PIL Image or Tensor: Color jittered image. | |
| """ | |
| transform = self.get_params( | |
| self.brightness, self.contrast, self.saturation, self.hue) | |
| return [transform(img) for img in imgs] | |
| class ToTensor3D(object): | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, imgs): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| return [F.to_tensor(img) for img in imgs] | |
| class Normalize3D(object): | |
| def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], inplace=False): | |
| super().__init__() | |
| self.mean = mean | |
| self.std = std | |
| self.inplace = inplace | |
| def __call__(self, imgs): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| return [F.normalize(img, self.mean, self.std, self.inplace) for img in imgs] | |
| class CenterCrop3D(object): | |
| def __init__(self, size): | |
| super().__init__() | |
| self.size = size | |
| def __call__(self, imgs): | |
| ''' | |
| x must be PIL.Image | |
| ''' | |
| return [F.center_crop(img, self.size) for img in imgs] | |
| class FrequencyMasking(object): | |
| def __init__(self, freq_mask_param: int, iid_masks: bool = False): | |
| super().__init__() | |
| self.masking = torchaudio.transforms.FrequencyMasking(freq_mask_param, iid_masks) | |
| def __call__(self, item): | |
| if 'cond_image' in item.keys(): | |
| batched_spec = torch.stack( | |
| [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0 | |
| )[:, None] # (2, 1, H, W) | |
| masked = self.masking(batched_spec).numpy() | |
| item['image'] = masked[0, 0] | |
| item['cond_image'] = masked[1, 0] | |
| elif 'image' in item.keys(): | |
| inp = torch.tensor(item['image']) | |
| item['image'] = self.masking(inp).numpy() | |
| else: | |
| raise NotImplementedError() | |
| return item | |
| class TimeMasking(object): | |
| def __init__(self, time_mask_param: int, iid_masks: bool = False): | |
| super().__init__() | |
| self.masking = torchaudio.transforms.TimeMasking(time_mask_param, iid_masks) | |
| def __call__(self, item): | |
| if 'cond_image' in item.keys(): | |
| batched_spec = torch.stack( | |
| [torch.tensor(item['image']), torch.tensor(item['cond_image'])], dim=0 | |
| )[:, None] # (2, 1, H, W) | |
| masked = self.masking(batched_spec).numpy() | |
| item['image'] = masked[0, 0] | |
| item['cond_image'] = masked[1, 0] | |
| elif 'image' in item.keys(): | |
| inp = torch.tensor(item['image']) | |
| item['image'] = self.masking(inp).numpy() | |
| else: | |
| raise NotImplementedError() | |
| return item | |
| class PitchShift(nn.Module): | |
| def __init__(self, up=12, down=-12, sample_rate=SR): | |
| super().__init__() | |
| self.range = (down, up) | |
| self.sr = sample_rate | |
| def forward(self, x): | |
| assert len(x.shape) == 2 | |
| x = x[:, None, :] | |
| ratio = float(random.randint(self.range[0], self.range[1]) / 12.) | |
| shifted = pitch_shift(x, ratio, self.sr) | |
| return shifted.squeeze() | |
| class MelSpectrogram(object): | |
| def __init__(self, sr, nfft, fmin, fmax, nmels, hoplen, spec_power, inverse=False): | |
| self.sr = sr | |
| self.nfft = nfft | |
| self.fmin = fmin | |
| self.fmax = fmax | |
| self.nmels = nmels | |
| self.hoplen = hoplen | |
| self.spec_power = spec_power | |
| self.inverse = inverse | |
| self.mel_basis = librosa.filters.mel(sr=sr, n_fft=nfft, fmin=fmin, fmax=fmax, n_mels=nmels) | |
| def __call__(self, x): | |
| x = x.numpy() | |
| if self.inverse: | |
| spec = librosa.feature.inverse.mel_to_stft( | |
| x, sr=self.sr, n_fft=self.nfft, fmin=self.fmin, fmax=self.fmax, power=self.spec_power | |
| ) | |
| wav = librosa.griffinlim(spec, hop_length=self.hoplen) | |
| return torch.FloatTensor(wav) | |
| else: | |
| spec = np.abs(librosa.stft(x, n_fft=self.nfft, hop_length=self.hoplen)) ** self.spec_power | |
| mel_spec = np.dot(self.mel_basis, spec) | |
| return torch.FloatTensor(mel_spec) | |
| class SpectrogramTorchAudio(object): | |
| def __init__(self, nfft, hoplen, spec_power, inverse=False): | |
| self.nfft = nfft | |
| self.hoplen = hoplen | |
| self.spec_power = spec_power | |
| self.inverse = inverse | |
| self.spec_trans = torchaudio.transforms.Spectrogram( | |
| n_fft=self.nfft, | |
| hop_length=self.hoplen, | |
| power=self.spec_power, | |
| ) | |
| self.inv_spec_trans = torchaudio.transforms.GriffinLim( | |
| n_fft=self.nfft, | |
| hop_length=self.hoplen, | |
| power=self.spec_power, | |
| ) | |
| def __call__(self, x): | |
| if self.inverse: | |
| wav = self.inv_spec_trans(x) | |
| return wav | |
| else: | |
| spec = torch.abs(self.spec_trans(x)) | |
| return spec | |
| class MelScaleTorchAudio(object): | |
| def __init__(self, sr, stft, fmin, fmax, nmels, inverse=False): | |
| self.sr = sr | |
| self.stft = stft | |
| self.fmin = fmin | |
| self.fmax = fmax | |
| self.nmels = nmels | |
| self.inverse = inverse | |
| self.mel_trans = torchaudio.transforms.MelScale( | |
| n_mels=self.nmels, | |
| sample_rate=self.sr, | |
| f_min=self.fmin, | |
| f_max=self.fmax, | |
| n_stft=self.stft, | |
| norm='slaney' | |
| ) | |
| self.inv_mel_trans = torchaudio.transforms.InverseMelScale( | |
| n_mels=self.nmels, | |
| sample_rate=self.sr, | |
| f_min=self.fmin, | |
| f_max=self.fmax, | |
| n_stft=self.stft, | |
| norm='slaney' | |
| ) | |
| def __call__(self, x): | |
| if self.inverse: | |
| spec = self.inv_mel_trans(x) | |
| return spec | |
| else: | |
| mel_spec = self.mel_trans(x) | |
| return mel_spec | |
| class Padding(object): | |
| def __init__(self, target_len, inverse=False): | |
| self.target_len=int(target_len) | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| x = x.squeeze() | |
| if x.shape[0] < self.target_len: | |
| pad = torch.zeros((self.target_len,), dtype=x.dtype, device=x.device) | |
| pad[:x.shape[0]] = x | |
| x = pad | |
| elif x.shape[0] > self.target_len: | |
| raise NotImplementedError() | |
| return x | |
| class MakeMono(object): | |
| def __init__(self, inverse=False): | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| x = x.squeeze() | |
| if len(x.shape) == 1: | |
| return torch.FloatTensor(x) | |
| elif len(x.shape) == 2: | |
| target_dim = int(torch.argmin(torch.tensor(x.shape))) | |
| return torch.mean(x, dim=target_dim) | |
| else: | |
| raise NotImplementedError | |
| class LowerThresh(object): | |
| def __init__(self, min_val, inverse=False): | |
| self.min_val = torch.tensor(min_val) | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| return torch.maximum(self.min_val, x) | |
| class Add(object): | |
| def __init__(self, val, inverse=False): | |
| self.inverse = inverse | |
| self.val = val | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x - self.val | |
| else: | |
| return x + self.val | |
| class Subtract(Add): | |
| def __init__(self, val, inverse=False): | |
| self.inverse = inverse | |
| self.val = val | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x + self.val | |
| else: | |
| return x - self.val | |
| class Multiply(object): | |
| def __init__(self, val, inverse=False) -> None: | |
| self.val = val | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x / self.val | |
| else: | |
| return x * self.val | |
| class Divide(Multiply): | |
| def __init__(self, val, inverse=False): | |
| self.inverse = inverse | |
| self.val = val | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x * self.val | |
| else: | |
| return x / self.val | |
| class Log10(object): | |
| def __init__(self, inverse=False): | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return 10 ** x | |
| else: | |
| return torch.log10(x) | |
| class Clip(object): | |
| def __init__(self, min_val, max_val, inverse=False): | |
| self.min_val = min_val | |
| self.max_val = max_val | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| return torch.clip(x, self.min_val, self.max_val) | |
| class TrimSpec(object): | |
| def __init__(self, max_len, inverse=False): | |
| self.max_len = max_len | |
| self.inverse = inverse | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| return x[:, :self.max_len] | |
| class MaxNorm(object): | |
| def __init__(self, inverse=False): | |
| self.inverse = inverse | |
| self.eps = 1e-10 | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| return x / (x.max() + self.eps) | |
| class NormalizeAudio(object): | |
| def __init__(self, inverse=False, desired_rms=0.1, eps=1e-4): | |
| self.inverse = inverse | |
| self.desired_rms = desired_rms | |
| self.eps = torch.tensor(eps) | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2))) | |
| x = x * (self.desired_rms / rms) | |
| x[x > 1.] = 1. | |
| x[x < -1.] = -1. | |
| return x | |
| class RandomNormalizeAudio(object): | |
| def __init__(self, inverse=False, rms_range=[0.05, 0.2], eps=1e-4): | |
| self.inverse = inverse | |
| self.rms_low, self.rms_high = rms_range | |
| self.eps = torch.tensor(eps) | |
| def __call__(self, x): | |
| if self.inverse: | |
| return x | |
| else: | |
| rms = torch.maximum(self.eps, torch.sqrt(torch.mean(x**2))) | |
| desired_rms = (torch.rand(1) * (self.rms_high - self.rms_low)) + self.rms_low | |
| x = x * (desired_rms / rms) | |
| x[x > 1.] = 1. | |
| x[x < -1.] = -1. | |
| return x | |
| class MakeDouble(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return x.to(torch.double) | |
| class MakeFloat(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| def forward(self, x): | |
| return x.to(torch.float) | |
| class Wave2Spectrogram(nn.Module): | |
| def __init__(self, mel_num, spec_crop_len): | |
| super().__init__() | |
| self.trans = transforms.Compose([ | |
| LowerThresh(1e-5), | |
| Log10(), | |
| Multiply(20), | |
| Subtract(20), | |
| Add(100), | |
| Divide(100), | |
| Clip(0, 1.0), | |
| TrimSpec(173), | |
| transforms.CenterCrop((mel_num, spec_crop_len)) | |
| ]) | |
| def forward(self, x): | |
| return self.trans(x) | |
| TRANSFORMS = transforms.Compose([ | |
| SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1), | |
| MelScaleTorchAudio(sr=22050, stft=513, fmin=125, fmax=7600, nmels=80), | |
| LowerThresh(1e-5), | |
| Log10(), | |
| Multiply(20), | |
| Subtract(20), | |
| Add(100), | |
| Divide(100), | |
| Clip(0, 1.0), | |
| ]) | |
| def get_spectrogram_torch(audio_path, save_dir, length, save_results=True): | |
| wav, _ = soundfile.read(audio_path) | |
| wav = torch.FloatTensor(wav) | |
| y = torch.zeros(length) | |
| if wav.shape[0] < length: | |
| y[:len(wav)] = wav | |
| else: | |
| y = wav[:length] | |
| mel_spec = TRANSFORMS(y).numpy() | |
| y = y.numpy() | |
| if save_results: | |
| os.makedirs(save_dir, exist_ok=True) | |
| audio_name = os.path.basename(audio_path).split('.')[0] | |
| np.save(os.path.join(save_dir, audio_name + '_mel.npy'), mel_spec) | |
| np.save(os.path.join(save_dir, audio_name + '_audio.npy'), y) | |
| else: | |
| return y, mel_spec | |