Bisher commited on
Commit
2b1b4a0
·
verified ·
1 Parent(s): 13db942

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +153 -284
app.py CHANGED
@@ -1,311 +1,180 @@
 
 
 
 
1
  import gradio as gr
2
- from gradio_client import Client, handle_file
3
  import jiwer
4
- import os
5
- import time
6
- import warnings
7
  import pyarabic.araby as araby
8
- import difflib # Import difflib
9
-
10
- # Suppress specific UserWarnings from jiwer related to empty strings
11
- warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
12
- warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=UserWarning)
13
-
14
- # --- Constants ---
15
- DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
16
- TRANSCRIPTION_API_URL = "Bisher/arabic_syllable_transcription"
17
-
18
- # Define Arabic diacritics
19
- # Use a try-except block in case pyarabic is not installed or fails to import
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  try:
21
- ARABIC_DIACRITICS = {
22
  araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
23
  araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
24
  }
25
- except (ImportError, NameError):
26
- print("Warning: pyarabic not found or failed to import. Using fallback diacritics set.")
27
- ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
28
-
29
- # --- API Clients ---
30
- # Use caching or global clients to avoid re-initializing on every call
31
- diacritization_client = None
32
- transcription_client = None
33
-
34
-
35
- def get_diacritization_client():
36
- global diacritization_client
37
- if diacritization_client is None:
38
- try:
39
- diacritization_client = Client(DIACRITIZATION_API_URL, download_files=True)
40
- except Exception as e:
41
- print(f"Error initializing diacritization client: {e}")
42
- return None
43
- return diacritization_client
44
-
45
- def get_transcription_client():
46
- global transcription_client
47
- if transcription_client is None:
48
- try:
49
- transcription_client = Client(TRANSCRIPTION_API_URL, download_files=True)
50
- except Exception as e:
51
- print(f"Error initializing transcription client: {e}")
52
- return None
53
- return transcription_client
54
-
55
-
56
-
57
- # --- Helper Functions ---
58
- def diacritize_text_api(text_to_diacritize):
59
- """Calls the diacritization API."""
60
- if not text_to_diacritize or not text_to_diacritize.strip():
61
- return "Please enter some text to diacritize.", "" # Return two values as expected by the click handler
62
- client = get_diacritization_client()
63
- if not client:
64
- return "Error: Could not connect to the diacritization service.", ""
65
- try:
66
- result = client.predict(
67
- model_type="Encoder-Only",
68
- input_text=text_to_diacritize,
69
- api_name="/predict"
70
- )
71
- # Ensure result is a string, handle potential None or unexpected types
72
- result_str = str(result) if result is not None else "Error: Empty response from diacritization service."
73
- # Return the result for both the output textbox and the state
74
- return result_str, result_str
75
- except Exception as e:
76
- print(f"Error during diacritization API call: {e}")
77
- return f"Error during diacritization: {e}", ""
78
-
79
- def transcribe_audio_api(audio_filepath):
80
- """Calls the standard transcription API."""
81
- if not audio_filepath:
82
- return "Error: Please provide an audio recording or file."
83
- if not os.path.exists(audio_filepath):
84
- return f"Error: Audio file not found at {audio_filepath}"
85
-
86
- client = get_transcription_client()
87
- if not client:
88
- return "Error: Could not connect to the transcription service."
89
- try:
90
- # Add a small delay if needed, sometimes helps with API race conditions
91
- # time.sleep(0.5)
92
- result = client.predict(
93
- audio=handle_file(audio_filepath),
94
- api_name="/predict"
95
- )
96
- return result[0], result[1]
97
- except Exception as e:
98
- print(f"Error during transcription API call: {e}")
99
- return f"Error during transcription: {e}"
100
-
101
- def get_diacritics_sequence(text):
102
- """Extracts diacritics from a string."""
103
- if not isinstance(text, str):
104
- return ""
105
- diacritics_only = [c for c in text if c in ARABIC_DIACRITICS]
106
- return ' '.join(diacritics_only)
107
-
108
- def calculate_metrics(reference, hypothesis):
109
- """Calculates WER, DER, CER."""
110
- ref = reference or ""
111
- hyp = hypothesis or ""
112
-
113
- # Handle cases where one or both are empty or just whitespace
114
- if not ref.strip() and not hyp.strip():
115
- return 0.0, 0.0, 0.0 # Both empty, 0 error
116
- if not ref.strip():
117
- return 1.0, 1.0, 1.0 # Reference empty, hypothesis not: Max error
118
- if not hyp.strip():
119
- # Hypothesis empty, reference not: Max error (though jiwer might handle this)
120
- # Let jiwer calculate based on its rules for empty hypothesis
121
- pass
122
 
