Spaces:
Runtime error
Runtime error
| import streamlit as st | |
| from parse import retrieve | |
| from transfer import retrieve_transfer | |
| def main(): | |
| st.sidebar.title("Choose Function") | |
| function_choice = st.sidebar.radio("", ["PromptBench", "Retrieve Transferability Information"]) | |
| if function_choice == "PromptBench": | |
| promptbench() | |
| elif function_choice == "Retrieve Transferability Information": | |
| retrieve_transferability_information() | |
| def promptbench(): | |
| st.title("PromptBench") | |
| model_name = st.selectbox( | |
| "Select Model", | |
| options=["T5", "Vicuna", "UL2", "ChatGPT"], | |
| index=0, | |
| ) | |
| dataset_name = st.selectbox( | |
| "Select Dataset", | |
| options=[ | |
| "SST-2", "CoLA", "QQP", "MRPC", "MNLI", "QNLI", | |
| "RTE", "WNLI", "MMLU", "SQuAD V2", "IWSLT 2017", "UN Multi", "Math" | |
| ], | |
| index=0, | |
| ) | |
| attack_name = st.selectbox( | |
| "Select Attack", | |
| options=[ | |
| "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic" | |
| ], | |
| index=0, | |
| ) | |
| prompt_type = st.selectbox( | |
| "Select Prompt Type", | |
| options=["zeroshot-task", "zeroshot-role", "fewshot-task", "fewshot-role"], | |
| index=0, | |
| ) | |
| st.write(f"Model: {model_name}") | |
| st.write(f"Dataset: {dataset_name}") | |
| st.write(f"Prompt Type: {prompt_type}") | |
| if st.button("Retrieve"): | |
| results = retrieve(model_name, dataset_name, attack_name, prompt_type) | |
| for result in results: | |
| st.write("Original prompt: {}".format(result["origin prompt"])) | |
| st.write("Original acc: {}".format(result["origin acc"])) | |
| st.write("Attack prompt: {}".format(result["attack prompt"])) | |
| st.write("Attack acc: {}".format(result["attack acc"])) | |
| def retrieve_transferability_information(): | |
| st.title("Retrieve Transferability Information") | |
| source_model_name = st.selectbox( | |
| "Select Source Model", | |
| options=["T5", "Vicuna", "UL2", "ChatGPT"], | |
| index=0, | |
| ) | |
| target_model_name = st.selectbox( | |
| "Select Target Model", | |
| options=["T5", "Vicuna", "UL2", "ChatGPT"], | |
| index=0, | |
| ) | |
| if source_model_name == target_model_name: | |
| st.write("Source model and target model cannot be the same.") | |
| return | |
| attack_name = st.selectbox( | |
| "Select Attack", | |
| options=[ | |
| "BertAttack", "CheckList", "DeepWordBug", "StressTest", "TextFooler", "TextBugger", "Semantic" | |
| ], | |
| index=0, | |
| ) | |
| if attack_name == "Semantic": | |
| attack_name = "translation" | |
| shot = st.selectbox( | |
| "Select Shot", | |
| options=[0, 3], | |
| index=0, | |
| ) | |
| data = retrieve_transfer(source_model_name, target_model_name, attack_name, shot) | |
| for d in data: | |
| with st.expander(f"Dataset: {d['dataset']} Prompt Type: {d['type']}-oriented"): | |
| st.write(f"Origin prompt: {d['origin_prompt']}") | |
| st.write(f"Attack prompt: {d['atk_prompt']}") | |
| st.write(f"Source model: origin acc: {d['origin_acc']}, attack acc: {d['atk_acc']}") | |
| st.write(f"Target model: origin acc: {d['transfer_ori_acc']}, attack acc: {d['transfer_atk_acc']}") | |
| if __name__ == "__main__": | |
| main() | |