wasmdashai commited on
Commit
bf679e1
·
verified ·
1 Parent(s): 4142f53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -108
app.py CHANGED
@@ -1,131 +1,82 @@
1
- from logging import error
2
- import gradio as gr
3
- import spaces
4
- import torch
5
  from transformers import AutoTokenizer, VitsModel
6
- import os
7
  import numpy as np
 
8
  import noisereduce as nr
9
- import torch.nn as nn
10
- from typing import Optional, Iterator
 
11
 
12
  # قراءة التوكن من Secrets
13
- token = os.getenv("acees-token") # تأكد أنك سميته بنفس الاسم في Settings → Repository secrets
14
 
15
- # كائن لتخزين النماذج
16
  models = {}
17
 
18
- # اختيار الجهاز (CUDA لو متوفر، غير كذا CPU)
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
-
22
- # دالة إزالة الضوضاء
23
- def remove_noise_nr(audio_data, sr=16000):
24
  return nr.reduce_noise(y=audio_data, hop_length=256, sr=sr)
25
 
26
-
27
- # دالة inference (streaming / non-streaming)
28
- def _inference_forward_stream(
29
- self,
30
- input_ids: Optional[torch.Tensor] = None,
31
- attention_mask: Optional[torch.Tensor] = None,
32
- speaker_embeddings: Optional[torch.Tensor] = None,
33
- chunk_size: int = 32,
34
- is_streaming: bool = True
35
- ) -> Iterator[torch.Tensor]:
36
-
37
- padding_mask = attention_mask.unsqueeze(-1).float() if attention_mask is not None else torch.ones_like(input_ids).unsqueeze(-1).float()
38
- text_encoder_output = self.text_encoder(input_ids=input_ids, padding_mask=padding_mask, attention_mask=attention_mask)
39
- hidden_states = text_encoder_output[0].transpose(1, 2)
40
- input_padding_mask = padding_mask.transpose(1, 2)
41
-
42
- log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
43
- length_scale = 1.0 / self.speaking_rate
44
- duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
45
- predicted_lengths = torch.clamp_min(torch.sum(duration, [1,2]), 1).long()
46
-
47
- indices = torch.arange(predicted_lengths.max(), device=predicted_lengths.device)
48
- output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
49
- output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
50
-
51
- attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
52
- batch_size, _, output_length, input_length = attn_mask.shape
53
- cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
54
- indices = torch.arange(output_length, device=duration.device)
55
- valid_indices = indices.unsqueeze(0) < cum_duration
56
- valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
57
- padded_indices = valid_indices - nn.functional.pad(valid_indices, [0,0,1,0,0,0])[:, :-1]
58
- attn = padded_indices.unsqueeze(1).transpose(2,3) * attn_mask
59
-
60
- prior_means = text_encoder_output[1]
61
- prior_log_variances = text_encoder_output[2]
62
- prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
63
- latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
64
- spectrogram = latents * output_padding_mask
65
-
66
- if is_streaming:
67
- for i in range(0, spectrogram.size(-1), chunk_size):
68
- with torch.no_grad():
69
- wav = self.decoder(spectrogram[:,:,i:i+chunk_size], speaker_embeddings)
70
- yield wav.squeeze().cpu().numpy()
71
- else:
72
- with torch.no_grad():
73
- wav = self.decoder(spectrogram, speaker_embeddings)
74
- yield wav.squeeze().cpu().numpy()
75
-
76
-
77
- # تحميل النموذج + التوكن
78
  def get_model(name_model):
79
  global models
80
  if name_model in models:
81
  tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
82
  return models[name_model], tokenizer
83
 
84
- models[name_model] = VitsModel.from_pretrained(name_model, token=token)
85
- models[name_model].decoder.apply_weight_norm()
86
- for flow in models[name_model].flow.flows:
87
  torch.nn.utils.weight_norm(flow.conv_pre)
88
  torch.nn.utils.weight_norm(flow.conv_post)
 
 
89
 
90
  tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
