DemoLahja / app.py
wasmdashai's picture
Update app.py
407171e verified
from logging import error
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer, VitsModel
import os
import numpy as np
import noisereduce as nr
import torch.nn as nn
from typing import Optional, Iterator
# قراءة التوكن من Secrets
token = os.getenv("acees-token") # تأكد أنك سميته بنفس الاسم في Settings → Repository secrets
# كائن لتخزين النماذج
models = {}
# اختيار الجهاز (CUDA لو متوفر، غير كذا CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
# دالة إزالة الضوضاء
def remove_noise_nr(audio_data, sr=16000):
return nr.reduce_noise(y=audio_data, hop_length=256, sr=sr)
# دالة inference (streaming / non-streaming)
def _inference_forward_stream(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
speaker_embeddings: Optional[torch.Tensor] = None,
chunk_size: int = 32,
is_streaming: bool = True
) -> Iterator[torch.Tensor]:
padding_mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else torch.ones_like(input_ids).unsqueeze(-1).float()
text_encoder_output = self.text_encoder(input_ids=input_ids, padding_mask=padding_mask, attention_mask=attention_mask)
hidden_states = text_encoder_output[0].transpose(1, 2)
input_padding_mask = padding_mask.transpose(1, 2)
log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
length_scale = 1.0 / self.speaking_rate
duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
predicted_lengths = torch.clamp_min(torch.sum(duration, [1,2]), 1).long()
indices = torch.arange(predicted_lengths.max(), device=predicted_lengths.device)
output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
batch_size, _, output_length, input_length = attn_mask.shape
cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
indices = torch.arange(output_length, device=duration.device)
valid_indices = indices.unsqueeze(0) < cum_duration
valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
padded_indices = valid_indices - nn.functional.pad(valid_indices, [0,0,1,0,0,0])[:, :-1]
attn = padded_indices.unsqueeze(1).transpose(2,3) * attn_mask
prior_means = text_encoder_output[1]
prior_log_variances = text_encoder_output[2]
prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
spectrogram = latents * output_padding_mask
if is_streaming:
for i in range(0, spectrogram.size(-1), chunk_size):
with torch.no_grad():
wav = self.decoder(spectrogram[:,:,i:i+chunk_size], speaker_embeddings)
yield wav.squeeze().cpu().numpy()
else:
with torch.no_grad():
wav = self.decoder(spectrogram, speaker_embeddings)
yield wav.squeeze().cpu().numpy()
# تحميل النموذج + التوكن
def get_model(name_model):
global models
if name_model in models:
tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
return models[name_model], tokenizer
models[name_model] = VitsModel.from_pretrained(name_model, token=token)
models[name_model].decoder.apply_weight_norm()
for flow in models[name_model].flow.flows:
torch.nn.utils.weight_norm(flow.conv_pre)
torch.nn.utils.weight_norm(flow.conv_post)
tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
return models[name_model], tokenizer
# النص الافتراضي
TXT = "السلام عليكم ورحمة الله وبركاته يا هلا وسهلا ومراحب بالغالي"
# دالة تحويل النص إلى كلام
def modelspeech(text=TXT, name_model="wasmdashai/vits-ar-sa-huba-v2", speaking_rate=16000):
model, tokenizer = get_model(name_model)
inputs = tokenizer(text, return_tensors="pt").to(device) # يشتغل على CPU أو GPU حسب المتوفر
model.speaking_rate = speaking_rate
with torch.no_grad():
outputs = model(**inputs)
waveform = outputs.waveform[0].cpu().numpy()
return model.config.sampling_rate, remove_noise_nr(waveform)
# واجهة Gradio
model_choices = gr.Dropdown(
choices=[
"wasmdashai/vits-ar-sa-huba-v1",
"wasmdashai/vits-ar-sa-huba-v2",
"wasmdashai/vits-ar-sa-A",
"wasmdashai/vits-ar-ye-sa",
"wasmdashai/vits-ar-sa-M-v1",
"wasmdashai/vits-en-v1"
],
label="اختر النموذج",
value="wasmdashai/vits-ar-sa-huba-v2"
)
demo = gr.Interface(
fn=modelspeech,
inputs=["text", model_choices, gr.Slider(0.1, 1, step=0.1, value=0.8)],
outputs=["audio"]
)
demo.queue()
demo.launch(server_name="0.0.0.0", server_port=7860)