File size: 8,808 Bytes
b3b0b53
ad71343
21e5a32
ad71343
 
 
 
 
 
 
 
 
b3b0b53
 
 
 
ad71343
 
 
 
 
 
 
 
 
b3b0b53
ad71343
b3b0b53
ad71343
b3b0b53
21e5a32
b3b0b53
 
21e5a32
b3b0b53
ad71343
b3b0b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad71343
b3b0b53
 
21e5a32
ad71343
 
 
 
 
 
b3b0b53
ad71343
 
b6b4609
b3b0b53
 
 
 
ad71343
b3b0b53
 
b6b4609
ad71343
 
21e5a32
b3b0b53
ad71343
 
 
21e5a32
 
b3b0b53
 
ad71343
b3b0b53
 
 
ad71343
 
b3b0b53
ad71343
b3b0b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad71343
 
b3b0b53
 
 
 
ad71343
 
 
b3b0b53
 
 
 
 
 
 
 
 
 
ad71343
b3b0b53
ad71343
b3b0b53
 
 
 
 
ad71343
b3b0b53
 
ad71343
 
 
b3b0b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21e5a32
b3b0b53
21e5a32
b3b0b53
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from transformers import AutoProcessor
from optimum.onnxruntime import ORTModelForSpeechSeq2Seq
import torch
import os
import base64
import tempfile
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
import uvicorn
import deepl
from dotenv import load_dotenv
import soundfile as sf
import logging

# --- Basic Configuration ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# --- Load environment variables and initialize DeepL ---
load_dotenv()

DEEPL_AUTH_KEY = os.getenv("DEEPL_AUTH_KEY")
deepl_translator = None
if DEEPL_AUTH_KEY:
    try:
        deepl_translator = deepl.Translator(DEEPL_AUTH_KEY)
        logging.info("DeepL translator initialized successfully.")
    except Exception as e:
        logging.error(f"Error initializing DeepL translator: {e}")
else:
    logging.warning("DEEPL_AUTH_KEY not found. DeepL will be unavailable.")

# --- Load Models ---
logging.info("Loading all models...")

# ASR Model
asr_model_id = "openai/whisper-base"
asr_model = None
asr_processor = None
try:
    asr_model = ORTModelForSpeechSeq2Seq.from_pretrained(asr_model_id, provider="CPUExecutionProvider")
    asr_processor = AutoProcessor.from_pretrained(asr_model_id)

    # FINAL, CRITICAL FIX: The model's default config has a conflicting 'forced_decoder_ids'
    # that clashes with the latest library versions. The library both requires this attribute
    # to exist, but also requires it to be None to avoid a conflict.
    if hasattr(asr_model.config, 'forced_decoder_ids'):
        logging.info("Found conflicting 'forced_decoder_ids' in model config. Setting to None.")
        asr_model.config.forced_decoder_ids = None
    
    if hasattr(asr_model.generation_config, 'forced_decoder_ids'):
        logging.info("Found conflicting 'forced_decoder_ids' in generation_config. Setting to None.")
        asr_model.generation_config.forced_decoder_ids = None

    logging.info("ASR model and processor loaded and configured successfully.")
except Exception as e:
    logging.error(f"Fatal error loading ASR model: {e}", exc_info=True)

# Translation Pipelines
from transformers import pipeline
translators = {}
try:
    translators = {
        "en-zh": pipeline("translation", model="Helsinki-NLP/opus-mt-en-zh"),
        "zh-en": pipeline("translation", model="Helsinki-NLP/opus-mt-zh-en"),
        "en-ja": pipeline("translation", model="staka/fugumt-en-ja"),
        "ja-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ja-en"),
        "en-ko": pipeline("translation", model="Helsinki-NLP/opus-mt-tc-big-en-ko"),
        "ko-en": pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en"),
    }
    logging.info("Translation models loaded successfully.")
except Exception as e:
    logging.error(f"Failed to load translation models: {e}")

# --- Core Logic Functions ---
def transcribe_audio(audio_bytes):
    if not asr_model or not asr_processor:
        return None, "ASR model or processor is not available."
    try:
        with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
            tmp_file.write(audio_bytes)
            audio_path = tmp_file.name

        audio_input, sample_rate = sf.read(audio_path)
        if audio_input.ndim > 1:
            audio_input = audio_input.mean(axis=1)

        input_features = asr_processor(audio_input, sampling_rate=16000, return_tensors="pt").input_features
        
        # By setting forced_decoder_ids to None in the config, we can now safely
        # let the generate function handle the task without conflicts.
        predicted_ids = asr_model.generate(input_features, task="transcribe")
        
        text = asr_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
        
        logging.info(f"ASR transcribed text: '{text}'")

        os.remove(audio_path)
        return text, None
    except Exception as e:
        logging.error(f"ASR transcription failed: {e}", exc_info=True)
        if 'audio_path' in locals() and os.path.exists(audio_path):
            os.remove(audio_path)
        return None, str(e)