91
- return models[name_model], tokenizer
92
-
93
-
94
- # النص الافتراضي
95
- TXT = "السلام عليكم ورحمة الله وبركاته يا هلا وسهلا ومراحب بالغالي"
96
-
97
-
98
- # دالة تحويل النص إلى كلام
99
- def modelspeech(text=TXT, name_model="wasmdashai/vits-ar-sa-huba-v2", speaking_rate=16000):
100
- model, tokenizer = get_model(name_model)
101
- inputs = tokenizer(text, return_tensors="pt").to(device) # يشتغل على CPU أو GPU حسب المت��فر
102
- model.speaking_rate = speaking_rate
103
- with torch.no_grad():
104
- outputs = model(**inputs)
105
- waveform = outputs.waveform[0].cpu().numpy()
106
- return model.config.sampling_rate, remove_noise_nr(waveform)
107
-
108
-
109
- # واجهة Gradio
110
- model_choices = gr.Dropdown(
111
- choices=[
112
- "wasmdashai/vits-ar-sa-huba-v1",
113
- "wasmdashai/vits-ar-sa-huba-v2",
114
- "wasmdashai/vits-ar-sa-A",
115
- "wasmdashai/vits-ar-ye-sa",
116
- "wasmdashai/vits-ar-sa-M-v1",
117
- "wasmdashai/vits-en-v1"
118
- ],
119
- label="اختر النموذج",
120
- value="wasmdashai/vits-ar-sa-huba-v2"
121
- )
122
-
123
- demo = gr.Interface(
124
- fn=modelspeech,
125
- inputs=["text", model_choices, gr.Slider(0.1, 1, step=0.1, value=0.8)],
126
- outputs=["audio"]
127
- )
128
-
129
- demo.queue()
130
- demo.launch(server_name="0.0.0.0", server_port=7860)
131
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
 
 
3
  from transformers import AutoTokenizer, VitsModel
4
+ import torch
5
  import numpy as np
6
+ import os
7
  import noisereduce as nr
8
+ import base64
9
+ import io
10
+ import soundfile as sf
11
 
12
  # قراءة التوكن من Secrets
13
+ token = os.getenv("acees-token")
14
 
15
+ # تخزين النماذج
16
  models = {}
17
 
18
+ # اختيار الجهاز
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
21
+ # إزالة الضوضاء
22
+ def remove_noise(audio_data, sr=16000):
 
23
  return nr.reduce_noise(y=audio_data, hop_length=256, sr=sr)
24
 
25
+ # تحميل النموذج
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def get_model(name_model):
27
  global models
28
  if name_model in models:
29
  tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
30
  return models[name_model], tokenizer
31
 
32
+ model = VitsModel.from_pretrained(name_model, token=token)
33
+ model.decoder.apply_weight_norm()
34
+ for flow in model.flow.flows:
35
  torch.nn.utils.weight_norm(flow.conv_pre)
36
  torch.nn.utils.weight_norm(flow.conv_post)
37
+ model.to(device)
38
+ models[name_model] = model
39
 
40
  tokenizer = AutoTokenizer.from_pretrained(name_model, token=token)
41
+ return model, tokenizer
42
+
43
+ # نموذج البيانات للـ POST
44
+ class TTSRequest(BaseModel):
45
+ text: str
46
+ name_model: str = "wasmdashai/vits-ar-sa-huba-v2"
47
+ speaking_rate: float = 16000.0
48
+
49
+ # إنشاء التطبيق
50
+ app = FastAPI(title="VITS TTS API", description="Convert Arabic/English text to speech using VITS models")
51
+
52
+ @app.get("/", summary="Health check")
53
+ def home():
54
+ return {"message": "FastAPI VITS TTS service is running"}
55
+
56
+ @app.post("/predict/", summary="Text-to-Speech", description="Convert text to audio (WAV, Base64)")
57
+ def modelspeech(req: TTSRequest):
58
+ try:
59
+ model, tokenizer = get_model(req.name_model)
60
+ inputs = tokenizer(req.text, return_tensors="pt").to(device)
61
+ model.speaking_rate = req.speaking_rate
62
+
63
+ with torch.no_grad():
64
+ outputs = model(**inputs)
65
+ waveform = outputs.waveform[0].cpu().numpy()
66
+
67
+ # إزالة الضوضاء
68
+ waveform = remove_noise(waveform)
69
+
70
+ # تحويل الصوت إلى Base64 WAV
71
+ buffer = io.BytesIO()
72
+ sf.write(buffer, waveform, samplerate=model.config.sampling_rate, format="WAV")
73
+ buffer.seek(0)
74
+ audio_base64 = base64.b64encode(buffer.read()).decode("utf-8")
75
+
76
+ return {
77
+ "sampling_rate": model.config.sampling_rate,
78
+ "audio_base64": audio_base64
79
+ }
 
80
 
81
+ except Exception as e:
82
+ raise HTTPException(status_code=500, detail=str(e))