Yee172 commited on
Commit
e179c89
·
verified ·
1 Parent(s): b1123de

Upload inference code

Browse files
Files changed (1) hide show
  1. DFS_search_with_concurrent.py +571 -0
DFS_search_with_concurrent.py ADDED
@@ -0,0 +1,571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from propositional_logic.random_gen.evaluation_access import *
2
+ from vllm import LLM, SamplingParams
3
+ import time
4
+ import json
5
+ import pickle
6
+ import os
7
+ import time
8
+ import re
9
+ import sys
10
+ from tqdm import tqdm
11
+ import argparse
12
+ from loguru import logger
13
+ from tqdm import tqdm
14
+ import concurrent.futures
15
+ import random
16
+ import subprocess
17
+
18
+ import uuid
19
+
20
+
21
+
22
+
23
+
24
+ #python -m DFS.DFS_search_with_concurrent ~/train_scripts/train_scripts_prop_serv8/v3_2_0_66_0_8/checkpoint-1600 1 ~/leandojo_project/proplogic_serv8/train_val_test/data_45w/data_5_vars/key_directory/key_20w_quantile_0_66_0_8_out_dist_test.json ~/leandojo_project/atp_research/DFS/output/outcome_basic_key_ge_080_1000.pkl
25
+
26
+ class DFS:
27
+ def __init__(
28
+ self,
29
+ num_sampled_tactics: int,
30
+ temperature,
31
+ test_theorem_list,
32
+ max_workers,
33
+ saved_file_path,
34
+ experiment_id
35
+ ) -> None:
36
+ self.max_workers = max_workers
37
+ self.test_theorem_list = test_theorem_list
38
+ self.whether_backtrack = {}
39
+ self.num_sampled_tactics = num_sampled_tactics
40
+ self.count_lean_dict = {}
41
+ self.counter_success = 0
42
+ self.counter_failed = 0
43
+ self.counter_in_process = 0
44
+ self.counter_too_long = 0
45
+ self.tactic_list_tree = {}
46
+ self.theorem_object_dict = {}
47
+ self.prompts_tactic_state_list = {}
48
+ self.root = {}
49
+ self.round_count = {}
50
+ self.parent_node_of_node = {}
51
+ self.round = 0
52
+ self.saved_file_path = saved_file_path
53
+ self.cstate_round = {} #should never be used once we want to apply tactic. should only be used at the beginning of each round
54
+ self.experiment_id = experiment_id
55
+ self.key_to_be_infered = []
56
+ self.key_not_finished = []
57
+ self.prompts_entered = []
58
+ self.counter_failed_with_error = 0
59
+ self.temperature = temperature
60
+
61
+ print(f"{len(self.test_theorem_list)} many theorem loaded")
62
+ for key in tqdm(self.test_theorem_list):
63
+ sample_eval = SingleTheoremEval(5, int(key))
64
+ self.theorem_object_dict[key] = sample_eval
65
+ init_state = self.theorem_object_dict[key].get_initial_prompt()
66
+
67
+ self.parent_node_of_node[key] = {}
68
+ self.prompts_tactic_state_list[key] = [init_state]
69
+ self.root[key] = 'open' #open, success or failed
70
+ self.tactic_list_tree[key] = {}
71
+ self.tactic_list_tree[key]["state_0:"] = None
72
+ self.parent_node_of_node[key]["state_0:"] = None
73
+ self.count_lean_dict[key] = {}
74
+ self.count_lean_dict[key]['count_lean_multiple_backtrack'] = 0
75
+ self.count_lean_dict[key]['count_lean_single_backtrack'] = 0
76
+ self.count_lean_dict[key]['count_lean_tactic_success'] = 0
77
+ self.round_count[key] = 0
78
+ self.whether_backtrack[key] = False
79
+
80
+ print('initialization done')
81
+ def get_current_state_number(self, key):
82
+ string = self.theorem_object_dict[key].get_current_state_with_label()
83
+ for line in string.split('\n'):
84
+ break
85
+ return line
86
+ def get_prev_state_number(self, key):
87
+ string = self.theorem_object_dict[key].get_prev_state_with_label()
88
+ for line in string.split('\n'):
89
+ break
90
+ return line
91
+ def revise_entered_tactic(self,entered_tactic,key):
92
+ if len(entered_tactic) != 2:
93
+ assert True==False
94
+ current_state_label = self.get_current_state_number(key)
95
+ entered_tactic[0] = current_state_label[:-1] + '_tactic_0:'
96
+ return entered_tactic
97
+ def back_track_tactic(self, key):
98
+ current_state_number = self.get_current_state_number(key)
99
+ for line in current_state_number.split('\n'):
100
+ break
101
+ match = re.search(r'\d+', line)
102
+ if match:
103
+ extracted_current_integer = int(match.group())
104
+ else:
105
+ assert False, 'no number in current state'
106
+
107
+ previous_state = self.parent_node_of_node[key][current_state_number]
108
+ previous_state_to_be_checked = self.get_prev_state_number(key)
109
+ if previous_state != previous_state_to_be_checked:
110
+ assert False, f'key is {key}, during backtrack, previous state marked and previous state by system are not the same'
111
+ for line in previous_state.split('\n'):
112
+ break
113
+ match = re.search(r'\d+', line)
114
+ if match:
115
+ extracted_previous_integer = int(match.group())
116
+ else:
117
+ assert False
118
+ return f"no solution, return to state {extracted_previous_integer} [that leads to state {extracted_current_integer}]"
119
+ def revise_output_list(self, output_text):
120
+ output_line_list = output_text.split("\n")
121
+ is_tactic = False
122
+ for idx_tactic, line in enumerate(output_line_list):
123
+ if '_tactic_' in line:
124
+ is_tactic = True
125
+ break
126
+
127
+ if is_tactic==False:
128
+ #print('output, warning: no tactic')
129
+ return ['no_tactic','no_tactic']
130
+
131
+ if "::: " in output_line_list[idx_tactic]:
132
+ output_line_list[idx_tactic] = output_line_list[idx_tactic][4:]
133
+ entered_tactic_list = output_line_list[idx_tactic:idx_tactic+2]
134
+
135
+ if len(entered_tactic_list) == 1:
136
+ return ['no_tactic','no_tactic']
137
+ return entered_tactic_list
138
+ def check_if_failure_per_key(self, key):
139
+ if len(self.tactic_list_tree[key]['state_0:']) == 0 and self.get_current_state_number(key) == 'state_0:':
140
+ print('triggered failure')
141
+ return True
142
+ else:
143
+ return False
144
+ def check_path_length(self,key):
145
+ current_state_label = self.get_current_state_number(key)
146
+ previous_state_label = current_state_label
147
+ theorem_object_length = 1
148
+ #print(f'key is {key}, current tactic list tree is {self.tactic_list_tree[key]}')
149
+ while True:
150
+ #print("state chain label is: ", previous_state_label)
151
+ if previous_state_label == 'state_0:':
152
+ break
153
+ previous_state_label = self.parent_node_of_node[key][current_state_label]
154
+ theorem_object_length += 1
155
+ current_state_label = previous_state_label
156
+ #print(f'prompt tactic list length is {len(self.prompts_tactic_state_list[key])}')
157
+ #print(f'theorem_object_length is {theorem_object_length}')
158
+ if theorem_object_length != len(self.prompts_tactic_state_list[key]):
159
+ assert True==False, "path_length not equal to each other"
160
+ def check_if_program_finished(self):
161
+ stop_signal = True
162
+ for key in self.test_theorem_list:
163
+ if self.root[key] == 'open':
164
+ stop_signal = False
165
+ else:
166
+ pass
167
+ return stop_signal
168
+ def revise_prompt(self, prompts_tactic_state_list_per_key):
169
+ pattern = r'state_\d+:'
170
+ matches = re.findall(pattern, prompts_tactic_state_list_per_key)
171
+ state_order = {}
172
+ order = 0
173
+ for match in matches:
174
+ if match not in state_order:
175
+ state_order[match] = order
176
+ order += 1
177
+
178
+ for state, ord in state_order.items():
179
+ prompts_tactic_state_list_per_key = prompts_tactic_state_list_per_key.replace(state, f'state_{ord}:')
180
+
181
+ last_state_id = None
182
+ output_prompt = []
183
+ for line in prompts_tactic_state_list_per_key.split('\n'):
184
+ if re.search('state_\d+:', line):
185
+ last_state_id = line[6:-1]
186
+ elif re.search('state_\d+_tactic_', line):
187
+ line = f'state_{last_state_id}_tactic_0:'
188
+ output_prompt.append(line)
189
+
190
+
191
+ '''for idx, item in enumerate(prompts_tactic_state_list_per_key):
192
+ temp_string = re.sub(r'state_(\d+):',f'state_{idx}:', item)
193
+ prompts_tactic_state_list_per_key[idx] = re.sub(r'state_(\d+)_tactic_(\d+):',f'state_{idx-1}_tactic_0:', temp_string)'''
194
+ return "\n".join(output_prompt)
195
+ def status_report(self):
196
+ counter_in_process = 0
197
+ counter_success = 0
198
+ counter_failed = 0
199
+ counter_too_long = 0
200
+ counter_failed_with_error = 0
201
+ for key in self.test_theorem_list:
202
+ if self.root[key] == 'open':
203
+ counter_in_process += 1
204
+ if self.root[key] == 'success':
205
+ counter_success += 1
206
+ if self.root[key] == 'failed':
207
+ counter_failed += 1
208
+ if self.root[key] == 'failed, too long':
209
+ counter_too_long += 1
210
+ if self.root[key] == 'failed with error':
211
+ counter_failed_with_error += 1
212
+ self.counter_success = counter_success
213
+ self.counter_failed = counter_failed
214
+ self.counter_failed_with_error = counter_failed_with_error
215
+ self.counter_too_long = counter_too_long
216
+ self.counter_in_process = counter_in_process
217
+ if counter_success + counter_failed + counter_too_long + counter_in_process + counter_failed_with_error != len(test_theorem_list):
218
+ assert False, 'success, failed, too long, in process, failed with error add up not equal to total number'
219
+ print(f'saved_file_path is {self.saved_file_path}')
220
+ print(f'total number of theorem is {len(self.test_theorem_list)}')
221
+ print(f'proof success number is {self.counter_success}')
222
+ print(f'proof failed number is {self.counter_failed}')
223
+ print(f'proof failed with error number is {self.counter_failed_with_error}')
224
+ print(f'proof too long number is {self.counter_too_long}')
225
+ print(f'proof in process number is {self.counter_in_process}')
226
+
227
+ count_lean_single_backtrack = 0
228
+ count_lean_multiple_backtrack = 0
229
+ count_lean_tactic_success = 0
230
+ for key in test_theorem_list:
231
+ count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack']
232
+ count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack']
233
+ count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success']
234
+
235
+ print(f'total lean count tactic success is {count_lean_tactic_success}')
236
+ print(f'total lean count single backtrack is {count_lean_single_backtrack}')
237
+ print(f'total lean count multiple backtrack is {count_lean_multiple_backtrack}')
238
+ def collect_inference_result(self, key_to_be_infered, outputs):
239
+ for idx, output_list in tqdm(enumerate(outputs), total=len(outputs), desc=f"Processing LLM output for Round {self.round}"):
240
+ assinged_output_list_per_key = []
241
+ for i in range(0, self.num_sampled_tactics):
242
+ output_tactic = self.revise_output_list(output_list.outputs[i].text)
243
+ if output_tactic[0] == 'no_tactic' or output_tactic[1] == 'no_tactic':
244
+ pass
245
+ else:
246
+ assinged_output_list_per_key.append(output_tactic)
247
+ #print(f'key is {key_to_be_infered[idx]}, output_tactic is {output_tactic}')
248
+ #print()
249
+ seen = set()
250
+ unique_assigned_output_list_per_key = []
251
+ for inner_list in assinged_output_list_per_key:
252
+ inner_tuple = tuple(inner_list)
253
+ if inner_tuple not in seen:
254
+ seen.add(inner_tuple)
255
+ unique_assigned_output_list_per_key.append(inner_list)
256
+ #print(
257
+ # f'key is {key_to_be_infered[idx]}, state_number to be assigned new tactic list is {self.cstate_round[key_to_be_infered[idx]]}')
258
+ #print(f'key is {key_to_be_infered[idx]}, Assigned tactic list is {unique_assigned_output_list_per_key}')
259
+ self.tactic_list_tree[key_to_be_infered[idx]][self.cstate_round[key_to_be_infered[idx]]] = unique_assigned_output_list_per_key
260
+
261
+ def current_state_obtained_list(self, key):
262
+ if self.root[key] == 'open':
263
+ self.cstate_round[key] = self.get_current_state_number(key)
264
+ cstate = self.cstate_round[key]
265
+ pickle.dump(cstate, open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','wb'))
266
+ def search(self):
267
+ tokenizer = llm.get_tokenizer()
268
+ while True:
269
+ self.round += 1
270
+ print(f'Round {self.round}------')
271
+ if self.check_if_program_finished() or self.round > 65:
272
+ print('confirmed test theorem finished. exit.')
273
+ self.status_report()
274
+ break
275
+
276
+ self.key_to_be_infered = []
277
+ self.key_not_finished = []
278
+ self.prompts_entered = []
279
+
280
+ with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
281
+ list(tqdm(executor.map(self.current_state_obtained_list, self.test_theorem_list), total=len(self.test_theorem_list), desc=f"get current state for Round {self.round}"))
282
+
283
+ key_prompt_length_list = []
284
+ for key in tqdm(self.test_theorem_list, total=len(self.test_theorem_list), desc=f"check state for theorems for round {self.round}" ):
285
+ if self.root[key] == 'open':
286
+ self.round_count[key] = self.round
287
+ self.key_not_finished.append(key)
288
+ self.cstate_round[key] = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/current_state_{key}_{self.round}_{self.experiment_id}.pkl','rb'))
289
+ if self.tactic_list_tree[key][self.cstate_round[key]] == None:
290
+ prompt_per_key = self.revise_prompt('\n'.join(self.prompts_tactic_state_list[key]))
291
+ key_prompt_length_list.append(len(prompt_per_key.split()))
292
+ tokenized_prompt_per_key = tokenizer.encode(prompt_per_key)
293
+ if len(prompt_per_key.split()) < 1500 and len(tokenized_prompt_per_key) < 4000: # used to be 1500
294
+ self.key_to_be_infered.append(key)
295
+ self.prompts_entered.append(prompt_per_key)
296
+ else:
297
+ self.root[key] = 'failed, too long'
298
+ self.key_not_finished.remove(key)
299
+ print(f'key open need inference before check length, length list is {key_prompt_length_list}')
300
+
301
+ print(f'key to be infered is {self.key_to_be_infered}')
302
+
303
+
304
+
305
+
306
+ sampling_params = SamplingParams(n=self.num_sampled_tactics, temperature=self.temperature, top_p=1,
307
+ max_tokens=200) # temperature is 1.2 at beginning
308
+ outputs = llm.generate(self.prompts_entered, sampling_params)
309
+
310
+
311
+ print('now we collect the inference')
312
+ self.collect_inference_result(self.key_to_be_infered, outputs)
313
+ print('inference collected')
314
+
315
+ print(f'enter concurrent process with max_workers as {self.max_workers}')
316
+
317
+ with concurrent.futures.ProcessPoolExecutor(max_workers=self.max_workers) as executor:
318
+ list(tqdm(executor.map(self.search_per_key_per_step, self.key_not_finished), total=len(self.key_not_finished), desc=f"Lean_verifying for Round {self.round}"))
319
+
320
+ '''for key in self.key_not_finished:
321
+ self.search_per_key_per_step(key)'''
322
+
323
+
324
+ for key in tqdm(self.key_not_finished, total=len(self.key_not_finished), desc=f"Retrieve results from concurrent processing for Round {self.round}"):
325
+ report_key_file = pickle.load(open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','rb'))
326
+ self.count_lean_dict[key]['count_lean_single_backtrack'] += report_key_file['count_lean_single_backtrack']
327
+ self.count_lean_dict[key]['count_lean_multiple_backtrack'] += report_key_file['count_lean_multiple_backtrack']
328
+ self.count_lean_dict[key]['count_lean_tactic_success'] += report_key_file['count_lean_tactic_success']
329
+ self.root[key] = report_key_file['key_status']
330
+ self.tactic_list_tree[key] = report_key_file['tactic_list_tree']
331
+ self.prompts_tactic_state_list[key] = report_key_file['prompts_tactic_state_list']
332
+ self.theorem_object_dict[key] = report_key_file['theorem_object_dict']
333
+ self.parent_node_of_node[key] = report_key_file['node_relation']
334
+ self.whether_backtrack[key] = report_key_file['whether_backtrack']
335
+
336
+ self.status_report()
337
+ self.save_outcome()
338
+ def search_per_key_per_step(self, key):
339
+ #print(f'key is {key}, current tactic_list_tree for Round {self.round} is {self.tactic_list_tree[key]}')
340
+
341
+ try:
342
+ key_status = 'open'
343
+ count_lean_tactic = 0
344
+ count_lean_single_backtrack = 0
345
+ count_lean_multiple_backtrack = 0
346
+ whether_backtrack = False
347
+
348
+ tactic_list_at_top_per_key = self.tactic_list_tree[key][self.get_current_state_number(key)]
349
+ #print(f"key is {key}, we start to search")
350
+ if tactic_list_at_top_per_key == None:
351
+ assert False, f"tactic_list_at_top_per_key is None, key is {key}\ncurrent tactic list is {self.tactic_list_tree[key]}" \
352
+ f"\nwhether key in key_to_be_infered {key in self.key_to_be_infered}"
353
+
354
+ if len(tactic_list_at_top_per_key) != 0:
355
+ try:
356
+ #print(f'key is {key}, We apply tactic')
357
+ count_lean_tactic += 1
358
+ entered_tactic = self.revise_entered_tactic(tactic_list_at_top_per_key[0], key)
359
+ label_before_tactic = self.get_current_state_number(key)
360
+ #print(f'key is {key}, get proof before apply tactic is: {self.theorem_object_dict[key].get_current_lean_proof()}')
361
+
362
+ #print(f'key is {key}, current tactic_list_tree before apply tactic is {self.tactic_list_tree[key]}')
363
+ #print(f'key is {key}, entered_tactic is {entered_tactic}')
364
+ #print(f'key is {key}, current state before apply tactic is: {label_before_tactic}')
365
+
366
+
367
+ lean_output = self.theorem_object_dict[key].provide_tactic(entered_tactic[0], entered_tactic[1])
368
+ label_after_tactic = self.get_current_state_number(key)
369
+ #print(f'key is {key}, current state after apply tactic is: {label_after_tactic}')
370
+
371
+ if 'proof is complete' in lean_output[1]:
372
+ key_status = 'success'
373
+ print(f'key is {key}, search is success!')
374
+ print(f'key is {key}, the successful proof \n{self.theorem_object_dict[key].get_current_lean_proof()}')
375
+ #self.status_report()
376
+
377
+ self.parent_node_of_node[key][label_after_tactic] = label_before_tactic
378
+ self.prompts_tactic_state_list[key].append(
379
+ f"{entered_tactic[0]}\n{entered_tactic[1]}\n{lean_output[1]}")
380
+ self.tactic_list_tree[key][self.get_current_state_number(key)] = None
381
+ del self.tactic_list_tree[key][label_before_tactic][0]
382
+ #print(f'key is {key}, current tactic_list_tree after apply tactic is {self.tactic_list_tree[key]}')
383
+ self.check_path_length(key)
384
+ except Exception as e:
385
+ #print('tactic error happen')
386
+ #print(f'key is {key}, we apply tactic {entered_tactic} and see error')
387
+ #print(e)
388
+ #print(f'key is {key}, current state number after error after apply tactic is: {self.get_current_state_number(key)}')
389
+ #print(f'key is {key}, get proof after error after apply tactic is: {self.theorem_object_dict[key].get_current_lean_proof()}')
390
+ #print(f'key is {key}, current tactic_list_tree after error after apply tactic is {self.tactic_list_tree[key]}')
391
+ #print(f'key is {key}, we now delete the tactic from tactic list tree, the label is {self.get_current_state_number(key)}')
392
+ del self.tactic_list_tree[key][self.get_current_state_number(key)][0]
393
+ #print(f"key is {key}, tactic from tactic list tree is deleted, current tactic_list_tree is {self.tactic_list_tree[key]}")
394
+ self.check_path_length(key)
395
+ if self.check_if_failure_per_key(key):
396
+ key_status = 'failed'
397
+ print(f"key is {key}, tactic error then search failed!")
398
+ #self.status_report()
399
+ else:
400
+ if self.check_if_failure_per_key(key):
401
+ key_status = 'failed'
402
+ print(f"key is {key}, backtrack to zero and no tactic to try, search failed!")
403
+ #self.status_report()
404
+
405
+ print(f'key is {key}, backtrack phase activated')
406
+ whether_backtrack = True
407
+ count_lean_single_backtrack += 1
408
+ while True:
409
+ tactic_list_at_intermediate_node = self.tactic_list_tree[key][self.get_current_state_number(key)]
410
+ # print(f'tactic_tree_list: {self.tactic_list_tree[key]}')
411
+ # print(f"current state number: {self.get_current_state_number(key)}")
412
+ if len(tactic_list_at_intermediate_node) != 0:
413
+ break
414
+ if self.check_if_failure_per_key(key):
415
+ key_status = 'failed'
416
+ print(f"key is {key}, backtrack to zero and no tactic to try, search failed!")
417
+ #self.status_report()
418
+ break
419
+ #print(f'key is {key}, before back track step check length')
420
+ self.check_path_length(key)
421
+ #print(f'key is {key}, current state before backtrack is {self.get_current_state_number(key)}')
422
+ count_lean_multiple_backtrack += 1
423
+ lean_output = self.theorem_object_dict[key].do_back_track(self.back_track_tactic(key))
424
+ #print(f'current state after backtrack is {self.get_current_state_number(key)}')
425
+ #print('before delete the last ele of prompts_tactic_state_list')
426
+ #print('prompts_tactic_state_list is:')
427
+ #print(self.prompts_tactic_state_list[key])
428
+ #print(f'key for remove prompt here is {key}')
429
+
430
+ del self.prompts_tactic_state_list[key][-1]
431
+
432
+ #print('after back track step check length')
433
+ self.check_path_length(key)
434
+ except Exception as e:
435
+ print(f"key is {key}, exception happend")
436
+ print(e)
437
+ key_status = 'failed with error'
438
+ count_lean_tactic = 0
439
+ count_lean_single_backtrack = 0
440
+ count_lean_multiple_backtrack = 0
441
+
442
+ finally:
443
+ report_key_file = {}
444
+ report_key_file['count_lean_multiple_backtrack'] = count_lean_multiple_backtrack + count_lean_tactic
445
+ report_key_file['count_lean_single_backtrack'] = count_lean_single_backtrack + count_lean_tactic
446
+ report_key_file['count_lean_tactic_success'] = count_lean_tactic
447
+ report_key_file['key_status'] = key_status
448
+ report_key_file['tactic_list_tree'] = self.tactic_list_tree[key]
449
+ report_key_file['prompts_tactic_state_list'] = self.prompts_tactic_state_list[key]
450
+ report_key_file['theorem_object_dict'] = self.theorem_object_dict[key]
451
+ report_key_file['node_relation'] = self.parent_node_of_node[key]
452
+ report_key_file['whether_backtrack'] = whether_backtrack
453
+ pickle.dump(report_key_file, open(f'~/leandojo_project/atp_research/DFS/temp/temp_file_{key}_{self.round}_{self.experiment_id}.pkl','wb'))
454
+
455
+
456
+
457
+ def save_outcome(self):
458
+ counter_in_process = 0
459
+ counter_success = 0
460
+ counter_failed = 0
461
+ counter_too_long = 0
462
+ count_lean_single_backtrack = 0
463
+ count_lean_multiple_backtrack = 0
464
+ count_lean_tactic_success = 0
465
+ counter_failed_with_error = 0
466
+
467
+ proof_dict = {}
468
+ for key in tqdm(self.test_theorem_list,total=len(self.test_theorem_list),desc='saving results'):
469
+ count_lean_single_backtrack += self.count_lean_dict[key]['count_lean_single_backtrack']
470
+ count_lean_multiple_backtrack += self.count_lean_dict[key]['count_lean_multiple_backtrack']
471
+ count_lean_tactic_success += self.count_lean_dict[key]['count_lean_tactic_success']
472
+ proof_dict[key] = self.theorem_object_dict[key].get_current_lean_proof()
473
+ if self.root[key] == 'open':
474
+ counter_in_process += 1
475
+ if self.root[key] == 'success':
476
+ counter_success += 1
477
+ if self.root[key] == 'failed':
478
+ counter_failed += 1
479
+ if self.root[key] == 'failed, too long':
480
+ counter_too_long += 1
481
+ if self.root[key] == 'failed with error':
482
+ counter_failed_with_error += 1
483
+ self.counter_success = counter_success
484
+ self.counter_failed = counter_failed
485
+ self.counter_failed_with_error = counter_failed_with_error
486
+ self.counter_too_long = counter_too_long
487
+ if counter_success + counter_failed + counter_too_long + counter_failed_with_error + counter_in_process!= len(test_theorem_list):
488
+ assert False, 'number of theorm not equal to success, failed or too long, or in process'
489
+ outcome = {}
490
+ outcome['stats'] = {}
491
+ outcome['stats']['total_lean_count_single_backtrack'] = count_lean_single_backtrack
492
+ outcome['stats']['total_lean_count_multiple_backtrack'] = count_lean_multiple_backtrack
493
+ outcome['stats']['count_lean_tactic_success'] = count_lean_tactic_success
494
+ outcome['stats']['num_success'] = self.counter_success
495
+ outcome['stats']['num_failed'] = self.counter_failed
496
+ outcome['stats']['num_failed_with_error'] = self.counter_failed_with_error
497
+ outcome['stats']['num_too_long'] = self.counter_too_long
498
+ outcome['stats']['num_sampled_tactics'] = self.num_sampled_tactics
499
+ outcome['stats']['temperature'] = self.temperature
500
+ outcome['key_final_state'] = self.root
501
+ outcome['key_lean_count'] = self.count_lean_dict
502
+ outcome['key_proof'] = proof_dict
503
+ outcome['tactic_list_tree'] = self.tactic_list_tree
504
+ outcome['round_count'] = self.round_count
505
+
506
+ pickle.dump(outcome, open(self.saved_file_path, 'wb'))
507
+
508
+ if __name__ == '__main__':
509
+ '''os.environ["TOKENIZERS_PARALLELISM"] = "true"
510
+ if os.environ.get("TOKENIZERS_PARALLELISM") == "true":
511
+ print("TOKENIZERS_PARALLELISM is set to true")
512
+ else:
513
+ print("TOKENIZERS_PARALLELISM is not set to true")'''
514
+
515
+ parser = argparse.ArgumentParser(description='Description of your program.')
516
+ parser.add_argument('checkpoint_path', type=str, help='checkpoint_path')
517
+ parser.add_argument('number_of_gpu', type=int, help='number_of_gpu')
518
+ parser.add_argument('test_data_path', type=str, help='test_data_path')
519
+ parser.add_argument('saved_file_path', type=str, help='test_data_path')
520
+ parser.add_argument('max_workers', type=str, help='test_data_path')
521
+ parser.add_argument('num_sampled_tactics', type=str, help='test_data_path')
522
+ parser.add_argument('temperature', type=str, help='test_data_path')
523
+ parser.add_argument('num_test_theorem', type=str, help='test_data_path')
524
+
525
+ args = parser.parse_args()
526
+ checkpoint = args.checkpoint_path
527
+ number_of_gpu = args.number_of_gpu
528
+ test_data_path = args.test_data_path
529
+ saved_file_path = args.saved_file_path
530
+ max_workers = args.max_workers
531
+ num_sampled_tactics = args.num_sampled_tactics
532
+ temperature = args.temperature
533
+ num_test_theorem = args.num_test_theorem
534
+ swap_space = 100
535
+
536
+ print(f'checkpoint is {checkpoint}')
537
+ print(f'number_of_gpu is {number_of_gpu}')
538
+ print(f'test_data_path is {test_data_path}')
539
+ print(f'saved_file_path is {saved_file_path}')
540
+ print(f'num_test_theorem is {num_test_theorem}')
541
+
542
+
543
+ print(f'max_workers is {max_workers}')
544
+ print(f'number_sampled_tactic is {num_sampled_tactics}')
545
+ print(f'temperature is {temperature}')
546
+
547
+ print(f'swap_space is {swap_space}')
548
+
549
+
550
+ random.seed(42)
551
+ with open(test_data_path, 'r') as f:
552
+ test_theorem_list = json.load(f)
553
+
554
+
555
+ test_theorem_list = test_theorem_list[:int(num_test_theorem)]
556
+ #print(test_theorem_list)
557
+ #test_theorem_list = [9573315344600956080853155758,]
558
+
559
+ experiment_id = uuid.uuid4() # Generates a random UUID.
560
+ print(experiment_id)
561
+
562
+
563
+
564
+ llm = LLM(model=checkpoint, tensor_parallel_size=number_of_gpu, swap_space=swap_space)
565
+ evaluate_obj = DFS(num_sampled_tactics=int(num_sampled_tactics), temperature=float(temperature), test_theorem_list=test_theorem_list, max_workers=int(max_workers), saved_file_path=saved_file_path, experiment_id=experiment_id)
566
+ evaluate_obj.search()
567
+ print('Now we start saving')
568
+ evaluate_obj.save_outcome()
569
+ print('Now we finish saving. exit')
570
+ #command = "rm DFS/temp/temp_file*"
571
+ #result = subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)