Spaces:
Build error
Build error
| import os | |
| import glob | |
| import torch | |
| import warnings | |
| import torchaudio | |
| import pyloudnorm as pyln | |
| class AudioFile(object): | |
| def __init__(self, filepath, preload=False, half=False, target_loudness=None): | |
| """Base class for audio files to handle metadata and loading. | |
| Args: | |
| filepath (str): Path to audio file to load from disk. | |
| preload (bool, optional): If set, load audio data into RAM. Default: False | |
| half (bool, optional): If set, store audio data as float16 to save space. Default: False | |
| target_loudness (float, optional): Loudness normalize to dB LUFS value. Default: | |
| """ | |
| super().__init__() | |
| self.filepath = filepath | |
| self.half = half | |
| self.target_loudness = target_loudness | |
| self.loaded = False | |
| if preload: | |
| self.load() | |
| num_frames = self.audio.shape[-1] | |
| num_channels = self.audio.shape[0] | |
| else: | |
| metadata = torchaudio.info(filepath) | |
| audio = None | |
| self.sample_rate = metadata.sample_rate | |
| num_frames = metadata.num_frames | |
| num_channels = metadata.num_channels | |
| self.num_frames = num_frames | |
| self.num_channels = num_channels | |
| def load(self): | |
| audio, sr = torchaudio.load(self.filepath, normalize=True) | |
| self.audio = audio | |
| self.sample_rate = sr | |
| if self.target_loudness is not None: | |
| self.loudness_normalize() | |
| if self.half: | |
| self.audio = audio.half() | |
| self.loaded = True | |
| def loudness_normalize(self): | |
| meter = pyln.Meter(self.sample_rate) | |
| # conver mono to stereo | |
| if self.audio.shape[0] == 1: | |
| tmp_audio = self.audio.repeat(2, 1) | |
| else: | |
| tmp_audio = self.audio | |
| # measure integrated loudness | |
| input_loudness = meter.integrated_loudness(tmp_audio.numpy().T) | |
| # compute and apply gain | |
| gain_dB = self.target_loudness - input_loudness | |
| gain_ln = 10 ** (gain_dB / 20.0) | |
| self.audio *= gain_ln | |
| # check for potentially clipped samples | |
| if self.audio.abs().max() >= 1.0: | |
| warnings.warn("Possible clipped samples in output.") | |
| class AudioFileDataset(torch.utils.data.Dataset): | |
| """Base class for audio file datasets loaded from disk. | |
| Datasets can be either paired or unpaired. A paired dataset requires passing the `target_dir` path. | |
| Args: | |
| input_dir (List[str]): List of paths to the directories containing input audio files. | |
| target_dir (List[str], optional): List of paths to the directories containing correponding audio files. Default: [] | |
| subset (str, optional): Dataset subset. One of ["train", "val", "test"]. Default: "train" | |
| length (int, optional): Number of samples to load for each example. Default: 65536 | |
| normalize (bool, optional): Normalize audio amplitiude to -1 to 1. Default: True | |
| train_frac (float, optional): Fraction of the files to use for training subset. Default: 0.8 | |
| val_frac (float, optional): Fraction of the files to use for validation subset. Default: 0.1 | |
| preload (bool, optional): Read audio files into RAM at the start of training. Default: False | |
| num_examples_per_epoch (int, optional): Define an epoch as certain number of audio examples. Default: 10000 | |
| ext (str, optional): Expected audio file extension. Default: "wav" | |
| """ | |
| def __init__( | |
| self, | |
| input_dirs, | |
| target_dirs=[], | |
| subset="train", | |
| length=65536, | |
| normalize=True, | |
| train_per=0.8, | |
| val_per=0.1, | |
| preload=False, | |
| num_examples_per_epoch=10000, | |
| ext="wav", | |
| ): | |
| super().__init__() | |
| self.input_dirs = input_dirs | |
| self.target_dirs = target_dirs | |
| self.subset = subset | |
| self.length = length | |
| self.normalize = normalize | |
| self.train_per = train_per | |
| self.val_per = val_per | |
| self.preload = preload | |
| self.num_examples_per_epoch = num_examples_per_epoch | |
| self.ext = ext | |
| self.input_filepaths = [] | |
| for input_dir in input_dirs: | |
| search_path = os.path.join(input_dir, f"*.{ext}") | |
| self.input_filepaths += glob.glob(search_path) | |
| self.input_filepaths = sorted(self.input_filepaths) | |
| self.target_filepaths = [] | |
| for target_dir in target_dirs: | |
| search_path = os.path.join(target_dir, f"*.{ext}") | |
| self.target_filepaths += glob.glob(search_path) | |
| self.target_filepaths = sorted(self.target_filepaths) | |
| # both sets must have same number of files in paired dataset | |
| assert len(self.target_filepaths) == len(self.input_filepaths) | |
| # get details about audio files | |
| self.input_files = [] | |
| for input_filepath in self.input_filepaths: | |
| self.input_files.append( | |
| AudioFile(input_filepath, preload=preload, normalize=normalize) | |
| ) | |
| self.target_files = [] | |
| if target_dir is not None: | |
| for target_filepath in self.target_filepaths: | |
| self.target_files.append( | |
| AudioFile(target_filepath, preload=preload, normalize=normalize) | |
| ) | |
| def __len__(self): | |
| return self.num_examples_per_epoch | |
| def __getitem__(self, idx): | |
| """ """ | |
| # index the current audio file | |
| input_file = self.input_files[idx] | |
| # load the audio data if needed | |
| if not input_file.loaded: | |
| input_file.load() | |
| # get a random patch of size `self.length` | |
| start_idx = int(torch.rand() * (input_file.num_frames - self.length)) | |
| stop_idx = start_idx + self.length | |
| input_audio = input_file.audio[:, start_idx:stop_idx] | |
| # if there is a target file, get it (and load) | |
| if len(self.target_files) > 0: | |
| target_file = self.target_files[idx] | |
| if not target_file.loaded: | |
| target_file.load() | |
| # use the same cropping indices | |
| target_audio = target_file.audio[:, start_idx:stop_idx] | |
| return input_audio, target_audio | |
| else: | |
| return input_audio | |