Spaces:
Runtime error
Runtime error
| import importlib | |
| import numpy as np | |
| import io | |
| import os | |
| import posixpath | |
| import random | |
| import re | |
| import subprocess | |
| import time | |
| import torch | |
| import torchaudio | |
| import webdataset as wds | |
| from aeiou.core import is_silence | |
| from os import path | |
| from pedalboard.io import AudioFile | |
| from torchaudio import transforms as T | |
| from typing import Optional, Callable, List | |
| from torchdata.datapipes.iter import IterDataPipe, IterableWrapper | |
| from torchdata.datapipes.iter import Prefetcher | |
| from .utils import Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T | |
| import json | |
| import os | |
| import datetime | |
| from memory_profiler import profile | |
| AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus") | |
| # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py | |
| def fast_scandir( | |
| dir:str, # top-level directory at which to begin scanning | |
| ext:list, # list of allowed file extensions, | |
| #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB | |
| ): | |
| "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" | |
| subfolders, files = [], [] | |
| ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed | |
| try: # hope to avoid 'permission denied' by this try | |
| for f in os.scandir(dir): | |
| try: # 'hope to avoid too many levels of symbolic links' error | |
| if f.is_dir(): | |
| subfolders.append(f.path) | |
| elif f.is_file(): | |
| file_ext = os.path.splitext(f.name)[1].lower() | |
| is_hidden = os.path.basename(f.path).startswith(".") | |
| if file_ext in ext and not is_hidden: | |
| files.append(f.path) | |
| except: | |
| pass | |
| except: | |
| pass | |
| for dir in list(subfolders): | |
| sf, f = fast_scandir(dir, ext) | |
| subfolders.extend(sf) | |
| files.extend(f) | |
| return subfolders, files | |
| def extract_audio_paths(jsonl_file, exts): | |
| audio_paths = [] | |
| video_paths = [] | |
| text_prompts = [] | |
| data_types = [] | |
| with open(jsonl_file, 'r') as file: | |
| for line in file: | |
| try: | |
| data = json.loads(line.strip()) | |
| path = data.get('path', '') | |
| video_path = data.get('video_path', '') | |
| text_prompt = data.get('caption', '') | |
| data_type = data.get('type', None) | |
| if any(path.endswith(ext) for ext in exts): | |
| audio_paths.append(path) | |
| video_paths.append(video_path) | |
| text_prompts.append(text_prompt) | |
| data_types.append(data_type) | |
| except json.JSONDecodeError: | |
| print(f"Error decoding JSON line: {line.strip()}") | |
| return audio_paths, video_paths, text_prompts, data_types | |
| def keyword_scandir( | |
| dir: str, # top-level directory at which to begin scanning | |
| ext: list, # list of allowed file extensions | |
| keywords: list, # list of keywords to search for in the file name | |
| ): | |
| "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" | |
| subfolders, files = [], [] | |
| # make keywords case insensitive | |
| keywords = [keyword.lower() for keyword in keywords] | |
| # add starting period to extensions if needed | |
| ext = ['.'+x if x[0] != '.' else x for x in ext] | |
| banned_words = ["paxheader", "__macosx"] | |
| try: # hope to avoid 'permission denied' by this try | |
| for f in os.scandir(dir): | |
| try: # 'hope to avoid too many levels of symbolic links' error | |
| if f.is_dir(): | |
| subfolders.append(f.path) | |
| elif f.is_file(): | |
| is_hidden = f.name.split("/")[-1][0] == '.' | |
| has_ext = os.path.splitext(f.name)[1].lower() in ext | |
| name_lower = f.name.lower() | |
| has_keyword = any( | |
| [keyword in name_lower for keyword in keywords]) | |
| has_banned = any( | |
| [banned_word in name_lower for banned_word in banned_words]) | |
| if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"): | |
| files.append(f.path) | |
| except: | |
| pass | |
| except: | |
| pass | |
| for dir in list(subfolders): | |
| sf, f = keyword_scandir(dir, ext, keywords) | |
| subfolders.extend(sf) | |
| files.extend(f) | |
| return subfolders, files | |
| def get_audio_filenames( | |
| paths: list, # directories in which to search | |
| keywords=None, | |
| exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus'] | |
| ): | |
| "recursively get a list of audio filenames" | |
| filenames = [] | |
| video_filenames = [] | |
| text_prompts = [] | |
| data_types = [] | |
| if type(paths) is str: | |
| paths = [paths] | |
| if os.path.isdir(paths[0]): | |
| for path in paths: # get a list of relevant filenames | |
| if keywords is not None: | |
| subfolders, files = keyword_scandir(path, exts, keywords) | |
| else: | |
| subfolders, files = fast_scandir(path, exts) | |
| filenames.extend(files) | |
| return filenames | |
| elif os.path.isfile(paths[0]): | |
| assert paths[0].endswith('.jsonl') | |
| for path in paths: | |
| audio_paths, video_paths, text_prompt, data_type = extract_audio_paths(path, exts) | |
| filenames.extend(audio_paths) | |
| video_filenames.extend(video_paths) | |
| text_prompts.extend(text_prompt) | |
| data_types.extend(data_type) | |
| return filenames, video_filenames, text_prompts, data_types | |
| class LocalDatasetConfig: | |
| def __init__( | |
| self, | |
| id: str, | |
| path: str, | |
| video_fps: int, | |
| custom_metadata_fn: Optional[Callable[[str], str]] = None | |
| ): | |
| self.id = id | |
| self.path = path | |
| self.video_fps = video_fps | |
| self.custom_metadata_fn = custom_metadata_fn | |
| # @profile | |
| class SampleDataset(torch.utils.data.Dataset): | |
| # @profile | |
| def __init__( | |
| self, | |
| configs, | |
| sample_size=65536, | |
| sample_rate=48000, | |
| keywords=None, | |
| random_crop=True, | |
| force_channels="stereo", | |
| video_fps=5 | |
| ): | |
| super().__init__() | |
| self.filenames = [] | |
| self.video_filenames = [] | |
| self.text_prompts = [] | |
| self.data_types = [] | |
| self.augs = torch.nn.Sequential( | |
| PhaseFlipper(), | |
| ) | |
| self.root_paths = [] | |
| self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop) | |
| self.force_channels = force_channels | |
| self.encoding = torch.nn.Sequential( | |
| Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), | |
| Mono() if self.force_channels == "mono" else torch.nn.Identity(), | |
| ) | |
| self.sr = sample_rate | |
| self.custom_metadata_fns = {} | |
| for config in configs: | |
| self.video_fps = config.video_fps | |
| self.root_paths.append(config.path) | |
| audio_files, video_files, text_prompt, data_types = get_audio_filenames(config.path, keywords) | |
| self.filenames.extend(audio_files) | |
| self.video_filenames.extend(video_files) | |
| self.text_prompts.extend(text_prompt) | |
| self.data_types.extend(data_types) | |
| if config.custom_metadata_fn is not None: | |
| self.custom_metadata_fns[config.path] = config.custom_metadata_fn | |
| print(f'Found {len(self.filenames)} files') | |
| def load_file(self, filename): | |
| ext = filename.split(".")[-1] | |
| if ext == "mp3": | |
| with AudioFile(filename) as f: | |
| audio = f.read(f.frames) | |
| audio = torch.from_numpy(audio) | |
| in_sr = f.samplerate | |
| else: | |
| audio, in_sr = torchaudio.load(filename, format=ext) | |
| if in_sr != self.sr: | |
| resample_tf = T.Resample(in_sr, self.sr) | |
| audio = resample_tf(audio) | |
| return audio | |
| def __len__(self): | |
| return len(self.filenames) | |
| def __getitem__(self, idx): | |
| audio_filename = self.filenames[idx] | |
| video_filename = self.video_filenames[idx] | |
| text_prompt = self.text_prompts[idx] | |
| data_type = self.data_types[idx] | |
| try: | |
| start_time = time.time() | |
| audio = self.load_file(audio_filename) | |
| if data_type in ["text_condition-audio", "text_condition-music", | |
| "video_condition-audio", "video_condition-music", | |
| "text+video_condition-audio","text+video_condition-music"]: | |
| if_audio_contition = False | |
| audio_prompt = torch.zeros((2, self.sr * 10)) | |
| elif data_type in ["audio_condition-audio", "audio_condition-music", | |
| "uni_condition-audio", "uni_condition-music"]: | |
| if_audio_contition = True | |
| if if_audio_contition: | |
| audio_org = audio.clamp(-1, 1) | |
| audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio) | |
| if self.augs is not None: | |
| audio = self.augs(audio) | |
| audio = audio.clamp(-1, 1) | |
| if if_audio_contition: | |
| if data_type.split("-")[-1] == "audio": | |
| start_index = max(0, int((seconds_start) * self.sr)) | |
| end_index = int((seconds_start+10) * self.sr) | |
| audio_prompt = audio_org[:, start_index:end_index] | |
| elif data_type.split("-")[-1] == "music": | |
| if seconds_start < 10: | |
| start_index = 0 | |
| end_index = int(10 * self.sr) | |
| else: | |
| start_index = max(0, int((seconds_start - 10) * self.sr)) | |
| end_index = int(seconds_start * self.sr) | |
| audio_prompt = audio_org[:, start_index:end_index] | |
| # Encode the file to assist in prediction | |
| if self.encoding is not None: | |
| audio = self.encoding(audio) | |
| info = {} | |
| info["path"] = audio_filename | |
| info["video_path"] = video_filename | |
| info["text_prompt"] = text_prompt | |
| info["audio_prompt"] = audio_prompt | |
| info["data_type"] = data_type | |
| for root_path in self.root_paths: | |
| if root_path in audio_filename: | |
| info["relpath"] = path.relpath(audio_filename, root_path) | |
| info["timestamps"] = (t_start, t_end) | |
| info["seconds_start"] = seconds_start | |
| info["seconds_total"] = seconds_total | |
| info["padding_mask"] = padding_mask | |
| info["video_fps"] = self.video_fps | |
| end_time = time.time() | |
| info["load_time"] = end_time - start_time | |
| for custom_md_path in self.custom_metadata_fns.keys(): | |
| if os.path.isdir(custom_md_path): | |
| if custom_md_path in audio_filename: | |
| custom_metadata_fn = self.custom_metadata_fns[custom_md_path] | |
| custom_metadata = custom_metadata_fn(info, audio) | |
| info.update(custom_metadata) | |
| elif os.path.isfile(custom_md_path): | |
| custom_metadata_fn = self.custom_metadata_fns[custom_md_path] | |
| custom_metadata = custom_metadata_fn(info, audio) | |
| info.update(custom_metadata) | |
| if "__reject__" in info and info["__reject__"]: | |
| return self[random.randrange(len(self))] | |
| file_name = audio_filename.split('/')[-1] | |
| return (audio, info) | |
| except Exception as e: | |
| print(f'Couldn\'t load file {audio_filename}: {e}') | |
| return self[random.randrange(len(self))] | |
| def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None): | |
| """Return function over iterator that groups key, value pairs into samples. | |
| :param keys: function that splits the key into key and extension (base_plus_ext) | |
| :param lcase: convert suffixes to lower case (Default value = True) | |
| """ | |
| current_sample = None | |
| for filesample in data: | |
| assert isinstance(filesample, dict) | |
| fname, value = filesample["fname"], filesample["data"] | |
| prefix, suffix = keys(fname) | |
| if wds.tariterators.trace: | |
| print( | |
| prefix, | |
| suffix, | |
| current_sample.keys() if isinstance(current_sample, dict) else None, | |
| ) | |
| if prefix is None: | |
| continue | |
| if lcase: | |
| suffix = suffix.lower() | |
| if current_sample is None or prefix != current_sample["__key__"]: | |
| if wds.tariterators.valid_sample(current_sample): | |
| yield current_sample | |
| current_sample = dict(__key__=prefix, __url__=filesample["__url__"]) | |
| if suffix in current_sample: | |
| print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}") | |
| if suffixes is None or suffix in suffixes: | |
| current_sample[suffix] = value | |
| if wds.tariterators.valid_sample(current_sample): | |
| yield current_sample | |
| wds.tariterators.group_by_keys = group_by_keys | |
| # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py | |
| def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None): | |
| """ | |
| Returns a list of full S3 paths to files in a given S3 bucket and directory path. | |
| """ | |
| # Ensure dataset_path ends with a trailing slash | |
| if dataset_path != '' and not dataset_path.endswith('/'): | |
| dataset_path += '/' | |
| # Use posixpath to construct the S3 URL path | |
| bucket_path = posixpath.join(s3_url_prefix or '', dataset_path) | |
| # Construct the `aws s3 ls` command | |
| cmd = ['aws', 's3', 'ls', bucket_path] | |
| if profile is not None: | |
| cmd.extend(['--profile', profile]) | |
| if recursive: | |
| # Add the --recursive flag if requested | |
| cmd.append('--recursive') | |
| # Run the `aws s3 ls` command and capture the output | |
| run_ls = subprocess.run(cmd, capture_output=True, check=True) | |
| # Split the output into lines and strip whitespace from each line | |
| contents = run_ls.stdout.decode('utf-8').split('\n') | |
| contents = [x.strip() for x in contents if x] | |
| # Remove the timestamp from lines that begin with a timestamp | |
| contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x) | |
| if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents] | |
| # Construct a full S3 path for each file in the contents list | |
| contents = [posixpath.join(s3_url_prefix or '', x) | |
| for x in contents if not x.endswith('/')] | |
| # Apply the filter, if specified | |
| if filter: | |
| contents = [x for x in contents if filter in x] | |
| # Remove redundant directory names in the S3 URL | |
| if recursive: | |
| # Get the main directory name from the S3 URL | |
| main_dir = "/".join(bucket_path.split('/')[3:]) | |
| # Remove the redundant directory names from each file path | |
| contents = [x.replace(f'{main_dir}', '').replace( | |
| '//', '/') for x in contents] | |
| # Print debugging information, if requested | |
| if debug: | |
| print("contents = \n", contents) | |
| # Return the list of S3 paths to files | |
| return contents | |
| def get_all_s3_urls( | |
| names=[], # list of all valid [LAION AudioDataset] dataset names | |
| # list of subsets you want from those datasets, e.g. ['train','valid'] | |
| subsets=[''], | |
| s3_url_prefix=None, # prefix for those dataset names | |
| recursive=True, # recursively list all tar files in all subdirs | |
| filter_str='tar', # only grab files with this substring | |
| # print debugging info -- note: info displayed likely to change at dev's whims | |
| debug=False, | |
| profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} | |
| ): | |
| "get urls of shards (tar files) for multiple datasets in one s3 bucket" | |
| urls = [] | |
| for name in names: | |
| # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list | |
| if s3_url_prefix is None: | |
| contents_str = name | |
| else: | |
| # Construct the S3 path using the s3_url_prefix and the current name value | |
| contents_str = posixpath.join(s3_url_prefix, name) | |
| if debug: | |
| print(f"get_all_s3_urls: {contents_str}:") | |
| for subset in subsets: | |
| subset_str = posixpath.join(contents_str, subset) | |
| if debug: | |
| print(f"subset_str = {subset_str}") | |
| # Get the list of tar files in the current subset directory | |
| profile = profiles.get(name, None) | |
| tar_list = get_s3_contents( | |
| subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) | |
| for tar in tar_list: | |
| # Escape spaces and parentheses in the tar filename for use in the shell command | |
| tar = tar.replace(" ", "\ ").replace( | |
| "(", "\(").replace(")", "\)") | |
| # Construct the S3 path to the current tar file | |
| s3_path = posixpath.join(name, subset, tar) + " -" | |
| # Construct the AWS CLI command to download the current tar file | |
| if s3_url_prefix is None: | |
| request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" | |
| else: | |
| request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" | |
| if profiles.get(name): | |
| request_str += f" --profile {profiles.get(name)}" | |
| if debug: | |
| print("request_str = ", request_str) | |
| # Add the constructed URL to the list of URLs | |
| urls.append(request_str) | |
| return urls | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" | |
| print(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
| return True | |
| def is_valid_sample(sample): | |
| has_json = "json" in sample | |
| has_audio = "audio" in sample | |
| is_silent = is_silence(sample["audio"]) | |
| is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"] | |
| return has_json and has_audio and not is_silent and not is_rejected | |
| class S3DatasetConfig: | |
| def __init__( | |
| self, | |
| id: str, | |
| s3_path: str, | |
| custom_metadata_fn: Optional[Callable[[str], str]] = None, | |
| profile: Optional[str] = None, | |
| ): | |
| self.id = id | |
| self.path = s3_path | |
| self.custom_metadata_fn = custom_metadata_fn | |
| self.profile = profile | |
| self.urls = [] | |
| def load_data_urls(self): | |
| self.urls = get_all_s3_urls( | |
| names=[self.path], | |
| s3_url_prefix=None, | |
| recursive=True, | |
| profiles={self.path: self.profile} if self.profile else {}, | |
| ) | |
| return self.urls | |
| class LocalWebDatasetConfig: | |
| def __init__( | |
| self, | |
| id: str, | |
| path: str, | |
| custom_metadata_fn: Optional[Callable[[str], str]] = None, | |
| profile: Optional[str] = None, | |
| ): | |
| self.id = id | |
| self.path = path | |
| self.custom_metadata_fn = custom_metadata_fn | |
| self.urls = [] | |
| def load_data_urls(self): | |
| self.urls = fast_scandir(self.path, ["tar"])[1] | |
| return self.urls | |
| def audio_decoder(key, value): | |
| # Get file extension from key | |
| ext = key.split(".")[-1] | |
| if ext in AUDIO_KEYS: | |
| return torchaudio.load(io.BytesIO(value)) | |
| else: | |
| return None | |
| def collation_fn(samples): | |
| batched = list(zip(*samples)) | |
| result = [] | |
| for b in batched: | |
| if isinstance(b[0], (int, float)): | |
| b = np.array(b) | |
| elif isinstance(b[0], torch.Tensor): | |
| b = torch.stack(b) | |
| elif isinstance(b[0], np.ndarray): | |
| b = np.array(b) | |
| else: | |
| b = b | |
| result.append(b) | |
| return result | |
| class WebDatasetDataLoader(): | |
| def __init__( | |
| self, | |
| datasets: List[S3DatasetConfig], | |
| batch_size, | |
| sample_size, | |
| sample_rate=48000, | |
| num_workers=8, | |
| epoch_steps=1000, | |
| random_crop=True, | |
| force_channels="stereo", | |
| augment_phase=True, | |
| **data_loader_kwargs | |
| ): | |
| self.datasets = datasets | |
| self.sample_size = sample_size | |
| self.sample_rate = sample_rate | |
| self.random_crop = random_crop | |
| self.force_channels = force_channels | |
| self.augment_phase = augment_phase | |
| urls = [dataset.load_data_urls() for dataset in datasets] | |
| # Flatten the list of lists of URLs | |
| urls = [url for dataset_urls in urls for url in dataset_urls] | |
| # Shuffle the urls | |
| random.shuffle(urls) | |
| self.dataset = wds.DataPipeline( | |
| wds.ResampledShards(urls), | |
| wds.tarfile_to_samples(handler=log_and_continue), | |
| wds.decode(audio_decoder, handler=log_and_continue), | |
| wds.map(self.wds_preprocess, handler=log_and_continue), | |
| wds.select(is_valid_sample), | |
| wds.to_tuple("audio", "json", handler=log_and_continue), | |
| #wds.shuffle(bufsize=1000, initial=5000), | |
| wds.batched(batch_size, partial=False, collation_fn=collation_fn), | |
| ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) | |
| self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs) | |
| def wds_preprocess(self, sample): | |
| found_key, rewrite_key = '', '' | |
| for k, v in sample.items(): # print the all entries in dict | |
| for akey in AUDIO_KEYS: | |
| if k.endswith(akey): | |
| # to rename long/weird key with its simpler counterpart | |
| found_key, rewrite_key = k, akey | |
| break | |
| if '' != found_key: | |
| break | |
| if '' == found_key: # got no audio! | |
| return None # try returning None to tell WebDataset to skip this one | |
| audio, in_sr = sample[found_key] | |
| if in_sr != self.sample_rate: | |
| resample_tf = T.Resample(in_sr, self.sample_rate) | |
| audio = resample_tf(audio) | |
| if self.sample_size is not None: | |
| # Pad/crop and get the relative timestamp | |
| pad_crop = PadCrop_Normalized_T( | |
| self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate) | |
| audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop( | |
| audio) | |
| sample["json"]["seconds_start"] = seconds_start | |
| sample["json"]["seconds_total"] = seconds_total | |
| sample["json"]["padding_mask"] = padding_mask | |
| else: | |
| t_start, t_end = 0, 1 | |
| # Check if audio is length zero, initialize to a single zero if so | |
| if audio.shape[-1] == 0: | |
| audio = torch.zeros(1, 1) | |
| # Make the audio stereo and augment by randomly inverting phase | |
| augs = torch.nn.Sequential( | |
| Stereo() if self.force_channels == "stereo" else torch.nn.Identity(), | |
| Mono() if self.force_channels == "mono" else torch.nn.Identity(), | |
| PhaseFlipper() if self.augment_phase else torch.nn.Identity() | |
| ) | |
| audio = augs(audio) | |
| sample["json"]["timestamps"] = (t_start, t_end) | |
| if "text" in sample["json"]: | |
| sample["json"]["prompt"] = sample["json"]["text"] | |
| # Check for custom metadata functions | |
| for dataset in self.datasets: | |
| if dataset.custom_metadata_fn is None: | |
| continue | |
| if dataset.path in sample["__url__"]: | |
| custom_metadata = dataset.custom_metadata_fn(sample["json"], audio) | |
| sample["json"].update(custom_metadata) | |
| if found_key != rewrite_key: # rename long/weird key with its simpler counterpart | |
| del sample[found_key] | |
| sample["audio"] = audio | |
| # Add audio to the metadata as well for conditioning | |
| sample["json"]["audio"] = audio | |
| return sample | |
| def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4, video_fps=5): | |
| dataset_type = dataset_config.get("dataset_type", None) | |
| assert dataset_type is not None, "Dataset type must be specified in dataset config" | |
| if audio_channels == 1: | |
| force_channels = "mono" | |
| else: | |
| force_channels = "stereo" | |
| if dataset_type == "audio_dir": | |
| audio_dir_configs = dataset_config.get("datasets", None) | |
| assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" | |
| configs = [] | |
| for audio_dir_config in audio_dir_configs: | |
| audio_dir_path = audio_dir_config.get("path", None) | |
| assert audio_dir_path is not None, "Path must be set for local audio directory configuration" | |
| custom_metadata_fn = None | |
| custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) | |
| if custom_metadata_module_path is not None: | |
| spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) | |
| metadata_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(metadata_module) | |
| custom_metadata_fn = metadata_module.get_custom_metadata | |
| configs.append( | |
| LocalDatasetConfig( | |
| id=audio_dir_config["id"], | |
| path=audio_dir_path, | |
| custom_metadata_fn=custom_metadata_fn, | |
| video_fps=video_fps | |
| ) | |
| ) | |
| train_set = SampleDataset( | |
| configs, | |
| sample_rate=sample_rate, | |
| sample_size=sample_size, | |
| random_crop=dataset_config.get("random_crop", True), | |
| force_channels=force_channels, | |
| video_fps=video_fps | |
| ) | |
| return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True, | |
| num_workers=num_workers, persistent_workers=True, pin_memory=False, drop_last=True, collate_fn=collation_fn) | |
| elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility | |
| wds_configs = [] | |
| for wds_config in dataset_config["datasets"]: | |
| custom_metadata_fn = None | |
| custom_metadata_module_path = wds_config.get("custom_metadata_module", None) | |
| if custom_metadata_module_path is not None: | |
| spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) | |
| metadata_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(metadata_module) | |
| custom_metadata_fn = metadata_module.get_custom_metadata | |
| if "s3_path" in wds_config: | |
| wds_configs.append( | |
| S3DatasetConfig( | |
| id=wds_config["id"], | |
| s3_path=wds_config["s3_path"], | |
| custom_metadata_fn=custom_metadata_fn, | |
| profile=wds_config.get("profile", None), | |
| ) | |
| ) | |
| elif "path" in wds_config: | |
| wds_configs.append( | |
| LocalWebDatasetConfig( | |
| id=wds_config["id"], | |
| path=wds_config["path"], | |
| custom_metadata_fn=custom_metadata_fn | |
| ) | |
| ) | |
| return WebDatasetDataLoader( | |
| wds_configs, | |
| sample_rate=sample_rate, | |
| sample_size=sample_size, | |
| batch_size=batch_size, | |
| random_crop=dataset_config.get("random_crop", True), | |
| num_workers=num_workers, | |
| persistent_workers=True, | |
| force_channels=force_channels, | |
| epoch_steps=dataset_config.get("epoch_steps", 2000) | |
| ).data_loader | |
| def create_dataloader_from_config_valid(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4): | |
| dataset_type = dataset_config.get("dataset_type", None) | |
| assert dataset_type is not None, "Dataset type must be specified in dataset config" | |
| if audio_channels == 1: | |
| force_channels = "mono" | |
| else: | |
| force_channels = "stereo" | |
| if dataset_type == "audio_dir": | |
| audio_dir_configs = dataset_config.get("datasets", None) | |
| assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]" | |
| configs = [] | |
| for audio_dir_config in audio_dir_configs: | |
| audio_dir_path = audio_dir_config.get("path", None) | |
| assert audio_dir_path is not None, "Path must be set for local audio directory configuration" | |
| custom_metadata_fn = None | |
| custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None) | |
| if custom_metadata_module_path is not None: | |
| spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) | |
| metadata_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(metadata_module) | |
| custom_metadata_fn = metadata_module.get_custom_metadata | |
| configs.append( | |
| LocalDatasetConfig( | |
| id=audio_dir_config["id"], | |
| path=audio_dir_path, | |
| custom_metadata_fn=custom_metadata_fn | |
| ) | |
| ) | |
| valid_set = SampleDataset( | |
| configs, | |
| sample_rate=sample_rate, | |
| sample_size=sample_size, | |
| random_crop=dataset_config.get("random_crop", True), | |
| force_channels=force_channels | |
| ) | |
| return torch.utils.data.DataLoader(valid_set, batch_size, shuffle=False, | |
| num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn) | |
| elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility | |
| wds_configs = [] | |
| for wds_config in dataset_config["datasets"]: | |
| custom_metadata_fn = None | |
| custom_metadata_module_path = wds_config.get("custom_metadata_module", None) | |
| if custom_metadata_module_path is not None: | |
| spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path) | |
| metadata_module = importlib.util.module_from_spec(spec) | |
| spec.loader.exec_module(metadata_module) | |
| custom_metadata_fn = metadata_module.get_custom_metadata | |
| if "s3_path" in wds_config: | |
| wds_configs.append( | |
| S3DatasetConfig( | |
| id=wds_config["id"], | |
| s3_path=wds_config["s3_path"], | |
| custom_metadata_fn=custom_metadata_fn, | |
| profile=wds_config.get("profile", None), | |
| ) | |
| ) | |
| elif "path" in wds_config: | |
| wds_configs.append( | |
| LocalWebDatasetConfig( | |
| id=wds_config["id"], | |
| path=wds_config["path"], | |
| custom_metadata_fn=custom_metadata_fn | |
| ) | |
| ) | |
| return WebDatasetDataLoader( | |
| wds_configs, | |
| sample_rate=sample_rate, | |
| sample_size=sample_size, | |
| batch_size=batch_size, | |
| random_crop=dataset_config.get("random_crop", True), | |
| num_workers=num_workers, | |
| persistent_workers=True, | |
| force_channels=force_channels, | |
| epoch_steps=dataset_config.get("epoch_steps", 2000) | |
| ).data_loader | |