Mgolo's picture
Update app.py
b19fe5a verified
raw
history blame
7.32 kB
import gradio as gr
from transformers import pipeline, MarianTokenizer, AutoModelForSeq2SeqLM
import torch
import tempfile
import os
import whisper
import fitz # PyMuPDF
import docx
from bs4 import BeautifulSoup
import markdown2
import chardet
import re
# Device setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Global model holders
translator = None
whisper_model = None
# Model configurations
MODELS = {
("English", "Wolof"): {"model_name": "LocaleNLP/localenlp-eng-wol-0.03", "tag": ">>wol<<"},
("Wolof", "English"): {"model_name": "LocaleNLP/localenlp-wol-eng-0.03", "tag": ">>eng<<"},
("English", "Hausa"): {"model_name": "LocaleNLP/localenlp-eng-hau-0.01", "tag": ">>hau<<"},
("Hausa", "English"): {"model_name": "LocaleNLP/localenlp-hau-eng-0.01", "tag": ">>eng<<"},
("English", "Darija"): {"model_name": "LocaleNLP/english_darija", "tag": ">>dar<<"},
}
HF_TOKEN = os.getenv("hffff")
def load_model(input_lang, output_lang):
global translator
key = (input_lang, output_lang)
if key not in MODELS:
raise ValueError("Language pair not supported.")
cfg = MODELS[key]
if translator is None or translator.model.config._name_or_path != cfg["model_name"]:
model = AutoModelForSeq2SeqLM.from_pretrained(cfg["model_name"], token=HF_TOKEN).to(device)
tokenizer = MarianTokenizer.from_pretrained(cfg["model_name"], token=HF_TOKEN)
translator = pipeline("translation", model=model, tokenizer=tokenizer, device=0 if device.type=='cuda' else -1)
return translator, cfg["tag"]
def load_whisper_model():
global whisper_model
if whisper_model is None:
whisper_model = whisper.load_model("base")
return whisper_model
def transcribe_audio(audio_file):
model = load_whisper_model()
if isinstance(audio_file, str):
audio_path = audio_file
else:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp.write(audio_file.read())
audio_path = tmp.name
result = model.transcribe(audio_path)
if not isinstance(audio_file, str):
os.remove(audio_path)
return result["text"]
def extract_text_from_file(uploaded_file):
if isinstance(uploaded_file, str):
file_path = uploaded_file
file_type = file_path.split('.')[-1].lower()
with open(file_path, "rb") as f:
content = f.read()
else:
file_type = uploaded_file.name.split('.')[-1].lower()
content = uploaded_file.read()
if file_type == "pdf":
with fitz.open(stream=content, filetype="pdf") as doc:
return "\n".join([page.get_text() for page in doc])
elif file_type == "docx":
doc = docx.Document(file_path if isinstance(uploaded_file, str) else uploaded_file)
return "\n".join([para.text for para in doc.paragraphs])
else:
encoding = chardet.detect(content)['encoding']
content = content.decode(encoding, errors='ignore') if encoding else content
if file_type in ("html", "htm"):
return BeautifulSoup(content, "html.parser").get_text()
elif file_type == "md":
html = markdown2.markdown(content)
return BeautifulSoup(html, "html.parser").get_text()
elif file_type == "srt":
return re.sub(r"\d+\n\d{2}:\d{2}:\d{2},\d{3} --> .*?\n", "", content)
elif file_type in ("txt", "text"):
return content
else:
raise ValueError("Unsupported file type")
def translate_text(text, input_lang, output_lang):
translator, tag = load_model(input_lang, output_lang)
paragraphs = text.split("\n")
translated_output = []
with torch.no_grad():
for para in paragraphs:
if not para.strip():
translated_output.append("")
continue
sentences = [s.strip() for s in para.split('. ') if s.strip()]
formatted = [f"{tag} {s}" for s in sentences]
results = translator(formatted,
max_length=5000,
num_beams=5,
early_stopping=True,
no_repeat_ngram_size=3,
repetition_penalty=1.5,
length_penalty=1.2)
translated_sentences = [r['translation_text'].capitalize() for r in results]
translated_output.append('. '.join(translated_sentences))
return "\n".join(translated_output)
def process_input(input_mode, input_lang, text, audio_file, file_obj):
if input_mode == "Audio" and input_lang != "English":
raise ValueError("Audio input must be in English.")
if input_mode == "Text":
return text
elif input_mode == "Audio" and audio_file is not None:
return transcribe_audio(audio_file)
elif input_mode == "File" and file_obj is not None:
return extract_text_from_file(file_obj)
return ""
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## LocaleNLP Multi-language Translator")
gr.Markdown("Translate between English, Wolof, and Hausa. Now, audio input only accepts English.")
with gr.Row():
input_mode = gr.Radio(choices=["Text", "Audio", "File"], label="Input type", value="Text")
input_lang = gr.Dropdown(choices=["English", "Wolof", "Hausa"], label="Input language", value="English")
output_lang = gr.Dropdown(choices=["English", "Wolof", "Hausa","Darija"], label="Output language", value="Wolof")
input_text = gr.Textbox(label="Enter text", lines=10, visible=True)
audio_input = gr.Audio(label="Upload audio (.wav, .mp3, .m4a)", type="filepath", visible=False)
file_input = gr.File(file_types=['.pdf', '.docx', '.html', '.htm', '.md', '.srt', '.txt'], label="Upload document", visible=False)
extracted_text = gr.Textbox(label="Extracted / Transcribed Text", lines=10, interactive=False)
translate_button = gr.Button("Translate")
output_text = gr.Textbox(label="Translated Text", lines=10, interactive=False)
def update_visibility(mode):
return {
input_text: gr.update(visible=(mode=="Text")),
audio_input: gr.update(visible=(mode=="Audio")),
file_input: gr.update(visible=(mode=="File")),
extracted_text: gr.update(value="", visible=True),
output_text: gr.update(value="")
}
input_mode.change(fn=update_visibility, inputs=input_mode, outputs=[input_text, audio_input, file_input, extracted_text, output_text])
def handle_process(mode, lang_in, text, audio, file_obj):
try:
extracted = process_input(mode, lang_in, text, audio, file_obj)
return extracted, ""
except Exception as e:
return "", f"Error: {str(e)}"
translate_button.click(fn=handle_process, inputs=[input_mode, input_lang, input_text, audio_input, file_input], outputs=[extracted_text, output_text])
def handle_translate(text, lang_in, lang_out):
if not text.strip():
return "No input text to translate."
return translate_text(text, lang_in, lang_out)
translate_button.click(fn=handle_translate, inputs=[extracted_text, input_lang, output_lang], outputs=output_text)
demo.launch()