Spaces:
Runtime error
Runtime error
Immortalise
commited on
Commit
·
505d6d4
1
Parent(s):
1c79925
init
Browse files
parse.py
CHANGED
|
@@ -123,50 +123,51 @@ def retrieve(model_name, dataset_name, attack_name, prompt_type):
|
|
| 123 |
directory_path = "./db"
|
| 124 |
md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
|
| 125 |
sections_dict = split_markdown_by_title(md_dir)
|
| 126 |
-
|
| 127 |
for cur_dataset in sections_dict.keys():
|
| 128 |
if cur_dataset == dataset_name:
|
| 129 |
dataset_dict = sections_dict[cur_dataset]
|
| 130 |
for cur_attack in dataset_dict.keys():
|
|
|
|
| 131 |
if cur_attack == attack_name:
|
| 132 |
-
pass
|
| 133 |
|
| 134 |
if attack_name == "translation":
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
atk_acc = []
|
| 138 |
|
| 139 |
-
for
|
| 140 |
-
if "acc: " not in
|
| 141 |
continue
|
| 142 |
-
|
|
|
|
|
|
|
| 143 |
import re
|
| 144 |
|
| 145 |
match_atk = re.search(r'acc: (\d+\.\d+)%', result)
|
| 146 |
-
|
| 147 |
number_atk = float(match_atk.group(1))
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
|
| 152 |
-
|
| 153 |
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
| 165 |
if match_origin and match_atk:
|
| 166 |
number_origin = float(match_origin.group(1))
|
| 167 |
number_atk = float(match_atk.group(1))
|
| 168 |
-
summary[title][dataset].append((number_origin - number_atk)/number_origin)
|
| 169 |
-
summary[title]["Avg"].append((number_origin - number_atk)/number_origin)
|
| 170 |
|
| 171 |
# print(model_shot, dataset, title, len(summary[attack][dataset]), num)
|
| 172 |
|
|
|
|
| 123 |
directory_path = "./db"
|
| 124 |
md_dir = os.path.join(directory_path, model_name + "_" + shot + ".md")
|
| 125 |
sections_dict = split_markdown_by_title(md_dir)
|
| 126 |
+
results = {}
|
| 127 |
for cur_dataset in sections_dict.keys():
|
| 128 |
if cur_dataset == dataset_name:
|
| 129 |
dataset_dict = sections_dict[cur_dataset]
|
| 130 |
for cur_attack in dataset_dict.keys():
|
| 131 |
+
|
| 132 |
if cur_attack == attack_name:
|
|
|
|
| 133 |
|
| 134 |
if attack_name == "translation":
|
| 135 |
+
prompts_dict = dataset_dict[attack_name].split("\n")
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
for prompt_summary in prompts_dict:
|
| 138 |
+
if "acc: " not in prompt_summary:
|
| 139 |
continue
|
| 140 |
+
|
| 141 |
+
prompt = prompt_summary.split("prompt: ")[1]
|
| 142 |
+
|
| 143 |
import re
|
| 144 |
|
| 145 |
match_atk = re.search(r'acc: (\d+\.\d+)%', result)
|
|
|
|
| 146 |
number_atk = float(match_atk.group(1))
|
| 147 |
+
results[prompt] = number_atk
|
| 148 |
+
|
| 149 |
+
sorted_results = sorted(results.items(), key=lambda item: item[1])[:6]
|
| 150 |
+
|
| 151 |
|
| 152 |
+
return sorted_results
|
| 153 |
|
| 154 |
+
elif attack_name in ["bertattack", "checklist", "deepwordbug", "stresstest", "textfooler", "textbugger"]:
|
| 155 |
|
| 156 |
+
prompts_dict = dataset_dict[attack_name].split("\n")
|
| 157 |
+
num = 0
|
| 158 |
+
|
| 159 |
|
| 160 |
+
for prompt_summary in prompts_dict:
|
| 161 |
+
if "Attacked prompt: " not in prompt_summary:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
num += 1
|
| 165 |
+
import re
|
| 166 |
+
match_origin = re.search(r'Original acc: (\d+\.\d+)%', prompt_summary)
|
| 167 |
+
match_atk = re.search(r'attacked acc: (\d+\.\d+)%', prompt_summary)
|
| 168 |
if match_origin and match_atk:
|
| 169 |
number_origin = float(match_origin.group(1))
|
| 170 |
number_atk = float(match_atk.group(1))
|
|
|
|
|
|
|
| 171 |
|
| 172 |
# print(model_shot, dataset, title, len(summary[attack][dataset]), num)
|
| 173 |
|