TTSS / local_server_new.py
Moustafa1111111111
Cleaned up: switched to runtime model download with huggingface_hub
4390e63
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from fastapi.responses import FileResponse
import logging
import torch
import os
from TTS.api import TTS
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
from langdetect import detect
from huggingface_hub import hf_hub_download
# Allowlist XttsConfig so torch.load doesn't raise UnpicklingError
from torch.serialization import add_safe_globals
from TTS.tts.configs.xtts_config import XttsConfig
add_safe_globals([XttsConfig])
# ✅ Monkey-patch torch.load to always use weights_only=False
_original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
kwargs["weights_only"] = False
return _original_torch_load(*args, **kwargs)
torch.load = patched_torch_load
logging.basicConfig(level=logging.DEBUG)
# Initialize FastAPI
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ✅ Dynamically download XTTS v2 model files from Hugging Face
try:
print("Downloading XTTS v2 model files from Hugging Face...")
model_path = hf_hub_download("coqui/XTTS-v2", "model.pth")
config_path = hf_hub_download("coqui/XTTS-v2", "config.json")
vocab_path = hf_hub_download("coqui/XTTS-v2", "vocab.json")
dvae_path = hf_hub_download("coqui/XTTS-v2", "dvae.pth")
speakers_path = hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth")
model_dir = os.path.dirname(model_path)
tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu")
print("✅ XTTS v2 model loaded successfully.")
except Exception as e:
print(f"❌ Failed to load XTTS v2 model: {e}")
raise RuntimeError("Failed to initialize TTS model.")
# Load sentiment models
arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter"
sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT")
sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english")
# Input class for POST body
class Message(BaseModel):
text: str
# Language detection
def detect_language_safely(text):
try:
if any('\u0600' <= c <= '\u06FF' for c in text):
return "ar"
return detect(text)
except:
return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en"
# Sentiment to emotion mapping
def map_sentiment_to_emotion(sentiment, language="en"):
if language == "ar":
return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral"
return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral"
# Simple Arabic sentiment analysis
def arabic_sentiment_analysis(text):
pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"]
neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"]
pos_count = sum(1 for word in pos_words if word in text.lower())
neg_count = sum(1 for word in neg_words if word in text.lower())
if pos_count > neg_count:
return "positive"
elif neg_count > pos_count:
return "negative"
else:
try:
inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
outputs = sentiment_model(**inputs)
sentiment_class = torch.argmax(outputs.logits).item()
return ["negative", "neutral", "positive"][sentiment_class]
except:
return "neutral"
# Main TTS endpoint
@app.post("/text-to-speech/")
def text_to_speech(msg: Message):
text = msg.text
language = detect_language_safely(text)
emotion = "neutral"
if language == "en":
try:
sentiment_result = sentiment_analyzer(text)[0]
emotion = map_sentiment_to_emotion(sentiment_result["label"])
except:
pass
else:
try:
sentiment_result = arabic_sentiment_analysis(text)
emotion = map_sentiment_to_emotion(sentiment_result, language="ar")
except:
pass
output_filename = "output.wav"
try:
tts.tts_to_file(
text=text,
file_path=output_filename,
emotion=emotion,
speaker_wav="/app/audio/speaker_reference.wav",
language=language
)
return {
"status": "success",
"audio_file": output_filename,
"url": "/audio"
}
except Exception as e:
return {"status": "error", "message": str(e)}
# ✅ Serve the audio file
@app.get("/audio")
def get_audio():
return FileResponse("output.wav", media_type="audio/wav", filename="output.wav")