Spaces:
Build error
Build error
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import logging | |
| import torch | |
| import os | |
| from TTS.api import TTS | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline | |
| from langdetect import detect | |
| # Allowlist XttsConfig so torch.load doesn't raise UnpicklingError | |
| from torch.serialization import add_safe_globals | |
| from TTS.tts.configs.xtts_config import XttsConfig | |
| add_safe_globals([XttsConfig]) | |
| # ✅ Monkey-patch torch.load to always use weights_only=False | |
| _original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs["weights_only"] = False | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| logging.basicConfig(level=logging.DEBUG) | |
| # Initialize FastAPI | |
| app = FastAPI() | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Load TTS model from local files | |
| try: | |
| model_dir = "/app/models/xtts_v2" | |
| config_path = os.path.join(model_dir, "config.json") | |
| # When providing config_path, TTS might expect the directory for model_path | |
| tts = TTS(model_path=model_dir, config_path=config_path).to("cuda" if torch.cuda.is_available() else "cpu") | |
| print("XTTS v2 model loaded successfully from local files.") | |
| except Exception as e: | |
| print(f"Error loading XTTS v2 model from local files: {e}") | |
| print("Falling back to loading by model name (license might be required).") | |
| tts = TTS("tts_models/multilingual/multi-dataset-xtts_v2").to("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load sentiment models | |
| arabic_model_name = "aubmindlab/bert-base-arabertv02-twitter" | |
| sentiment_tokenizer = AutoTokenizer.from_pretrained(arabic_model_name) | |
| sentiment_model = AutoModelForSequenceClassification.from_pretrained("UBC-NLP/MARBERT") | |
| sentiment_analyzer = pipeline("sentiment-analysis", model="distilbert-base-uncased-finetuned-sst-2-english") | |
| # Input class for POST body | |
| class Message(BaseModel): | |
| text: str | |
| # Language detection | |
| def detect_language_safely(text): | |
| try: | |
| if any('\u0600' <= c <= '\u06FF' for c in text): | |
| return "ar" | |
| return detect(text) | |
| except: | |
| return "ar" if any('\u0600' <= c <= '\u06FF' for c in text) else "en" | |
| # Sentiment to emotion mapping | |
| def map_sentiment_to_emotion(sentiment, language="en"): | |
| if language == "ar": | |
| return "happy" if sentiment == "positive" else "sad" if sentiment == "negative" else "neutral" | |
| return "happy" if "positive" in sentiment.lower() else "sad" if "negative" in sentiment.lower() else "neutral" | |
| # Simple Arabic sentiment analysis | |
| def arabic_sentiment_analysis(text): | |
| pos_words = ["سعيد", "فرح", "ممتاز", "رائع", "جيد", "حب", "جميل", "نجاح", "أحسنت", "شكرا"] | |
| neg_words = ["حزين", "غاضب", "سيء", "فشل", "خطأ", "مشكلة", "صعب", "لا أحب", "سخيف", "مؤسف"] | |
| pos_count = sum(1 for word in pos_words if word in text.lower()) | |
| neg_count = sum(1 for word in neg_words if word in text.lower()) | |
| if pos_count > neg_count: | |
| return "positive" | |
| elif neg_count > pos_count: | |
| return "negative" | |
| else: | |
| try: | |
| inputs = sentiment_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) | |
| outputs = sentiment_model(**inputs) | |
| sentiment_class = torch.argmax(outputs.logits).item() | |
| return ["negative", "neutral", "positive"][sentiment_class] | |
| except: | |
| return "neutral" | |
| # Main TTS endpoint | |
| def text_to_speech(msg: Message): | |
| text = msg.text | |
| language = detect_language_safely(text) | |
| emotion = "neutral" | |
| if language == "en": | |
| try: | |
| sentiment_result = sentiment_analyzer(text)[0] | |
| emotion = map_sentiment_to_emotion(sentiment_result["label"]) | |
| except: | |
| pass | |
| else: | |
| try: | |
| sentiment_result = arabic_sentiment_analysis(text) | |
| emotion = map_sentiment_to_emotion(sentiment_result, language="ar") | |
| except: | |
| pass | |
| output_filename = "output.wav" | |
| try: | |
| tts.tts_to_file( | |
| text=text, | |
| file_path=output_filename, | |
| emotion=emotion, | |
| speaker_wav="/app/audio/speaker_reference.wav", # Updated path | |
| language=language | |
| ) | |
| return { | |
| "status": "success", | |
| "audio_file": output_filename, | |
| "url": "/audio" | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| # ✅ Serve the audio file | |
| def get_audio(): | |
| return FileResponse("output.wav", media_type="audio/wav", filename="output.wav") | |
| # Serve static files (your web page) from the 'web' directory | |
| app.mount("/", StaticFiles(directory="web", html=True), name="static") |