Spaces:
Runtime error
Runtime error
| import os | |
| # Questions for Gradio | |
| # - Chat share button is enabled by default but thrown an error when clicked. | |
| # - How to add local images in HTML? (https://github.com/gradio-app/gradio/issues/884) | |
| # - How to allow Chatbot to fill the vertical space? (https://github.com/gradio-app/gradio/issues/4001) | |
| # TODO: | |
| # - Add the 1MB models, keras/gemma_1.1_instruct_7b_en | |
| # - Add retry button, for each model individually | |
| # - Add ability to route a message to a single model only. | |
| # - log_applied_layout_map: make it work for Llama3CausalLM and LlamaCausalLM (vicuna) | |
| # - display context length | |
| os.environ["KERAS_BACKEND"] = "jax" | |
| import gradio as gr | |
| from gradio import ChatMessage | |
| import keras_hub | |
| from chatstate import ChatState | |
| from enum import Enum | |
| from models import ( | |
| model_presets, | |
| load_model, | |
| model_labels, | |
| preset_to_website_url, | |
| get_appropriate_chat_template, | |
| ) | |
| class TextRoute(Enum): | |
| LEFT = 0 | |
| RIGHT = 1 | |
| BOTH = 2 | |
| model_labels_list = list(model_labels) | |
| # load and warm up (compile) all the models | |
| models = [] | |
| for preset in model_presets: | |
| model = load_model(preset) | |
| chat_template = get_appropriate_chat_template(preset) | |
| chat_state = ChatState(model, "", chat_template) | |
| prompt, response = chat_state.send_message("Hello") | |
| print("model " + preset + " loaded and initialized.") | |
| print("The model responded: " + response) | |
| models.append(model) | |
| # For local debugging | |
| # model = keras_hub.models.Llama3CausalLM.from_preset( | |
| # # "hf://meta-llama/Llama-3.2-1B-Instruct", dtype="bfloat16" | |
| # "../misc-code/ari_tiny_llama3" | |
| # ) | |
| # models = [model, model, model, model, model] | |
| def chat_turn_assistant( | |
| message, | |
| sel, | |
| history, | |
| system_message, | |
| # max_tokens, | |
| # temperature, | |
| # top_p, | |
| ): | |
| model = models[sel] | |
| preset = model_presets[sel] | |
| chat_template = get_appropriate_chat_template(preset) | |
| chat_state = ChatState(model, system_message, chat_template) | |
| for msg in history: | |
| msg = ChatMessage(**msg) | |
| if msg.role == "user": | |
| chat_state.add_to_history_as_user(msg.content) | |
| elif msg.role == "assistant": | |
| chat_state.add_to_history_as_model(msg.content) | |
| prompt, response = chat_state.send_message(message) | |
| history.append(ChatMessage(role="assistant", content=response)) | |
| return history | |
| def chat_turn_both_assistant( | |
| message, sel1, sel2, history1, history2, system_message | |
| ): | |
| return ( | |
| chat_turn_assistant(message, sel1, history1, system_message), | |
| chat_turn_assistant(message, sel2, history2, system_message), | |
| ) | |
| def chat_turn_user(message, history): | |
| history.append(ChatMessage(role="user", content=message)) | |
| return history | |
| def chat_turn_both_user(message, history1, history2): | |
| return ( | |
| chat_turn_user(message, history1), | |
| chat_turn_user(message, history2), | |
| ) | |
| def bot_icon_select(model_name): | |
| if "gemma" in model_name: | |
| return "img/gemma.png" | |
| elif "llama" in model_name: | |
| return "img/meta.png" | |
| elif "vicuna" in model_name: | |
| return "img/vicuna.png" | |
| elif "mistral" in model_name: | |
| return "img/mistral.png" | |
| # default | |
| return "img/bot.png" | |
| def instantiate_select_box(sel, model_labels): | |
| return gr.Dropdown( | |
| choices=[(name, i) for i, name in enumerate(model_labels)], | |
| show_label=False, | |
| value=sel, | |
| info="<span style='color:black'>Selected model:</span> <a href='" | |
| + preset_to_website_url(model_presets[sel]) | |
| + "'>" | |
| + preset_to_website_url(model_presets[sel]) | |
| + "</a>", | |
| ) | |
| def instantiate_chatbot(sel, key): | |
| model_name = model_presets[sel] | |
| return gr.Chatbot( | |
| type="messages", | |
| key=key, | |
| show_label=False, | |
| show_share_button=False, | |
| show_copy_all_button=True, | |
| avatar_images=("img/usr.png", bot_icon_select(model_name)), | |
| ) | |
| def instantiate_arrow_button(route, text_route): | |
| icons = { | |
| TextRoute.LEFT: "img/arrowL.png", | |
| TextRoute.RIGHT: "img/arrowR.png", | |
| TextRoute.BOTH: "img/arrowRL.png", | |
| } | |
| button = gr.Button( | |
| "", | |
| size="sm", | |
| scale=0, | |
| min_width=40, | |
| icon=icons[route], | |
| ) | |
| button.click(lambda: route, outputs=[text_route]) | |
| return button | |
| def instantiate_retry_button(route): | |
| return gr.Button( | |
| "", | |
| size="sm", | |
| scale=0, | |
| min_width=40, | |
| icon="img/retry.png", | |
| ) | |
| def instantiate_trash_button(): | |
| return gr.Button( | |
| "", | |
| size="sm", | |
| scale=0, | |
| min_width=40, | |
| icon="img/trash.png", | |
| ) | |
| def instantiate_text_box(): | |
| return gr.Textbox(label="Your message:", submit_btn=True, key="msg") | |
| def instantiate_additional_settings(): | |
| with gr.Accordion("Additional settings", open=False): | |
| system_message = gr.Textbox( | |
| label="Sytem prompt", | |
| key="system_prompt", | |
| value="You are a helpful assistant and your name is Eliza.", | |
| ) | |
| return system_message | |
| def retry_fn(history): | |
| if len(history) >= 2: | |
| msg = history.pop(-1) # assistant message | |
| msg = history.pop(-1) # user message | |
| return msg["content"], history | |
| else: | |
| return gr.skip(), gr.skip() | |
| def retry_fn_both(history1, history2): | |
| msg1, history1 = retry_fn(history1) | |
| msg2, history2 = retry_fn(history2) | |
| if isinstance(msg1, str) and isinstance(msg2, str): | |
| if msg1 == msg2: | |
| msg = msg1 | |
| else: | |
| msg = msg1 + " / " + msg2 | |
| elif isinstance(msg1, str): | |
| msg = msg1 | |
| elif isinstance(msg2, str): | |
| msg = msg2 | |
| else: | |
| msg = msg1 | |
| return msg, history1, history2 | |
| sel1 = instantiate_select_box(0, model_labels_list) | |
| sel2 = instantiate_select_box(1, model_labels_list) | |
| chatbot1 = instantiate_chatbot(sel1.value, "chat1") | |
| chatbot2 = instantiate_chatbot(sel2.value, "chat2") | |
| # to correctly align the left/right arrows | |
| CSS = ".stick-to-the-right {align-items: end; justify-content: end}" | |
| with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo: | |
| # Where do messages go | |
| text_route = gr.State(TextRoute.BOTH) | |
| with gr.Row(): | |
| gr.Image( | |
| "img/keras_logo_k.png", | |
| width=80, | |
| height=80, | |
| min_width=80, | |
| show_label=False, | |
| show_download_button=False, | |
| show_fullscreen_button=False, | |
| show_share_button=False, | |
| interactive=False, | |
| scale=0, | |
| container=False, | |
| ) | |
| gr.HTML( | |
| "<H2>Keras chatbot arena - running with JAX on TPU</H2>" | |
| + "All the models are loaded into the TPU memory. " | |
| + "You can call any of them and compare their answers. " | |
| + "The entire chat<br/>history is fed to the models at every submission. " | |
| + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision." | |
| ) | |
| with gr.Row(): | |
| sel1.render(), | |
| sel2.render(), | |
| with gr.Row(): | |
| chatbot1.render() | |
| chatbot2.render() | |
| def render_text_area(route): | |
| if route == TextRoute.BOTH: | |
| with gr.Row(): | |
| msg = instantiate_text_box() | |
| with gr.Column(scale=0, min_width=100): | |
| with gr.Row(): | |
| instantiate_arrow_button(TextRoute.LEFT, text_route) | |
| retry = instantiate_retry_button(route) | |
| with gr.Row(): | |
| instantiate_arrow_button(TextRoute.RIGHT, text_route) | |
| trash = instantiate_trash_button() | |
| retry.click( | |
| retry_fn_both, | |
| inputs=[chatbot1, chatbot2], | |
| outputs=[msg, chatbot1, chatbot2], | |
| ) | |
| trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2]) | |
| elif route == TextRoute.LEFT: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| msg = instantiate_text_box() | |
| with gr.Column(scale=1): | |
| with gr.Row(): | |
| instantiate_arrow_button(TextRoute.RIGHT, text_route) | |
| retry = instantiate_retry_button(route) | |
| with gr.Row(): | |
| instantiate_arrow_button(TextRoute.BOTH, text_route) | |
| trash = instantiate_trash_button() | |
| retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1]) | |
| trash.click(lambda: ("", []), outputs=[msg, chatbot1]) | |
| elif route == TextRoute.RIGHT: | |
| with gr.Row(): | |
| with gr.Column(scale=1, elem_classes="stick-to-the-right"): | |
| with gr.Row(elem_classes="stick-to-the-right"): | |
| retry = instantiate_retry_button(route) | |
| instantiate_arrow_button(TextRoute.LEFT, text_route) | |
| with gr.Row(elem_classes="stick-to-the-right"): | |
| trash = instantiate_trash_button() | |
| instantiate_arrow_button(TextRoute.BOTH, text_route) | |
| with gr.Column(scale=1): | |
| msg = instantiate_text_box() | |
| retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2]) | |
| trash.click(lambda: ("", []), outputs=[msg, chatbot2]) | |
| system_message = instantiate_additional_settings() | |
| # Route the submitted message to the left, right or both chatbots | |
| if route == TextRoute.LEFT: | |
| submission = msg.submit( | |
| chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1] | |
| ).then( | |
| chat_turn_assistant, | |
| [msg, sel1, chatbot1, system_message], | |
| outputs=[chatbot1], | |
| ) | |
| elif route == TextRoute.RIGHT: | |
| submission = msg.submit( | |
| chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2] | |
| ).then( | |
| chat_turn_assistant, | |
| [msg, sel2, chatbot2, system_message], | |
| outputs=[chatbot2], | |
| ) | |
| elif route == TextRoute.BOTH: | |
| submission = msg.submit( | |
| chat_turn_both_user, | |
| inputs=[msg, chatbot1, chatbot2], | |
| outputs=[chatbot1, chatbot2], | |
| ).then( | |
| chat_turn_both_assistant, | |
| [msg, sel1, sel2, chatbot1, chatbot2, system_message], | |
| outputs=[chatbot1, chatbot2], | |
| ) | |
| # In all cases reset text box after submission | |
| submission.then(lambda: "", outputs=msg) | |
| sel1.select( | |
| lambda sel: instantiate_chatbot(sel, "chat1"), | |
| inputs=[sel1], | |
| outputs=[chatbot1], | |
| ).then( | |
| lambda sel: instantiate_select_box(sel, model_labels_list), | |
| inputs=[sel1], | |
| outputs=[sel1], | |
| ) | |
| sel2.select( | |
| lambda sel: instantiate_chatbot(sel, "chat2"), | |
| inputs=[sel2], | |
| outputs=[chatbot2], | |
| ).then( | |
| lambda sel: instantiate_select_box(sel, model_labels_list), | |
| inputs=[sel2], | |
| outputs=[sel2], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |