Bisher commited on
Commit
17ddd97
·
verified ·
1 Parent(s): 7b8a9a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -65
app.py CHANGED
@@ -4,6 +4,10 @@ import jiwer
4
  import os
5
  import time
6
  import warnings
 
 
 
 
7
 
8
  # Suppress specific UserWarnings from jiwer related to empty strings
9
  warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
@@ -13,16 +17,29 @@ warnings.filterwarnings("ignore", message="Hypothesis is empty.*", category=User
13
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
14
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
15
 
16
- # --- Gradio API Clients ---
17
- # It's good practice to initialize clients outside the functions
18
- # if the app runs continuously, but be mindful of potential state issues
19
- # or connection timeouts in long-running deployments. For simplicity here,
20
- # we might re-initialize, though a single initialization is often preferred.
 
 
 
 
 
 
 
 
 
 
 
 
21
 
 
 
22
  def get_diacritization_client():
23
  """Initializes and returns the client for the text diacritization API."""
24
  try:
25
- # Added timeout for robustness
26
  return Client(DIACRITIZATION_API_URL, download_files=True)
27
  except Exception as e:
28
  print(f"Error initializing diacritization client: {e}")
@@ -31,7 +48,6 @@ def get_diacritization_client():
31
  def get_transcription_client():
32
  """Initializes and returns the client for the audio transcription API."""
33
  try:
34
- # Added timeout for robustness
35
  return Client(TRANSCRIPTION_API_URL, download_files=True)
36
  except Exception as e:
37
  print(f"Error initializing transcription client: {e}")
@@ -52,31 +68,26 @@ def diacritize_text_api(text_to_diacritize):
52
  """
53
  if not text_to_diacritize or not text_to_diacritize.strip():
54
  error_msg = "Please enter some text to diacritize."
55
- # Return the error message twice
56
  return error_msg, error_msg
57
 
58
  client = get_diacritization_client()
59
  if not client:
60
  error_msg = "Error: Could not connect to the diacritization service."
61
- # Return the error message twice
62
  return error_msg, error_msg
63
 
64
  try:
65
  print(f"Sending text to diacritization API: {text_to_diacritize}")
66
  result = client.predict(
67
- model_type="Encoder-Only", # Or 'Encoder-Decoder' if preferred
68
  input_text=text_to_diacritize,
69
  api_name="/predict"
70
  )
71
  print(f"Received diacritized text: {result}")
72
- # Ensure result is a string before returning
73
  result_str = str(result) if result is not None else "Error: Received empty response from diacritization service."
74
- # Return the result twice
75
  return result_str, result_str
76
  except Exception as e:
77
  print(f"Error during text diacritization API call: {e}")
78
  error_msg = f"Error during diacritization: {e}"
79
- # Return the error message twice
80
  return error_msg, error_msg
81
 
82
  def transcribe_audio_api(audio_filepath):
@@ -92,7 +103,6 @@ def transcribe_audio_api(audio_filepath):
92
  if not audio_filepath:
93
  return "Error: Please provide an audio recording or file."
94
 
95
- # Check if file exists and is accessible
96
  if not os.path.exists(audio_filepath):
97
  return f"Error: Audio file not found at {audio_filepath}"
98
 
@@ -102,14 +112,11 @@ def transcribe_audio_api(audio_filepath):
102
 
103
  try:
104
  print(f"Sending audio file to transcription API: {audio_filepath}")
105
- # Use handle_file to manage the audio file for the API call
106
  result = client.predict(
107
  audio=handle_file(audio_filepath),
108
  api_name="/predict"
109
  )
110
  print(f"Received transcript: {result}")
111
- # The API might return more structure, adapt if needed. Assuming it returns the text directly.
112
- # Example: if result is {'text': '...'}, use result['text']
113
  if isinstance(result, dict) and 'text' in result:
114
  transcript = result['text']
115
  elif isinstance(result, str):
@@ -118,17 +125,40 @@ def transcribe_audio_api(audio_filepath):
118
  print(f"Unexpected transcription result format: {result}")
119
  return "Error: Unexpected format received from transcription service."
120
 
121
- # Ensure transcript is a string
122
  return str(transcript) if transcript is not None else "Error: Received empty response from transcription service."
123
 
124
  except Exception as e:
125
  print(f"Error during audio transcription API call: {e}")
126
- # Provide more specific error feedback if possible
127
  return f"Error during transcription: {e}"
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  def calculate_metrics(reference, hypothesis):
130
  """
131
  Calculates Word Error Rate (WER) and Diacritic Error Rate (DER).
 
132
 
133
  Args:
134
  reference (str): The original diacritized text.
@@ -140,45 +170,53 @@ def calculate_metrics(reference, hypothesis):
140
  # Ensure inputs are strings before proceeding
141
  if not isinstance(reference, str):
142
  print(f"Error: Reference input is not a string (type: {type(reference)}). Value: {reference}")
143
- reference = "" # Default to empty string to avoid downstream errors
144
  if not isinstance(hypothesis, str):
145
  print(f"Error: Hypothesis input is not a string (type: {type(hypothesis)}). Value: {hypothesis}")
146
- hypothesis = "" # Default to empty string
147
 
148
- # Handle empty strings to avoid jiwer warnings/errors if not suppressed
149
  ref_strip = reference.strip()
150
  hyp_strip = hypothesis.strip()
151
 
152
- if not ref_strip and not hyp_strip:
153
- return 0.0, 0.0 # Both empty, 0% error
154
- if not ref_strip:
155
- print("Warning: Reference text is empty.")
156
- # WER/DER are typically 1.0 (or inf) if reference is empty and hypothesis is not.
157
- return 1.0, 1.0
158
- # Note: If hypothesis is empty but reference is not, jiwer calculates WER=1.0, which is correct.
159
 
160
  try:
161
  # 1. Calculate Word Error Rate (WER)
162
- wer = jiwer.wer(reference, hypothesis)
163
-
164
- # 2. Calculate Diacritic Error Rate (DER)
165
- # - Treat each character (including diacritics) as a token.
166
- # - Join characters with spaces to make jiwer treat them as "words".
167
- ref_chars = ' '.join(list(reference))
168
- hyp_chars = ' '.join(list(hypothesis))
169
- # Need to handle potential empty strings after join for jiwer
170
- if not ref_chars.strip() and not hyp_chars.strip():
171
- der = 0.0
172
- elif not ref_chars.strip():
173
- der = 1.0
174
  else:
175
- der = jiwer.wer(ref_chars, hyp_chars)
176
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
- return round(wer, 4), round(der, 4)
 
 
 
179
 
180
  except Exception as e:
181
  print(f"Error calculating metrics: {e}")
 
182
  return None, None
183
 
184
 
@@ -194,30 +232,23 @@ def process_audio_and_compare(audio_input, original_diacritized_text):
194
  der (float | None): Diacritic Error Rate or None if error.
195
  """
196
  print("Processing audio and comparing...")
197
- # Check if original_diacritized_text is valid
198
  if not original_diacritized_text or not isinstance(original_diacritized_text, str) or original_diacritized_text.startswith("Error:"):
199
  error_msg = "Error: Valid reference diacritized text not available. Please diacritize text first."
200
  print(error_msg)
201
- # Return default/error values for all outputs
202
  return error_msg, None, None
203
 
204
- # --- 1. Transcribe Audio ---
205
- # Gradio provides the audio data (e.g., filepath for upload/mic)
206
  transcript = transcribe_audio_api(audio_input)
207
 
208
  if not isinstance(transcript, str) or transcript.startswith("Error:"):
209
- # If transcription failed, return the error and None for metrics
210
  error_msg = transcript if isinstance(transcript, str) else "Error: Transcription failed with non-string output."
211
  print(error_msg)
212
  return error_msg, None, None
213
 
214
- # --- 2. Calculate Metrics ---
215
  wer, der = calculate_metrics(original_diacritized_text, transcript)
216
 
217
  if wer is None or der is None:
218
  print("Metrics calculation failed.")
219
- # Return transcript but indicate metric failure
220
- return transcript, None, None
221
 
222
  print(f"Comparison complete. WER: {wer}, DER: {der}")
223
  return transcript, wer, der
@@ -231,11 +262,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
231
  1. Enter undiacritized Arabic text and click **Diacritize Text**.
232
  2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file.
233
  3. Click **Transcribe and Compare** to get the transcript and see the WER/DER scores compared to the original diacritized text.
 
 
234
  """
235
  )
