Spaces:
Runtime error
Runtime error
| import time | |
| from fairseq import checkpoint_utils, distributed_utils, options, tasks, utils | |
| from inference.engine import Model | |
| from flask import Flask, request | |
| from flask import jsonify | |
| from flask_cors import CORS, cross_origin | |
| import webvtt | |
| from io import StringIO | |
| app = Flask(__name__) | |
| cors = CORS(app) | |
| app.config['CORS_HEADERS'] = 'Content-Type' | |
| indic2en_model = Model(expdir='../models/v3/indic-en') | |
| en2indic_model = Model(expdir='../models/v3/en-indic') | |
| m2m_model = Model(expdir='../models/m2m') | |
| language_dict = { | |
| 'Assamese': 'as', | |
| 'Hindi' : 'hi', | |
| 'Marathi' : 'mr', | |
| 'Tamil' : 'ta', | |
| 'Bengali' : 'bn', | |
| 'Kannada' : 'kn', | |
| 'Oriya' : 'or', | |
| 'Telugu' : 'te', | |
| 'Gujarati' : 'gu', | |
| 'Malayalam' : 'ml', | |
| 'Punjabi' : 'pa', | |
| } | |
| def get_inference_params(): | |
| model_type = request.form['model_type'] | |
| source_language = request.form['source_language'] | |
| target_language = request.form['target_language'] | |
| if model_type == 'indic-en': | |
| model = indic2en_model | |
| source_lang = language_dict[source_language] | |
| assert target_language == 'English' | |
| target_lang = 'en' | |
| elif model_type == 'en-indic': | |
| model = en2indic_model | |
| assert source_language == 'English' | |
| source_lang = 'en' | |
| target_lang = language_dict[target_language] | |
| elif model_type == 'm2m': | |
| model = m2m_model | |
| source_lang = language_dict[source_language] | |
| target_lang = language_dict[target_language] | |
| return model, source_lang, target_lang | |
| def main(): | |
| return "IndicTrans API" | |
| def infer_indic_en(): | |
| model, source_lang, target_lang = get_inference_params() | |
| source_text = request.form['text'] | |
| start_time = time.time() | |
| target_text = model.translate_paragraph(source_text, source_lang, target_lang) | |
| end_time = time.time() | |
| return {'text':target_text, 'duration':round(end_time-start_time, 2)} | |
| def infer_vtt_indic_en(): | |
| model, source_lang, target_lang = get_inference_params() | |
| source_text = request.form['text'] | |
| captions = webvtt.read_buffer(StringIO(source_text)) | |
| source_sentences = [caption.text.replace('\r', '').replace('\n', ' ') for caption in captions] | |
| start_time = time.time() | |
| target_sentences = model.batch_translate(source_sentences, source_lang, target_lang) | |
| end_time = time.time() | |
| for i in range(len(target_sentences)): | |
| captions[i].text = target_sentences[i] | |
| return {'text': captions.content, 'duration':round(end_time-start_time, 2)} | |