Spaces:
Paused
Paused
| import json | |
| import os | |
| import matplotlib.pyplot as plt | |
| import torch | |
| from torchvision import transforms | |
| import numpy as np | |
| from tqdm import tqdm | |
| from random import sample | |
| import torchaudio | |
| import logging | |
| from glob import glob | |
| import sys | |
| import soundfile | |
| import copy | |
| import csv | |
| import noisereduce as nr | |
| sys.path.insert(0, '.') # nopep8 | |
| from train import instantiate_from_config | |
| from foleycrafter.models.specvqgan.data.transforms import * | |
| torchaudio.set_audio_backend("sox_io") | |
| logger = logging.getLogger(f'main.{__name__}') | |
| SR = 22050 | |
| FPS = 15 | |
| MAX_SAMPLE_ITER = 10 | |
| def non_negative(x): return int(np.round(max(0, x), 0)) | |
| def rms(x): return np.sqrt(np.mean(x**2)) | |
| def get_GH_data_identifier(video_name, start_idx, split='_'): | |
| if isinstance(start_idx, str): | |
| return video_name + split + start_idx | |
| elif isinstance(start_idx, int): | |
| return video_name + split + str(start_idx) | |
| else: | |
| raise NotImplementedError | |
| def draw_spec(spec, dest, cmap='magma'): | |
| plt.imshow(spec, cmap=cmap, origin='lower') | |
| plt.axis('off') | |
| plt.savefig(dest, bbox_inches='tight', pad_inches=0., dpi=300) | |
| plt.close() | |
| def convert_to_decibel(arr): | |
| ref = 1 | |
| return 20 * np.log10(abs(arr + 1e-4) / ref) | |
| class ResampleFrames(object): | |
| def __init__(self, feat_sample_size, times_to_repeat_after_resample=None): | |
| self.feat_sample_size = feat_sample_size | |
| self.times_to_repeat_after_resample = times_to_repeat_after_resample | |
| def __call__(self, item): | |
| feat_len = item['feature'].shape[0] | |
| ## resample | |
| assert feat_len >= self.feat_sample_size | |
| # evenly spaced points (abcdefghkl -> aoooofoooo) | |
| idx = np.linspace(0, feat_len, self.feat_sample_size, dtype=np.int, endpoint=False) | |
| # xoooo xoooo -> ooxoo ooxoo | |
| shift = feat_len // (self.feat_sample_size + 1) | |
| idx = idx + shift | |
| ## repeat after resampling (abc -> aaaabbbbcccc) | |
| if self.times_to_repeat_after_resample is not None and self.times_to_repeat_after_resample > 1: | |
| idx = np.repeat(idx, self.times_to_repeat_after_resample) | |
| item['feature'] = item['feature'][idx, :] | |
| return item | |
| class ImpactSetWave(torch.utils.data.Dataset): | |
| def __init__(self, split, random_crop, mel_num, spec_crop_len, | |
| L=2.0, denoise=False, splits_path='./data', | |
| data_path='data/ImpactSet/impactset-proccess-resize'): | |
| super().__init__() | |
| self.split = split | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.L = L | |
| self.denoise = denoise | |
| video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') | |
| if not os.path.exists(video_name_split_path): | |
| self.make_split_files() | |
| video_name = json.load(open(video_name_split_path, 'r')) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} | |
| self.dataset = video_name | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video = self.dataset[idx] | |
| available_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| wav = None | |
| spec = None | |
| max_db = -np.inf | |
| wave_path = '' | |
| cur_wave_path = self.video_audio_path[video] | |
| if self.denoise: | |
| cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav') | |
| for _ in range(10): | |
| start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] | |
| # target | |
| start_t = (start_idx + 0.5) / FPS | |
| start_audio_idx = non_negative(start_t * SR) | |
| cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx) | |
| decibel = convert_to_decibel(cur_wav) | |
| if float(np.mean(decibel)) > max_db: | |
| wav = cur_wav | |
| wave_path = cur_wave_path | |
| max_db = float(np.mean(decibel)) | |
| if max_db >= -40: | |
| break | |
| # print(max_db) | |
| wav = self.wav_transforms(wav) | |
| item['image'] = wav # (80, 173) | |
| # item['wav'] = wav | |
| item['file_path_wav_'] = wave_path | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| def make_split_files(self): | |
| raise NotImplementedError | |
| class ImpactSetWaveTrain(ImpactSetWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class ImpactSetWaveValidation(ImpactSetWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class ImpactSetWaveTest(ImpactSetWave): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class ImpactSetSpec(torch.utils.data.Dataset): | |
| def __init__(self, split, random_crop, mel_num, spec_crop_len, | |
| L=2.0, denoise=False, splits_path='./data', | |
| data_path='data/ImpactSet/impactset-proccess-resize'): | |
| super().__init__() | |
| self.split = split | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.L = L | |
| self.denoise = denoise | |
| video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') | |
| if not os.path.exists(video_name_split_path): | |
| self.make_split_files() | |
| video_name = json.load(open(video_name_split_path, 'r')) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} | |
| self.dataset = video_name | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1), | |
| MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80), | |
| LowerThresh(1e-5), | |
| Log10(), | |
| Multiply(20), | |
| Subtract(20), | |
| Add(100), | |
| Divide(100), | |
| Clip(0, 1.0), | |
| TrimSpec(173), | |
| ]) | |
| self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video = self.dataset[idx] | |
| available_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| wav = None | |
| spec = None | |
| max_rms = -np.inf | |
| wave_path = '' | |
| cur_wave_path = self.video_audio_path[video] | |
| if self.denoise: | |
| cur_wave_path = cur_wave_path.replace('.wav', '_denoised.wav') | |
| for _ in range(10): | |
| start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] | |
| # target | |
| start_t = (start_idx + 0.5) / FPS | |
| start_audio_idx = non_negative(start_t * SR) | |
| cur_wav, _ = soundfile.read(cur_wave_path, frames=int(SR * self.L), start=start_audio_idx) | |
| if self.wav_transforms is not None: | |
| spec_tensor = self.wav_transforms(torch.tensor(cur_wav).float()) | |
| cur_spec = spec_tensor.numpy() | |
| # zeros padding if not enough spec t steps | |
| if cur_spec.shape[1] < 173: | |
| pad = np.zeros((80, 173), dtype=cur_spec.dtype) | |
| pad[:, :cur_spec.shape[1]] = cur_spec | |
| cur_spec = pad | |
| rms_val = rms(cur_spec) | |
| if rms_val > max_rms: | |
| wav = cur_wav | |
| spec = cur_spec | |
| wave_path = cur_wave_path | |
| max_rms = rms_val | |
| # print(rms_val) | |
| if max_rms >= 0.1: | |
| break | |
| item['image'] = 2 * spec - 1 # (80, 173) | |
| # item['wav'] = wav | |
| item['file_path_wav_'] = wave_path | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| if self.spec_transforms is not None: | |
| item = self.spec_transforms(item) | |
| return item | |
| def make_split_files(self): | |
| raise NotImplementedError | |
| class ImpactSetSpecTrain(ImpactSetSpec): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class ImpactSetSpecValidation(ImpactSetSpec): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class ImpactSetSpecTest(ImpactSetSpec): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class ImpactSetWaveTestTime(torch.utils.data.Dataset): | |
| def __init__(self, split, random_crop, mel_num, spec_crop_len, | |
| L=2.0, denoise=False, splits_path='./data', | |
| data_path='data/ImpactSet/impactset-proccess-resize'): | |
| super().__init__() | |
| self.split = split | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.L = L | |
| self.denoise = denoise | |
| self.video_list = glob('data/ImpactSet/RawVideos/StockVideo_sound/*.wav') + [ | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/1_ckbCU5aQs/1_ckbCU5aQs_0013_0016_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/GFmuVBiwz6k/GFmuVBiwz6k_0034_0054_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/OsPcY316h1M/OsPcY316h1M_0000_0005_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/SExIpBIBj_k/SExIpBIBj_k_0009_0019_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/S6TkbV4B4QI/S6TkbV4B4QI_0028_0036_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/2Ld24pPIn3k/2Ld24pPIn3k_0005_0011_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/6d1YS7fdBK4/6d1YS7fdBK4_0007_0019_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/JnBsmJgEkiw/JnBsmJgEkiw_0008_0016_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/xcUyiXt0gjo/xcUyiXt0gjo_0015_0021_resize.wav', | |
| 'data/ImpactSet/RawVideos/YouTube-impact-ccl/4DRFJnZjpMM/4DRFJnZjpMM_0000_0010_resize.wav' | |
| ] + glob('data/ImpactSet/RawVideos/self_recorded/*_resize.wav') | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| SpectrogramTorchAudio(nfft=1024, hoplen=1024//4, spec_power=1), | |
| MelScaleTorchAudio(sr=SR, stft=513, fmin=125, fmax=7600, nmels=80), | |
| LowerThresh(1e-5), | |
| Log10(), | |
| Multiply(20), | |
| Subtract(20), | |
| Add(100), | |
| Divide(100), | |
| Clip(0, 1.0), | |
| TrimSpec(173), | |
| ]) | |
| self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) | |
| def __len__(self): | |
| return len(self.video_list) | |
| def __getitem__(self, idx): | |
| item = {} | |
| wave_path = self.video_list[idx] | |
| wav, _ = soundfile.read(wave_path) | |
| start_idx = random.randint(0, min(4, wav.shape[0] - int(SR * self.L))) | |
| wav = wav[start_idx:start_idx+int(SR * self.L)] | |
| if self.denoise: | |
| if len(wav.shape) == 1: | |
| wav = wav[None, :] | |
| wav = nr.reduce_noise(y=wav, sr=SR, n_fft=1024, hop_length=1024//4) | |
| wav = wav.squeeze() | |
| if self.wav_transforms is not None: | |
| spec_tensor = self.wav_transforms(torch.tensor(wav).float()) | |
| spec = spec_tensor.numpy() | |
| if spec.shape[1] < 173: | |
| pad = np.zeros((80, 173), dtype=spec.dtype) | |
| pad[:, :spec.shape[1]] = spec | |
| spec = pad | |
| item['image'] = 2 * spec - 1 # (80, 173) | |
| # item['wav'] = wav | |
| item['file_path_wav_'] = wave_path | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| if self.spec_transforms is not None: | |
| item = self.spec_transforms(item) | |
| return item | |
| def make_split_files(self): | |
| raise NotImplementedError | |
| class ImpactSetWaveTestTimeTrain(ImpactSetWaveTestTime): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class ImpactSetWaveTestTimeValidation(ImpactSetWaveTestTime): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class ImpactSetWaveTestTimeTest(ImpactSetWaveTestTime): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class ImpactSetWaveWithSilent(torch.utils.data.Dataset): | |
| def __init__(self, split, random_crop, mel_num, spec_crop_len, | |
| L=2.0, denoise=False, splits_path='./data', | |
| data_path='data/ImpactSet/impactset-proccess-resize'): | |
| super().__init__() | |
| self.split = split | |
| self.splits_path = splits_path | |
| self.data_path = data_path | |
| self.L = L | |
| self.denoise = denoise | |
| video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') | |
| if not os.path.exists(video_name_split_path): | |
| self.make_split_files() | |
| video_name = json.load(open(video_name_split_path, 'r')) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} | |
| self.dataset = video_name | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| self.spec_transforms = CropImage([mel_num, spec_crop_len], random_crop) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video = self.dataset[idx] | |
| available_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| wave_path = self.video_audio_path[video] | |
| if self.denoise: | |
| wave_path = wave_path.replace('.wav', '_denoised.wav') | |
| start_idx = torch.randint(0, available_frame_idx, (1,)).tolist()[0] | |
| # target | |
| start_t = (start_idx + 0.5) / FPS | |
| start_audio_idx = non_negative(start_t * SR) | |
| wav, _ = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) | |
| wav = self.wav_transforms(wav) | |
| item['image'] = wav # (44100,) | |
| # item['wav'] = wav | |
| item['file_path_wav_'] = wave_path | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| def make_split_files(self): | |
| raise NotImplementedError | |
| class ImpactSetWaveWithSilentTrain(ImpactSetWaveWithSilent): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('train', **specs_dataset_cfg) | |
| class ImpactSetWaveWithSilentValidation(ImpactSetWaveWithSilent): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('val', **specs_dataset_cfg) | |
| class ImpactSetWaveWithSilentTest(ImpactSetWaveWithSilent): | |
| def __init__(self, specs_dataset_cfg): | |
| super().__init__('test', **specs_dataset_cfg) | |
| class ImpactSetWaveCondOnImage(torch.utils.data.Dataset): | |
| def __init__(self, split, | |
| L=2.0, frame_transforms=None, denoise=False, splits_path='./data', | |
| data_path='data/ImpactSet/impactset-proccess-resize', | |
| p_outside_cond=0.): | |
| super().__init__() | |
| self.split = split | |
| self.splits_path = splits_path | |
| self.frame_transforms = frame_transforms | |
| self.data_path = data_path | |
| self.L = L | |
| self.denoise = denoise | |
| self.p_outside_cond = torch.tensor(p_outside_cond) | |
| video_name_split_path = os.path.join(splits_path, f'countixAV_{split}.json') | |
| if not os.path.exists(video_name_split_path): | |
| self.make_split_files() | |
| video_name = json.load(open(video_name_split_path, 'r')) | |
| self.video_frame_cnt = {v: len(os.listdir(os.path.join(self.data_path, v, 'frames'))) for v in video_name} | |
| self.left_over = int(FPS * L + 1) | |
| for v, cnt in self.video_frame_cnt.items(): | |
| if cnt - (3*self.left_over) <= 0: | |
| video_name.remove(v) | |
| self.video_audio_path = {v: os.path.join(self.data_path, v, f'audio/{v}_resampled.wav') for v in video_name} | |
| self.dataset = video_name | |
| video_timing_split_path = os.path.join(splits_path, f'countixAV_{split}_timing.json') | |
| self.video_timing = json.load(open(video_timing_split_path, 'r')) | |
| self.video_timing = {v: [int(float(t) * FPS) for t in ts] for v, ts in self.video_timing.items()} | |
| if split != 'test': | |
| video_class_path = os.path.join(splits_path, f'countixAV_{split}_class.json') | |
| if not os.path.exists(video_class_path): | |
| self.make_video_class() | |
| self.video_class = json.load(open(video_class_path, 'r')) | |
| self.class2video = {} | |
| for v, c in self.video_class.items(): | |
| if c not in self.class2video.keys(): | |
| self.class2video[c] = [] | |
| self.class2video[c].append(v) | |
| self.wav_transforms = transforms.Compose([ | |
| MakeMono(), | |
| Padding(target_len=int(SR * self.L)), | |
| ]) | |
| if self.frame_transforms == None: | |
| self.frame_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.1, saturation=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| def make_video_class(self): | |
| meta_path = f'data/ImpactSet/data-info/CountixAV_{self.split}.csv' | |
| video_class = {} | |
| with open(meta_path, 'r') as f: | |
| reader = csv.reader(f) | |
| for i, row in enumerate(reader): | |
| if i == 0: | |
| continue | |
| vid, k_st, k_et = row[:3] | |
| video_name = f'{vid}_{int(k_st):0>4d}_{int(k_et):0>4d}' | |
| if video_name not in self.dataset: | |
| continue | |
| video_class[video_name] = row[-1] | |
| with open(os.path.join(self.splits_path, f'countixAV_{self.split}_class.json'), 'w') as f: | |
| json.dump(video_class, f) | |
| def __len__(self): | |
| return len(self.dataset) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video = self.dataset[idx] | |
| available_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| rep_start_idx, rep_end_idx = self.video_timing[video] | |
| rep_end_idx = min(available_frame_idx, rep_end_idx) | |
| if available_frame_idx <= rep_start_idx + self.L * FPS: | |
| idx_set = list(range(0, available_frame_idx)) | |
| else: | |
| idx_set = list(range(rep_start_idx, rep_end_idx)) | |
| start_idx = sample(idx_set, k=1)[0] | |
| wave_path = self.video_audio_path[video] | |
| if self.denoise: | |
| wave_path = wave_path.replace('.wav', '_denoised.wav') | |
| # target | |
| start_t = (start_idx + 0.5) / FPS | |
| end_idx= non_negative(start_idx + FPS * self.L) | |
| start_audio_idx = non_negative(start_t * SR) | |
| wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) | |
| assert sr == SR | |
| wav = self.wav_transforms(wav) | |
| frame_path = os.path.join(self.data_path, video, 'frames') | |
| frames = [Image.open(os.path.join( | |
| frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(start_idx, end_idx)] | |
| if torch.all(torch.bernoulli(self.p_outside_cond) == 1.) and self.split != 'test': | |
| # outside from the same class | |
| cur_class = self.video_class[video] | |
| tmp_video = copy.copy(self.class2video[cur_class]) | |
| if len(tmp_video) > 1: | |
| # if only 1 video in the class, use itself | |
| tmp_video.remove(video) | |
| cond_video = sample(tmp_video, k=1)[0] | |
| cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over | |
| cond_start_idx = torch.randint(0, cond_available_frame_idx, (1,)).tolist()[0] | |
| else: | |
| cond_video = video | |
| idx_set = list(range(0, start_idx)) + list(range(end_idx, available_frame_idx)) | |
| cond_start_idx = random.sample(idx_set, k=1)[0] | |
| cond_end_idx = non_negative(cond_start_idx + FPS * self.L) | |
| cond_start_t = (cond_start_idx + 0.5) / FPS | |
| cond_audio_idx = non_negative(cond_start_t * SR) | |
| cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') | |
| cond_wave_path = self.video_audio_path[cond_video] | |
| cond_frames = [Image.open(os.path.join( | |
| cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(cond_start_idx, cond_end_idx)] | |
| cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx) | |
| assert sr == SR | |
| cond_wav = self.wav_transforms(cond_wav) | |
| item['image'] = wav # (44100,) | |
| item['cond_image'] = cond_wav # (44100,) | |
| item['file_path_wav_'] = wave_path | |
| item['file_path_cond_wav_'] = cond_wave_path | |
| if self.frame_transforms is not None: | |
| cond_frames = self.frame_transforms(cond_frames) | |
| frames = self.frame_transforms(frames) | |
| item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
| item['file_path_feats_'] = (frame_path, start_idx) | |
| item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| def make_split_files(self): | |
| raise NotImplementedError | |
| class ImpactSetWaveCondOnImageTrain(ImpactSetWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| train_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
| class ImpactSetWaveCondOnImageValidation(ImpactSetWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| valid_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
| class ImpactSetWaveCondOnImageTest(ImpactSetWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| test_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
| class ImpactSetCleanWaveCondOnImage(ImpactSetWaveCondOnImage): | |
| def __init__(self, split, L=2, frame_transforms=None, denoise=False, splits_path='./data', data_path='data/ImpactSet/impactset-proccess-resize', p_outside_cond=0): | |
| super().__init__(split, L, frame_transforms, denoise, splits_path, data_path, p_outside_cond) | |
| pred_timing_path = f'data/countixAV_{split}_timing_processed_0.20.json' | |
| assert os.path.exists(pred_timing_path) | |
| self.pred_timing = json.load(open(pred_timing_path, 'r')) | |
| self.dataset = [] | |
| for v, ts in self.pred_timing.items(): | |
| if v in self.video_audio_path.keys(): | |
| for t in ts: | |
| self.dataset.append([v, t]) | |
| def __getitem__(self, idx): | |
| item = {} | |
| video, start_t = self.dataset[idx] | |
| available_frame_idx = self.video_frame_cnt[video] - self.left_over | |
| available_timing = (available_frame_idx + 0.5) / FPS | |
| start_t = float(start_t) | |
| start_t = min(start_t, available_timing) | |
| start_idx = non_negative(start_t * FPS - 0.5) | |
| wave_path = self.video_audio_path[video] | |
| if self.denoise: | |
| wave_path = wave_path.replace('.wav', '_denoised.wav') | |
| # target | |
| end_idx= non_negative(start_idx + FPS * self.L) | |
| start_audio_idx = non_negative(start_t * SR) | |
| wav, sr = soundfile.read(wave_path, frames=int(SR * self.L), start=start_audio_idx) | |
| assert sr == SR | |
| wav = self.wav_transforms(wav) | |
| frame_path = os.path.join(self.data_path, video, 'frames') | |
| frames = [Image.open(os.path.join( | |
| frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(start_idx, end_idx)] | |
| if torch.all(torch.bernoulli(self.p_outside_cond) == 1.): | |
| other_video = list(self.pred_timing.keys()) | |
| other_video.remove(video) | |
| cond_video = sample(other_video, k=1)[0] | |
| cond_available_frame_idx = self.video_frame_cnt[cond_video] - self.left_over | |
| cond_available_timing = (cond_available_frame_idx + 0.5) / FPS | |
| else: | |
| cond_video = video | |
| cond_available_timing = available_timing | |
| cond_start_t = sample(self.pred_timing[cond_video], k=1)[0] | |
| cond_start_t = float(cond_start_t) | |
| cond_start_t = min(cond_start_t, cond_available_timing) | |
| cond_start_idx = non_negative(cond_start_t * FPS - 0.5) | |
| cond_end_idx = non_negative(cond_start_idx + FPS * self.L) | |
| cond_audio_idx = non_negative(cond_start_t * SR) | |
| cond_frame_path = os.path.join(self.data_path, cond_video, 'frames') | |
| cond_wave_path = self.video_audio_path[cond_video] | |
| cond_frames = [Image.open(os.path.join( | |
| cond_frame_path, f'frame{i+1:0>6d}.jpg')).convert('RGB') for i in | |
| range(cond_start_idx, cond_end_idx)] | |
| cond_wav, sr = soundfile.read(cond_wave_path, frames=int(SR * self.L), start=cond_audio_idx) | |
| assert sr == SR | |
| cond_wav = self.wav_transforms(cond_wav) | |
| item['image'] = wav # (44100,) | |
| item['cond_image'] = cond_wav # (44100,) | |
| item['file_path_wav_'] = wave_path | |
| item['file_path_cond_wav_'] = cond_wave_path | |
| if self.frame_transforms is not None: | |
| cond_frames = self.frame_transforms(cond_frames) | |
| frames = self.frame_transforms(frames) | |
| item['feature'] = np.stack(cond_frames + frames, axis=0) # (30 * L, 112, 112, 3) | |
| item['file_path_feats_'] = (frame_path, start_idx) | |
| item['file_path_cond_feats_'] = (cond_frame_path, cond_start_idx) | |
| item['label'] = 'None' | |
| item['target'] = 'None' | |
| return item | |
| class ImpactSetCleanWaveCondOnImageTrain(ImpactSetCleanWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| train_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| RandomResizedCrop3D(112, scale=(0.5, 1.0)), | |
| RandomHorizontalFlip3D(), | |
| ColorJitter3D(brightness=0.4, saturation=0.4, contrast=0.2, hue=0.1), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('train', frame_transforms=train_transforms, **dataset_cfg) | |
| class ImpactSetCleanWaveCondOnImageValidation(ImpactSetCleanWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| valid_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('val', frame_transforms=valid_transforms, **dataset_cfg) | |
| class ImpactSetCleanWaveCondOnImageTest(ImpactSetCleanWaveCondOnImage): | |
| def __init__(self, dataset_cfg): | |
| test_transforms = transforms.Compose([ | |
| Resize3D(128), | |
| CenterCrop3D(112), | |
| ToTensor3D(), | |
| Normalize3D(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]), | |
| ]) | |
| super().__init__('test', frame_transforms=test_transforms, **dataset_cfg) | |
| if __name__ == '__main__': | |
| import sys | |
| from omegaconf import OmegaConf | |
| cfg = OmegaConf.load('configs/countixAV_transformer_denoise_clean.yaml') | |
| data = instantiate_from_config(cfg.data) | |
| data.prepare_data() | |
| data.setup() | |
| print(data.datasets['train']) | |
| print(len(data.datasets['train'])) | |
| # print(data.datasets['train'][24]) | |
| exit() | |
| stats = [] | |
| torch.manual_seed(0) | |
| np.random.seed(0) | |
| random.seed = 0 | |
| for k in range(1): | |
| x = np.arange(SR * 2) | |
| for i in tqdm(range(len(data.datasets['train']))): | |
| wav = data.datasets['train'][i]['wav'] | |
| spec = data.datasets['train'][i]['image'] | |
| spec = 0.5 * (spec + 1) | |
| spec_rms = rms(spec) | |
| stats.append(float(spec_rms)) | |
| # plt.plot(x, wav) | |
| # plt.ylim(-1, 1) | |
| # plt.savefig(f'tmp/th0.1_wav_e_{k}_{i}_{mean_val:.3f}_{spec_rms:.3f}.png') | |
| # plt.close() | |
| # plt.cla() | |
| soundfile.write(f'tmp/wav_e_{k}_{i}_{spec_rms:.3f}.wav', wav, SR) | |
| draw_spec(spec, f'tmp/wav_spec_e_{k}_{i}_{spec_rms:.3f}.png') | |
| if i == 100: | |
| break | |
| # plt.hist(stats, bins=50) | |
| # plt.savefig(f'tmp/rms_spec_stats.png') | |