Spaces:
Sleeping
Sleeping
| from transformers import AutoProcessor | |
| from optimum.onnxruntime import ORTModelForSpeechSeq2Seq | |
| import torch | |
| import os | |
| import base64 | |
| import tempfile | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| import uvicorn | |
| import deepl | |
| from dotenv import load_dotenv | |
| import soundfile as sf | |
| import logging | |
| # --- Basic Configuration --- | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| # --- Load environment variables and initialize DeepL --- | |
| load_dotenv() | |
| DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY") | |
| deepl_translator = None | |
| if DEEPL_AUTH_KEY: | |
| try: | |
| deepl_translator = deepl.Translator(DEEPL_AUTH_KEY) | |
| logging.info("DeepL translator initialized successfully.") | |
| except Exception as e: | |
| logging.error(f"Error initializing DeepL translator: {e}") | |
| else: | |
| logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.") | |
| # --- Load Models --- | |
| logging.info("Loading all models...") | |
| # ASR Model | |
| asr_model_id = "openai/whisper-base" | |
| asr_model = None | |
| asr_processor = None | |
| try: | |
| asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider") | |
| asr_processor = AutoProcessor.from_pretrained(asr_model_id) | |
| # FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids' | |
| # that clashes with the latest library versions. The library both requires this attribute | |
| # to exist, but also requires it to be None to avoid a conflict. | |
| if hasattr(asr_model.config, 'forced_decoder_ids'): | |
| logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.") | |
| asr_model.config.forced_decoder_ids = None | |
| if hasattr(asr_model.generation_config, 'forced_decoder_ids'): | |
| logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.") | |
| asr_model.generation_config.forced_decoder_ids = None | |
| logging.info("ASR model and processor loaded and configured successfully.") | |
| except Exception as e: | |
| logging.error(f"Fatal error loading ASR model: {e}", exc_info=True) | |
| # Translation Pipelines | |
| from transformers import pipeline | |
| translators = {} | |
| try: | |
| translators = { | |
| "en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"), | |
| "zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"), | |
| "en-ja": pipeline("translation", model="staka/fugumt-en-ja"), | |
| "ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"), | |
| "en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"), | |
| "ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"), | |
| } | |
| logging.info("Translation models loaded successfully.") | |
| except Exception as e: | |
| logging.error(f"Failed to load translation models: {e}") | |
| # --- Core Logic Functions --- | |
| def transcribe_audio(audio_bytes): | |
| if not asr_model or not asr_processor: | |
| return None, "ASR model or processor is not available." | |
| try: | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file: | |
| tmp_file.write(audio_bytes) | |
| audio_path = tmp_file.name | |
| audio_input, sample_rate = sf.read(audio_path) | |
| if audio_input.ndim > 1: | |
| audio_input = audio_input.mean(axis=1) | |
| input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features | |
| # By setting forced_decoder_ids to None in the config, we can now safely | |
| # let the generate function handle the task without conflicts. | |
| predicted_ids = asr_model.generate(input_features, task="transcribe") | |
| text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
| logging.info(f"ASR transcribed text: '{text}'") | |
| os.remove(audio_path) | |
| return text, None | |
| except Exception as e: | |
| logging.error(f"ASR transcription failed: {e}", exc_info=True) | |
| if 'audio_path' in locals() and os.path.exists(audio_path): | |
| os.remove(audio_path) | |
| return None, str(e) | |
| def translate_text(text, source_lang, target_lang): | |
| # Priority 1: Use DeepL for specific, high-quality pairs if available | |
| if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')): | |
| try: | |
| dl_source_lang = "ZH" if source_lang == 'zh' else "EN" | |
| logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}") | |
| result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA") | |
| return result.text, None | |
| except Exception as e: | |
| logging.error(f"DeepL failed: {e}. Falling back to HF models.") | |
| # Priority 2: Try direct HF translation | |
| model_key = f"{source_lang}-{target_lang}" | |
| translator = translators.get(model_key) | |
| if translator: | |
| try: | |
| logging.info(f"Attempting direct HF translation for {model_key}") | |
| translated_text = translator(text, max_length=512)[0]['translation_text'] | |
| return translated_text, None | |
| except Exception as e: | |
| logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True) | |
| # Don't return here, allow fallback to pivot | |
| # Priority 3: Try pivot translation via English | |
| if source_lang != 'en' and target_lang != 'en': | |
| to_en_key = f"{source_lang}-en" | |
| from_en_key = f"en-{target_lang}" | |
| translator_to_en = translators.get(to_en_key) | |
| translator_from_en = translators.get(from_en_key) | |
| if translator_to_en and translator_from_en: | |
| try: | |
| logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}") | |
| # Step 1: Source to English | |
| english_text = translator_to_en(text, max_length=512)[0]['translation_text'] | |
| logging.info(f"Pivot step (to en) result: '{english_text}'") | |
| # Step 2: English to Target | |
| final_text = translator_from_en(english_text, max_length=512)[0]['translation_text'] | |
| logging.info(f"Pivot step (from en) result: '{final_text}'") | |
| return final_text, None | |
| except Exception as e: | |
| logging.error(f"Pivot translation failed: {e}", exc_info=True) | |
| # If all else fails | |
| logging.warning(f"No translation path found for {source_lang} -> {target_lang}") | |
| return None, f"No model available for {source_lang} to {target_lang}" | |
| # --- FastAPI App --- | |
| app = FastAPI() | |
| def root(): | |
| return {"status": "ok", "message": "Translator API is running."} | |
| async def api_asr(request: Request): | |
| try: | |
| body = await request.json() | |
| audio_b64 = body.get('audio_base64') | |
| if not audio_b64: | |
| logging.error("Request is missing 'audio_base64'") | |
| return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"}) | |
| audio_bytes = base64.b64decode(audio_b64) | |
| text, error = transcribe_audio(audio_bytes) | |
| if error: | |
| logging.error(f"ASR transcription function returned an error: {error}") | |
| return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"}) | |
| response_data = {"text": text} | |
| logging.info(f"Returning ASR response: {response_data}") | |
| return JSONResponse(content=response_data) | |
| except Exception as e: | |
| logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True) | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| async def api_translate(request: Request): | |
| try: | |
| body = await request.json() | |
| text = body.get('text') | |
| source_lang = body.get('source_lang') | |
| target_lang = body.get('target_lang') | |
| if not all([text, source_lang, target_lang]): | |
| return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"}) | |
| translated_text, error = translate_text(text, source_lang, target_lang) | |
| if error: | |
| return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"}) | |
| response_data = {"translated_text": translated_text} | |
| logging.info(f"Returning translation response: {response_data}") | |
| return JSONResponse(content=response_data) | |
| except Exception as e: | |
| logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True) | |
| return JSONResponse(status_code=500, content={"error": str(e)}) | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |