Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from peft import PeftModel, PeftConfig | |
| from transformers import AutoModelForTokenClassification | |
| def test_mask(model, sample): | |
| """ | |
| Masks the padded tokens in the input. | |
| Args: | |
| data (list): List of strings. | |
| Returns: | |
| dataset (list): List of dictionaries. | |
| """ | |
| tokens = dict() | |
| input_tokens = [i + 3 for i in sample.encode('utf-8')] | |
| input_tokens.append(0) # eos token | |
| tokens['input_ids'] = torch.tensor([input_tokens], dtype=torch.int64, device=model.device) | |
| # Create attention mask | |
| tokens['attention_mask'] = torch.ones_like(tokens['input_ids'], dtype=torch.int64, device=model.device) | |
| return tokens | |
| def rewrite(model, data): | |
| """ | |
| Rewrites the input text with the model. | |
| Args: | |
| model (torch.nn.Module): Model. | |
| data (dict): Dictionary containing 'input_ids' and 'attention_mask'. | |
| Returns: | |
| output (str): Rewritten text. | |
| """ | |
| with torch.no_grad(): | |
| pred = torch.argmax(model(**data).logits, dim=2).squeeze(0) | |
| output = list() # save the indices of the characters as list of integers | |
| # Conversion table for Turkish characters {100: [300, 350], ...} | |
| en2tr = {en: tr for tr, en in zip(list(map(list, map(str.encode, list('ÜİĞŞÇÖüığşçö')))), list(map(ord, list('UIGSCOuigsco'))))} | |
| for inp, lab in zip((data['input_ids'].squeeze(0) - 3).tolist(), pred.tolist()): | |
| if lab and inp in en2tr: | |
| # if the model predicts a diacritic, replace it with the corresponding Turkish character | |
| output.extend(en2tr[inp]) | |
| elif inp >= 0: output.append(inp) | |
| return bytes(output).decode() | |
| def try_it(text): | |
| sample = test_mask(model, text) | |
| return rewrite(model, sample) | |
| if __name__ == '__main__': | |
| config = PeftConfig.from_pretrained("bite-the-byte/byt5-small-deASCIIfy-TR") | |
| model = AutoModelForTokenClassification.from_pretrained("google/byt5-small") | |
| model = PeftModel.from_pretrained(model, "bite-the-byte/byt5-small-deASCIIfy-TR") | |
| diacritize_app = gr.Interface(fn=try_it, inputs="text", outputs="text") | |
| diacritize_app.launch(share=True) |