HuMo_local / humo /utils /audio_processor_whisper.py
alexnasa's picture
Upload 54 files
295978e verified
# pylint: disable=C0301
'''
This module contains the AudioProcessor class and related functions for processing audio data.
It utilizes various libraries and models to perform tasks such as preprocessing, feature extraction,
and audio separation. The class is initialized with configuration parameters and can process
audio files using the provided models.
'''
import os
import subprocess
import librosa
import numpy as np
import torch
from audio_separator.separator import Separator
from transformers import WhisperModel, AutoFeatureExtractor
import torch.nn.functional as F
def linear_interpolation_fps(features, input_fps, output_fps, output_len=None):
features = features.transpose(1, 2) # [1, C, T]
seq_len = features.shape[2] / float(input_fps)
if output_len is None:
output_len = int(seq_len * output_fps)
output_features = F.interpolate(features, size=output_len, align_corners=True, mode='linear')
return output_features.transpose(1, 2)
def resample_audio(input_audio_file: str, output_audio_file: str, sample_rate: int):
p = subprocess.Popen([
"ffmpeg", "-y", "-v", "error", "-i", input_audio_file, "-ar", str(sample_rate), output_audio_file
])
ret = p.wait()
assert ret == 0, "Resample audio failed!"
return output_audio_file
class AudioProcessor:
"""
AudioProcessor is a class that handles the processing of audio files.
It takes care of preprocessing the audio files, extracting features
using wav2vec models, and separating audio signals if needed.
:param sample_rate: Sampling rate of the audio file
:param fps: Frames per second for the extracted features
:param wav2vec_model_path: Path to the wav2vec model
:param only_last_features: Whether to only use the last features
:param audio_separator_model_path: Path to the audio separator model
:param audio_separator_model_name: Name of the audio separator model
:param cache_dir: Directory to cache the intermediate results
:param device: Device to run the processing on
"""
def __init__(
self,
sample_rate,
fps,
wav2vec_model_path,
wav2vec_feature_type,
audio_separator_model_path:str=None,
audio_separator_model_name:str=None,
cache_dir:str='',
device="cuda:0",
) -> None:
self.sample_rate = sample_rate
self.fps = fps
self.device = device
self.whisper = WhisperModel.from_pretrained(wav2vec_model_path).to(device).eval()
self.whisper.requires_grad_(False)
self.feature_extractor = AutoFeatureExtractor.from_pretrained(wav2vec_model_path)
if audio_separator_model_name is not None:
try:
os.makedirs(cache_dir, exist_ok=True)
except OSError as _:
print("Fail to create the output cache dir.")
self.audio_separator = Separator(
output_dir=cache_dir,
output_single_stem="vocals",
model_file_dir=audio_separator_model_path,
)
self.audio_separator.load_model(audio_separator_model_name)
assert self.audio_separator.model_instance is not None, "Fail to load audio separate model."
else:
self.audio_separator=None
print("Use audio directly without vocals seperator.")
def get_audio_feature(self, audio_path):
audio_input, sampling_rate = librosa.load(audio_path, sr=16000)
assert sampling_rate == 16000
audio_features = []
window = 750*640
for i in range(0, len(audio_input), window):
audio_feature = self.feature_extractor(audio_input[i:i+window],
sampling_rate=sampling_rate,
return_tensors="pt",
).input_features
audio_features.append(audio_feature)
audio_features = torch.cat(audio_features, dim=-1)
return audio_features, len(audio_input) // 640
def preprocess(self, audio_path: str):
audio_input, audio_len = self.get_audio_feature(audio_path)
audio_feature = audio_input.to(self.whisper.device).float()
window = 3000
audio_prompts = []
for i in range(0, audio_feature.shape[-1], window):
audio_prompt = self.whisper.encoder(audio_feature[:,:,i:i+window], output_hidden_states=True).hidden_states
audio_prompt = torch.stack(audio_prompt, dim=2)
audio_prompts.append(audio_prompt)
audio_prompts = torch.cat(audio_prompts, dim=1)
audio_prompts = audio_prompts[:,:audio_len*2]
audio_emb = self.audio_emb_enc(audio_prompts, wav_enc_type="whisper")
return audio_emb, audio_emb.shape[0]
def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
if wav_enc_type == "wav2vec":
feat_merge = audio_emb
elif wav_enc_type == "whisper":
# [1, T, 33, 1280]
feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] # [T, 5, 1280]
else:
raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
return feat_merge
def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
iter_ = 1 + (frame_num - 1) // 4
audio_emb_wind = []
for lt_i in range(iter_):
if lt_i == 0: # latent_i
# 提取第一帧VAElatent,audio左侧补0,标识出
st = frame0_idx + lt_i - 2
ed = frame0_idx + lt_i + 3
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0) # [5, 13, 768]
wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) # [8, 13, 768]
else:
st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
ed = frame0_idx + 1 + 4 * lt_i + audio_shift
wind_feat = torch.stack([
audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
for i in range(st, ed)
], dim=0) # [8, 13, 768]
audio_emb_wind.append(wind_feat)
audio_emb_wind = torch.stack(audio_emb_wind, dim=0) # [iter_, 8, 13, 768]
return audio_emb_wind, ed - audio_shift
def close(self):
"""
TODO: to be implemented
"""
return self
def __enter__(self):
return self
def __exit__(self, _exc_type, _exc_val, _exc_tb):
self.close()