123
- try:
124
- # WER
125
- wer = jiwer.wer(ref, hyp)
126
- # DER
127
- ref_d = get_diacritics_sequence(ref)
128
- hyp_d = get_diacritics_sequence(hyp)
129
- # Handle empty diacritic sequences for DER calculation
130
- if not ref_d.strip() and not hyp_d.strip():
131
- der = 0.0
132
- elif not ref_d.strip():
133
- der = 1.0
134
- else:
135
- der = jiwer.wer(ref_d, hyp_d) # jiwer handles empty hyp_d if ref_d is not empty
136
- # CER
137
- cer = jiwer.cer(ref, hyp)
138
- return round(wer, 4), round(der, 4), round(cer, 4)
139
- except Exception as e:
140
- print(f"Error calculating metrics: {e}")
141
- return None, None, None # Indicate error in calculation
142
 
 
 
143
 
144
- def highlight_errors(reference, hypothesis):
145
- """Highlights differences between reference and hypothesis using HTML mark tag."""
146
- ref = reference or ""
147
- hyp = hypothesis or ""
148
- ref_words = ref.split()
149
- hyp_words = hyp.split()
150
 
151
- if not ref_words and not hyp_words:
152
- return "", "" # No errors if both are empty
 
 
 
 
 
 
153
 
154
- matcher = difflib.SequenceMatcher(None, ref_words, hyp_words, autojunk=False)
155
- highlighted_hyp_words = []
156
- error_words_ref = [] # Words in reference that were deleted or replaced
157
- error_words_hyp = [] # Words in hypothesis that were inserted or replaced
158
 
 
 
 
 
159
  for tag, i1, i2, j1, j2 in matcher.get_opcodes():
160
  if tag == 'equal':
161
- highlighted_hyp_words.extend(hyp_words[j1:j2])
162
  elif tag == 'replace':
163
- # Mark incorrect words in hypothesis red
164
- for word in hyp_words[j1:j2]:
165
- highlighted_hyp_words.append(f"<mark style='background-color: #ffcccb;'>{word}</mark>")
166
- error_words_ref.extend(ref_words[i1:i2])
167
- error_words_hyp.extend(hyp_words[j1:j2])
168
  elif tag == 'delete':
169
- # Indicate missing words (maybe with a placeholder?) - for now, just note them
170
- # We don't add anything to highlighted_hyp_words here as they are missing
171
- error_words_ref.extend(ref_words[i1:i2])
172
- # Optionally add a placeholder in the output to show where deletion happened
173
- # highlighted_hyp_words.append("<mark style='background-color: #lightgrey;'>[missing]</mark>")
174
  elif tag == 'insert':
175
- # Mark inserted words in hypothesis green
176
- for word in hyp_words[j1:j2]:
177
- highlighted_hyp_words.append(f"<mark style='background-color: #ccffcc;'>{word}</mark>")
178
- error_words_hyp.extend(hyp_words[j1:j2])
179
-
180
- html_output = ' '.join(highlighted_hyp_words)
181
- # Combine unique error words for the list
182
- error_list = sorted(list(set(error_words_ref + error_words_hyp)))
183
-
184
- return html_output, ', '.join(error_list)
185
-
186
-
187
- # --- Gradio Interface ---
 
 
 
188
  with gr.Blocks(theme=gr.themes.Soft()) as app:
189
- gr.Markdown(
190
- """
191
- # Arabic Diacritization and Reading Assessment Tool
192
- 1. Enter undiacritized Arabic text and click **Diacritize Text**.
193
- 2. Read the generated **Diacritized Text** aloud and record or upload audio.
194
- 3. Click **Transcribe and Compare** to see the transcript, syllable transcript, WER/DER/CER, and mispronounced words highlighted.
195
- """
196
- )
197
-
198
- # Using gr.State to hold the diacritized reference text between steps
199
- reference_text_state = gr.State("")
200
 
201
  with gr.Row():
202
  with gr.Column(scale=1):
203
- text_input = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
204
- diacritize_btn = gr.Button("Diacritize Text")
205
- diacritized_output = gr.Textbox(
206
- label="Diacritized Text (Reference)",
207
- lines=3,
208
- interactive=True, # User shouldn't edit this directly
209
- text_align="right",
210
- )
211
- diacritized_output.change(
212
- fn=lambda text: text,
213
- inputs=diacritized_output,
214
- outputs=reference_text_state
215
- )
216
 
