Spaces:
Configuration error
Configuration error
| import logging | |
| import os | |
| from pathlib import Path | |
| from typing import Union | |
| import open_clip | |
| import pandas as pd | |
| import torch | |
| import torchaudio | |
| from torch.utils.data.dataset import Dataset | |
| log = logging.getLogger() | |
| class WavTextClipsDataset(Dataset): | |
| def __init__( | |
| self, | |
| root: Union[str, Path], | |
| *, | |
| captions_tsv: Union[str, Path], | |
| clips_tsv: Union[str, Path], | |
| sample_rate: int, | |
| num_samples: int, | |
| normalize_audio: bool = False, | |
| reject_silent: bool = False, | |
| tokenizer_id: str = 'ViT-H-14-378-quickgelu', | |
| ): | |
| self.root = Path(root) | |
| self.sample_rate = sample_rate | |
| self.num_samples = num_samples | |
| self.normalize_audio = normalize_audio | |
| self.reject_silent = reject_silent | |
| self.tokenizer = open_clip.get_tokenizer(tokenizer_id) | |
| audios = sorted(os.listdir(self.root)) | |
| audios = set([ | |
| Path(audio).stem for audio in audios | |
| if audio.endswith('.wav') or audio.endswith('.flac') | |
| ]) | |
| self.captions = {} | |
| # read the caption tsv | |
| df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records') | |
| for record in df_list: | |
| id = record['id'] | |
| caption = record['caption'] | |
| self.captions[id] = caption | |
| # read the clip tsv | |
| df_list = pd.read_csv(clips_tsv, sep='\t', dtype={ | |
| 'id': str, | |
| 'name': str | |
| }).to_dict('records') | |
| self.clips = [] | |
| for record in df_list: | |
| record['id'] = record['id'] | |
| record['name'] = record['name'] | |
| id = record['id'] | |
| name = record['name'] | |
| if name not in self.captions: | |
| log.warning(f'Audio {name} not found in {captions_tsv}') | |
| continue | |
| record['caption'] = self.captions[name] | |
| self.clips.append(record) | |
| log.info(f'Found {len(self.clips)} audio files in {self.root}') | |
| self.resampler = {} | |
| def __getitem__(self, idx: int) -> torch.Tensor: | |
| try: | |
| clip = self.clips[idx] | |
| audio_name = clip['name'] | |
| audio_id = clip['id'] | |
| caption = clip['caption'] | |
| start_sample = clip['start_sample'] | |
| end_sample = clip['end_sample'] | |
| audio_path = self.root / f'{audio_name}.flac' | |
| if not audio_path.exists(): | |
| audio_path = self.root / f'{audio_name}.wav' | |
| assert audio_path.exists() | |
| audio_chunk, sample_rate = torchaudio.load(audio_path) | |
| audio_chunk = audio_chunk.mean(dim=0) # mono | |
| abs_max = audio_chunk.abs().max() | |
| if self.normalize_audio: | |
| audio_chunk = audio_chunk / abs_max * 0.95 | |
| if self.reject_silent and abs_max < 1e-6: | |
| log.warning(f'Rejecting silent audio') | |
| return None | |
| audio_chunk = audio_chunk[start_sample:end_sample] | |
| # resample | |
| if sample_rate == self.sample_rate: | |
| audio_chunk = audio_chunk | |
| else: | |
| if sample_rate not in self.resampler: | |
| # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best | |
| self.resampler[sample_rate] = torchaudio.transforms.Resample( | |
| sample_rate, | |
| self.sample_rate, | |
| lowpass_filter_width=64, | |
| rolloff=0.9475937167399596, | |
| resampling_method='sinc_interp_kaiser', | |
| beta=14.769656459379492, | |
| ) | |
| audio_chunk = self.resampler[sample_rate](audio_chunk) | |
| if audio_chunk.shape[0] < self.num_samples: | |
| raise ValueError('Audio is too short') | |
| audio_chunk = audio_chunk[:self.num_samples] | |
| tokens = self.tokenizer([caption])[0] | |
| output = { | |
| 'waveform': audio_chunk, | |
| 'id': audio_id, | |
| 'caption': caption, | |
| 'tokens': tokens, | |
| } | |
| return output | |
| except Exception as e: | |
| log.error(f'Error reading {audio_path}: {e}') | |
| return None | |
| def __len__(self): | |
| return len(self.clips) | |