Spaces:
Running
Running
| import torch | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # True | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| STYLE = """ | |
| .container { | |
| width: 100%; | |
| display: grid; | |
| align-items: center; | |
| margin: 0!important; | |
| } | |
| .prose ul ul { | |
| margin: 0!important; | |
| } | |
| .tree { | |
| padding: 0px; | |
| margin: 0!important; | |
| box-sizing: border-box; | |
| font-size: 16px; | |
| width: 100%; | |
| height: auto; | |
| text-align: center; | |
| } | |
| .tree ul { | |
| padding-top: 20px; | |
| position: relative; | |
| transition: .5s; | |
| margin: 0!important; | |
| } | |
| .tree li { | |
| display: inline-table; | |
| text-align: center; | |
| list-style-type: none; | |
| position: relative; | |
| padding: 10px; | |
| transition: .5s; | |
| } | |
| .tree li::before, .tree li::after { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| right: 50%; | |
| border-top: 1px solid #ccc; | |
| width: 51%; | |
| height: 10px; | |
| } | |
| .tree li::after { | |
| right: auto; | |
| left: 50%; | |
| border-left: 1px solid #ccc; | |
| } | |
| .tree li:only-child::after, .tree li:only-child::before { | |
| display: none; | |
| } | |
| .tree li:only-child { | |
| padding-top: 0; | |
| } | |
| .tree li:first-child::before, .tree li:last-child::after { | |
| border: 0 none; | |
| } | |
| .tree li:last-child::before { | |
| border-right: 1px solid #ccc; | |
| border-radius: 0 5px 0 0; | |
| -webkit-border-radius: 0 5px 0 0; | |
| -moz-border-radius: 0 5px 0 0; | |
| } | |
| .tree li:first-child::after { | |
| border-radius: 5px 0 0 0; | |
| -webkit-border-radius: 5px 0 0 0; | |
| -moz-border-radius: 5px 0 0 0; | |
| } | |
| .tree ul ul::before { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| left: 50%; | |
| border-left: 1px solid #ccc; | |
| width: 0; | |
| height: 20px; | |
| } | |
| .tree li a { | |
| border: 1px solid #ccc; | |
| padding: 10px; | |
| display: inline-grid; | |
| border-radius: 5px; | |
| text-decoration-line: none; | |
| border-radius: 5px; | |
| transition: .5s; | |
| } | |
| .tree li a span { | |
| border: 1px solid #ccc; | |
| border-radius: 5px; | |
| color: #666; | |
| padding: 8px; | |
| font-size: 12px; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| font-weight: 500; | |
| } | |
| /*Hover-Section*/ | |
| .tree li a:hover, .tree li a:hover i, .tree li a:hover span, .tree li a:hover+ul li a { | |
| background: #c8e4f8; | |
| color: #000; | |
| border: 1px solid #94a0b4; | |
| } | |
| .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { | |
| border-color: #94a0b4; | |
| } | |
| """ | |
| from transformers import GPT2Tokenizer, AutoModelForCausalLM | |
| import numpy as np | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| model = AutoModelForCausalLM.from_pretrained("gpt2") | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| pritn("Loading finished.") | |
| def generate_html(token, node): | |
| """Recursively generate HTML for the tree.""" | |
| html_content = f" <li> <a href='#'> <span> <b>{token}</b> </span> " | |
| html_content += node["table"] if node["table"] is not None else "" | |
| html_content += "</a>" | |
| if len(node["children"].keys()) > 0: | |
| html_content += "<ul> " | |
| for token, subnode in node["children"].items(): | |
| html_content += generate_html(token, subnode) | |
| html_content += "</ul>" | |
| html_content += "</li>" | |
| return html_content | |
| def generate_markdown_table(scores, top_k=4, chosen_tokens=None): | |
| markdown_table = """ | |
| <table> | |
| <tr> | |
| <th><b>Token</b></th> | |
| <th><b>Probability</b></th> | |
| </tr>""" | |
| for token_idx in np.argsort(scores)[-top_k:]: | |
| token = tokenizer.decode([token_idx]) | |
| style = "" | |
| if chosen_tokens and token in chosen_tokens: | |
| style = "background-color:red" | |
| markdown_table += f""" | |
| <tr style={style}> | |
| <td>{token}</td> | |
| <td>{scores[token_idx]}</td> | |
| </tr>""" | |
| markdown_table += """ | |
| </table>""" | |
| return markdown_table | |
| def display_tree(scores, sequences, beam_indices): | |
| display = """<div class="container"> | |
| <div class="tree"> | |
| <ul>""" | |
| sequences = sequences.cpu().numpy() | |
| print(tokenizer.batch_decode(sequences)) | |
| original_tree = {"table": None, "children": {}} | |
| for sequence_ix in range(len(sequences)): | |
| current_tree = original_tree | |
| for step, step_scores in enumerate(scores): | |
| current_token_choice = tokenizer.decode([sequences[sequence_ix, step]]) | |
| current_beam = beam_indices[sequence_ix, step] | |
| if current_token_choice not in current_tree["children"]: | |
| current_tree["children"][current_token_choice] = { | |
| "table": None, | |
| "children": {}, | |
| } | |
| # Rewrite the probs table even if it was there before, since new chosen nodes have appeared in the children of current tree | |
| markdown_table = generate_markdown_table( | |
| step_scores[current_beam, :], | |
| chosen_tokens=current_tree["children"].keys(), | |
| ) | |
| current_tree["table"] = markdown_table | |
| current_tree = current_tree["children"][current_token_choice] | |
| display += generate_html("Today is", original_tree) | |
| display += """ | |
| </ul> | |
| </div> | |
| </body> | |
| """ | |
| print(display) | |
| return display | |
| def get_tables(input_text, number_steps, number_beams): | |
| inputs = tokenizer([input_text], return_tensors="pt") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=number_steps, | |
| num_beams=number_beams, | |
| num_return_sequences=number_beams, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| top_k=5, | |
| temperature=1.0, | |
| do_sample=True, | |
| ) | |
| tables = display_tree( | |
| outputs.scores, | |
| outputs.sequences[:, len(inputs) :], | |
| outputs.beam_indices[:, : -len(inputs)], | |
| ) | |
| return tables | |
| import gradio as gr | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.green | |
| ), | |
| css=STYLE, | |
| ) as demo: | |
| text = gr.Textbox(label="Sentence to decode from🪶", value="Today is") | |
| steps = gr.Slider(label="Number of steps", minimum=1, maximum=10, step=1, value=4) | |
| beams = gr.Slider(label="Number of beams", minimum=1, maximum=3, step=1, value=3) | |
| button = gr.Button() | |
| out = gr.Markdown(label="Output") | |
| button.click(get_tables, inputs=[text, steps, beams], outputs=out) | |
| demo.launch() |