def translate_text(text, source_lang, target_lang):
    # Priority 1: Use DeepL for specific, high-quality pairs if available
    if deepl_translator and ((source_lang == 'zh' and target_lang == 'ja') or (source_lang == 'en' and target_lang == 'ja')):
        try:
            dl_source_lang = "ZH" if source_lang == 'zh' else "EN"
            logging.info(f"Attempting DeepL translation for {source_lang} -> {target_lang}")
            result = deepl_translator.translate_text(text, source_lang=dl_source_lang, target_lang="JA")
            return result.text, None
        except Exception as e:
            logging.error(f"DeepL failed: {e}. Falling back to HF models.")

    # Priority 2: Try direct HF translation
    model_key = f"{source_lang}-{target_lang}"
    translator = translators.get(model_key)
    if translator:
        try:
            logging.info(f"Attempting direct HF translation for {model_key}")
            translated_text = translator(text, max_length=512)[0]['translation_text']
            return translated_text, None
        except Exception as e:
            logging.error(f"Direct HF translation for {model_key} failed: {e}", exc_info=True)
            # Don't return here, allow fallback to pivot

    # Priority 3: Try pivot translation via English
    if source_lang != 'en' and target_lang != 'en':
        to_en_key = f"{source_lang}-en"
        from_en_key = f"en-{target_lang}"
        translator_to_en = translators.get(to_en_key)
        translator_from_en = translators.get(from_en_key)

        if translator_to_en and translator_from_en:
            try:
                logging.info(f"Attempting pivot translation for {source_lang} -> en -> {target_lang}")
                # Step 1: Source to English
                english_text = translator_to_en(text, max_length=512)[0]['translation_text']
                logging.info(f"Pivot step (to en) result: '{english_text}'")
                # Step 2: English to Target
                final_text = translator_from_en(english_text, max_length=512)[0]['translation_text']
                logging.info(f"Pivot step (from en) result: '{final_text}'")
                return final_text, None
            except Exception as e:
                logging.error(f"Pivot translation failed: {e}", exc_info=True)

    # If all else fails
    logging.warning(f"No translation path found for {source_lang} -> {target_lang}")
    return None, f"No model available for {source_lang} to {target_lang}"

# --- FastAPI App ---
app = FastAPI()

@app.get("/")
def root():
    return {"status": "ok", "message": "Translator API is running."}

@app.post("/api/asr")
async def api_asr(request: Request):
    try:
        body = await request.json()
        audio_b64 = body.get('audio_base64')
        if not audio_b64:
            logging.error("Request is missing 'audio_base64'")
            return JSONResponse(status_code=400, content={"error": "No audio_base64 found in request"})

        audio_bytes = base64.b64decode(audio_b64)
        
        text, error = transcribe_audio(audio_bytes)

        if error:
            logging.error(f"ASR transcription function returned an error: {error}")
            return JSONResponse(status_code=500, content={"error": f"ASR Error: {error}"})

        response_data = {"text": text}
        logging.info(f"Returning ASR response: {response_data}")
        return JSONResponse(content=response_data)

    except Exception as e:
        logging.error(f"Critical error in /api/asr endpoint: {e}", exc_info=True)
        return JSONResponse(status_code=500, content={"error": str(e)})

@app.post("/api/translate")
async def api_translate(request: Request):
    try:
        body = await request.json()
        text = body.get('text')
        source_lang = body.get('source_lang')
        target_lang = body.get('target_lang')

        if not all([text, source_lang, target_lang]):
            return JSONResponse(status_code=400, content={"error": "Missing parameters: text, source_lang, or target_lang"})

        translated_text, error = translate_text(text, source_lang, target_lang)

        if error:
            return JSONResponse(status_code=500, content={"error": f"Translation Error: {error}"})
        
        response_data = {"translated_text": translated_text}
        logging.info(f"Returning translation response: {response_data}")
        return JSONResponse(content=response_data)

    except Exception as e:
        logging.error(f"Error in /api/translate endpoint: {e}", exc_info=True)
        return JSONResponse(status_code=500, content={"error": str(e)})

# --- Main Execution ---
if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)