Spaces:
Configuration error
Configuration error
| import argparse | |
| import datetime | |
| import hashlib | |
| import json | |
| import os | |
| import time | |
| import gradio as gr | |
| import requests | |
| from constants import LOGDIR | |
| from conversation import (default_conversation, conv_templates, | |
| SeparatorStyle) | |
| from utils import (build_logger, server_error_msg) | |
| logger = build_logger("gradio_web_server", "gradio_web_server.log") | |
| from model_worker import ModelWorker | |
| no_change_btn = gr.Button() | |
| enable_btn = gr.Button(interactive=True) | |
| disable_btn = gr.Button(interactive=False) | |
| def get_conv_log_filename(): | |
| t = datetime.datetime.now() | |
| name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-conv.json") | |
| return name | |
| get_window_url_params = """ | |
| function() { | |
| const params = new URLSearchParams(window.location.search); | |
| url_params = Object.fromEntries(params); | |
| console.log(url_params); | |
| return url_params; | |
| } | |
| """ | |
| def load_demo(url_params, request: gr.Request): | |
| logger.info(f"load_demo. ip: {request.client.host}. params: {url_params}") | |
| global worker | |
| dropdown_update = gr.Dropdown(visible=True) | |
| worker = ModelWorker(model_path, None, model_name, True, lora_path) | |
| state = default_conversation.copy() | |
| return state, dropdown_update | |
| def vote_last_response(state, vote_type, model_selector, request: gr.Request): | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(time.time(), 4), | |
| "type": vote_type, | |
| "model": model_selector, | |
| "state": state.dict(), | |
| "ip": request.client.host, | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| def upvote_last_response(state, model_selector, request: gr.Request): | |
| logger.info(f"upvote. ip: {request.client.host}") | |
| vote_last_response(state, "upvote", model_selector, request) | |
| return ("",) + (disable_btn,) * 3 | |
| def downvote_last_response(state, model_selector, request: gr.Request): | |
| logger.info(f"downvote. ip: {request.client.host}") | |
| vote_last_response(state, "downvote", model_selector, request) | |
| return ("",) + (disable_btn,) * 3 | |
| def flag_last_response(state, model_selector, request: gr.Request): | |
| logger.info(f"flag. ip: {request.client.host}") | |
| vote_last_response(state, "flag", model_selector, request) | |
| return ("",) + (disable_btn,) * 3 | |
| def regenerate(state, image_process_mode, request: gr.Request): | |
| logger.info(f"regenerate. ip: {request.client.host}") | |
| state.messages[-1][-1] = None | |
| prev_human_msg = state.messages[-2] | |
| if type(prev_human_msg[1]) in (tuple, list): | |
| prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
| def clear_history(request: gr.Request): | |
| logger.info(f"clear_history. ip: {request.client.host}") | |
| state = default_conversation.copy() | |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
| def add_text(state, text, image, image_process_mode, request: gr.Request): | |
| logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}") | |
| if len(text) <= 0 and image is None: | |
| state.skip_next = True | |
| return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5 | |
| text = text[:1536] # Hard cut-off | |
| if image is not None: | |
| text = text[:1200] # Hard cut-off for images | |
| if '<image>' not in text: | |
| # text = '<Image><image></Image>' + text | |
| text = text + '\n<image>' | |
| text = (text, image, image_process_mode) | |
| state = default_conversation.copy() | |
| state.append_message(state.roles[0], text) | |
| state.append_message(state.roles[1], None) | |
| state.skip_next = False | |
| return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5 | |
| def http_bot(state, model_selector, temperature, top_p, max_new_tokens, request: gr.Request): | |
| logger.info(f"http_bot. ip: {request.client.host}") | |
| start_tstamp = time.time() | |
| model_name = model_selector | |
| if state.skip_next: | |
| # This generate call is skipped due to invalid inputs | |
| yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5 | |
| return | |
| if len(state.messages) == state.offset + 2: | |
| # First round of conversation | |
| if "llava" in model_name.lower(): | |
| if 'llama-2' in model_name.lower(): | |
| template_name = "llava_llama_2" | |
| elif "mistral" in model_name.lower() or "mixtral" in model_name.lower(): | |
| if 'orca' in model_name.lower(): | |
| template_name = "mistral_orca" | |
| elif 'hermes' in model_name.lower(): | |
| template_name = "chatml_direct" | |
| else: | |
| template_name = "mistral_instruct" | |
| elif 'llava-v1.6-34b' in model_name.lower(): | |
| template_name = "chatml_direct" | |
| elif "v1" in model_name.lower(): | |
| if 'mmtag' in model_name.lower(): | |
| template_name = "v1_mmtag" | |
| elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): | |
| template_name = "v1_mmtag" | |
| else: | |
| template_name = "llava_v1" | |
| elif "mpt" in model_name.lower(): | |
| template_name = "mpt" | |
| else: | |
| if 'mmtag' in model_name.lower(): | |
| template_name = "v0_mmtag" | |
| elif 'plain' in model_name.lower() and 'finetune' not in model_name.lower(): | |
| template_name = "v0_mmtag" | |
| else: | |
| template_name = "llava_v0" | |
| elif "mpt" in model_name: | |
| template_name = "mpt_text" | |
| elif "llama-2" in model_name: | |
| template_name = "llama_2" | |
| else: | |
| template_name = "vicuna_v1" | |
| new_state = conv_templates[template_name].copy() | |
| new_state.append_message(new_state.roles[0], state.messages[-2][1]) | |
| new_state.append_message(new_state.roles[1], None) | |
| state = new_state | |
| # Construct prompt | |
| prompt = state.get_prompt() | |
| all_images = state.get_images(return_pil=True) | |
| all_image_hash = [hashlib.md5(image.tobytes()).hexdigest() for image in all_images] | |
| for image, hash in zip(all_images, all_image_hash): | |
| t = datetime.datetime.now() | |
| filename = os.path.join(LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{hash}.jpg") | |
| if not os.path.isfile(filename): | |
| os.makedirs(os.path.dirname(filename), exist_ok=True) | |
| image.save(filename) | |
| # Make requests | |
| pload = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| "temperature": float(temperature), | |
| "top_p": float(top_p), | |
| "max_new_tokens": min(int(max_new_tokens), 1536), | |
| "stop": state.sep if state.sep_style in [SeparatorStyle.SINGLE, SeparatorStyle.MPT] else state.sep2, | |
| "images": f'List of {len(state.get_images())} images: {all_image_hash}', | |
| } | |
| logger.info(f"==== request ====\n{pload}") | |
| pload['images'] = state.get_images() | |
| state.messages[-1][-1] = "β" | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
| try: | |
| # Stream output | |
| for chunk in worker.generate_stream_gate(pload): | |
| if chunk: | |
| data = json.loads(chunk.decode()) | |
| if data["error_code"] == 0: | |
| output = data["text"][len(prompt):].strip() | |
| state.messages[-1][-1] = output + "β" | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5 | |
| else: | |
| output = data["text"] + f" (error_code: {data['error_code']})" | |
| state.messages[-1][-1] = output | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
| return | |
| time.sleep(0.03) | |
| except requests.exceptions.RequestException as e: | |
| state.messages[-1][-1] = server_error_msg | |
| yield (state, state.to_gradio_chatbot()) + (disable_btn, disable_btn, disable_btn, enable_btn, enable_btn) | |
| return | |
| state.messages[-1][-1] = state.messages[-1][-1][:-1] | |
| yield (state, state.to_gradio_chatbot()) + (enable_btn,) * 5 | |
| finish_tstamp = time.time() | |
| logger.info(f"{output}") | |
| with open(get_conv_log_filename(), "a") as fout: | |
| data = { | |
| "tstamp": round(finish_tstamp, 4), | |
| "type": "chat", | |
| "model": model_name, | |
| "start": round(start_tstamp, 4), | |
| "finish": round(finish_tstamp, 4), | |
| "state": state.dict(), | |
| "images": all_image_hash, | |
| "ip": request.client.host, | |
| } | |
| fout.write(json.dumps(data) + "\n") | |
| title_markdown = (""" | |
| # Dr-LLaVA: Visual Instruction Tuning with Symbolic Clinical Grounding | |
| [[Project Page](https://XXXXX)] [[Code](https://github.com/AlaaLab/Dr-LLaVA)] | π [[Dr-LLaVA](https://arxiv.org/abs/2405.19567)]] | |
| """) | |
| tos_markdown = (""" | |
| This demo is intended for research purposes only and not for medical use. | |
| The model has not been fine-tuned on non-medical images. | |
| """) | |
| learn_more_markdown = (""" | |
| ### License | |
| The service is a research preview intended for non-commercial use only, subject to the model [License](https://github.com/facebookresearch/llama/blob/main/MODEL_CARD.md) of LLaMA, [Terms of Use](https://openai.com/policies/terms-of-use) of the data generated by OpenAI, and [Privacy Practices](https://chrome.google.com/webstore/detail/sharegpt-share-your-chatg/daiacboceoaocpibfodeljbdfacokfjb) of ShareGPT. Please contact us if you find any potential violation. | |
| """) | |
| block_css = """ | |
| #buttons button { | |
| min-width: min(120px,100%); | |
| } | |
| """ | |
| def build_demo(cur_dir=None, concurrency_count=10): | |
| textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER", container=False) | |
| with gr.Blocks(title="LLaVA", theme=gr.themes.Default(), css=block_css) as demo: | |
| state = gr.State() | |
| gr.Markdown(title_markdown) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # add a description | |
| gr.Markdown("""Shenghuan Sun, Gregory Goldgof, Alex Schubert, Zhiqing Sun, Atul Butte, Ahmed Alaa | |
| Demo Creator: [David W. Day](https://github.com/daviddaytw) | |
| This is the demo for Dr-LLaVA: a conversational vision-language model for diagnosing blood cancer using Bone Marrow Aspirate images. | |
| **Instructions:** | |
| - Drop a single image from a bone marrow aspirate whole slide image taken at 40x.</li> | |
| """) | |
| # Replace 'path_to_image' with the path to your image file | |
| gr.Image(value="https://davidday.tw/wp-content/uploads/2024/08/Dr-LLa-VA-Fig-1.jpg", | |
| width=600, interactive=False, type="pil") | |
| with gr.Column(scale=3): | |
| with gr.Row(elem_id="model_selector_row"): | |
| model_selector = gr.Dropdown( | |
| choices=models, | |
| value=models[0] if len(models) > 0 else "", | |
| interactive=True, | |
| show_label=False, | |
| container=False) | |
| imagebox = gr.Image(type="pil") | |
| image_process_mode = gr.Radio( | |
| ["Crop", "Resize", "Pad", "Default"], | |
| value="Default", | |
| label="Preprocess for non-square image", visible=False) | |
| if cur_dir is None: | |
| cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
| gr.Examples(examples=[ | |
| [f"{cur_dir}/examples/example1.jpeg", "Can you assess if these pathology images are suitable for identifying cancer upon inspection?"], | |
| [f"{cur_dir}/examples/example2.jpeg", "Are you able to recognize the probable illness in the image patch?"], | |
| ], inputs=[imagebox, textbox]) | |
| with gr.Accordion("Parameters", open=False) as parameter_row: | |
| temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, step=0.1, interactive=True, label="Temperature",) | |
| top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.7, step=0.1, interactive=True, label="Top P",) | |
| max_output_tokens = gr.Slider(minimum=0, maximum=1024, value=512, step=64, interactive=True, label="Max output tokens",) | |
| with gr.Column(scale=6): | |
| chatbot = gr.Chatbot( | |
| elem_id="chatbot", | |
| label="LLaVA Chatbot", | |
| height=470, | |
| layout="panel", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| textbox.render() | |
| with gr.Column(scale=1, min_width=50): | |
| submit_btn = gr.Button(value="Send", variant="primary") | |
| with gr.Row(elem_id="buttons") as button_row: | |
| upvote_btn = gr.Button(value="π Upvote", interactive=False) | |
| downvote_btn = gr.Button(value="π Downvote", interactive=False) | |
| flag_btn = gr.Button(value="β οΈ Flag", interactive=False) | |
| #stop_btn = gr.Button(value="βΉοΈ Stop Generation", interactive=False) | |
| regenerate_btn = gr.Button(value="π Regenerate", interactive=False) | |
| clear_btn = gr.Button(value="ποΈ Clear", interactive=False) | |
| gr.Markdown(tos_markdown) | |
| gr.Markdown(learn_more_markdown) | |
| url_params = gr.JSON(visible=False) | |
| # Register listeners | |
| btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn] | |
| upvote_btn.click( | |
| upvote_last_response, | |
| [state, model_selector], | |
| [textbox, upvote_btn, downvote_btn, flag_btn] | |
| ) | |
| downvote_btn.click( | |
| downvote_last_response, | |
| [state, model_selector], | |
| [textbox, upvote_btn, downvote_btn, flag_btn] | |
| ) | |
| flag_btn.click( | |
| flag_last_response, | |
| [state, model_selector], | |
| [textbox, upvote_btn, downvote_btn, flag_btn] | |
| ) | |
| regenerate_btn.click( | |
| regenerate, | |
| [state, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| concurrency_limit=concurrency_count | |
| ) | |
| clear_btn.click( | |
| clear_history, | |
| None, | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| queue=False | |
| ) | |
| textbox.submit( | |
| add_text, | |
| [state, textbox, imagebox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list, | |
| queue=False | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| concurrency_limit=concurrency_count | |
| ) | |
| submit_btn.click( | |
| add_text, | |
| [state, textbox, imagebox, image_process_mode], | |
| [state, chatbot, textbox, imagebox] + btn_list | |
| ).then( | |
| http_bot, | |
| [state, model_selector, temperature, top_p, max_output_tokens], | |
| [state, chatbot] + btn_list, | |
| concurrency_limit=concurrency_count | |
| ) | |
| demo.load( | |
| load_demo, | |
| [url_params], | |
| [state, model_selector], | |
| js=get_window_url_params | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", type=str, default="0.0.0.0") | |
| parser.add_argument("--port", type=int) | |
| parser.add_argument("--concurrency-count", type=int, default=16) | |
| parser.add_argument("--share", action="store_true") | |
| args = parser.parse_args() | |
| logger.info(f"args: {args}") | |
| models = ['llava-rlhf-13b-v1.5-336'] | |
| model_path = 'daviddaytw/Dr-LLaVA-sft' | |
| model_name = 'llava-rlhf-13b-v1.5-336' | |
| lora_path = 'daviddaytw/Dr-LLaVA-lora-adapter' | |
| demo = build_demo(concurrency_count=args.concurrency_count) | |
| demo.queue( | |
| api_open=False | |
| ).launch( | |
| server_name=args.host, | |
| server_port=args.port, | |
| share=args.share | |
| ) | |