236
 
237
- # Store the original diacritized text for comparison later
238
- original_diacritized_state = gr.State("") # Initialize state
239
 
240
  with gr.Row():
241
  with gr.Column(scale=1):
@@ -243,20 +275,20 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
243
  label="1. Enter Undiacritized Arabic Text",
244
  placeholder="مثال: السلام عليكم",
245
  lines=3,
246
- text_align="right", # Align text right for Arabic
247
  )
248
  diacritize_button = gr.Button("Diacritize Text")
249
  diacritized_text_output = gr.Textbox(
250
  label="2. Diacritized Text (Reference)",
251
  lines=3,
252
- interactive=False, # User shouldn't edit this directly
253
  text_align="right",
254
  )
255
 
256
  with gr.Column(scale=1):
257
  audio_input = gr.Audio(
258
  sources=["microphone", "upload"],
259
- type="filepath", # Get the path to the saved audio file
260
  label="3. Record or Upload Audio of Reading Diacritized Text",
261
  )
262
  transcribe_button = gr.Button("Transcribe and Compare")
@@ -267,26 +299,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as app:
267
  text_align="right",
268
  )
269
  with gr.Row():
270
- # Set precision for number outputs
271
  wer_output = gr.Number(label="Word Error Rate (WER)", interactive=False, precision=4)
