Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM | |
| from huggingface_hub import InferenceClient | |
| # Define tokenizer | |
| special_tokens = ["<pad>", "<s>", "</s>", "<unk>"] | |
| nepali_chars = list("अआइईउऊऋॠऌॡऎएऐओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलवशषसह्ािीुूृॄेैोौंंःँ।०१२३४५६७८९,.;?!़ॅंःॊॅऒऽॉड़ॐ॥ऑऱफ़ढ़") | |
| char_vocab = special_tokens + nepali_chars | |
| char2id = {char: idx for idx, char in enumerate(char_vocab)} | |
| id2char = {idx: char for char, idx in char2id.items()} | |
| class CharTokenizer: | |
| def __init__(self, char2id, id2char): | |
| self.char2id = char2id | |
| self.id2char = id2char | |
| def encode(self, text): | |
| return [self.char2id.get(char, self.char2id["<unk>"]) for char in text] | |
| def decode(self, tokens): | |
| return "".join([self.id2char.get(token, "<unk>") for token in tokens]) | |
| def decodex(self, tokens): | |
| decoded_string = "" | |
| for i, token in enumerate(tokens): | |
| char = self.id2char.get(token, "<unk>") | |
| if char == "<unk>": | |
| if i == 0 or i == len(tokens) - 1 or self.id2char.get(tokens[i - 1], "<unk>") == "<unk>": | |
| decoded_string += "" | |
| else: | |
| decoded_string += " " | |
| elif char == "<pad>": | |
| pass | |
| else: | |
| decoded_string += char | |
| return decoded_string | |
| # Initialize tokenizer | |
| tokenizer = CharTokenizer(char2id, id2char) | |
| # Load T5 model | |
| model_name = "bashyaldhiraj2067/t5_char_nepali" | |
| # model_name = "bashyaldhiraj2067/attention_epoch_2_xpu_64_copymechanism_nepali_GEC_new_21" | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
| def correct_text(input_text, max_length=256): | |
| input_ids = tokenizer.encode(input_text) | |
| input_tensor = torch.tensor([input_ids]) | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| input_tensor, | |
| max_length=max_length, | |
| return_dict_in_generate=True | |
| ) | |
| generated_tokens = outputs.sequences[0].tolist() | |
| return tokenizer.decodex(generated_tokens) | |
| # Gradio interface | |
| demo = gr.Interface( | |
| fn=correct_text, | |
| inputs=[gr.Textbox(label="Enter Nepali Text"), gr.Slider(50, 256, step=10, label="Max Length")], | |
| outputs=gr.Textbox(label="Corrected Text"), | |
| title="Nepali Text Correction", | |
| description="Enter text with errors and get corrected output using a T5 model trained on Nepali text.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |