Spaces:
Running
Running
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import numpy as np | |
| import gradio as gr | |
| import spaces | |
| tokenizer = AutoTokenizer.from_pretrained("gpt2") | |
| model = AutoModelForCausalLM.from_pretrained("gpt2") | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| print("Loading finished.") | |
| print(f"Is CUDA available: {torch.cuda.is_available()}") | |
| # True | |
| if torch.cuda.is_available(): | |
| print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}") | |
| STYLE = """ | |
| .custom-container { | |
| width: 100%; | |
| display: grid; | |
| align-items: center; | |
| margin: 0!important; | |
| overflow: scroll; | |
| } | |
| .prose ul ul { | |
| margin: 0!important; | |
| font-size: 10px!important; | |
| } | |
| .prose td, th { | |
| padding-left: 2px; | |
| padding-right: 2px; | |
| padding-top: 0; | |
| padding-bottom: 0; | |
| text-wrap: nowrap; | |
| } | |
| .tree { | |
| padding: 0px; | |
| margin: 0!important; | |
| box-sizing: border-box; | |
| font-size: 10px; | |
| width: 100%; | |
| height: auto; | |
| } | |
| .tree ul { | |
| padding-top: 20px; | |
| position: relative; | |
| transition: .5s; | |
| margin: 0!important; | |
| display: flex; | |
| flex-direction: row; | |
| justify-content: center; | |
| gap:10px; | |
| } | |
| .tree li { | |
| display: inline-table; | |
| text-align: center; | |
| list-style-type: none; | |
| position: relative; | |
| padding-top: 10px; | |
| transition: .5s; | |
| } | |
| .tree li::before, .tree li::after { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| right: 50%; | |
| border-top: 2px solid var(--body-text-color); | |
| width: 55%; | |
| min-width: 30px; | |
| height: 10px; | |
| } | |
| .tree li::after { | |
| right: auto; | |
| left: 50%; | |
| border-left: 2px solid var(--body-text-color); | |
| } | |
| .tree li:only-child::after, .tree li:only-child::before { | |
| display: none; | |
| } | |
| .tree ul:has(> li:only-child)::before { | |
| height:40px; | |
| } | |
| .tree li:first-child::before, .tree li:last-child::after { | |
| border: 0 none; | |
| } | |
| .tree li:last-child::before { | |
| border-right: 2px solid var(--body-text-color); | |
| border-radius: 0 5px 0 0; | |
| -webkit-border-radius: 0 5px 0 0; | |
| -moz-border-radius: 0 5px 0 0; | |
| } | |
| .tree li:first-child::after { | |
| border-radius: 5px 0 0 0; | |
| -webkit-border-radius: 5px 0 0 0; | |
| -moz-border-radius: 5px 0 0 0; | |
| } | |
| .tree ul ul::before { | |
| content: ''; | |
| position: absolute; | |
| top: 0; | |
| left: 50%; | |
| border-left: 2px solid var(--body-text-color); | |
| width: 0; | |
| height: 20px; | |
| } | |
| .tree li a { | |
| border: 1px solid var(--body-text-color); | |
| padding: 5px; | |
| display: inline-grid; | |
| border-radius: 5px; | |
| text-decoration-line: none; | |
| border-radius: 5px; | |
| transition: .5s; | |
| } | |
| .tree li a span { | |
| padding: 5px; | |
| font-size: 12px; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| font-weight: 500; | |
| } | |
| /*Hover-Section*/ | |
| .tree li a:hover, .tree li a:hover+ul li a { | |
| background: #ffedd5; | |
| } | |
| .tree li a:hover+ul li::after, .tree li a:hover+ul li::before, .tree li a:hover+ul::before, .tree li a:hover+ul ul::before { | |
| border-color: #f97316; | |
| color= #CCC; | |
| } | |
| .chosen { | |
| background-color: #ea580c; | |
| } | |
| """ | |
| def generate_nodes(token, node): | |
| """Recursively generate HTML for the tree nodes.""" | |
| html_content = f" <li> <a href='#' class={('chosen' if node.table is None else '')}> <span> <b>{token}</b> </span> " | |
| html_content += node.table if node.table is not None else "" | |
| html_content += "</a>" | |
| if len(node.children.keys()) > 0: | |
| html_content += "<ul> " | |
| for token, subnode in node.children.items(): | |
| html_content += generate_nodes(token, subnode) | |
| html_content += "</ul>" | |
| html_content += "</li>" | |
| return html_content | |
| def generate_markdown_table(scores, sequence_prob, top_k=4, chosen_tokens=None): | |
| markdown_table = """ | |
| <table> | |
| <tr> | |
| <th><b>Token</b></th> | |
| <th><b>Step score</b></th> | |
| <th><b>Total score</b></th> | |
| </tr>""" | |
| for token_idx in np.array(np.argsort(scores)[-top_k:])[::-1]: | |
| token = tokenizer.decode([token_idx]) | |
| item_class = "" | |
| if chosen_tokens and token in chosen_tokens: | |
| item_class = "chosen" | |
| markdown_table += f""" | |
| <tr class={item_class}> | |
| <td>{token}</td> | |
| <td>{scores[token_idx]:.4f}</td> | |
| <td>{scores[token_idx] + sequence_prob:.4f}</td> | |
| </tr>""" | |
| markdown_table += """ | |
| </table>""" | |
| return markdown_table | |
| def generate_html(start_sentence, original_tree): | |
| html_output = """<div class="custom-container"> | |
| <div class="tree"> | |
| <ul>""" | |
| html_output += generate_nodes(start_sentence, original_tree) | |
| html_output += """ | |
| </ul> | |
| </div> | |
| </body> | |
| """ | |
| return html_output | |
| import pandas as pd | |
| from typing import Dict | |
| from dataclasses import dataclass | |
| class BeamNode: | |
| cumulative_score: float | |
| table: str | |
| current_sentence: str | |
| children: Dict[str, "BeamNode"] | |
| def generate_beams(start_sentence, scores, sequences, beam_indices): | |
| sequences = sequences.cpu().numpy() | |
| original_tree = BeamNode( | |
| cumulative_score=0, table=None, current_sentence=start_sentence, children={} | |
| ) | |
| n_beams = len(scores[0]) | |
| beam_trees = [original_tree] * n_beams | |
| for step, step_scores in enumerate(scores): | |
| ( | |
| top_token_indexes, | |
| top_cumulative_scores, | |
| beam_indexes, | |
| current_completions, | |
| top_tokens, | |
| ) = ([], [], [], [], []) | |
| for beam_ix in range(n_beams): | |
| current_beam = beam_trees[beam_ix] | |
| # Get top cumulative scores for the current beam | |
| current_top_token_indexes = list( | |
| np.array(scores[step][beam_ix].argsort()[-n_beams:])[::-1] | |
| ) | |
| top_token_indexes += current_top_token_indexes | |
| top_cumulative_scores += list( | |
| np.array(scores[step][beam_ix][current_top_token_indexes]) | |
| + current_beam.cumulative_score | |
| ) | |
| beam_indexes += [beam_ix] * n_beams | |
| current_completions += [beam_trees[beam_ix].current_sentence] * n_beams | |
| top_tokens += [ | |
| tokenizer.decode([el]) for el in current_top_token_indexes | |
| ] | |
| top_df = pd.DataFrame.from_dict( | |
| { | |
| "token_index": top_token_indexes, | |
| "cumulative_score": top_cumulative_scores, | |
| "beam_index": beam_indexes, | |
| "current_completions": current_completions, | |
| "token": top_tokens, | |
| } | |
| ) | |
| maxes = top_df.groupby(["token_index", "current_completions"])[ | |
| "cumulative_score" | |
| ].idxmax() | |
| top_df = top_df.loc[maxes] | |
| # Sort all top probabilities and keep top n_beams | |
| top_df_selected = top_df.sort_values("cumulative_score", ascending=False).iloc[ | |
| :n_beams | |
| ] | |
| # Write the scores table - one per beam source? | |
| # Edge case: if several beam indexes are actually on the same beam, the selected tokens by beam_index for the second one will be empty. So we reverse | |
| for beam_ix in reversed(list(range(n_beams))): | |
| current_beam = beam_trees[beam_ix] | |
| selected_tokens = top_df_selected.loc[top_df_selected["beam_index"] == beam_ix] | |
| markdown_table = generate_markdown_table( | |
| step_scores[beam_ix, :], | |
| current_beam.cumulative_score, | |
| chosen_tokens=list(selected_tokens["token"].values), | |
| ) | |
| beam_trees[beam_ix].table = markdown_table | |
| # Add new children for each beam | |
| cumulative_scores = [beam.cumulative_score for beam in beam_trees] | |
| for beam_ix in range(n_beams): | |
| current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
| current_token_choice = tokenizer.decode([current_token_choice_ix]) | |
| # Update the source tree | |
| source_beam_ix = int(top_df_selected.iloc[beam_ix]["beam_index"]) | |
| previous_len = len(str(original_tree)) | |
| beam_trees[source_beam_ix].children[current_token_choice] = BeamNode( | |
| table=None, | |
| children={}, | |
| current_sentence=beam_trees[source_beam_ix].current_sentence | |
| + current_token_choice, | |
| cumulative_score=cumulative_scores[source_beam_ix] | |
| + scores[step][source_beam_ix][current_token_choice_ix].numpy(), | |
| ) | |
| assert ( | |
| len(str(original_tree)) > previous_len | |
| ), "Original tree has not increased size" | |
| # Reassign all beams at once | |
| beam_trees = [ | |
| beam_trees[int(top_df_selected.iloc[beam_ix]["beam_index"])] | |
| for beam_ix in range(n_beams) | |
| ] | |
| # Advance all beams by one token | |
| for beam_ix in range(n_beams): | |
| current_token_choice_ix = top_df_selected.iloc[beam_ix]["token_index"] | |
| current_token_choice = tokenizer.decode([current_token_choice_ix]) | |
| beam_trees[beam_ix] = beam_trees[beam_ix].children[current_token_choice] | |
| return original_tree | |
| def get_beam_search_html(input_text, number_steps, number_beams): | |
| inputs = tokenizer([input_text], return_tensors="pt") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=number_steps, | |
| num_beams=number_beams, | |
| num_return_sequences=number_beams, | |
| return_dict_in_generate=True, | |
| output_scores=True, | |
| top_k=5, | |
| do_sample=False, | |
| ) | |
| original_tree = generate_beams( | |
| input_text, | |
| outputs.scores[:], | |
| outputs.sequences[:, :], | |
| outputs.beam_indices[:, :], | |
| ) | |
| html = generate_html(input_text, original_tree) | |
| return html | |
| with gr.Blocks( | |
| theme=gr.themes.Soft( | |
| text_size="lg", font=["monospace"], primary_hue=gr.themes.colors.yellow | |
| ), | |
| css=STYLE, | |
| ) as demo: | |
| text = gr.Textbox(label="Sentence to decode from", value="Today is") | |
| steps = gr.Slider(label="Number of steps", minimum=1, maximum=7, step=1, value=4) | |
| beams = gr.Slider(label="Number of beams", minimum=2, maximum=4, step=1, value=3) | |
| button = gr.Button() | |
| out = gr.Markdown(label="Output") | |
| button.click(get_beam_search_html, inputs=[text, steps, beams], outputs=out) | |
| demo.launch() |