|  | import os | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | os.environ["KERAS_BACKEND"] = "jax" | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from gradio import ChatMessage | 
					
						
						|  | import keras_hub | 
					
						
						|  |  | 
					
						
						|  | from chatstate import ChatState | 
					
						
						|  | from enum import Enum | 
					
						
						|  | from models import ( | 
					
						
						|  | model_presets, | 
					
						
						|  | load_model, | 
					
						
						|  | model_labels, | 
					
						
						|  | preset_to_website_url, | 
					
						
						|  | get_appropriate_chat_template, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class TextRoute(Enum): | 
					
						
						|  | LEFT = 0 | 
					
						
						|  | RIGHT = 1 | 
					
						
						|  | BOTH = 2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | model_labels_list = list(model_labels) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | models = [] | 
					
						
						|  | for preset in model_presets: | 
					
						
						|  | model = load_model(preset) | 
					
						
						|  | chat_template = get_appropriate_chat_template(preset) | 
					
						
						|  | chat_state = ChatState(model, "", chat_template) | 
					
						
						|  | prompt, response = chat_state.send_message("Hello") | 
					
						
						|  | print("model " + preset + " loaded and initialized.") | 
					
						
						|  | print("The model responded: " + response) | 
					
						
						|  | models.append(model) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_assistant( | 
					
						
						|  | message, | 
					
						
						|  | sel, | 
					
						
						|  | history, | 
					
						
						|  | system_message, | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ): | 
					
						
						|  | model = models[sel] | 
					
						
						|  | preset = model_presets[sel] | 
					
						
						|  | chat_template = get_appropriate_chat_template(preset) | 
					
						
						|  | chat_state = ChatState(model, system_message, chat_template) | 
					
						
						|  |  | 
					
						
						|  | for msg in history: | 
					
						
						|  | msg = ChatMessage(**msg) | 
					
						
						|  | if msg.role == "user": | 
					
						
						|  | chat_state.add_to_history_as_user(msg.content) | 
					
						
						|  | elif msg.role == "assistant": | 
					
						
						|  | chat_state.add_to_history_as_model(msg.content) | 
					
						
						|  |  | 
					
						
						|  | prompt, response = chat_state.send_message(message) | 
					
						
						|  | history.append(ChatMessage(role="assistant", content=response)) | 
					
						
						|  | return history | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_both_assistant( | 
					
						
						|  | message, sel1, sel2, history1, history2, system_message | 
					
						
						|  | ): | 
					
						
						|  | return ( | 
					
						
						|  | chat_turn_assistant(message, sel1, history1, system_message), | 
					
						
						|  | chat_turn_assistant(message, sel2, history2, system_message), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_user(message, history): | 
					
						
						|  | history.append(ChatMessage(role="user", content=message)) | 
					
						
						|  | return history | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def chat_turn_both_user(message, history1, history2): | 
					
						
						|  | return ( | 
					
						
						|  | chat_turn_user(message, history1), | 
					
						
						|  | chat_turn_user(message, history2), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def bot_icon_select(model_name): | 
					
						
						|  | if "gemma" in model_name: | 
					
						
						|  | return "img/gemma.png" | 
					
						
						|  | elif "llama" in model_name: | 
					
						
						|  | return "img/meta.png" | 
					
						
						|  | elif "vicuna" in model_name: | 
					
						
						|  | return "img/vicuna.png" | 
					
						
						|  | elif "mistral" in model_name: | 
					
						
						|  | return "img/mistral.png" | 
					
						
						|  |  | 
					
						
						|  | return "img/bot.png" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_select_box(sel, model_labels): | 
					
						
						|  | return gr.Dropdown( | 
					
						
						|  | choices=[(name, i) for i, name in enumerate(model_labels)], | 
					
						
						|  | show_label=False, | 
					
						
						|  | value=sel, | 
					
						
						|  | info="<span style='color:black'>Selected model:</span> <a href='" | 
					
						
						|  | + preset_to_website_url(model_presets[sel]) | 
					
						
						|  | + "'>" | 
					
						
						|  | + preset_to_website_url(model_presets[sel]) | 
					
						
						|  | + "</a>", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_chatbot(sel, key): | 
					
						
						|  | model_name = model_presets[sel] | 
					
						
						|  | return gr.Chatbot( | 
					
						
						|  | type="messages", | 
					
						
						|  | key=key, | 
					
						
						|  | show_label=False, | 
					
						
						|  | show_share_button=False, | 
					
						
						|  | show_copy_all_button=True, | 
					
						
						|  | avatar_images=("img/usr.png", bot_icon_select(model_name)), | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_arrow_button(route, text_route): | 
					
						
						|  | icons = { | 
					
						
						|  | TextRoute.LEFT: "img/arrowL.png", | 
					
						
						|  | TextRoute.RIGHT: "img/arrowR.png", | 
					
						
						|  | TextRoute.BOTH: "img/arrowRL.png", | 
					
						
						|  | } | 
					
						
						|  | button = gr.Button( | 
					
						
						|  | "", | 
					
						
						|  | size="sm", | 
					
						
						|  | scale=0, | 
					
						
						|  | min_width=40, | 
					
						
						|  | icon=icons[route], | 
					
						
						|  | ) | 
					
						
						|  | button.click(lambda: route, outputs=[text_route]) | 
					
						
						|  | return button | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_retry_button(route): | 
					
						
						|  | return gr.Button( | 
					
						
						|  | "", | 
					
						
						|  | size="sm", | 
					
						
						|  | scale=0, | 
					
						
						|  | min_width=40, | 
					
						
						|  | icon="img/retry.png", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_trash_button(): | 
					
						
						|  | return gr.Button( | 
					
						
						|  | "", | 
					
						
						|  | size="sm", | 
					
						
						|  | scale=0, | 
					
						
						|  | min_width=40, | 
					
						
						|  | icon="img/trash.png", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_text_box(): | 
					
						
						|  | return gr.Textbox(label="Your message:", submit_btn=True, key="msg") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def instantiate_additional_settings(): | 
					
						
						|  | with gr.Accordion("Additional settings", open=False): | 
					
						
						|  | system_message = gr.Textbox( | 
					
						
						|  | label="Sytem prompt", | 
					
						
						|  | key="system_prompt", | 
					
						
						|  | value="You are a helpful assistant and your name is Eliza.", | 
					
						
						|  | ) | 
					
						
						|  | return system_message | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def retry_fn(history): | 
					
						
						|  | if len(history) >= 2: | 
					
						
						|  | msg = history.pop(-1) | 
					
						
						|  | msg = history.pop(-1) | 
					
						
						|  | return msg["content"], history | 
					
						
						|  | else: | 
					
						
						|  | return gr.skip(), gr.skip() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def retry_fn_both(history1, history2): | 
					
						
						|  | msg1, history1 = retry_fn(history1) | 
					
						
						|  | msg2, history2 = retry_fn(history2) | 
					
						
						|  | if isinstance(msg1, str) and isinstance(msg2, str): | 
					
						
						|  | if msg1 == msg2: | 
					
						
						|  | msg = msg1 | 
					
						
						|  | else: | 
					
						
						|  | msg = msg1 + " / " + msg2 | 
					
						
						|  | elif isinstance(msg1, str): | 
					
						
						|  | msg = msg1 | 
					
						
						|  | elif isinstance(msg2, str): | 
					
						
						|  | msg = msg2 | 
					
						
						|  | else: | 
					
						
						|  | msg = msg1 | 
					
						
						|  | return msg, history1, history2 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | sel1 = instantiate_select_box(0, model_labels_list) | 
					
						
						|  | sel2 = instantiate_select_box(1, model_labels_list) | 
					
						
						|  | chatbot1 = instantiate_chatbot(sel1.value, "chat1") | 
					
						
						|  | chatbot2 = instantiate_chatbot(sel2.value, "chat2") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | CSS = ".stick-to-the-right {align-items: end; justify-content: end}" | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks(fill_width=True, title="Keras demo", css=CSS) as demo: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | text_route = gr.State(TextRoute.BOTH) | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | gr.Image( | 
					
						
						|  | "img/keras_logo_k.png", | 
					
						
						|  | width=80, | 
					
						
						|  | height=80, | 
					
						
						|  | min_width=80, | 
					
						
						|  | show_label=False, | 
					
						
						|  | show_download_button=False, | 
					
						
						|  | show_fullscreen_button=False, | 
					
						
						|  | show_share_button=False, | 
					
						
						|  | interactive=False, | 
					
						
						|  | scale=0, | 
					
						
						|  | container=False, | 
					
						
						|  | ) | 
					
						
						|  | gr.HTML( | 
					
						
						|  | "<H2>Keras chatbot arena  - running with JAX on TPU</H2>" | 
					
						
						|  | + "All the models are loaded into the TPU memory. " | 
					
						
						|  | + "You can call any of them and compare their answers. " | 
					
						
						|  | + "The entire chat<br/>history is fed to the models at every submission. " | 
					
						
						|  | + "This demo is runnig on a Google TPU v5e 2x4 (8 cores) in bfloat16 precision." | 
					
						
						|  | ) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | sel1.render(), | 
					
						
						|  | sel2.render(), | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | chatbot1.render() | 
					
						
						|  | chatbot2.render() | 
					
						
						|  |  | 
					
						
						|  | @gr.render(inputs=text_route) | 
					
						
						|  | def render_text_area(route): | 
					
						
						|  |  | 
					
						
						|  | if route == TextRoute.BOTH: | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | msg = instantiate_text_box() | 
					
						
						|  | with gr.Column(scale=0, min_width=100): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | instantiate_arrow_button(TextRoute.LEFT, text_route) | 
					
						
						|  | retry = instantiate_retry_button(route) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | instantiate_arrow_button(TextRoute.RIGHT, text_route) | 
					
						
						|  | trash = instantiate_trash_button() | 
					
						
						|  | retry.click( | 
					
						
						|  | retry_fn_both, | 
					
						
						|  | inputs=[chatbot1, chatbot2], | 
					
						
						|  | outputs=[msg, chatbot1, chatbot2], | 
					
						
						|  | ) | 
					
						
						|  | trash.click(lambda: ("", [], []), outputs=[msg, chatbot1, chatbot2]) | 
					
						
						|  |  | 
					
						
						|  | elif route == TextRoute.LEFT: | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=1): | 
					
						
						|  | msg = instantiate_text_box() | 
					
						
						|  | with gr.Column(scale=1): | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | instantiate_arrow_button(TextRoute.RIGHT, text_route) | 
					
						
						|  | retry = instantiate_retry_button(route) | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | instantiate_arrow_button(TextRoute.BOTH, text_route) | 
					
						
						|  | trash = instantiate_trash_button() | 
					
						
						|  | retry.click(retry_fn, inputs=[chatbot1], outputs=[msg, chatbot1]) | 
					
						
						|  | trash.click(lambda: ("", []), outputs=[msg, chatbot1]) | 
					
						
						|  |  | 
					
						
						|  | elif route == TextRoute.RIGHT: | 
					
						
						|  | with gr.Row(): | 
					
						
						|  | with gr.Column(scale=1, elem_classes="stick-to-the-right"): | 
					
						
						|  | with gr.Row(elem_classes="stick-to-the-right"): | 
					
						
						|  | retry = instantiate_retry_button(route) | 
					
						
						|  | instantiate_arrow_button(TextRoute.LEFT, text_route) | 
					
						
						|  | with gr.Row(elem_classes="stick-to-the-right"): | 
					
						
						|  | trash = instantiate_trash_button() | 
					
						
						|  | instantiate_arrow_button(TextRoute.BOTH, text_route) | 
					
						
						|  | with gr.Column(scale=1): | 
					
						
						|  | msg = instantiate_text_box() | 
					
						
						|  | retry.click(retry_fn, inputs=[chatbot2], outputs=[msg, chatbot2]) | 
					
						
						|  | trash.click(lambda: ("", []), outputs=[msg, chatbot2]) | 
					
						
						|  |  | 
					
						
						|  | system_message = instantiate_additional_settings() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if route == TextRoute.LEFT: | 
					
						
						|  | submission = msg.submit( | 
					
						
						|  | chat_turn_user, inputs=[msg, chatbot1], outputs=[chatbot1] | 
					
						
						|  | ).then( | 
					
						
						|  | chat_turn_assistant, | 
					
						
						|  | [msg, sel1, chatbot1, system_message], | 
					
						
						|  | outputs=[chatbot1], | 
					
						
						|  | ) | 
					
						
						|  | elif route == TextRoute.RIGHT: | 
					
						
						|  | submission = msg.submit( | 
					
						
						|  | chat_turn_user, inputs=[msg, chatbot2], outputs=[chatbot2] | 
					
						
						|  | ).then( | 
					
						
						|  | chat_turn_assistant, | 
					
						
						|  | [msg, sel2, chatbot2, system_message], | 
					
						
						|  | outputs=[chatbot2], | 
					
						
						|  | ) | 
					
						
						|  | elif route == TextRoute.BOTH: | 
					
						
						|  | submission = msg.submit( | 
					
						
						|  | chat_turn_both_user, | 
					
						
						|  | inputs=[msg, chatbot1, chatbot2], | 
					
						
						|  | outputs=[chatbot1, chatbot2], | 
					
						
						|  | ).then( | 
					
						
						|  | chat_turn_both_assistant, | 
					
						
						|  | [msg, sel1, sel2, chatbot1, chatbot2, system_message], | 
					
						
						|  | outputs=[chatbot1, chatbot2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | submission.then(lambda: "", outputs=msg) | 
					
						
						|  |  | 
					
						
						|  | sel1.select( | 
					
						
						|  | lambda sel: instantiate_chatbot(sel, "chat1"), | 
					
						
						|  | inputs=[sel1], | 
					
						
						|  | outputs=[chatbot1], | 
					
						
						|  | ).then( | 
					
						
						|  | lambda sel: instantiate_select_box(sel, model_labels_list), | 
					
						
						|  | inputs=[sel1], | 
					
						
						|  | outputs=[sel1], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | sel2.select( | 
					
						
						|  | lambda sel: instantiate_chatbot(sel, "chat2"), | 
					
						
						|  | inputs=[sel2], | 
					
						
						|  | outputs=[chatbot2], | 
					
						
						|  | ).then( | 
					
						
						|  | lambda sel: instantiate_select_box(sel, model_labels_list), | 
					
						
						|  | inputs=[sel2], | 
					
						
						|  | outputs=[sel2], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == "__main__": | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |