Upload inference code
Browse files- DFS_search_with_concurrent.py +571 -0
DFS_search_with_concurrent.py
ADDED
|
@@ -0,0 +1,571 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from propositional_logic.random_gen.evaluation_access import *
|
| 2 |
+
from vllm import LLM, SamplingParams
|
| 3 |
+
import time
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
import os
|
| 7 |
+
import time
|
| 8 |
+
import re
|
| 9 |
+
import sys
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
import argparse
|
| 12 |
+
from loguru import logger
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import concurrent.futures
|
| 15 |
+
import random
|
| 16 |
+
import subprocess
|
| 17 |
+
|
| 18 |
+
import uuid
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#python -m DFS.DFS_search_with_concurrent ~/train_scripts/train_scripts_prop_serv8/v3_2_0_66_0_8/checkpoint-1600 1 ~/leandojo_project/proplogic_serv8/train_val_test/data_45w/data_5_vars/key_directory/key_20w_quantile_0_66_0_8_out_dist_test.json ~/leandojo_project/atp_research/DFS/output/outcome_basic_key_ge_080_1000.pkl
|
| 25 |
+
|
| 26 |
+
class DFS:
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
num_sampled_tactics: int,
|
| 30 |
+
temperature,
|
| 31 |
+
test_theorem_list,
|
| 32 |
+
max_workers,
|
| 33 |
+
saved_file_path,
|
| 34 |
+
experiment_id
|
| 35 |
+
) -> None:
|
| 36 |
+
self.max_workers = max_workers
|
| 37 |
+
self.test_theorem_list = test_theorem_list
|
| 38 |
+
self.whether_backtrack = {}
|
| 39 |
+
self.num_sampled_tactics = num_sampled_tactics
|
| 40 |
+
self.count_lean_dict = {}
|
| 41 |
+
self.counter_success = 0
|
| 42 |
+
self.counter_failed = 0
|
| 43 |
+
self.counter_in_process = 0
|
| 44 |
+
self.counter_too_long = 0
|
| 45 |
+
self.tactic_list_tree = {}
|
| 46 |
+
self.theorem_object_dict = {}
|
| 47 |
+
self.prompts_tactic_state_list = {}
|
| 48 |
+
self.root = {}
|
| 49 |
+
self.round_count = {}
|
| 50 |
+
self.parent_node_of_node = {}
|
| 51 |
+
self.round = 0
|
| 52 |
+
self.saved_file_path = saved_file_path
|
| 53 |
+
self.cstate_round = {} #should never be used once we want to apply tactic. should only be used at the beginning of each round
|
| 54 |
+
self.experiment_id = experiment_id
|
| 55 |
+
self.key_to_be_infered = []
|
| 56 |
+
self.key_not_finished = []
|
| 57 |
+
self.prompts_entered = []
|
| 58 |
+
self.counter_failed_with_error = 0
|
| 59 |
+
self.temperature = temperature
|
| 60 |
+
|
| 61 |
+
print(f"{len(self.test_theorem_list)} many theorem loaded")
|
| 62 |
+
for key in tqdm(self.test_theorem_list):
|
| 63 |
+
sample_eval = SingleTheoremEval(5, int(key))
|
| 64 |
+
self.theorem_object_dict[key] = sample_eval
|
| 65 |
+
init_state = self.theorem_object_dict[key].get_initial_prompt()
|
| 66 |
+
|
| 67 |
+
self.parent_node_of_node[key] = {}
|
| 68 |
+
self.prompts_tactic_state_list[key] = [init_state]
|
| 69 |
+
self.root[key] = 'open' #open, success or failed
|
| 70 |
+
self.tactic_list_tree[key] = {}
|
| 71 |
+
self.tactic_list_tree[key]["state_0:"] = None
|
| 72 |
+
self.parent_node_of_node[key]["state_0:"] = None
|
| 73 |
+
self.count_lean_dict[key] = {}
|
| 74 |
+
self.count_lean_dict[key]['count_lean_multiple_backtrack'] = 0
|
| 75 |
+
self.count_lean_dict[key]['count_lean_single_backtrack'] = 0
|
| 76 |
+
self.count_lean_dict[key]['count_lean_tactic_success'] = 0
|
| 77 |
+
self.round_count[key] = 0
|
| 78 |
+
self.whether_backtrack[key] = False
|
| 79 |
+
|
| 80 |
+
print('initialization done')
|
| 81 |
+
def get_current_state_number(self, key):
|
| 82 |
+
string = self.theorem_object_dict[key].get_current_state_with_label()
|
| 83 |
+
for line in string.split('\n'):
|
| 84 |
+
break
|
| 85 |
+
return line
|
| 86 |
+
def get_prev_state_number(self, key):
|
| 87 |
+
string = self.theorem_object_dict[key].get_prev_state_with_label()
|
| 88 |
+
for line in string.split('\n'):
|
| 89 |
+
break
|
| 90 |
+
return line
|
| 91 |
+
def revise_entered_tactic(self,entered_tactic,key):
|
| 92 |
+
if len(entered_tactic) != 2:
|
| 93 |
+
assert True==False
|
| 94 |
+
current_state_label = self.get_current_state_number(key)
|
| 95 |
+
entered_tactic[0] = current_state_label[:-1] + '_tactic_0:'
|
| 96 |
+
return entered_tactic
|
| 97 |
+
def back_track_tactic(self, key):
|
| 98 |
+
current_state_number = self.get_current_state_number(key)
|
| 99 |
+
for line in current_state_number.split('\n'):
|
| 100 |
+
break
|
| 101 |
+
match = re.search(r'\d+', line)
|
| 102 |
+
if match:
|
| 103 |
+
extracted_current_integer = int(match.group())
|
| 104 |
+
else:
|
| 105 |
+
assert False, 'no number in current state'
|
| 106 |
+
|
| 107 |
+
previous_state = self.parent_node_of_node[key][current_state_number]
|
| 108 |
+
previous_state_to_be_checked = self.get_prev_state_number(key)
|
| 109 |
+
if previous_state != previous_state_to_be_checked:
|
| 110 |
+
assert False, f'key is {key}, during backtrack, previous state marked and previous state by system are not the same'
|
| 111 |
+
for line in previous_state.split('\n'):
|
| 112 |
+
break
|
| 113 |
+
match = re.search(r'\d+', line)
|
| 114 |
+
if match:
|
| 115 |
+
extracted_previous_integer = int(match.group())
|
| 116 |
+
else:
|
| 117 |
+
assert False
|
| 118 |
+
return f"no solution, return to state {extracted_previous_integer} [that leads to state {extracted_current_integer}]"
|
| 119 |
+
def revise_output_list(self, output_text):
|
| 120 |
+
output_line_list = output_text.split("\n")
|
| 121 |
+
is_tactic = False
|
| 122 |
+
for idx_tactic, line in enumerate(output_line_list):
|
| 123 |
+
if '_tactic_' in line:
|
| 124 |
+
is_tactic = True
|
| 125 |
+
break
|
| 126 |
+
|
| 127 |
+
if is_tactic==False:
|
| 128 |
+
#print('output, warning: no tactic')
|
| 129 |
+
return ['no_tactic','no_tactic']
|
| 130 |
+
|
| 131 |
+
if "::: " in output_line_list[idx_tactic]:
|
| 132 |
+
output_line_list[idx_tactic] = output_line_list[idx_tactic][4:]
|
| 133 |
+
entered_tactic_list = output_line_list[idx_tactic:idx_tactic+2]
|
| 134 |
+
|
| 135 |
+
if len(entered_tactic_list) == 1:
|
| 136 |
+
return ['no_tactic','no_tactic']
|
| 137 |
+
return entered_tactic_list
|
| 138 |
+
def check_if_failure_per_key(self, key):
|
| 139 |
+
if len(self.tactic_list_tree[key]['state_0:']) == 0 and self.get_current_state_number(key) == 'state_0:':
|
| 140 |
+
print('triggered failure')
|
| 141 |
+
return True
|
| 142 |
+
else:
|
| 143 |
+
return False
|
| 144 |
+
def check_path_length(self,key):
|
| 145 |
+
current_state_label = self.get_current_state_number(key)
|
| 146 |
+
previous_state_label = current_state_label
|
| 147 |
+
theorem_object_length = 1
|
| 148 |
+
#print(f'key is {key}, current tactic list tree is {self.tactic_list_tree[key]}')
|
| 149 |
+
while True:
|
| 150 |
+
#print("state chain label is: ", previous_state_label)
|
| 151 |
+
if previous_state_label == 'state_0:':
|
| 152 |
+
break
|
| 153 |
+
previous_state_label = self.parent_node_of_node[key][current_state_label]
|
| 154 |
+
theorem_object_length += 1
|
| 155 |
+
current_state_label = previous_state_label
|
| 156 |
+
#print(f'prompt tactic list length is {len(self.prompts_tactic_state_list[key])}')
|
| 157 |
+
#print(f'theorem_object_length is {theorem_object_length}')
|
| 158 |
+
if theorem_object_length != len(self.prompts_tactic_state_list[key]):
|
| 159 |
+
assert True==False, "path_length not equal to each other"
|
| 160 |
+
def check_if_program_finished(self):
|
| 161 |
+
stop_signal = True
|
| 162 |
+
for key in self.test_theorem_list:
|
| 163 |
+
if self.root[key] == 'open':
|
| 164 |
+
stop_signal = False
|
| 165 |
+
else:
|
| 166 |
+
pass
|
| 167 |
+
return stop_signal
|
| 168 |
+
def revise_prompt(self, prompts_tactic_state_list_per_key):
|
| 169 |
+
pattern = r'state_\d+:'
|
| 170 |
+
matches = re.findall(pattern, prompts_tactic_state_list_per_key)
|
| 171 |
+
state_order = {}
|
| 172 |
+
order = 0
|
| 173 |
+
for match in matches:
|
| 174 |
+
if match not in state_order:
|
| 175 |
+
state_order[match] = order
|
| 176 |
+
order += 1
|
| 177 |
+
|
| 178 |
+
for state, ord in state_order.items():
|
| 179 |
+
prompts_tactic_state_list_per_key = prompts_tactic_state_list_per_key.replace(state, f'state_{ord}:')
|
| 180 |
+
|
| 181 |
+
last_state_id = None
|
| 182 |
+
output_prompt = []
|
| 183 |
+
for line in prompts_tactic_state_list_per_key.split('\n'):
|
| 184 |
+
if re.search('state_\d+:', line):
|
| 185 |
+
last_state_id = line[6:-1]
|
| 186 |
+
elif re.search('state_\d+_tactic_', line):
|
| 187 |
+
line = f'state_{last_state_id}_tactic_0:'
|
| 188 |
+
output_prompt.append(line)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
'''for idx, item in enumerate(prompts_tactic_state_list_per_key):
|
| 192 |
+
temp_string = re.sub(r'state_(\d+):',f'state_{idx}:', item)
|
| 193 |
+
prompts_tactic_state_list_per_key[idx] = re.sub(r'state_(\d+)_tactic_(\d+):',f'state_{idx-1}_tactic_0:', temp_string)'''
|
| 194 |
+
return "\n".join(output_prompt)
|
| 195 |
+
def status_report(self):
|
| 196 |
+
counter_in_process = 0
|
| 197 |
+
counter_success = 0
|
| 198 |
+
counter_failed = 0
|
| 199 |
+
counter_too_long = 0
|
| 200 |
+
counter_failed_with_error = 0
|
| 201 |
+
for key in self.test_theorem_list:
|
| 202 |
+
if self.root[key] == 'open':
|
| 203 |
+
counter_in_process += 1
|
| 204 |
+
if self.root[key] == 'success':
|
| 205 |
+
counter_success += 1
|
| 206 |
+
if self.root[key] == 'failed':
|
| 207 |
+
counter_failed += 1
|
| 208 |
+
if self.root[key] == 'failed, too long':
|
| 209 |
+
counter_too_long += 1
|
| 210 |
+
if self.root[key] == 'failed with error':
|
| 211 |
+
counter_failed_with_error += 1
|
| 212 |
+
self.counter_success = counter_success
|
| 213 |
+
self.counter_failed = counter_failed
|
| 214 |
+
self.counter_failed_with_error = counter_failed_with_error
|
| 215 |
+
self.counter_too_long = counter_too_long
|
| 216 |
+
self.counter_in_process = counter_in_process
|
| 217 |
+
if counter_success + counter_failed + counter_too_long + counter_in_process + counter_failed_with_error != len(test_theorem_list):
|
| 218 |
+
assert False, 'success, failed, too long, in process, failed with error add up not equal to total number'
|
| 219 |
+
print(f'saved_file_path is {self.saved_file_path}')
|
| 220 |
+
print(f'total number of theorem is {len(self.test_theorem_list)}')
|
| 221 |
+
print(f'proof success number is {self.counter_success}')
|
| 222 |
+
print(f'proof failed number is {self.counter_failed}')
|
| 223 |
+
print(f'proof failed with error number is {self.counter_failed_with_error}')
|
| 224 |
+
print(f'proof too long number is {self.counter_too_long}')
|
| 225 |
+
print(f'proof in process number is {self.counter_in_process}')
|
| 226 |
+
|
| 227 |
+
count_lean_single_backtrack = 0
|
| 228 |
+
count_lean_multiple_backtrack = 0
|
| 229 |
+
count_lean_tactic_success = 0
|
| 230 |
+
for key in test_theorem_list:
|
| 231 |
+
count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack']
|
| 232 |
+
count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack']
|
| 233 |
+
count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success']
|
| 234 |
+
|
| 235 |
+
print(f'total lean count tactic success is {count_lean_tactic_success}')
|
| 236 |
+
print(f'total lean count single backtrack is {count_lean_single_backtrack}')
|
| 237 |
+
print(f'total lean count multiple backtrack is {count_lean_multiple_backtrack}')
|
| 238 |
+
def collect_inference_result(self, key_to_be_infered, outputs):
|
| 239 |
+
for idx, output_list in tqdm(enumerate(outputs), total=len(outputs), desc=f"Processing LLM output for Round {self.round}"):
|
| 240 |
+
assinged_output_list_per_key = []
|
| 241 |
+
for i in range(0, self.num_sampled_tactics):
|
| 242 |
+
output_tactic = self.revise_output_list(output_list.outputs[i].text)
|
| 243 |
+
if output_tactic[0] == 'no_tactic' or output_tactic[1] == 'no_tactic':
|
| 244 |
+
pass
|
| 245 |
+
else:
|
| 246 |
+
assinged_output_list_per_key.append(output_tactic)
|
| 247 |
+
#print(f'key is {key_to_be_infered[idx]}, output_tactic is {output_tactic}')
|
| 248 |
+
#print()
|
| 249 |
+
seen = set()
|
| 250 |
+
unique_assigned_output_list_per_key = []
|
| 251 |
+
for inner_list in assinged_output_list_per_key:
|
| 252 |
+
inner_tuple = tuple(inner_list)
|
| 253 |
+
if inner_tuple not in seen:
|
| 254 |
+
seen.add(inner_tuple)
|
| 255 |
+
unique_assigned_output_list_per_key.append(inner_list)
|
| 256 |
+
#print(
|
| 257 |
+
# f'key is {key_to_be_infered[idx]}, state_number to be assigned new tactic list is {self.cstate_round[key_to_be_infered[idx]]}')
|
| 258 |
+
#print(f'key is {key_to_be_infered[idx]}, Assigned tactic list is {unique_assigned_output_list_per_key}')
|
| 259 |
+
self.tactic_list_tree[key_to_be_infered[idx]][self.cstate_round[key_to_be_infered[idx]]] = unique_assigned_output_list_per_key
|
| 260 |
+
|
| 261 |
+
def current_state_obtained_list(self, key):
|
| 262 |
+
if self.root[key] == 'open':
|
| 263 |
+
self.cstate_round[key] = self.get_current_state_number(key)
|
| 264 |
+
cstate = self.cstate_round[key]
|
| 265 |
+
pickle.dump(cstate, open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','wb'))
|
| 266 |
+
def search(self):
|
| 267 |
+
tokenizer = llm.get_tokenizer()
|
| 268 |
+
while True:
|
| 269 |
+
self.round += 1
|
| 270 |
+
print(f'Round {self.round}------')
|
| 271 |
+
if self.check_if_program_finished() or self.round > 65:
|
| 272 |
+
print('confirmed test theorem finished. exit.')
|
| 273 |
+
self.status_report()
|
| 274 |
+
break
|
| 275 |
+
|
| 276 |
+
self.key_to_be_infered = []
|
| 277 |
+
self.key_not_finished = []
|
| 278 |
+
self.prompts_entered = []
|
| 279 |
+
|
| 280 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
| 281 |
+
list(tqdm(executor.map(self.current_state_obtained_list, self.test_theorem_list), total=len(self.test_theorem_list), desc=f"get current state for Round {self.round}"))
|
| 282 |
+
|
| 283 |
+
key_prompt_length_list = []
|
| 284 |
+
for key in tqdm(self.test_theorem_list, total=len(self.test_theorem_list), desc=f"check state for theorems for round {self.round}" ):
|
| 285 |
+
if self.root[key] == 'open':
|
| 286 |
+
self.round_count[key] = self.round
|
| 287 |
+
self.key_not_finished.append(key)
|
| 288 |
+
self.cstate_round[key] = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','rb'))
|
| 289 |
+
if self.tactic_list_tree[key][self.cstate_round[key]] == None:
|
| 290 |
+
prompt_per_key = self.revise_prompt('\n'.join(self.prompts_tactic_state_list[key]))
|
| 291 |
+
key_prompt_length_list.append(len(prompt_per_key.split()))
|
| 292 |
+
tokenized_prompt_per_key = tokenizer.encode(prompt_per_key)
|
| 293 |
+
if len(prompt_per_key.split()) < 1500 and len(tokenized_prompt_per_key) < 4000: # used to be 1500
|
| 294 |
+
self.key_to_be_infered.append(key)
|
| 295 |
+
self.prompts_entered.append(prompt_per_key)
|
| 296 |
+
else:
|
| 297 |
+
self.root[key] = 'failed, too long'
|
| 298 |
+
self.key_not_finished.remove(key)
|
| 299 |
+
print(f'key open need inference before check length, length list is {key_prompt_length_list}')
|
| 300 |
+
|
| 301 |
+
print(f'key to be infered is {self.key_to_be_infered}')
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
sampling_params = SamplingParams(n=self.num_sampled_tactics, temperature=self.temperature, top_p=1,
|
| 307 |
+
max_tokens=200) # temperature is 1.2 at beginning
|
| 308 |
+
outputs = llm.generate(self.prompts_entered, sampling_params)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
print('now we collect the inference')
|
| 312 |
+
self.collect_inference_result(self.key_to_be_infered, outputs)
|
| 313 |
+
print('inference collected')
|
| 314 |
+
|
| 315 |
+
print(f'enter concurrent process with max_workers as {self.max_workers}')
|
| 316 |
+
|
| 317 |
+
with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
|
| 318 |
+
list(tqdm(executor.map(self.search_per_key_per_step, self.key_not_finished), total=len(self.key_not_finished), desc=f"Lean_verifying for Round {self.round}"))
|
| 319 |
+
|
| 320 |
+
'''for key in self.key_not_finished:
|
| 321 |
+
self.search_per_key_per_step(key)'''
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
for key in tqdm(self.key_not_finished, total=len(self.key_not_finished), desc=f"Retrieve results from concurrent processing for Round {self.round}"):
|
| 325 |
+
report_key_file = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','rb'))
|
| 326 |
+
self.count_lean_dict[key]['count_lean_single_backtrack'] += report_key_file['count_lean_single_backtrack']
|
| 327 |
+
self.count_lean_dict[key]['count_lean_multiple_backtrack'] += report_key_file['count_lean_multiple_backtrack']
|
| 328 |
+
self.count_lean_dict[key]['count_lean_tactic_success'] += report_key_file['count_lean_tactic_success']
|
| 329 |
+
self.root[key] = report_key_file['key_status']
|
| 330 |
+
self.tactic_list_tree[key] = report_key_file['tactic_list_tree']
|
| 331 |
+
self.prompts_tactic_state_list[key] = report_key_file['prompts_tactic_state_list']
|
| 332 |
+
self.theorem_object_dict[key] = report_key_file['theorem_object_dict']
|
| 333 |
+
self.parent_node_of_node[key] = report_key_file['node_relation']
|
| 334 |
+
self.whether_backtrack[key] = report_key_file['whether_backtrack']
|
| 335 |
+
|
| 336 |
+
self.status_report()
|
| 337 |
+
self.save_outcome()
|
| 338 |
+
def search_per_key_per_step(self, key):
|
| 339 |
+
#print(f'key is {key}, current tactic_list_tree for Round {self.round} is {self.tactic_list_tree[key]}')
|
| 340 |
+
|
| 341 |
+
try:
|
| 342 |
+
key_status = 'open'
|
| 343 |
+
count_lean_tactic = 0
|
| 344 |
+
count_lean_single_backtrack = 0
|
| 345 |
+
count_lean_multiple_backtrack = 0
|
| 346 |
+
whether_backtrack = False
|
| 347 |
+
|
| 348 |
+
tactic_list_at_top_per_key = self.tactic_list_tree[key][self.get_current_state_number(key)]
|
| 349 |
+
#print(f"key is {key}, we start to search")
|
| 350 |
+
if tactic_list_at_top_per_key == None:
|
| 351 |
+
assert False, f"tactic_list_at_top_per_key is None, key is {key}\ncurrent tactic list is {self.tactic_list_tree[key]}" \
|
| 352 |
+
f"\nwhether key in key_to_be_infered {key in self.key_to_be_infered}"
|
| 353 |
+
|
| 354 |
+
if len(tactic_list_at_top_per_key) != 0:
|
| 355 |
+
try:
|
| 356 |
+
#print(f'key is {key}, We apply tactic')
|
| 357 |
+
count_lean_tactic += 1
|
| 358 |
+
entered_tactic = self.revise_entered_tactic(tactic_list_at_top_per_key[0], key)
|
| 359 |
+
label_before_tactic = self.get_current_state_number(key)
|
| 360 |
+
#print(f'key is {key}, get proof before apply tactic is: {self.theorem_object_dict[key].get_current_lean_proof()}')
|
| 361 |
+
|
| 362 |
+
#print(f'key is {key}, current tactic_list_tree before apply tactic is {self.tactic_list_tree[key]}')
|
| 363 |
+
#print(f'key is {key}, entered_tactic is {entered_tactic}')
|
| 364 |
+
#print(f'key is {key}, current state before apply tactic is: {label_before_tactic}')
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
lean_output = self.theorem_object_dict[key].provide_tactic(entered_tactic[0], entered_tactic[1])
|
| 368 |
+
label_after_tactic = self.get_current_state_number(key)
|
| 369 |
+
#print(f'key is {key}, current state after apply tactic is: {label_after_tactic}')
|
| 370 |
+
|
| 371 |
+
if 'proof is complete' in lean_output[1]:
|
| 372 |
+
key_status = 'success'
|
| 373 |
+
print(f'key is {key}, search is success!')
|
| 374 |
+
print(f'key is {key}, the successful proof \n{self.theorem_object_dict[key].get_current_lean_proof()}')
|
| 375 |
+
#self.status_report()
|
| 376 |
+
|
| 377 |
+
self.parent_node_of_node[key][label_after_tactic] = label_before_tactic
|
| 378 |
+
self.prompts_tactic_state_list[key].append(
|
| 379 |
+
f"{entered_tactic[0]}\n{entered_tactic[1]}\n{lean_output[1]}")
|
| 380 |
+
self.tactic_list_tree[key][self.get_current_state_number(key)] = None
|
| 381 |
+
del self.tactic_list_tree[key][label_before_tactic][0]
|
| 382 |
+
#print(f'key is {key}, current tactic_list_tree after apply tactic is {self.tactic_list_tree[key]}')
|
| 383 |
+
self.check_path_length(key)
|
| 384 |
+
except Exception as e:
|
| 385 |
+
#print('tactic error happen')
|
| 386 |
+
#print(f'key is {key}, we apply tactic {entered_tactic} and see error')
|
| 387 |
+
#print(e)
|
| 388 |
+
#print(f'key is {key}, current state number after error after apply tactic is: {self.get_current_state_number(key)}')
|
| 389 |
+
#print(f'key is {key}, get proof after error after apply tactic is: {self.theorem_object_dict[key].get_current_lean_proof()}')
|
| 390 |
+
#print(f'key is {key}, current tactic_list_tree after error after apply tactic is {self.tactic_list_tree[key]}')
|
| 391 |
+
#print(f'key is {key}, we now delete the tactic from tactic list tree, the label is {self.get_current_state_number(key)}')
|
| 392 |
+
del self.tactic_list_tree[key][self.get_current_state_number(key)][0]
|
| 393 |
+
#print(f"key is {key}, tactic from tactic list tree is deleted, current tactic_list_tree is {self.tactic_list_tree[key]}")
|
| 394 |
+
self.check_path_length(key)
|
| 395 |
+
if self.check_if_failure_per_key(key):
|
| 396 |
+
key_status = 'failed'
|
| 397 |
+
print(f"key is {key}, tactic error then search failed!")
|
| 398 |
+
#self.status_report()
|
| 399 |
+
else:
|
| 400 |
+
if self.check_if_failure_per_key(key):
|
| 401 |
+
key_status = 'failed'
|
| 402 |
+
print(f"key is {key}, backtrack to zero and no tactic to try, search failed!")
|
| 403 |
+
#self.status_report()
|
| 404 |
+
|
| 405 |
+
print(f'key is {key}, backtrack phase activated')
|
| 406 |
+
whether_backtrack = True
|
| 407 |
+
count_lean_single_backtrack += 1
|
| 408 |
+
while True:
|
| 409 |
+
tactic_list_at_intermediate_node = self.tactic_list_tree[key][self.get_current_state_number(key)]
|
| 410 |
+
# print(f'tactic_tree_list: {self.tactic_list_tree[key]}')
|
| 411 |
+
# print(f"current state number: {self.get_current_state_number(key)}")
|
| 412 |
+
if len(tactic_list_at_intermediate_node) != 0:
|
| 413 |
+
break
|
| 414 |
+
if self.check_if_failure_per_key(key):
|
| 415 |
+
key_status = 'failed'
|
| 416 |
+
print(f"key is {key}, backtrack to zero and no tactic to try, search failed!")
|
| 417 |
+
#self.status_report()
|
| 418 |
+
break
|
| 419 |
+
#print(f'key is {key}, before back track step check length')
|
| 420 |
+
self.check_path_length(key)
|
| 421 |
+
#print(f'key is {key}, current state before backtrack is {self.get_current_state_number(key)}')
|
| 422 |
+
count_lean_multiple_backtrack += 1
|
| 423 |
+
lean_output = self.theorem_object_dict[key].do_back_track(self.back_track_tactic(key))
|
| 424 |
+
#print(f'current state after backtrack is {self.get_current_state_number(key)}')
|
| 425 |
+
#print('before delete the last ele of prompts_tactic_state_list')
|
| 426 |
+
#print('prompts_tactic_state_list is:')
|
| 427 |
+
#print(self.prompts_tactic_state_list[key])
|
| 428 |
+
#print(f'key for remove prompt here is {key}')
|
| 429 |
+
|
| 430 |
+
del self.prompts_tactic_state_list[key][-1]
|
| 431 |
+
|
| 432 |
+
#print('after back track step check length')
|
| 433 |
+
self.check_path_length(key)
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"key is {key}, exception happend")
|
| 436 |
+
print(e)
|
| 437 |
+
key_status = 'failed with error'
|
| 438 |
+
count_lean_tactic = 0
|
| 439 |
+
count_lean_single_backtrack = 0
|
| 440 |
+
count_lean_multiple_backtrack = 0
|
| 441 |
+
|
| 442 |
+
finally:
|
| 443 |
+
report_key_file = {}
|
| 444 |
+
report_key_file['count_lean_multiple_backtrack'] = count_lean_multiple_backtrack + count_lean_tactic
|
| 445 |
+
report_key_file['count_lean_single_backtrack'] = count_lean_single_backtrack + count_lean_tactic
|
| 446 |
+
report_key_file['count_lean_tactic_success'] = count_lean_tactic
|
| 447 |
+
report_key_file['key_status'] = key_status
|
| 448 |
+
report_key_file['tactic_list_tree'] = self.tactic_list_tree[key]
|
| 449 |
+
report_key_file['prompts_tactic_state_list'] = self.prompts_tactic_state_list[key]
|
| 450 |
+
report_key_file['theorem_object_dict'] = self.theorem_object_dict[key]
|
| 451 |
+
report_key_file['node_relation'] = self.parent_node_of_node[key]
|
| 452 |
+
report_key_file['whether_backtrack'] = whether_backtrack
|
| 453 |
+
pickle.dump(report_key_file, open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','wb'))
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
def save_outcome(self):
|
| 458 |
+
counter_in_process = 0
|
| 459 |
+
counter_success = 0
|
| 460 |
+
counter_failed = 0
|
| 461 |
+
counter_too_long = 0
|
| 462 |
+
count_lean_single_backtrack = 0
|
| 463 |
+
count_lean_multiple_backtrack = 0
|
| 464 |
+
count_lean_tactic_success = 0
|
| 465 |
+
counter_failed_with_error = 0
|
| 466 |
+
|
| 467 |
+
proof_dict = {}
|
| 468 |
+
for key in tqdm(self.test_theorem_list,total=len(self.test_theorem_list),desc='saving results'):
|
| 469 |
+
count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack']
|
| 470 |
+
count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack']
|
| 471 |
+
count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success']
|
| 472 |
+
proof_dict[key] = self.theorem_object_dict[key].get_current_lean_proof()
|
| 473 |
+
if self.root[key] == 'open':
|
| 474 |
+
counter_in_process += 1
|
| 475 |
+
if self.root[key] == 'success':
|
| 476 |
+
counter_success += 1
|
| 477 |
+
if self.root[key] == 'failed':
|
| 478 |
+
counter_failed += 1
|
| 479 |
+
if self.root[key] == 'failed, too long':
|
| 480 |
+
counter_too_long += 1
|
| 481 |
+
if self.root[key] == 'failed with error':
|
| 482 |
+
counter_failed_with_error += 1
|
| 483 |
+
self.counter_success = counter_success
|
| 484 |
+
self.counter_failed = counter_failed
|
| 485 |
+
self.counter_failed_with_error = counter_failed_with_error
|
| 486 |
+
self.counter_too_long = counter_too_long
|
| 487 |
+
if counter_success + counter_failed + counter_too_long + counter_failed_with_error + counter_in_process!= len(test_theorem_list):
|
| 488 |
+
assert False, 'number of theorm not equal to success, failed or too long, or in process'
|
| 489 |
+
outcome = {}
|
| 490 |
+
outcome['stats'] = {}
|
| 491 |
+
outcome['stats']['total_lean_count_single_backtrack'] = count_lean_single_backtrack
|
| 492 |
+
outcome['stats']['total_lean_count_multiple_backtrack'] = count_lean_multiple_backtrack
|
| 493 |
+
outcome['stats']['count_lean_tactic_success'] = count_lean_tactic_success
|
| 494 |
+
outcome['stats']['num_success'] = self.counter_success
|
| 495 |
+
outcome['stats']['num_failed'] = self.counter_failed
|
| 496 |
+
outcome['stats']['num_failed_with_error'] = self.counter_failed_with_error
|
| 497 |
+
outcome['stats']['num_too_long'] = self.counter_too_long
|
| 498 |
+
outcome['stats']['num_sampled_tactics'] = self.num_sampled_tactics
|
| 499 |
+
outcome['stats']['temperature'] = self.temperature
|
| 500 |
+
outcome['key_final_state'] = self.root
|
| 501 |
+
outcome['key_lean_count'] = self.count_lean_dict
|
| 502 |
+
outcome['key_proof'] = proof_dict
|
| 503 |
+
outcome['tactic_list_tree'] = self.tactic_list_tree
|
| 504 |
+
outcome['round_count'] = self.round_count
|
| 505 |
+
|
| 506 |
+
pickle.dump(outcome, open(self.saved_file_path, 'wb'))
|
| 507 |
+
|
| 508 |
+
if __name__ == '__main__':
|
| 509 |
+
'''os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
| 510 |
+
if os.environ.get("TOKENIZERS_PARALLELISM") == "true":
|
| 511 |
+
print("TOKENIZERS_PARALLELISM is set to true")
|
| 512 |
+
else:
|
| 513 |
+
print("TOKENIZERS_PARALLELISM is not set to true")'''
|
| 514 |
+
|
| 515 |
+
parser = argparse.ArgumentParser(description='Description of your program.')
|
| 516 |
+
parser.add_argument('checkpoint_path', type=str, help='checkpoint_path')
|
| 517 |
+
parser.add_argument('number_of_gpu', type=int, help='number_of_gpu')
|
| 518 |
+
parser.add_argument('test_data_path', type=str, help='test_data_path')
|
| 519 |
+
parser.add_argument('saved_file_path', type=str, help='test_data_path')
|
| 520 |
+
parser.add_argument('max_workers', type=str, help='test_data_path')
|
| 521 |
+
parser.add_argument('num_sampled_tactics', type=str, help='test_data_path')
|
| 522 |
+
parser.add_argument('temperature', type=str, help='test_data_path')
|
| 523 |
+
parser.add_argument('num_test_theorem', type=str, help='test_data_path')
|
| 524 |
+
|
| 525 |
+
args = parser.parse_args()
|
| 526 |
+
checkpoint = args.checkpoint_path
|
| 527 |
+
number_of_gpu = args.number_of_gpu
|
| 528 |
+
test_data_path = args.test_data_path
|
| 529 |
+
saved_file_path = args.saved_file_path
|
| 530 |
+
max_workers = args.max_workers
|
| 531 |
+
num_sampled_tactics = args.num_sampled_tactics
|
| 532 |
+
temperature = args.temperature
|
| 533 |
+
num_test_theorem = args.num_test_theorem
|
| 534 |
+
swap_space = 100
|
| 535 |
+
|
| 536 |
+
print(f'checkpoint is {checkpoint}')
|
| 537 |
+
print(f'number_of_gpu is {number_of_gpu}')
|
| 538 |
+
print(f'test_data_path is {test_data_path}')
|
| 539 |
+
print(f'saved_file_path is {saved_file_path}')
|
| 540 |
+
print(f'num_test_theorem is {num_test_theorem}')
|
| 541 |
+
|
| 542 |
+
|
| 543 |
+
print(f'max_workers is {max_workers}')
|
| 544 |
+
print(f'number_sampled_tactic is {num_sampled_tactics}')
|
| 545 |
+
print(f'temperature is {temperature}')
|
| 546 |
+
|
| 547 |
+
print(f'swap_space is {swap_space}')
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
random.seed(42)
|
| 551 |
+
with open(test_data_path, 'r') as f:
|
| 552 |
+
test_theorem_list = json.load(f)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
test_theorem_list = test_theorem_list[:int(num_test_theorem)]
|
| 556 |
+
#print(test_theorem_list)
|
| 557 |
+
#test_theorem_list = [9573315344600956080853155758,]
|
| 558 |
+
|
| 559 |
+
experiment_id = uuid.uuid4() # Generates a random UUID.
|
| 560 |
+
print(experiment_id)
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
llm = LLM(model=checkpoint, tensor_parallel_size=number_of_gpu, swap_space=swap_space)
|
| 565 |
+
evaluate_obj = DFS(num_sampled_tactics=int(num_sampled_tactics), temperature=float(temperature), test_theorem_list=test_theorem_list, max_workers=int(max_workers), saved_file_path=saved_file_path, experiment_id=experiment_id)
|
| 566 |
+
evaluate_obj.search()
|
| 567 |
+
print('Now we start saving')
|
| 568 |
+
evaluate_obj.save_outcome()
|
| 569 |
+
print('Now we finish saving. exit')
|
| 570 |
+
#command = "rm DFS/temp/temp_file*"
|
| 571 |
+
#result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
|