272
  der_output = gr.Number(label="Diacritic Error Rate (DER)", interactive=False, precision=4)
273
 
274
 
275
  # --- Connect Components ---
276
-
277
- # Action for Diacritize Button
278
  diacritize_button.click(
279
  fn=diacritize_text_api,
280
  inputs=[text_input],
281
- # Expects two outputs now from the modified function
282
  outputs=[diacritized_text_output, original_diacritized_state]
283
  )
284
 
285
- # Action for Transcribe Button
286
  transcribe_button.click(
287
  fn=process_audio_and_compare,
288
- inputs=[audio_input, original_diacritized_state], # Pass audio and stored text
289
- outputs=[transcript_output, wer_output, der_output] # Update transcript and metrics
290
  )
291
 
292
- app.launch(debug=True, share=True)
 
4
  import os
5
  import time
6
  import warnings
7
+ # Import pyarabic for diacritic identification
8
+ import pyarabic.araby as araby
9
+
10
+
11
 
12
  # Suppress specific UserWarnings from jiwer related to empty strings
13
  warnings.filterwarnings("ignore", message="Reference is empty.*", category=UserWarning)
 
17
  DIACRITIZATION_API_URL = "Bisher/CATT.diacratization"
18
  TRANSCRIPTION_API_URL = "gh-kaka22/diacritic_level_arabic_transcription"
19
 
20
+ # Define the set of Arabic diacritic characters using pyarabic constants if available
21
+ if araby:
22
+ ARABIC_DIACRITICS = {
23
+ araby.FATHA, # U+064E
24
+ araby.FATHATAN, # U+064B
25
+ araby.DAMMA, # U+064F
26
+ araby.DAMMATAN, # U+064C
27
+ araby.KASRA, # U+0650
28
+ araby.KASRATAN, # U+064D
29
+ araby.SUKUN, # U+0652
30
+ araby.SHADDA, # U+0651
31
+ # Consider adding others if needed, e.g., araby.MADDA (U+0653), araby.HAMZA_ABOVE (U+0654), etc.
32
+ # Sticking to the main 8 Tashkeel for now.
33
+ }
34
+ else:
35
+ # Fallback if pyarabic failed to import
36
+ ARABIC_DIACRITICS = {'\u064B', '\u064C', '\u064D', '\u064E', '\u064F', '\u0650', '\u0651', '\u0652'}
37
 
38
+
39
+ # --- Gradio API Clients ---
40
  def get_diacritization_client():
41
  """Initializes and returns the client for the text diacritization API."""
42
  try:
 
43
  return Client(DIACRITIZATION_API_URL, download_files=True)
44
  except Exception as e:
45
  print(f"Error initializing diacritization client: {e}")
 
48
  def get_transcription_client():
49
  """Initializes and returns the client for the audio transcription API."""
50
  try:
 
51
  return Client(TRANSCRIPTION_API_URL, download_files=True)
52
  except Exception as e:
53
  print(f"Error initializing transcription client: {e}")
 
