Spaces:
Running
on
Zero
Running
on
Zero
| """Demo of the IBM Granite Guardian model.""" | |
| import json | |
| import os | |
| from pathlib import Path | |
| from typing import Literal | |
| import gradio as gr | |
| from gradio_modal import Modal | |
| from logger import logger | |
| from model import get_guardian_config, get_guardian_response, get_prompt | |
| from themes.research_monochrome import theme as carbon_theme | |
| from utils import ( | |
| get_messages_documents_and_tools, | |
| get_result_description, | |
| load_command_line_args, | |
| to_snake_case, | |
| to_title_case, | |
| ) | |
| load_command_line_args() | |
| catalog = {} | |
| toy_json = '{"name": "John"}' | |
| CATALOG_FILE_PATH = Path(__file__).parent / "catalog.json" | |
| with open(CATALOG_FILE_PATH) as f: | |
| logger.debug("Loading catalog from json.") | |
| catalog = json.load(f) | |
| def update_selected_test_case(button_name: str, state: gr.State, event: gr.EventData) -> gr.State: | |
| target_sub_catalog_name, target_test_case_name = event.target.elem_id.split("---") | |
| state["selected_sub_catalog"] = target_sub_catalog_name | |
| state["selected_criteria_name"] = target_test_case_name | |
| state["selected_test_case"] = next( | |
| t | |
| for sub_catalog_name, sub_catalog in catalog.items() | |
| for t in sub_catalog | |
| if t["name"] == to_snake_case(button_name) and to_snake_case(sub_catalog_name) == target_sub_catalog_name | |
| ) | |
| return state | |
| def on_test_case_click(state: gr.State): | |
| """Todo""" | |
| selected_sub_catalog = state["selected_sub_catalog"] | |
| selected_criteria_name = state["selected_criteria_name"] | |
| selected_test_case = state["selected_test_case"] | |
| logger.debug(f'Changing to test case "{selected_criteria_name}" from catalog "{selected_sub_catalog}".') | |
| is_assistant_message_visible = selected_test_case["assistant_message"] is not None | |
| is_context_visible = selected_criteria_name in ["context_relevance", "groundedness"] | |
| is_tools_present = "tools" in selected_test_case and selected_test_case["tools"] is not None | |
| test_case_name = f'<h2>{to_title_case(selected_test_case["name"])}</h2>' | |
| criteria = selected_test_case["criteria"] | |
| # update context field: | |
| if is_context_visible: | |
| context = gr.update( | |
| value=selected_test_case["context"], interactive=True, visible=True, elem_classes=["input-box"] | |
| ) | |
| else: | |
| context = gr.update( | |
| visible=False, | |
| ) | |
| # else: | |
| # context = gr.update( | |
| # visible=selected_test_case["context"] is not None, | |
| # value=selected_test_case["context"], | |
| # interactive=False, | |
| # elem_classes=["read-only", "input-box"], | |
| # ) | |
| tools = gr.update( | |
| visible=is_tools_present, | |
| value=selected_test_case["tools"] if is_tools_present else toy_json, | |
| elem_classes=["read-only", "margin-bottom"], | |
| ) | |
| # update user message field | |
| # if is_user_message_visible: | |
| user_message = gr.update( | |
| value=selected_test_case["user_message"], visible=True, interactive=True, elem_classes=["input-box"] | |
| ) | |
| # else: | |
| # user_message = gr.update( | |
| # value=selected_test_case["user_message"], interactive=False, elem_classes=["read-only", "input-box"] | |
| # ) | |
| # update assistant message field | |
| if is_tools_present: | |
| assistant_message_json = gr.update( | |
| visible=True, | |
| value=selected_test_case["assistant_message"], | |
| elem_classes=["read-only", "margin-bottom"], | |
| ) | |
| assistant_message_text = gr.update(visible=False) | |
| elif is_assistant_message_visible: | |
| # if is_assistant_message_visible: | |
| assistant_message_text = gr.update( | |
| value=selected_test_case["assistant_message"], | |
| visible=True, | |
| interactive=True, | |
| elem_classes=["input-box"], | |
| ) | |
| # else: | |
| # assistant_message_text = gr.update( | |
| # visible=selected_test_case["assistant_message"] is not None, | |
| # value=selected_test_case["assistant_message"], | |
| # interactive=False, | |
| # elem_classes=["read-only", "input-box"], | |
| # ) | |
| assistant_message_json = gr.update(visible=False) | |
| else: | |
| assistant_message_text = gr.update(visible=False) | |
| assistant_message_json = gr.update(visible=False) | |
| result_text = gr.update(visible=False, value="") | |
| thinking_trace_accordion = gr.update(visible=False, value="") | |
| result_trace = gr.update(visible=False, value="") | |
| result_explanation = gr.update( | |
| value=f"<p><strong>{get_result_description(selected_sub_catalog,selected_criteria_name)}</strong></p>", | |
| ) | |
| return ( | |
| test_case_name, | |
| criteria, | |
| context, | |
| user_message, | |
| assistant_message_text, | |
| assistant_message_json, | |
| tools, | |
| result_text, | |
| thinking_trace_accordion, | |
| result_trace, | |
| result_explanation, | |
| ) | |
| def change_button_color(event: gr.EventData): | |
| """Todo""" | |
| return [ | |
| ( | |
| gr.update(elem_classes=["catalog-button", "selected"]) | |
| if v.elem_id == event.target.elem_id | |
| else gr.update(elem_classes=["catalog-button"]) | |
| ) | |
| for c in catalog_buttons.values() | |
| for v in c.values() | |
| ] | |
| def on_submit(criteria, context, user_message, assistant_message_text, assistant_message_json, tools, think, state): | |
| # if think_checkbox: | |
| # yield ( | |
| # gr.update(value=None, visible=True), | |
| # gr.update(value=None, visible=True), | |
| # ) | |
| # else: | |
| # yield ( | |
| # gr.update(value="", visible=True), | |
| # gr.update(value=None, visible=False), | |
| # ) | |
| criteria_name = state["selected_criteria_name"] | |
| if criteria_name == "general_harm": | |
| criteria_name = "harm" | |
| elif criteria_name == "function_calling_hallucination": | |
| criteria_name = "function_call" | |
| assistant_message = assistant_message_json if criteria_name == "function_call" else assistant_message_text | |
| test_case = { | |
| "name": criteria_name, | |
| "criteria": criteria, | |
| "context": context, | |
| "user_message": user_message, | |
| "assistant_message": assistant_message, | |
| "tools": tools, | |
| } | |
| sub_catalog_name = state["selected_sub_catalog"] | |
| logger.debug(f"Starting evaluation for subcatalog {sub_catalog_name} and criteria name {criteria_name}") | |
| generator = get_guardian_response( | |
| test_case=test_case, | |
| sub_catalog_name=sub_catalog_name, | |
| criteria_name=criteria_name, | |
| criteria_description=None, | |
| think=think, | |
| ) | |
| for label, trace in generator: | |
| waiting_message: Literal["Waiting for thinking to end..."] | Literal["Generating result..."] = ( | |
| "Waiting for thinking to end..." if think else "Generating result..." | |
| ) | |
| yield ( | |
| gr.update(value=f"<p><strong>{label}</strong></p>" if label else f"<p>{waiting_message}<p>"), | |
| ( | |
| gr.update( | |
| value=trace, | |
| visible=True, | |
| ) | |
| if think | |
| else None | |
| ), | |
| gr.update(open=True), | |
| ) | |
| yield ( | |
| gr.update(value=f"<p><strong>{label}</strong></p>", visible=True), | |
| gr.update(value=trace) if think_checkbox else None, | |
| gr.update(open=False), | |
| ) | |
| def on_submit_byor(byor_criteria, byor_user_message, byor_assistant_message, think): | |
| test_case = { | |
| "name": "byor", | |
| "description": byor_criteria, | |
| "context": "", | |
| "user_message": byor_user_message, | |
| "assistant_message": byor_assistant_message, | |
| "tools": "", | |
| } | |
| criteria_name = test_case["name"] | |
| criteria_description = test_case["description"] | |
| logger.debug(f"Starting evaluation for subcatalog {sub_catalog} and criteria name {criteria_name}") | |
| generator = get_guardian_response( | |
| test_case=test_case, | |
| sub_catalog_name=None, | |
| criteria_name=criteria_name, | |
| criteria_description=criteria_description, | |
| think=think, | |
| ) | |
| for label, trace in generator: | |
| waiting_message = "Waiting for thinking to end..." if think else "Generating result..." | |
| yield ( | |
| gr.update(value=f"<p><strong>{label}</strong></p>" if label else f"<p>{waiting_message}<p>"), | |
| ( | |
| gr.update( | |
| value=trace, | |
| visible=True, | |
| ) | |
| if think | |
| else None | |
| ), | |
| gr.update(open=True), | |
| ) | |
| yield ( | |
| gr.update(value=f"<p><strong>{label}</strong></p>", visible=True), | |
| gr.update(value=trace) if think_checkbox else None, | |
| gr.update(open=False), | |
| ) | |
| def on_show_prompt_click( | |
| criteria, context, user_message, assistant_message_text, assistant_message_json, tools, think, state | |
| ) -> gr.Markdown: | |
| criteria_name = state["selected_criteria_name"] | |
| if criteria_name == "general_harm": | |
| criteria_name = "harm" | |
| elif criteria_name == "function_calling_hallucination": | |
| criteria_name = "function_call" | |
| assistant_message = assistant_message_json if criteria_name == "function_call" else assistant_message_text | |
| sub_catalog_name = state["selected_sub_catalog"] | |
| test_case = { | |
| "name": criteria_name, | |
| "criteria": criteria, | |
| "context": context, | |
| "user_message": user_message, | |
| "assistant_message": assistant_message, | |
| "tools": tools, | |
| } | |
| messages, documents, tools = get_messages_documents_and_tools( | |
| test_case=test_case, sub_catalog_name=sub_catalog_name | |
| ) | |
| guardian_config = get_guardian_config( | |
| criteria_name=criteria_name, | |
| criteria_description=criteria, | |
| ) | |
| prompt = get_prompt( | |
| messages=messages, | |
| documents=documents, | |
| tools=tools, | |
| guardian_config=guardian_config, | |
| think=think, | |
| ) | |
| prompt = prompt.replace("<", "<").replace(">", ">").replace("\\n", "<br>") | |
| return gr.Markdown(prompt) | |
| def on_show_byor_prompt_click(byor_criteria, byor_user_message, byor_assistant_message, think) -> gr.Markdown: | |
| test_case = { | |
| "name": "byor", | |
| "description": byor_criteria, | |
| "context": "", | |
| "user_message": byor_user_message, | |
| "assistant_message": byor_assistant_message, | |
| "tools": "", | |
| } | |
| messages = get_messages_documents_and_tools(test_case=test_case) | |
| criteria_name = test_case["name"] | |
| criteria_description = test_case["description"] | |
| logger.debug(f"Starting evaluation for subcatalog {sub_catalog} and criteria name {criteria_name}") | |
| prompt = get_prompt( | |
| messages=messages, | |
| guardian_config=get_guardian_config(criteria_name, criteria_description), | |
| think=think, | |
| ) | |
| prompt = prompt.replace("<", "<").replace(">", ">").replace("\\n", "<br>") | |
| return gr.Markdown(prompt) | |
| def on_add_assistant_message_button_click(): | |
| return [ | |
| gr.update(visible=False), | |
| gr.update(visible=True), | |
| f"<p><strong>{'Is the assistant message harmful based on the provided criteria?'}</strong></p>", | |
| ] | |
| def on_remove_assistant_message_button(): | |
| return [ | |
| gr.update(visible=False), | |
| "", | |
| gr.update(visible=True), | |
| f"<p><strong>{'Is the user message harmful based on the provided criteria?'}</strong></p>", | |
| ] | |
| css_file_path = Path(Path(__file__).parent / "app.css") | |
| head_file_path = Path(Path(__file__).parent / "app_head.html") | |
| with gr.Blocks( | |
| title="Granite Guardian", | |
| css_paths=css_file_path, | |
| theme=carbon_theme, | |
| head_paths=head_file_path, | |
| ) as demo: | |
| state = gr.State( | |
| value={ | |
| "selected_sub_catalog": "harmful_content_in_user_prompt", | |
| "selected_criteria_name": "general_harm", | |
| } | |
| ) | |
| starting_test_case = next( | |
| iter( | |
| t | |
| for sub_catalog_name, sub_catalog in catalog.items() | |
| for t in sub_catalog | |
| if t["name"] == state.value["selected_criteria_name"] | |
| and sub_catalog_name == state.value["selected_sub_catalog"] | |
| ) | |
| ) | |
| description = """ | |
| <p>Granite Guardian models are specialized language models in the Granite family that can detect harms and risks in | |
| generative AI systems. They can be used with any large language model to make interactions with generative AI systems | |
| safe. Select an example in the left panel to see how the Granite Guardian model evaluates harms and risks in user | |
| prompts, assistant responses, and for hallucinations in retrieval-augmented generation and function calling. In this | |
| demo, we use granite-guardian-3.3-8b. This version of Granite Guardian is a hybrid thinking model that allows the user to operate in thinking or non-thinking mode.</p> | |
| """ | |
| with gr.Row(elem_classes="header-row", equal_height=True), gr.Column(): | |
| gr.HTML("<h1>IBM Granite Guardian 3.3</h1>", elem_classes="title") | |
| gr.HTML( | |
| elem_classes="system-description", | |
| value=description, | |
| ) | |
| with gr.Tab("Try Example"): | |
| with gr.Row(): | |
| with gr.Column(scale=0): | |
| title_display_left = gr.HTML("<h2>Example Risks</h2>", elem_classes=["subtitle", "subtitle-harms"]) | |
| with gr.Column(scale=1) as test_case_content: | |
| with gr.Row(): | |
| test_case_name = gr.HTML( | |
| f'<h2>{to_title_case(starting_test_case["name"])}</h2>', elem_classes="subtitle" | |
| ) | |
| show_propt_button = gr.Button( | |
| "Show prompt", size="sm", scale=0, min_width=110, elem_classes="no-stretch" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=0, elem_classes="accordions-gap"): | |
| accordions = [] | |
| catalog_buttons: dict[str, dict[str, gr.Button]] = {} | |
| for i, (sub_catalog_name, sub_catalog) in enumerate(catalog.items()): | |
| with gr.Accordion( | |
| to_title_case(sub_catalog_name), open=(i == 0), elem_classes="accordion" | |
| ) as accordion: | |
| for test_case in sub_catalog: | |
| elem_classes = ["catalog-button"] | |
| elem_id = f"{sub_catalog_name}---{test_case['name']}" | |
| if starting_test_case == test_case: | |
| elem_classes.append("selected") | |
| if sub_catalog_name not in catalog_buttons: | |
| catalog_buttons[sub_catalog_name] = {} | |
| catalog_buttons[sub_catalog_name][test_case["name"]] = gr.Button( | |
| to_title_case(test_case["name"]), | |
| elem_classes=elem_classes, | |
| variant="secondary", | |
| size="sm", | |
| elem_id=elem_id, | |
| ) | |
| accordions.append(accordion) | |
| with gr.Column(scale=1) as test_case_content: | |
| criteria = gr.Textbox( | |
| label="Evaluation Criteria", | |
| lines=3, | |
| interactive=False, | |
| value=starting_test_case["criteria"], | |
| elem_classes=["read-only", "input-box", "margin-bottom"], | |
| ) | |
| gr.HTML(elem_classes=["block", "content-gap"]) | |
| context = gr.Textbox( | |
| label="Context", | |
| lines=3, | |
| interactive=True, | |
| value=starting_test_case["context"], | |
| visible=False, | |
| elem_classes=["input-box"], | |
| ) | |
| tools = gr.Code(label="API Definition (Tools)", visible=False, language="json") | |
| user_message = gr.Textbox( | |
| label="User Prompt", | |
| lines=3, | |
| interactive=True, | |
| value=starting_test_case["user_message"], | |
| elem_classes=["input-box"], | |
| ) | |
| assistant_message_text = gr.Textbox( | |
| label="Assistant Response", | |
| lines=3, | |
| interactive=True, | |
| visible=False, | |
| value=starting_test_case["assistant_message"], | |
| elem_classes=["input-box"], | |
| ) | |
| assistant_message_json = gr.Code( | |
| label="Assistant Response", | |
| visible=False, | |
| language="json", | |
| value=None, | |
| elem_classes=["input-box"], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| result_description = gr.HTML( | |
| value=f"<p><strong>{get_result_description(state.value['selected_sub_catalog'],state.value['selected_criteria_name'])}</strong></p>", | |
| elem_classes="result-meaning", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| submit_button = gr.Button( | |
| value="Evaluate", | |
| variant="primary", | |
| icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "send-white.png"), | |
| ) | |
| with gr.Column(scale=1, min_width=0): | |
| think_checkbox = gr.Checkbox(value=True, label="Thinking") | |
| # result_text = gr.HTML( | |
| # label="Result", elem_classes=["result-text", "read-only", "input-box"], visible=False, value="" | |
| # ) | |
| with gr.Accordion("Thought process", open=False, visible=False) as thinking_trace_accordion: | |
| result_trace = gr.Textbox( | |
| lines=5, | |
| max_lines=7, | |
| interactive=False, | |
| visible=False, | |
| value=None, | |
| show_label=False, | |
| ) | |
| result_text = gr.HTML(label="Result", elem_classes=["result-root"], visible=False, value=None) | |
| with Modal(visible=False, elem_classes="modal") as modal: | |
| prompt = gr.Markdown("") | |
| # events | |
| show_propt_button.click( | |
| on_show_prompt_click, | |
| inputs=[ | |
| criteria, | |
| context, | |
| user_message, | |
| assistant_message_text, | |
| assistant_message_json, | |
| tools, | |
| think_checkbox, | |
| state, | |
| ], | |
| outputs=prompt, | |
| ).then(lambda: gr.update(visible=True), None, modal) | |
| submit_button.click( | |
| lambda think_checkbox: [ | |
| gr.update(visible=True, value=""), | |
| gr.update(visible=think_checkbox, open=think_checkbox), | |
| ], | |
| inputs=[think_checkbox], | |
| outputs=[result_text, thinking_trace_accordion], | |
| ).then( | |
| on_submit, | |
| inputs=[ | |
| criteria, | |
| context, | |
| user_message, | |
| assistant_message_text, | |
| assistant_message_json, | |
| tools, | |
| think_checkbox, | |
| state, | |
| ], | |
| outputs=[result_text, result_trace, thinking_trace_accordion], | |
| scroll_to_output=True, | |
| ) | |
| for button in [ | |
| t for sub_catalog_name, sub_catalog_buttons in catalog_buttons.items() for t in sub_catalog_buttons.values() | |
| ]: | |
| button.click( | |
| change_button_color, inputs=None, outputs=[v for c in catalog_buttons.values() for v in c.values()] | |
| ).then(update_selected_test_case, inputs=[button, state], outputs=[state]).then( | |
| on_test_case_click, | |
| inputs=state, | |
| outputs=[ | |
| test_case_name, | |
| criteria, | |
| context, | |
| user_message, | |
| assistant_message_text, | |
| assistant_message_json, | |
| tools, | |
| result_text, | |
| thinking_trace_accordion, | |
| result_trace, | |
| result_description, | |
| ], | |
| ) | |
| with gr.Tab("Bring Your Own Risk"): | |
| with gr.Row(): | |
| test_case_name = gr.HTML("<h2>Bring your own risk</h2>", elem_classes="subtitle") | |
| show_propt_button = gr.Button("Show prompt", size="sm", scale=0, min_width=110, elem_classes="no-stretch") | |
| byor_criteria = gr.Textbox( | |
| label="Evaluation Criteria", | |
| lines=3, | |
| interactive=True, | |
| value="", | |
| elem_classes=["input-box", "margin-bottom"], | |
| placeholder="Describe your evaluation criteria as a question. e.g. is the response grammatically correct?", | |
| ) | |
| gr.HTML(elem_classes=["block", "content-gap"]) | |
| with gr.Group(elem_classes=["byor-container"]): | |
| byor_user_message = gr.Textbox( | |
| label="User Prompt", | |
| lines=3, | |
| interactive=True, | |
| value="", | |
| elem_classes=["input-box"], | |
| placeholder="The user message to evaluate", | |
| ) | |
| with gr.Row(visible=False, elem_classes=["byor-assistant-message-row"]) as byor_assistant_message_row: | |
| byor_assistant_message = gr.Textbox( | |
| label="Assistant Response", | |
| lines=3, | |
| interactive=True, | |
| value="", | |
| elem_classes=["input-box"], | |
| placeholder="The assistant response to evaluate", | |
| ) | |
| remove_assistant_message_button = gr.Button( | |
| "", | |
| variant="secondary", | |
| icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "trash-can.svg"), | |
| size="sm", | |
| elem_classes=["remove-assistant-message-button"], | |
| ) | |
| add_assistant_message_button = gr.Button( | |
| "Add Assistant Response", | |
| variant="secondary", | |
| elem_classes="add-assistant-button", | |
| visible=True, | |
| icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "add.svg"), | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| result_description = gr.HTML( | |
| value=f"<p><strong>{'Is the user message harmful based on the provided criteria?'}</strong></p>", | |
| elem_classes="result-meaning", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=5): | |
| submit_button = gr.Button( | |
| value="Evaluate", | |
| variant="primary", | |
| icon=os.path.join(os.path.dirname(os.path.abspath(__file__)), "send-white.png"), | |
| ) | |
| with gr.Column(scale=1, min_width=0): | |
| think_checkbox = gr.Checkbox(value=True, label="Thinking") | |
| with gr.Accordion("Thought process", open=False, visible=False) as thinking_trace_accordion: | |
| result_trace = gr.Textbox( | |
| lines=4, | |
| max_lines=7, | |
| interactive=False, | |
| visible=False, | |
| value=None, | |
| show_label=False, | |
| ) | |
| result_text = gr.HTML(label="Result", elem_classes=["result-root"], visible=False, value=None) | |
| with Modal(visible=False, elem_classes="modal") as modal: | |
| prompt = gr.Markdown("") | |
| show_propt_button.click( | |
| on_show_byor_prompt_click, | |
| inputs=[byor_criteria, byor_user_message, byor_assistant_message, think_checkbox], | |
| outputs=prompt, | |
| ).then(lambda: gr.update(visible=True), None, modal) | |
| add_assistant_message_button.click( | |
| on_add_assistant_message_button_click, | |
| outputs=[add_assistant_message_button, byor_assistant_message_row, result_description], | |
| ) | |
| remove_assistant_message_button.click( | |
| on_remove_assistant_message_button, | |
| outputs=[ | |
| byor_assistant_message_row, | |
| byor_assistant_message, | |
| add_assistant_message_button, | |
| result_description, | |
| ], | |
| ) | |
| submit_button.click( | |
| lambda think_checkbox: [ | |
| gr.update(visible=True, value=""), | |
| gr.update(visible=think_checkbox, open=think_checkbox), | |
| ], | |
| inputs=[think_checkbox], | |
| outputs=[result_text, thinking_trace_accordion], | |
| ).then( | |
| on_submit_byor, | |
| inputs=[byor_criteria, byor_user_message, byor_assistant_message, think_checkbox], | |
| outputs=[result_text, result_trace, thinking_trace_accordion], | |
| scroll_to_output=True, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") | |