Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,619 Bytes
295978e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# pylint: disable=C0301
'''
This module contains the AudioProcessor class and related functions for processing audio data.
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
and audio separation. The class is initialized with configuration parameters and can process
audio files using the provided models.
'''
import os
import subprocess
import librosa
import numpy as np
import torch
from audio_separator.separator import Separator
from transformers import WhisperModel, AutoFeatureExtractor
import torch.nn.functional as F
def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
features = features.transpose(1, 2) # [1, C, T]
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
p = subprocess.Popen([
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
])
ret = p.wait()
assert ret == 0, "Resample audio failed!"
return output_audio_file
class AudioProcessor:
"""
AudioProcessor is a class that handles the processing of audio files.
It takes care of preprocessing the audio files, extracting features
using wav2vec models, and separating audio signals if needed.
:param sample_rate: Sampling rate of the audio file
:param fps: Frames per second for the extracted features
:param wav2vec_model_path: Path to the wav2vec model
:param only_last_features: Whether to only use the last features
:param audio_separator_model_path: Path to the audio separator model
:param audio_separator_model_name: Name of the audio separator model
:param cache_dir: Directory to cache the intermediate results
:param device: Device to run the processing on
"""
def __init__(
self,
sample_rate,
fps,
wav2vec_model_path,
wav2vec_feature_type,
audio_separator_model_path:str=None,
audio_separator_model_name:str=None,
cache_dir:str='',
device="cuda:0",
) -> None:
self.sample_rate = sample_rate
self.fps = fps
self.device = device
self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
self.whisper.requires_grad_(False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
if audio_separator_model_name is not None:
try:
os.makedirs(cache_dir, exist_ok=True)
except OSError as _:
print("Fail to create the output cache dir.")
self.audio_separator = Separator(
output_dir=cache_dir,
output_single_stem="vocals",
model_file_dir=audio_separator_model_path,
)
self.audio_separator.load_model(audio_separator_model_name)
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
else:
self.audio_separator=None
print("Use audio directly without vocals seperator.")
def get_audio_feature(self, audio_path):
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
assert sampling_rate == 16000
audio_features = []
window = 750*640
for i in range(0, len(audio_input), window):
audio_feature = self.feature_extractor(audio_input[i:i+window],
sampling_rate=sampling_rate,
return_tensors="pt",
).input_features
audio_features.append(audio_feature)
audio_features = torch.cat(audio_features, dim=-1)
return audio_features, len(audio_input) // 640
def preprocess(self, audio_path: str):
audio_input, audio_len = self.get_audio_feature(audio_path)
audio_feature = audio_input.to(self.whisper.device).float()
window = 3000
audio_prompts = []
for i in range(0, audio_feature.shape[-1], window):
audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
audio_prompt = torch.stack(audio_prompt, dim=2)
audio_prompts.append(audio_prompt)
audio_prompts = torch.cat(audio_prompts, dim=1)
audio_prompts = audio_prompts[:,:audio_len*2]
audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
return audio_emb, audio_emb.shape[0]
def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
if wav_enc_type == "wav2vec":
feat_merge = audio_emb
elif wav_enc_type == "whisper":
# [1, T, 33, 1280]
feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
else:
raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
return feat_merge
def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
iter_ = 1 + (frame_num - 1) // 4
audio_emb_wind = []
for lt_i in range(iter_):
if lt_i == 0: # latent_i
# 提取第一帧VAElatent,audio左侧补0,标识出
st = frame0_idx + lt_i - 2
ed = frame0_idx + lt_i + 3
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0) # [5, 13, 768]
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768]
else:
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0) # [8, 13, 768]
audio_emb_wind.append(wind_feat)
audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768]
return audio_emb_wind, ed - audio_shift
def close(self):
"""
TODO: to be implemented
"""
return self
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()
|