68
  """
69
  if not text_to_diacritize or not text_to_diacritize.strip():
70
  error_msg = "Please enter some text to diacritize."
 
71
  return error_msg, error_msg
72
 
73
  client = get_diacritization_client()
74
  if not client:
75
  error_msg = "Error: Could not connect to the diacritization service."
 
76
  return error_msg, error_msg
77
 
78
  try:
79
  print(f"Sending text to diacritization API: {text_to_diacritize}")
80
  result = client.predict(
81
+ model_type="Encoder-Only",
82
  input_text=text_to_diacritize,
83
  api_name="/predict"
84
  )
85
  print(f"Received diacritized text: {result}")
 
86
  result_str = str(result) if result is not None else "Error: Received empty response from diacritization service."
 
87
  return result_str, result_str
88
  except Exception as e:
89
  print(f"Error during text diacritization API call: {e}")
90
  error_msg = f"Error during diacritization: {e}"
 
91
  return error_msg, error_msg
92
 
93
  def transcribe_audio_api(audio_filepath):
 
103
  if not audio_filepath:
104
  return "Error: Please provide an audio recording or file."
105
 
 
106
  if not os.path.exists(audio_filepath):
107
  return f"Error: Audio file not found at {audio_filepath}"
108
 
 
112
 
113
  try:
114
  print(f"Sending audio file to transcription API: {audio_filepath}")
 
115
  result = client.predict(
116
  audio=handle_file(audio_filepath),
117
  api_name="/predict"
118
  )
119
  print(f"Received transcript: {result}")
 
 
120
  if isinstance(result, dict) and 'text' in result:
121
  transcript = result['text']
122
  elif isinstance(result, str):
 
125
  print(f"Unexpected transcription result format: {result}")
126
  return "Error: Unexpected format received from transcription service."
127
 
 
128
  return str(transcript) if transcript is not None else "Error: Received empty response from transcription service."
129
 
130
  except Exception as e:
131
  print(f"Error during audio transcription API call: {e}")
 
132
  return f"Error during transcription: {e}"
133
 
134
+ def get_diacritics_sequence(text):
135
+ """
136
+ Extracts only the Arabic diacritic characters from a string.
137
+
138
+ Args:
139
+ text (str): The input string potentially containing diacritics.
140
+
141
+ Returns:
142
+ str: A space-separated string of diacritics found in the text.
143
+ Returns an empty string if no diacritics are found or input is not a string.
144
+ """
145
+ if not isinstance(text, str):
146
+ return "" # Return empty string for non-string input
147
+
148
+ # Check if pyarabic was imported successfully
149
+ if not araby and not ARABIC_DIACRITICS:
150
+ print("Warning: pyarabic not loaded, cannot reliably extract diacritics.")
151
+ return "" # Cannot proceed without diacritic definitions
152
+
153
+ diacritics_only = [char for char in text if char in ARABIC_DIACRITICS]
154
+ # Return as a space-separated string for jiwer.wer
155
+ return ' '.join(diacritics_only)
156
+
157
+
158
  def calculate_metrics(reference, hypothesis):
159
  """
160
  Calculates Word Error Rate (WER) and Diacritic Error Rate (DER).
161
+ DER is calculated based *only* on the sequence of diacritic marks.
162
 
163
  Args:
164
  reference (str): The original diacritized text.
 
170
  # Ensure inputs are strings before proceeding
171
  if not isinstance(reference, str):
172
  print(f"Error: Reference input is not a string (type: {type(reference)}). Value: {reference}")
173
+ reference = ""
174
  if not isinstance(hypothesis, str):
175
  print(f"Error: Hypothesis input is not a string (type: {type(hypothesis)}). Value: {hypothesis}")
176
+ hypothesis = ""
177
 
 
178
  ref_strip = reference.strip()
179
  hyp_strip = hypothesis.strip()
180
 
181
+ wer = None
182
+ der = None
 
 
 
 
 
183
 
184
  try:
185
  # 1. Calculate Word Error Rate (WER)
186
+ if not ref_strip and not hyp_strip:
187
+ wer = 0.0
188
+ elif not ref_strip:
189
+ wer = 1.0 # Reference empty, hypothesis not
 
 
 
 
 
 
 
 
190
  else:
