Spaces:
Paused
Paused
| import os | |
| from gradio.themes import ThemeClass as Theme | |
| import numpy as np | |
| import argparse | |
| import gradio as gr | |
| from typing import Any, Iterator | |
| from typing import Iterator, List, Optional, Tuple | |
| import filelock | |
| import glob | |
| import json | |
| import time | |
| from gradio.routes import Request | |
| from gradio.utils import SyncToAsyncIterator, async_iteration | |
| from gradio.helpers import special_args | |
| import anyio | |
| from typing import AsyncGenerator, Callable, Literal, Union, cast, Generator | |
| from gradio_client.documentation import document, set_documentation_group | |
| from gradio.components import Button, Component | |
| from gradio.events import Dependency, EventListenerMethod | |
| from typing import List, Optional, Union, Dict, Tuple | |
| from tqdm.auto import tqdm | |
| from huggingface_hub import snapshot_download | |
| from gradio.components.base import Component | |
| from .base_demo import register_demo, get_demo_class, BaseDemo | |
| from .chat_interface import ( | |
| SYSTEM_PROMPT, | |
| MODEL_NAME, | |
| MAX_TOKENS, | |
| TEMPERATURE, | |
| CHAT_EXAMPLES, | |
| gradio_history_to_openai_conversations, | |
| gradio_history_to_conversation_prompt, | |
| DATETIME_FORMAT, | |
| get_datetime_string, | |
| chat_response_stream_multiturn_engine, | |
| ChatInterfaceDemo, | |
| CustomizedChatInterface, | |
| ) | |
| from gradio.events import Events | |
| import inspect | |
| from typing import AsyncGenerator, Callable, Literal, Union, cast | |
| import anyio | |
| from gradio_client import utils as client_utils | |
| from gradio_client.documentation import document | |
| from gradio.blocks import Blocks | |
| from gradio.components import ( | |
| Button, | |
| Chatbot, | |
| Component, | |
| Markdown, | |
| State, | |
| Textbox, | |
| get_component_instance, | |
| ) | |
| from gradio.events import Dependency, on | |
| from gradio.helpers import create_examples as Examples # noqa: N812 | |
| from gradio.helpers import special_args | |
| from gradio.layouts import Accordion, Group, Row | |
| from gradio.routes import Request | |
| from gradio.themes import ThemeClass as Theme | |
| from gradio.utils import SyncToAsyncIterator, async_iteration | |
| from ..globals import MODEL_ENGINE | |
| from ..configs import ( | |
| USE_PANEL, | |
| IMAGE_TOKEN, | |
| IMAGE_TOKEN_INTERACTIVE, | |
| CHATBOT_HEIGHT, | |
| ALLOWED_PATHS, | |
| ) | |
| from .multimodal_chat_interface import ( | |
| DOC_INSTRUCTION, | |
| DOC_TEMPLATE, | |
| CSS, | |
| undo_history, | |
| undo_history_until_last_assistant_turn, | |
| MultiModalChatInterface, | |
| gradio_history_to_conversation_prompt, | |
| gradio_history_to_openai_conversations, | |
| gradio_history_to_vision_conversation_prompt_paths, | |
| gradio_history_to_doc_conversation_prompt, | |
| gradio_history_to_vision_doc_conversation_prompt_paths, | |
| VisionChatInterfaceDemo, | |
| vision_chat_response_stream_multiturn_engine, | |
| ) | |
| import glob | |
| from pathlib import Path | |
| from gradio import utils as gradio_utils | |
| PREF_DIR = os.environ.get("PREF_DIR", "./tmp") | |
| PREFERENCE_MAKE_DATA_PATH = os.environ.get("PREFERENCE_MAKE_DATA_PATH", "assets/example_pref.json") | |
| IMAGE_DIR = os.environ.get("IMAGE_DIR", "./tmp_image") | |
| EXAMPLE_IMAGE_PATHS = [ | |
| x | |
| for x in glob.glob(os.path.join(IMAGE_DIR, "*")) | |
| ] | |
| print(f'IMAGES: {EXAMPLE_IMAGE_PATHS[:3]=}') | |
| # ! Existing images | |
| IMAGE_GLOB_ROOT = "/mnt/workspace/workgroup/phi/raw_data/multimodal_seallm/processed/sft/dpo_examples" | |
| # ALLOWED_PATHS.append(IMAGE_GLOB_ROOT) | |
| IMAGE_GLOBS = { | |
| # "geometry": "geo3k/train/*/img_diagram.png", | |
| "Geometry": ["geoqa_plus/*png", "Ask question about to solve the puzzle, calculating angles, find values, ... Provide extra information in the question (e.g 'Angle 1 = 30 degrees, find angle 2 from image.')"], | |
| "Everyday": ["gqa/images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories ...."], | |
| "OCR (read text)": ["ocr_vqa/images/*", "Ask question (1) full OCR description, (2) read specific details (e.g 'Who wrote the book?')."], | |
| "OpenViVQA": ["OpenViVQA/training-images/*", "Only vietnamese, (1) full OCR description, (2) read specific details, (3) image description and question answering"], | |
| "Text-VQA": ["textvqa/train_images/*", "Ask question to (1) describe, (2) find details, (3) negation (e.g 'Where's the cat?' while there is no cat in image.), (4) write stories, (5) reasoning"], | |
| "Landmarks": ["web-landmark/images/*", "Ask question to (1) Where is landmarks (2) What to do at that place (3) Write stories, (4) give advise for tourists..."], | |
| "Everyday-VG2": ["vg/VG_100K_2/*", "Same with Everyday"], | |
| } | |
| IMAGE_CUT_OFF_BEGIN = 0 | |
| IMAGE_CUT_OFF = 100 | |
| # IMAGE_CUT_OFF = 20 | |
| IMAGE_GLOB_PATHS = {} | |
| IMAGE_GLOB_DESCS = {} | |
| for k, v in IMAGE_GLOBS.items(): | |
| glob_p, description = v | |
| paths = [] | |
| for i, p in enumerate(glob.glob(os.path.join(IMAGE_GLOB_ROOT, glob_p))): | |
| if i < IMAGE_CUT_OFF_BEGIN: | |
| continue | |
| if i >= IMAGE_CUT_OFF + IMAGE_CUT_OFF_BEGIN: | |
| break | |
| paths.append(p) | |
| IMAGE_GLOB_PATHS[k] = paths | |
| IMAGE_GLOB_DESCS[k] = description | |
| print(IMAGE_GLOB_PATHS['Geometry'][:10]) | |
| def read_json(json_file): | |
| print(f'Reading : {json_file}') | |
| with open(json_file, 'r', encoding='utf-8') as f: | |
| rows = json.load(f) | |
| return rows | |
| def write_json(data, json_file): | |
| with open(json_file, 'w', encoding='utf-8') as f: | |
| json.dump(data, f, indent=4, ensure_ascii=False) | |
| def convert_pref_data_to_openai_format(rows_dict): | |
| for key, r in rows_dict.items(): | |
| if "conversation_prefix" in r: | |
| assert "responses" in r, f'invalid: {r}' | |
| continue | |
| history = r['history'] | |
| conversations = [] | |
| for user, assistant in history: | |
| conversations.append({"role": "user", "content": user.strip()}) | |
| conversations.append({"role": "assistant", "content": assistant.strip()}) | |
| r['conversation_prefix'] = conversations[:-1] | |
| r['responses'] = [conversations[-1]] | |
| r['original_response'] = conversations[-1] | |
| if "lang" not in r: | |
| r['lang'] = key[-2:] | |
| # missing an item in responses | |
| lang_set = list(set([r['lang'] for r in rows_dict.values()])) | |
| return rows_dict, lang_set | |
| def convert_mm_pref_data_to_openai_format(rows_dict): | |
| pass | |
| PREFERENCE_RATE_DICT = None | |
| LANG_SET = ["en", "vi", "id", 'ms', "th", "zh", 'lo', 'km', 'tl', 'my'] | |
| if PREFERENCE_MAKE_DATA_PATH is not None and os.path.exists(PREFERENCE_MAKE_DATA_PATH): | |
| print(f'Loading {PREFERENCE_MAKE_DATA_PATH}') | |
| PREFERENCE_RATE_DICT = read_json(PREFERENCE_MAKE_DATA_PATH) | |
| PREFERENCE_RATE_DICT, _LANG_SET = convert_pref_data_to_openai_format(PREFERENCE_RATE_DICT) | |
| LANG_SET = LANG_SET + [l for l in _LANG_SET if l not in LANG_SET] | |
| class CustomJsonlLogger(gr.FlaggingCallback): | |
| def __init__(self): | |
| self.num_lines = 0 | |
| def setup( | |
| self, | |
| components: list[Component], | |
| flagging_dir: Union[str, Path], | |
| ): | |
| self.components = components | |
| self.flagging_dir = flagging_dir | |
| os.makedirs(flagging_dir, exist_ok=True) | |
| flagging_dir = self.flagging_dir | |
| log_filepath = Path(flagging_dir) / "log.jsonl" | |
| if Path(log_filepath).exists(): | |
| with open(log_filepath, "rb") as f: | |
| self.num_lines = sum(1 for _ in f) | |
| else: | |
| self.num_lines = 0 | |
| def flag( | |
| self, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: Union[str, None] = None, | |
| ) -> int: | |
| import datetime | |
| flagging_dir = self.flagging_dir | |
| log_filepath = Path(flagging_dir) / "log.jsonl" | |
| is_new = not Path(log_filepath).exists() | |
| headers = [ | |
| getattr(component, "label", None) or f"component {idx}" | |
| for idx, component in enumerate(self.components) | |
| ] + [ | |
| "flag", | |
| "username", | |
| "timestamp", | |
| ] | |
| csv_data = [] | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = Path( | |
| flagging_dir | |
| ) / client_utils.strip_invalid_filename_characters( | |
| getattr(component, "label", None) or f"component {idx}" | |
| ) | |
| if gradio_utils.is_update(sample): | |
| csv_data.append(str(sample)) | |
| else: | |
| csv_data.append( | |
| component.flag(sample, flag_dir=save_dir) | |
| if sample is not None | |
| else "" | |
| ) | |
| csv_data.append(flag_option) | |
| csv_data.append(username if username is not None else "") | |
| csv_data.append(str(datetime.datetime.now())) | |
| json_obj = {} | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = Path( | |
| flagging_dir | |
| ) / client_utils.strip_invalid_filename_characters( | |
| getattr(component, "label", None) or f"component {idx}" | |
| ) | |
| label = getattr(component, "label", None) or f"component {idx}" | |
| if gradio_utils.is_update(sample): | |
| value = str(sample) | |
| else: | |
| value = component.flag(sample, flag_dir=save_dir) if sample is not None else None | |
| json_obj[label] = value | |
| json_obj['flag'] = flag_option | |
| json_obj['username'] = username if username is not None else "" | |
| json_obj['timestamp'] = str(datetime.datetime.now()) | |
| with open(log_filepath, "a", encoding="utf-8") as jsonl_file: | |
| jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n") | |
| self.num_lines += 1 | |
| return self.num_lines | |
| class VisionJsonlLogger(CustomJsonlLogger): | |
| # ! must save the image | |
| def flag( | |
| self, | |
| flag_data: list[Any], | |
| flag_option: str = "", | |
| username: Union[str, None] = None, | |
| ) -> int: | |
| import datetime | |
| from shutil import copyfile | |
| flagging_dir = self.flagging_dir | |
| log_filepath = Path(flagging_dir) / "log.jsonl" | |
| image_dir = Path(flagging_dir) / "images" | |
| is_new = not Path(log_filepath).exists() | |
| os.makedirs(image_dir, exist_ok=True) | |
| headers = [ | |
| getattr(component, "label", None) or f"component {idx}" | |
| for idx, component in enumerate(self.components) | |
| ] + [ | |
| "flag", | |
| "username", | |
| "timestamp", | |
| ] | |
| csv_data = [] | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = Path( | |
| flagging_dir | |
| ) / client_utils.strip_invalid_filename_characters( | |
| getattr(component, "label", None) or f"component {idx}" | |
| ) | |
| if gradio_utils.is_update(sample): | |
| csv_data.append(str(sample)) | |
| else: | |
| csv_data.append( | |
| component.flag(sample, flag_dir=save_dir) | |
| if sample is not None | |
| else "" | |
| ) | |
| csv_data.append(flag_option) | |
| csv_data.append(username if username is not None else "") | |
| csv_data.append(str(datetime.datetime.now())) | |
| json_obj = {} | |
| for idx, (component, sample) in enumerate(zip(self.components, flag_data)): | |
| save_dir = Path( | |
| flagging_dir | |
| ) / client_utils.strip_invalid_filename_characters( | |
| getattr(component, "label", None) or f"component {idx}" | |
| ) | |
| label = getattr(component, "label", None) or f"component {idx}" | |
| if gradio_utils.is_update(sample): | |
| value = str(sample) | |
| else: | |
| value = component.flag(sample, flag_dir=save_dir) if sample is not None else None | |
| if isinstance(value, list): | |
| # Expecting history | |
| from .multimodal_chat_interface import gradio_history_to_vision_conversations_paths | |
| conversations, image_paths = gradio_history_to_vision_conversations_paths(value) | |
| new_paths = [ | |
| os.path.join(image_dir, str(datetime.datetime.now()) + os.path.basename(p)) | |
| for p in image_paths | |
| ] | |
| for np, ip in zip(new_paths, image_paths): | |
| copyfile(ip, np) | |
| json_obj[label] = conversations | |
| json_obj[label + "-images"] = new_paths | |
| else: | |
| json_obj[label] = value | |
| json_obj['flag'] = flag_option | |
| json_obj['username'] = username if username is not None else "" | |
| json_obj['timestamp'] = str(datetime.datetime.now()) | |
| with open(log_filepath, "a", encoding="utf-8") as jsonl_file: | |
| jsonl_file.write(json.dumps(json_obj, ensure_ascii=False) + "\n") | |
| self.num_lines += 1 | |
| return self.num_lines | |
| def get_preference_radio(): | |
| pref_choice = gr.Radio( | |
| ['1 Better', '2 Better', 'Add best', 'dirty/undecided'], | |
| label='preference', | |
| info="Indicate if 1 or 2 is better. If both not excellent, pick 'Add best' and write the better one below. If question or answer is problematic, cannot decide, then choose dirty/undecided." | |
| ) | |
| return pref_choice | |
| def vision_submit_vision_response_stream_multiturn_engine_yhistory( | |
| message: str, | |
| input_image: str, | |
| history: List[List[str]], | |
| temperature: float, | |
| max_tokens: int, | |
| system_prompt: Optional[str] = SYSTEM_PROMPT, | |
| image_token: Optional[str] = IMAGE_TOKEN, | |
| ): | |
| # ! Add message and input_image into the history and submit | |
| message = message.strip() | |
| if message == "": | |
| gr.Warning(f'Input text cannot be empty') | |
| yield history | |
| new_history = history | |
| if input_image is not None and os.path.exists(input_image): | |
| # ! image exist, so add message if it's not empty | |
| new_history = new_history + [[(input_image,), None]] | |
| if message != "": | |
| new_history = new_history + [[message, None]] | |
| else: | |
| # ! message cannot be empty if there is no input_image | |
| if message == "": | |
| gr.Warning(f'Input text cannot be empty!') | |
| yield history | |
| return | |
| else: | |
| new_history = new_history + [[message, None]] | |
| yield new_history | |
| # ! yield current history | |
| # use vision_chat_response_stream_multiturn_engine | |
| response = None | |
| for response, num_tokens in vision_chat_response_stream_multiturn_engine( | |
| history=new_history, | |
| temperature=temperature, max_tokens=max_tokens, system_prompt=system_prompt, | |
| image_token=image_token, | |
| ): | |
| yield new_history[:-1] + [[message, response]] | |
| if response is not None: | |
| yield new_history[:-1] + [[message, response]] | |
| def vision_submit_2_histories( | |
| message: str, | |
| input_image: str, | |
| history1: List[List[str]], | |
| history2: List[List[str]], | |
| temperature: float, | |
| max_tokens: int, | |
| system_prompt: Optional[str] = SYSTEM_PROMPT, | |
| image_token: Optional[str] = IMAGE_TOKEN, | |
| ): | |
| # need to yield 2 history | |
| new_history1 = history1 | |
| new_history2 = history2 | |
| for his in vision_submit_vision_response_stream_multiturn_engine_yhistory( | |
| message, input_image, history1, temperature, max_tokens, system_prompt, image_token, | |
| ): | |
| new_history1 = his | |
| yield new_history1, new_history2 | |
| for his in vision_submit_vision_response_stream_multiturn_engine_yhistory( | |
| message, input_image, history2, temperature, max_tokens, system_prompt, image_token, | |
| ): | |
| new_history2 = his | |
| yield new_history1, new_history2 | |
| def undo_history_until_last_assistant_turn_message(history): | |
| history = undo_history(history) | |
| while len(history) > 0 and history[-1][-1] is None: | |
| history = undo_history(history) | |
| return history, history | |
| def replace_last_response(input_text: str, history: List[Tuple[str, str]]): | |
| # replace the last response with input_text | |
| input_text = input_text.strip() | |
| if input_text == "": | |
| gr.Warning(f'prompt empty! dont send empty prompt') | |
| return "", history | |
| if len(history) == 0: | |
| gr.Warning(f'History empty, cannot replace') | |
| return input_text, history | |
| history[-1][-1] = input_text | |
| return "", history | |
| # def load_image_from_gallery(selected_state: gr.SelectData): | |
| # convo = sft_data_list[selected_state.index] | |
| # dirname = sft_dirname | |
| # image_path = os.path.join(dirname, convo['image']) | |
| # return image_path | |
| def load_image_from_gallery(data_list, selected_state: gr.SelectData): | |
| image_path = data_list[selected_state.index] | |
| # dirname = sft_dirname | |
| # image_path = os.path.join(dirname, convo['image']) | |
| return image_path | |
| class VisionLivePreferencePickDemo(VisionChatInterfaceDemo): | |
| def examples(self): | |
| return [ | |
| ["What's strange about this image?", "assets/dog_monalisa.jpeg",], | |
| ["Explain why the sky is blue.", None,], | |
| ] | |
| def tab_name(self): | |
| return "Vision Live Preference" | |
| def create_demo( | |
| self, | |
| title: str | None = None, | |
| description: str | None = None, | |
| **kwargs | |
| ) -> gr.Blocks: | |
| system_prompt = kwargs.get("system_prompt", SYSTEM_PROMPT) | |
| max_tokens = kwargs.get("max_tokens", MAX_TOKENS) | |
| temperature = kwargs.get("temperature", TEMPERATURE) | |
| model_name = kwargs.get("model_name", MODEL_NAME) | |
| log_folder = os.path.join(PREF_DIR, "live_preference_pick") | |
| description = f""" | |
| ## Live generation preference picking | |
| Live generation is similar to the Preference Picking demo, except that linguists can come up with questions/prompts **on their own** instead of pre-existing data. | |
| PREF_DIR: {log_folder} | |
| """ | |
| instruction_content = f""" | |
| ### Tasks | |
| You are enabled to freely build 2 different conversations using the model and pick the better conversations. | |
| You can also create best responses if model's generated ones are not good. | |
| ### Requirements | |
| The 2 conversations must share at least the first user query. Other than that, the length, number of turns, user queries (except the first one) can vary. | |
| For example: | |
| ``` | |
| # Valid conversation pairs | |
| "User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136" | |
| -> "Bot: I dont know" | |
| "User: Hello, 1+1=?" -> "Bot: 1+1=2" -> "User: what about 123+13?" -> "Bot: 123+13=136" | |
| -> "Bot: 1+1=3" -> "User: that's wrong!" -> "Bot: Im sorry man." | |
| ``` | |
| ``` | |
| # Invalid pairs: | |
| "User: Hello, 1+1=?" -> "Bot: 1+1=2" | |
| "User: Tell me a joke" -> "Bot: here is the joke for your..." | |
| ``` | |
| ### Steps to proceed: | |
| There are multiple buttons: | |
| * `Submit both`: Submit the text prompt to both chatboxes, expect different (or same) answers. | |
| * `Regenerate`: Regenerate the responses of both chatboxes from the last user queries. | |
| * `Clear`: Clear both chatboxes. | |
| The following numbered buttons (1 or 2) is applied to only Bot-1 or Bot-2 respectively. | |
| * `Submit-1`: Submit the text prompt only one chatbot (1 or 2). | |
| * `Undo-1`: Undo the last generation (both last response and query) | |
| * `Regen-1`: Regenerate the last response. | |
| * `Replace-1`: Replace the last response with a better response (in case the last response is incorrect, unsatisfactory) | |
| """ | |
| callback = VisionJsonlLogger() | |
| with gr.Blocks(css=CSS) as pdemo: | |
| gr.Markdown(description) | |
| with gr.Accordion(label="Instructions and Guidelines", open=False): | |
| gr.Markdown(instruction_content) | |
| with gr.Accordion(label="Additional input", open=False): | |
| temp = gr.Number(value=temperature, label='Temperature', info="Higher -> more random") | |
| length = gr.Number(value=max_tokens, label='Max tokens', info='Increase if want more generation') | |
| # freq_pen = gr.Number(value=frequence_penalty, label='Frequency penalty', info='> 0 encourage new tokens over repeated tokens') | |
| # pres_pen = gr.Number(value=presence_penalty, label='Presence penalty', info='> 0 encourage new tokens, < 0 encourage existing tokens') | |
| # stop_strings = gr.Textbox(value="<s>,</s>,<|im_start|>", label='Stop strings', info='Comma-separated string to stop generation.', lines=1) | |
| system_prompt = gr.Textbox(value=system_prompt, label='system_prompt', lines=1) | |
| with gr.Row(): | |
| chatbot_1 = gr.Chatbot( | |
| [], | |
| label="Bot-1", | |
| elem_id="chatbot-1", | |
| bubble_full_width=False, | |
| latex_delimiters=[ | |
| # { "left": "$", "right": "$", "display": False}, | |
| { "left": "$$", "right": "$$", "display": True}, | |
| ], | |
| show_copy_button=True, | |
| layout="panel" if USE_PANEL else "bubble", | |
| height=CHATBOT_HEIGHT, | |
| ) | |
| chatbot_2 = gr.Chatbot( | |
| [], | |
| label="Bot-2", | |
| elem_id="chatbot-2", | |
| bubble_full_width=False, | |
| latex_delimiters=[ | |
| # { "left": "$", "right": "$", "display": False}, | |
| { "left": "$$", "right": "$$", "display": True}, | |
| ], | |
| show_copy_button=True, | |
| layout="panel" if USE_PANEL else "bubble", | |
| height=CHATBOT_HEIGHT, | |
| ) | |
| with gr.Row(): | |
| input_text = gr.Textbox( | |
| scale=6, | |
| lines=12, | |
| # lines=4, | |
| max_lines=40, | |
| show_label=False, | |
| placeholder="Enter text and press enter, or upload an image", | |
| container=False, | |
| ) | |
| # submit will submit the same input text to both responses | |
| input_image = gr.Image( | |
| label="input_image", type="filepath", scale=3, | |
| # height=250, | |
| ) | |
| with gr.Row(): | |
| gen_submit = gr.Button('Send both', scale=1, variant='primary') | |
| # regenerate should not care about input_text, it just undo the previous history | |
| # regen_submit = gr.Button('Regenerate', scale=1) | |
| clear_btn = gr.Button('Clear', scale=1) | |
| # submit | |
| with gr.Row(): | |
| chat1_submit = gr.Button('Send-1', variant='primary') | |
| chat1_undo = gr.Button('Undo-1') | |
| # chat1_regenerate = gr.Button('Regen-1') | |
| chat1_replace = gr.Button('Replace-1') | |
| chat2_submit = gr.Button('Send-2', variant='primary') | |
| chat2_undo = gr.Button('Undo-2') | |
| # chat2_regenerate = gr.Button('Regen-2') | |
| chat2_replace = gr.Button('Replace-2') | |
| gr.Markdown(f'**Do not click `Record Choice` twice with the same data sample!**') | |
| with gr.Row(): | |
| pref_choice = get_preference_radio() | |
| # with gr.Row(): | |
| # text_replace = gr.Textbox( | |
| # placeholder="If both responses are not good, write a better response here. Only apply to the last response.", | |
| # lines=2, | |
| # max_lines=30, | |
| # scale=6, | |
| # label="best_response" | |
| # ) | |
| submit_choice_btn = gr.Button('Record Choice', variant='secondary') | |
| from functools import partial | |
| with gr.Row(): | |
| gr.Examples( | |
| label="Random images", | |
| examples=[[x] for x in EXAMPLE_IMAGE_PATHS], | |
| inputs=input_image, | |
| cache_examples=False, | |
| examples_per_page=100, | |
| ) | |
| for k, plist in IMAGE_GLOB_PATHS.items(): | |
| print(f'{k}: {plist[:5]}') | |
| gr.Markdown(f"{k}: {IMAGE_GLOB_DESCS[k]}") | |
| gallery = gr.Gallery( | |
| label=k, | |
| value=plist, | |
| allow_preview=False, | |
| columns=10, | |
| # rows=2, | |
| height=250, | |
| ) | |
| def _load_image_from_gallery(selected_state: gr.SelectData): | |
| image_path = selected_state.value['image']['path'] | |
| print(f'Select: {image_path}') | |
| return image_path | |
| gallery.select( | |
| _load_image_from_gallery, | |
| # lambda select: plist[select.index], | |
| # inputs=, | |
| outputs=[input_image], | |
| queue=False | |
| ) | |
| # ! events for submit choices | |
| submit_choice_btn.click( | |
| lambda: gr.Button(value="Saving...", interactive=False, variant='stop'), | |
| None, | |
| submit_choice_btn, | |
| queue=False, | |
| api_name=False, | |
| ) | |
| visual_feedback = True | |
| def flag_method(request: gr.Request, *args): | |
| # ! must save the image somewhere | |
| try: | |
| callback.flag(args) | |
| except Exception as e: | |
| print(f"Error while flagging: {e}") | |
| if visual_feedback: | |
| return "Error!" | |
| if not visual_feedback: | |
| return | |
| gr.Info(f'Saving preference sucessful ({args[0]})') | |
| time.sleep(1) # to provide enough time for the user to observe button change | |
| return gr.Button(value="Record Choice", interactive=True) | |
| callback.setup([chatbot_1, chatbot_2, pref_choice], log_folder) | |
| submit_choice_btn.click( | |
| flag_method, [chatbot_1, chatbot_2, pref_choice], submit_choice_btn, | |
| preprocess=False, queue=False, api_name=False | |
| ) | |
| # ! button evenrs | |
| from gradio.events import Dependency, EventListenerMethod, on | |
| generate_sub_events_both = [input_text.submit, gen_submit.click] | |
| on( | |
| generate_sub_events_both, | |
| vision_submit_2_histories, | |
| [ | |
| input_text, input_image, chatbot_1, chatbot_2, | |
| temp, length, system_prompt | |
| ], | |
| [chatbot_1, chatbot_2], | |
| api_name=False, | |
| queue=True, | |
| ).then( | |
| lambda mes, img: ("", None), | |
| [input_text, input_image], | |
| [input_text, input_image], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| clear_btn.click( | |
| lambda c1, c2, txt, img: ([], [], "", None), | |
| [chatbot_1, chatbot_2, input_text, input_image], | |
| [chatbot_1, chatbot_2, input_text, input_image], | |
| api_name=False, | |
| queue=True, | |
| ) | |
| chat1_submit.click( | |
| vision_submit_vision_response_stream_multiturn_engine_yhistory, | |
| [ | |
| input_text, input_image, chatbot_1, | |
| temp, length, system_prompt, | |
| ], | |
| [chatbot_1], | |
| api_name=False, | |
| queue=True, | |
| ).then( | |
| lambda mes, img: ("", None), | |
| [input_text, input_image], | |
| [input_text, input_image], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| chat2_submit.click( | |
| vision_submit_vision_response_stream_multiturn_engine_yhistory, | |
| [ | |
| input_text, input_image, chatbot_2, | |
| temp, length, system_prompt, | |
| ], | |
| [chatbot_2], | |
| api_name=False, | |
| queue=True, | |
| ).then( | |
| lambda mes, img: ("", None), | |
| [input_text, input_image], | |
| [input_text, input_image], | |
| api_name=False, | |
| queue=False, | |
| ) | |
| chat1_undo.click( | |
| undo_history_until_last_assistant_turn, | |
| chatbot_1, | |
| [chatbot_1, input_text], | |
| api_name=False, | |
| queue=True, | |
| ) | |
| chat2_undo.click( | |
| undo_history_until_last_assistant_turn, | |
| chatbot_2, | |
| [chatbot_2, input_text], | |
| api_name=False, | |
| queue=True, | |
| ) | |
| chat1_replace.click( | |
| replace_last_response, | |
| [input_text, chatbot_1], | |
| [input_text, chatbot_1], | |
| api_name=False, | |
| queue=True, | |
| ) | |
| chat2_replace.click( | |
| replace_last_response, | |
| [input_text, chatbot_2], | |
| [input_text, chatbot_2], | |
| api_name=False, | |
| queue=True, | |
| ) | |
| return pdemo |