from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from fastapi.responses import FileResponse import logging import torch import os from TTS.api import TTS from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline from langdetect import detect from huggingface_hub import hf_hub_download # Allowlist XttsConfig so torch.load doesn't raise UnpicklingError from torch.serialization import add_safe_globals from TTS.tts.configs.xtts_config import XttsConfig add_safe_globals([XttsConfig]) # ✅ Monkey-patch torch.load to always use weights_only=False _original_torch_load = torch.load def patched_torch_load(*args, **kwargs): kwargs["weights_only"] = False return _original_torch_load(*args, **kwargs) torch.load = patched_torch_load logging.basicConfig(level=logging.DEBUG) # Initialize FastAPI app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ✅ Dynamically download XTTS v2 model files from Hugging Face try: print("Downloading XTTS v2 model files from Hugging Face...") model_path = hf_hub_download("coqui/XTTS-v2", "model.pth") config_path = hf_hub_download("coqui/XTTS-v2", "config.json") vocab_path = hf_hub_download("coqui/XTTS-v2", "vocab.json") dvae_path = hf_hub_download("coqui/XTTS-v2", "dvae.pth") speakers_path = hf_hub_download("coqui/XTTS-v2", "speakers_xtts.pth") model_dir = os.path.dirname(model_path) tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu") print("✅ XTTS v2 model loaded successfully.") except Exception as e: print(f"❌ Failed to load XTTS v2 model: {e}") raise RuntimeError("Failed to initialize TTS model.") # Load sentiment models arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter" sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name) sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT") sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") # Input class for POST body class Message(BaseModel): text: str # Language detection def detect_language_safely(text): try: if any('\u0600' <= c <= '\u06FF' for c in text): return "ar" return detect(text) except: return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en" # Sentiment to emotion mapping def map_sentiment_to_emotion(sentiment, language="en"): if language == "ar": return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral" return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral" # Simple Arabic sentiment analysis def arabic_sentiment_analysis(text): pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"] neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"] pos_count = sum(1 for word in pos_words if word in text.lower()) neg_count = sum(1 for word in neg_words if word in text.lower()) if pos_count > neg_count: return "positive" elif neg_count > pos_count: return "negative" else: try: inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) outputs = sentiment_model(**inputs) sentiment_class = torch.argmax(outputs.logits).item() return ["negative", "neutral", "positive"][sentiment_class] except: return "neutral" # Main TTS endpoint @app.post("/text-to-speech/") def text_to_speech(msg: Message): text = msg.text language = detect_language_safely(text) emotion = "neutral" if language == "en": try: sentiment_result = sentiment_analyzer(text)[0] emotion = map_sentiment_to_emotion(sentiment_result["label"]) except: pass else: try: sentiment_result = arabic_sentiment_analysis(text) emotion = map_sentiment_to_emotion(sentiment_result, language="ar") except: pass output_filename = "output.wav" try: tts.tts_to_file( text=text, file_path=output_filename, emotion=emotion, speaker_wav="/app/audio/speaker_reference.wav", language=language ) return { "status": "success", "audio_file": output_filename, "url": "/audio" } except Exception as e: return {"status": "error", "message": str(e)} # ✅ Serve the audio file @app.get("/audio") def get_audio(): return FileResponse("output.wav", media_type="audio/wav", filename="output.wav")