Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import urllib.request | |
| import torch | |
| import gradio as gr | |
| import jiwer | |
| import difflib | |
| import pyarabic.araby as araby | |
| from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer | |
| # ---------- Setup: Clone CATT repo & download diacritization models ---------- | |
| CATT_REPO_URL = "https://github.com/abjadai/catt.git" | |
| CATT_FOLDER = "catt" | |
| MODELS_DIR = "models" | |
| ED_URL = "https://github.com/abjadai/catt/releases/download/v2/best_ed_mlm_ns_epoch_178.pt" | |
| EO_URL = "https://github.com/abjadai/catt/releases/download/v2/best_eo_mlm_ns_epoch_193.pt" | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| # Clone if needed | |
| if not os.path.isdir(CATT_FOLDER): | |
| os.system(f"git clone {CATT_REPO_URL}") | |
| if CATT_FOLDER not in sys.path: | |
| sys.path.append(CATT_FOLDER) | |
| # Download checkpoints | |
| for url in (ED_URL, EO_URL): | |
| fname = os.path.basename(url) | |
| dest = os.path.join(MODELS_DIR, fname) | |
| if not os.path.isfile(dest): | |
| urllib.request.urlretrieve(url, dest) | |
| # Import CATT modules | |
| from tashkeel_tokenizer import TashkeelTokenizer | |
| from utils import remove_non_arabic | |
| from ed_pl import TashkeelModel as TashkeelModel_ED | |
| from eo_pl import TashkeelModel as TashkeelModel_EO | |
| # Prepare tokenizer & device | |
| tokenizer = TashkeelTokenizer() | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Load diacritization models | |
| def load_diacritization_models(): | |
| global model_ed, model_eo | |
| max_seq_len = 1024 | |
| model_ed = TashkeelModel_ED(tokenizer, max_seq_len=max_seq_len, n_layers=3, learnable_pos_emb=False) | |
| model_ed.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(ED_URL)), map_location=device)) | |
| model_ed.eval().to(device) | |
| model_eo = TashkeelModel_EO(tokenizer, max_seq_len=max_seq_len, n_layers=6, learnable_pos_emb=False) | |
| model_eo.load_state_dict(torch.load(os.path.join(MODELS_DIR, os.path.basename(EO_URL)), map_location=device)) | |
| model_eo.eval().to(device) | |
| load_diacritization_models() | |
| # ---------- Setup: Arabic syllable transcription pipelines ---------- | |
| ASR_PIPE = pipeline("automatic-speech-recognition", model="IbrahimSalah/Arabic_speech_Syllables_recognition_Using_Wav2vec2") | |
| MT5_MODEL = AutoModelForSeq2SeqLM.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5") | |
| MT5_TOKENIZER = AutoTokenizer.from_pretrained("IbrahimSalah/Arabic_Syllables_to_text_Converter_Using_MT5") | |
| MT5_MODEL.eval() | |
| # Arabic diacritics set | |
| try: | |
| DIACRITICS = { | |
| araby.FATHA, araby.FATHATAN, araby.DAMMA, araby.DAMMATAN, | |
| araby.KASRA, araby.KASRATAN, araby.SUKUN, araby.SHADDA, | |
| } | |
| except: | |
| DIACRITICS = {'\u064B','\u064C','\u064D','\u064E','\u064F','\u0650','\u0651','\u0652'} | |
| # ---------- Core Functions ---------- | |
| def diacritize_text(model_type, input_text): | |
| """ | |
| Returns the diacritized text twice: once for display, once for state storage. | |
| """ | |
| text_clean = remove_non_arabic(input_text.strip()) | |
| if not text_clean: | |
| return "Please enter some Arabic text.", "" | |
| x = [text_clean] | |
| if model_type == "Encoder-Decoder": | |
| outputs = model_ed.do_tashkeel_batch(x, batch_size=16, verbose=False) | |
| else: | |
| outputs = model_eo.do_tashkeel_batch(x, batch_size=16, verbose=False) | |
| result = outputs[0] if outputs else "" | |
| return result, result | |
| def get_and_process_syllables(audio_path): | |
| # ASR -> syllable sequence -> MT5 conversion | |
| clip = ASR_PIPE(audio_path)["text"] | |
| seq = "|" + clip.replace(" ", "|") + "." | |
| input_ids = MT5_TOKENIZER.encode(seq, return_tensors="pt") | |
| out_ids = MT5_MODEL.generate( | |
| input_ids, | |
| max_length=100, | |
| early_stopping=True, | |
| pad_token_id=MT5_TOKENIZER.pad_token_id, | |
| bos_token_id=MT5_TOKENIZER.bos_token_id, | |
| eos_token_id=MT5_TOKENIZER.eos_token_id, | |
| ) | |
| text = MT5_TOKENIZER.decode(out_ids[0][1:], skip_special_tokens=True).split('.')[0] | |
| return text, seq | |
| def get_diacritics_sequence(txt): | |
| return ' '.join([c for c in txt if c in DIACRITICS]) | |
| def calculate_metrics(ref, hyp): | |
| if not ref.strip() and not hyp.strip(): return 0.0, 0.0, 0.0 | |
| if not ref.strip(): return 1.0, 1.0, 1.0 | |
| wer = jiwer.wer(ref, hyp) | |
| ref_d, hyp_d = get_diacritics_sequence(ref), get_diacritics_sequence(hyp) | |
| der = 0.0 if (not ref_d and not hyp_d) else (1.0 if not ref_d else jiwer.wer(ref_d, hyp_d)) | |
| cer = jiwer.cer(ref, hyp) | |
| return round(wer,4), round(der,4), round(cer,4) | |
| def highlight_errors(ref, hyp): | |
| ref_w, hyp_w = ref.split(), hyp.split() | |
| matcher = difflib.SequenceMatcher(None, ref_w, hyp_w, autojunk=False) | |
| out_words, errs = [], [] | |
| for tag, i1, i2, j1, j2 in matcher.get_opcodes(): | |
| if tag == 'equal': | |
| out_words.extend(hyp_w[j1:j2]) | |
| elif tag == 'replace': | |
| for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ffcccb;'>{w}</mark>") | |
| errs.extend(ref_w[i1:i2] + hyp_w[j1:j2]) | |
| elif tag == 'delete': | |
| errs.extend(ref_w[i1:i2]) | |
| elif tag == 'insert': | |
| for w in hyp_w[j1:j2]: out_words.append(f"<mark style='background-color:#ccffcc;'>{w}</mark>") | |
| errs.extend(hyp_w[j1:j2]) | |
| return ' '.join(out_words), ', '.join(sorted(set(errs))) | |
| def process_audio_and_compare(audio_path, reference_text): | |
| if not audio_path: | |
| return *("Error: No audio provided.",)*2, None, None, None, "", "" | |
| if not reference_text.strip(): | |
| return *("Error: No reference text.",)*2, None, None, None, "", "" | |
| hyp, syll = get_and_process_syllables(audio_path) | |
| wer, der, cer = calculate_metrics(reference_text, hyp) if not hyp.startswith("Error") else (None,None,None) | |
| html_out, errs = highlight_errors(reference_text, hyp) if not hyp.startswith("Error") else ("", "") | |
| return hyp, syll, wer, der, cer, html_out, errs | |
| # ---------- Gradio Interface ---------- | |
| with gr.Blocks(theme=gr.themes.Soft()) as app: | |
| gr.Markdown(""" | |
| # Arabic Diacritization & Reading Assessment | |
| 1. Enter undiacritized Arabic text → Diacritize. | |
| 2. Optionally edit the diacritized result. | |
| 3. Record/upload audio → Transcribe & Compare. | |
| """) | |
| ref_state = gr.State("") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| text_in = gr.Textbox(label="Undiacritized Arabic Text", lines=3, text_align="right") | |
| model_sel = gr.Dropdown(choices=["Encoder-Only","Encoder-Decoder"], value="Encoder-Only", label="Model") | |
| diac_btn = gr.Button("Diacritize Text") | |
| diac_out = gr.Textbox(label="Diacritized Text (Reference)", lines=3, text_align="right", interactive=True) | |
| diac_btn.click(fn=diacritize_text, inputs=[model_sel, text_in], outputs=[diac_out, ref_state]) | |
| diac_out.change(fn=lambda text: text, inputs=diac_out, outputs=ref_state) | |
| with gr.Column(scale=1): | |
| audio_in = gr.Audio(label="Record/Upload Audio", type="filepath") | |
| trans_btn = gr.Button("Transcribe & Compare") | |
| hyp_out = gr.Textbox(label="Transcript (Hypothesis)", lines=3, text_align="right") | |
| syl_out = gr.Textbox(label="Transcript Syllables", lines=3, text_align="right") | |
| wer_n = gr.Number(label="WER", precision=4) | |
| der_n = gr.Number(label="DER", precision=4) | |
| cer_n = gr.Number(label="CER", precision=4) | |
| err_html = gr.HTML(label="Highlighted Errors") | |
| err_list = gr.Textbox(label="Error Words") | |
| trans_btn.click( | |
| fn=process_audio_and_compare, | |
| inputs=[audio_in, ref_state], | |
| outputs=[hyp_out, syl_out, wer_n, der_n, cer_n, err_html, err_list] | |
| ) | |
| # Launch | |
| if __name__ == "__main__": | |
| app.launch(debug=True, share=True) | |