Mgolo commited on
Commit
6bd1e51
·
verified ·
1 Parent(s): c56374c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -0
app.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline, MarianTokenizer, AutoModelForSeq2SeqLM
3
+ import torch
4
+ import unicodedata
5
+ import re
6
+ import whisper
7
+ import tempfile
8
+ import os
9
+
10
+ import nltk
11
+ nltk.download('punkt')
12
+ from nltk.tokenize import sent_tokenize
13
+
14
+ import fitz # PyMuPDF
15
+ import docx
16
+ from bs4 import BeautifulSoup
17
+ import markdown2
18
+ import chardet
19
+
20
+ # Device setup
21
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
22
+
23
+ # Model configuration
24
+ MODELS = {
25
+ "english_wolof": {
26
+ "model_name": "LocaleNLP/localenlp-eng-wol-0.03",
27
+ "target_tag": ">>wol<<"
28
+ },
29
+ "wolof_english": {
30
+ "model_name": "LocaleNLP/localenlp-wol-eng-0.03",
31
+ "target_tag": ">>eng<<"
32
+ },
33
+ "english_hausa": {
34
+ "model_name": "LocaleNLP/localenlp-eng-hau-0.01",
35
+ "target_tag": ">>hau<<"
36
+ },
37
+ "hausa_english": {
38
+ "model_name": "LocaleNLP/localenlp-hau-eng-0.01",
39
+ "target_tag": ">>eng<<"
40
+ }
41
+ }
42
+
43
+ # Global variables
44
+ translator = None
45
+ current_model = None
46
+ whisper_model = None
47
+
48
+ HF_TOKEN = os.getenv("HF_TOKEN")
49
+
50
+ def load_translation_model(input_lang, output_lang):
51
+ global translator, current_model
52
+
53
+ model_key = f"{input_lang.lower()}_{output_lang.lower()}"
54
+ if model_key not in MODELS:
55
+ raise ValueError(f"Translation from {input_lang} to {output_lang} is not supported")
56
+
57
+ if current_model != model_key or translator is None:
58
+ model_config = MODELS[model_key]
59
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_config["model_name"], token=HF_TOKEN).to(device)
60
+ tokenizer = MarianTokenizer.from_pretrained(model_config["model_name"], token=HF_TOKEN)
61
+ translator = {
62
+ "pipeline": pipeline("translation", model=model, tokenizer=tokenizer,
63
+ device=0 if device.type == 'cuda' else -1),
64
+ "target_tag": model_config["target_tag"]
65
+ }
66
+ current_model = model_key
67
+
68
+ return translator
69
+
70
+ def load_whisper_model():
71
+ global whisper_model
72
+ if whisper_model is None:
73
+ whisper_model = whisper.load_model("base")
74
+ return whisper_model
75
+
76
+ def transcribe_audio(audio_file):
77
+ model = load_whisper_model()
78
+ if isinstance(audio_file, str):
79
+ audio_path = audio_file
80
+ else:
81
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
82
+ tmp.write(audio_file.read())
83
+ audio_path = tmp.name
84
+ result = model.transcribe(audio_path)
85
+ if not isinstance(audio_file, str):
86
+ os.remove(audio_path)
87
+ return result["text"]
88
+
89
+ def extract_text_from_file(uploaded_file):
90
+ if isinstance(uploaded_file, str):
91
+ file_path = uploaded_file
92
+ file_type = file_path.split('.')[-1].lower()
93
+ with open(file_path, "rb") as f:
94
+ content = f.read()
95
+ else:
96
+ file_type = uploaded_file.name.split('.')[-1].lower()
97
+ content = uploaded_file.read()
98
+
99
+ if file_type == "pdf":
100
+ with fitz.open(stream=content, filetype="pdf") as doc:
101
+ return "\n".join([page.get_text() for page in doc])
102
+ elif file_type == "docx":
103
+ if isinstance(uploaded_file, str):
104
+ doc = docx.Document(file_path)
105
+ else:
106
+ doc = docx.Document(uploaded_file)
107
+ return "\n".join([para.text for para in doc.paragraphs])
108
+ else:
109
+ encoding = chardet.detect(content)['encoding']
110
+ if encoding:
111
+ content = content.decode(encoding, errors='ignore')
112
+ if file_type in ("html", "htm"):
113
+ soup = BeautifulSoup(content, "html.parser")
114
+ return soup.get_text()
115
+ elif file_type == "md":
116
+ html = markdown2.markdown(content)
117
+ soup = BeautifulSoup(html, "html.parser")
118
+ return soup.get_text()
119
+ elif file_type == "srt":
120
+ return re.sub(r"\d+\n\d{2}:\d{2}:\d{2},\d{3} --> .*?\n", "", content)
121
+ elif file_type in ("txt", "text"):
122
+ return content
123
+ else:
124
+ raise ValueError("Unsupported file type")
125
+
126
+ def translate(text, input_lang, output_lang):
127
+ translator = load_translation_model(input_lang, output_lang)
128
+ lang_tag = translator["target_tag"]
129
+ translation_pipeline = translator["pipeline"]
130
+
131
+ paragraphs = text.split("\n")
132
+ translated_output = []
133
+
134
+ with torch.no_grad():
135
+ for para in paragraphs:
136
+ if not para.strip():
137
+ translated_output.append("")
138
+ continue
139
+ sentences = [s.strip() for s in para.split('. ') if s.strip()]
140
+ formatted = [f"{lang_tag} {s}" for s in sentences]
141
+
142
+ results = translation_pipeline(formatted,
143
+ max_length=5000,
144
+ num_beams=5,
145
+ early_stopping=True,
146
+ no_repeat_ngram_size=3,
147
+ repetition_penalty=1.5,
148
+ length_penalty=1.2)
149
+ translated_sentences = [r['translation_text'].capitalize() for r in results]
150
+ translated_output.append('. '.join(translated_sentences))
151
+
152
+ return "\n".join(translated_output)
153
+
154
+ def process_input(input_mode, text, audio_file, file_obj, input_lang):
155
+ input_text = ""
156
+ if input_mode == "Text":
157
+ input_text = text
158
+ elif input_mode == "Audio":
159
+ if audio_file is not None:
160
+ input_text = transcribe_audio(audio_file)
161
+ elif input_mode == "File":
162
+ if file_obj is not None:
163
+ input_text = extract_text_from_file(file_obj)
164
+ return input_text
165
+
166
+ def translate_and_return(text, input_lang, output_lang):
167
+ if not text.strip():
168
+ return "No input text to translate."
169
+ return translate(text, input_lang, output_lang)
170
+
171
+ def update_input_lang_dropdown(input_mode):
172
+ if input_mode == "Audio":
173
+ return gr.Dropdown(value="English", interactive=False)
174
+ else:
175
+ return gr.Dropdown(interactive=True)
176
+
177
+ # Gradio UI components
178
+ with gr.Blocks() as demo:
179
+ gr.Markdown("## LocaleNLP Translator")
180
+ gr.Markdown("Translate between English, Wolof, and Hausa using Localenlp models.")
181
+
182
+ with gr.Row():
183
+ input_mode = gr.Radio(choices=["Text", "Audio", "File"], label="Select input mode", value="Text")
184
+
185
+ with gr.Row():
186
+ input_lang = gr.Dropdown(choices=["English", "Wolof", "Hausa"], label="Input Language", value="English")
187
+ output_lang = gr.Dropdown(choices=["English", "Wolof", "Hausa"], label="Output Language", value="Hausa")
188
+
189
+ input_text = gr.Textbox(label="Enter text", lines=10, visible=True)
190
+ audio_input = gr.Audio(label="Upload audio (.wav, .mp3, .m4a)", type="filepath", visible=False)
191
+ file_input = gr.File(file_types=['.pdf', '.docx', '.html', '.htm', '.md', '.srt', '.txt'], label="Upload document", visible=False)
192
+
193
+ extracted_text = gr.Textbox(label="Extracted / Transcribed Text", lines=10, interactive=False)
194
+ translate_button = gr.Button("Translate")
195
+ output_text = gr.Textbox(label="Translated Text", lines=10, interactive=False)
196
+
197
+ def update_visibility(mode):
198
+ return {
199
+ input_text: gr.update(visible=(mode=="Text")),
200
+ audio_input: gr.update(visible=(mode=="Audio")),
201
+ file_input: gr.update(visible=(mode=="File")),
202
+ extracted_text: gr.update(value="", visible=True),
203
+ output_text: gr.update(value="")
204
+ }
205
+
206
+ input_mode.change(fn=update_visibility, inputs=input_mode, outputs=[input_text, audio_input, file_input, extracted_text, output_text])
207
+ input_mode.change(fn=update_input_lang_dropdown, inputs=input_mode, outputs=input_lang)
208
+
209
+ def handle_process(mode, text, audio, file_obj, in_lang):
210
+ try:
211
+ extracted = process_input(mode, text, audio, file_obj, in_lang)
212
+ return extracted, ""
213
+ except Exception as e:
214
+ return "", f"Error: {str(e)}"
215
+
216
+ translate_button.click(fn=handle_process, inputs=[input_mode, input_text, audio_input, file_input, input_lang], outputs=[extracted_text, output_text])
217
+
218
+ def handle_translate(text, in_lang, out_lang):
219
+ return translate_and_return(text, in_lang, out_lang)
220
+
221
+ translate_button.click(fn=handle_translate, inputs=[extracted_text, input_lang, output_lang], outputs=output_text)
222
+
223
+ demo.launch()