Spaces:
Running
Running
| import gradio as gr | |
| from generate_plot import generate_main_plot, generate_sub_plot | |
| from utils.score_extract.ood_agg import ood_t2i_agg, ood_i2t_agg | |
| from utils.score_extract.hallucination_agg import hallucination_t2i_agg, hallucination_i2t_agg | |
| from utils.score_extract.safety_agg import safety_t2i_agg, safety_i2t_agg | |
| from utils.score_extract.adversarial_robustness_agg import adversarial_robustness_t2i_agg, adversarial_robustness_i2t_agg | |
| from utils.score_extract.fairness_agg import fairness_t2i_agg, fairness_i2t_agg | |
| from utils.score_extract.privacy_agg import privacy_t2i_agg, privacy_i2t_agg | |
| t2i_models = [ # Average time spent running the following example | |
| "dall-e-2", | |
| "dall-e-3", | |
| "DeepFloyd/IF-I-M-v1.0", # 15.372 | |
| "dreamlike-art/dreamlike-photoreal-2.0", # 3.526 | |
| "prompthero/openjourney-v4", # 4.981 | |
| "stabilityai/stable-diffusion-xl-base-1.0", # 7.463 | |
| ] | |
| i2t_models = [ # Average time spent running the following example | |
| "gpt-4-vision-preview", | |
| "gpt-4o-2024-05-13", | |
| "llava-hf/llava-v1.6-vicuna-7b-hf" | |
| ] | |
| perspectives = ["Safety", "Fairness", "Hallucination", "Privacy", "Adv", "OOD"] | |
| main_scores_t2i = {} | |
| main_scores_i2t = {} | |
| sub_scores_t2i = {} | |
| sub_scores_i2t = {} | |
| for model in t2i_models: | |
| model = model.split("/")[-1] | |
| main_scores_t2i[model] = {} | |
| for perspective in perspectives: | |
| if perspective not in sub_scores_t2i.keys(): | |
| sub_scores_t2i[perspective] = {} | |
| if perspective == "Hallucination": | |
| main_scores_t2i[model][perspective] = hallucination_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = hallucination_t2i_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Safety": | |
| main_scores_t2i[model][perspective] = safety_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = safety_t2i_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Adv": | |
| main_scores_t2i[model][perspective] = adversarial_robustness_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = adversarial_robustness_t2i_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Fairness": | |
| main_scores_t2i[model][perspective] = fairness_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = fairness_t2i_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Privacy": | |
| main_scores_t2i[model][perspective] = privacy_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = privacy_t2i_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "OOD": | |
| main_scores_t2i[model][perspective] = ood_t2i_agg(model, "./data/results")["score"] | |
| sub_scores_t2i[perspective][model] = ood_t2i_agg(model, "./data/results")["subscenarios"] | |
| else: | |
| raise ValueError("Invalid perspective") | |
| for model in i2t_models: | |
| model = model.split("/")[-1] | |
| main_scores_i2t[model] = {} | |
| for perspective in perspectives: | |
| if perspective not in sub_scores_i2t.keys(): | |
| sub_scores_i2t[perspective] = {} | |
| if perspective == "Hallucination": | |
| main_scores_i2t[model][perspective] = hallucination_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = hallucination_i2t_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Safety": | |
| main_scores_i2t[model][perspective] = safety_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = safety_i2t_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Adv": | |
| main_scores_i2t[model][perspective] = adversarial_robustness_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = adversarial_robustness_i2t_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Fairness": | |
| main_scores_i2t[model][perspective] = fairness_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = fairness_i2t_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "Privacy": | |
| main_scores_i2t[model][perspective] = privacy_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = privacy_i2t_agg(model, "./data/results")["subscenarios"] | |
| elif perspective == "OOD": | |
| main_scores_i2t[model][perspective] = ood_i2t_agg(model, "./data/results")["score"] | |
| sub_scores_i2t[perspective][model] = ood_i2t_agg(model, "./data/results")["subscenarios"] | |
| else: | |
| raise ValueError("Invalid perspective") | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| with gr.Column(visible=True) as output_col: | |
| with gr.Row(visible=True) as report_col: | |
| curr_select = gr.Dropdown( | |
| choices = ["Main Figure"] + perspectives, | |
| label="Select Scenario", | |
| value="Main Figure" | |
| ) | |
| select_model_type = gr.Dropdown( | |
| choices = ["T2I", "I2T"], | |
| label = "Select Model Type", | |
| value = "T2I" | |
| ) | |
| gr.Markdown("# Overall statistics") | |
| plot = gr.Plot(value=generate_main_plot(t2i_models, main_scores_t2i)) | |
| def radar(model_type, perspective): | |
| perspectives_name = perspectives + ["Main Figure"] | |
| if model_type == "T2I": | |
| models = t2i_models | |
| main_scores = main_scores_t2i | |
| sub_scores = sub_scores_t2i | |
| else: | |
| models = i2t_models | |
| main_scores = main_scores_i2t | |
| sub_scores = sub_scores_i2t | |
| if len(perspective) == 0 or perspective == "Main Figure": | |
| fig = generate_main_plot(models, main_scores) | |
| select = gr.Dropdown(choices=perspectives_name, value="Main Figure", label="Select Scenario") | |
| type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type) | |
| else: | |
| fig = generate_sub_plot(models, sub_scores, perspective) | |
| select = gr.Dropdown(choices=perspectives_name, value=perspective, label="Select Scenario") | |
| type_dropdown = gr.Dropdown(choices=["T2I", "I2T"], label="Select Model Type", value=model_type) | |
| return {plot: fig, curr_select: select, select_model_type: type_dropdown} | |
| gr.on(triggers=[curr_select.change, select_model_type.change], fn=radar, inputs=[select_model_type, curr_select], outputs=[plot, curr_select, select_model_type]) | |
| if __name__ == "__main__": | |
| demo.queue().launch() | |