217
  with gr.Column(scale=1):
218
- audio_input = gr.Audio(label="Record or Upload Audio", type="filepath", sources=["microphone", "upload"])
219
- transcribe_btn = gr.Button("Transcribe and Compare")
220
- transcript_output = gr.Textbox(
221
- label="Transcript (Hypothesis)",
222
- lines=3,
223
- interactive=False,
224
- text_align="right"
225
- )
226
- # Ensure this Textbox is defined correctly
227
- transcript_syllables_output = gr.Textbox(
228
- label="Transcript Syllables (Hypothesis)", # Corrected label slightly for clarity
229
- lines=3,
230
- interactive=False,
231
- text_align="right"
232
  )
233
- with gr.Row():
234
- wer_out = gr.Number(label="WER", interactive=False, precision=4)
235
- der_out = gr.Number(label="DER", interactive=False, precision=4)
236
- cer_out = gr.Number(label="CER", interactive=False, precision=4)
237
- # Use Markdown for potentially richer HTML display if needed, but HTML component is fine
238
- error_html = gr.HTML(label="Highlighted Errors in Hypothesis")
239
- error_list = gr.Textbox(label="Words Involved in Errors", interactive=False) # Changed label
240
-
241
- # --- Event Handlers ---
242
-
243
- # When Diacritize button is clicked
244
- diacritize_btn.click(
245
- fn=diacritize_text_api,
246
- inputs=[text_input],
247
- # Output to the display box AND the hidden state
248
- outputs=[diacritized_output, reference_text_state]
249
- )
250
-
251
- # Define the main processing function that returns all 7 values
252
- def process_audio_and_compare(audio_filepath, reference_text):
253
- """Processes audio, gets both transcripts, calculates metrics, and highlights errors."""
254
- # Default values in case of errors
255
- transcript = "Error: Processing failed."
256
- syllable_transcript = "Error: Processing failed."
257
- wer, der, cer = None, None, None
258
- html_output = ""
259
- error_words = ""
260
-
261
- # Validate inputs
262
- if not audio_filepath:
263
- transcript = "Error: No audio provided."
264
- syllable_transcript = "Error: No audio provided."
265
- # Return 7 values even on input error
266
- return transcript, syllable_transcript, None, None, None, "", ""
267
- if not reference_text:
268
- transcript = "Error: No reference text found. Please diacritize first."
269
- syllable_transcript = "Error: No reference text found."
270
- # Return 7 values
271
- return transcript, syllable_transcript, None, None, None, "", ""
272
- try:
273
- # --- Call Transcription APIs ---
274
- transcript, syllable_transcript = transcribe_audio_api(audio_filepath)
275
- except:
276
- print(f"Error calculating metrics: {e}")
277
- transcript, syllable_transcript = "error", "error"
278
- # --- Calculate Metrics and Highlight Errors (only if first transcript is not an error) ---
279
- if not transcript.startswith("Error"):
280
- wer, der, cer = calculate_metrics(reference_text, transcript)
281
- # Use the standard transcript for highlighting, adjust if needed
282
- html_output, error_words = highlight_errors(reference_text, transcript)
283
- else:
284
- # If the main transcript failed, indicate no metrics/highlighting possible
285
- wer, der, cer = None, None, None
286
- html_output = "Highlighting not available due to transcription error."
287
- error_words = "N/A"
288
-
289
- # --- Return all 7 values ---
290
- return transcript, syllable_transcript, wer, der, cer, html_output, error_words
291
-
292
- # When Transcribe button is clicked
293
- transcribe_btn.click(
294
- fn=process_audio_and_compare,
295
- # Get audio path and the reference text from the state
296
- inputs=[audio_input, reference_text_state],
297
- # Update all 7 output components
298
- outputs=[
299
- transcript_output,
300
- transcript_syllables_output, # This should now update correctly
301
- wer_out,
302
- der_out,
303
- cer_out,
304
- error_html,
305
- error_list
306
- ]
307
- )
308
 
309
- # Launch the app
310
  if __name__ == "__main__":
311
- app.launch(debug=True, ssr_mode=False) # Set share=True if you need a public link
 
1
+ import os
2
+ import sys
3
+ import urllib.request
4
+ import torch
5
  import gradio as gr
 
6
  import jiwer
