genau-demo / GenAU /src /utilities /data /videoaudio_dataset.py
Moayed's picture
add demo files
cef9e84
import os
import pandas as pd
from librosa.filters import mel as librosa_mel_fn
import random
from torch.utils.data import Dataset
import torch.nn.functional
import torch
import numpy as np
import torchaudio
from moviepy.editor import VideoFileClip
from torchvision import transforms as T
import src.utilities.audio as Audio
import numpy as np
from torch.utils.data.dataloader import default_collate
import shutil
import re
from src.tools.io import load_file, write_json, load_json
from src.tools.torch_utils import spectral_normalize_torch, random_uniform
from src.tools.training_utils import build_dataset_json_from_list
import gc
import librosa
import threading
class VideoAudioDataset(Dataset):
def __init__(
self,
config=None,
load_video=True,
load_audio=True,
keep_audio_files=True,
video_transform=None,
target_frame_cnt=10,
split="train",
waveform_only=False,
add_ons=[],
dataset_json=None,
sample_single_caption=True,
augment_p=0.0,
limit_data_percentage = None,
cache_dir=None
):
"""
Dataset that manages audio recordings
:param audio_conf: Dictionary containing the audio loading and preprocessing settings
:param dataset_json_file
"""
self.target_frame_cnt = target_frame_cnt
self.config = config
self.split = split
self.pad_wav_start_sample = 0 # If none, random choose
self.load_video = load_video
self.load_audio = load_audio
self.keep_audio_files = keep_audio_files
self.sample_single_caption = sample_single_caption
self.limit_data_percentage = config['data'].get('limit_data_percentage', False)
self.trim_wav = False
self.waveform_only = waveform_only
self.augment_p = augment_p
self.add_ons = [eval(x) for x in add_ons]
self.consistent_start_time = config['data'].get('consistent_start_time', False)
self.cache_dir = config['data'].get('cache_dir', None)
if self.cache_dir is not None:
os.makedirs(self.cache_dir, exist_ok=True)
print("[INFO] Add-ons:", self.add_ons)
self.obtained_samples = 0
# transforms
if video_transform is None:
video_transform = T.Compose(
[
# T, H, W, C
T.ToTensor()
]
)
self.video_transform = video_transform
self.build_setting_parameters()
# For an external dataset
if dataset_json is not None:
self.data = dataset_json["data"]
self.dataset_name = "TEST"
self.id2label, self.index_dict, self.num2label = {}, {}, {}
else:
self.metadata_root = load_json(self.config['data']["metadata_root"])
self.dataset_name = self.config["data"][self.split]
assert split in self.config["data"].keys(), (
"The dataset split %s you specified is not present in the config. You can choose from %s"
% (split, self.config["data"].keys())
)
self.retrieve_paths()
if split=='train' and self.limit_data_percentage:
print(f"[INFO] limiting data to only {self.limit_data_percentage} of the total data {len(self.data)}")
num_datapoints = int(len(self.data) * self.limit_data_percentage)
# fix the seed to make sure we select the same data.
np.random.seed(42)
selected_idx = np.random.randint(0, len(self.data), size=num_datapoints)
# select
self.video_json_paths = np.asarray(self.video_json_paths)[selected_idx]
self.data = np.asarray(self.data)[selected_idx]
self.datasets_of_datapoints = np.asarray(self.datasets_of_datapoints)[selected_idx]
self.build_dsp()
if 'keys_synonyms' in config['data']:
self.keys_synonyms = config['data']['keys_synonyms']
else:
self.keys_synonyms = {
"gt_audio_caption": ["audiocaps_gt_captions", "gt_audio_caption", "caption", "gt_caption", "gt_captions", 'best_model_w_meta_pred_caption', "wavcaps_caption"],
"tags": ["keywords", "tags"],
}
print("[INFO] Dataset initialize finished")
def filter_text(self, text):
filtered_text = re.sub(r'[^\x00-\x7F\u00A0-\u00FF]+', '', text).strip()
return filtered_text
def get_data_from_keys(self, data, key, default_value=None):
"""
Check for each possible key and return the value if found
"""
if key not in self.keys_synonyms:
return data.get(key, default_value)
possible_keys = self.keys_synonyms[key]
for key in possible_keys:
if key in data:
return data[key]
return default_value # Or return a default value if none of the keys are found
def default_sample(self):
data = {
"dataset_name": "UNK",
"json_path": "UNK",
"fname": "UNK", # list
"waveform": "" if (not self.load_audio) else torch.zeros(1, int(self.sampling_rate * self.duration)),
# "waveform": torch.zeros(1, int(self.sampling_rate * self.duration)),
# tensor, [batchsize, t-steps, f-bins]
"stft": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 512),
# tensor, [batchsize, t-steps, mel-bins]
"log_mel_spec": "" if self.waveform_only else torch.zeros(int(self.duration * 100), 64),
"duration": self.duration,
"sampling_rate": self.sampling_rate,
"random_start_sample_in_original_audio_file": -1,
"labels": "UNK",
# # video
"frames": "",
# additional meta data
"title": "UNK",
"url": "UNK",
"description": "UNK",
"original_captions": "UNK",
"automatic_captions": "UNK",
"gt_audio_caption": "UNK" if self.sample_single_caption else ["UNK"] * 5,
"video_caption": "UNK",
"videollama_caption": "UNK",
"text": "UNK" if self.sample_single_caption else ["UNK"] * 5
}
return data
def __getitem__(self, index, augment=True):
retries = 0
max_retries = 1
while retries < max_retries:
try:
if '.json' in self.data[index]:
dataset_name = self.datasets_of_datapoints[index]
absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
if not os.path.exists(absolute_file_path):
print(f"file {absolute_file_path} does not exists. Retying..")
index = random.randint(0, len(self.data) - 1)
retries += 1
continue
else:
dataset_name = absolute_file_path = ""
(
index,
fname,
video_frames,
waveform,
stft,
log_mel_spec,
_, # the one-hot representation of the audio class
(datum, mix_datum),
random_start,
) = self.feature_extraction(index)
data = {
"dataset_name": dataset_name,
"json_path": absolute_file_path,
"fname": fname, # list
"waveform": "" if (not self.load_audio) else waveform.float(),
# tensor, [batchsize, t-steps, f-bins]
"stft": "" if (stft is None) else stft.float(),
# tensor, [batchsize, t-steps, mel-bins]
"log_mel_spec": "" if (log_mel_spec is None) else log_mel_spec.float(),
"duration": self.duration,
"sampling_rate": self.sampling_rate,
"random_start_sample_in_original_audio_file": -1 if random_start is None else random_start,
"labels": ', '.join(datum.get('labels', [])),
# # video
"frames": video_frames if self.load_video else "",
# additional meta data
"title": self.filter_text(datum.get('title', '')),
"url": self.filter_text(datum.get('url', '')),
"description": self.filter_text(self.get_sample_description(datum)),
"original_captions": self.filter_text(datum.get('original_captions', '')),
"automatic_captions": self.filter_text(datum.get('automatic_captions', '')),
"gt_audio_caption": self.get_sample_caption(datum, index=index),
"video_caption": datum.get('panda70m_caption_0000', '').replace("<unk>", "").strip(),
"videollama_caption": datum.get('videollama_caption_0000', ''),
}
# select one caption if multiple exists
if isinstance(data['gt_audio_caption'], list) and len(data['gt_audio_caption']) > 0 and self.sample_single_caption:
idx = np.random.randint(len(data['gt_audio_caption']))
data['gt_audio_caption'] = data['gt_audio_caption'][idx]
for add_on in self.add_ons:
data.update(add_on(self.config, data, self.data[index]))
# augment data
if augment and np.random.rand() < self.augment_p:
data = self.pair_augmentation(data)
data['text'] = data['gt_audio_caption']
self.obtained_samples += 1
if self.obtained_samples % 20 == 0:
gc.collect()
return data
except Exception as e:
if '.json' in self.data[index]:
dataset_name = self.datasets_of_datapoints[index]
file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
else:
file_path = ""
index = random.randint(0, len(self.data) - 1)
retries += 1
print("[ERROR, videoaudio_dataset] error while loading", file_path, e)
continue
return self.default_sample()
def text_to_filename(self, text):
return text.replace(" ", "_").replace("'", "_").replace('"', "_")
def get_dataset_root_path(self, dataset):
assert dataset in self.metadata_root.keys()
return self.metadata_root[dataset]
def get_dataset_metadata_path(self, dataset, key):
# key: train, test, val, class_label_indices
try:
if dataset in self.metadata_root["metadata"]["path"].keys():
return self.metadata_root["metadata"]["path"][dataset][key]
except KeyError as e:
print("Error:", e)
raise ValueError(
'[ERROR, videoaudio_dataset] Dataset %s does not metadata "%s" specified' % (dataset, key)
)
def __len__(self):
return len(self.data)
def replace_extension(self, path, new_ext):
return f"{'/'.join(path.split('.')[:-1])}.{new_ext}"
def feature_extraction(self, index):
# Read wave file and extract feature
if isinstance(self.data[index], str) and '.json' in self.data[index]:
dataset_name = self.datasets_of_datapoints[index]
file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
datum = load_json(file_path)
else:
datum = self.data[index]
if 'path' in datum and datum['path']:
datum['path'] = self._relative_path_to_absolute_path([datum['path']], dataset_name)[0]
if 'wav' in datum and datum['wav']:
datum['wav'] = self._relative_path_to_absolute_path([datum['wav']], dataset_name)[0]
random_start = None
log_mel_spec, stft, waveform, frames = None, None, None, None
audio_file = None
if self.load_audio and not ('wav' in datum.keys() and os.path.exists(datum['wav'])):
# assume that a .wav file exists in the same location as the .json file
wav_path = self.replace_extension(file_path, 'wav')
flac_path = self.replace_extension(file_path, 'flac')
if os.path.exists(wav_path):
datum['wav'] = wav_path
elif os.path.exists(flac_path):
datum['wav'] = flac_path
elif 'wav' in datum:
del datum['wav']
# cache wav file: useful when there exists a local memory the is faster to do read operations on it
if self.load_audio and 'wav' in datum and self.cache_dir is not None:
target_audio_file_path = f"{self.cache_dir}{datum['wav']}"
if not os.path.exists(target_audio_file_path):
os.makedirs(os.path.dirname(target_audio_file_path), exist_ok=True)
shutil.copy2(datum['wav'] , target_audio_file_path)
# update
datum['wav'] = target_audio_file_path
save_random_start = False
random_start = None
if self.consistent_start_time: # always sample from the same start time
if 'random_start_t' in datum:
random_start = datum.get('random_start_t', None)
save_random_start = False
else:
save_random_start = True
# load audio
if self.load_audio:
if 'wav' in datum:
(
log_mel_spec,
stft,
waveform,
random_start,
) = self.read_audio_file(datum["wav"], random_start=random_start)
waveform = torch.FloatTensor(waveform)
else:
(
frames,
log_mel_spec,
stft,
waveform,
random_start,
audio_file
) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
waveform = torch.FloatTensor(waveform)
# load video
if self.load_video and 'path' in datum:
(frames, _, _, _, _, _ ) = self.read_video_file(datum["path"], random_start=random_start, load_audio=self.load_audio and waveform is None)
elif self.load_video and 'path' in datum:
(
frames,
log_mel_spec,
stft,
waveform,
random_start,
audio_file
) = self.read_video_file(datum["path"], random_start=random_start, load_audio=True)
waveform = torch.FloatTensor(waveform)
if audio_file is not None:
# update json to include path to audio. Only effective if keep_audio_file is enabled
updated_json = load_json(file_path)
updated_json['wav'] = self._absolute_path_to_relative_path([audio_file], dataset_name)[0]
datum["wav"] = updated_json['wav']
updated_json['random_start_t'] = random_start
# write_json(updated_json, file_path)
elif save_random_start and random_start is not None:
# update json to include the randomly sampled start time for future experiments
updated_json = load_json(file_path)
updated_json['random_start_t'] = random_start
write_json(updated_json, file_path)
mix_datum = None
if self.load_video:
assert frames.shape == (3, self.target_frame_cnt, self.frame_width, self.frame_height)
# The filename of the wav file
fname = datum["path"] if 'path' in datum and self.load_video else datum.get('wav', '')
if not fname:
fname = datum['fname']
return (
index,
fname,
frames,
waveform,
stft,
log_mel_spec,
None,
(datum, mix_datum),
random_start,
)
def combine_captions(self, caption1, caption2, remove_duplicates=False, background=False):
"""
Useful function to combine two caption when doing mixup augmentation
"""
words1 = caption1.split()
words2 = caption2.split()
seen = set(words1)
combined_words = words1.copy()
combined_words.append('and')
for word in words2:
if word not in seen or (not remove_duplicates):
combined_words.append(word)
seen.add(word) # Add to set to keep track of seen words
combined_caption = " ".join(combined_words)
if background:
combined_caption += " in the background"
return combined_caption
def pair_augmentation(self, batch):
"""
Mixup augmentation function that combines two audio at different weight, such that one audio is considered to be a background sound.
"""
# load a random audio
idx = np.random.randint(0, self.__len__())
second_data = self.__getitem__(idx, augment=False)
if np.random.randint(0, 2):
ratio = 0.2 + np.random.rand() * 0.2
else:
ratio = 0.5
batch['waveform'] = ((1 - ratio) * batch['waveform'] + ratio * second_data['waveform'])
batch['gt_audio_caption'] = self.combine_captions(batch['gt_audio_caption'], second_data['gt_audio_caption'], background=(ratio!=0.5))
batch['panda_caption'] = f"{batch['panda_caption']} and {second_data['panda_caption']}"
batch['description'] = f"{batch['description']} and {second_data['description']}"
return batch
def build_setting_parameters(self):
# Read from the json config
self.melbins = self.config["preprocessing"]["mel"]["n_mel_channels"]
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
self.hopsize = self.config["preprocessing"]["stft"]["hop_length"]
self.duration = self.config["preprocessing"]["audio"]["duration"]
self.target_length = int(self.duration * self.sampling_rate / self.hopsize)
def merge_paths(self, path1, path2):
parts1 = path1.split('/')
parts2 = path2.split('/')
common_part = None
for i, part in enumerate(parts1):
if parts1[i:] == parts2[:len(parts1)-i]:
common_part = i
break
if common_part is not None:
merged_path = '/'.join(parts1[:common_part] + parts2)
else:
# no common part, simply concatenate
merged_path = '/'.join([path1, path2])
return merged_path
def _relative_path_to_absolute_path(self, paths, dataset_name):
root_path = self.get_dataset_root_path(dataset_name)
for i, path in enumerate(paths):
assert path[0] != "/", (
"The dataset metadata should only contain relative path to the json file: "
+ str(path)
)
merged_path = self.merge_paths(
root_path, path
)
if not os.path.exists(merged_path):
merged_path = self.merge_paths(root_path, path.split('/')[-1])
paths[i] = merged_path
return paths
def _absolute_path_to_relative_path(self, paths, dataset_name):
root_path = self.get_dataset_root_path(dataset_name)
for i, path in enumerate(paths):
assert path[0] == "/", (
"the json file should be absolute: "
+ str(path)
)
paths[i] = os.path.relpath(path, root_path)
return paths
def retrieve_paths(self):
self.video_json_paths = []
self.data = []
self.datasets_of_datapoints = []
print("[INFO] Build dataset split %s from %s" % (self.split, self.dataset_name))
if type(self.dataset_name) is str:
video_paths_list = load_file(
self.get_dataset_metadata_path(self.dataset_name, key=self.split)
)
self.video_json_paths = video_paths_list
self.datasets_of_datapoints = [self.dataset_name] * len(video_paths_list)
elif type(self.dataset_name) is list:
for dataset_name in self.dataset_name:
video_paths_list = load_file(
self.get_dataset_metadata_path(dataset_name, key=self.split)
)
self.datasets_of_datapoints += [dataset_name] * len(video_paths_list)
self.video_json_paths += video_paths_list
else:
raise Exception("[ERROR, videoaudio_dataset] Invalid data format:", type(self.dataset_name))
self.data = self.video_json_paths
print("[INFO] Data size: {}".format(len(self.data)))
return self.data
def build_dsp(self):
self.mel_basis = {}
self.hann_window = {}
self.filter_length = self.config["preprocessing"]["stft"]["filter_length"]
self.hop_length = self.config["preprocessing"]["stft"]["hop_length"]
self.win_length = self.config["preprocessing"]["stft"]["win_length"]
self.n_mel = self.config["preprocessing"]["mel"]["n_mel_channels"]
self.sampling_rate = self.config["preprocessing"]["audio"]["sampling_rate"]
self.mel_fmin = self.config["preprocessing"]["mel"]["mel_fmin"]
self.mel_fmax = self.config["preprocessing"]["mel"]["mel_fmax"]
# video
self.video_fps = self.config["preprocessing"]["video"]["fps"]
self.frame_height = self.config["preprocessing"]["video"]["height"]
self.frame_width = self.config["preprocessing"]["video"]["width"]
if not self.waveform_only:
self.STFT = Audio.stft.TacotronSTFT(
self.config["preprocessing"]["stft"]["filter_length"],
self.config["preprocessing"]["stft"]["hop_length"],
self.config["preprocessing"]["stft"]["win_length"],
self.config["preprocessing"]["mel"]["n_mel_channels"],
self.config["preprocessing"]["audio"]["sampling_rate"],
self.config["preprocessing"]["mel"]["mel_fmin"],
self.config["preprocessing"]["mel"]["mel_fmax"],
)
def build_id_to_label(self):
id2label = {}
id2num = {}
num2label = {}
class_label_indices_path = self.get_dataset_metadata_path(
dataset=self.config["data"]["class_label_indices"],
key="class_label_indices",
)
if class_label_indices_path is not None:
df = pd.read_csv(class_label_indices_path)
for _, row in df.iterrows():
index, mid, display_name = row["index"], row["mid"], row["display_name"]
id2label[mid] = display_name
id2num[mid] = index
num2label[index] = display_name
self.id2label, self.index_dict, self.num2label = id2label, id2num, num2label
else:
self.id2label, self.index_dict, self.num2label = {}, {}, {}
def resample_wav(self, waveform, sr):
waveform = torchaudio.functional.resample(waveform, sr, self.sampling_rate)
return waveform
def normalize_wav(self, waveform):
waveform = waveform - np.mean(waveform)
waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
return waveform * 0.5 # Manually limit the maximum amplitude into 0.5
def random_segment_wav(self, waveform, target_length, random_start=None):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
# Too short
if (waveform_length - target_length) <= 0:
return waveform, 0
if random_start is None:
for i in range(10):
random_start = int(random_uniform(0, waveform_length - target_length))
if torch.max(
torch.abs(waveform[:, random_start : random_start + target_length])
> 1e-4
):
break
return waveform[:, random_start : random_start + target_length], random_start
def pad_wav(self, waveform, target_length):
waveform_length = waveform.shape[-1]
assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
if waveform_length == target_length:
return waveform
# Pad
temp_wav = np.zeros((1, target_length), dtype=np.float32)
if self.pad_wav_start_sample is None:
rand_start = int(random_uniform(0, target_length - waveform_length))
else:
rand_start = 0
temp_wav[:, rand_start : rand_start + waveform_length] = waveform
return temp_wav
def trim_wav(self, waveform):
if np.max(np.abs(waveform)) < 0.0001:
return waveform
def detect_leading_silence(waveform, threshold=0.0001):
chunk_size = 1000
waveform_length = waveform.shape[0]
start = 0
while start + chunk_size < waveform_length:
if np.max(np.abs(waveform[start : start + chunk_size])) < threshold:
start += chunk_size
else:
break
return start
def detect_ending_silence(waveform, threshold=0.0001):
chunk_size = 1000
waveform_length = waveform.shape[0]
start = waveform_length
while start - chunk_size > 0:
if np.max(np.abs(waveform[start - chunk_size : start])) < threshold:
start -= chunk_size
else:
break
if start == waveform_length:
return start
else:
return start + chunk_size
start = detect_leading_silence(waveform)
end = detect_ending_silence(waveform)
return waveform[start:end]
def process_wavform(self, waveform, sr):
waveform = self.resample_wav(waveform, sr)
waveform = waveform.numpy()[0, ...]
waveform = self.normalize_wav(waveform)
if self.trim_wav:
waveform = self.trim_wav(waveform)
waveform = waveform[None, ...]
waveform = self.pad_wav(
waveform, target_length=int(self.sampling_rate * self.duration)
)
return waveform
def load_audio_with_timeout(self, file_path, timeout):
"""
Load an audio file with a specified timeout using threading.
:param file_path: Path to the audio file.
:param timeout: Maximum time (in seconds) to allow for loading the file.
:return: (waveform, sample_rate) if successful, None if timeout occurs.
"""
class AudioLoader(threading.Thread):
def __init__(self, file_path):
super().__init__()
self.file_path = file_path
self.result = None
def run(self):
try:
waveform, sample_rate = torchaudio.load(self.file_path)
self.result = (waveform, sample_rate)
except Exception as e:
print(f"Failed to load audio: {e}")
self.result = None
# Start the thread
audio_loader = AudioLoader(file_path)
audio_loader.start()
# Wait for the thread to complete or timeout
audio_loader.join(timeout=timeout)
if audio_loader.is_alive():
print(f"Timeout while loading {file_path}")
return None, None # Timeout case
return audio_loader.result
def read_wav_file(self, filename, random_start=None):
# waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
waveform, sr = self.load_audio_with_timeout(filename, timeout=10)
if waveform is None:
print("[INFO] timeout when loading the audio")
# # # TODO Important, dummy audio
waveform = torch.zeros(1, int(self.sampling_rate * self.duration))
sr = 16000
waveform, random_start = self.random_segment_wav(
waveform, target_length=int(sr * self.duration), random_start=random_start
)
waveform = self.process_wavform(waveform, sr)
return waveform, random_start
def read_mp4_file(self, filename, random_start=None, load_audio=True):
video = VideoFileClip(filename)
video = video.resize(newsize=(self.frame_width, self.frame_height))
audio = video.audio
# audio part
waveform = None
tmp_audio_file = None
if load_audio:
tmp_audio_file = f"{filename[:-4]}.wav"
audio.write_audiofile(tmp_audio_file, codec='pcm_s16le', fps=self.sampling_rate, nbytes=2, ffmpeg_params=["-ac", "1"])
waveform, sr = torchaudio.load(tmp_audio_file, format='wav')
if not self.keep_audio_files: # keep the audio file and save its path in the json file
os.remove(tmp_audio_file)
tmp_audio_file = None
# random segment
waveform, random_start = self.random_segment_wav(
waveform, target_length=int(sr * self.duration), random_start=random_start
)
random_start_sec = random_start / sr
waveform = self.process_wavform(waveform, sr)
else:
random_start_sec = random_start / self.sampling_rate
# video part
frames = []
if self.load_video:
interval = 1 / self.video_fps
current_time = random_start_sec
# assuming fixed fps
while current_time <= video.duration and len(frames) < self.target_frame_cnt:
frame = video.get_frame(current_time)
frames.append(frame) # T x H x W x C
current_time += interval
# transform frames
frames = torch.from_numpy(np.stack(frames[:self.target_frame_cnt]))
frames = frames.permute(3, 0, 1, 2).float() # (C, T, H, W)
frames = self.video_transform.transform(frames)
return frames, waveform, random_start, tmp_audio_file
def read_video_file(self, filename, load_audio=True, random_start=None):
if os.path.exists(filename):
frames, waveform, random_start, audio_file = self.read_mp4_file(filename, load_audio=load_audio, random_start=random_start)
# frames C x T x H x W
if frames and frames.shape[1] < self.target_frame_cnt:
extra_frames = torch.zeros((frames.shape[0], self.target_frame_cnt - frames.shape[1], frames.shape[2], frames.shape[3]))
frames = torch.cat([frames, extra_frames], dim=1)
else:
print(
'[WARNING, videoaudio_dataset] The path "',
filename,
'" is not find in the metadata. Use empty video instead. This is normal in the inference process.',
)
target_wavform_length = int(self.sampling_rate * self.duration)
waveform = torch.zeros((1, target_wavform_length))
frames = torch.zeros((3, self.target_frame_cnt, self.frame_height, self.frame_width))
random_start = 0
audio_file = None
if load_audio and not self.waveform_only:
log_mel_spec, stft = self.wav_feature_extraction(waveform)
else:
# Load waveform data only
# Use zero array to keep the format unified
log_mel_spec, stft = None, None
return frames, log_mel_spec, stft, waveform, random_start, audio_file
def read_audio_file(self, filename, random_start=None):
if os.path.exists(filename):
waveform, random_start = self.read_wav_file(filename, random_start=random_start)
else:
print(
'Non-fatal Warning [dataset.py]: The wav path "',
filename,
'" is not find in the metadata. Use empty waveform instead. This is normal in the inference process.',
)
target_length = int(self.sampling_rate * self.duration)
waveform = torch.zeros((1, target_length))
random_start = 0
if not self.waveform_only:
log_mel_spec, stft = self.wav_feature_extraction(waveform)
else:
# Load waveform data only
# Use zero array to keep the format unified
log_mel_spec, stft = None, None
return log_mel_spec, stft, waveform, random_start
def get_sample_caption(self, datum, index):
"""
Use groundtruth caption if exists, otherwise use a hand crafted caption based on the labels
"""
caption = self.get_data_from_keys(datum, "gt_audio_caption", [])
if caption:
return caption
# covert labels to caption
labels = self.get_data_from_keys(datum, "labels", [])
if not labels:
dataset_name = self.datasets_of_datapoints[index]
absolute_file_path = self._relative_path_to_absolute_path([self.data[index]], dataset_name)[0]
print(f"Warning file {absolute_file_path} does not have gt caption")
return ["Sound"]
if not isinstance(labels, list):
labels = [labels]
# process each tag
def clean_tag(tag):
# Replace common delimiters with spaces
for delimiter in [';', '_', '-', ',']:
tag = tag.replace(delimiter, ' ')
return tag.strip().lower()
unique_keywords = set()
for tag in labels:
words = clean_tag(tag).split(' ')
unique_keywords.update(words) # Add words to the set, which automatically removes duplicates
cleaned_labels = list(unique_keywords)
caption = 'The sound of ' + ', '.join(cleaned_labels[:-1]) + (', and ' + cleaned_labels[-1] + '.') if len(cleaned_labels) > 1 else cleaned_labels[0] + '.'
return [caption]
def get_sample_description(self, datum):
"""
Use description from metadata if exists, otherwise use a hand crafted description based on the tags
"""
if 'description' in datum.keys():
return datum['description']
# covert tags to description
tags = self.get_data_from_keys(datum, "tags", [])
if not tags:
# print("[INFO] could not find tags for:", datum)
return ""
if not isinstance(tags, list):
tags = [tags]
# process each tag
def clean_tag(tag):
# Replace common delimiters with spaces
for delimiter in [';', '_', '-', ',']:
tag = tag.replace(delimiter, ' ')
return tag.strip().lower()
unique_keywords = set()
for tag in tags:
words = clean_tag(tag).split(' ')
unique_keywords.update(words) # Add words to the set, which automatically removes duplicates
cleaned_tags = list(unique_keywords)
description = ', '.join(cleaned_tags[:-1]) + ', and ' + cleaned_tags[-1] + '.' if len(cleaned_tags) > 1 else cleaned_tags[0] + '.'
return description
def mel_spectrogram_train(self, y):
if torch.min(y) < -1.0:
print("train min value is ", torch.min(y))
if torch.max(y) > 1.0:
print("train max value is ", torch.max(y))
if self.mel_fmax not in self.mel_basis:
mel = librosa_mel_fn(
sr=self.sampling_rate,
n_fft=self.filter_length,
n_mels=self.n_mel,
fmin=self.mel_fmin,
fmax=self.mel_fmax,
)
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)] = (
torch.from_numpy(mel).float().to(y.device)
)
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(
y.device
)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.filter_length - self.hop_length) / 2),
int((self.filter_length - self.hop_length) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
stft_spec = torch.stft(
y,
self.filter_length,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.hann_window[str(y.device)],
center=False,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
stft_spec = torch.abs(stft_spec)
mel = spectral_normalize_torch(
torch.matmul(
self.mel_basis[str(self.mel_fmax) + "_" + str(y.device)], stft_spec
)
)
return mel[0], stft_spec[0]
# This one is significantly slower than "wav_feature_extraction_torchaudio" if num_worker > 1
def wav_feature_extraction(self, waveform):
waveform = waveform[0, ...]
waveform = torch.FloatTensor(waveform)
# log_mel_spec, stft, energy = Audio.tools.get_mel_from_wav(waveform, self.STFT)[0]
log_mel_spec, stft = self.mel_spectrogram_train(waveform.unsqueeze(0))
log_mel_spec = torch.FloatTensor(log_mel_spec.T)
stft = torch.FloatTensor(stft.T)
log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
return log_mel_spec, stft
# @profile
# def wav_feature_extraction_torchaudio(self, waveform):
# waveform = waveform[0, ...]
# waveform = torch.FloatTensor(waveform)
# stft = self.stft_transform(waveform)
# mel_spec = self.melscale_transform(stft)
# log_mel_spec = torch.log(mel_spec + 1e-7)
# log_mel_spec = torch.FloatTensor(log_mel_spec.T)
# stft = torch.FloatTensor(stft.T)
# log_mel_spec, stft = self.pad_spec(log_mel_spec), self.pad_spec(stft)
# return log_mel_spec, stft
def pad_spec(self, log_mel_spec):
n_frames = log_mel_spec.shape[0]
p = self.target_length - n_frames
if p > 0:
m = torch.nn.ZeroPad2d((0, 0, 0, p))
log_mel_spec = m(log_mel_spec)
elif p < 0:
log_mel_spec = log_mel_spec[0 : self.target_length, :]
if log_mel_spec.size(-1) % 2 != 0:
log_mel_spec = log_mel_spec[..., :-1]
return log_mel_spec
def custom_collate_fn(batch):
# for test
# for k in batch[0].keys():
# try:
# default_collate([{k:item[k]} for item in batch])
# except Exception as e:
# print("collect error in key", k)
# print("files", [b['fname'] for b in batch])
# print("shape", [item[k].shape for item in batch])
# print("error", e)
collated_batch = default_collate(batch)
# Handle the 'captions' manually as needed, here assuming you want to keep them as lists of lists
collated_batch['gt_audio_caption'] = [item['gt_audio_caption'] for item in batch]
return collated_batch
if __name__ == "__main__":
import torch
from tqdm import tqdm
from pytorch_lightning import seed_everything
from torch.utils.data import DataLoader
from src.tools.configuration import Configuration
model_config = "settings/simple_runs/genau.yaml"
config = Configuration(model_config)
config = config.get_config()
seed_everything(20)
dataset = VideoAudioDataset(
config=config, split="train", waveform_only=False,
load_video=False, sample_single_caption=True, augment_p=1.0)
print("[INFO] Dataset len:", len(dataset))
loader = DataLoader(dataset, batch_size=64, num_workers=0, shuffle=True, collate_fn=custom_collate_fn)
# # test augmentation on a single audio
# audio_1 = dataset.__getitem__(0, augment=False)
# aug_audio_1 = dataset.__getitem__(1, augment=True)
# aug_audio_2 = dataset.__getitem__(2, augment=True)
# print("orginal_caption:", audio_1['gt_audio_caption'])
# print("aug_caption_1:", aug_audio_1['gt_audio_caption'])
# print("aug_caption_2:", aug_audio_2['gt_audio_caption'])
# # save audio
# torchaudio.save("original_audio.wav", audio_1['waveform'], 16000)
# torchaudio.save("aug_audio_1.wav", aug_audio_1['waveform'], 16000)
# torchaudio.save("aug_audio_2.wav", aug_audio_2['waveform'], 16000)
for cnt, each in tqdm(enumerate(loader)):
print("wav shape:", each['waveform'].shape, flush=True)
print("log_mel_spec shape:", each['log_mel_spec'].shape, flush=True)
print("names:", each['fname'], flush=True)
break
# Test from dataset_json
dataset = VideoAudioDataset(
config=config,
split='test',
dataset_json= build_dataset_json_from_list("tests/captionlist/inference_submission.lst"),
load_audio=False,
load_video=False
)
# print("Item 0", dataset[0])