191
+ # Jiwer handles hyp_strip being empty if ref_strip is not
192
+ wer = jiwer.wer(reference, hypothesis)
193
+
194
+ # 2. Calculate Diacritic Error Rate (DER) based *only* on diacritics
195
+ ref_diacritics = get_diacritics_sequence(reference)
196
+ hyp_diacritics = get_diacritics_sequence(hypothesis)
197
+
198
+ ref_diacritics_strip = ref_diacritics.strip()
199
+ hyp_diacritics_strip = hyp_diacritics.strip()
200
+
201
+ if not ref_diacritics_strip and not hyp_diacritics_strip:
202
+ der = 0.0 # No diacritics in either reference or hypothesis
203
+ elif not ref_diacritics_strip:
204
+ # Reference has no diacritics, but hypothesis does. DER is 1.0 (all hyp diacritics are insertions).
205
+ der = 1.0
206
+ print("Warning: No diacritics found in reference text for DER calculation.")
207
+ else:
208
+ # Reference has diacritics. Jiwer calculates WER on the diacritic sequences.
209
+ # If hypothesis has no diacritics, jiwer.wer will be 1.0 (all ref diacritics deleted).
210
+ der = jiwer.wer(ref_diacritics, hyp_diacritics) # Use the space-separated strings
211
 
212
+ # Round the results if they were calculated successfully
213
+ wer_rounded = round(wer, 4) if wer is not None else None
214
+ der_rounded = round(der, 4) if der is not None else None
215
+ return wer_rounded, der_rounded
216
 
217
  except Exception as e:
218
  print(f"Error calculating metrics: {e}")
219
+ # Return None if any exception occurred during calculation
220
  return None, None
221
 
222
 
 
232
  der (float | None): Diacritic Error Rate or None if error.
233
  """
234
  print("Processing audio and comparing...")
 
235
  if not original_diacritized_text or not isinstance(original_diacritized_text, str) or original_diacritized_text.startswith("Error:"):
236
  error_msg = "Error: Valid reference diacritized text not available. Please diacritize text first."
237
  print(error_msg)
 
238
  return error_msg, None, None
239
 
 
 
240
  transcript = transcribe_audio_api(audio_input)
241
 
242
  if not isinstance(transcript, str) or transcript.startswith("Error:"):
 
243
  error_msg = transcript if isinstance(transcript, str) else "Error: Transcription failed with non-string output."
244
  print(error_msg)
245
  return error_msg, None, None
246
 
 
247
  wer, der = calculate_metrics(original_diacritized_text, transcript)
248
 
249
  if wer is None or der is None:
250
  print("Metrics calculation failed.")
251
+ return transcript, None, None # Return transcript but None for metrics
 
252
 
253
  print(f"Comparison complete. WER: {wer}, DER: {der}")
254
  return transcript, wer, der
 
262
  1. Enter undiacritized Arabic text and click **Diacritize Text**.
263
  2. Read the generated **Diacritized Text** aloud and record it using the microphone or upload an audio file.
264
  3. Click **Transcribe and Compare** to get the transcript and see the WER/DER scores compared to the original diacritized text.
265
+
266
+ **Note:** Requires `pyarabic` library (`pip install pyarabic`) for accurate Diacritic Error Rate (DER) calculation.
267
  """
268
  )
269
 
270
+ original_diacritized_state = gr.State("")
 
271
 
272
  with gr.Row():
273
  with gr.Column(scale=1):
 
275
  label="1. Enter Undiacritized Arabic Text",
276
  placeholder="مثال: السلام عليكم",
277
  lines=3,
278
+ text_align="right",
279
  )
280
  diacritize_button = gr.Button("Diacritize Text")
281
  diacritized_text_output = gr.Textbox(
282
  label="2. Diacritized Text (Reference)",
283
  lines=3,
284
+ interactive=False,
285
  text_align="right",
286
  )
287
 
288
  with gr.Column(scale=1):
289
  audio_input = gr.Audio(
290
  sources=["microphone", "upload"],
291
+ type="filepath",
292
  label="3. Record or Upload Audio of Reading Diacritized Text",
293
  )
294
  transcribe_button = gr.Button("Transcribe and Compare")
 
299
  text_align="right",
300
  )
301
  with gr.Row():
 
302
  wer_output = gr.Number(label="Word Error Rate (WER)", interactive=False, precision=4)
303
  der_output = gr.Number(label="Diacritic Error Rate (DER)", interactive=False, precision=4)
304
 
305
 
306
  # --- Connect Components ---
 
 
307
  diacritize_button.click(
308
  fn=diacritize_text_api,
309
  inputs=[text_input],
 
310
  outputs=[diacritized_text_output, original_diacritized_state]
311
  )
312
 
 
313
  transcribe_button.click(
314
  fn=process_audio_and_compare,
315
+ inputs=[audio_input, original_diacritized_state],
316
+ outputs=[transcript_output, wer_output, der_output]
317
  )
318
 
319
+ app.launch(debug=True, share=True)