7
+ import difflib
 
 
8
  import pyarabic.araby as araby
9
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
10
+
11
+ # ---------- Setup: Clone CATT repo & download diacritization models ----------
12
+ CATT_REPO_URL = "https://github.com/abjadai/catt.git"
13
+ CATT_FOLDER = "catt"
14
+ MODELS_DIR = "models"
15
+ ED_URL = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt"
16
+ EO_URL = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt"
17
+
18
+ os.makedirs(MODELS_DIR, exist_ok=True)
19
+
20
+ # Clone if needed
21
+ if not os.path.isdir(CATT_FOLDER):
22
+ os.system(f"git clone {CATT_REPO_URL}")
23
+ if CATT_FOLDER not in sys.path:
24
+ sys.path.append(CATT_FOLDER)
25
+
26
+ # Download checkpoints
27
+ for url in (ED_URL, EO_URL):
28
+ fname = os.path.basename(url)
29
+ dest = os.path.join(MODELS_DIR, fname)
30
+ if not os.path.isfile(dest):
31
+ urllib.request.urlretrieve(url, dest)
32
+
33
+ # Import CATT modules
34
+ from tashkeel_tokenizer import TashkeelTokenizer
35
+ from utils import remove_non_arabic
36
+ from ed_pl import TashkeelModel as TashkeelModel_ED
37
+ from eo_pl import TashkeelModel as TashkeelModel_EO
38
+
39
+ # Prepare tokenizer & device
40
+ tokenizer = TashkeelTokenizer()
41
+ device = "cuda" if torch.cuda.is_available() else "cpu"
42
+
43
+ # Load diacritization models
44
+ def load_diacritization_models():
45
+ global model_ed, model_eo
46
+ max_seq_len = 1024
47
+ model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False)
48
+ model_ed.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(ED_URL)), map_location=device))
49
+ model_ed.eval().to(device)
50
+
51
+ model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False)
52
+ model_eo.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(EO_URL)), map_location=device))
53
+ model_eo.eval().to(device)
54
+
55
+ load_diacritization_models()
56
+
57
+ # ---------- Setup: Arabic syllable transcription pipelines ----------
58
+ ASR_PIPE = pipeline("automatic-speech-recognition", model="IbrahimSalah/Arabic_speech_Syllables_recognition_Using_Wav2vec2")
59
+ MT5_MODEL = AutoModelForSeq2SeqLM.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
60
+ MT5_TOKENIZER = AutoTokenizer.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5")
61
+ MT5_MODEL.eval()
62
+
63
+ # Arabic diacritics set
64
  try:
65
+ DIACRITICS = {
66
  araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN,
67
  araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA,
68
  }
69
+ except:
70
+ DIACRITICS = {'\u064B','\u064C','\u064D','\u064E','\u064F','\u0650','\u0651','\u0652'}
71
+
72
+ # ---------- Core Functions ----------
73
+ def diacritize_text(model_type, input_text):
74
+ text_clean = remove_non_arabic(input_text.strip())
75
+ if not text_clean:
76
+ return "Please enter some Arabic text."
77
+ x = [text_clean]
78
+ if model_type == "Encoder-Decoder":
79
+ out = model_ed.do_tashkeel_batch(x, batch_size=16, verbose=False)
80
+ else:
81
+ out = model_eo.do_tashkeel_batch(x, batch_size=16, verbose=False)
82
+ return out[0] if out else ""
83
+
84
+
85
+ def get_and_process_syllables(audio_path):
86
+ # ASR -> syllable sequence -> MT5 conversion
87
+ clip = ASR_PIPE(audio_path)["text"]
88
+ seq = "|" + clip.replace(" ", "|") + "."
89
+ input_ids = MT5_TOKENIZER.encode(seq, return_tensors="pt")
90
+ out_ids = MT5_MODEL.generate(
91
+ input_ids,
92
+ max_length=100,
93
+ early_stopping=True,
94
+ pad_token_id=MT5_TOKENIZER.pad_token_id,
95
+ bos_token_id=MT5_TOKENIZER.bos_token_id,
96
+ eos_token_id=MT5_TOKENIZER.eos_token_id,
97
+ )
98
+ text = MT5_TOKENIZER.decode(out_ids[0][1:], skip_special_tokens=True).split('.')[0]
99
+ return text, seq
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ def get_diacritics_sequence(txt):
103
+ return ' '.join([c for c in txt if c in DIACRITICS])
104
 
 
 
 
 
 
 
