Spaces:
Sleeping
Sleeping
| import datasets | |
| import gradio as gr | |
| from transformers import AutoModelForMaskedLM, AutoTokenizer, DataCollatorForLanguageModeling | |
| ds = datasets.load_dataset("nthngdy/oscar-small", "unshuffled_deduplicated_en", streaming=True, split="train") | |
| ds = ds.shuffle(buffer_size=1000) | |
| ds = iter(ds) | |
| model_name = "RomanCast/roberta-en-100k" | |
| model = AutoModelForMaskedLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| collate_fn = DataCollatorForLanguageModeling(tokenizer) | |
| with gr.Blocks() as demo: | |
| inputs_oscar = gr.TextArea( | |
| placeholder="Type a sentence or click the button below to get a random sentence from the English OSCAR corpus", | |
| label="Input", | |
| num_lines=6, | |
| interactive=True, | |
| ) | |
| next_button = gr.Button("Random OSCAR sentence") | |
| next_button.click(fn=lambda: next(ds)["text"], outputs=inputs_oscar) | |
| masked_text = gr.Textbox(label="Masked sentence") | |
| labels_and_outputs = [] | |
| with gr.Row(): | |
| for _ in range(4): | |
| with gr.Column(): | |
| labels_and_outputs.append(gr.Textbox(label="Label")) | |
| labels_and_outputs.append(gr.Label(num_top_classes=5, show_label=False)) | |
| with gr.Row(): | |
| for _ in range(4): | |
| with gr.Column(): | |
| labels_and_outputs.append(gr.Textbox(label="Label")) | |
| labels_and_outputs.append(gr.Label(num_top_classes=5, show_label=False)) | |
| def model_inputs_and_outputs(example): | |
| token_ids = tokenizer(example, return_tensors="pt", truncation=True, max_length=128) | |
| model_inputs = collate_fn((token_ids,)) | |
| model_inputs = {k: v[0] for k, v in model_inputs.items()} | |
| masked_tokens = tokenizer.batch_decode(model_inputs["input_ids"])[0] | |
| original_labels = [tokenizer.convert_ids_to_tokens([id])[0] for id in model_inputs["labels"][0] if id != -100] | |
| out = model(**model_inputs) | |
| all_logits = out.logits[model_inputs["labels"] != -100].softmax(-1) | |
| all_outputs = [ | |
| {tokenizer.convert_ids_to_tokens([id])[0]: val.item() for id, val in enumerate(logits)} | |
| for logits in all_logits | |
| ] | |
| out_dict = {masked_text: masked_tokens} | |
| for i in range(len(labels_and_outputs) // 2): | |
| try: | |
| out_dict[labels_and_outputs[2 * i]] = original_labels[i] | |
| out_dict[labels_and_outputs[2 * i + 1]] = all_outputs[i] | |
| except: | |
| out_dict[labels_and_outputs[2 * i]] = "" | |
| out_dict[labels_and_outputs[2 * i + 1]] = {} | |
| return out_dict | |
| button = gr.Button("Predict tokens") | |
| button.click(fn=model_inputs_and_outputs, inputs=inputs_oscar, outputs=[masked_text] + labels_and_outputs) | |
| demo.launch() | |