Spaces:
Running
Running
| import random | |
| import os | |
| import torch | |
| import torch.distributed as dist | |
| from PIL import Image | |
| import subprocess | |
| import torchvision.transforms as transforms | |
| import torch.nn.functional as F | |
| import torch.nn as nn | |
| # from shared.utils.multitalk_utils import save_video_ffmpeg | |
| # from .kokoro import KPipeline | |
| from transformers import Wav2Vec2FeatureExtractor | |
| from .wav2vec2 import Wav2Vec2Model | |
| import librosa | |
| import pyloudnorm as pyln | |
| import numpy as np | |
| from einops import rearrange | |
| import soundfile as sf | |
| import re | |
| import math | |
| from shared.utils import files_locator as fl | |
| def custom_init(device, wav2vec): | |
| audio_encoder = Wav2Vec2Model.from_pretrained(wav2vec, local_files_only=True).to(device) | |
| audio_encoder.feature_extractor._freeze_parameters() | |
| wav2vec_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(wav2vec, local_files_only=True) | |
| return wav2vec_feature_extractor, audio_encoder | |
| def loudness_norm(audio_array, sr=16000, lufs=-23): | |
| meter = pyln.Meter(sr) | |
| loudness = meter.integrated_loudness(audio_array) | |
| if abs(loudness) > 100: | |
| return audio_array | |
| normalized_audio = pyln.normalize.loudness(audio_array, loudness, lufs) | |
| return normalized_audio | |
| def get_embedding(speech_array, wav2vec_feature_extractor, audio_encoder, sr=16000, device='cpu', fps = 25): | |
| audio_duration = len(speech_array) / sr | |
| video_length = audio_duration * fps | |
| # wav2vec_feature_extractor | |
| audio_feature = np.squeeze( | |
| wav2vec_feature_extractor(speech_array, sampling_rate=sr).input_values | |
| ) | |
| audio_feature = torch.from_numpy(audio_feature).float().to(device=device) | |
| audio_feature = audio_feature.unsqueeze(0) | |
| # audio encoder | |
| with torch.no_grad(): | |
| embeddings = audio_encoder(audio_feature, seq_len=int(video_length), output_hidden_states=True) | |
| if len(embeddings) == 0: | |
| print("Fail to extract audio embedding") | |
| return None | |
| audio_emb = torch.stack(embeddings.hidden_states[1:], dim=1).squeeze(0) | |
| audio_emb = rearrange(audio_emb, "b s d -> s b d") | |
| audio_emb = audio_emb.cpu().detach() | |
| return audio_emb | |
| def extract_audio_from_video(filename, sample_rate): | |
| raw_audio_path = filename.split('/')[-1].split('.')[0]+'.wav' | |
| ffmpeg_command = [ | |
| "ffmpeg", | |
| "-y", | |
| "-i", | |
| str(filename), | |
| "-vn", | |
| "-acodec", | |
| "pcm_s16le", | |
| "-ar", | |
| "16000", | |
| "-ac", | |
| "2", | |
| str(raw_audio_path), | |
| ] | |
| subprocess.run(ffmpeg_command, check=True) | |
| human_speech_array, sr = librosa.load(raw_audio_path, sr=sample_rate) | |
| human_speech_array = loudness_norm(human_speech_array, sr) | |
| os.remove(raw_audio_path) | |
| return human_speech_array | |
| def audio_prepare_single(audio_path, sample_rate=16000, duration = 0): | |
| ext = os.path.splitext(audio_path)[1].lower() | |
| if ext in ['.mp4', '.mov', '.avi', '.mkv']: | |
| human_speech_array = extract_audio_from_video(audio_path, sample_rate) | |
| return human_speech_array | |
| else: | |
| human_speech_array, sr = librosa.load(audio_path, duration=duration, sr=sample_rate) | |
| human_speech_array = loudness_norm(human_speech_array, sr) | |
| return human_speech_array | |
| def audio_prepare_multi(left_path, right_path, audio_type = "add", sample_rate=16000, duration = 0, pad = 0, min_audio_duration = 0): | |
| if not (left_path==None or right_path==None): | |
| human_speech_array1 = audio_prepare_single(left_path, duration = duration) | |
| human_speech_array2 = audio_prepare_single(right_path, duration = duration) | |
| else: | |
| audio_type='para' | |
| if left_path==None: | |
| human_speech_array2 = audio_prepare_single(right_path, duration = duration) | |
| human_speech_array1 = np.zeros(human_speech_array2.shape[0]) | |
| elif right_path==None: | |
| human_speech_array1 = audio_prepare_single(left_path, duration = duration) | |
| human_speech_array2 = np.zeros(human_speech_array1.shape[0]) | |
| if audio_type=='para': | |
| new_human_speech1 = human_speech_array1 | |
| new_human_speech2 = human_speech_array2 | |
| if len(new_human_speech1) != len(new_human_speech2): | |
| if len(new_human_speech1) < len(new_human_speech2): | |
| new_human_speech1 = np.pad(new_human_speech1, (0, len(new_human_speech2) - len(new_human_speech1))) | |
| else: | |
| new_human_speech2 = np.pad(new_human_speech2, (0, len(new_human_speech1) - len(new_human_speech2))) | |
| elif audio_type=='add': | |
| new_human_speech1 = np.concatenate([human_speech_array1[: human_speech_array1.shape[0]], np.zeros(human_speech_array2.shape[0])]) | |
| new_human_speech2 = np.concatenate([np.zeros(human_speech_array1.shape[0]), human_speech_array2[:human_speech_array2.shape[0]]]) | |
| duration_changed = False | |
| if min_audio_duration > 0: | |
| min_samples = math.ceil( min_audio_duration * sample_rate) | |
| if len(new_human_speech1) < min_samples: | |
| new_human_speech1 = np.concatenate([new_human_speech1, np.zeros(min_samples -len(new_human_speech1)) ]) | |
| duration_changed = True | |
| if len(new_human_speech2) < min_samples: | |
| new_human_speech2 = np.concatenate([new_human_speech2, np.zeros(min_samples -len(new_human_speech2)) ]) | |
| duration_changed = True | |
| #dont include the padding on the summed audio which is used to build the output audio track | |
| sum_human_speechs = new_human_speech1 + new_human_speech2 | |
| if pad > 0: | |
| duration_changed = True | |
| new_human_speech1 = np.concatenate([np.zeros(pad), new_human_speech1]) | |
| new_human_speech2 = np.concatenate([np.zeros(pad), new_human_speech2]) | |
| return new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed | |
| def process_tts_single(text, save_dir, voice1): | |
| s1_sentences = [] | |
| pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') | |
| voice_tensor = torch.load(voice1, weights_only=True) | |
| generator = pipeline( | |
| text, voice=voice_tensor, # <= change voice here | |
| speed=1, split_pattern=r'\n+' | |
| ) | |
| audios = [] | |
| for i, (gs, ps, audio) in enumerate(generator): | |
| audios.append(audio) | |
| audios = torch.concat(audios, dim=0) | |
| s1_sentences.append(audios) | |
| s1_sentences = torch.concat(s1_sentences, dim=0) | |
| save_path1 =f'{save_dir}/s1.wav' | |
| sf.write(save_path1, s1_sentences, 24000) # save each audio file | |
| s1, _ = librosa.load(save_path1, sr=16000) | |
| return s1, save_path1 | |
| def process_tts_multi(text, save_dir, voice1, voice2): | |
| pattern = r'\(s(\d+)\)\s*(.*?)(?=\s*\(s\d+\)|$)' | |
| matches = re.findall(pattern, text, re.DOTALL) | |
| s1_sentences = [] | |
| s2_sentences = [] | |
| pipeline = KPipeline(lang_code='a', repo_id='weights/Kokoro-82M') | |
| for idx, (speaker, content) in enumerate(matches): | |
| if speaker == '1': | |
| voice_tensor = torch.load(voice1, weights_only=True) | |
| generator = pipeline( | |
| content, voice=voice_tensor, # <= change voice here | |
| speed=1, split_pattern=r'\n+' | |
| ) | |
| audios = [] | |
| for i, (gs, ps, audio) in enumerate(generator): | |
| audios.append(audio) | |
| audios = torch.concat(audios, dim=0) | |
| s1_sentences.append(audios) | |
| s2_sentences.append(torch.zeros_like(audios)) | |
| elif speaker == '2': | |
| voice_tensor = torch.load(voice2, weights_only=True) | |
| generator = pipeline( | |
| content, voice=voice_tensor, # <= change voice here | |
| speed=1, split_pattern=r'\n+' | |
| ) | |
| audios = [] | |
| for i, (gs, ps, audio) in enumerate(generator): | |
| audios.append(audio) | |
| audios = torch.concat(audios, dim=0) | |
| s2_sentences.append(audios) | |
| s1_sentences.append(torch.zeros_like(audios)) | |
| s1_sentences = torch.concat(s1_sentences, dim=0) | |
| s2_sentences = torch.concat(s2_sentences, dim=0) | |
| sum_sentences = s1_sentences + s2_sentences | |
| save_path1 =f'{save_dir}/s1.wav' | |
| save_path2 =f'{save_dir}/s2.wav' | |
| save_path_sum = f'{save_dir}/sum.wav' | |
| sf.write(save_path1, s1_sentences, 24000) # save each audio file | |
| sf.write(save_path2, s2_sentences, 24000) | |
| sf.write(save_path_sum, sum_sentences, 24000) | |
| s1, _ = librosa.load(save_path1, sr=16000) | |
| s2, _ = librosa.load(save_path2, sr=16000) | |
| # sum, _ = librosa.load(save_path_sum, sr=16000) | |
| return s1, s2, save_path_sum | |
| def get_full_audio_embeddings(audio_guide1 = None, audio_guide2 = None, combination_type ="add", num_frames = 0, fps = 25, sr = 16000, padded_frames_for_embeddings = 0, min_audio_duration = 0, return_sum_only = False): | |
| wav2vec_feature_extractor, audio_encoder= custom_init('cpu', fl.locate_folder("chinese-wav2vec2-base")) | |
| # wav2vec_feature_extractor, audio_encoder= custom_init('cpu', "ckpts/wav2vec") | |
| pad = int(padded_frames_for_embeddings/ fps * sr) | |
| new_human_speech1, new_human_speech2, sum_human_speechs, duration_changed = audio_prepare_multi(audio_guide1, audio_guide2, combination_type, duration= num_frames / fps, pad = pad, min_audio_duration = min_audio_duration ) | |
| if return_sum_only: | |
| full_audio_embs = None | |
| else: | |
| audio_embedding_1 = get_embedding(new_human_speech1, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) | |
| audio_embedding_2 = get_embedding(new_human_speech2, wav2vec_feature_extractor, audio_encoder, sr=sr, fps= fps) | |
| full_audio_embs = [] | |
| if audio_guide1 != None: full_audio_embs.append(audio_embedding_1) | |
| if audio_guide2 != None: full_audio_embs.append(audio_embedding_2) | |
| if audio_guide2 == None and not duration_changed: sum_human_speechs = None | |
| return full_audio_embs, sum_human_speechs | |
| def get_window_audio_embeddings(full_audio_embs, audio_start_idx=0, clip_length = 81, vae_scale = 4, audio_window = 5): | |
| if full_audio_embs == None: return None | |
| HUMAN_NUMBER = len(full_audio_embs) | |
| audio_end_idx = audio_start_idx + clip_length | |
| indices = (torch.arange(2 * 2 + 1) - 2) * 1 | |
| audio_embs = [] | |
| # split audio with window size | |
| for human_idx in range(HUMAN_NUMBER): | |
| center_indices = torch.arange( | |
| audio_start_idx, | |
| audio_end_idx, | |
| 1 | |
| ).unsqueeze( | |
| 1 | |
| ) + indices.unsqueeze(0) | |
| center_indices = torch.clamp(center_indices, min=0, max=full_audio_embs[human_idx].shape[0]-1).to(full_audio_embs[human_idx].device) | |
| audio_emb = full_audio_embs[human_idx][center_indices][None,...] #.to(self.device) | |
| audio_embs.append(audio_emb) | |
| audio_embs = torch.concat(audio_embs, dim=0) #.to(self.param_dtype) | |
| # audio_cond = audio.to(device=x.device, dtype=x.dtype) | |
| audio_cond = audio_embs | |
| first_frame_audio_emb_s = audio_cond[:, :1, ...] | |
| latter_frame_audio_emb = audio_cond[:, 1:, ...] | |
| latter_frame_audio_emb = rearrange(latter_frame_audio_emb, "b (n_t n) w s c -> b n_t n w s c", n=vae_scale) | |
| middle_index = audio_window // 2 | |
| latter_first_frame_audio_emb = latter_frame_audio_emb[:, :, :1, :middle_index+1, ...] | |
| latter_first_frame_audio_emb = rearrange(latter_first_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") | |
| latter_last_frame_audio_emb = latter_frame_audio_emb[:, :, -1:, middle_index:, ...] | |
| latter_last_frame_audio_emb = rearrange(latter_last_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") | |
| latter_middle_frame_audio_emb = latter_frame_audio_emb[:, :, 1:-1, middle_index:middle_index+1, ...] | |
| latter_middle_frame_audio_emb = rearrange(latter_middle_frame_audio_emb, "b n_t n w s c -> b n_t (n w) s c") | |
| latter_frame_audio_emb_s = torch.concat([latter_first_frame_audio_emb, latter_middle_frame_audio_emb, latter_last_frame_audio_emb], dim=2) | |
| return [first_frame_audio_emb_s, latter_frame_audio_emb_s] | |
| def resize_and_centercrop(cond_image, target_size): | |
| """ | |
| Resize image or tensor to the target size without padding. | |
| """ | |
| # Get the original size | |
| if isinstance(cond_image, torch.Tensor): | |
| _, orig_h, orig_w = cond_image.shape | |
| else: | |
| orig_h, orig_w = cond_image.height, cond_image.width | |
| target_h, target_w = target_size | |
| # Calculate the scaling factor for resizing | |
| scale_h = target_h / orig_h | |
| scale_w = target_w / orig_w | |
| # Compute the final size | |
| scale = max(scale_h, scale_w) | |
| final_h = math.ceil(scale * orig_h) | |
| final_w = math.ceil(scale * orig_w) | |
| # Resize | |
| if isinstance(cond_image, torch.Tensor): | |
| if len(cond_image.shape) == 3: | |
| cond_image = cond_image[None] | |
| resized_tensor = nn.functional.interpolate(cond_image, size=(final_h, final_w), mode='nearest').contiguous() | |
| # crop | |
| cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) | |
| cropped_tensor = cropped_tensor.squeeze(0) | |
| else: | |
| resized_image = cond_image.resize((final_w, final_h), resample=Image.BILINEAR) | |
| resized_image = np.array(resized_image) | |
| # tensor and crop | |
| resized_tensor = torch.from_numpy(resized_image)[None, ...].permute(0, 3, 1, 2).contiguous() | |
| cropped_tensor = transforms.functional.center_crop(resized_tensor, target_size) | |
| cropped_tensor = cropped_tensor[:, :, None, :, :] | |
| return cropped_tensor | |
| def timestep_transform( | |
| t, | |
| shift=5.0, | |
| num_timesteps=1000, | |
| ): | |
| t = t / num_timesteps | |
| # shift the timestep based on ratio | |
| new_t = shift * t / (1 + (shift - 1) * t) | |
| new_t = new_t * num_timesteps | |
| return new_t | |
| def parse_speakers_locations(speakers_locations): | |
| bbox = {} | |
| if speakers_locations is None or len(speakers_locations) == 0: | |
| return None, "" | |
| speakers = speakers_locations.split(" ") | |
| if len(speakers) !=2: | |
| error= "Two speakers locations should be defined" | |
| return "", error | |
| for i, speaker in enumerate(speakers): | |
| location = speaker.strip().split(":") | |
| if len(location) not in (2,4): | |
| error = f"Invalid Speaker Location '{location}'. A Speaker Location should be defined in the format Left:Right or usuing a BBox Left:Top:Right:Bottom" | |
| return "", error | |
| try: | |
| good = False | |
| location_float = [ float(val) for val in location] | |
| good = all( 0 <= val <= 100 for val in location_float) | |
| except: | |
| pass | |
| if not good: | |
| error = f"Invalid Speaker Location '{location}'. Each number should be between 0 and 100." | |
| return "", error | |
| if len(location_float) == 2: | |
| location_float = [location_float[0], 0, location_float[1], 100] | |
| bbox[f"human{i}"] = location_float | |
| return bbox, "" | |
| # construct human mask | |
| def get_target_masks(HUMAN_NUMBER, lat_h, lat_w, src_h, src_w, face_scale = 0.05, bbox = None): | |
| human_masks = [] | |
| if HUMAN_NUMBER==1: | |
| background_mask = torch.ones([src_h, src_w]) | |
| human_mask1 = torch.ones([src_h, src_w]) | |
| human_mask2 = torch.ones([src_h, src_w]) | |
| human_masks = [human_mask1, human_mask2, background_mask] | |
| elif HUMAN_NUMBER==2: | |
| if bbox != None: | |
| assert len(bbox) == HUMAN_NUMBER, f"The number of target bbox should be the same with cond_audio" | |
| background_mask = torch.zeros([src_h, src_w]) | |
| for _, person_bbox in bbox.items(): | |
| y_min, x_min, y_max, x_max = person_bbox | |
| x_min, y_min, x_max, y_max = max(x_min,5), max(y_min, 5), min(x_max,95), min(y_max,95) | |
| x_min, y_min, x_max, y_max = int(src_h * x_min / 100), int(src_w * y_min / 100), int(src_h * x_max / 100), int(src_w * y_max / 100) | |
| human_mask = torch.zeros([src_h, src_w]) | |
| human_mask[int(x_min):int(x_max), int(y_min):int(y_max)] = 1 | |
| background_mask += human_mask | |
| human_masks.append(human_mask) | |
| else: | |
| x_min, x_max = int(src_h * face_scale), int(src_h * (1 - face_scale)) | |
| background_mask = torch.zeros([src_h, src_w]) | |
| background_mask = torch.zeros([src_h, src_w]) | |
| human_mask1 = torch.zeros([src_h, src_w]) | |
| human_mask2 = torch.zeros([src_h, src_w]) | |
| lefty_min, lefty_max = int((src_w//2) * face_scale), int((src_w//2) * (1 - face_scale)) | |
| righty_min, righty_max = int((src_w//2) * face_scale + (src_w//2)), int((src_w//2) * (1 - face_scale) + (src_w//2)) | |
| human_mask1[x_min:x_max, lefty_min:lefty_max] = 1 | |
| human_mask2[x_min:x_max, righty_min:righty_max] = 1 | |
| background_mask += human_mask1 | |
| background_mask += human_mask2 | |
| human_masks = [human_mask1, human_mask2] | |
| background_mask = torch.where(background_mask > 0, torch.tensor(0), torch.tensor(1)) | |
| human_masks.append(background_mask) | |
| # toto = Image.fromarray(human_masks[2].mul_(255).unsqueeze(-1).repeat(1,1,3).to(torch.uint8).cpu().numpy()) | |
| ref_target_masks = torch.stack(human_masks, dim=0) #.to(self.device) | |
| # resize and centercrop for ref_target_masks | |
| # ref_target_masks = resize_and_centercrop(ref_target_masks, (target_h, target_w)) | |
| N_h, N_w = lat_h // 2, lat_w // 2 | |
| token_ref_target_masks = F.interpolate(ref_target_masks.unsqueeze(0), size=(N_h, N_w), mode='nearest').squeeze() | |
| token_ref_target_masks = (token_ref_target_masks > 0) | |
| token_ref_target_masks = token_ref_target_masks.float() #.to(self.device) | |
| token_ref_target_masks = token_ref_target_masks.view(token_ref_target_masks.shape[0], -1) | |
| return token_ref_target_masks |