import gradio as gr import torch import soundfile as sf import os import numpy as np import noisereduce as nr from typing import Optional, Iterator import torch.nn as nn from transformers import AutoTokenizer, VitsModel # لازم تتأكد أنك مستوردهم from concurrent.futures import ThreadPoolExecutor, as_completed # اختيار الجهاز (CPU أو GPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("✅ Running on:", device) token=os.environ.get("key_") models = {} from huggingface_hub import login login(token=token) # فلتر الضوضاء def remove_noise_nr(audio_data, sr=16000): reduced_noise = nr.reduce_noise(y=audio_data, hop_length=256, sr=sr) return reduced_noise def _inference_forward_stream( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, speaker_embeddings: torch.Tensor = None, chunk_size: int = 32, is_streaming: bool = True ): import torch.nn as nn padding_mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else torch.ones_like(input_ids).unsqueeze(-1).float() text_encoder_output = self.text_encoder( input_ids=input_ids, padding_mask=padding_mask, attention_mask=attention_mask ) hidden_states = text_encoder_output[0] hidden_states = hidden_states.transpose(1, 2) input_padding_mask = padding_mask.transpose(1, 2) prior_means = text_encoder_output[1] prior_log_variances = text_encoder_output[2] # حساب المدة if self.config.use_stochastic_duration_prediction: log_duration = self.duration_predictor( hidden_states, input_padding_mask, speaker_embeddings, reverse=True, noise_scale=self.noise_scale_duration ) else: log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings) length_scale = 1.0 / self.speaking_rate duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale) predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long() indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device) output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1) output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype) attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1) batch_size, _, output_length, input_length = attn_mask.shape cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1) indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device) valid_indices = indices.unsqueeze(0) < cum_duration valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length) padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1] attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2) prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2) prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True) spectrogram = latents * output_padding_mask if is_streaming: for i in range(0, spectrogram.size(-1), chunk_size): with torch.no_grad(): yield spectrogram[:, :, i: i + chunk_size] else: yield spectrogram def get_model(name_model): global models if name_model in models: if name_model=='wasmdashai/vits-en-v1': tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-en-v1",token=token) else: tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token) return models[name_model],tokenizer models[name_model]=VitsModel.from_pretrained(name_model,token=token) models[name_model].decoder.apply_weight_norm() # torch.nn.utils.weight_norm(self.decoder.conv_pre) # torch.nn.utils.weight_norm(self.decoder.conv_post) for flow in models[name_model].flow.flows: torch.nn.utils.weight_norm(flow.conv_pre) torch.nn.utils.weight_norm(flow.conv_post) if name_model=='wasmdashai/vits-en-v1': tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vits-en-v1",token=token) else: tokenizer = AutoTokenizer.from_pretrained("wasmdashai/vtk",token=token) return models[name_model],tokenizer TXT = """السلام عليكم ورحمة الله وبركاته يا هلا وسهلا ومراحب بالغالي اخباركم طيبين ان شاء الله ارحبوا على العين والراس""" def process_chunk(chunk_id, spectrogram_chunk, speaker_embeddings, decoder): with torch.no_grad(): wav = decoder(torch.tensor(spectrogram_chunk), speaker_embeddings) wav = wav.squeeze().cpu().numpy() file_path = f"audio_chunks/chunk_{chunk_id}.wav" sf.write(file_path, wav, samplerate=16000) return file_path def modelspeech(text=TXT, name_model="wasmdashai/vits-ar-sa-huba-v2", speaking_rate=0.9): os.makedirs("audio_chunks", exist_ok=True) model, tokenizer = get_model(name_model) model.config.sampling_rate=16000 #text = ask_ai(text) inputs = tokenizer(text, return_tensors="pt").to(device) model.speaking_rate = speaking_rate chunk_files = [] with ThreadPoolExecutor(max_workers=8) as executor: futures = [] chunk_id = 0 for spectrogram_chunk in _inference_forward_stream( model, input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, speaker_embeddings=None, is_streaming=True, chunk_size=32 ): futures.append(executor.submit(process_chunk, chunk_id, spectrogram_chunk, None, model.decoder)) chunk_id += 1 for future in as_completed(futures): chunk_files.append(future.result()) chunk_files.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) all_audio = np.concatenate([sf.read(f)[0] for f in chunk_files]) return (model.config.sampling_rate, remove_noise_nr(all_audio)) model_choices = gr.Dropdown( choices=[ "wasmdashai/vits-ar-sa-huba-v1", "wasmdashai/vits-ar-sa-huba-v2", "wasmdashai/vits-ar-sa-A", "wasmdashai/vits-ar-ye-sa", "wasmdashai/vits-ar-sa-M-v2", "wasmdashai/model-dash-fahd", "wasmdashai/vits-ar-huba-fine", 'wasmdashai/vits-en-v1' ], label="اختر النموذج", value="wasmdashai/vits-ar-sa-huba-v2", ) demo = gr.Interface( fn=modelspeech, inputs=["text", model_choices, gr.Slider(0.1, 1, step=0.1, value=0.8)], outputs=[gr.Audio(autoplay=True)] ) demo.queue() demo.launch(debug=True)