105
 
106
+ def calculate_metrics(ref, hyp):
107
+ if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0
108
+ if not ref.strip(): return 1.0, 1.0, 1.0
109
+ wer = jiwer.wer(ref, hyp)
110
+ ref_d, hyp_d = get_diacritics_sequence(ref), get_diacritics_sequence(hyp)
111
+ der = 0.0 if (not ref_d and not hyp_d) else (1.0 if not ref_d else jiwer.wer(ref_d, hyp_d))
112
+ cer = jiwer.cer(ref, hyp)
113
+ return round(wer,4), round(der,4), round(cer,4)
114
 
 
 
 
 
115
 
116
+ def highlight_errors(ref, hyp):
117
+ ref_w, hyp_w = ref.split(), hyp.split()
118
+ matcher = difflib.SequenceMatcher(None, ref_w, hyp_w, autojunk=False)
119
+ out_words, errs = [], []
120
  for tag, i1, i2, j1, j2 in matcher.get_opcodes():
121
  if tag == 'equal':
122
+ out_words.extend(hyp_w[j1:j2])
123
  elif tag == 'replace':
124
+ for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ffcccb;'>{w}</mark>")
125
+ errs.extend(ref_w[i1:i2] + hyp_w[j1:j2])
 
 
 
126
  elif tag == 'delete':
127
+ errs.extend(ref_w[i1:i2])
 
 
 
 
128
  elif tag == 'insert':
129
+ for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ccffcc;'>{w}</mark>")
130
+ errs.extend(hyp_w[j1:j2])
131
+ return ' '.join(out_words), ', '.join(sorted(set(errs)))
132
+
133
+
134
+ def process_audio_and_compare(audio_path, reference_text):
135
+ if not audio_path:
136
+ return *("Error: No audio provided.",)*2, None, None, None, "", ""
137
+ if not reference_text.strip():
138
+ return *("Error: No reference text.",)*2, None, None, None, "", ""
139
+ hyp, syll = get_and_process_syllables(audio_path)
140
+ wer, der, cer = calculate_metrics(reference_text, hyp) if not hyp.startswith("Error") else (None,None,None)
141
+ html_out, errs = highlight_errors(reference_text, hyp) if not hyp.startswith("Error") else ("", "")
142
+ return hyp, syll, wer, der, cer, html_out, errs
143
+
144
+ # ---------- Gradio Interface ----------
145
  with gr.Blocks(theme=gr.themes.Soft()) as app:
146
+ gr.Markdown("""
147
+ # Arabic Diacritization & Reading Assessment
148
+ 1. Enter undiacritized Arabic text Diacritize.
149
+ 2. Read aloud & record/upload audio Transcribe & Compare.
150
+ """)
151
+ ref_state = gr.State("")
 
 
 
 
 
152
 
153
  with gr.Row():
154
  with gr.Column(scale=1):
155
+ text_in = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right")
156
+ model_sel = gr.Dropdown(choices=["Encoder-Only","Encoder-Decoder"], value="Encoder-Only", label="Model")
157
+ diac_btn = gr.Button("Diacritize Text")
158
+ diac_out = gr.Textbox(label="Diacritized Text (Reference)", lines=3, text_align="right")
159
+ diac_btn.click(fn=diacritize_text, inputs=[model_sel, text_in], outputs=[diac_out, ref_state])
 
 
 
 
 
 
 
 
160
 
161
  with gr.Column(scale=1):
162
+ audio_in = gr.Audio(label="Record/Upload Audio", type="filepath")
163
+ trans_btn = gr.Button("Transcribe & Compare")
164
+ hyp_out = gr.Textbox(label="Transcript (Hypothesis)", lines=3, text_align="right")
165
+ syl_out = gr.Textbox(label="Transcript Syllables", lines=3, text_align="right")
166
+ wer_n = gr.Number(label="WER", precision=4)
167
+ der_n = gr.Number(label="DER", precision=4)
168
+ cer_n = gr.Number(label="CER", precision=4)
169
+ err_html = gr.HTML(label="Highlighted Errors")
170
+ err_list = gr.Textbox(label="Error Words")
171
+
172
+ trans_btn.click(
173
+ fn=process_audio_and_compare,
174
+ inputs=[audio_in, ref_state],
175
+ outputs=[hyp_out, syl_out, wer_n, der_n, cer_n, err_html, err_list]
176
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ # Launch
179
  if __name__ == "__main__":
180
+ app.launch(debug=True)