ZaynZhu commited on
Commit
7c08dc3
·
0 Parent(s):

Clean version without large assets

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +40 -0
  2. .gitignore +18 -0
  3. Paper2Poster/.gitignore +17 -0
  4. Paper2Poster/LICENSE +21 -0
  5. Paper2Poster/Paper2Poster-eval/create_paper_questions.py +40 -0
  6. Paper2Poster/Paper2Poster-eval/eval_poster_pipeline.py +479 -0
  7. Paper2Poster/Paper2Poster-eval/eval_qa_fix.py +114 -0
  8. Paper2Poster/PosterAgent/LLM_direct_generate.py +103 -0
  9. Paper2Poster/PosterAgent/LLM_direct_generate_beamer.py +189 -0
  10. Paper2Poster/PosterAgent/__init__.py +16 -0
  11. Paper2Poster/PosterAgent/apply_theme.py +281 -0
  12. Paper2Poster/PosterAgent/beamer_pipeline.py +182 -0
  13. Paper2Poster/PosterAgent/create_dataset.py +69 -0
  14. Paper2Poster/PosterAgent/deoverflow.py +234 -0
  15. Paper2Poster/PosterAgent/deoverflow_parallel.py +485 -0
  16. Paper2Poster/PosterAgent/fill_and_style.py +215 -0
  17. Paper2Poster/PosterAgent/gen_beamer_code.py +299 -0
  18. Paper2Poster/PosterAgent/gen_outline_layout.py +851 -0
  19. Paper2Poster/PosterAgent/gen_outline_layout_parallel.py +949 -0
  20. Paper2Poster/PosterAgent/gen_poster_content.py +529 -0
  21. Paper2Poster/PosterAgent/gen_pptx_code.py +249 -0
  22. Paper2Poster/PosterAgent/new_pipeline.py +547 -0
  23. Paper2Poster/PosterAgent/parse_raw.py +237 -0
  24. Paper2Poster/PosterAgent/poster_gen_pipeline.py +101 -0
  25. Paper2Poster/PosterAgent/tree_split_layout.py +750 -0
  26. Paper2Poster/README.md +215 -0
  27. Paper2Poster/__init__.py +3 -0
  28. Paper2Poster/camel/__init__.py +25 -0
  29. Paper2Poster/camel/agents/__init__.py +44 -0
  30. Paper2Poster/camel/agents/base.py +29 -0
  31. Paper2Poster/camel/agents/chat_agent.py +1539 -0
  32. Paper2Poster/camel/agents/critic_agent.py +202 -0
  33. Paper2Poster/camel/agents/deductive_reasoner_agent.py +303 -0
  34. Paper2Poster/camel/agents/embodied_agent.py +201 -0
  35. Paper2Poster/camel/agents/knowledge_graph_agent.py +259 -0
  36. Paper2Poster/camel/agents/multi_hop_generator_agent.py +117 -0
  37. Paper2Poster/camel/agents/programmed_agent_instruction.py +203 -0
  38. Paper2Poster/camel/agents/role_assignment_agent.py +141 -0
  39. Paper2Poster/camel/agents/search_agent.py +133 -0
  40. Paper2Poster/camel/agents/task_agent.py +410 -0
  41. Paper2Poster/camel/agents/tool_agents/__init__.py +20 -0
  42. Paper2Poster/camel/agents/tool_agents/base.py +39 -0
  43. Paper2Poster/camel/agents/tool_agents/hugging_face_tool_agent.py +206 -0
  44. Paper2Poster/camel/benchmarks/__init__.py +30 -0
  45. Paper2Poster/camel/benchmarks/apibank.py +565 -0
  46. Paper2Poster/camel/benchmarks/apibench.py +500 -0
  47. Paper2Poster/camel/benchmarks/base.py +152 -0
  48. Paper2Poster/camel/benchmarks/gaia.py +478 -0
  49. Paper2Poster/camel/benchmarks/nexus.py +518 -0
  50. Paper2Poster/camel/benchmarks/ragbench.py +333 -0
.gitattributes ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.png filter=lfs diff=lfs merge=lfs -text
37
+ *.jpg filter=lfs diff=lfs merge=lfs -text
38
+ *.pdf filter=lfs diff=lfs merge=lfs -text
39
+ *.wav filter=lfs diff=lfs merge=lfs -text
40
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input/
2
+ output/Paper2Poster/assets/
3
+ Paper2Video/assets/
4
+ posterbuilder/latex_proj/figures/
5
+ *.png
6
+ *.pdf
7
+ *.jpg
8
+ *.wav
9
+ *.mp4
10
+ __pycache__/
11
+ *.png
12
+ *.jpg
13
+ *.pdf
14
+ *.wav
15
+ *.mp4
16
+ Paper2Poster/assets/
17
+ Paper2Video/assets/
18
+ posterbuilder/latex_proj/figures/
Paper2Poster/.gitignore ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ .vscode/
3
+ ablations/
4
+ **/__pycache__/
5
+ *_generated_posters/
6
+ *_images_and_tables/
7
+ contents/
8
+ tmp/
9
+ tree_splits/
10
+ eval_results/
11
+ Paper2Poster-data/
12
+ Example/
13
+ *.ipynb
14
+ eval_time_detail_parallel/
15
+ *.sh
16
+ .claude
17
+ CLAUDE.md
Paper2Poster/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Paper2Poster
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
Paper2Poster/Paper2Poster-eval/create_paper_questions.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.poster_eval_utils import *
2
+ import argparse
3
+ import os
4
+ import json
5
+
6
+ if __name__ == '__main__':
7
+ parser = argparse.ArgumentParser()
8
+ parser.add_argument('--paper_folder', type=str, default=None)
9
+ parser.add_argument('--model_name', type=str, default='o3')
10
+ args = parser.parse_args()
11
+
12
+ paper_text = get_poster_text(os.path.join(args.paper_folder, 'paper.pdf'))
13
+
14
+ if args.model_name == '4o':
15
+ model_type = ModelType.GPT_4O
16
+ elif args.model_name == 'o3':
17
+ model_type = ModelType.O3
18
+ detail_qa = get_questions(paper_text, 'detail', model_type)
19
+ understanding_qa = get_questions(paper_text, 'understanding', model_type)
20
+
21
+ detail_q, detail_a, detail_aspects = get_answers_and_remove_answers(detail_qa)
22
+ understanding_q, understanding_a, understanding_aspects = get_answers_and_remove_answers(understanding_qa)
23
+
24
+ final_qa = {}
25
+ detail_qa = {
26
+ 'questions': detail_q,
27
+ 'answers': detail_a,
28
+ 'aspects': detail_aspects,
29
+ }
30
+
31
+ understanding_qa = {
32
+ 'questions': understanding_q,
33
+ 'answers': understanding_a,
34
+ 'aspects': understanding_aspects,
35
+ }
36
+ final_qa['detail'] = detail_qa
37
+ final_qa['understanding'] = understanding_qa
38
+
39
+ with open(os.path.join(args.paper_folder, f'{args.model_name}_qa.json'), 'w') as f:
40
+ json.dump(final_qa, f, indent=4)
Paper2Poster/Paper2Poster-eval/eval_poster_pipeline.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.poster_eval_utils import *
2
+ import json
3
+ from utils.wei_utils import get_agent_config
4
+ import argparse
5
+ from dotenv import load_dotenv
6
+ import tempfile
7
+ import shutil
8
+ import os
9
+ import glob
10
+ import re
11
+
12
+ load_dotenv()
13
+
14
+ def run_qa_and_update_results(
15
+ args,
16
+ raw_folder,
17
+ gen_poster_path,
18
+ save_path,
19
+ single_model_name=None,
20
+ del_model_name=None,
21
+ ):
22
+ """
23
+ If single_model_name is provided, run QA for that one model only,
24
+ but update an existing JSON file (which already contains the other
25
+ models' results) and re-compute the overall averages.
26
+
27
+ If single_model_name is None, run QA for all models in all_model_names
28
+ and write a new JSON file.
29
+
30
+ :param raw_folder: Path to folder with 'o3_qa.json'.
31
+ :param gen_poster_path: Path to the generated poster image.
32
+ :param save_path: Directory where overall_qa_result.json is saved or should be written.
33
+ :param all_model_names: List of model names (e.g. ['vllm_qwen_vl', '4o', 'o3']).
34
+ :param single_model_name: Optional single model name.
35
+ """
36
+
37
+ # Load the QA data (questions, answers, aspects)
38
+ qa_dict = json.load(open(os.path.join(raw_folder, 'o3_qa.json'), 'r'))
39
+ detail_qa = qa_dict['detail']
40
+ understanding_qa = qa_dict['understanding']
41
+
42
+ # Option A: Single model case
43
+ if single_model_name is not None:
44
+ qa_input_token, qa_output_token = 0, 0
45
+ # Load the existing JSON with all previously computed results
46
+ existing_path = os.path.join(save_path, "overall_qa_result.json")
47
+ with open(existing_path, 'r') as f:
48
+ overall_qa_result = json.load(f)
49
+
50
+ if del_model_name is not None:
51
+ # Remove the specified model from the existing results
52
+ if del_model_name in overall_qa_result['qa_result']:
53
+ del overall_qa_result['qa_result'][del_model_name]
54
+ print(f"Removed model {del_model_name} from existing results.")
55
+
56
+ if single_model_name in overall_qa_result['qa_result']:
57
+ print(f"Model {single_model_name} already evaluated. Skipping.")
58
+ return
59
+
60
+ # Evaluate QA for the single_model_name
61
+ print(f"Running QA for single model: {single_model_name}")
62
+ agent_config = get_agent_config(single_model_name)
63
+
64
+ if args.poster_method == 'paper':
65
+ poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), format='jpg')
66
+ else:
67
+ poster_images = [Image.open(gen_poster_path)]
68
+
69
+ poster_images = [ensure_under_limit_pil(image) for image in poster_images]
70
+
71
+ detail_accuracy, detail_aspect_accuracy, detail_agent_answers, input_token, output_token = eval_qa_get_answer(
72
+ poster_input=poster_images,
73
+ questions=detail_qa['questions'],
74
+ answers=detail_qa['answers'],
75
+ aspects=detail_qa['aspects'],
76
+ input_type='image',
77
+ agent_config=agent_config
78
+ )
79
+ qa_input_token += input_token
80
+ qa_output_token += output_token
81
+ print('Detail QA accuracy:', detail_accuracy)
82
+
83
+ understanding_accuracy, understanding_aspect_accuracy, understanding_agent_answers, input_token, output_token = eval_qa_get_answer(
84
+ poster_input=poster_images,
85
+ questions=understanding_qa['questions'],
86
+ answers=understanding_qa['answers'],
87
+ aspects=understanding_qa['aspects'],
88
+ input_type='image',
89
+ agent_config=agent_config
90
+ )
91
+ qa_input_token += input_token
92
+ qa_output_token += output_token
93
+ print('Understanding QA accuracy:', understanding_accuracy)
94
+
95
+ # Update QA result for this one model
96
+ # overall_qa_result["qa_result"] is assumed to already have the others
97
+ overall_qa_result['qa_result'][single_model_name] = {
98
+ 'detail_accuracy': detail_accuracy,
99
+ 'detail_aspect_accuracy': detail_aspect_accuracy,
100
+ 'detail_agent_answers': detail_agent_answers,
101
+ 'understanding_accuracy': understanding_accuracy,
102
+ 'understanding_aspect_accuracy': understanding_aspect_accuracy,
103
+ 'understanding_agent_answers': understanding_agent_answers
104
+ }
105
+
106
+ # Now re-compute the averages across all models present in the JSON
107
+ # Grab all model entries from overall_qa_result['qa_result']
108
+ all_models_in_file = list(overall_qa_result['qa_result'].keys())
109
+ detail_accs = []
110
+ understanding_accs = []
111
+ for m in all_models_in_file:
112
+ detail_accs.append(overall_qa_result['qa_result'][m]['detail_accuracy'])
113
+ understanding_accs.append(overall_qa_result['qa_result'][m]['understanding_accuracy'])
114
+
115
+ avg_detail_accuracy = float(np.mean(detail_accs)) if detail_accs else 0.0
116
+ avg_understanding_accuracy = float(np.mean(understanding_accs)) if understanding_accs else 0.0
117
+
118
+ overall_qa_result['avg_detail_accuracy'] = avg_detail_accuracy
119
+ overall_qa_result['avg_understanding_accuracy'] = avg_understanding_accuracy
120
+
121
+ # Finally, overwrite the same JSON file with the updated results
122
+ with open(existing_path, 'w') as f:
123
+ json.dump(overall_qa_result, f, indent=4)
124
+
125
+ print(f'Input tokens: {qa_input_token}')
126
+ print(f'Output tokens: {qa_output_token}')
127
+
128
+ print('Updated overall_qa_result.json with single-model results.')
129
+ print('New average detail accuracy:', avg_detail_accuracy)
130
+ print('New average understanding accuracy:', avg_understanding_accuracy)
131
+
132
+ if __name__ == '__main__':
133
+ parser = argparse.ArgumentParser()
134
+ parser.add_argument('--paper_name', type=str)
135
+ parser.add_argument('--base_dir', type=str, default='Paper2Poster-data')
136
+ parser.add_argument('--poster_method', type=str)
137
+ parser.add_argument('--poster_image_name', type=str, default='poster.png', choices=['poster.png'])
138
+ parser.add_argument('--metric', type=str, choices=['stats', 'qa', 'judge', 'word_count', 'token_count', 'figure_count', 'aesthetic_judge'], default='stats')
139
+ parser.add_argument('--fix', type=str, default=None)
140
+ parser.add_argument('--del_model_name', type=str, default=None)
141
+
142
+ args = parser.parse_args()
143
+
144
+ raw_poster_path = f'{args.base_dir}/{args.paper_name}/poster.png'
145
+ raw_folder = f'{args.base_dir}/{args.paper_name}'
146
+
147
+ gen_poster_path = f'{args.poster_method}/{args.base_dir}/{args.paper_name}/{args.poster_image_name}'
148
+ gen_folder = f'{args.poster_method}/{args.base_dir}/{args.paper_name}'
149
+
150
+ save_path = f'eval_results/{args.paper_name}/{args.poster_method}'
151
+ os.makedirs(save_path, exist_ok=True)
152
+
153
+ if args.poster_method == 'paper':
154
+ if args.metric == 'qa' and args.fix is not None:
155
+ overall_qa_result = json.load(open(f'{save_path}/overall_qa_result.json', 'r'))
156
+ if args.fix in overall_qa_result['qa_result']:
157
+ print(f"Model {args.fix} already evaluated. Skipping.")
158
+ exit(0)
159
+ # create a temp folder to store the paper
160
+ # 1) Create a unique temp folder
161
+ temp_dir = tempfile.mkdtemp(prefix="eval_temp", suffix="_data")
162
+
163
+ # 2) Build your source directory path, replacing spaces
164
+ paper_slug = args.paper_name.replace(' ', '_')
165
+ source_dir = os.path.join('<4o_vllm_qwen>_images_and_tables', paper_slug)
166
+
167
+ # 3) Sequentially copy files named "<paper_slug>-<index>.png"
168
+ index = 1
169
+ while True:
170
+ filename = f"{paper_slug}-{index}.png"
171
+ src_path = os.path.join(source_dir, filename)
172
+ if not os.path.isfile(src_path):
173
+ # stop once the next index is missing
174
+ break
175
+ shutil.copy2(src_path, os.path.join(temp_dir, filename))
176
+ index += 1
177
+ if index > 20 and args.metric != 'word_count' and args.metric != 'token_count':
178
+ break
179
+
180
+ gen_folder = temp_dir
181
+ gen_poster_path = f'{args.base_dir}/{args.paper_name}/paper.pdf'
182
+
183
+
184
+ print('Evaluating poster:', args.paper_name)
185
+
186
+ if args.metric == 'stats':
187
+ stats_file = os.path.join(save_path, 'stats_result.json')
188
+
189
+ # 1) load existing results if there are any
190
+ if os.path.exists(stats_file):
191
+ with open(stats_file, 'r') as f:
192
+ stats_result = json.load(f)
193
+ print(f"Loaded existing stats from {stats_file}")
194
+ else:
195
+ stats_result = {}
196
+
197
+ # 2) CLIP similarity
198
+ if 'CLIP_similarity' not in stats_result:
199
+ _, cos_sim = compare_folders_with_clip(raw_folder, gen_folder)
200
+ stats_result['CLIP_similarity'] = cos_sim
201
+ print(f'CLIP similarity: {cos_sim}')
202
+ else:
203
+ print(f"Skipping CLIP similarity (already {stats_result['CLIP_similarity']})")
204
+
205
+ # 3) we only need to regenerate markdown+images if any of the text/image metrics is missing
206
+ need_eval = any(k not in stats_result for k in ('textual_ppl', 'mixtual_ppl', 'visual_relevance', 'visual_ppl'))
207
+ if need_eval:
208
+ images, poster_text, raw_markdown, new_markdown = gen_eval_markdown(
209
+ args.paper_name,
210
+ args.poster_method,
211
+ gen_poster_path
212
+ )
213
+
214
+ # textual PPL
215
+ if 'textual_ppl' not in stats_result:
216
+ textual_ppl = get_ppl(poster_text)
217
+ stats_result['textual_ppl'] = textual_ppl
218
+ print(f'Textual PPL: {textual_ppl}')
219
+ else:
220
+ print(f"Skipping textual PPL (already {stats_result['textual_ppl']})")
221
+
222
+ # mixtual PPL
223
+ if 'mixtual_ppl' not in stats_result:
224
+ mixtual_ppl = get_ppl(new_markdown)
225
+ stats_result['mixtual_ppl'] = mixtual_ppl
226
+ print(f'Mixtual PPL: {mixtual_ppl}')
227
+ else:
228
+ print(f"Skipping mixtual PPL (already {stats_result['mixtual_ppl']})")
229
+
230
+ # visual relevance
231
+ if 'visual_relevance' not in stats_result:
232
+ if images:
233
+ sims = [
234
+ compute_cosine_similarity(v['image_clip_embedding'],
235
+ v['section_text_clip_embedding'])
236
+ for v in images.values()
237
+ ]
238
+ avg_sim = float(np.mean(sims))
239
+ stats_result['visual_relevance'] = avg_sim
240
+ print(f'Average cosine similarity: {avg_sim}')
241
+ else:
242
+ stats_result['visual_relevance'] = 0.0
243
+ print('No images found in the poster. Set visual_relevance to 0.')
244
+ else:
245
+ print(f"Skipping visual relevance (already {stats_result['visual_relevance']})")
246
+
247
+ if 'visual_ppl' not in stats_result or math.isnan(stats_result['visual_ppl']):
248
+ visual_ppls = []
249
+ for relative_path, v in images.items():
250
+ image_path = os.path.join('eval_poster_markdown', args.paper_name, args.poster_method, relative_path)
251
+ image = Image.open(image_path)
252
+ visual_ppl = get_visual_ppl(image, poster_text)
253
+ visual_ppls.append(visual_ppl)
254
+ avg_visual_ppl = float(np.mean(visual_ppls))
255
+ stats_result['visual_ppl'] = avg_visual_ppl
256
+ print(f'Average visual PPL: {avg_visual_ppl}')
257
+ else:
258
+ print("All textual and visual metrics already computed; skipping gen_eval_markdown.")
259
+
260
+ if 'interleaved_ppl' not in stats_result:
261
+ interleaved_ppl = compute_interleaved_ppl(args.paper_name, args.poster_method)
262
+ stats_result['interleaved_ppl'] = interleaved_ppl
263
+ print(f'Interleaved PPL: {interleaved_ppl}')
264
+ else:
265
+ print(f"Skipping interleaved PPL (already {stats_result['interleaved_ppl']})")
266
+
267
+ if 'poster_image_ppl' not in stats_result:
268
+ if args.poster_method == 'paper':
269
+ poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), format='jpg')
270
+ else:
271
+ poster_images = [Image.open(gen_poster_path)]
272
+ poster_image_ppl = compute_poster_image_ppl(poster_images)
273
+ stats_result['poster_image_ppl'] = poster_image_ppl
274
+ print(f'Poster image PPL: {poster_image_ppl}')
275
+ else:
276
+ print(f"Skipping poster image PPL (already {stats_result['poster_image_ppl']})")
277
+
278
+ # 4) write back updated file
279
+ with open(stats_file, 'w') as f:
280
+ json.dump(stats_result, f, indent=4)
281
+ print(f"Updated stats written to {stats_file}")
282
+ elif args.metric == 'figure_count':
283
+ save_file_path = os.path.join(save_path, 'figure_count.json')
284
+ if os.path.exists(save_file_path):
285
+ print(f"Figure count already exists at {save_file_path}. Skipping.")
286
+ else:
287
+ figure_count = gen_eval_markdown(
288
+ args.paper_name,
289
+ args.poster_method,
290
+ gen_poster_path,
291
+ figure_count_only=True
292
+ )
293
+ with open(save_file_path, 'w') as f:
294
+ json.dump({'figure_count': figure_count}, f, indent=4)
295
+ print(f"Figure count saved to {save_file_path}")
296
+ elif args.metric == 'qa':
297
+ if args.fix is not None:
298
+ run_qa_and_update_results(
299
+ args,
300
+ raw_folder,
301
+ gen_poster_path,
302
+ save_path,
303
+ single_model_name=args.fix,
304
+ del_model_name=args.del_model_name
305
+ )
306
+ else:
307
+ overall_qa_result = {}
308
+ qa_result = {}
309
+ qa_dict = json.load(open(os.path.join(raw_folder, 'o3_qa.json'), 'r'))
310
+ detail_qa = qa_dict['detail']
311
+ understanding_qa = qa_dict['understanding']
312
+ model_names = [
313
+ '4o',
314
+ 'o3',
315
+ '4o-mini'
316
+ ]
317
+ if args.poster_method == 'paper':
318
+ poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
319
+ else:
320
+ poster_images = [Image.open(gen_poster_path)]
321
+
322
+ poster_images = [ensure_under_limit_pil(image) for image in poster_images]
323
+
324
+ for model_name in model_names:
325
+ qa_input_token, qa_output_token = 0, 0
326
+ print('QA model:', model_name)
327
+ agent_config = get_agent_config(model_name)
328
+ detail_accuracy, detail_aspect_accuracy, detail_agent_answers, input_token, output_token = eval_qa_get_answer(
329
+ poster_input=poster_images,
330
+ questions=detail_qa['questions'],
331
+ answers=detail_qa['answers'],
332
+ aspects=detail_qa['aspects'],
333
+ input_type='image',
334
+ agent_config=agent_config
335
+ )
336
+ print(f'{model_name} Detail QA accuracy:', detail_accuracy)
337
+ qa_input_token += input_token
338
+ qa_output_token += output_token
339
+
340
+ understanding_accuracy, understanding_aspect_accuracy, understanding_agent_answers, input_token, output_token = eval_qa_get_answer(
341
+ poster_input=poster_images,
342
+ questions=understanding_qa['questions'],
343
+ answers=understanding_qa['answers'],
344
+ aspects=understanding_qa['aspects'],
345
+ input_type='image',
346
+ agent_config=agent_config
347
+ )
348
+ print(f'{model_name} Understanding QA accuracy:', understanding_accuracy)
349
+ qa_input_token += input_token
350
+ qa_output_token += output_token
351
+
352
+ qa_result[model_name] = {
353
+ 'detail_accuracy': detail_accuracy,
354
+ 'detail_aspect_accuracy': detail_aspect_accuracy,
355
+ 'detail_agent_answers': detail_agent_answers,
356
+ 'understanding_accuracy': understanding_accuracy,
357
+ 'understanding_aspect_accuracy': understanding_aspect_accuracy,
358
+ 'understanding_agent_answers': understanding_agent_answers
359
+ }
360
+
361
+ print(f'{model_name} Input tokens:', qa_input_token)
362
+ print(f'{model_name} Output tokens:', qa_output_token)
363
+
364
+ # average the results
365
+ avg_detail_accuracy = np.mean([qa_result[model_name]['detail_accuracy'] for model_name in model_names])
366
+ avg_understanding_accuracy = np.mean([qa_result[model_name]['understanding_accuracy'] for model_name in model_names])
367
+
368
+ print('Average detail accuracy:', avg_detail_accuracy)
369
+ print('Average understanding accuracy:', avg_understanding_accuracy)
370
+
371
+ overall_qa_result['avg_detail_accuracy'] = avg_detail_accuracy
372
+ overall_qa_result['avg_understanding_accuracy'] = avg_understanding_accuracy
373
+ overall_qa_result['qa_result'] = qa_result
374
+
375
+ with open(f'{save_path}/overall_qa_result.json', 'w') as f:
376
+ json.dump(overall_qa_result, f, indent=4)
377
+
378
+ elif args.metric == 'word_count':
379
+ if args.poster_method == 'paper':
380
+ # loop through all images in the folder
381
+ image_paths = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), return_path=True)
382
+ word_count = 0
383
+ for image_path in image_paths:
384
+ # count words in each image
385
+ word_count += count_words_in_image(image_path)
386
+ else:
387
+ word_count = count_words_in_image(gen_poster_path)
388
+ # save to json
389
+ with open(f'{save_path}/word_count.json', 'w') as f:
390
+ json.dump({'word_count': word_count}, f, indent=4)
391
+
392
+ elif args.metric == 'token_count':
393
+ if args.poster_method == 'paper':
394
+ # loop through all images in the folder
395
+ image_paths = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), return_path=True)
396
+ token_count = 0
397
+ for image_path in image_paths:
398
+ # count tokens in each image
399
+ token_count += count_tokens_in_image(image_path)
400
+ else:
401
+ token_count = count_tokens_in_image(gen_poster_path)
402
+ # save to json
403
+ with open(f'{save_path}/token_count.json', 'w') as f:
404
+ json.dump({'token_count': token_count}, f, indent=4)
405
+ elif args.metric == 'judge':
406
+ agent_config = get_agent_config('4o')
407
+
408
+ if args.poster_method == 'paper':
409
+ poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
410
+ else:
411
+ poster_images = [Image.open(gen_poster_path)]
412
+
413
+ results = eval_vlm_as_judge(
414
+ poster_image_list=poster_images,
415
+ agent_config=agent_config,
416
+ )
417
+
418
+ aesthetic_aspects = [
419
+ 'aesthetic_element',
420
+ 'aesthetic_engagement',
421
+ 'aesthetic_layout'
422
+ ]
423
+
424
+ information_aspects = [
425
+ 'information_low_level',
426
+ 'information_logic',
427
+ 'information_content',
428
+ ]
429
+
430
+ # compute average scores for all, for aesthetic, and for information
431
+ overall_average = np.mean([results[aspect]['score'] for aspect in results])
432
+ aesthetic_average = np.mean([results[aspect]['score'] for aspect in results if aspect in aesthetic_aspects])
433
+ information_average = np.mean([results[aspect]['score'] for aspect in results if aspect in information_aspects])
434
+
435
+ judge_result = {
436
+ 'overall_average': overall_average,
437
+ 'aesthetic_average': aesthetic_average,
438
+ 'information_average': information_average,
439
+ 'results': results
440
+ }
441
+
442
+ # save to json
443
+ with open(f'{save_path}/judge_result.json', 'w') as f:
444
+ json.dump(judge_result, f, indent=4)
445
+ elif args.metric == 'aesthetic_judge':
446
+ agent_config = get_agent_config('4o')
447
+
448
+ if args.poster_method == 'paper':
449
+ poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
450
+ else:
451
+ poster_images = [Image.open(gen_poster_path)]
452
+
453
+ results = eval_vlm_as_judge(
454
+ poster_image_list=poster_images,
455
+ agent_config=agent_config,
456
+ aspect='aesthetic'
457
+ )
458
+
459
+ aesthetic_aspects = [
460
+ 'aesthetic_element',
461
+ 'aesthetic_engagement',
462
+ 'aesthetic_layout'
463
+ ]
464
+
465
+ aesthetic_average = np.mean([results[aspect]['score'] for aspect in results if aspect in aesthetic_aspects])
466
+
467
+ judge_result = {
468
+ 'aesthetic_average': aesthetic_average,
469
+ 'results': results
470
+ }
471
+
472
+ # save to json
473
+ with open(f'{save_path}/aesthetic_judge_result.json', 'w') as f:
474
+ json.dump(judge_result, f, indent=4)
475
+
476
+ if args.poster_method == 'paper':
477
+ # remove the temp folder
478
+ shutil.rmtree(temp_dir)
479
+ print(f"Removed temporary folder {temp_dir}")
Paper2Poster/Paper2Poster-eval/eval_qa_fix.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Run eval_poster_pipeline.py for every sub-folder in poster_sum_100,
4
+ using up to 10 threads. poster_method and fix are now taken from
5
+ command-line arguments.
6
+
7
+ Example:
8
+ python run_eval_threads.py \
9
+ --poster_method poster_sum_50 \
10
+ --fix llama-3-70b-vl
11
+
12
+ """
13
+ import argparse
14
+ import concurrent.futures as cf
15
+ import pathlib
16
+ import signal
17
+ import subprocess
18
+ import sys
19
+
20
+ BASE_DIR = pathlib.Path("poster_sum_100") # directory holding the papers # number of worker threads
21
+
22
+ # ── Argument parsing ───────────────────────────────────────────────────────────
23
+ parser = argparse.ArgumentParser(
24
+ description="Run eval_poster_pipeline.py concurrently on all papers."
25
+ )
26
+ parser.add_argument(
27
+ "--poster_method",
28
+ default="poster_sum_100",
29
+ help="Name of the poster-generation method to evaluate (default: %(default)s)",
30
+ )
31
+ parser.add_argument(
32
+ "--fix",
33
+ default="qwen-2.5-vl-72b",
34
+ help="Value to pass to --fix in eval_poster_pipeline.py (default: %(default)s)",
35
+ )
36
+
37
+ parser.add_argument(
38
+ '--max_workers',
39
+ type=int,
40
+ default=1,
41
+ )
42
+
43
+ parser.add_argument('--del_model_name', type=str)
44
+ args = parser.parse_args()
45
+ # ───────────────────────────────────────────────────────────────────────────────
46
+
47
+
48
+ MAX_WORKERS = args.max_workers
49
+
50
+ def run_pipeline(subfolder: str, poster_method: str, fix: str) -> None:
51
+ """Invoke eval_poster_pipeline.py for a single paper."""
52
+ cmd = [
53
+ "python",
54
+ "eval_poster_pipeline.py",
55
+ "--paper_name",
56
+ subfolder,
57
+ "--poster_method",
58
+ poster_method,
59
+ "--poster_image_name",
60
+ "poster.png",
61
+ "--metric",
62
+ "qa",
63
+ "--fix",
64
+ fix,
65
+ ]
66
+ if args.del_model_name:
67
+ cmd += ["--del_model_name", args.del_model_name]
68
+ subprocess.run(cmd, check=True)
69
+
70
+
71
+ MAX_RETRIES = 50
72
+
73
+ def run_with_retries(folder: str, poster_method, fix) -> None:
74
+ """
75
+ Tries to run_pipeline up to MAX_RETRIES times before giving up.
76
+ """
77
+ for attempt in range(1, MAX_RETRIES + 1):
78
+ try:
79
+ run_pipeline(folder, poster_method, fix)
80
+ return
81
+ except Exception as e:
82
+ if attempt < MAX_RETRIES:
83
+ print(f"⚠️ {folder}: attempt {attempt} failed ({e!r}), retrying…")
84
+ else:
85
+ # Last attempt also failed, re-raise so the pool will catch it
86
+ raise
87
+
88
+ def main() -> None:
89
+ folders = sorted(p.name for p in BASE_DIR.iterdir() if p.is_dir())
90
+
91
+ with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
92
+ futures = {
93
+ pool.submit(run_with_retries, f, args.poster_method, args.fix): f
94
+ for f in folders
95
+ }
96
+ for fut in cf.as_completed(futures):
97
+ paper = futures[fut]
98
+ try:
99
+ fut.result()
100
+ print(f"✓ {paper} done")
101
+ except Exception as e:
102
+ print(f"✗ {paper} failed after {MAX_RETRIES} attempts: {e}", file=sys.stderr)
103
+ # ── Graceful shutdown on Ctrl-C / SIGTERM ──────────────────────────────────────
104
+ def _handle_signal(signum, frame):
105
+ print("\nReceived signal, shutting down…", file=sys.stderr)
106
+ sys.exit(1)
107
+
108
+
109
+ signal.signal(signal.SIGINT, _handle_signal)
110
+ signal.signal(signal.SIGTERM, _handle_signal)
111
+
112
+ # ── Entry point ────────────────────────────────────────────────────────────────
113
+ if __name__ == "__main__":
114
+ main()
Paper2Poster/PosterAgent/LLM_direct_generate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import get_json_from_response
3
+
4
+ from camel.models import ModelFactory
5
+ from camel.agents import ChatAgent
6
+
7
+
8
+ from utils.wei_utils import account_token, get_agent_config, html_to_png
9
+
10
+ from utils.pptx_utils import *
11
+ from utils.critic_utils import *
12
+ import yaml
13
+ import time
14
+ from jinja2 import Environment, StrictUndefined
15
+ from utils.poster_eval_utils import get_poster_text
16
+ import argparse
17
+ import json
18
+ import os
19
+
20
+ load_dotenv()
21
+
22
+ if __name__ == '__main__':
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--paper_path', type=str)
25
+ parser.add_argument('--model_name', type=str, default='4o')
26
+
27
+ args = parser.parse_args()
28
+
29
+ # get current directory
30
+ current_dir = os.path.dirname(os.path.abspath(__file__))
31
+
32
+ meta_dir = args.paper_path.replace('paper.pdf', 'meta.json')
33
+ meta = json.load(open(meta_dir, 'r'))
34
+ poster_width = meta['width']
35
+ poster_height = meta['height']
36
+
37
+ output_dir = f"{args.model_name}_HTML/{args.paper_path.replace('paper.pdf', '')}"
38
+ os.makedirs(output_dir, exist_ok=True)
39
+
40
+ total_input_token = 0
41
+ total_output_token = 0
42
+
43
+ start_time = time.time()
44
+ model_config = get_agent_config(args.model_name)
45
+ model = ModelFactory.create(
46
+ model_platform=model_config['model_platform'],
47
+ model_type=model_config['model_type'],
48
+ model_config_dict=model_config['model_config'],
49
+ )
50
+ paper_text = get_poster_text(args.paper_path)
51
+
52
+ actor_agent_name = 'LLM_gen_HTML'
53
+
54
+ with open(f'prompt_templates/{actor_agent_name}.yaml', "r") as f:
55
+ content_config = yaml.safe_load(f)
56
+ jinja_env = Environment(undefined=StrictUndefined)
57
+ template = jinja_env.from_string(content_config["template"])
58
+
59
+ actor_sys_msg = content_config['system_prompt']
60
+ actor_agent = ChatAgent(
61
+ system_message=actor_sys_msg,
62
+ model=model,
63
+ message_window_size=None
64
+ )
65
+
66
+ jinja_args = {
67
+ 'document_markdown': paper_text,
68
+ 'poster_width': poster_width,
69
+ 'poster_height': poster_height,
70
+ }
71
+ prompt = template.render(**jinja_args)
72
+
73
+ actor_agent.reset()
74
+ response = actor_agent.step(prompt)
75
+ input_token, output_token = account_token(response)
76
+ total_input_token += input_token
77
+ total_output_token += output_token
78
+ result_json = get_json_from_response(response.msgs[0].content)
79
+ html_str = result_json['HTML']
80
+
81
+ # write to poster.html
82
+ with open(f'{output_dir}/poster.html', 'w') as f:
83
+ f.write(html_str)
84
+
85
+ html_to_png(
86
+ os.path.join(current_dir, output_dir, 'poster.html'),
87
+ poster_width,
88
+ poster_height,
89
+ os.path.join(current_dir, output_dir, 'poster.png')
90
+ )
91
+
92
+
93
+ end_time = time.time()
94
+ elapsed_time = end_time - start_time
95
+
96
+ log = {
97
+ 'input_token': total_input_token,
98
+ 'output_token': total_output_token,
99
+ 'time_taken': elapsed_time
100
+ }
101
+
102
+ with open(f'{output_dir}/log.json', 'w') as f:
103
+ json.dump(log, f, indent=4)
Paper2Poster/PosterAgent/LLM_direct_generate_beamer.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ from dotenv import load_dotenv
5
+ from jinja2 import Environment, StrictUndefined
6
+
7
+ from utils.src.utils import get_json_from_response, account_token, html_to_png
8
+ from utils.config_utils import load_poster_yaml_config
9
+
10
+ from camel.models import ModelFactory
11
+ from camel.agents import ChatAgent
12
+ from camel.configs import ChatGPTConfig
13
+ from camel.types import ModelPlatformType, ModelType
14
+
15
+ load_dotenv()
16
+
17
+ def gen_beamer_poster_direct(
18
+ paper_text: str,
19
+ poster_width_cm: float = 120,
20
+ poster_height_cm: float = 90,
21
+ beamer_theme: str = "default",
22
+ output_dir: str = "output",
23
+ model_name: str = "4o"
24
+ ):
25
+ """
26
+ Generate Beamer poster directly from paper text using LLM.
27
+
28
+ Args:
29
+ paper_text: Extracted text from the paper
30
+ poster_width_cm: Poster width in centimeters
31
+ poster_height_cm: Poster height in centimeters
32
+ beamer_theme: Beamer theme name
33
+ output_dir: Output directory
34
+ model_name: Model name for generation
35
+ """
36
+ start_time = time.time()
37
+ total_input_token, total_output_token = 0, 0
38
+
39
+ # Load configuration
40
+ config_path = "utils/prompt_templates/LLM_gen_Beamer.yaml"
41
+ with open(config_path, "r") as f:
42
+ config = yaml.safe_load(f)
43
+
44
+ # Create model and agent
45
+ actor_model = ModelFactory.create(
46
+ model_platform=ModelPlatformType.OPENAI,
47
+ model_type=ModelType.GPT_4O,
48
+ model_config_dict=ChatGPTConfig().as_dict(),
49
+ )
50
+
51
+ actor_agent = ChatAgent(
52
+ system_message=config['system_prompt'],
53
+ model=actor_model,
54
+ message_window_size=None
55
+ )
56
+
57
+ # Prepare template arguments
58
+ jinja_args = {
59
+ 'document_markdown': paper_text,
60
+ 'poster_width_cm': poster_width_cm,
61
+ 'poster_height_cm': poster_height_cm,
62
+ 'beamer_theme': beamer_theme,
63
+ 'aspect_ratio': "169",
64
+ 'title_color': "[47, 85, 151]",
65
+ 'text_color': "[0, 0, 0]"
66
+ }
67
+
68
+ # Render template
69
+ jinja_env = Environment(undefined=StrictUndefined)
70
+ template = jinja_env.from_string(config["template"])
71
+ prompt = template.render(**jinja_args)
72
+
73
+ # Generate Beamer code
74
+ actor_agent.reset()
75
+ response = actor_agent.step(prompt)
76
+ input_token, output_token = account_token(response)
77
+ total_input_token += input_token
78
+ total_output_token += output_token
79
+
80
+ # Extract LaTeX code
81
+ result_json = get_json_from_response(response.msgs[0].content)
82
+ latex_str = result_json['LATEX']
83
+
84
+ # Save LaTeX file
85
+ os.makedirs(output_dir, exist_ok=True)
86
+ tex_path = os.path.join(output_dir, 'poster.tex')
87
+ with open(tex_path, 'w', encoding='utf-8') as f:
88
+ f.write(latex_str)
89
+
90
+ # Compile to PDF
91
+ print("Compiling LaTeX to PDF...")
92
+ success = compile_beamer_to_pdf(tex_path, output_dir)
93
+
94
+ if success:
95
+ print(f"✅ Beamer poster generated successfully: {tex_path}")
96
+ else:
97
+ print("❌ Failed to compile LaTeX to PDF")
98
+
99
+ # Save log
100
+ end_time = time.time()
101
+ elapsed_time = end_time - start_time
102
+
103
+ log = {
104
+ 'input_token': total_input_token,
105
+ 'output_token': total_output_token,
106
+ 'time_taken': elapsed_time,
107
+ 'output_format': 'beamer',
108
+ 'beamer_theme': beamer_theme
109
+ }
110
+
111
+ with open(os.path.join(output_dir, 'log.json'), 'w') as f:
112
+ json.dump(log, f, indent=4)
113
+
114
+ return tex_path, success
115
+
116
+ def compile_beamer_to_pdf(tex_path: str, output_dir: str = "."):
117
+ """
118
+ Compile Beamer .tex file to PDF using pdflatex.
119
+
120
+ Args:
121
+ tex_path: Path to .tex file
122
+ output_dir: Output directory for PDF
123
+ """
124
+ import subprocess
125
+
126
+ try:
127
+ # Run pdflatex twice for proper cross-references
128
+ result1 = subprocess.run(
129
+ ['pdflatex', '-output-directory', output_dir, tex_path],
130
+ capture_output=True,
131
+ text=True,
132
+ timeout=60
133
+ )
134
+
135
+ result2 = subprocess.run(
136
+ ['pdflatex', '-output-directory', output_dir, tex_path],
137
+ capture_output=True,
138
+ text=True,
139
+ timeout=60
140
+ )
141
+
142
+ if result1.returncode == 0 and result2.returncode == 0:
143
+ print(f"Successfully compiled {tex_path} to PDF")
144
+ return True
145
+ else:
146
+ print(f"Error compiling {tex_path}:")
147
+ print(result1.stderr)
148
+ print(result2.stderr)
149
+ return False
150
+
151
+ except subprocess.TimeoutExpired:
152
+ print(f"Timeout while compiling {tex_path}")
153
+ return False
154
+ except Exception as e:
155
+ print(f"Error compiling {tex_path}: {e}")
156
+ return False
157
+
158
+ if __name__ == "__main__":
159
+ import argparse
160
+
161
+ parser = argparse.ArgumentParser(description='Generate Beamer poster directly from paper')
162
+ parser.add_argument('--paper_path', required=True, help='Path to paper PDF')
163
+ parser.add_argument('--output_dir', default='beamer_output', help='Output directory')
164
+ parser.add_argument('--poster_width_cm', type=float, default=120, help='Poster width in cm')
165
+ parser.add_argument('--poster_height_cm', type=float, default=90, help='Poster height in cm')
166
+ parser.add_argument('--beamer_theme', default='default', help='Beamer theme')
167
+ parser.add_argument('--model_name', default='4o', help='Model name')
168
+
169
+ args = parser.parse_args()
170
+
171
+ # Extract text from paper (you'll need to implement this)
172
+ # For now, using placeholder text
173
+ paper_text = "This is placeholder text. In practice, you would extract text from the PDF."
174
+
175
+ # Generate Beamer poster
176
+ tex_path, success = gen_beamer_poster_direct(
177
+ paper_text=paper_text,
178
+ poster_width_cm=args.poster_width_cm,
179
+ poster_height_cm=args.poster_height_cm,
180
+ beamer_theme=args.beamer_theme,
181
+ output_dir=args.output_dir,
182
+ model_name=args.model_name
183
+ )
184
+
185
+ if success:
186
+ print(f"Beamer poster generated at: {tex_path}")
187
+ else:
188
+ print("Failed to generate Beamer poster")
189
+
Paper2Poster/PosterAgent/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import (
2
+ apply_theme,
3
+ create_dataset,
4
+ deoverflow,
5
+ deoverflow_parallel,
6
+ fill_and_style,
7
+ gen_outline_layout_parallel,
8
+ gen_outline_layout,
9
+ gen_poster_content,
10
+ gen_pptx_code,
11
+ LLM_direct_generate,
12
+ new_pipeline,
13
+ parse_raw,
14
+ poster_gen_pipeline,
15
+ tree_split_layout
16
+ )
Paper2Poster/PosterAgent/apply_theme.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import ppt_to_images, get_json_from_response
3
+ import json
4
+ import shutil
5
+
6
+ from camel.models import ModelFactory
7
+ from camel.agents import ChatAgent
8
+
9
+ from utils.wei_utils import *
10
+
11
+ from camel.messages import BaseMessage
12
+ from PIL import Image
13
+ import pickle as pkl
14
+ from utils.pptx_utils import *
15
+ from utils.critic_utils import *
16
+ import yaml
17
+ from jinja2 import Environment, StrictUndefined
18
+ from pdf2image import convert_from_path
19
+ import argparse
20
+
21
+ load_dotenv()
22
+
23
+ def poster_apply_theme(args, actor_config, critic_config):
24
+ total_input_token, total_output_token = 0, 0
25
+ extract_input_token, extract_output_token = 0, 0
26
+ gen_input_token, gen_output_token = 0, 0
27
+ non_overlap_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'rb'))
28
+ non_overlap_code = non_overlap_ckpt['final_code_by_section']
29
+ sections = list(non_overlap_code.keys())
30
+ sections = [s for s in sections if s != 'meta']
31
+ template_img = convert_from_path(args.template_path)[0]
32
+ image_bytes = io.BytesIO()
33
+ template_img.save(image_bytes, format="PNG")
34
+ image_bytes.seek(0)
35
+
36
+ # Reload the image from memory as a standard PIL.Image.Image
37
+ template_img = Image.open(image_bytes)
38
+
39
+
40
+ title_actor_agent_name = 'theme_agent_title'
41
+ with open(f"prompt_templates/{title_actor_agent_name}.yaml", "r") as f:
42
+ title_theme_actor_config = yaml.safe_load(f)
43
+
44
+ section_actor_agent_name = 'theme_agent_section'
45
+ with open(f"prompt_templates/{section_actor_agent_name}.yaml", "r") as f:
46
+ section_theme_actor_config = yaml.safe_load(f)
47
+
48
+ title_actor_model = ModelFactory.create(
49
+ model_platform=actor_config['model_platform'],
50
+ model_type=actor_config['model_type'],
51
+ model_config_dict=actor_config['model_config'], # [Optional] the config for model
52
+ )
53
+
54
+ title_actor_sys_msg = title_theme_actor_config['system_prompt']
55
+
56
+ title_actor_agent = ChatAgent(
57
+ system_message=title_actor_sys_msg,
58
+ model=title_actor_model,
59
+ message_window_size=10, # [Optional] the length for chat memory
60
+ )
61
+
62
+ section_actor_model = ModelFactory.create(
63
+ model_platform=actor_config['model_platform'],
64
+ model_type=actor_config['model_type'],
65
+ model_config_dict=actor_config['model_config'], # [Optional] the config for model
66
+ )
67
+
68
+ section_actor_sys_msg = section_theme_actor_config['system_prompt']
69
+
70
+ section_actor_agent = ChatAgent(
71
+ system_message=section_actor_sys_msg,
72
+ model=section_actor_model,
73
+ message_window_size=10, # [Optional] the length for chat memory
74
+ )
75
+
76
+ critic_model = ModelFactory.create(
77
+ model_platform=critic_config['model_platform'],
78
+ model_type=critic_config['model_type'],
79
+ model_config_dict=critic_config['model_config'],
80
+ )
81
+
82
+ critic_sys_msg = 'You are a helpful assistant.'
83
+
84
+ critic_agent = ChatAgent(
85
+ system_message=critic_sys_msg,
86
+ model=critic_model,
87
+ message_window_size=None,
88
+ )
89
+
90
+ theme_aspects = {
91
+ 'background': ['background'],
92
+ 'title': ['title_author', 'title_author_border'],
93
+ 'section': ['section_body', 'section_title', 'section_border']
94
+ }
95
+
96
+ theme_styles = {}
97
+ for aspect in theme_aspects.keys():
98
+ theme_styles[aspect] = {}
99
+
100
+ for aspect, prompt_types in theme_aspects.items():
101
+ for prompt_type in prompt_types:
102
+ print(f'Getting style for {prompt_type}')
103
+ with open(f"prompt_templates/theme_templates/theme_{prompt_type}.txt", "r") as f:
104
+ prompt = f.read()
105
+ msg = BaseMessage.make_user_message(
106
+ role_name="User",
107
+ content=prompt,
108
+ image_list=[template_img],
109
+ )
110
+
111
+ critic_agent.reset()
112
+ response = critic_agent.step(msg)
113
+ input_token, output_token = account_token(response)
114
+ total_input_token += input_token
115
+ total_output_token += output_token
116
+ extract_input_token += input_token
117
+ extract_output_token += output_token
118
+ theme_style = get_json_from_response(response.msgs[0].content)
119
+ theme_styles[aspect][prompt_type] = theme_style
120
+
121
+ if 'fontStyle' in theme_styles['section']['section_body']:
122
+ del theme_styles['section']['section_body']['fontStyle']
123
+
124
+ outline_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
125
+ outline = json.load(open(outline_path, 'r'))
126
+ outline_skeleton = {}
127
+ for key, val in outline.items():
128
+ if key == 'meta':
129
+ continue
130
+ if not 'subsections' in val:
131
+ outline_skeleton[key] = {
132
+ 'section': key
133
+ }
134
+ else:
135
+ for subsection_name, subsection_dict in val['subsections'].items():
136
+ outline_skeleton[subsection_dict['name']] = {
137
+ 'section': key
138
+ }
139
+
140
+ for key in outline_skeleton.keys():
141
+ if 'title' in key.lower() or 'author' in key.lower():
142
+ outline_skeleton[key]['style'] = theme_styles['section']['section_title']
143
+ else:
144
+ outline_skeleton[key]['style'] = theme_styles['section']['section_body']
145
+
146
+ outline_skeleton_list = []
147
+ for section in sections[1:]:
148
+ # append all subsections whose section key is the current section
149
+ for key, val in outline_skeleton.items():
150
+ if val['section'] == section:
151
+ outline_skeleton_list.append({key: val})
152
+
153
+ theme_logs = {}
154
+ theme_code = {}
155
+ concatenated_code = {}
156
+
157
+ # Title
158
+ jinja_env = Environment(undefined=StrictUndefined)
159
+
160
+ title_actor_template = jinja_env.from_string(title_theme_actor_config["template"])
161
+
162
+ # Title section
163
+ print(f'Processing section {sections[0]}')
164
+ curr_title_code = non_overlap_code[sections[0]]
165
+ for style in ['background', 'title']:
166
+ for sub_style in theme_styles[style].keys():
167
+ print(f' Applying theme for {sub_style}')
168
+ jinja_args = {
169
+ 'style_json': {sub_style: theme_styles[style][sub_style]},
170
+ 'function_docs': documentation,
171
+ 'existing_code': curr_title_code
172
+ }
173
+ actor_prompt = title_actor_template.render(**jinja_args)
174
+ log = apply_theme(title_actor_agent, actor_prompt, args.max_retry, existing_code='')
175
+ if log[-1]['error'] is not None:
176
+ raise Exception(log[-1]['error'])
177
+
178
+ input_token, output_token = log[-1]['cumulative_tokens']
179
+ total_input_token += input_token
180
+ total_output_token += output_token
181
+ gen_input_token += input_token
182
+ gen_output_token += output_token
183
+
184
+ shutil.copy('poster.pptx', f'tmp/theme_poster_<{sections[0]}>_<{style}>_<{sub_style}>.pptx')
185
+
186
+ if not style in theme_logs:
187
+ theme_logs[style] = {}
188
+
189
+ theme_logs[style][sub_style] = log
190
+ curr_title_code = log[-1]['code']
191
+
192
+ theme_code[sections[0]] = curr_title_code
193
+ concatenated_code[sections[0]] = log[-1]['concatenated_code']
194
+
195
+ # Remaining sections
196
+
197
+ jinja_env = Environment(undefined=StrictUndefined)
198
+
199
+ section_actor_template = jinja_env.from_string(section_theme_actor_config["template"])
200
+
201
+ prev_section = None
202
+ for style_dict in outline_skeleton_list:
203
+ curr_subsection = list(style_dict.keys())[0]
204
+ curr_section = style_dict[curr_subsection]['section']
205
+ section_index = sections.index(curr_section)
206
+ print(f'Processing section {curr_section}')
207
+ if prev_section != curr_section:
208
+ prev_section = curr_section
209
+ curr_section_code = non_overlap_code[curr_section]
210
+ print(f' Applying theme for {curr_subsection}')
211
+ jinja_args = {
212
+ 'style_json': json.dumps({curr_subsection: style_dict[curr_subsection]['style']}, indent=4),
213
+ 'function_docs': documentation,
214
+ 'existing_code': curr_section_code
215
+ }
216
+ actor_prompt = section_actor_template.render(**jinja_args)
217
+ existing_code = concatenated_code[sections[section_index - 1]]
218
+ log = apply_theme(section_actor_agent, actor_prompt, args.max_retry, existing_code=existing_code)
219
+ if log[-1]['error'] is not None:
220
+ raise Exception(log[-1]['error'])
221
+
222
+ input_token, output_token = log[-1]['cumulative_tokens']
223
+ total_input_token += input_token
224
+ total_output_token += output_token
225
+ gen_input_token += input_token
226
+ gen_output_token += output_token
227
+
228
+ shutil.copy('poster.pptx', f'tmp/theme_poster_<{curr_section}>_<{curr_subsection}>.pptx')
229
+
230
+ if not style in theme_logs:
231
+ theme_logs[style] = {}
232
+
233
+ theme_logs[style][sub_style] = log
234
+ curr_section_code = log[-1]['code']
235
+
236
+ theme_code[curr_section] = curr_section_code
237
+ concatenated_code[curr_section] = log[-1]['concatenated_code']
238
+
239
+ ppt_to_images(f'poster.pptx', 'tmp/theme_preview')
240
+
241
+ result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
242
+ shutil.copy('poster.pptx', f'{result_dir}/theme_poster.pptx')
243
+ ppt_to_images(f'poster.pptx', f'{result_dir}/theme_poster_preview')
244
+
245
+
246
+ ckpt = {
247
+ 'theme_styles': theme_styles,
248
+ 'theme_logs': theme_logs,
249
+ 'theme_code': theme_code,
250
+ 'concatenated_code': concatenated_code,
251
+ 'total_input_token': total_input_token,
252
+ 'total_output_token': total_output_token,
253
+ 'extract_input_token': extract_input_token,
254
+ 'extract_output_token': extract_output_token,
255
+ 'gen_input_token': gen_input_token,
256
+ 'gen_output_token': gen_output_token
257
+ }
258
+
259
+ pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_theme_ckpt.pkl', 'wb'))
260
+
261
+ return total_input_token, total_output_token
262
+
263
+ if __name__ == '__main__':
264
+ parser = argparse.ArgumentParser()
265
+ parser.add_argument('--poster_name', type=str, default=None)
266
+ parser.add_argument('--model_name', type=str, default='4o')
267
+ parser.add_argument('--poster_path', type=str, required=True)
268
+ parser.add_argument('--index', type=int, default=0)
269
+ parser.add_argument('--template_path', type=str)
270
+ parser.add_argument('--max_retry', type=int, default=3)
271
+ args = parser.parse_args()
272
+
273
+ actor_config = get_agent_config(args.model_name)
274
+ critic_config = get_agent_config(args.model_name)
275
+
276
+ if args.poster_name is None:
277
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
278
+
279
+ input_token, output_token = poster_apply_theme(args, actor_config, critic_config)
280
+
281
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/beamer_pipeline.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import argparse
4
+ from typing import Dict, Any, List
5
+
6
+ # Import existing modules
7
+ from PosterAgent.gen_beamer_code import (
8
+ generate_beamer_poster_code,
9
+ convert_pptx_layout_to_beamer,
10
+ save_beamer_code,
11
+ compile_beamer_to_pdf
12
+ )
13
+ from PosterAgent.gen_pptx_code import generate_poster_code
14
+ from utils.wei_utils import run_code
15
+ from utils.theme_utils import get_default_theme, create_theme_with_alignment
16
+
17
+ def generate_beamer_poster(
18
+ panel_arrangement_inches: List[Dict[str, Any]],
19
+ text_arrangement_inches: List[Dict[str, Any]],
20
+ figure_arrangement_inches: List[Dict[str, Any]],
21
+ bullet_content: List[Dict[str, Any]],
22
+ poster_info: Dict[str, str],
23
+ args,
24
+ width_cm: float = 120,
25
+ height_cm: float = 90,
26
+ theme: str = "default"
27
+ ):
28
+ """
29
+ Generate Beamer poster instead of PowerPoint.
30
+
31
+ Args:
32
+ panel_arrangement_inches: Panel layout data
33
+ text_arrangement_inches: Text layout data
34
+ figure_arrangement_inches: Figure layout data
35
+ bullet_content: Content for text boxes
36
+ poster_info: Poster metadata (title, author, institute)
37
+ args: Command line arguments
38
+ width_cm: Poster width in centimeters
39
+ height_cm: Poster height in centimeters
40
+ theme: Beamer theme name
41
+ """
42
+ print("\n🎯 Generating Beamer poster code...", flush=True)
43
+
44
+ # Convert layout data to Beamer format
45
+ beamer_data = convert_pptx_layout_to_beamer({
46
+ 'text_arrangement': text_arrangement_inches,
47
+ 'figure_arrangement': figure_arrangement_inches
48
+ })
49
+
50
+ # Update poster info
51
+ beamer_data['poster_info'].update(poster_info)
52
+
53
+ # Generate Beamer code
54
+ beamer_code = generate_beamer_poster_code(
55
+ sections=beamer_data['sections'],
56
+ figures=beamer_data['figures'],
57
+ poster_info=beamer_data['poster_info'],
58
+ width_cm=width_cm,
59
+ height_cm=height_cm,
60
+ theme=theme
61
+ )
62
+
63
+ # Save Beamer code
64
+ tex_path = f'{args.tmp_dir}/poster.tex'
65
+ save_beamer_code(beamer_code, tex_path)
66
+
67
+ # Compile to PDF
68
+ print("\n📄 Compiling Beamer to PDF...", flush=True)
69
+ success = compile_beamer_to_pdf(tex_path, args.tmp_dir)
70
+
71
+ if not success:
72
+ raise RuntimeError('Error in compiling Beamer to PDF')
73
+
74
+ print(f"✅ Beamer poster generated successfully: {tex_path}")
75
+ return tex_path
76
+
77
+ def modify_new_pipeline_for_beamer(args):
78
+ """
79
+ Modified version of new_pipeline.py to support Beamer output.
80
+ This function replaces the PowerPoint generation part with Beamer generation.
81
+ """
82
+ # Import the original pipeline components
83
+ from PosterAgent.new_pipeline import (
84
+ parse_paper_content,
85
+ gen_outline_layout_parallel,
86
+ gen_poster_content,
87
+ deoverflow_parallel,
88
+ apply_theme
89
+ )
90
+
91
+ # ... (keep all the existing pipeline steps until poster generation)
92
+
93
+ # At the poster generation step, replace PowerPoint with Beamer:
94
+
95
+ # === Beamer Poster Generation ===
96
+ print("\n🎯 Generating Beamer poster...", flush=True)
97
+
98
+ # Extract poster information from content
99
+ poster_info = {
100
+ 'title': 'Research Poster Title', # Extract from paper content
101
+ 'author': 'Author Name', # Extract from paper content
102
+ 'institute': 'Institute Name' # Extract from paper content
103
+ }
104
+
105
+ # Convert inches to centimeters (1 inch = 2.54 cm)
106
+ width_cm = args.poster_width_inches * 2.54
107
+ height_cm = args.poster_height_inches * 2.54
108
+
109
+ # Generate Beamer poster
110
+ tex_path = generate_beamer_poster(
111
+ panel_arrangement_inches=panel_arrangement_inches,
112
+ text_arrangement_inches=text_arrangement_inches,
113
+ figure_arrangement_inches=figure_arrangement_inches,
114
+ bullet_content=bullet_content,
115
+ poster_info=poster_info,
116
+ args=args,
117
+ width_cm=width_cm,
118
+ height_cm=height_cm,
119
+ theme=getattr(args, 'beamer_theme', 'default')
120
+ )
121
+
122
+ # Copy output to final directory
123
+ output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_posters/{args.poster_path.replace("paper.pdf", "")}'
124
+ os.makedirs(output_dir, exist_ok=True)
125
+
126
+ # Copy generated files
127
+ import shutil
128
+ shutil.copy(tex_path, f'{output_dir}/poster.tex')
129
+ shutil.copy(f'{args.tmp_dir}/poster.pdf', f'{output_dir}/poster.pdf')
130
+
131
+ print(f"✅ Beamer poster saved to: {output_dir}")
132
+ return output_dir
133
+
134
+ def add_beamer_arguments(parser):
135
+ """Add Beamer-specific command line arguments."""
136
+ parser.add_argument(
137
+ '--output_format',
138
+ choices=['pptx', 'beamer'],
139
+ default='pptx',
140
+ help='Output format: pptx (PowerPoint) or beamer (LaTeX)'
141
+ )
142
+ parser.add_argument(
143
+ '--beamer_theme',
144
+ default='default',
145
+ help='Beamer theme name (default, Madrid, Warsaw, etc.)'
146
+ )
147
+ parser.add_argument(
148
+ '--beamer_width_cm',
149
+ type=float,
150
+ default=120,
151
+ help='Beamer poster width in centimeters'
152
+ )
153
+ parser.add_argument(
154
+ '--beamer_height_cm',
155
+ type=float,
156
+ default=90,
157
+ help='Beamer poster height in centimeters'
158
+ )
159
+ return parser
160
+
161
+ # Example integration with existing pipeline
162
+ def integrate_beamer_with_existing_pipeline():
163
+ """
164
+ Example of how to integrate Beamer generation with the existing pipeline.
165
+ """
166
+ # This would be added to the main pipeline function
167
+ pass
168
+
169
+ if __name__ == "__main__":
170
+ parser = argparse.ArgumentParser(description='Generate Beamer poster from paper')
171
+ parser = add_beamer_arguments(parser)
172
+
173
+ # Add other existing arguments...
174
+
175
+ args = parser.parse_args()
176
+
177
+ if args.output_format == 'beamer':
178
+ modify_new_pipeline_for_beamer(args)
179
+ else:
180
+ # Use original PowerPoint pipeline
181
+ pass
182
+
Paper2Poster/PosterAgent/create_dataset.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ import os
3
+ import subprocess
4
+
5
+ from PIL import Image
6
+ import json
7
+
8
+ def generate_meta_json(base_dir='Paper2Poster-data'):
9
+ # Loop over each item in the specified base directory
10
+ for folder_name in os.listdir(base_dir):
11
+ subfolder_path = os.path.join(base_dir, folder_name)
12
+
13
+ # Ensure the item is a directory
14
+ if os.path.isdir(subfolder_path):
15
+ poster_path = os.path.join(subfolder_path, 'poster.png')
16
+
17
+ # Check if the poster.png exists in the subfolder
18
+ if os.path.exists(poster_path):
19
+ try:
20
+ # Open the image and get size (width, height)
21
+ with Image.open(poster_path) as img:
22
+ width, height = img.size
23
+
24
+ # Prepare metadata dictionary
25
+ metadata = {
26
+ 'width': width,
27
+ 'height': height
28
+ }
29
+
30
+ # Write metadata to meta.json in the same subfolder
31
+ meta_json_path = os.path.join(subfolder_path, 'meta.json')
32
+ with open(meta_json_path, 'w') as json_file:
33
+ json.dump(metadata, json_file)
34
+
35
+ print(f"Metadata for '{folder_name}' saved successfully.")
36
+ except Exception as e:
37
+ print(f"Error processing image in folder '{folder_name}': {e}")
38
+ else:
39
+ print(f"No poster.png found in folder '{folder_name}'.")
40
+
41
+ if __name__ == "__main__":
42
+ dataset = load_dataset("Paper2Poster/Paper2Poster", split="train")
43
+ os.makedirs('Paper2Poster-data', exist_ok=True)
44
+ for data in dataset:
45
+ paper_title = data['title']
46
+ paper_url = data['paper_url']
47
+ poster_url = data['image_url']
48
+ qa = data['qa']
49
+
50
+ os.makedirs(f'Paper2Poster-data/{paper_title}', exist_ok=True)
51
+
52
+ paper_output_path = os.path.join('Paper2Poster-data', paper_title, 'paper.pdf')
53
+ poster_output_path = os.path.join('Paper2Poster-data', paper_title, 'poster.png')
54
+ qa_path = os.path.join('Paper2Poster-data', paper_title, 'o3_qa.json')
55
+
56
+ qa_dict = json.loads(qa)
57
+ with open(qa_path, 'w') as f:
58
+ json.dump(qa_dict, f, indent=4)
59
+ print(f"Saved QA for {paper_title} into {qa_path}")
60
+
61
+ try:
62
+ subprocess.run(['wget', paper_url, '-O', paper_output_path], check=True)
63
+ subprocess.run(['wget', poster_url, '-O', poster_output_path], check=True)
64
+ print(f"Downloaded {poster_url} into {poster_output_path}")
65
+ print(f"Downloaded {paper_url} into {paper_output_path}")
66
+ except subprocess.CalledProcessError as e:
67
+ print(f"Error downloading {paper_url} or {poster_url}: {e}")
68
+
69
+ generate_meta_json('Paper2Poster-data')
Paper2Poster/PosterAgent/deoverflow.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import ppt_to_images, get_json_from_response
3
+ import json
4
+
5
+ from camel.models import ModelFactory
6
+ from camel.agents import ChatAgent
7
+
8
+ from utils.wei_utils import *
9
+
10
+ from camel.messages import BaseMessage
11
+ from PIL import Image
12
+ import pickle as pkl
13
+ from utils.pptx_utils import *
14
+ from utils.critic_utils import *
15
+ import yaml
16
+ import argparse
17
+ import shutil
18
+ from jinja2 import Environment, StrictUndefined
19
+
20
+ load_dotenv()
21
+
22
+ MAX_ATTEMPTS = 5
23
+
24
+ def deoverflow(args, actor_config, critic_config):
25
+ total_input_token, total_output_token = 0, 0
26
+ style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb'))
27
+ logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
28
+
29
+ style_logs = style_ckpt['style_logs']
30
+ sections = list(style_logs.keys())
31
+ sections = [s for s in sections if s != 'meta']
32
+
33
+ slide_width = style_ckpt['outline']['meta']['width']
34
+ slide_height = style_ckpt['outline']['meta']['height']
35
+
36
+ content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
37
+ outline = logs_ckpt['outline']
38
+
39
+ name_to_hierarchy = get_hierarchy(outline, 1)
40
+
41
+ critic_agent_name = 'critic_overlap_agent'
42
+ with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f:
43
+ deoverflow_critic_config = yaml.safe_load(f)
44
+
45
+ actor_agent_name = 'actor_editor_agent'
46
+
47
+ with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f:
48
+ deoverflow_actor_config = yaml.safe_load(f)
49
+
50
+ actor_model = ModelFactory.create(
51
+ model_platform=actor_config['model_platform'],
52
+ model_type=actor_config['model_type'],
53
+ model_config_dict=actor_config['model_config'],
54
+ )
55
+
56
+ actor_sys_msg = deoverflow_actor_config['system_prompt']
57
+
58
+ actor_agent = ChatAgent(
59
+ system_message=actor_sys_msg,
60
+ model=actor_model,
61
+ message_window_size=10,
62
+ )
63
+
64
+ critic_model = ModelFactory.create(
65
+ model_platform=critic_config['model_platform'],
66
+ model_type=critic_config['model_type'],
67
+ model_config_dict=critic_config['model_config'],
68
+ )
69
+
70
+ critic_sys_msg = deoverflow_critic_config['system_prompt']
71
+
72
+ critic_agent = ChatAgent(
73
+ system_message=critic_sys_msg,
74
+ model=critic_model,
75
+ message_window_size=None,
76
+ )
77
+
78
+ jinja_env = Environment(undefined=StrictUndefined)
79
+
80
+ actor_template = jinja_env.from_string(deoverflow_actor_config["template"])
81
+ critic_template = jinja_env.from_string(deoverflow_critic_config["template"])
82
+
83
+ critic_logs = {}
84
+ actor_logs = {}
85
+ img_logs = {}
86
+
87
+ # Load neg and pos examples
88
+ neg_img = Image.open('overflow_example/neg.jpg')
89
+ pos_img = Image.open('overflow_example/pos.jpg')
90
+
91
+ for section_index in range(len(sections)):
92
+ section_name = sections[section_index]
93
+ section_code = style_logs[section_name][-1]['code']
94
+
95
+ if 'subsections' in content[section_name]:
96
+ subsections = list(content[section_name]['subsections'].keys())
97
+ else:
98
+ subsections = [section_name]
99
+
100
+ log = []
101
+
102
+ for leaf_section in subsections:
103
+ if leaf_section in outline:
104
+ leaf_name = outline[leaf_section]['name']
105
+ else:
106
+ leaf_name = outline[section_name]['subsections'][leaf_section]['name']
107
+ num_rounds = 0
108
+ while True:
109
+ print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}")
110
+ num_rounds += 1
111
+ if num_rounds > MAX_ATTEMPTS:
112
+ break
113
+ poster = create_poster(slide_width, slide_height)
114
+ add_blank_slide(poster)
115
+ save_presentation(poster, file_name='poster.pptx')
116
+ curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section(
117
+ leaf_section,
118
+ section_name,
119
+ name_to_hierarchy,
120
+ leaf_name,
121
+ section_code
122
+ )
123
+
124
+ if not leaf_section in img_logs:
125
+ img_logs[leaf_section] = []
126
+ img_logs[leaf_section].append(zoomed_in_img)
127
+
128
+ jinja_args = {
129
+ 'content_json': content[leaf_section] if leaf_section in content else content[section_name]['subsections'][leaf_section],
130
+ 'existing_code': section_code,
131
+ }
132
+
133
+ critic_prompt = critic_template.render(**jinja_args)
134
+
135
+ critic_msg = BaseMessage.make_user_message(
136
+ role_name="User",
137
+ content=critic_prompt,
138
+ image_list=[neg_img, pos_img, zoomed_in_img],
139
+ )
140
+
141
+ critic_agent.reset()
142
+ response = critic_agent.step(critic_msg)
143
+ resp = response.msgs[0].content
144
+ input_token, output_token = account_token(response)
145
+ total_input_token += input_token
146
+ total_output_token += output_token
147
+ if not leaf_section in critic_logs:
148
+ critic_logs[leaf_section] = []
149
+
150
+ critic_logs[leaf_section].append(response)
151
+
152
+ if type(resp) == str:
153
+ if resp in ['NO', 'NO.', '"NO"', "'NO'"]:
154
+ break
155
+
156
+ feedback = get_json_from_response(resp)
157
+ print(feedback)
158
+ jinja_args = {
159
+ 'content_json': content[leaf_section] if leaf_section in content else content[section_name]['subsections'][leaf_section],
160
+ 'function_docs': documentation,
161
+ 'existing_code': section_code,
162
+ 'suggestion_json': feedback,
163
+ }
164
+
165
+ actor_prompt = actor_template.render(**jinja_args)
166
+
167
+ log = edit_code(actor_agent, actor_prompt, 3, existing_code='')
168
+ if log[-1]['error'] is not None:
169
+ raise Exception(log[-1]['error'])
170
+
171
+ input_token = log[-1]['cumulative_tokens'][0]
172
+ output_token = log[-1]['cumulative_tokens'][1]
173
+ total_input_token += input_token
174
+ total_output_token += output_token
175
+
176
+ section_code = log[-1]['code']
177
+
178
+ if not leaf_section in actor_logs:
179
+ actor_logs[leaf_section] = []
180
+
181
+ actor_logs[leaf_section].append(log)
182
+ if len(log) > 0:
183
+ style_logs[section_name].append(log[-1])
184
+
185
+ final_code = ''
186
+ for section in sections:
187
+ final_code += style_logs[section][-1]['code'] + '\n'
188
+
189
+ run_code_with_utils(final_code, utils_functions)
190
+ ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview')
191
+
192
+ result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
193
+ if not os.path.exists(result_dir):
194
+ os.makedirs(result_dir)
195
+ shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx')
196
+ ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview')
197
+
198
+ final_code_by_section = {}
199
+ for section in sections:
200
+ final_code_by_section[section] = style_logs[section][-1]['code']
201
+
202
+ non_overlap_ckpt = {
203
+ 'critic_logs': critic_logs,
204
+ 'actor_logs': actor_logs,
205
+ 'img_logs': img_logs,
206
+ 'name_to_hierarchy': name_to_hierarchy,
207
+ 'final_code': final_code,
208
+ 'final_code_by_section': final_code_by_section,
209
+ 'total_input_token': total_input_token,
210
+ 'total_output_token': total_output_token
211
+ }
212
+
213
+ pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb'))
214
+
215
+ return total_input_token, total_output_token
216
+
217
+ if __name__ == '__main__':
218
+ parser = argparse.ArgumentParser()
219
+ parser.add_argument('--poster_name', type=str, default=None)
220
+ parser.add_argument('--model_name', type=str, default='4o')
221
+ parser.add_argument('--poster_path', type=str, required=True)
222
+ parser.add_argument('--index', type=int, default=0)
223
+ parser.add_argument('--max_retry', type=int, default=3)
224
+ args = parser.parse_args()
225
+
226
+ actor_config = get_agent_config(args.model_name)
227
+ critic_config = get_agent_config(args.model_name)
228
+
229
+ if args.poster_name is None:
230
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
231
+
232
+ input_token, output_token = deoverflow(args, actor_config, critic_config)
233
+
234
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/deoverflow_parallel.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import ppt_to_images, get_json_from_response
3
+ import json
4
+
5
+ from camel.models import ModelFactory
6
+ from camel.agents import ChatAgent
7
+
8
+ from utils.wei_utils import *
9
+
10
+ from camel.messages import BaseMessage
11
+ from PIL import Image
12
+ import pickle as pkl
13
+ from utils.pptx_utils import *
14
+ from utils.critic_utils import *
15
+ import yaml
16
+ import argparse
17
+ import shutil
18
+ from jinja2 import Environment, StrictUndefined
19
+ from concurrent.futures import ThreadPoolExecutor
20
+ import copy
21
+
22
+ load_dotenv()
23
+
24
+ MAX_ATTEMPTS = 5
25
+
26
+ def process_leaf_section(
27
+ leaf_section,
28
+ section_name,
29
+ outline,
30
+ content,
31
+ style_logs,
32
+ critic_logs,
33
+ actor_logs,
34
+ img_logs,
35
+ slide_width,
36
+ slide_height,
37
+ name_to_hierarchy,
38
+ critic_template,
39
+ actor_template,
40
+ critic_agent,
41
+ actor_agent,
42
+ neg_img,
43
+ pos_img,
44
+ MAX_ATTEMPTS,
45
+ documentation,
46
+ total_input_token,
47
+ total_output_token,
48
+ ):
49
+ """
50
+ Handles the logic for a single leaf_section within a section_name.
51
+
52
+ Returns a dictionary of updated logs and tokens.
53
+ """
54
+ section_code = style_logs[section_name][-1]['code'] # current code for this section
55
+ log = []
56
+ leaf_name = None
57
+ if leaf_section in outline:
58
+ leaf_name = outline[leaf_section]['name']
59
+ else:
60
+ leaf_name = outline[section_name]['subsections'][leaf_section]['name']
61
+
62
+ num_rounds = 0
63
+ while True:
64
+ print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}")
65
+ num_rounds += 1
66
+ if num_rounds > MAX_ATTEMPTS:
67
+ break
68
+
69
+ poster = create_poster(slide_width, slide_height)
70
+ add_blank_slide(poster)
71
+ empty_poster_path = f'tmp/empty_poster_{section_name}_{leaf_section}.pptx'
72
+ save_presentation(poster, file_name=empty_poster_path)
73
+
74
+ curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section(
75
+ leaf_section,
76
+ section_name,
77
+ name_to_hierarchy,
78
+ leaf_name,
79
+ section_code,
80
+ empty_poster_path
81
+ )
82
+
83
+ if leaf_section not in img_logs:
84
+ img_logs[leaf_section] = []
85
+ img_logs[leaf_section].append(zoomed_in_img)
86
+
87
+ jinja_args = {
88
+ 'content_json': content[leaf_section] if leaf_section in content
89
+ else content[section_name]['subsections'][leaf_section],
90
+ 'existing_code': section_code,
91
+ }
92
+
93
+ critic_prompt = critic_template.render(**jinja_args)
94
+
95
+ critic_msg = BaseMessage.make_user_message(
96
+ role_name="User",
97
+ content=critic_prompt,
98
+ image_list=[neg_img, pos_img, zoomed_in_img],
99
+ )
100
+
101
+ critic_agent.reset()
102
+ response = critic_agent.step(critic_msg)
103
+ resp = response.msgs[0].content
104
+
105
+ # Track tokens
106
+ input_token, output_token = account_token(response)
107
+ total_input_token += input_token
108
+ total_output_token += output_token
109
+
110
+ if leaf_section not in critic_logs:
111
+ critic_logs[leaf_section] = []
112
+ critic_logs[leaf_section].append(response)
113
+
114
+ # Stop condition
115
+ if isinstance(resp, str):
116
+ if resp in ['NO', 'NO.', '"NO"', "'NO'"]:
117
+ break
118
+
119
+ feedback = get_json_from_response(resp)
120
+ print(feedback)
121
+
122
+ jinja_args = {
123
+ 'content_json': content[leaf_section] if leaf_section in content
124
+ else content[section_name]['subsections'][leaf_section],
125
+ 'function_docs': documentation,
126
+ 'existing_code': section_code,
127
+ 'suggestion_json': feedback,
128
+ }
129
+
130
+ actor_prompt = actor_template.render(**jinja_args)
131
+
132
+ leaf_log = edit_code(actor_agent, actor_prompt, 3, existing_code='')
133
+ if leaf_log[-1]['error'] is not None:
134
+ raise Exception(leaf_log[-1]['error'])
135
+
136
+ # Track tokens
137
+ in_tok = leaf_log[-1]['cumulative_tokens'][0]
138
+ out_tok = leaf_log[-1]['cumulative_tokens'][1]
139
+ total_input_token += in_tok
140
+ total_output_token += out_tok
141
+
142
+ section_code = leaf_log[-1]['code']
143
+
144
+ if leaf_section not in actor_logs:
145
+ actor_logs[leaf_section] = []
146
+ actor_logs[leaf_section].append(leaf_log)
147
+
148
+ log.extend(leaf_log)
149
+
150
+ return {
151
+ "section_code": section_code,
152
+ "log": log,
153
+ "img_logs": img_logs,
154
+ "critic_logs": critic_logs,
155
+ "actor_logs": actor_logs,
156
+ "total_input_token": total_input_token,
157
+ "total_output_token": total_output_token,
158
+ }
159
+
160
+
161
+ def process_section(
162
+ section_name,
163
+ content,
164
+ outline,
165
+ sections,
166
+ style_logs,
167
+ critic_logs,
168
+ actor_logs,
169
+ img_logs,
170
+ slide_width,
171
+ slide_height,
172
+ name_to_hierarchy,
173
+ critic_template,
174
+ actor_template,
175
+ critic_agent,
176
+ actor_agent,
177
+ neg_img,
178
+ pos_img,
179
+ MAX_ATTEMPTS,
180
+ documentation,
181
+ total_input_token,
182
+ total_output_token,
183
+ ):
184
+ """
185
+ Handles processing of a single section and its subsections (leaf sections).
186
+ Returns updated logs and token counters for this section.
187
+ """
188
+ results_per_leaf = []
189
+
190
+ # Grab the current code for this section
191
+ section_code = style_logs[section_name][-1]['code']
192
+
193
+ # Determine which leaf sections to process
194
+ if 'subsections' in content[section_name]:
195
+ subsections = list(content[section_name]['subsections'].keys())
196
+ else:
197
+ subsections = [section_name]
198
+
199
+ all_logs_for_section = []
200
+
201
+ for leaf_section in subsections:
202
+ # Process this leaf section
203
+ leaf_result = process_leaf_section(
204
+ leaf_section,
205
+ section_name,
206
+ outline,
207
+ content,
208
+ style_logs,
209
+ critic_logs,
210
+ actor_logs,
211
+ img_logs,
212
+ slide_width,
213
+ slide_height,
214
+ name_to_hierarchy,
215
+ critic_template,
216
+ actor_template,
217
+ critic_agent,
218
+ actor_agent,
219
+ neg_img,
220
+ pos_img,
221
+ MAX_ATTEMPTS,
222
+ documentation,
223
+ total_input_token,
224
+ total_output_token,
225
+ )
226
+
227
+ # Update logs/tokens
228
+ section_code = leaf_result["section_code"]
229
+ all_logs_for_section.extend(leaf_result["log"])
230
+ img_logs = leaf_result["img_logs"]
231
+ critic_logs = leaf_result["critic_logs"]
232
+ actor_logs = leaf_result["actor_logs"]
233
+ total_input_token = leaf_result["total_input_token"]
234
+ total_output_token = leaf_result["total_output_token"]
235
+
236
+ # If we have any logs from the last leaf in this section, append them
237
+ if all_logs_for_section:
238
+ style_logs[section_name].append(all_logs_for_section[-1])
239
+
240
+ # Return updated state for merging back in the main thread
241
+ return {
242
+ "section_name": section_name,
243
+ "style_logs": style_logs,
244
+ "critic_logs": critic_logs,
245
+ "actor_logs": actor_logs,
246
+ "img_logs": img_logs,
247
+ "total_input_token": total_input_token,
248
+ "total_output_token": total_output_token
249
+ }
250
+
251
+ def parallel_by_sections(
252
+ sections,
253
+ content,
254
+ outline,
255
+ style_logs,
256
+ critic_logs,
257
+ actor_logs,
258
+ img_logs,
259
+ slide_width,
260
+ slide_height,
261
+ name_to_hierarchy,
262
+ critic_template,
263
+ actor_template,
264
+ critic_agent,
265
+ actor_agent,
266
+ neg_img,
267
+ pos_img,
268
+ MAX_ATTEMPTS,
269
+ documentation,
270
+ total_input_token,
271
+ total_output_token,
272
+ max_workers=4
273
+ ):
274
+ """
275
+ Main entry point to parallelize processing across sections.
276
+
277
+ Returns the merged logs and token counters after processing all sections in parallel.
278
+ """
279
+ # Because we’ll be modifying dictionaries (like style_logs, etc.),
280
+ # it can be safer to create a copy for the workers, then merge results
281
+ # after. (Below is a simple approach—depending on your scale, consider
282
+ # explicit concurrency controls or a database-backed approach.)
283
+
284
+ # Summaries from each future
285
+ results = []
286
+
287
+ # We’ll store fresh copies for each section to avoid concurrency collisions
288
+ # on dictionary updates. If the data is large, you might want a more
289
+ # sophisticated synchronization or partition approach rather than naive copies.
290
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
291
+ futures = []
292
+
293
+ for section_name in sections:
294
+ # Make shallow copies or deep copies of logs
295
+ _style_logs = copy.deepcopy(style_logs)
296
+ _critic_logs = copy.deepcopy(critic_logs)
297
+ _actor_logs = copy.deepcopy(actor_logs)
298
+ _img_logs = copy.deepcopy(img_logs)
299
+
300
+ futures.append(executor.submit(
301
+ process_section,
302
+ section_name,
303
+ content,
304
+ outline,
305
+ sections,
306
+ _style_logs,
307
+ _critic_logs,
308
+ _actor_logs,
309
+ _img_logs,
310
+ slide_width,
311
+ slide_height,
312
+ name_to_hierarchy,
313
+ critic_template,
314
+ actor_template,
315
+ critic_agent,
316
+ actor_agent,
317
+ neg_img,
318
+ pos_img,
319
+ MAX_ATTEMPTS,
320
+ documentation,
321
+ total_input_token,
322
+ total_output_token
323
+ ))
324
+
325
+ for future in futures:
326
+ results.append(future.result())
327
+
328
+ # The code below merges the results. The method of merging depends on how
329
+ # you prefer to aggregate. For a minimal approach, we’ll pick the logs from
330
+ # each section, then overwrite or update them in the main dictionaries.
331
+
332
+ for res in results:
333
+ sec_name = res["section_name"]
334
+ # Overwrite or merge logs as needed
335
+ style_logs[sec_name] = res["style_logs"][sec_name]
336
+ critic_logs.update(res["critic_logs"])
337
+ actor_logs.update(res["actor_logs"])
338
+ img_logs.update(res["img_logs"])
339
+ total_input_token = res["total_input_token"]
340
+ total_output_token = res["total_output_token"]
341
+
342
+ return style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token
343
+
344
+
345
+ def deoverflow(args, actor_config, critic_config):
346
+ total_input_token, total_output_token = 0, 0
347
+ style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb'))
348
+ logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
349
+
350
+ style_logs = style_ckpt['style_logs']
351
+ sections = list(style_logs.keys())
352
+ sections = [s for s in sections if s != 'meta']
353
+
354
+ slide_width = style_ckpt['outline']['meta']['width']
355
+ slide_height = style_ckpt['outline']['meta']['height']
356
+
357
+ content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
358
+ outline = logs_ckpt['outline']
359
+
360
+ name_to_hierarchy = get_hierarchy(outline, 1)
361
+
362
+ critic_agent_name = 'critic_overlap_agent'
363
+ with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f:
364
+ deoverflow_critic_config = yaml.safe_load(f)
365
+
366
+ actor_agent_name = 'actor_editor_agent'
367
+
368
+ with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f:
369
+ deoverflow_actor_config = yaml.safe_load(f)
370
+
371
+ actor_model = ModelFactory.create(
372
+ model_platform=actor_config['model_platform'],
373
+ model_type=actor_config['model_type'],
374
+ model_config_dict=actor_config['model_config'],
375
+ )
376
+
377
+ actor_sys_msg = deoverflow_actor_config['system_prompt']
378
+
379
+ actor_agent = ChatAgent(
380
+ system_message=actor_sys_msg,
381
+ model=actor_model,
382
+ message_window_size=10,
383
+ )
384
+
385
+ critic_model = ModelFactory.create(
386
+ model_platform=critic_config['model_platform'],
387
+ model_type=critic_config['model_type'],
388
+ model_config_dict=critic_config['model_config'],
389
+ )
390
+
391
+ critic_sys_msg = deoverflow_critic_config['system_prompt']
392
+
393
+ critic_agent = ChatAgent(
394
+ system_message=critic_sys_msg,
395
+ model=critic_model,
396
+ message_window_size=None,
397
+ )
398
+
399
+ jinja_env = Environment(undefined=StrictUndefined)
400
+
401
+ actor_template = jinja_env.from_string(deoverflow_actor_config["template"])
402
+ critic_template = jinja_env.from_string(deoverflow_critic_config["template"])
403
+
404
+ critic_logs = {}
405
+ actor_logs = {}
406
+ img_logs = {}
407
+
408
+ # Load neg and pos examples
409
+ neg_img = Image.open('overflow_example/neg.jpg')
410
+ pos_img = Image.open('overflow_example/pos.jpg')
411
+
412
+ style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token = parallel_by_sections(
413
+ sections=sections,
414
+ content=content,
415
+ outline=outline,
416
+ style_logs=style_logs,
417
+ critic_logs=critic_logs,
418
+ actor_logs=actor_logs,
419
+ img_logs=img_logs,
420
+ slide_width=slide_width,
421
+ slide_height=slide_height,
422
+ name_to_hierarchy=name_to_hierarchy,
423
+ critic_template=critic_template,
424
+ actor_template=actor_template,
425
+ critic_agent=critic_agent,
426
+ actor_agent=actor_agent,
427
+ neg_img=neg_img,
428
+ pos_img=pos_img,
429
+ MAX_ATTEMPTS=MAX_ATTEMPTS,
430
+ documentation=documentation,
431
+ total_input_token=total_input_token,
432
+ total_output_token=total_output_token,
433
+ max_workers=100, # or however many worker threads you want
434
+ )
435
+
436
+ final_code = ''
437
+ for section in sections:
438
+ final_code += style_logs[section][-1]['code'] + '\n'
439
+
440
+ run_code_with_utils(final_code, utils_functions)
441
+ ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview')
442
+
443
+ result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
444
+ if not os.path.exists(result_dir):
445
+ os.makedirs(result_dir)
446
+ shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx')
447
+ ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview')
448
+
449
+ final_code_by_section = {}
450
+ for section in sections:
451
+ final_code_by_section[section] = style_logs[section][-1]['code']
452
+
453
+ non_overlap_ckpt = {
454
+ 'critic_logs': critic_logs,
455
+ 'actor_logs': actor_logs,
456
+ 'img_logs': img_logs,
457
+ 'name_to_hierarchy': name_to_hierarchy,
458
+ 'final_code': final_code,
459
+ 'final_code_by_section': final_code_by_section,
460
+ 'total_input_token': total_input_token,
461
+ 'total_output_token': total_output_token
462
+ }
463
+
464
+ pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb'))
465
+
466
+ return total_input_token, total_output_token
467
+
468
+ if __name__ == '__main__':
469
+ parser = argparse.ArgumentParser()
470
+ parser.add_argument('--poster_name', type=str, default=None)
471
+ parser.add_argument('--model_name', type=str, default='4o')
472
+ parser.add_argument('--poster_path', type=str, required=True)
473
+ parser.add_argument('--index', type=int, default=0)
474
+ parser.add_argument('--max_retry', type=int, default=3)
475
+ args = parser.parse_args()
476
+
477
+ actor_config = get_agent_config(args.model_name)
478
+ critic_config = get_agent_config(args.model_name)
479
+
480
+ if args.poster_name is None:
481
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
482
+
483
+ input_token, output_token = deoverflow(args, actor_config, critic_config)
484
+
485
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/fill_and_style.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ from utils.src.utils import ppt_to_images, get_json_from_response
4
+ import json
5
+ import pptx
6
+
7
+ from camel.models import ModelFactory
8
+ from camel.types import ModelPlatformType, ModelType
9
+ from camel.configs import ChatGPTConfig, QwenConfig
10
+ from camel.agents import ChatAgent
11
+
12
+ from utils.wei_utils import fill_content
13
+
14
+ from camel.messages import BaseMessage
15
+ from PIL import Image
16
+ import pickle as pkl
17
+ from utils.pptx_utils import *
18
+ from utils.critic_utils import *
19
+ from utils.wei_utils import *
20
+ import importlib
21
+ import yaml
22
+ import os
23
+ import shutil
24
+ from datetime import datetime
25
+ from jinja2 import Environment, StrictUndefined, Template
26
+ import argparse
27
+
28
+ load_dotenv()
29
+
30
+ def fill_poster_content(args, actor_config):
31
+ total_input_token, total_output_token = 0, 0
32
+ poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
33
+ agent_name = 'content_filler_agent'
34
+
35
+ with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
36
+ fill_config = yaml.safe_load(f)
37
+
38
+ actor_model = ModelFactory.create(
39
+ model_platform=actor_config['model_platform'],
40
+ model_type=actor_config['model_type'],
41
+ model_config_dict=actor_config['model_config'],
42
+ )
43
+
44
+ actor_sys_msg = fill_config['system_prompt']
45
+
46
+ actor_agent = ChatAgent(
47
+ system_message=actor_sys_msg,
48
+ model=actor_model,
49
+ message_window_size=10,
50
+ )
51
+
52
+ ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
53
+ logs = ckpt['logs']
54
+ outline = ckpt['outline']
55
+
56
+ sections = list(outline.keys())
57
+ sections = [s for s in sections if s != 'meta']
58
+
59
+ jinja_env = Environment(undefined=StrictUndefined)
60
+
61
+ template = jinja_env.from_string(fill_config["template"])
62
+ content_logs = {}
63
+
64
+ for section_index in range(len(sections)):
65
+ section_name = sections[section_index]
66
+ section_code = logs[section_name][-1]['code']
67
+
68
+ print(f'Filling content for {section_name}')
69
+
70
+ jinja_args = {
71
+ 'content_json': poster_content[section_name],
72
+ 'function_docs': documentation,
73
+ 'existing_code': section_code
74
+ }
75
+
76
+ prompt = template.render(**jinja_args)
77
+ if section_index == 0:
78
+ existing_code = ''
79
+ else:
80
+ existing_code = content_logs[sections[section_index - 1]][-1]['concatenated_code']
81
+ content_logs[section_name] = fill_content(
82
+ actor_agent,
83
+ prompt,
84
+ 3,
85
+ existing_code
86
+ )
87
+
88
+ shutil.copy('poster.pptx', f'tmp/content_poster_<{section_name}>.pptx')
89
+
90
+ if content_logs[section_name][-1]['error'] is not None:
91
+ raise Exception(f'Error in filling content for {section_name}: {content_logs[section_name][-1]["error"]}')
92
+
93
+ total_input_token += content_logs[section_name][-1]['cumulative_tokens'][0]
94
+ total_output_token += content_logs[section_name][-1]['cumulative_tokens'][1]
95
+
96
+ ppt_to_images(f'tmp/content_poster_<{sections[-1]}>.pptx', 'tmp/content_preview')
97
+
98
+ ckpt = {
99
+ 'logs': logs,
100
+ 'content_logs': content_logs,
101
+ 'outline': outline,
102
+ 'total_input_token': total_input_token,
103
+ 'total_output_token': total_output_token
104
+ }
105
+
106
+ pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'wb'))
107
+
108
+ return total_input_token, total_output_token
109
+
110
+ def stylize_poster(args, actor_config):
111
+ total_input_token, total_output_token = 0, 0
112
+ poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
113
+ agent_name = 'style_agent'
114
+
115
+ with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
116
+ style_config = yaml.safe_load(f)
117
+
118
+ actor_model = ModelFactory.create(
119
+ model_platform=actor_config['model_platform'],
120
+ model_type=actor_config['model_type'],
121
+ model_config_dict=actor_config['model_config'],
122
+ )
123
+
124
+ actor_sys_msg = style_config['system_prompt']
125
+
126
+ actor_agent = ChatAgent(
127
+ system_message=actor_sys_msg,
128
+ model=actor_model,
129
+ message_window_size=10,
130
+ )
131
+
132
+ ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'rb'))
133
+ content_logs = ckpt['content_logs']
134
+ outline = ckpt['outline']
135
+
136
+ sections = list(outline.keys())
137
+ sections = [s for s in sections if s != 'meta']
138
+
139
+ jinja_env = Environment(undefined=StrictUndefined)
140
+
141
+ template = jinja_env.from_string(style_config["template"])
142
+ style_logs = {}
143
+
144
+ for section_index in range(len(sections)):
145
+ section_name = sections[section_index]
146
+ section_outline = json.dumps(outline[section_name])
147
+ section_code = content_logs[section_name][-1]['code']
148
+
149
+ print(f'Stylizing for {section_name}')
150
+
151
+ img_ratio_json = get_img_ratio_in_section(poster_content[section_name])
152
+
153
+ jinja_args = {
154
+ 'content_json': poster_content[section_name],
155
+ 'function_docs': documentation,
156
+ 'existing_code': section_code,
157
+ 'image_ratio': img_ratio_json,
158
+ }
159
+
160
+ prompt = template.render(**jinja_args)
161
+ if section_index == 0:
162
+ existing_code = ''
163
+ else:
164
+ existing_code = style_logs[sections[section_index - 1]][-1]['concatenated_code']
165
+ style_logs[section_name] = stylize(
166
+ actor_agent,
167
+ prompt,
168
+ args.max_retry,
169
+ existing_code
170
+ )
171
+
172
+ shutil.copy('poster.pptx', f'tmp/style_poster_<{section_name}>.pptx')
173
+
174
+ if style_logs[section_name][-1]['error'] is not None:
175
+ raise Exception(f'Error in stylizing for {section_name}')
176
+
177
+ total_input_token += style_logs[section_name][-1]['cumulative_tokens'][0]
178
+ total_output_token += style_logs[section_name][-1]['cumulative_tokens'][1]
179
+
180
+ ppt_to_images(f'tmp/style_poster_<{sections[-1]}>.pptx', 'tmp/style_preview')
181
+ ckpt = {
182
+ 'logs': ckpt['logs'],
183
+ 'content_logs': content_logs,
184
+ 'style_logs': style_logs,
185
+ 'outline': outline,
186
+ 'total_input_token': total_input_token,
187
+ 'total_output_token': total_output_token
188
+ }
189
+
190
+ with open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'wb') as f:
191
+ pkl.dump(ckpt, f)
192
+
193
+ return total_input_token, total_output_token
194
+
195
+ if __name__ == '__main__':
196
+ parser = argparse.ArgumentParser()
197
+ parser.add_argument('--poster_name', type=str, default=None)
198
+ parser.add_argument('--model_name', type=str, default='4o')
199
+ parser.add_argument('--poster_path', type=str, required=True)
200
+ parser.add_argument('--index', type=int, default=0)
201
+ parser.add_argument('--max_retry', type=int, default=3)
202
+ args = parser.parse_args()
203
+
204
+ actor_config = get_agent_config(args.model_name)
205
+
206
+ if args.poster_name is None:
207
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
208
+
209
+ fill_total_input_token, fill_total_output_token = fill_poster_content(args, actor_config)
210
+ style_total_input_token, style_total_output_token = stylize_poster(args, actor_config)
211
+
212
+ total_input_token = fill_total_input_token + style_total_input_token
213
+ total_output_token = fill_total_output_token + style_total_output_token
214
+
215
+ print(f'Token consumption: {total_input_token} -> {total_output_token}')
Paper2Poster/PosterAgent/gen_beamer_code.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+ import os
4
+ from typing import List, Dict, Any
5
+
6
+ def sanitize_for_latex(name):
7
+ """Convert any character that is not alphanumeric into underscore for LaTeX compatibility."""
8
+ return re.sub(r'[^0-9a-zA-Z_]+', '_', name)
9
+
10
+ def initialize_beamer_document(width_cm=120, height_cm=90, theme="default"):
11
+ """
12
+ Initialize a Beamer document with specified dimensions and theme.
13
+
14
+ Args:
15
+ width_cm: Width in centimeters (default 120cm for poster)
16
+ height_cm: Height in centimeters (default 90cm for poster)
17
+ theme: Beamer theme name (default, Madrid, Warsaw, etc.)
18
+ """
19
+ code = f'''\\documentclass[aspectratio=169]{{beamer}}
20
+ \\usepackage[utf8]{{inputenc}}
21
+ \\usepackage[T1]{{fontenc}}
22
+ \\usepackage{{graphicx}}
23
+ \\usepackage{{tikz}}
24
+ \\usepackage{{xcolor}}
25
+ \\usepackage{{geometry}}
26
+ \\usepackage{{multicol}}
27
+ \\usepackage{{array}}
28
+ \\usepackage{{booktabs}}
29
+ \\usepackage{{adjustbox}}
30
+
31
+ % Set page dimensions for poster
32
+ \\geometry{{paperwidth={width_cm}cm, paperheight={height_cm}cm, margin=1cm}}
33
+
34
+ % Beamer theme
35
+ \\usetheme{{{theme}}}
36
+ \\usecolortheme{{default}}
37
+
38
+ % Custom colors
39
+ \\definecolor{{titlecolor}}{{RGB}}{{47, 85, 151}}
40
+ \\definecolor{{textcolor}}{{RGB}}{{0, 0, 0}}
41
+ \\definecolor{{bgcolor}}{{RGB}}{{255, 255, 255}}
42
+
43
+ % Remove navigation symbols
44
+ \\setbeamertemplate{{navigation symbols}}{{}}
45
+
46
+ % Custom title page
47
+ \\setbeamertemplate{{title page}}{{
48
+ \\begin{{center}}
49
+ \\vspace{{1cm}}
50
+ {{\\color{{titlecolor}}\\Huge\\textbf{{\\inserttitle}}}}
51
+ \\vspace{{0.5cm}}
52
+ \\Large{{\\insertauthor}}
53
+ \\vspace{{0.3cm}}
54
+ \\normalsize{{\\insertinstitute}}
55
+ \\end{{center}}
56
+ }}
57
+
58
+ % Custom frame title
59
+ \\setbeamertemplate{{frametitle}}{{
60
+ \\vspace{{0.5cm}}
61
+ \\begin{{flushleft}}
62
+ {{\\color{{titlecolor}}\\Large\\textbf{{\\insertframetitle}}}}
63
+ \\end{{flushleft}}
64
+ \\vspace{{0.3cm}}
65
+ }}
66
+
67
+ \\begin{{document}}
68
+
69
+ % Title frame
70
+ \\title{{POSTER_TITLE_PLACEHOLDER}}
71
+ \\author{{POSTER_AUTHOR_PLACEHOLDER}}
72
+ \\institute{{POSTER_INSTITUTE_PLACEHOLDER}}
73
+ \\date{{\\today}}
74
+
75
+ \\begin{{frame}}[plain]
76
+ \\titlepage
77
+ \\end{{frame}}
78
+
79
+ '''
80
+ return code
81
+
82
+ def generate_beamer_section_code(section_data: Dict[str, Any], section_index: int):
83
+ """
84
+ 兼容 Paper2Poster bullet JSON:
85
+ - section_data 包含 title_blocks / textbox1_blocks / textbox2_blocks
86
+ - 每个 *_blocks 是 list[ {bullet: bool, runs: [{text: str, ...}], ...} ]
87
+ """
88
+ def blocks_to_lines(blocks):
89
+ """把 blocks 转成 list[str],并标注是否 bullet"""
90
+ lines = []
91
+ for blk in blocks or []:
92
+ text = " ".join([r.get("text","") for r in blk.get("runs", [])]).strip()
93
+ if not text:
94
+ continue
95
+ lines.append({
96
+ "text": text,
97
+ "bullet": bool(blk.get("bullet", False))
98
+ })
99
+ return lines
100
+
101
+ # Frame title 优先用 title_blocks 的文本,否则用 title_str,否则 Untitled
102
+ if isinstance(section_data.get("title_blocks"), list) and section_data["title_blocks"]:
103
+ frame_title = " ".join([r.get("text","") for r in section_data["title_blocks"][0].get("runs", [])]).strip()
104
+ else:
105
+ frame_title = section_data.get("title_str") or "Untitled"
106
+
107
+ frame_title = frame_title.replace("{","\\{").replace("}","\\}") # 简单转义以防标题含花括号
108
+
109
+ code = f"\n% ===== Section {section_index} =====\n"
110
+ code += f"\\begin{{frame}}[t]{{{frame_title}}}\n"
111
+ code += " \\vspace{-0.5cm}\n"
112
+
113
+ for key in ["textbox1_blocks", "textbox2_blocks"]:
114
+ lines = blocks_to_lines(section_data.get(key, []))
115
+ if not lines:
116
+ continue
117
+
118
+ # 如果全是 bullet,就合并成一个 itemize;否则分别处理
119
+ if all(l["bullet"] for l in lines):
120
+ code += " \\begin{itemize}\n"
121
+ for l in lines:
122
+ code += f" \\item {l['text']}\n"
123
+ code += " \\end{itemize}\n"
124
+ else:
125
+ for l in lines:
126
+ if l["bullet"]:
127
+ code += f" \\begin{{itemize}}\\item {l['text']}\\end{{itemize}}\n"
128
+ else:
129
+ code += f" {l['text']}\\\\\n"
130
+
131
+ code += "\\end{frame}\n\n"
132
+ return code
133
+
134
+
135
+
136
+ def generate_beamer_figure_code(figure_data: Dict[str, Any], figure_index: int):
137
+ """
138
+ Generate Beamer code for including figures.
139
+
140
+ Args:
141
+ figure_data: Dictionary containing figure information
142
+ figure_index: Index of the figure
143
+ """
144
+ figure_name = sanitize_for_latex(figure_data.get('figure_name', f'figure_{figure_index}'))
145
+ figure_path = figure_data.get('figure_path', '')
146
+
147
+ # Convert inches to centimeters (1 inch = 2.54 cm)
148
+ width_cm = figure_data.get('width', 10) * 2.54
149
+ height_cm = figure_data.get('height', 8) * 2.54
150
+
151
+ code = f'''
152
+ % Figure: {figure_name}
153
+ \\begin{{frame}}[t]{{{figure_data.get('title', 'Figure')}}}
154
+ \\vspace{{-0.5cm}}
155
+ \\begin{{center}}
156
+ \\includegraphics[width={width_cm:.2f}cm, height={height_cm:.2f}cm]{{{figure_path}}}
157
+ \\end{{center}}
158
+ \\vspace{{0.3cm}}
159
+ \\begin{{center}}
160
+ \\small{{\\textbf{{{figure_data.get('caption', 'Figure Caption')}}}}}
161
+ \\end{{center}}
162
+ \\end{{frame}}
163
+
164
+ '''
165
+ return code
166
+
167
+ def generate_beamer_poster_code(
168
+ sections: List[Dict[str, Any]],
169
+ figures: List[Dict[str, Any]],
170
+ poster_info: Dict[str, str],
171
+ width_cm: float = 120,
172
+ height_cm: float = 90,
173
+ theme: str = "default",
174
+ output_path: str = "poster.tex"
175
+ ):
176
+ """
177
+ Generate complete Beamer poster code.
178
+
179
+ Args:
180
+ sections: List of section dictionaries
181
+ figures: List of figure dictionaries
182
+ poster_info: Dictionary with title, author, institute
183
+ width_cm: Poster width in centimeters
184
+ height_cm: Poster height in centimeters
185
+ theme: Beamer theme name
186
+ output_path: Output .tex file path
187
+ """
188
+ code = initialize_beamer_document(width_cm, height_cm, theme)
189
+
190
+ # Replace placeholders with actual content
191
+ code = code.replace('POSTER_TITLE_PLACEHOLDER', poster_info.get('title', 'Poster Title'))
192
+ code = code.replace('POSTER_AUTHOR_PLACEHOLDER', poster_info.get('author', 'Author Name'))
193
+ code = code.replace('POSTER_INSTITUTE_PLACEHOLDER', poster_info.get('institute', 'Institute Name'))
194
+
195
+ # Add sections
196
+ for i, section in enumerate(sections):
197
+ code += generate_beamer_section_code(section, i)
198
+
199
+ # Add figures
200
+ for i, figure in enumerate(figures):
201
+ code += generate_beamer_figure_code(figure, i)
202
+
203
+ # Close document
204
+ code += '''
205
+ \\end{document}
206
+ '''
207
+
208
+ return code
209
+
210
+ def save_beamer_code(code: str, output_path: str):
211
+ """Save Beamer code to file."""
212
+ with open(output_path, 'w', encoding='utf-8') as f:
213
+ f.write(code)
214
+
215
+ def compile_beamer_to_pdf(tex_path: str, output_dir: str = "."):
216
+ """
217
+ Compile Beamer .tex file to PDF using pdflatex.
218
+
219
+ Args:
220
+ tex_path: Path to .tex file
221
+ output_dir: Output directory for PDF
222
+ """
223
+ import subprocess
224
+
225
+ try:
226
+ # Run pdflatex twice for proper cross-references
227
+ result1 = subprocess.run(
228
+ ['pdflatex', '-output-directory', output_dir, tex_path],
229
+ capture_output=True,
230
+ text=True,
231
+ timeout=60
232
+ )
233
+
234
+ result2 = subprocess.run(
235
+ ['pdflatex', '-output-directory', output_dir, tex_path],
236
+ capture_output=True,
237
+ text=True,
238
+ timeout=60
239
+ )
240
+
241
+ if result1.returncode == 0 and result2.returncode == 0:
242
+ print(f"Successfully compiled {tex_path} to PDF")
243
+ return True
244
+ else:
245
+ print(f"Error compiling {tex_path}:")
246
+ print(result1.stderr)
247
+ print(result2.stderr)
248
+ return False
249
+
250
+ except subprocess.TimeoutExpired:
251
+ print(f"Timeout while compiling {tex_path}")
252
+ return False
253
+ except Exception as e:
254
+ print(f"Error compiling {tex_path}: {e}")
255
+ return False
256
+
257
+ # Example usage functions
258
+ def convert_pptx_layout_to_beamer(pptx_layout_data: Dict[str, Any]) -> Dict[str, Any]:
259
+ """
260
+ Convert PowerPoint layout data to Beamer-compatible format.
261
+
262
+ Args:
263
+ pptx_layout_data: Layout data from PowerPoint generation
264
+ """
265
+ beamer_data = {
266
+ 'sections': [],
267
+ 'figures': [],
268
+ 'poster_info': {
269
+ 'title': 'Default Title',
270
+ 'author': 'Default Author',
271
+ 'institute': 'Default Institute'
272
+ }
273
+ }
274
+
275
+ # Convert text arrangements to sections
276
+ if 'text_arrangement' in pptx_layout_data:
277
+ for i, text_item in enumerate(pptx_layout_data['text_arrangement']):
278
+ section = {
279
+ 'section_name': text_item.get('textbox_name', f'section_{i}'),
280
+ 'title': text_item.get('title', f'Section {i+1}'),
281
+ 'content': text_item.get('content', 'Content placeholder')
282
+ }
283
+ beamer_data['sections'].append(section)
284
+
285
+ # Convert figure arrangements to figures
286
+ if 'figure_arrangement' in pptx_layout_data:
287
+ for i, figure_item in enumerate(pptx_layout_data['figure_arrangement']):
288
+ figure = {
289
+ 'figure_name': figure_item.get('figure_name', f'figure_{i}'),
290
+ 'figure_path': figure_item.get('figure_path', ''),
291
+ 'width': figure_item.get('width', 10),
292
+ 'height': figure_item.get('height', 8),
293
+ 'title': figure_item.get('title', f'Figure {i+1}'),
294
+ 'caption': figure_item.get('caption', 'Figure caption')
295
+ }
296
+ beamer_data['figures'].append(figure)
297
+
298
+ return beamer_data
299
+
Paper2Poster/PosterAgent/gen_outline_layout.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import json
4
+ import copy
5
+ import yaml
6
+ from jinja2 import Environment, StrictUndefined
7
+
8
+ from utils.src.utils import ppt_to_images, get_json_from_response
9
+
10
+ from camel.models import ModelFactory
11
+ from camel.agents import ChatAgent
12
+ from camel.messages import BaseMessage
13
+
14
+ from utils.pptx_utils import *
15
+ from utils.wei_utils import *
16
+
17
+ import pickle as pkl
18
+ import argparse
19
+
20
+ load_dotenv()
21
+
22
+ IMAGE_SCALE_RATIO_MIN = 50
23
+ IMAGE_SCALE_RATIO_MAX = 40
24
+ TABLE_SCALE_RATIO_MIN = 100
25
+ TABLE_SCALE_RATIO_MAX = 80
26
+
27
+ def compute_tp(raw_content_json):
28
+ total_length = 0
29
+ for section in raw_content_json['sections']:
30
+ total_length += len(section['content'])
31
+
32
+ for i in range(len(raw_content_json['sections'])):
33
+ raw_content_json['sections'][i]['tp'] = len(raw_content_json['sections'][i]['content']) / total_length
34
+ raw_content_json['sections'][i]['text_len'] = len(raw_content_json['sections'][i]['content'])
35
+
36
+ def compute_gp(table_info, image_info):
37
+ total_area = 0
38
+ for k, v in table_info.items():
39
+ total_area += v['figure_size']
40
+
41
+ for k, v in image_info.items():
42
+ total_area += v['figure_size']
43
+
44
+ for k, v in table_info.items():
45
+ v['gp'] = v['figure_size'] / total_area
46
+
47
+ for k, v in image_info.items():
48
+ v['gp'] = v['figure_size'] / total_area
49
+
50
+ def get_outline_location(outline, subsection=False):
51
+ outline_location = {}
52
+ for k, v in outline.items():
53
+ if k == 'meta':
54
+ continue
55
+ outline_location[k] = {
56
+ 'location': v['location'],
57
+ }
58
+ if subsection:
59
+ if 'subsections' in v:
60
+ outline_location[k]['subsections'] = get_outline_location(v['subsections'])
61
+ return outline_location
62
+
63
+ def apply_outline_location(outline, location, subsection=False):
64
+ new_outline = {}
65
+ for k, v in outline.items():
66
+ if k == 'meta':
67
+ new_outline[k] = v
68
+ continue
69
+ new_outline[k] = copy.deepcopy(v)
70
+ new_outline[k]['location'] = location[k]['location']
71
+ if subsection:
72
+ if 'subsections' in v:
73
+ new_outline[k]['subsections'] = apply_outline_location(v['subsections'], location[k]['subsections'])
74
+
75
+ return new_outline
76
+
77
+ def fill_location(outline, section_name, location_dict):
78
+ new_outline = copy.deepcopy(outline)
79
+ if 'subsections' not in new_outline[section_name]:
80
+ return new_outline
81
+ for k, v in new_outline[section_name]['subsections'].items():
82
+ v['location'] = location_dict[k]['location']
83
+ return new_outline
84
+
85
+ def recover_name_and_location(outline_no_name, outline):
86
+ new_outline = copy.deepcopy(outline_no_name)
87
+ for k, v in outline_no_name.items():
88
+ if k == 'meta':
89
+ continue
90
+ new_outline[k]['name'] = outline[k]['name']
91
+ if type(new_outline[k]['location']) == list:
92
+ new_outline[k]['location'] = {
93
+ 'left': v['location'][0],
94
+ 'top': v['location'][1],
95
+ 'width': v['location'][2],
96
+ 'height': v['location'][3]
97
+ }
98
+ if 'subsections' in v:
99
+ for k_sub, v_sub in v['subsections'].items():
100
+ new_outline[k]['subsections'][k_sub]['name'] = outline[k]['subsections'][k_sub]['name']
101
+ if type(new_outline[k]['subsections'][k_sub]['location']) == list:
102
+ new_outline[k]['subsections'][k_sub]['location'] = {
103
+ 'left': v_sub['location'][0],
104
+ 'top': v_sub['location'][1],
105
+ 'width': v_sub['location'][2],
106
+ 'height': v_sub['location'][3]
107
+ }
108
+ return new_outline
109
+
110
+
111
+ def validate_and_adjust_subsections(section_bbox, subsection_bboxes):
112
+ """
113
+ Validate that the given subsections collectively occupy the entire section.
114
+ If not, return an adjusted version that fixes the layout.
115
+
116
+ We assume all subsections are intended to be stacked vertically with no gaps,
117
+ spanning the full width of the section.
118
+
119
+ :param section_bbox: dict with keys ["left", "top", "width", "height"]
120
+ :param subsection_bboxes: dict of subsection_name -> bounding_box (each also
121
+ with keys ["left", "top", "width", "height"])
122
+ :return: (is_valid, revised_subsections)
123
+ where is_valid is True/False,
124
+ and revised_subsections is either the same as subsection_bboxes if valid,
125
+ or a new dict of adjusted bounding boxes if invalid.
126
+ """
127
+
128
+ # Helper functions
129
+ def _right(bbox):
130
+ return bbox["left"] + bbox["width"]
131
+
132
+ def _bottom(bbox):
133
+ return bbox["top"] + bbox["height"]
134
+
135
+ section_left = section_bbox["left"]
136
+ section_top = section_bbox["top"]
137
+ section_right = section_left + section_bbox["width"]
138
+ section_bottom = section_top + section_bbox["height"]
139
+
140
+ # Convert dictionary to a list of (subsection_name, bbox) pairs
141
+ items = list(subsection_bboxes.items())
142
+ if not items:
143
+ # No subsections is definitely not valid if we want to fill the section
144
+ return False, None
145
+
146
+ # Sort subsections by their 'top' coordinate
147
+ items_sorted = sorted(items, key=lambda x: x[1]["top"])
148
+
149
+ # ---------------------------
150
+ # Step 1: Validate
151
+ # ---------------------------
152
+ # We'll check:
153
+ # 1. left/right boundaries match the section for each subsection
154
+ # 2. The first subsection's top == section_top
155
+ # 3. The last subsection's bottom == section_bottom
156
+ # 4. Each pair of consecutive subsections lines up exactly
157
+ # (previous bottom == current top) with no gap or overlap.
158
+
159
+ is_valid = True
160
+
161
+ # Check left/right for each
162
+ for name, bbox in items_sorted:
163
+ if bbox["left"] != section_left or _right(bbox) != section_right:
164
+ is_valid = False
165
+ break
166
+
167
+ # Check alignment for the first and last
168
+ if is_valid:
169
+ first_sub_name, first_sub_bbox = items_sorted[0]
170
+ if first_sub_bbox["top"] != section_top:
171
+ is_valid = False
172
+
173
+ if is_valid:
174
+ last_sub_name, last_sub_bbox = items_sorted[-1]
175
+ if _bottom(last_sub_bbox) != section_bottom:
176
+ is_valid = False
177
+
178
+ # Check consecutive alignment
179
+ if is_valid:
180
+ for i in range(len(items_sorted) - 1):
181
+ _, current_bbox = items_sorted[i]
182
+ _, next_bbox = items_sorted[i + 1]
183
+ if _bottom(current_bbox) != next_bbox["top"]:
184
+ is_valid = False
185
+ break
186
+
187
+ # If everything passed, we return
188
+ if is_valid:
189
+ return True, subsection_bboxes
190
+
191
+ # ---------------------------
192
+ # Step 2: Revise
193
+ # ---------------------------
194
+ # We will adjust all subsection bboxes so that they occupy
195
+ # the entire section exactly, preserving each original bbox's
196
+ # height *ratio* if possible.
197
+
198
+ # 2a. Compute total original height (in the order of sorted items)
199
+ original_heights = [bbox["height"] for _, bbox in items_sorted]
200
+ total_original_height = sum(original_heights)
201
+
202
+ # Avoid divide-by-zero if somehow there's a 0 height
203
+ if total_original_height <= 0:
204
+ # Fallback: split the section equally among subsections
205
+ # to avoid zero or negative heights
206
+ chunk_height = section_bbox["height"] / len(items_sorted)
207
+ scale_heights = [chunk_height] * len(items_sorted)
208
+ else:
209
+ # Scale each original height by the ratio of
210
+ # (section total height / sum of original heights)
211
+ scale = section_bbox["height"] / total_original_height
212
+ scale_heights = [h * scale for h in original_heights]
213
+
214
+ # 2b. Assign bounding boxes top->bottom, ensuring no gap
215
+ revised = {}
216
+ current_top = section_top
217
+ for i, (name, original_bbox) in enumerate(items_sorted):
218
+ revised_height = scale_heights[i]
219
+ # If there's floating error, we can clamp in the last iteration
220
+ # so that the bottom exactly matches section_bottom.
221
+ # But for simplicity, we'll keep it straightforward unless needed.
222
+
223
+ revised[name] = {
224
+ "left": section_left,
225
+ "top": current_top,
226
+ "width": section_bbox["width"],
227
+ "height": revised_height
228
+ }
229
+ # Update current_top for next subsection
230
+ current_top += revised_height
231
+
232
+ # Due to potential float rounding, we can enforce the last subsection
233
+ # to exactly end at section_bottom:
234
+ last_name = items_sorted[-1][0]
235
+ # Recompute the actual bottom after the above assignment
236
+ new_bottom = revised[last_name]["top"] + revised[last_name]["height"]
237
+ diff = new_bottom - section_bottom
238
+ if abs(diff) > 1e-9:
239
+ # Adjust the last subsection's height
240
+ revised[last_name]["height"] -= diff
241
+
242
+ # Return the revised dictionary
243
+ return False, revised
244
+
245
+ def filter_image_table(args, filter_config):
246
+ images = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images.json', 'r'))
247
+ tables = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables.json', 'r'))
248
+ doc_json = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
249
+ agent_filter = 'image_table_filter_agent'
250
+ with open(f"utils/prompt_templates/{agent_filter}.yaml", "r", encoding="utf-8") as f:
251
+ config_filter = yaml.safe_load(f)
252
+
253
+ image_information = {}
254
+ for k, v in images.items():
255
+ image_information[k] = copy.deepcopy(v)
256
+ image_information[k]['min_width'] = v['width'] // IMAGE_SCALE_RATIO_MIN
257
+ image_information[k]['min_height'] = v['height'] // IMAGE_SCALE_RATIO_MIN
258
+ image_information[k]['max_width'] = v['width'] // IMAGE_SCALE_RATIO_MAX
259
+ image_information[k]['max_height'] = v['height'] // IMAGE_SCALE_RATIO_MAX
260
+
261
+ table_information = {}
262
+ for k, v in tables.items():
263
+ table_information[k] = copy.deepcopy(v)
264
+ table_information[k]['min_width'] = v['width'] // TABLE_SCALE_RATIO_MIN
265
+ table_information[k]['min_height'] = v['height'] // TABLE_SCALE_RATIO_MIN
266
+ table_information[k]['max_width'] = v['width'] // TABLE_SCALE_RATIO_MAX
267
+ table_information[k]['max_height'] = v['height'] // TABLE_SCALE_RATIO_MAX
268
+
269
+ filter_actor_sys_msg = config_filter['system_prompt']
270
+
271
+ if args.model_name_t.startswith('vllm_qwen'):
272
+ filter_model = ModelFactory.create(
273
+ model_platform=filter_config['model_platform'],
274
+ model_type=filter_config['model_type'],
275
+ model_config_dict=filter_config['model_config'],
276
+ url=filter_config['url'],
277
+ )
278
+ else:
279
+ filter_model = ModelFactory.create(
280
+ model_platform=filter_config['model_platform'],
281
+ model_type=filter_config['model_type'],
282
+ model_config_dict=filter_config['model_config'],
283
+ )
284
+
285
+ filter_actor_agent = ChatAgent(
286
+ system_message=filter_actor_sys_msg,
287
+ model=filter_model,
288
+ message_window_size=10,
289
+ )
290
+
291
+ filter_jinja_args = {
292
+ 'json_content': doc_json,
293
+ 'table_information': json.dumps(table_information, indent=4),
294
+ 'image_information': json.dumps(image_information, indent=4),
295
+ }
296
+ jinja_env = Environment(undefined=StrictUndefined)
297
+ filter_prompt = jinja_env.from_string(config_filter["template"])
298
+ filter_actor_agent.reset()
299
+ response = filter_actor_agent.step(filter_prompt.render(**filter_jinja_args))
300
+ input_token, output_token = account_token(response)
301
+ response_json = get_json_from_response(response.msgs[0].content)
302
+ table_information = response_json['table_information']
303
+ image_information = response_json['image_information']
304
+ json.dump(images, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images_filtered.json', 'w'), indent=4)
305
+ json.dump(tables, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables_filtered.json', 'w'), indent=4)
306
+
307
+ return input_token, output_token
308
+
309
+ def gen_outline_layout_v2(args, actor_config):
310
+ total_input_token, total_output_token = 0, 0
311
+ agent_name = 'poster_planner_new_v2'
312
+ doc_json = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
313
+ filtered_table_information = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
314
+ filtered_image_information = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
315
+
316
+ filtered_table_information_captions = {}
317
+ filtered_image_information_captions = {}
318
+
319
+ for k, v in filtered_table_information.items():
320
+ filtered_table_information_captions[k] = {
321
+ v['caption']
322
+ }
323
+
324
+ for k, v in filtered_image_information.items():
325
+ filtered_image_information_captions[k] = {
326
+ v['caption']
327
+ }
328
+
329
+ with open(f"utils/prompt_templates/{agent_name}.yaml", "r", encoding="utf-8") as f:
330
+ planner_config = yaml.safe_load(f)
331
+
332
+ compute_tp(doc_json)
333
+
334
+ jinja_env = Environment(undefined=StrictUndefined)
335
+ outline_template = jinja_env.from_string(planner_config["template"])
336
+ planner_jinja_args = {
337
+ 'json_content': doc_json,
338
+ 'table_information': filtered_table_information_captions,
339
+ 'image_information': filtered_image_information_captions,
340
+ }
341
+
342
+ if args.model_name_t.startswith('vllm_qwen'):
343
+ planner_model = ModelFactory.create(
344
+ model_platform=actor_config['model_platform'],
345
+ model_type=actor_config['model_type'],
346
+ model_config_dict=actor_config['model_config'],
347
+ url=actor_config['url'],
348
+ )
349
+ else:
350
+ planner_model = ModelFactory.create(
351
+ model_platform=actor_config['model_platform'],
352
+ model_type=actor_config['model_type'],
353
+ model_config_dict=actor_config['model_config'],
354
+ )
355
+
356
+
357
+ planner_agent = ChatAgent(
358
+ system_message=planner_config['system_prompt'],
359
+ model=planner_model,
360
+ message_window_size=10,
361
+ )
362
+
363
+ print(f'Generating outline...')
364
+ planner_prompt = outline_template.render(**planner_jinja_args)
365
+ planner_agent.reset()
366
+ response = planner_agent.step(planner_prompt)
367
+ input_token, output_token = account_token(response)
368
+ total_input_token += input_token
369
+ total_output_token += output_token
370
+
371
+ figure_arrangement = get_json_from_response(response.msgs[0].content)
372
+
373
+ print(f'Figure arrangement: {json.dumps(figure_arrangement, indent=4)}')
374
+
375
+ arranged_images = {}
376
+ arranged_tables = {}
377
+ assigned_images = set()
378
+ assigned_tables = set()
379
+
380
+ for section_name, figure in figure_arrangement.items():
381
+ if 'image' in figure:
382
+ image_id = str(figure['image'])
383
+ if image_id in assigned_images:
384
+ continue
385
+ if image_id in filtered_image_information:
386
+ arranged_images[image_id] = filtered_image_information[image_id]
387
+ assigned_images.add(image_id)
388
+ if 'table' in figure:
389
+ table_id = str(figure['table'])
390
+ if table_id in assigned_tables:
391
+ continue
392
+ if table_id in filtered_table_information:
393
+ arranged_tables[table_id] = filtered_table_information[table_id]
394
+ assigned_tables.add(table_id)
395
+
396
+ compute_gp(arranged_tables, arranged_images)
397
+
398
+ # Obtain panel input
399
+ paper_panels = []
400
+ for i in range(len(doc_json['sections'])):
401
+ section = doc_json['sections'][i]
402
+ panel = {}
403
+ panel['panel_id'] = i
404
+ panel['section_name'] = section['title']
405
+ panel['tp'] = section['tp']
406
+ panel['text_len'] = section['text_len']
407
+ panel['gp'] = 0
408
+ panel['figure_size'] = 0
409
+ panel['figure_aspect'] = 1
410
+ if section['title'] in figure_arrangement:
411
+ curr_arrangement = figure_arrangement[section['title']]
412
+ if 'table' in curr_arrangement:
413
+ table_id = str(curr_arrangement['table'])
414
+ if table_id in arranged_tables:
415
+ panel['gp'] = arranged_tables[table_id]['gp']
416
+ panel['figure_size'] = arranged_tables[table_id]['figure_size']
417
+ panel['figure_aspect'] = arranged_tables[table_id]['figure_aspect']
418
+ elif 'image' in curr_arrangement:
419
+ image_id = str(curr_arrangement['image'])
420
+ if image_id in arranged_images:
421
+ panel['gp'] = arranged_images[image_id]['gp']
422
+ panel['figure_size'] = arranged_images[image_id]['figure_size']
423
+ panel['figure_aspect'] = arranged_images[image_id]['figure_aspect']
424
+
425
+ paper_panels.append(panel)
426
+
427
+ return total_input_token, total_output_token, paper_panels, figure_arrangement
428
+
429
+ def gen_outline_layout(args, actor_config, critic_config):
430
+ poster_log_path = f'log/{args.model_name}_{args.poster_name}_poster_{args.index}'
431
+ if not os.path.exists(poster_log_path):
432
+ os.mkdir(poster_log_path)
433
+ total_input_token, total_output_token = 0, 0
434
+ consumption_log = {
435
+ 'outline': [],
436
+ 'h1_actor': [],
437
+ 'h2_actor': [],
438
+ 'h1_critic': [],
439
+ 'gen_layout': []
440
+ }
441
+ jinja_env = Environment(undefined=StrictUndefined)
442
+ outline_file_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
443
+ agent_name = 'poster_planner_new'
444
+ agent_init_name = 'layout_agent_init'
445
+ agent_new_section_name = 'layout_agent_new_section'
446
+ h1_critic_name = 'critic_layout_hierarchy_1'
447
+ h2_actor_name = 'actor_layout_hierarchy_2'
448
+
449
+ doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
450
+ filtered_table_information = json.load(open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
451
+ filtered_image_information = json.load(open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
452
+
453
+ with open(f"utils/prompt_templates/{agent_name}.yaml", "r", encoding="utf-8") as f:
454
+ planner_config = yaml.safe_load(f)
455
+
456
+ with open(f"utils/prompt_templates/{agent_init_name}.yaml", "r", encoding="utf-8") as f:
457
+ config_init = yaml.safe_load(f)
458
+
459
+ with open(f"utils/prompt_templates/{agent_new_section_name}.yaml", "r", encoding="utf-8") as f:
460
+ config_new_section = yaml.safe_load(f)
461
+
462
+ with open(f"utils/prompt_templates/{h1_critic_name}.yaml", "r", encoding="utf-8") as f:
463
+ config_h1_critic = yaml.safe_load(f)
464
+
465
+ with open(f"utils/prompt_templates/{h2_actor_name}.yaml", "r", encoding="utf-8") as f:
466
+ config_h2_actor = yaml.safe_load(f)
467
+
468
+ planner_model = ModelFactory.create(
469
+ model_platform=actor_config['model_platform'],
470
+ model_type=actor_config['model_type'],
471
+ model_config_dict=actor_config['model_config'],
472
+ )
473
+
474
+ planner_agent = ChatAgent(
475
+ system_message=planner_config['system_prompt'],
476
+ model=planner_model,
477
+ message_window_size=10,
478
+ )
479
+
480
+ outline_template = jinja_env.from_string(planner_config["template"])
481
+
482
+ planner_jinja_args = {
483
+ 'json_content': doc_json,
484
+ 'table_information': filtered_table_information,
485
+ 'image_information': filtered_image_information,
486
+ }
487
+
488
+ actor_model = ModelFactory.create(
489
+ model_platform=actor_config['model_platform'],
490
+ model_type=actor_config['model_type'],
491
+ model_config_dict=actor_config['model_config'],
492
+ )
493
+
494
+ init_actor_sys_msg = config_init['system_prompt']
495
+
496
+ init_actor_agent = ChatAgent(
497
+ system_message=init_actor_sys_msg,
498
+ model=actor_model,
499
+ message_window_size=10,
500
+ )
501
+
502
+ new_section_actor_sys_msg = config_new_section['system_prompt']
503
+ new_section_actor_agent = ChatAgent(
504
+ system_message=new_section_actor_sys_msg,
505
+ model=actor_model,
506
+ message_window_size=10,
507
+ )
508
+
509
+ h1_critic_model = ModelFactory.create(
510
+ model_platform=critic_config['model_platform'],
511
+ model_type=critic_config['model_type'],
512
+ model_config_dict=critic_config['model_config'],
513
+ )
514
+
515
+ h1_critic_sys_msg = config_h1_critic['system_prompt']
516
+
517
+ h1_critic_agent = ChatAgent(
518
+ system_message=h1_critic_sys_msg,
519
+ model=h1_critic_model,
520
+ message_window_size=None,
521
+ )
522
+
523
+ h1_pos_example = Image.open('assets/h1_example/h1_pos.jpg')
524
+ h1_neg_example = Image.open('assets/h1_example/h1_neg.jpg')
525
+
526
+ h2_actor_model = ModelFactory.create(
527
+ model_platform=actor_config['model_platform'],
528
+ model_type=actor_config['model_type'],
529
+ model_config_dict=actor_config['model_config'],
530
+ )
531
+
532
+ h2_actor_sys_msg = config_h2_actor['system_prompt']
533
+
534
+ h2_actor_agent = ChatAgent(
535
+ system_message=h2_actor_sys_msg,
536
+ model=h2_actor_model,
537
+ message_window_size=10,
538
+ )
539
+
540
+ attempt = 0
541
+ while True:
542
+ print(f'Generating outline attempt {attempt}...')
543
+ planner_prompt = outline_template.render(**planner_jinja_args)
544
+ planner_agent.reset()
545
+ response = planner_agent.step(planner_prompt)
546
+ input_token, output_token = account_token(response)
547
+ consumption_log['outline'].append((input_token, output_token))
548
+ total_input_token += input_token
549
+ total_output_token += output_token
550
+
551
+ outline = get_json_from_response(response.msgs[0].content)
552
+ name_to_hierarchy = get_hierarchy(outline)
553
+
554
+ sections = list(outline.keys())
555
+ sections = [x for x in sections if x != 'meta']
556
+ init_template = jinja_env.from_string(config_init["template"])
557
+ new_section_template = jinja_env.from_string(config_new_section["template"])
558
+ h1_critic_template = jinja_env.from_string(config_h1_critic["template"])
559
+ init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
560
+
561
+ new_outline = outline
562
+
563
+ init_jinja_args = {
564
+ 'json_outline': init_outline,
565
+ 'function_docs': documentation
566
+ }
567
+
568
+ init_prompt = init_template.render(**init_jinja_args)
569
+
570
+ # hierarchy 1 only
571
+ outline_location = get_outline_location(outline, subsection=False)
572
+ logs = {}
573
+ curr_section = sections[0]
574
+
575
+ layout_cumulative_input_token = 0
576
+ layout_cumulative_output_token = 0
577
+
578
+ print('Generating h1 layout...\n')
579
+ print(f'Generating h1 layout for section {curr_section}...')
580
+ logs[curr_section] = gen_layout(
581
+ init_actor_agent,
582
+ init_prompt,
583
+ args.max_retry,
584
+ name_to_hierarchy,
585
+ visual_identifier=curr_section
586
+ )
587
+
588
+ if logs[curr_section][-1]['error'] is not None:
589
+ raise ValueError(f'Failed to generate layout for section {curr_section}.')
590
+
591
+ layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
592
+ layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
593
+
594
+ for section_index in range(1, len(sections)):
595
+ curr_section = sections[section_index]
596
+ print(f'generating h1 layout for section {curr_section}...')
597
+ new_section_outline = {curr_section: new_outline[curr_section]}
598
+ new_section_jinja_args = {
599
+ 'json_outline': new_section_outline,
600
+ 'function_docs': documentation
601
+ }
602
+ new_section_prompt = new_section_template.render(**new_section_jinja_args)
603
+
604
+ logs[curr_section] = gen_layout(
605
+ new_section_actor_agent,
606
+ new_section_prompt,
607
+ args.max_retry,
608
+ name_to_hierarchy,
609
+ visual_identifier=curr_section,
610
+ existing_code = logs[sections[section_index - 1]][-1]['concatenated_code']
611
+ )
612
+ if logs[curr_section][-1]['error'] is not None:
613
+ raise ValueError(f'Failed to generate layout for section {curr_section}.')
614
+
615
+ layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
616
+ layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
617
+
618
+ consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
619
+ total_input_token += layout_cumulative_input_token
620
+ total_output_token += layout_cumulative_output_token
621
+
622
+ h1_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1.pptx'
623
+ h2_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2.pptx'
624
+
625
+ h1_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1_filled.pptx'
626
+ h2_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2_filled.pptx'
627
+
628
+ ppt_to_images(h1_path, 'tmp/layout_h1')
629
+ ppt_to_images(h2_path, 'tmp/layout_h2')
630
+ ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
631
+ ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
632
+
633
+ h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
634
+ h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
635
+ h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
636
+ h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
637
+
638
+ h1_critic_msg = BaseMessage.make_user_message(
639
+ role_name='User',
640
+ content=h1_critic_template.render(),
641
+ image_list=[h1_neg_example, h1_pos_example, h1_filled_img]
642
+ )
643
+
644
+ outline_bbox_dict = {}
645
+ for k, v in outline_location.items():
646
+ outline_bbox_dict[k] = v['location']
647
+
648
+ bbox_check_result = check_bounding_boxes(
649
+ outline_bbox_dict,
650
+ new_outline['meta']['width'],
651
+ new_outline['meta']['height']
652
+ )
653
+
654
+ if len(bbox_check_result) != 0:
655
+ print(bbox_check_result)
656
+ attempt += 1
657
+ continue
658
+
659
+ h1_critic_agent.reset()
660
+ response = h1_critic_agent.step(h1_critic_msg)
661
+ input_token, output_token = account_token(response)
662
+ consumption_log['h1_critic'].append((input_token, output_token))
663
+ total_input_token += input_token
664
+ total_output_token += output_token
665
+ if response.msgs[0].content == 'T':
666
+ print('Blank area detected.')
667
+ attempt += 1
668
+ continue
669
+
670
+ break
671
+
672
+ outline_bbox_dict = {}
673
+ for k, v in outline_location.items():
674
+ outline_bbox_dict[k] = v['location']
675
+
676
+ # Generate subsection locations
677
+ outline_no_sub_locations = copy.deepcopy(new_outline)
678
+ if 'meta' in outline_no_sub_locations:
679
+ outline_no_sub_locations.pop('meta')
680
+
681
+ for k, v in outline_no_sub_locations.items():
682
+ if 'subsections' in v:
683
+ subsections = v['subsections']
684
+ for k_sub, v_sub in subsections.items():
685
+ del v_sub['location']
686
+ del v_sub['name']
687
+
688
+ h2_actor_template = jinja_env.from_string(config_h2_actor["template"])
689
+
690
+ h2_cumulative_input_token = 0
691
+ h2_cumulative_output_token = 0
692
+
693
+ for section in sections:
694
+ while True:
695
+ print(f'generating h2 for section {section}...')
696
+ section_outline = {section: outline_no_sub_locations[section]}
697
+ section_jinja_args = {
698
+ 'section_outline': json.dumps(section_outline, indent=4),
699
+ }
700
+
701
+ section_prompt = h2_actor_template.render(**section_jinja_args)
702
+
703
+ h2_actor_agent.reset()
704
+ response = h2_actor_agent.step(section_prompt)
705
+ input_token, output_token = account_token(response)
706
+ h2_cumulative_input_token += input_token
707
+ h2_cumulative_output_token += output_token
708
+ subsection_location = get_json_from_response(response.msgs[0].content)
709
+
710
+ sec_bbox = outline_no_sub_locations[section]['location']
711
+ subsection_location_dict = {}
712
+ for k, v in subsection_location.items():
713
+ subsection_location_dict[k] = {
714
+ 'left': v['location'][0],
715
+ 'top': v['location'][1],
716
+ 'width': v['location'][2],
717
+ 'height': v['location'][3]
718
+ }
719
+
720
+ is_valid, revised = validate_and_adjust_subsections(sec_bbox, subsection_location_dict)
721
+ if not is_valid:
722
+ is_valid, revised = validate_and_adjust_subsections(sec_bbox, revised)
723
+ assert is_valid, "Failed to adjust subsections to fit section"
724
+ outline_no_sub_locations = fill_location(outline_no_sub_locations, section, revised)
725
+ else:
726
+ outline_no_sub_locations = fill_location(outline_no_sub_locations, section, subsection_location)
727
+ break
728
+
729
+ consumption_log['h2_actor'].append((h2_cumulative_input_token, h2_cumulative_output_token))
730
+ total_input_token += h2_cumulative_input_token
731
+ total_output_token += h2_cumulative_output_token
732
+
733
+ outline_no_sub_locations['meta'] = outline['meta']
734
+ outline_no_sub_locations_with_name = recover_name_and_location(outline_no_sub_locations, new_outline)
735
+ new_outline = outline_no_sub_locations_with_name
736
+
737
+ ### Outline finalized, actually generate layout
738
+
739
+ logs = {}
740
+
741
+ gen_layout_cumulative_input_token = 0
742
+ gen_layout_cumulative_output_token = 0
743
+ curr_section = sections[0]
744
+
745
+ init_outline = {'meta': new_outline['meta'], sections[0]: new_outline[sections[0]]}
746
+
747
+ init_jinja_args = {
748
+ 'json_outline': init_outline,
749
+ 'function_docs': documentation
750
+ }
751
+
752
+ init_prompt = init_template.render(**init_jinja_args)
753
+ logs[curr_section] = gen_layout(
754
+ init_actor_agent,
755
+ init_prompt,
756
+ args.max_retry,
757
+ name_to_hierarchy,
758
+ visual_identifier=curr_section
759
+ )
760
+
761
+ if logs[curr_section][-1]['error'] is not None:
762
+ raise ValueError(f'Failed to generate layout for section {curr_section}.')
763
+
764
+ gen_layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
765
+ gen_layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
766
+
767
+ for section_index in range(1, len(sections)):
768
+ curr_section = sections[section_index]
769
+ print(f'generating section {curr_section}...')
770
+ new_section_outline = {curr_section: new_outline[curr_section]}
771
+ new_section_jinja_args = {
772
+ 'json_outline': new_section_outline,
773
+ 'function_docs': documentation
774
+ }
775
+ new_section_prompt = new_section_template.render(**new_section_jinja_args)
776
+
777
+ logs[curr_section] = gen_layout(
778
+ new_section_actor_agent,
779
+ new_section_prompt,
780
+ args.max_retry,
781
+ name_to_hierarchy,
782
+ visual_identifier=curr_section,
783
+ existing_code = logs[sections[section_index - 1]][-1]['concatenated_code']
784
+ )
785
+ if logs[curr_section][-1]['error'] is not None:
786
+ raise ValueError(f'Failed to generate layout for section {curr_section}.')
787
+
788
+ gen_layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
789
+ gen_layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
790
+
791
+ consumption_log['gen_layout'].append((gen_layout_cumulative_input_token, gen_layout_cumulative_output_token))
792
+ total_input_token += gen_layout_cumulative_input_token
793
+ total_output_token += gen_layout_cumulative_output_token
794
+
795
+ h1_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1.pptx'
796
+ h2_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2.pptx'
797
+
798
+ h1_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1_filled.pptx'
799
+ h2_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2_filled.pptx'
800
+
801
+ ppt_to_images(h1_path, f'{poster_log_path}/layout_h1')
802
+ ppt_to_images(h2_path, f'{poster_log_path}/layout_h2')
803
+ ppt_to_images(h1_filled_path, f'{poster_log_path}/layout_h1_filled')
804
+ ppt_to_images(h2_filled_path, f'{poster_log_path}/layout_h2_filled')
805
+
806
+ h1_img = Image.open(f'{poster_log_path}/layout_h1/slide_0001.jpg')
807
+ h2_img = Image.open(f'{poster_log_path}/layout_h2/slide_0001.jpg')
808
+ h1_filled_img = Image.open(f'{poster_log_path}/layout_h1_filled/slide_0001.jpg')
809
+ h2_filled_img = Image.open(f'{poster_log_path}/layout_h2_filled/slide_0001.jpg')
810
+
811
+ ckpt = {
812
+ 'logs': logs,
813
+ 'outline': new_outline,
814
+ 'name_to_hierarchy': name_to_hierarchy,
815
+ 'consumption_log': consumption_log,
816
+ 'total_input_token': total_input_token,
817
+ 'total_output_token': total_output_token,
818
+ }
819
+
820
+ with open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'wb') as f:
821
+ pkl.dump(ckpt, f)
822
+
823
+ json.dump(
824
+ new_outline,
825
+ open(outline_file_path, "w"),
826
+ ensure_ascii=False,
827
+ indent=4,
828
+ )
829
+
830
+ return total_input_token, total_output_token
831
+
832
+ if __name__ == '__main__':
833
+ parser = argparse.ArgumentParser()
834
+ parser.add_argument('--poster_name', type=str, default=None)
835
+ parser.add_argument('--model_name', type=str, default='4o')
836
+ parser.add_argument('--poster_path', type=str, required=True)
837
+ parser.add_argument('--index', type=int, default=0)
838
+ parser.add_argument('--max_retry', type=int, default=3)
839
+ args = parser.parse_args()
840
+
841
+ actor_config = get_agent_config(args.model_name)
842
+ critic_config = get_agent_config(args.model_name)
843
+
844
+ if args.poster_name is None:
845
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
846
+
847
+ input_token, output_token = filter_image_table(args, actor_config)
848
+ print(f'Token consumption: {input_token} -> {output_token}')
849
+
850
+ input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
851
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/gen_outline_layout_parallel.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ import os
3
+ import json
4
+ import copy
5
+ import yaml
6
+ import logging
7
+ import time
8
+ from jinja2 import Environment, StrictUndefined
9
+
10
+ from utils.src.utils import ppt_to_images, get_json_from_response
11
+
12
+ from camel.models import ModelFactory
13
+ from camel.agents import ChatAgent
14
+ from camel.messages import BaseMessage
15
+
16
+ from utils.pptx_utils import *
17
+ from utils.wei_utils import *
18
+
19
+ import pickle as pkl
20
+ import argparse
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from concurrent.futures import ProcessPoolExecutor, as_completed
23
+ import concurrent.futures
24
+ import sys
25
+
26
+ load_dotenv()
27
+
28
+ logging.basicConfig(
29
+ level=logging.DEBUG,
30
+ format='%(threadName)s: %(message)s',
31
+ stream=sys.stdout
32
+ )
33
+ logger = logging.getLogger(__name__)
34
+
35
+ IMAGE_SCALE_RATIO_MIN = 50
36
+ IMAGE_SCALE_RATIO_MAX = 40
37
+ TABLE_SCALE_RATIO_MIN = 100
38
+ TABLE_SCALE_RATIO_MAX = 80
39
+
40
+ def layout_process_section_wrapped(
41
+ sections,
42
+ new_outline,
43
+ init_template,
44
+ new_section_template,
45
+ init_actor_sys_msg,
46
+ new_section_actor_sys_msg,
47
+ actor_config,
48
+ documentation,
49
+ max_retry,
50
+ slide_width,
51
+ slide_height
52
+ ):
53
+ logs = {}
54
+ parallel_results = {}
55
+ total_input_token, total_output_token = 0, 0
56
+
57
+ # Switch from ThreadPoolExecutor to ProcessPoolExecutor
58
+ with ThreadPoolExecutor() as executor:
59
+ futures = []
60
+
61
+ for section_index in range(len(sections)):
62
+ if section_index == 0:
63
+ sys_msg = init_actor_sys_msg
64
+ prompt_template = init_template
65
+ else:
66
+ sys_msg = new_section_actor_sys_msg
67
+ prompt_template = new_section_template
68
+
69
+ actor_model = ModelFactory.create(
70
+ model_platform=actor_config['model_platform'],
71
+ model_type=actor_config['model_type'],
72
+ model_config_dict=actor_config['model_config'],
73
+ )
74
+
75
+ future = executor.submit(
76
+ layout_process_section,
77
+ section_index,
78
+ sections,
79
+ new_outline,
80
+ prompt_template,
81
+ documentation,
82
+ sys_msg,
83
+ actor_model,
84
+ 10,
85
+ max_retry,
86
+ slide_width,
87
+ slide_height
88
+ )
89
+ futures.append(future)
90
+
91
+ # Collect results as processes complete
92
+ for future in as_completed(futures):
93
+ section_index, section_logs, in_toks, out_toks = future.result()
94
+
95
+ # Store logs by section index
96
+ parallel_results[section_index] = section_logs
97
+
98
+ # Update token counters
99
+ total_input_token += in_toks
100
+ total_output_token += out_toks
101
+
102
+ # Merge results back into `logs`
103
+ for section_index, section_logs in parallel_results.items():
104
+ curr_section = sections[section_index]
105
+ logs[curr_section] = section_logs
106
+
107
+ return logs, total_input_token, total_output_token
108
+
109
+ def create_agent_fn(sys_msg, agent_model, window_size=10):
110
+ agent = ChatAgent(
111
+ system_message=sys_msg,
112
+ model=agent_model,
113
+ message_window_size=window_size,
114
+ )
115
+ return agent
116
+
117
+ def layout_h2_process_section(
118
+ section,
119
+ outline_no_sub_locations,
120
+ h2_actor_template,
121
+ create_h2_actor_agent, # If you need a fresh agent for each thread
122
+ ):
123
+ """
124
+ Run the logic for a single section.
125
+ Returns a tuple containing:
126
+ - section name (or id),
127
+ - updated subsection-location dict,
128
+ - input token count,
129
+ - output token count
130
+ """
131
+ print(f'Generating h2 for section {section}...', flush=True)
132
+
133
+ # 1) Create the prompt
134
+ section_outline = {section: outline_no_sub_locations[section]}
135
+ section_jinja_args = {
136
+ 'section_outline': json.dumps(section_outline, indent=4),
137
+ }
138
+ section_prompt = h2_actor_template.render(**section_jinja_args)
139
+
140
+ # 2) Prepare a fresh agent or reuse existing (thread-safe?) agent
141
+ # If your h2_actor_agent is not thread-safe, instantiate a new one here:
142
+ h2_actor_agent = create_h2_actor_agent()
143
+ h2_actor_agent.reset()
144
+
145
+ # 3) Get response
146
+ response = h2_actor_agent.step(section_prompt)
147
+ input_token, output_token = account_token(response)
148
+
149
+ # 4) Parse JSON
150
+ subsection_location = get_json_from_response(response.msgs[0].content)
151
+
152
+ # 5) Create a dict from the sub-locations
153
+ sec_bbox = outline_no_sub_locations[section]['location']
154
+ subsection_location_dict = {}
155
+ for k, v in subsection_location.items():
156
+ subsection_location_dict[k] = {
157
+ 'left': v['location'][0],
158
+ 'top': v['location'][1],
159
+ 'width': v['location'][2],
160
+ 'height': v['location'][3]
161
+ }
162
+
163
+ # 6) Validate and possibly revise
164
+ is_valid, revised = validate_and_adjust_subsections(sec_bbox, subsection_location_dict)
165
+ if not is_valid:
166
+ # Try once more
167
+ is_valid, revised = validate_and_adjust_subsections(sec_bbox, revised)
168
+ assert is_valid, "Failed to adjust subsections to fit section"
169
+ final_sub_loc = revised
170
+ else:
171
+ final_sub_loc = subsection_location
172
+
173
+ # Return all data needed by the main thread
174
+ return section, final_sub_loc, input_token, output_token
175
+
176
+ def layout_process_section(
177
+ section_index,
178
+ sections,
179
+ new_outline,
180
+ new_section_template,
181
+ documentation,
182
+ sys_msg,
183
+ agent_model,
184
+ window_size,
185
+ max_retry,
186
+ slide_width,
187
+ slide_height
188
+ ):
189
+ """
190
+ Runs the 'gen_layout' logic for a single section_index.
191
+
192
+ Returns a tuple:
193
+ (section_index, updated_log, input_tokens, output_tokens)
194
+ """
195
+ curr_section = sections[section_index]
196
+ print(f'Generating h1 layout for section {curr_section}...')
197
+
198
+ # Build outline JSON just for current section
199
+ new_section_outline = {curr_section: new_outline[curr_section]}
200
+ if section_index == 0:
201
+ new_section_outline = {'meta': new_outline['meta'], curr_section: new_outline[curr_section]}
202
+ new_section_jinja_args = {
203
+ 'json_outline': new_section_outline,
204
+ 'function_docs': documentation,
205
+ 'file_name': f'poster_{section_index}.pptx'
206
+ }
207
+
208
+ # Render prompt
209
+ new_section_prompt = new_section_template.render(**new_section_jinja_args)
210
+
211
+ existing_code = '' # Or fetch from a stable location that is not dependent on real-time results
212
+
213
+ # Call gen_layout
214
+ section_logs = gen_layout_parallel(
215
+ create_agent_fn(
216
+ sys_msg,
217
+ agent_model,
218
+ window_size
219
+ ),
220
+ new_section_prompt,
221
+ max_retry,
222
+ existing_code=existing_code,
223
+ slide_width=slide_width,
224
+ slide_height=slide_height,
225
+ tmp_name=section_index
226
+ )
227
+
228
+ if section_logs[-1]['error'] is not None:
229
+ print(f'Failed to generate layout for section {curr_section}.')
230
+ return None
231
+
232
+ in_toks, out_toks = section_logs[-1]['cumulative_tokens']
233
+ return (section_index, section_logs, in_toks, out_toks)
234
+
235
+ def get_outline_location(outline, subsection=False):
236
+ outline_location = {}
237
+ for k, v in outline.items():
238
+ if k == 'meta':
239
+ continue
240
+ outline_location[k] = {
241
+ 'location': v['location'],
242
+ }
243
+ if subsection:
244
+ if 'subsections' in v:
245
+ outline_location[k]['subsections'] = get_outline_location(v['subsections'])
246
+ return outline_location
247
+
248
+ def apply_outline_location(outline, location, subsection=False):
249
+ new_outline = {}
250
+ for k, v in outline.items():
251
+ if k == 'meta':
252
+ new_outline[k] = v
253
+ continue
254
+ new_outline[k] = copy.deepcopy(v)
255
+ new_outline[k]['location'] = location[k]['location']
256
+ if subsection:
257
+ if 'subsections' in v:
258
+ new_outline[k]['subsections'] = apply_outline_location(v['subsections'], location[k]['subsections'])
259
+
260
+ return new_outline
261
+
262
+ def fill_location(outline, section_name, location_dict):
263
+ new_outline = copy.deepcopy(outline)
264
+ if 'subsections' not in new_outline[section_name]:
265
+ return new_outline
266
+ for k, v in new_outline[section_name]['subsections'].items():
267
+ v['location'] = location_dict[k]['location']
268
+ return new_outline
269
+
270
+ def recover_name_and_location(outline_no_name, outline):
271
+ new_outline = copy.deepcopy(outline_no_name)
272
+ for k, v in outline_no_name.items():
273
+ if k == 'meta':
274
+ continue
275
+ new_outline[k]['name'] = outline[k]['name']
276
+ if type(new_outline[k]['location']) == list:
277
+ new_outline[k]['location'] = {
278
+ 'left': v['location'][0],
279
+ 'top': v['location'][1],
280
+ 'width': v['location'][2],
281
+ 'height': v['location'][3]
282
+ }
283
+ if 'subsections' in v:
284
+ for k_sub, v_sub in v['subsections'].items():
285
+ new_outline[k]['subsections'][k_sub]['name'] = outline[k]['subsections'][k_sub]['name']
286
+ if type(new_outline[k]['subsections'][k_sub]['location']) == list:
287
+ new_outline[k]['subsections'][k_sub]['location'] = {
288
+ 'left': v_sub['location'][0],
289
+ 'top': v_sub['location'][1],
290
+ 'width': v_sub['location'][2],
291
+ 'height': v_sub['location'][3]
292
+ }
293
+ return new_outline
294
+
295
+
296
+ def validate_and_adjust_subsections(section_bbox, subsection_bboxes):
297
+ """
298
+ Validate that the given subsections collectively occupy the entire section.
299
+ If not, return an adjusted version that fixes the layout.
300
+
301
+ We assume all subsections are intended to be stacked vertically with no gaps,
302
+ spanning the full width of the section.
303
+
304
+ :param section_bbox: dict with keys ["left", "top", "width", "height"]
305
+ :param subsection_bboxes: dict of subsection_name -> bounding_box (each also
306
+ with keys ["left", "top", "width", "height"])
307
+ :return: (is_valid, revised_subsections)
308
+ where is_valid is True/False,
309
+ and revised_subsections is either the same as subsection_bboxes if valid,
310
+ or a new dict of adjusted bounding boxes if invalid.
311
+ """
312
+
313
+ # Helper functions
314
+ def _right(bbox):
315
+ return bbox["left"] + bbox["width"]
316
+
317
+ def _bottom(bbox):
318
+ return bbox["top"] + bbox["height"]
319
+
320
+ section_left = section_bbox["left"]
321
+ section_top = section_bbox["top"]
322
+ section_right = section_left + section_bbox["width"]
323
+ section_bottom = section_top + section_bbox["height"]
324
+
325
+ # Convert dictionary to a list of (subsection_name, bbox) pairs
326
+ items = list(subsection_bboxes.items())
327
+ if not items:
328
+ # No subsections is definitely not valid if we want to fill the section
329
+ return False, None
330
+
331
+ # Sort subsections by their 'top' coordinate
332
+ items_sorted = sorted(items, key=lambda x: x[1]["top"])
333
+
334
+ # ---------------------------
335
+ # Step 1: Validate
336
+ # ---------------------------
337
+ # We'll check:
338
+ # 1. left/right boundaries match the section for each subsection
339
+ # 2. The first subsection's top == section_top
340
+ # 3. The last subsection's bottom == section_bottom
341
+ # 4. Each pair of consecutive subsections lines up exactly
342
+ # (previous bottom == current top) with no gap or overlap.
343
+
344
+ is_valid = True
345
+
346
+ # Check left/right for each
347
+ for name, bbox in items_sorted:
348
+ if bbox["left"] != section_left or _right(bbox) != section_right:
349
+ is_valid = False
350
+ break
351
+
352
+ # Check alignment for the first and last
353
+ if is_valid:
354
+ first_sub_name, first_sub_bbox = items_sorted[0]
355
+ if first_sub_bbox["top"] != section_top:
356
+ is_valid = False
357
+
358
+ if is_valid:
359
+ last_sub_name, last_sub_bbox = items_sorted[-1]
360
+ if _bottom(last_sub_bbox) != section_bottom:
361
+ is_valid = False
362
+
363
+ # Check consecutive alignment
364
+ if is_valid:
365
+ for i in range(len(items_sorted) - 1):
366
+ _, current_bbox = items_sorted[i]
367
+ _, next_bbox = items_sorted[i + 1]
368
+ if _bottom(current_bbox) != next_bbox["top"]:
369
+ is_valid = False
370
+ break
371
+
372
+ # If everything passed, we return
373
+ if is_valid:
374
+ return True, subsection_bboxes
375
+
376
+ # ---------------------------
377
+ # Step 2: Revise
378
+ # ---------------------------
379
+ # We will adjust all subsection bboxes so that they occupy
380
+ # the entire section exactly, preserving each original bbox's
381
+ # height *ratio* if possible.
382
+
383
+ # 2a. Compute total original height (in the order of sorted items)
384
+ original_heights = [bbox["height"] for _, bbox in items_sorted]
385
+ total_original_height = sum(original_heights)
386
+
387
+ # Avoid divide-by-zero if somehow there's a 0 height
388
+ if total_original_height <= 0:
389
+ # Fallback: split the section equally among subsections
390
+ # to avoid zero or negative heights
391
+ chunk_height = section_bbox["height"] / len(items_sorted)
392
+ scale_heights = [chunk_height] * len(items_sorted)
393
+ else:
394
+ # Scale each original height by the ratio of
395
+ # (section total height / sum of original heights)
396
+ scale = section_bbox["height"] / total_original_height
397
+ scale_heights = [h * scale for h in original_heights]
398
+
399
+ # 2b. Assign bounding boxes top->bottom, ensuring no gap
400
+ revised = {}
401
+ current_top = section_top
402
+ for i, (name, original_bbox) in enumerate(items_sorted):
403
+ revised_height = scale_heights[i]
404
+ # If there's floating error, we can clamp in the last iteration
405
+ # so that the bottom exactly matches section_bottom.
406
+ # But for simplicity, we'll keep it straightforward unless needed.
407
+
408
+ revised[name] = {
409
+ "left": section_left,
410
+ "top": current_top,
411
+ "width": section_bbox["width"],
412
+ "height": revised_height
413
+ }
414
+ # Update current_top for next subsection
415
+ current_top += revised_height
416
+
417
+ # Due to potential float rounding, we can enforce the last subsection
418
+ # to exactly end at section_bottom:
419
+ last_name = items_sorted[-1][0]
420
+ # Recompute the actual bottom after the above assignment
421
+ new_bottom = revised[last_name]["top"] + revised[last_name]["height"]
422
+ diff = new_bottom - section_bottom
423
+ if abs(diff) > 1e-9:
424
+ # Adjust the last subsection's height
425
+ revised[last_name]["height"] -= diff
426
+
427
+ # Return the revised dictionary
428
+ return False, revised
429
+
430
+ def filter_image_table(args, filter_config):
431
+ images = json.load(open(f'images_and_tables/{args.poster_name}_images.json', 'r'))
432
+ tables = json.load(open(f'images_and_tables/{args.poster_name}_tables.json', 'r'))
433
+ doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
434
+ agent_filter = 'image_table_filter_agent'
435
+ with open(f"prompt_templates/{agent_filter}.yaml", "r") as f:
436
+ config_filter = yaml.safe_load(f)
437
+
438
+ image_information = {}
439
+ for k, v in images.items():
440
+ image_information[k] = copy.deepcopy(v)
441
+ image_information[k]['min_width'] = v['width'] // IMAGE_SCALE_RATIO_MIN
442
+ image_information[k]['min_height'] = v['height'] // IMAGE_SCALE_RATIO_MIN
443
+ image_information[k]['max_width'] = v['width'] // IMAGE_SCALE_RATIO_MAX
444
+ image_information[k]['max_height'] = v['height'] // IMAGE_SCALE_RATIO_MAX
445
+
446
+ table_information = {}
447
+ for k, v in tables.items():
448
+ table_information[k] = copy.deepcopy(v)
449
+ table_information[k]['min_width'] = v['width'] // TABLE_SCALE_RATIO_MIN
450
+ table_information[k]['min_height'] = v['height'] // TABLE_SCALE_RATIO_MIN
451
+ table_information[k]['max_width'] = v['width'] // TABLE_SCALE_RATIO_MAX
452
+ table_information[k]['max_height'] = v['height'] // TABLE_SCALE_RATIO_MAX
453
+
454
+ filter_actor_sys_msg = config_filter['system_prompt']
455
+
456
+ filter_model = ModelFactory.create(
457
+ model_platform=filter_config['model_platform'],
458
+ model_type=filter_config['model_type'],
459
+ model_config_dict=filter_config['model_config'],
460
+ )
461
+ filter_actor_agent = ChatAgent(
462
+ system_message=filter_actor_sys_msg,
463
+ model=filter_model,
464
+ message_window_size=10, # [Optional] the length for chat memory
465
+ )
466
+
467
+ filter_jinja_args = {
468
+ 'json_content': doc_json,
469
+ 'table_information': table_information,
470
+ 'image_information': image_information,
471
+ }
472
+ jinja_env = Environment(undefined=StrictUndefined)
473
+ filter_prompt = jinja_env.from_string(config_filter["template"])
474
+ response = filter_actor_agent.step(filter_prompt.render(**filter_jinja_args))
475
+ input_token, output_token = account_token(response)
476
+ response_json = get_json_from_response(response.msgs[0].content)
477
+ table_information = response_json['table_information']
478
+ image_information = response_json['image_information']
479
+ json.dump(images, open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'w'), indent=4)
480
+ json.dump(tables, open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'w'), indent=4)
481
+
482
+ return input_token, output_token
483
+
484
+ def gen_outline_layout(args, actor_config, critic_config):
485
+ poster_log_path = f'log/{args.model_name}_{args.poster_name}_poster_{args.index}'
486
+ if not os.path.exists(poster_log_path):
487
+ os.mkdir(poster_log_path)
488
+ total_input_token, total_output_token = 0, 0
489
+ consumption_log = {
490
+ 'outline': [],
491
+ 'h1_actor': [],
492
+ 'h2_actor': [],
493
+ 'h1_critic': [],
494
+ 'gen_layout': []
495
+ }
496
+ jinja_env = Environment(undefined=StrictUndefined)
497
+ outline_file_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
498
+ agent_name = 'poster_planner_new'
499
+ agent_init_name = 'layout_agent_init_parallel'
500
+ agent_new_section_name = 'layout_agent_new_section_parallel'
501
+ h1_critic_name = 'critic_layout_hierarchy_1'
502
+ h2_actor_name = 'actor_layout_hierarchy_2'
503
+
504
+ doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
505
+ filtered_table_information = json.load(open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
506
+ filtered_image_information = json.load(open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
507
+
508
+ with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
509
+ planner_config = yaml.safe_load(f)
510
+
511
+ with open(f"prompt_templates/{agent_init_name}.yaml", "r") as f:
512
+ config_init = yaml.safe_load(f)
513
+
514
+ with open(f"prompt_templates/{agent_new_section_name}.yaml", "r") as f:
515
+ config_new_section = yaml.safe_load(f)
516
+
517
+ with open(f"prompt_templates/{h1_critic_name}.yaml", "r") as f:
518
+ config_h1_critic = yaml.safe_load(f)
519
+
520
+ with open(f"prompt_templates/{h2_actor_name}.yaml", "r") as f:
521
+ config_h2_actor = yaml.safe_load(f)
522
+
523
+ planner_model = ModelFactory.create(
524
+ model_platform=actor_config['model_platform'],
525
+ model_type=actor_config['model_type'],
526
+ model_config_dict=actor_config['model_config'],
527
+ )
528
+
529
+ planner_agent = ChatAgent(
530
+ system_message=planner_config['system_prompt'],
531
+ model=planner_model,
532
+ message_window_size=10,
533
+ )
534
+
535
+ outline_template = jinja_env.from_string(planner_config["template"])
536
+
537
+ planner_jinja_args = {
538
+ 'json_content': doc_json,
539
+ 'table_information': filtered_table_information,
540
+ 'image_information': filtered_image_information,
541
+ }
542
+
543
+ actor_model = ModelFactory.create(
544
+ model_platform=actor_config['model_platform'],
545
+ model_type=actor_config['model_type'],
546
+ model_config_dict=actor_config['model_config'],
547
+ )
548
+
549
+ init_actor_sys_msg = config_init['system_prompt']
550
+
551
+ def create_init_actor_agent():
552
+ actor_model = ModelFactory.create(
553
+ model_platform=actor_config['model_platform'],
554
+ model_type=actor_config['model_type'],
555
+ model_config_dict=actor_config['model_config'],
556
+ )
557
+ init_actor_agent = ChatAgent(
558
+ system_message=init_actor_sys_msg,
559
+ model=actor_model,
560
+ message_window_size=10,
561
+ )
562
+ return init_actor_agent
563
+
564
+ new_section_actor_sys_msg = config_new_section['system_prompt']
565
+
566
+ def create_new_section_actor_agent():
567
+ actor_model = ModelFactory.create(
568
+ model_platform=actor_config['model_platform'],
569
+ model_type=actor_config['model_type'],
570
+ model_config_dict=actor_config['model_config'],
571
+ )
572
+ new_section_actor_agent = ChatAgent(
573
+ system_message=new_section_actor_sys_msg,
574
+ model=actor_model,
575
+ message_window_size=10,
576
+ )
577
+ return new_section_actor_agent
578
+
579
+ h1_critic_model = ModelFactory.create(
580
+ model_platform=critic_config['model_platform'],
581
+ model_type=critic_config['model_type'],
582
+ model_config_dict=critic_config['model_config'],
583
+ )
584
+
585
+ h1_critic_sys_msg = config_h1_critic['system_prompt']
586
+
587
+ h1_critic_agent = ChatAgent(
588
+ system_message=h1_critic_sys_msg,
589
+ model=h1_critic_model,
590
+ message_window_size=None,
591
+ )
592
+
593
+ h1_pos_example = Image.open('h1_example/h1_pos.jpg')
594
+ h1_neg_example = Image.open('h1_example/h1_neg.jpg')
595
+
596
+ h2_actor_model = ModelFactory.create(
597
+ model_platform=actor_config['model_platform'],
598
+ model_type=actor_config['model_type'],
599
+ model_config_dict=actor_config['model_config'],
600
+ )
601
+
602
+ h2_actor_sys_msg = config_h2_actor['system_prompt']
603
+
604
+ def create_h2_actor_agent():
605
+ h2_actor_model = ModelFactory.create(
606
+ model_platform=actor_config['model_platform'],
607
+ model_type=actor_config['model_type'],
608
+ model_config_dict=actor_config['model_config'],
609
+ )
610
+ h2_actor_agent = ChatAgent(
611
+ system_message=h2_actor_sys_msg,
612
+ model=h2_actor_model,
613
+ message_window_size=10,
614
+ )
615
+ return h2_actor_agent
616
+
617
+
618
+ init_template = jinja_env.from_string(config_init["template"])
619
+ new_section_template = jinja_env.from_string(config_new_section["template"])
620
+ h1_critic_template = jinja_env.from_string(config_h1_critic["template"])
621
+
622
+ attempt = 0
623
+ while True:
624
+ print(f'Generating outline attempt {attempt}...', flush=True)
625
+ planner_prompt = outline_template.render(**planner_jinja_args)
626
+ planner_agent.reset()
627
+ response = planner_agent.step(planner_prompt)
628
+ outline = get_json_from_response(response.msgs[0].content)
629
+ input_token, output_token = account_token(response)
630
+ sections = list(outline.keys())
631
+ sections = [x for x in sections if x != 'meta']
632
+ slide_width = outline['meta']['width']
633
+ slide_height = outline['meta']['height']
634
+ name_to_hierarchy = get_hierarchy(outline)
635
+ consumption_log['outline'].append((input_token, output_token))
636
+ total_input_token += input_token
637
+ total_output_token += output_token
638
+ init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
639
+
640
+ new_outline = outline
641
+
642
+ init_jinja_args = {
643
+ 'json_outline': init_outline,
644
+ 'function_docs': documentation
645
+ }
646
+
647
+ init_prompt = init_template.render(**init_jinja_args)
648
+
649
+ # hierarchy 1 only
650
+ outline_location = get_outline_location(outline, subsection=False)
651
+
652
+ logs, layout_cumulative_input_token, layout_cumulative_output_token = layout_process_section_wrapped(
653
+ sections,
654
+ new_outline,
655
+ init_template,
656
+ new_section_template,
657
+ init_actor_sys_msg,
658
+ new_section_actor_sys_msg,
659
+ actor_config,
660
+ documentation,
661
+ args.max_retry,
662
+ slide_width,
663
+ slide_height
664
+ )
665
+
666
+ concatenated_code = utils_functions
667
+ for section_index in range(len(sections)):
668
+ section = sections[section_index]
669
+ concatenated_code += '\n' + logs[section][-1]['code']
670
+ presentation_object_name = logs[section][-1]['output'].replace('\n', '')
671
+ concatenated_code += '\n' + f'save_presentation({presentation_object_name}, file_name="poster_{section_index + 1}.pptx")'
672
+
673
+ concatenated_code += f'''
674
+ name_to_hierarchy = {name_to_hierarchy}
675
+ identifier = "parallel"
676
+ poster_path = "poster_{section_index + 1}.pptx"
677
+ get_visual_cues(name_to_hierarchy, identifier, poster_path)
678
+ '''
679
+ output, error = run_code_with_utils(concatenated_code, utils_functions)
680
+ if error is not None:
681
+ print(error, flush=True)
682
+ attempt += 1
683
+ continue
684
+
685
+ consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
686
+ total_input_token += layout_cumulative_input_token
687
+ total_output_token += layout_cumulative_output_token
688
+
689
+ h1_path = f'tmp/poster_<parallel>_hierarchy_1.pptx'
690
+ h2_path = f'tmp/poster_<parallel>_hierarchy_2.pptx'
691
+
692
+ h1_filled_path = f'tmp/poster_<parallel>_hierarchy_1_filled.pptx'
693
+ h2_filled_path = f'tmp/poster_<parallel>_hierarchy_2_filled.pptx'
694
+
695
+ ppt_to_images(h1_path, 'tmp/layout_h1')
696
+ ppt_to_images(h2_path, 'tmp/layout_h2')
697
+ ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
698
+ ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
699
+
700
+ h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
701
+ h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
702
+ h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
703
+ h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
704
+
705
+ h1_critic_msg = BaseMessage.make_user_message(
706
+ role_name='User',
707
+ content=h1_critic_template.render(),
708
+ image_list=[h1_neg_example, h1_pos_example, h1_filled_img]
709
+ )
710
+
711
+ outline_bbox_dict = {}
712
+ for k, v in outline_location.items():
713
+ outline_bbox_dict[k] = v['location']
714
+
715
+ bbox_check_result = check_bounding_boxes(
716
+ outline_bbox_dict,
717
+ new_outline['meta']['width'],
718
+ new_outline['meta']['height']
719
+ )
720
+
721
+ if len(bbox_check_result) != 0:
722
+ print(bbox_check_result, flush=True)
723
+ attempt += 1
724
+ continue
725
+
726
+ h1_critic_agent.reset()
727
+ response = h1_critic_agent.step(h1_critic_msg)
728
+ input_token, output_token = account_token(response)
729
+ consumption_log['h1_critic'].append((input_token, output_token))
730
+ total_input_token += input_token
731
+ total_output_token += output_token
732
+ if response.msgs[0].content == 'T':
733
+ print('Blank area detected.', flush=True)
734
+ attempt += 1
735
+ continue
736
+
737
+ print('Sucessfully generated outline.', flush=True)
738
+
739
+ break
740
+
741
+ outline_bbox_dict = {}
742
+ for k, v in outline_location.items():
743
+ outline_bbox_dict[k] = v['location']
744
+
745
+ # Generate subsection locations
746
+ outline_no_sub_locations = copy.deepcopy(new_outline)
747
+ if 'meta' in outline_no_sub_locations:
748
+ outline_no_sub_locations.pop('meta')
749
+
750
+ for k, v in outline_no_sub_locations.items():
751
+ if 'subsections' in v:
752
+ subsections = v['subsections']
753
+ for k_sub, v_sub in subsections.items():
754
+ del v_sub['location']
755
+ del v_sub['name']
756
+
757
+ h2_actor_template = jinja_env.from_string(config_h2_actor["template"])
758
+
759
+ h2_cumulative_input_token = 0
760
+ h2_cumulative_output_token = 0
761
+
762
+ updated_sections = []
763
+
764
+ with ThreadPoolExecutor() as executor:
765
+ # Kick off all tasks
766
+ future_to_section = {
767
+ executor.submit(
768
+ layout_h2_process_section,
769
+ section,
770
+ outline_no_sub_locations,
771
+ h2_actor_template,
772
+ create_h2_actor_agent # pass the factory function
773
+ ): section
774
+ for section in sections
775
+ }
776
+
777
+ # Gather results as they complete
778
+ for future in concurrent.futures.as_completed(future_to_section):
779
+ section = future_to_section[future]
780
+ sec, final_sub_loc, in_toks, out_toks = future.result()
781
+
782
+ # Accumulate token usage
783
+ h2_cumulative_input_token += in_toks
784
+ h2_cumulative_output_token += out_toks
785
+
786
+ # Stash the final sub-loc for merging
787
+ updated_sections.append((sec, final_sub_loc))
788
+
789
+ # Now merge each updated subsection location back into outline_no_sub_locations
790
+ for (section, final_sub_loc) in updated_sections:
791
+ outline_no_sub_locations = fill_location(
792
+ outline_no_sub_locations,
793
+ section,
794
+ final_sub_loc
795
+ )
796
+
797
+ consumption_log['h2_actor'].append((h2_cumulative_input_token, h2_cumulative_output_token))
798
+ total_input_token += h2_cumulative_input_token
799
+ total_output_token += h2_cumulative_output_token
800
+
801
+ outline_no_sub_locations['meta'] = outline['meta']
802
+ outline_no_sub_locations_with_name = recover_name_and_location(outline_no_sub_locations, new_outline)
803
+ new_outline = outline_no_sub_locations_with_name
804
+
805
+ ### Outline finalized, actually generate layout
806
+
807
+ logs = {}
808
+
809
+ gen_layout_cumulative_input_token = 0
810
+ gen_layout_cumulative_output_token = 0
811
+ init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
812
+
813
+ new_outline = outline
814
+
815
+ init_jinja_args = {
816
+ 'json_outline': init_outline,
817
+ 'function_docs': documentation
818
+ }
819
+
820
+ outline_location = get_outline_location(outline, subsection=False)
821
+ logs = {}
822
+
823
+ # We'll store all updated logs here, keyed by section_index.
824
+ parallel_results = {}
825
+
826
+ with concurrent.futures.ThreadPoolExecutor() as executor:
827
+ futures = []
828
+ for section_index in range(len(sections)):
829
+ if section_index == 0:
830
+ create_agent_fn = create_init_actor_agent
831
+ prompt_template = init_template
832
+ else:
833
+ create_agent_fn = create_new_section_actor_agent
834
+ prompt_template = new_section_template
835
+ future = executor.submit(
836
+ layout_process_section,
837
+ section_index,
838
+ sections,
839
+ new_outline,
840
+ prompt_template,
841
+ documentation,
842
+ create_agent_fn,
843
+ args.max_retry,
844
+ name_to_hierarchy,
845
+ slide_width,
846
+ slide_height
847
+ )
848
+ futures.append(future)
849
+
850
+ # Collect the results as they come in
851
+ for future in concurrent.futures.as_completed(futures):
852
+ try:
853
+ section_index, section_logs, in_toks, out_toks = future.result()
854
+
855
+ # Store these logs in a dictionary keyed by the section index
856
+ parallel_results[section_index] = section_logs
857
+
858
+ # Update token counters
859
+ gen_layout_cumulative_input_token += in_toks
860
+ gen_layout_cumulative_output_token += out_toks
861
+
862
+ except Exception as exc:
863
+ print(f"[ERROR] A section failed: {exc}", flush=True)
864
+ # Possibly re-raise if you want to stop everything on error
865
+ # raise
866
+
867
+ # After all tasks complete, merge the results back into `logs`
868
+ for section_index, section_logs in parallel_results.items():
869
+ curr_section = sections[section_index]
870
+ logs[curr_section] = section_logs
871
+
872
+ concatenated_code = utils_functions
873
+ for section_index in range(len(sections)):
874
+ section = sections[section_index]
875
+ concatenated_code += '\n' + logs[section][-1]['code']
876
+ concatenated_code += '\n' + f'save_presentation(presentation, file_name="poster_{section_index + 1}.pptx")'
877
+
878
+ concatenated_code += f'''
879
+ name_to_hierarchy = {name_to_hierarchy}
880
+ identifier = "parallel"
881
+ poster_path = "poster_{section_index + 1}.pptx"
882
+ get_visual_cues(name_to_hierarchy, identifier, poster_path)
883
+ '''
884
+ output, error = run_code(concatenated_code)
885
+ if error is not None:
886
+ print(f'Failed to generate layout for section {curr_section}.')
887
+
888
+ consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
889
+ total_input_token += gen_layout_cumulative_input_token
890
+ total_output_token += gen_layout_cumulative_output_token
891
+
892
+ h1_path = f'tmp/poster_<parallel>_hierarchy_1.pptx'
893
+ h2_path = f'tmp/poster_<parallel>_hierarchy_2.pptx'
894
+
895
+ h1_filled_path = f'tmp/poster_<parallel>_hierarchy_1_filled.pptx'
896
+ h2_filled_path = f'tmp/poster_<parallel>_hierarchy_2_filled.pptx'
897
+
898
+ ppt_to_images(h1_path, 'tmp/layout_h1')
899
+ ppt_to_images(h2_path, 'tmp/layout_h2')
900
+ ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
901
+ ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
902
+
903
+ h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
904
+ h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
905
+ h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
906
+ h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
907
+
908
+
909
+ ckpt = {
910
+ 'logs': logs,
911
+ 'outline': new_outline,
912
+ 'name_to_hierarchy': name_to_hierarchy,
913
+ 'consumption_log': consumption_log,
914
+ 'total_input_token': total_input_token,
915
+ 'total_output_token': total_output_token,
916
+ }
917
+
918
+ with open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'wb') as f:
919
+ pkl.dump(ckpt, f)
920
+
921
+ json.dump(
922
+ new_outline,
923
+ open(outline_file_path, "w"),
924
+ ensure_ascii=False,
925
+ indent=4,
926
+ )
927
+
928
+ return total_input_token, total_output_token
929
+
930
+ if __name__ == '__main__':
931
+ parser = argparse.ArgumentParser()
932
+ parser.add_argument('--poster_name', type=str, default=None)
933
+ parser.add_argument('--model_name', type=str, default='4o')
934
+ parser.add_argument('--poster_path', type=str, required=True)
935
+ parser.add_argument('--index', type=int, default=0)
936
+ parser.add_argument('--max_retry', type=int, default=3)
937
+ args = parser.parse_args()
938
+
939
+ actor_config = get_agent_config(args.model_name)
940
+ critic_config = get_agent_config(args.model_name)
941
+
942
+ if args.poster_name is None:
943
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
944
+
945
+ input_token, output_token = filter_image_table(args, actor_config)
946
+ print(f'Token consumption: {input_token} -> {output_token}', flush=True)
947
+
948
+ input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
949
+ print(f'Token consumption: {input_token} -> {output_token}', flush=True)
Paper2Poster/PosterAgent/gen_poster_content.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import shutil
3
+ from dotenv import load_dotenv
4
+ from utils.src.utils import get_json_from_response
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ import json
7
+
8
+ from camel.models import ModelFactory
9
+ from PosterAgent.gen_pptx_code import generate_poster_code
10
+ from camel.agents import ChatAgent
11
+ from camel.messages import BaseMessage
12
+ from utils.src.utils import ppt_to_images
13
+ from PIL import Image
14
+
15
+ from utils.wei_utils import *
16
+
17
+ from utils.pptx_utils import *
18
+ from utils.critic_utils import *
19
+ import yaml
20
+ from jinja2 import Environment, StrictUndefined
21
+ import argparse
22
+
23
+ load_dotenv()
24
+ MAX_ATTEMPT = 10
25
+
26
+ def gen_content_process_section(
27
+ section_name,
28
+ outline,
29
+ raw_content,
30
+ raw_outline,
31
+ template,
32
+ create_actor_agent,
33
+ MAX_ATTEMPT
34
+ ):
35
+ """
36
+ Process a single section in its own thread or process.
37
+ Returns (section_name, result_json, total_input_token, total_output_token).
38
+ """
39
+ # Create a fresh ActorAgent instance for each parallel call
40
+ actor_agent = create_actor_agent()
41
+
42
+ section_outline = ''
43
+ num_attempts = 0
44
+ total_input_token = 0
45
+ total_output_token = 0
46
+ result_json = None
47
+
48
+ while True:
49
+ print(f"[Thread] Generating content for section: {section_name}")
50
+
51
+ if len(section_outline) == 0:
52
+ # Initialize the section outline
53
+ section_outline = json.dumps(outline[section_name], indent=4)
54
+
55
+ # Render prompt using Jinja template
56
+ jinja_args = {
57
+ 'json_outline': section_outline,
58
+ 'json_content': raw_content,
59
+ }
60
+ prompt = template.render(**jinja_args)
61
+
62
+ # Step the actor_agent and track tokens
63
+ response = actor_agent.step(prompt)
64
+ input_token, output_token = account_token(response)
65
+ total_input_token += input_token
66
+ total_output_token += output_token
67
+
68
+ # Parse JSON and possibly adjust text length
69
+ result_json = get_json_from_response(response.msgs[0].content)
70
+ new_section_outline, suggested = generate_length_suggestions(
71
+ result_json,
72
+ json.dumps(outline[section_name]),
73
+ raw_outline[section_name]
74
+ )
75
+ section_outline = json.dumps(new_section_outline, indent=4)
76
+
77
+ if not suggested:
78
+ # No more adjustments needed
79
+ break
80
+
81
+ print(f"[Thread] Adjusting text length for section: {section_name}...")
82
+
83
+ num_attempts += 1
84
+ if num_attempts >= MAX_ATTEMPT:
85
+ break
86
+
87
+ return section_name, result_json, total_input_token, total_output_token
88
+
89
+
90
+ def gen_content_parallel_process_sections(
91
+ sections,
92
+ outline,
93
+ raw_content,
94
+ raw_outline,
95
+ template,
96
+ create_actor_agent,
97
+ MAX_ATTEMPT=3
98
+ ):
99
+ """
100
+ Parallelize the section processing using ThreadPoolExecutor.
101
+ """
102
+ poster_content = {}
103
+ total_input_token = 0
104
+ total_output_token = 0
105
+
106
+ # Create a pool of worker threads (or processes)
107
+ with ThreadPoolExecutor() as executor:
108
+ futures = []
109
+
110
+ # Submit each section to be processed in parallel
111
+ for section_name in sections:
112
+ futures.append(
113
+ executor.submit(
114
+ gen_content_process_section,
115
+ section_name,
116
+ outline,
117
+ raw_content,
118
+ raw_outline,
119
+ template,
120
+ create_actor_agent,
121
+ MAX_ATTEMPT
122
+ )
123
+ )
124
+
125
+ # Collect results as they complete
126
+ for future in as_completed(futures):
127
+ section_name, result_json, sec_input_token, sec_output_token = future.result()
128
+ poster_content[section_name] = result_json
129
+ total_input_token += sec_input_token
130
+ total_output_token += sec_output_token
131
+
132
+ return poster_content, total_input_token, total_output_token
133
+
134
+ def render_textbox(text_arrangement, textbox_content, tmp_dir):
135
+ arrangement = copy.deepcopy(text_arrangement)
136
+ arrangement['x'] = 1
137
+ arrangement['y'] = 1
138
+
139
+ poster_code = generate_poster_code(
140
+ [],
141
+ [arrangement],
142
+ [],
143
+ presentation_object_name='poster_presentation',
144
+ slide_object_name='poster_slide',
145
+ utils_functions=utils_functions,
146
+ slide_width=text_arrangement['width'] + 3,
147
+ slide_height=text_arrangement['height'] + 3,
148
+ img_path='placeholder.jpg',
149
+ save_path=f'{tmp_dir}/poster.pptx',
150
+ visible=True,
151
+ content=textbox_content,
152
+ check_overflow=True,
153
+ tmp_dir=tmp_dir,
154
+ )
155
+
156
+ output, err = run_code(poster_code)
157
+ ppt_to_images(f'{tmp_dir}/poster.pptx', tmp_dir, output_type='jpg')
158
+ img = Image.open(f'{tmp_dir}/poster.jpg')
159
+
160
+ return img
161
+
162
+ def gen_poster_title_content(args, actor_config):
163
+ total_input_token, total_output_token = 0, 0
164
+ raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
165
+ actor_agent_name = 'poster_title_agent'
166
+
167
+ title_string = raw_content['meta']
168
+
169
+ with open(f'utils/prompt_templates/{actor_agent_name}.yaml', "r") as f:
170
+ content_config = yaml.safe_load(f)
171
+ jinja_env = Environment(undefined=StrictUndefined)
172
+ template = jinja_env.from_string(content_config["template"])
173
+
174
+ if args.model_name_t == 'vllm_qwen':
175
+ actor_model = ModelFactory.create(
176
+ model_platform=actor_config['model_platform'],
177
+ model_type=actor_config['model_type'],
178
+ model_config_dict=actor_config['model_config'],
179
+ url=actor_config['url'],
180
+ )
181
+ else:
182
+ actor_model = ModelFactory.create(
183
+ model_platform=actor_config['model_platform'],
184
+ model_type=actor_config['model_type'],
185
+ model_config_dict=actor_config['model_config']
186
+ )
187
+
188
+ actor_sys_msg = content_config['system_prompt']
189
+ actor_agent = ChatAgent(
190
+ system_message=actor_sys_msg,
191
+ model=actor_model,
192
+ message_window_size=30
193
+ )
194
+
195
+ jinja_args = {
196
+ 'title_string': title_string,
197
+ 'title_font_size': getattr(args, 'poster_title_font_size', None) or getattr(args, 'title_font_size', None),
198
+ 'author_font_size': getattr(args, 'poster_author_font_size', None) or getattr(args, 'author_font_size', None),
199
+ }
200
+ prompt = template.render(**jinja_args)
201
+ # Step the actor_agent and track tokens
202
+ actor_agent.reset()
203
+ response = actor_agent.step(prompt)
204
+ input_token, output_token = account_token(response)
205
+ total_input_token += input_token
206
+ total_output_token += output_token
207
+ result_json = get_json_from_response(response.msgs[0].content)
208
+
209
+ return result_json, total_input_token, total_output_token
210
+
211
+ def gen_bullet_point_content(args, actor_config, critic_config, agent_modify=True, tmp_dir='tmp'):
212
+ import json, yaml, copy, threading
213
+ from concurrent.futures import ThreadPoolExecutor, as_completed
214
+ from PIL import Image
215
+ from jinja2 import Environment, StrictUndefined
216
+
217
+ # ----------------------- Load data & configs -----------------------
218
+ total_input_token_t = total_output_token_t = 0
219
+ total_input_token_v = total_output_token_v = 0
220
+
221
+ raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
222
+ with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'r') as f:
223
+ tree_split_results = json.load(f)
224
+
225
+ panels = tree_split_results['panels']
226
+ text_arrangement_list = tree_split_results['text_arrangement_inches']
227
+
228
+ actor_agent_name = 'bullet_point_agent'
229
+ if args.model_name_v == 'vllm_qwen_vl':
230
+ critic_agent_name = 'critic_overlap_agent_v3_short'
231
+ else:
232
+ critic_agent_name = 'critic_overlap_agent_v3'
233
+
234
+ with open(f"utils/prompt_templates/{actor_agent_name}.yaml", "r") as f:
235
+ content_config = yaml.safe_load(f)
236
+ with open(f"utils/prompt_templates/{critic_agent_name}.yaml", "r") as f:
237
+ critic_content_config = yaml.safe_load(f)
238
+
239
+ jinja_env = Environment(undefined=StrictUndefined)
240
+ template = jinja_env.from_string(content_config["template"])
241
+ critic_template = jinja_env.from_string(critic_content_config["template"])
242
+
243
+ # Preload images once (each worker can reopen if needed, or just pass paths)
244
+ neg_img_path = 'assets/overflow_example_v2/neg.jpg'
245
+ pos_img_path = 'assets/overflow_example_v2/pos.jpg'
246
+
247
+ # Group text arrangements by panel_id for O(1) lookup in workers
248
+ from collections import defaultdict
249
+ textboxes_by_panel = defaultdict(list)
250
+ for ta in text_arrangement_list:
251
+ textboxes_by_panel[ta['panel_id']].append(ta)
252
+ # Ensure deterministic order inside each panel
253
+ for k in textboxes_by_panel:
254
+ textboxes_by_panel[k] = sorted(textboxes_by_panel[k], key=lambda x: x.get('textbox_id', 0))
255
+
256
+ # ----------------------- Worker (defined INSIDE main fn) -----------------------
257
+ def _process_section(i):
258
+ """
259
+ Returns:
260
+ (i, result_json, t_in, t_out, v_in, v_out)
261
+ """
262
+ local_t_in = local_t_out = 0
263
+ local_v_in = local_v_out = 0
264
+
265
+ arrangement = panels[i]
266
+ num_textboxes = 2 if arrangement.get('gp', 0) > 0 else 1
267
+
268
+ local_tmp_dir = tempfile.mkdtemp(prefix=f"sec_{i}_", dir=tmp_dir)
269
+
270
+ jinja_args = {
271
+ 'summary_of_section': raw_content['sections'][i]['content'],
272
+ 'number_of_textboxes': num_textboxes,
273
+ 'section_title': raw_content['sections'][i]['title'],
274
+ 'bullet_font_size': args.bullet_font_size,
275
+ 'section_title_font_size': args.section_title_font_size,
276
+ }
277
+
278
+ target_textboxes = textboxes_by_panel[i][1:] # skip first (section title)
279
+ total_expected_length = sum(tb['num_chars'] for tb in target_textboxes)
280
+
281
+ # Create fresh models & agents per thread for safety
282
+ if args.model_name_t.startswith('vllm_qwen'):
283
+ actor_model = ModelFactory.create(
284
+ model_platform=actor_config['model_platform'],
285
+ model_type=actor_config['model_type'],
286
+ model_config_dict=actor_config['model_config'],
287
+ url=actor_config['url'],
288
+ )
289
+ else:
290
+ actor_model = ModelFactory.create(
291
+ model_platform=actor_config['model_platform'],
292
+ model_type=actor_config['model_type'],
293
+ model_config_dict=actor_config['model_config']
294
+ )
295
+ if args.model_name_v.startswith('vllm_qwen'):
296
+ critic_model = ModelFactory.create(
297
+ model_platform=critic_config['model_platform'],
298
+ model_type=critic_config['model_type'],
299
+ model_config_dict=critic_config['model_config'],
300
+ url=critic_config['url'],
301
+ )
302
+ else:
303
+ critic_model = ModelFactory.create(
304
+ model_platform=critic_config['model_platform'],
305
+ model_type=critic_config['model_type'],
306
+ model_config_dict=critic_config['model_config']
307
+ )
308
+
309
+ actor_agent = ChatAgent(system_message=content_config['system_prompt'], model=actor_model, message_window_size=30)
310
+ critic_agent = ChatAgent(system_message=critic_content_config['system_prompt'], model=critic_model, message_window_size=10)
311
+
312
+ prompt = template.render(**jinja_args)
313
+ actor_agent.reset()
314
+ response = actor_agent.step(prompt)
315
+ t_in, t_out = account_token(response)
316
+ local_t_in += t_in
317
+ local_t_out += t_out
318
+
319
+ result_json = get_json_from_response(response.msgs[0].content)
320
+
321
+ max_attempts = 5
322
+ num_attempts = 0
323
+ old_result_json = copy.deepcopy(result_json)
324
+
325
+ # Length control loop
326
+ while args.estimate_chars:
327
+ num_attempts += 1
328
+ if num_attempts > max_attempts:
329
+ result_json = old_result_json
330
+ break
331
+ try:
332
+ total_bullet_length = 0
333
+ for j in range(num_textboxes):
334
+ bullet_content_key = f'textbox{j + 1}'
335
+ total_bullet_length += compute_bullet_length(result_json[bullet_content_key])
336
+ except Exception:
337
+ result_json = old_result_json
338
+ break
339
+
340
+ if total_bullet_length > total_expected_length:
341
+ percentage_to_shrink = int((total_bullet_length - total_expected_length) / total_bullet_length * 100)
342
+ percentage_to_shrink = min(90, percentage_to_shrink + 10)
343
+ old_result_json = copy.deepcopy(result_json)
344
+ response = actor_agent.step('Too long, please shorten the bullet points by ' + str(percentage_to_shrink) + '%.')
345
+ t_in, t_out = account_token(response)
346
+ local_t_in += t_in
347
+ local_t_out += t_out
348
+ result_json = get_json_from_response(response.msgs[0].content)
349
+ else:
350
+ break
351
+
352
+ critic_prompt = critic_template.render()
353
+ bullet_contents = ['textbox1'] + (['textbox2'] if num_textboxes == 2 else [])
354
+
355
+ # Visual overflow/blank detection & correction
356
+ for j, text_arrangement in enumerate(target_textboxes[:num_textboxes]):
357
+ bullet_content = bullet_contents[j]
358
+ curr_round = 0
359
+ while True:
360
+ if args.ablation_no_commenter:
361
+ break
362
+ curr_round += 1
363
+ img = render_textbox(text_arrangement, result_json[bullet_content], local_tmp_dir)
364
+ if args.model_name_v.startswith('vllm_qwen') or args.ablation_no_example:
365
+ critic_msg = BaseMessage.make_user_message(
366
+ role_name="User",
367
+ content=critic_prompt,
368
+ image_list=[img],
369
+ )
370
+ else:
371
+ critic_msg = BaseMessage.make_user_message(
372
+ role_name="User",
373
+ content=critic_prompt,
374
+ image_list=[Image.open(neg_img_path), Image.open(pos_img_path), img],
375
+ )
376
+
377
+ critic_agent.reset()
378
+ response = critic_agent.step(critic_msg)
379
+ v_in, v_out = account_token(response)
380
+ local_v_in += v_in
381
+ local_v_out += v_out
382
+
383
+ decision = response.msgs[0].content.lower()
384
+ if decision in ['1', '1.', '"1"', "'1'"]:
385
+ if curr_round > 10:
386
+ print(f'Section {i}: Too many rounds of modification, breaking...')
387
+ break
388
+ if agent_modify:
389
+ print(f'Section {i}: Text overflow detected, modifying...')
390
+ modify_message = f'{bullet_content} is too long, please shorten that part, other content should stay the same. Return the entire modified JSON.'
391
+ response = actor_agent.step(modify_message)
392
+ t_in, t_out = account_token(response)
393
+ local_t_in += t_in
394
+ local_t_out += t_out
395
+ result_json = get_json_from_response(response.msgs[0].content)
396
+ else:
397
+ # naive truncate
398
+ result_json[bullet_content] = result_json[bullet_content][:-1]
399
+ continue
400
+ elif decision in ['2', '2.', '"2"', "'2'"]:
401
+ if args.no_blank_detection:
402
+ print(f'Section {i}: No blank space detection, skipping...')
403
+ break
404
+ if curr_round > 10:
405
+ print(f'Section {i}: Too many rounds of modification, breaking...')
406
+ break
407
+ print(f'Section {i}: Too much blank space detected, modifying...')
408
+ modify_message = f'{bullet_content} is too short, please add one more bullet point, other content should stay the same. Return the entire modified JSON.'
409
+ response = actor_agent.step(modify_message)
410
+ t_in, t_out = account_token(response)
411
+ local_t_in += t_in
412
+ local_t_out += t_out
413
+ result_json = get_json_from_response(response.msgs[0].content)
414
+ else:
415
+ break
416
+
417
+ # Clean up temp dir
418
+ if local_tmp_dir:
419
+ try:
420
+ print(f'Section {i}: Cleaning up temp dir {local_tmp_dir}')
421
+ shutil.rmtree(local_tmp_dir)
422
+ except Exception as e:
423
+ print(f"Error cleaning up temp dir {local_tmp_dir}: {e}")
424
+ return i, result_json, local_t_in, local_t_out, local_v_in, local_v_out
425
+
426
+ # ----------------------- Parallel execution -----------------------
427
+ max_workers = getattr(args, 'max_workers', 4)
428
+ results = {}
429
+ lock = threading.Lock()
430
+
431
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
432
+ futures = {
433
+ ex.submit(_process_section, i): i
434
+ for i in range(1, len(raw_content['sections']))
435
+ }
436
+ for fut in as_completed(futures):
437
+ i, rjson, t_in, t_out, v_in, v_out = fut.result()
438
+ with lock:
439
+ results[i] = rjson
440
+ total_input_token_t += t_in
441
+ total_output_token_t += t_out
442
+ total_input_token_v += v_in
443
+ total_output_token_v += v_out
444
+
445
+ # ----------------------- Title generation (sequential) -----------------------
446
+ title_json, title_input_token, title_output_token = gen_poster_title_content(args, actor_config)
447
+ total_input_token_t += title_input_token
448
+ total_output_token_t += title_output_token
449
+
450
+ # ----------------------- Assemble & save -----------------------
451
+ bullet_point_content = [title_json]
452
+ for idx in range(1, len(raw_content['sections'])):
453
+ bullet_point_content.append(results[idx])
454
+
455
+ json.dump(
456
+ bullet_point_content,
457
+ open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'w'),
458
+ indent=2
459
+ )
460
+
461
+ return total_input_token_t, total_output_token_t, total_input_token_v, total_output_token_v
462
+
463
+ def gen_poster_content(args, actor_config):
464
+ total_input_token, total_output_token = 0, 0
465
+ raw_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
466
+ agent_name = 'poster_content_agent'
467
+
468
+ with open(f"utils/prompt_templates/{agent_name}.yaml", "r") as f:
469
+ content_config = yaml.safe_load(f)
470
+
471
+ actor_model = ModelFactory.create(
472
+ model_platform=actor_config['model_platform'],
473
+ model_type=actor_config['model_type'],
474
+ model_config_dict=actor_config['model_config']
475
+ )
476
+
477
+ actor_sys_msg = content_config['system_prompt']
478
+
479
+ def create_actor_agent():
480
+ actor_agent = ChatAgent(
481
+ system_message=actor_sys_msg,
482
+ model=actor_model,
483
+ message_window_size=10
484
+ )
485
+ return actor_agent
486
+
487
+ outline = json.load(open(f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json', 'r'))
488
+ raw_outline = json.loads(json.dumps(outline))
489
+ outline_estimate_num_chars(outline)
490
+ outline = remove_hierarchy_and_id(outline)
491
+
492
+ sections = list(outline.keys())
493
+ sections = [s for s in sections if s != 'meta']
494
+
495
+ jinja_env = Environment(undefined=StrictUndefined)
496
+
497
+ template = jinja_env.from_string(content_config["template"])
498
+
499
+ poster_content = {}
500
+
501
+ poster_content, total_input_token, total_output_token = gen_content_parallel_process_sections(
502
+ sections,
503
+ outline,
504
+ raw_content,
505
+ raw_outline,
506
+ template,
507
+ create_actor_agent,
508
+ MAX_ATTEMPT=5
509
+ )
510
+
511
+ json.dump(poster_content, open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'w'), indent=2)
512
+ return total_input_token, total_output_token
513
+
514
+ if __name__ == '__main__':
515
+ parser = argparse.ArgumentParser()
516
+ parser.add_argument('--poster_name', type=str, default=None)
517
+ parser.add_argument('--model_name', type=str, default='4o')
518
+ parser.add_argument('--poster_path', type=str, required=True)
519
+ parser.add_argument('--index', type=int, default=0)
520
+ parser.add_argument('--max_retry', type=int, default=3)
521
+ args = parser.parse_args()
522
+
523
+ actor_config = get_agent_config(args.model_name)
524
+ if args.poster_name is None:
525
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
526
+
527
+ input_token, output_token = gen_poster_content(args, actor_config)
528
+
529
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/gen_pptx_code.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import json
3
+
4
+ def sanitize_for_var(name):
5
+ # Convert any character that is not alphanumeric or underscore into underscore.
6
+ return re.sub(r'[^0-9a-zA-Z_]+', '_', name)
7
+
8
+ def initialize_poster_code(width, height, slide_object_name, presentation_object_name, utils_functions):
9
+ code = utils_functions
10
+ code += fr'''
11
+ # Poster: {presentation_object_name}
12
+ {presentation_object_name} = create_poster(width_inch={width}, height_inch={height})
13
+
14
+ # Slide: {slide_object_name}
15
+ {slide_object_name} = add_blank_slide({presentation_object_name})
16
+ '''
17
+
18
+ return code
19
+
20
+ def save_poster_code(output_file, utils_functions, presentation_object_name):
21
+ code = utils_functions
22
+ code = fr'''
23
+ # Save the presentation
24
+ save_presentation({presentation_object_name}, file_name="{output_file}")
25
+ '''
26
+ return code
27
+
28
+ def generate_panel_code(panel_dict, utils_functions, slide_object_name, visible=False, theme=None):
29
+ code = utils_functions
30
+ raw_name = panel_dict["panel_name"]
31
+ var_name = 'var_' + sanitize_for_var(raw_name)
32
+
33
+ code += fr'''
34
+ # Panel: {raw_name}
35
+ {var_name} = add_textbox(
36
+ {slide_object_name},
37
+ '{var_name}',
38
+ {panel_dict['x']},
39
+ {panel_dict['y']},
40
+ {panel_dict['width']},
41
+ {panel_dict['height']},
42
+ text="",
43
+ word_wrap=True,
44
+ font_size=40,
45
+ bold=False,
46
+ italic=False,
47
+ alignment="left",
48
+ fill_color=None,
49
+ font_name="Arial"
50
+ )
51
+ '''
52
+
53
+ if visible:
54
+ if theme is None:
55
+ code += fr'''
56
+ # Make border visible
57
+ style_shape_border({var_name}, color=(0, 0, 0), thickness=5, line_style="solid")
58
+ '''
59
+ else:
60
+ code += fr'''
61
+ # Make border visible
62
+ style_shape_border({var_name}, color={theme['color']}, thickness={theme['thickness']}, line_style="{theme['line_style']}")
63
+ '''
64
+
65
+ return code
66
+
67
+ def generate_textbox_code(
68
+ text_dict,
69
+ utils_functions,
70
+ slide_object_name,
71
+ visible=False,
72
+ content=None,
73
+ theme=None,
74
+ tmp_dir='tmp',
75
+ is_title=False,
76
+ ):
77
+ code = utils_functions
78
+ raw_name = text_dict["textbox_name"]
79
+ var_name = sanitize_for_var(raw_name)
80
+
81
+ code += fr'''
82
+ # Textbox: {raw_name}
83
+ {var_name} = add_textbox(
84
+ {slide_object_name},
85
+ '{var_name}',
86
+ {text_dict['x']},
87
+ {text_dict['y']},
88
+ {text_dict['width']},
89
+ {text_dict['height']},
90
+ text="",
91
+ word_wrap=True,
92
+ font_size=40,
93
+ bold=False,
94
+ italic=False,
95
+ alignment="left",
96
+ fill_color=None,
97
+ font_name="Arial"
98
+ )
99
+ '''
100
+ if visible:
101
+ # Extract textbox_theme from full theme if needed
102
+ textbox_border_theme = None
103
+ if theme is not None and isinstance(theme, dict):
104
+ textbox_border_theme = theme.get('textbox_theme')
105
+
106
+ if textbox_border_theme is None:
107
+ code += fr'''
108
+ # Make border visible
109
+ style_shape_border({var_name}, color=(255, 0, 0), thickness=5, line_style="solid")
110
+ '''
111
+ else:
112
+ code += fr'''
113
+ # Make border visible
114
+ style_shape_border({var_name}, color={textbox_border_theme['color']}, thickness={textbox_border_theme['thickness']}, line_style="{textbox_border_theme['line_style']}")
115
+ '''
116
+
117
+ if content is not None:
118
+ tmp_name = f'{tmp_dir}/{var_name}_content.json'
119
+ json.dump(content, open(tmp_name, 'w'), indent=4)
120
+
121
+ # Determine vertical alignment
122
+ vertical_anchor = None
123
+ if is_title and theme is not None and 'section_title_vertical_align' in theme:
124
+ vertical_anchor = theme['section_title_vertical_align']
125
+
126
+ if vertical_anchor:
127
+ code += fr'''
128
+ fill_textframe({var_name}, json.load(open('{tmp_name}', 'r')), vertical_anchor="{vertical_anchor}")
129
+ '''
130
+ else:
131
+ code += fr'''
132
+ fill_textframe({var_name}, json.load(open('{tmp_name}', 'r')))
133
+ '''
134
+
135
+ return code
136
+
137
+ def generate_figure_code(figure_dict, utils_functions, slide_object_name, img_path, visible=False, theme=None):
138
+ code = utils_functions
139
+ raw_name = figure_dict["figure_name"]
140
+ var_name = sanitize_for_var(raw_name)
141
+
142
+ code += fr'''
143
+ # Figure: {raw_name}
144
+ {var_name} = add_image(
145
+ {slide_object_name},
146
+ '{var_name}',
147
+ {figure_dict['x']},
148
+ {figure_dict['y']},
149
+ {figure_dict['width']},
150
+ {figure_dict['height']},
151
+ image_path="{img_path}"
152
+ )
153
+ '''
154
+
155
+ if visible:
156
+ if theme is None:
157
+ code += fr'''
158
+ # Make border visible
159
+ style_shape_border({var_name}, color=(0, 0, 255), thickness=5, line_style="long_dash_dot")
160
+ '''
161
+ else:
162
+ code += fr'''
163
+ # Make border visible
164
+ style_shape_border({var_name}, color={theme['color']}, thickness={theme['thickness']}, line_style="{theme['line_style']}")
165
+ '''
166
+
167
+ return code
168
+
169
+ def generate_poster_code(
170
+ panel_arrangement_list,
171
+ text_arrangement_list,
172
+ figure_arrangement_list,
173
+ presentation_object_name,
174
+ slide_object_name,
175
+ utils_functions,
176
+ slide_width,
177
+ slide_height,
178
+ img_path,
179
+ save_path,
180
+ visible=False,
181
+ content=None,
182
+ check_overflow=False,
183
+ theme=None,
184
+ tmp_dir='tmp',
185
+ ):
186
+ code = ''
187
+ code += initialize_poster_code(slide_width, slide_height, slide_object_name, presentation_object_name, utils_functions)
188
+
189
+ if theme is None:
190
+ panel_visible = visible
191
+ textbox_visible = visible
192
+ figure_visible = visible
193
+
194
+ panel_theme, textbox_theme, figure_theme = None, None, None
195
+ else:
196
+ panel_visible = theme['panel_visible']
197
+ textbox_visible = theme['textbox_visible']
198
+ figure_visible = theme['figure_visible']
199
+ panel_theme = theme['panel_theme']
200
+ textbox_theme = theme['textbox_theme']
201
+ figure_theme = theme['figure_theme']
202
+
203
+ for p in panel_arrangement_list:
204
+ code += generate_panel_code(p, '', slide_object_name, panel_visible, panel_theme)
205
+
206
+ if check_overflow:
207
+ t = text_arrangement_list[0]
208
+ # Pass full theme for consistency
209
+ code += generate_textbox_code(t, '', slide_object_name, textbox_visible, content, theme, tmp_dir, is_title=False)
210
+ else:
211
+ all_content = []
212
+ title_indices = set() # Track which indices are section titles
213
+ if content is not None:
214
+ idx = 0
215
+ for section_content in content:
216
+ if 'title' in section_content:
217
+ all_content.append(section_content['title'])
218
+ title_indices.add(idx) # Mark this index as a title
219
+ idx += 1
220
+ if len(section_content) == 2:
221
+ all_content.append(section_content['textbox1'])
222
+ idx += 1
223
+ elif len(section_content) == 3:
224
+ all_content.append(section_content['textbox1'])
225
+ all_content.append(section_content['textbox2'])
226
+ idx += 2
227
+ else:
228
+ raise ValueError(f"Unexpected content length: {len(section_content)}")
229
+
230
+ for i in range(len(text_arrangement_list)):
231
+ t = text_arrangement_list[i]
232
+ if content is not None:
233
+ textbox_content = all_content[i]
234
+ is_title = i in title_indices
235
+ else:
236
+ textbox_content = None
237
+ is_title = False
238
+ # Pass full theme (not textbox_theme) so vertical alignment config is available
239
+ code += generate_textbox_code(t, '', slide_object_name, textbox_visible, textbox_content, theme, tmp_dir, is_title=is_title)
240
+
241
+ for f in figure_arrangement_list:
242
+ if img_path is None:
243
+ code += generate_figure_code(f, '', slide_object_name, f['figure_path'], figure_visible, figure_theme)
244
+ else:
245
+ code += generate_figure_code(f, '', slide_object_name, img_path, figure_visible, figure_theme)
246
+
247
+ code += save_poster_code(save_path, '', presentation_object_name)
248
+
249
+ return code
Paper2Poster/PosterAgent/new_pipeline.py ADDED
@@ -0,0 +1,547 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ print("Initializing...")
3
+ from PosterAgent.parse_raw import parse_raw, gen_image_and_table
4
+ from PosterAgent.gen_outline_layout import filter_image_table, gen_outline_layout_v2
5
+ from utils.wei_utils import get_agent_config, utils_functions, run_code, scale_to_target_area, char_capacity
6
+ from PosterAgent.tree_split_layout import main_train, main_inference, get_arrangments_in_inches, split_textbox, to_inches
7
+ # from PosterAgent.gen_pptx_code import generate_poster_code
8
+ # from utils.src.utils import ppt_to_images
9
+ # from PosterAgent.gen_poster_content import gen_bullet_point_content
10
+ from utils.ablation_utils import no_tree_get_layout
11
+
12
+ # Import refactored utilities
13
+ from utils.logo_utils import LogoManager, add_logos_to_poster_code
14
+ # from utils.config_utils import (
15
+ # load_poster_yaml_config, extract_font_sizes, extract_colors,
16
+ # extract_vertical_alignment, extract_section_title_symbol, normalize_config_values
17
+ # )
18
+ # from utils.style_utils import apply_all_styles
19
+ # from utils.theme_utils import get_default_theme, create_theme_with_alignment, resolve_colors
20
+
21
+ # from PosterAgent.gen_beamer_code import (
22
+ # generate_beamer_poster_code,
23
+ # save_beamer_code,
24
+ # compile_beamer_to_pdf,
25
+ # convert_pptx_layout_to_beamer
26
+ # )
27
+
28
+ import argparse
29
+ import json
30
+
31
+ import time
32
+ import shutil
33
+
34
+ units_per_inch = 25
35
+
36
+ if __name__ == '__main__':
37
+
38
+ parser = argparse.ArgumentParser(description='Poster Generation Pipeline with Logo Support')
39
+ parser.add_argument('--poster_path', type=str)
40
+ parser.add_argument('--model_name_t', type=str, default='4o')
41
+ parser.add_argument('--model_name_v', type=str, default='4o')
42
+ parser.add_argument('--index', type=int, default=0)
43
+ parser.add_argument('--poster_name', type=str, default=None)
44
+ parser.add_argument('--tmp_dir', type=str, default='tmp')
45
+ parser.add_argument('--estimate_chars', action='store_true')
46
+ parser.add_argument('--max_workers', type=int, default=10)
47
+ parser.add_argument('--poster_width_inches', type=int, default=None)
48
+ parser.add_argument('--poster_height_inches', type=int, default=None)
49
+ parser.add_argument('--no_blank_detection', action='store_true', help='When overflow is severe, try this option.')
50
+ parser.add_argument('--ablation_no_tree_layout', action='store_true', help='Ablation study: no tree layout')
51
+ parser.add_argument('--ablation_no_commenter', action='store_true', help='Ablation study: no commenter')
52
+ parser.add_argument('--ablation_no_example', action='store_true', help='Ablation study: no example')
53
+
54
+ # Logo-related arguments
55
+ parser.add_argument('--conference_venue', type=str, default=None,
56
+ help='Conference name for automatic logo search (e.g., "NeurIPS", "CVPR")')
57
+ parser.add_argument('--institution_logo_path', type=str, default=None,
58
+ help='Custom path to institution logo (auto-searches from paper metadata if not provided)')
59
+ parser.add_argument('--conference_logo_path', type=str, default=None,
60
+ help='Custom path to conference logo (auto-searches if venue specified)')
61
+ parser.add_argument('--use_google_search', action='store_true',
62
+ help='Use Google Custom Search API for logo search (requires API keys in .env)')
63
+
64
+ args = parser.parse_args()
65
+
66
+ start_time = time.time()
67
+
68
+ os.makedirs(args.tmp_dir, exist_ok=True)
69
+
70
+ detail_log = {}
71
+
72
+ agent_config_t = get_agent_config(args.model_name_t)
73
+ agent_config_v = get_agent_config(args.model_name_v)
74
+ poster_name = args.poster_path.split('/')[-2].replace(' ', '_')
75
+ if args.poster_name is None:
76
+ args.poster_name = poster_name
77
+ else:
78
+ poster_name = args.poster_name
79
+ meta_json_path = args.poster_path.replace('paper.pdf', 'meta.json')
80
+ if args.poster_width_inches is not None and args.poster_height_inches is not None:
81
+ poster_width = args.poster_width_inches * units_per_inch
82
+ poster_height = args.poster_height_inches * units_per_inch
83
+ elif os.path.exists(meta_json_path):
84
+ meta_json = json.load(open(meta_json_path, 'r'))
85
+ poster_width = meta_json['width']
86
+ poster_height = meta_json['height']
87
+ else:
88
+ poster_width = 48 * units_per_inch
89
+ poster_height = 36 * units_per_inch
90
+
91
+ poster_width, poster_height = scale_to_target_area(poster_width, poster_height)
92
+ poster_width_inches = to_inches(poster_width, units_per_inch)
93
+ poster_height_inches = to_inches(poster_height, units_per_inch)
94
+
95
+ if poster_width_inches > 56 or poster_height_inches > 56:
96
+ # Work out which side is longer, then compute a single scale factor
97
+ if poster_width_inches >= poster_height_inches:
98
+ scale_factor = 56 / poster_width_inches
99
+ else:
100
+ scale_factor = 56 / poster_height_inches
101
+
102
+ poster_width_inches *= scale_factor
103
+ poster_height_inches *= scale_factor
104
+
105
+ # convert back to internal units
106
+ poster_width = poster_width_inches * units_per_inch
107
+ poster_height = poster_height_inches * units_per_inch
108
+
109
+ print(f'Poster size: {poster_width_inches} x {poster_height_inches} inches')
110
+
111
+ total_input_tokens_t, total_output_tokens_t = 0, 0
112
+ total_input_tokens_v, total_output_tokens_v = 0, 0
113
+
114
+ # Step 1: Parse the raw poster
115
+ input_token, output_token, raw_result = parse_raw(args, agent_config_t, version=2)
116
+ total_input_tokens_t += input_token
117
+ total_output_tokens_t += output_token
118
+
119
+ _, _, images, tables = gen_image_and_table(args, raw_result)
120
+
121
+ print(f'Parsing token consumption: {input_token} -> {output_token}')
122
+
123
+ parser_time_taken = time.time() - start_time
124
+ print(f'Parser time: {parser_time_taken:.2f} seconds')
125
+ detail_log['parser_time'] = parser_time_taken
126
+
127
+ parser_time = time.time()
128
+
129
+ detail_log['parser_in_t'] = input_token
130
+ detail_log['parser_out_t'] = output_token
131
+
132
+ # Initialize LogoManager
133
+ logo_manager = LogoManager()
134
+ institution_logo_path = args.institution_logo_path
135
+ conference_logo_path = args.conference_logo_path
136
+
137
+ # Auto-detect institution from paper if not provided
138
+ # Now using the raw_result directly instead of reading from file
139
+ if not institution_logo_path:
140
+ print("\n" + "="*60)
141
+ print("🔍 AUTO-DETECTING INSTITUTION FROM PAPER")
142
+ print("="*60)
143
+
144
+ # Use the raw_result we already have from the parser
145
+ if raw_result:
146
+ print(f"📄 Using parsed paper content")
147
+ # Extract text content from the ConversionResult object
148
+ try:
149
+ paper_text = raw_result.document.export_to_markdown()
150
+ except:
151
+ # Fallback: try to get text content in another way
152
+ paper_text = str(raw_result)
153
+
154
+ print("🔎 Searching for FIRST AUTHOR's institution...")
155
+ first_author_inst = logo_manager.extract_first_author_institution(paper_text)
156
+
157
+ if first_author_inst:
158
+ print(f"\n✅ FIRST AUTHOR INSTITUTION: {first_author_inst}")
159
+ print(f"🔍 Searching for logo: {first_author_inst}")
160
+
161
+ inst_logo_path = logo_manager.get_logo_path(first_author_inst, category="institute", use_google=args.use_google_search)
162
+ if inst_logo_path:
163
+ institution_logo_path = str(inst_logo_path)
164
+ print(f"✅ Institution logo found: {institution_logo_path}")
165
+ else:
166
+ print(f"❌ Could not find/download logo for: {first_author_inst}")
167
+ else:
168
+ print("❌ No first author institution detected or matched with available logos")
169
+ else:
170
+ print("❌ No parsed content available")
171
+ print("="*60 + "\n")
172
+
173
+ # Handle conference logo
174
+ if args.conference_venue and not conference_logo_path:
175
+ print("\n" + "="*60)
176
+ print("🏛️ SEARCHING FOR CONFERENCE LOGO")
177
+ print("="*60)
178
+ print(f"📍 Conference: {args.conference_venue}")
179
+ print(f"🔍 Searching for logo...")
180
+
181
+ conf_logo_path = logo_manager.get_logo_path(args.conference_venue, category="conference", use_google=args.use_google_search)
182
+ if conf_logo_path:
183
+ conference_logo_path = str(conf_logo_path)
184
+ print(f"✅ Conference logo found: {conference_logo_path}")
185
+ else:
186
+ print(f"❌ Could not find/download logo for: {args.conference_venue}")
187
+ # Note: Web search is now handled inside get_logo_path automatically
188
+ print("="*60 + "\n")
189
+
190
+ # Step 2: Filter unnecessary images and tables
191
+ input_token, output_token = filter_image_table(args, agent_config_t)
192
+ total_input_tokens_t += input_token
193
+ total_output_tokens_t += output_token
194
+ print(f'Filter figures token consumption: {input_token} -> {output_token}')
195
+
196
+ filter_time_taken = time.time() - parser_time
197
+ print(f'Filter time: {filter_time_taken:.2f} seconds')
198
+ detail_log['filter_time'] = filter_time_taken
199
+
200
+ filter_time = time.time()
201
+
202
+ detail_log['filter_in_t'] = input_token
203
+ detail_log['filter_out_t'] = output_token
204
+
205
+ # Step 3: Generate outline
206
+ input_token, output_token, panels, figures = gen_outline_layout_v2(args, agent_config_t)
207
+ total_input_tokens_t += input_token
208
+ total_output_tokens_t += output_token
209
+ print(f'Outline token consumption: {input_token} -> {output_token}')
210
+
211
+ outline_time_taken = time.time() - filter_time
212
+ print(f'Outline time: {outline_time_taken:.2f} seconds')
213
+ detail_log['outline_time'] = outline_time_taken
214
+
215
+ outline_time = time.time()
216
+
217
+ detail_log['outline_in_t'] = input_token
218
+ detail_log['outline_out_t'] = output_token
219
+
220
+ if args.ablation_no_tree_layout:
221
+ panel_arrangement, figure_arrangement, text_arrangement, input_token, output_token = no_tree_get_layout(
222
+ poster_width,
223
+ poster_height,
224
+ panels,
225
+ figures,
226
+ agent_config_t
227
+ )
228
+ total_input_tokens_t += input_token
229
+ total_output_tokens_t += output_token
230
+ print(f'No tree layout token consumption: {input_token} -> {output_token}')
231
+ detail_log['no_tree_layout_in_t'] = input_token
232
+ detail_log['no_tree_layout_out_t'] = output_token
233
+ else:
234
+
235
+ # Step 4: Learn and generate layout
236
+ panel_model_params, figure_model_params = main_train()
237
+
238
+ panel_arrangement, figure_arrangement, text_arrangement = main_inference(
239
+ panels,
240
+ panel_model_params,
241
+ figure_model_params,
242
+ poster_width,
243
+ poster_height,
244
+ shrink_margin=3
245
+ )
246
+
247
+ text_arrangement_title = text_arrangement[0]
248
+ text_arrangement = text_arrangement[1:]
249
+ # Split the title textbox into two parts
250
+ text_arrangement_title_top, text_arrangement_title_bottom = split_textbox(
251
+ text_arrangement_title,
252
+ 0.8
253
+ )
254
+ # Add the split textboxes back to the list
255
+ text_arrangement = [text_arrangement_title_top, text_arrangement_title_bottom] + text_arrangement
256
+
257
+ for i in range(len(figure_arrangement)):
258
+ panel_id = figure_arrangement[i]['panel_id']
259
+ panel_section_name = panels[panel_id]['section_name']
260
+ figure_info = figures[panel_section_name]
261
+ if 'image' in figure_info:
262
+ figure_id = figure_info['image']
263
+ if not figure_id in images:
264
+ figure_path = images[str(figure_id)]['image_path']
265
+ else:
266
+ figure_path = images[figure_id]['image_path']
267
+ elif 'table' in figure_info:
268
+ figure_id = figure_info['table']
269
+ if not figure_id in tables:
270
+ figure_path = tables[str(figure_id)]['table_path']
271
+ else:
272
+ figure_path = tables[figure_id]['table_path']
273
+
274
+ figure_arrangement[i]['figure_path'] = figure_path
275
+
276
+ for text_arrangement_item in text_arrangement:
277
+ num_chars = char_capacity(
278
+ bbox=(text_arrangement_item['x'], text_arrangement_item['y'], text_arrangement_item['height'], text_arrangement_item['width'])
279
+ )
280
+ text_arrangement_item['num_chars'] = num_chars
281
+
282
+
283
+ width_inch, height_inch, panel_arrangement_inches, figure_arrangement_inches, text_arrangement_inches = get_arrangments_in_inches(
284
+ poster_width, poster_height, panel_arrangement, figure_arrangement, text_arrangement, 25
285
+ )
286
+
287
+ # Save to file
288
+ tree_split_results = {
289
+ 'poster_width': poster_width,
290
+ 'poster_height': poster_height,
291
+ 'poster_width_inches': width_inch,
292
+ 'poster_height_inches': height_inch,
293
+ 'panels': panels,
294
+ 'panel_arrangement': panel_arrangement,
295
+ 'figure_arrangement': figure_arrangement,
296
+ 'text_arrangement': text_arrangement,
297
+ 'panel_arrangement_inches': panel_arrangement_inches,
298
+ 'figure_arrangement_inches': figure_arrangement_inches,
299
+ 'text_arrangement_inches': text_arrangement_inches,
300
+ }
301
+ os.makedirs('tree_splits', exist_ok=True)
302
+ with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'w') as f:
303
+ json.dump(tree_split_results, f, indent=4)
304
+
305
+ layout_time_taken = time.time() - outline_time
306
+ print(f'Layout time: {layout_time_taken:.2f} seconds')
307
+ detail_log['layout_time'] = layout_time_taken
308
+
309
+ layout_time = time.time()
310
+
311
+ # # === Configuration Loading ===
312
+ # print("\n📋 Loading configuration from YAML files...", flush=True)
313
+ # yaml_cfg = load_poster_yaml_config(args.poster_path)
314
+
315
+ # # Extract configuration values
316
+ # bullet_fs, title_fs, poster_title_fs, poster_author_fs = extract_font_sizes(yaml_cfg)
317
+ # title_text_color, title_fill_color, main_text_color, main_text_fill_color = extract_colors(yaml_cfg)
318
+ # section_title_vertical_align = extract_vertical_alignment(yaml_cfg)
319
+ # section_title_symbol = extract_section_title_symbol(yaml_cfg)
320
+
321
+ # # Normalize configuration values
322
+ # bullet_fs, title_fs, poster_title_fs, poster_author_fs, \
323
+ # title_text_color, title_fill_color, main_text_color, main_text_fill_color = normalize_config_values(
324
+ # bullet_fs, title_fs, poster_title_fs, poster_author_fs,
325
+ # title_text_color, title_fill_color, main_text_color, main_text_fill_color
326
+ # )
327
+
328
+ # # Store configuration in args
329
+ # setattr(args, 'bullet_font_size', bullet_fs)
330
+ # setattr(args, 'section_title_font_size', title_fs)
331
+ # setattr(args, 'poster_title_font_size', poster_title_fs)
332
+ # setattr(args, 'poster_author_font_size', poster_author_fs)
333
+ # setattr(args, 'title_text_color', title_text_color)
334
+ # setattr(args, 'title_fill_color', title_fill_color)
335
+ # setattr(args, 'main_text_color', main_text_color)
336
+ # setattr(args, 'main_text_fill_color', main_text_fill_color)
337
+ # setattr(args, 'section_title_vertical_align', section_title_vertical_align)
338
+
339
+ # # Step 5: Generate content
340
+ # print(f"\n✍️ Generating poster content (max_workers={args.max_workers})...", flush=True)
341
+ # # --- Step 1: 检查缓存 ---
342
+ # content_cache_path = f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json'
343
+
344
+ # if os.path.exists(content_cache_path):
345
+ # print(f"🧩 Cache found: {content_cache_path}")
346
+ # print("⚡ Skipping model generation, loading from cache...")
347
+ # bullet_content = json.load(open(content_cache_path, 'r'))
348
+ # input_token_t = output_token_t = input_token_v = output_token_v = 0
349
+ # else:
350
+ # print("🧠 Running model to generate poster content...")
351
+ # input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(
352
+ # args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir
353
+ # )
354
+ # bullet_content = json.load(open(content_cache_path, 'r'))
355
+
356
+ # input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir)
357
+ # total_input_tokens_t += input_token
358
+ # total_output_tokens_t += output_token
359
+ # total_input_tokens_v += input_token_v
360
+ # total_output_tokens_v += output_token_v
361
+ # print(f'Content generation token consumption T: {input_token_t} -> {output_token_t}')
362
+ # print(f'Content generation token consumption V: {input_token_v} -> {output_token_v}')
363
+
364
+ # content_time_taken = time.time() - layout_time
365
+ # print(f'Content generation time: {content_time_taken:.2f} seconds')
366
+ # detail_log['content_time'] = content_time_taken
367
+
368
+ # content_time = time.time()
369
+
370
+ # bullet_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'r'))
371
+
372
+ # detail_log['content_in_t'] = input_token_t
373
+ # detail_log['content_out_t'] = output_token_t
374
+ # detail_log['content_in_v'] = input_token_v
375
+ # detail_log['content_out_v'] = output_token_v
376
+
377
+ # # === Style Application ===
378
+ # print("\n🎨 Applying styles and colors...", flush=True)
379
+
380
+ # # Resolve colors with fallbacks
381
+ # final_title_text_color, final_title_fill_color, final_main_text_color, final_main_text_fill_color = resolve_colors(
382
+ # getattr(args, 'title_text_color', None),
383
+ # getattr(args, 'title_fill_color', None),
384
+ # getattr(args, 'main_text_color', None),
385
+ # getattr(args, 'main_text_fill_color', None)
386
+ # )
387
+
388
+ # # Apply all styles in one go
389
+ # bullet_content = apply_all_styles(
390
+ # bullet_content,
391
+ # title_text_color=final_title_text_color,
392
+ # title_fill_color=final_title_fill_color,
393
+ # main_text_color=final_main_text_color,
394
+ # main_text_fill_color=final_main_text_fill_color,
395
+ # section_title_symbol=section_title_symbol,
396
+ # main_text_font_size=bullet_fs
397
+ # )
398
+
399
+ # # === Poster Generation ===
400
+ # # print("\n🎯 Generating PowerPoint code...", flush=True)
401
+
402
+ # # Create theme with alignment
403
+ # base_theme = get_default_theme()
404
+ # theme_with_alignment = create_theme_with_alignment(
405
+ # base_theme,
406
+ # getattr(args, 'section_title_vertical_align', None)
407
+ # )
408
+
409
+ # # poster_code = generate_poster_code(
410
+ # # panel_arrangement_inches,
411
+ # # text_arrangement_inches,
412
+ # # figure_arrangement_inches,
413
+ # # presentation_object_name='poster_presentation',
414
+ # # slide_object_name='poster_slide',
415
+ # # utils_functions=utils_functions,
416
+ # # slide_width=width_inch,
417
+ # # slide_height=height_inch,
418
+ # # img_path=None,
419
+ # # save_path=f'{args.tmp_dir}/poster.pptx',
420
+ # # visible=False,
421
+ # # content=bullet_content,
422
+ # # theme=theme_with_alignment,
423
+ # # tmp_dir=args.tmp_dir,
424
+ # # )
425
+ # print("\n🎯 Generating Beamer poster (LaTeX)...", flush=True)
426
+
427
+ # # --- 1. 提取 poster_info ---
428
+ # poster_info = {
429
+ # "title": args.poster_name,
430
+ # "author": "AutoGen",
431
+ # "institute": "Auto-detected Institution"
432
+ # }
433
+ # if isinstance(bullet_content, list) and len(bullet_content) > 0:
434
+ # first_section = bullet_content[0]
435
+ # if isinstance(first_section, dict):
436
+ # if "poster_title" in first_section:
437
+ # poster_info["title"] = first_section["poster_title"]
438
+ # elif "title" in first_section:
439
+ # poster_info["title"] = first_section["title"]
440
+
441
+ # --- 2. 构造 Beamer 数据结构 ---
442
+ # layout_data = {
443
+ # "text_arrangement": text_arrangement,
444
+ # "figure_arrangement": figure_arrangement
445
+ # }
446
+ # beamer_data = convert_pptx_layout_to_beamer(layout_data)
447
+
448
+ # 将 bullet_content 映射进 sections
449
+ # for i, section in enumerate(beamer_data["sections"]):
450
+ # if i < len(bullet_content):
451
+ # section_data = bullet_content[i]
452
+ # if isinstance(section_data, dict):
453
+ # section["content"] = section_data.get("textbox1") or section_data.get("title") or json.dumps(section_data)
454
+ # else:
455
+ # section["content"] = str(section_data)
456
+
457
+ # --- 3. 生成 LaTeX 文件 ---
458
+ # poster_info = {k: (str(v) if not isinstance(v, str) else v) for k, v in poster_info.items()}
459
+
460
+ # beamer_code = generate_beamer_poster_code(
461
+ # beamer_data["sections"],
462
+ # beamer_data["figures"],
463
+ # poster_info,
464
+ # width_cm=poster_width_inches * 2.54,
465
+ # height_cm=poster_height_inches * 2.54,
466
+ # theme="Madrid",
467
+ # output_path=f"{args.tmp_dir}/{poster_name}.tex"
468
+ # )
469
+ # save_beamer_code(beamer_code, f"{args.tmp_dir}/{poster_name}.tex")
470
+
471
+
472
+ # --- 4. 编译为 PDF ---
473
+ # output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_beamer_posters/{args.poster_path.replace("paper.pdf", "")}'
474
+ # compile_beamer_to_pdf(f"{args.tmp_dir}/{poster_name}.tex", output_dir=args.tmp_dir)
475
+ # pdf_path = os.path.join(args.tmp_dir, f"{poster_name}.pdf")
476
+ # os.makedirs(output_dir, exist_ok=True)
477
+ # os.rename(pdf_path, os.path.join(output_dir, f"{poster_name}.pdf"))
478
+
479
+ # print(f"✅ Beamer poster PDF saved to {output_dir}")
480
+ # Add logos to the poster
481
+ # print("\n🖼️ Adding logos to poster...", flush=True)
482
+ # poster_code = add_logos_to_poster_code(
483
+ # poster_code,
484
+ # width_inch,
485
+ # height_inch,
486
+ # institution_logo_path=institution_logo_path,
487
+ # conference_logo_path=conference_logo_path
488
+ # )
489
+
490
+ # output, err = run_code(poster_code)
491
+ # if err is not None:
492
+ # raise RuntimeError(f'Error in generating PowerPoint: {err}')
493
+
494
+ # # Step 8: Create a folder in the output directory
495
+ # output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_posters/{args.poster_path.replace("paper.pdf", "")}'
496
+ # os.makedirs(output_dir, exist_ok=True)
497
+
498
+ # # Copy logos to output directory for reference
499
+ # logos_dir = os.path.join(output_dir, 'logos')
500
+ # if institution_logo_path or conference_logo_path:
501
+ # os.makedirs(logos_dir, exist_ok=True)
502
+ # if institution_logo_path and os.path.exists(institution_logo_path):
503
+ # shutil.copy2(institution_logo_path, os.path.join(logos_dir, 'institution_logo' + os.path.splitext(institution_logo_path)[1]))
504
+ # if conference_logo_path and os.path.exists(conference_logo_path):
505
+ # shutil.copy2(conference_logo_path, os.path.join(logos_dir, 'conference_logo' + os.path.splitext(conference_logo_path)[1]))
506
+
507
+ # # Step 9: Move poster.pptx to the output directory
508
+ # pptx_path = os.path.join(output_dir, f'{poster_name}.pptx')
509
+ # os.rename(f'{args.tmp_dir}/poster.pptx', pptx_path)
510
+ # print(f'Poster PowerPoint saved to {pptx_path}')
511
+ # # Step 10: Convert the PowerPoint to images
512
+ # ppt_to_images(pptx_path, output_dir)
513
+ # print(f'Poster images saved to {output_dir}')
514
+
515
+ # end_time = time.time()
516
+ # time_taken = end_time - start_time
517
+
518
+ # render_time_taken = time.time() - content_time
519
+ # print(f'Render time: {render_time_taken:.2f} seconds')
520
+ # detail_log['render_time'] = render_time_taken
521
+
522
+ # # log
523
+ # log_file = os.path.join(output_dir, 'log.json')
524
+ # with open(log_file, 'w') as f:
525
+ # log_data = {
526
+ # 'input_tokens_t': total_input_tokens_t,
527
+ # 'output_tokens_t': total_output_tokens_t,
528
+ # 'input_tokens_v': total_input_tokens_v,
529
+ # 'output_tokens_v': total_output_tokens_v,
530
+ # 'time_taken': time_taken,
531
+ # 'institution_logo': institution_logo_path,
532
+ # 'conference_logo': conference_logo_path,
533
+ # }
534
+ # json.dump(log_data, f, indent=4)
535
+
536
+ # detail_log_file = os.path.join(output_dir, 'detail_log.json')
537
+ # with open(detail_log_file, 'w') as f:
538
+ # json.dump(detail_log, f, indent=4)
539
+
540
+ # print(f'\nTotal time: {time_taken:.2f} seconds')
541
+ # print(f'Total text model tokens: {total_input_tokens_t} -> {total_output_tokens_t}')
542
+ # print(f'Total vision model tokens: {total_input_tokens_v} -> {total_output_tokens_v}')
543
+
544
+ # if institution_logo_path:
545
+ # print(f'Institution logo added: {institution_logo_path}')
546
+ # if conference_logo_path:
547
+ # print(f'Conference logo added: {conference_logo_path}')
Paper2Poster/PosterAgent/parse_raw.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import get_json_from_response
3
+ from utils.src.model_utils import parse_pdf
4
+ import json
5
+ import random
6
+
7
+ from camel.models import ModelFactory
8
+ from camel.agents import ChatAgent
9
+ from tenacity import retry, stop_after_attempt
10
+ from docling_core.types.doc import ImageRefMode, PictureItem, TableItem
11
+
12
+ from docling.datamodel.base_models import InputFormat
13
+ from docling.datamodel.pipeline_options import PdfPipelineOptions
14
+ from docling.document_converter import DocumentConverter, PdfFormatOption
15
+
16
+ from pathlib import Path
17
+
18
+ import PIL
19
+
20
+ from marker.models import create_model_dict
21
+
22
+ from utils.wei_utils import *
23
+
24
+ from utils.pptx_utils import *
25
+ from utils.critic_utils import *
26
+ import torch
27
+ from jinja2 import Template
28
+ import re
29
+ import argparse
30
+
31
+ load_dotenv()
32
+ IMAGE_RESOLUTION_SCALE = 5.0
33
+
34
+ pipeline_options = PdfPipelineOptions()
35
+ pipeline_options.images_scale = IMAGE_RESOLUTION_SCALE
36
+ pipeline_options.generate_page_images = True
37
+ pipeline_options.generate_picture_images = True
38
+
39
+ doc_converter = DocumentConverter(
40
+ format_options={
41
+ InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
42
+ }
43
+ )
44
+
45
+ @retry(stop=stop_after_attempt(5))
46
+ def parse_raw(args, actor_config, version=2):
47
+ raw_source = args.poster_path
48
+ markdown_clean_pattern = re.compile(r"<!--[\s\S]*?-->")
49
+
50
+ raw_result = doc_converter.convert(raw_source)
51
+
52
+ raw_markdown = raw_result.document.export_to_markdown()
53
+ text_content = markdown_clean_pattern.sub("", raw_markdown)
54
+
55
+ if len(text_content) < 500:
56
+ print('\nParsing with docling failed, using marker instead\n')
57
+ parser_model = create_model_dict(device='cuda', dtype=torch.float16)
58
+ text_content, rendered = parse_pdf(raw_source, model_lst=parser_model, save_file=False)
59
+
60
+ if version == 1:
61
+ template = Template(open("utils/prompts/gen_poster_raw_content.txt").read())
62
+ elif version == 2:
63
+ print('Using v2 prompt template')
64
+ template = Template(open("utils/prompts/gen_poster_raw_content_v2.txt").read())
65
+
66
+ if args.model_name_t.startswith('vllm_qwen'):
67
+ actor_model = ModelFactory.create(
68
+ model_platform=actor_config['model_platform'],
69
+ model_type=actor_config['model_type'],
70
+ model_config_dict=actor_config['model_config'],
71
+ url=actor_config['url'],
72
+ )
73
+ else:
74
+ actor_model = ModelFactory.create(
75
+ model_platform=actor_config['model_platform'],
76
+ model_type=actor_config['model_type'],
77
+ model_config_dict=actor_config['model_config'],
78
+ )
79
+
80
+ actor_sys_msg = 'You are the author of the paper, and you will create a poster for the paper.'
81
+
82
+ actor_agent = ChatAgent(
83
+ system_message=actor_sys_msg,
84
+ model=actor_model,
85
+ message_window_size=10,
86
+ token_limit=actor_config.get('token_limit', None)
87
+ )
88
+
89
+ while True:
90
+ prompt = template.render(
91
+ markdown_document=text_content,
92
+ )
93
+ actor_agent.reset()
94
+ response = actor_agent.step(prompt)
95
+ input_token, output_token = account_token(response)
96
+
97
+ content_json = get_json_from_response(response.msgs[0].content)
98
+
99
+ if len(content_json) > 0:
100
+ break
101
+ print('Error: Empty response, retrying...')
102
+ if args.model_name_t.startswith('vllm_qwen'):
103
+ text_content = text_content[:80000]
104
+
105
+ if len(content_json['sections']) > 9:
106
+ # First 2 sections + randomly select 5 sections + last 2 sections
107
+ selected_sections = content_json['sections'][:2] + random.sample(content_json['sections'][2:-2], 5) + content_json['sections'][-2:]
108
+ content_json['sections'] = selected_sections
109
+
110
+ has_title = False
111
+
112
+ for section in content_json['sections']:
113
+ if type(section) != dict or not 'title' in section or not 'content' in section:
114
+ print(f"Ouch! The response is invalid, the LLM is not following the format :(")
115
+ print('Trying again...')
116
+ raise
117
+ if 'title' in section['title'].lower():
118
+ has_title = True
119
+
120
+ if not has_title:
121
+ print('Ouch! The response is invalid, the LLM is not following the format :(')
122
+ raise
123
+
124
+ os.makedirs('contents', exist_ok=True)
125
+ json.dump(content_json, open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'w'), indent=4)
126
+ return input_token, output_token, raw_result
127
+
128
+
129
+ def gen_image_and_table(args, conv_res):
130
+ input_token, output_token = 0, 0
131
+ raw_source = args.poster_path
132
+
133
+ output_dir = Path(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}')
134
+
135
+ output_dir.mkdir(parents=True, exist_ok=True)
136
+ doc_filename = args.poster_name
137
+
138
+ # Save page images
139
+ for page_no, page in conv_res.document.pages.items():
140
+ page_no = page.page_no
141
+ page_image_filename = output_dir / f"{doc_filename}-{page_no}.png"
142
+ with page_image_filename.open("wb") as fp:
143
+ page.image.pil_image.save(fp, format="PNG")
144
+
145
+ # Save images of figures and tables
146
+ table_counter = 0
147
+ picture_counter = 0
148
+ for element, _level in conv_res.document.iterate_items():
149
+ if isinstance(element, TableItem):
150
+ table_counter += 1
151
+ element_image_filename = (
152
+ output_dir / f"{doc_filename}-table-{table_counter}.png"
153
+ )
154
+ with element_image_filename.open("wb") as fp:
155
+ element.get_image(conv_res.document).save(fp, "PNG")
156
+
157
+ if isinstance(element, PictureItem):
158
+ picture_counter += 1
159
+ element_image_filename = (
160
+ output_dir / f"{doc_filename}-picture-{picture_counter}.png"
161
+ )
162
+ with element_image_filename.open("wb") as fp:
163
+ element.get_image(conv_res.document).save(fp, "PNG")
164
+
165
+ # Save markdown with embedded pictures
166
+ md_filename = output_dir / f"{doc_filename}-with-images.md"
167
+ conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.EMBEDDED)
168
+
169
+ # Save markdown with externally referenced pictures
170
+ md_filename = output_dir / f"{doc_filename}-with-image-refs.md"
171
+ conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.REFERENCED)
172
+
173
+ # Save HTML with externally referenced pictures
174
+ html_filename = output_dir / f"{doc_filename}-with-image-refs.html"
175
+ conv_res.document.save_as_html(html_filename, image_mode=ImageRefMode.REFERENCED)
176
+
177
+ tables = {}
178
+
179
+ table_index = 1
180
+ for table in conv_res.document.tables:
181
+ caption = table.caption_text(conv_res.document)
182
+ if len(caption) > 0:
183
+ table_img_path = f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}/{args.poster_name}-table-{table_index}.png'
184
+ table_img = PIL.Image.open(table_img_path)
185
+ tables[str(table_index)] = {
186
+ 'caption': caption,
187
+ 'table_path': table_img_path,
188
+ 'width': table_img.width,
189
+ 'height': table_img.height,
190
+ 'figure_size': table_img.width * table_img.height,
191
+ 'figure_aspect': table_img.width / table_img.height,
192
+ }
193
+
194
+ table_index += 1
195
+
196
+ images = {}
197
+ image_index = 1
198
+ for image in conv_res.document.pictures:
199
+ caption = image.caption_text(conv_res.document)
200
+ if len(caption) > 0:
201
+ image_img_path = f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}/{args.poster_name}-picture-{image_index}.png'
202
+ image_img = PIL.Image.open(image_img_path)
203
+ images[str(image_index)] = {
204
+ 'caption': caption,
205
+ 'image_path': image_img_path,
206
+ 'width': image_img.width,
207
+ 'height': image_img.height,
208
+ 'figure_size': image_img.width * image_img.height,
209
+ 'figure_aspect': image_img.width / image_img.height,
210
+ }
211
+ image_index += 1
212
+
213
+ json.dump(images, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images.json', 'w'), indent=4)
214
+ json.dump(tables, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables.json', 'w'), indent=4)
215
+
216
+ return input_token, output_token, images, tables
217
+
218
+ if __name__ == '__main__':
219
+ parser = argparse.ArgumentParser()
220
+ parser.add_argument('--poster_name', type=str, default=None)
221
+ parser.add_argument('--model_name', type=str, default='4o')
222
+ parser.add_argument('--poster_path', type=str, required=True)
223
+ parser.add_argument('--index', type=int, default=0)
224
+ args = parser.parse_args()
225
+
226
+ agent_config = get_agent_config(args.model_name)
227
+
228
+ if args.poster_name is None:
229
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
230
+
231
+ # Parse raw content
232
+ input_token, output_token = parse_raw(args, agent_config)
233
+
234
+ # Generate images and tables
235
+ _, _ = gen_image_and_table(args)
236
+
237
+ print(f'Token consumption: {input_token} -> {output_token}')
Paper2Poster/PosterAgent/poster_gen_pipeline.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import time
3
+
4
+ from utils.wei_utils import get_agent_config
5
+ from PosterAgent.parse_raw import parse_raw, gen_image_and_table
6
+ from PosterAgent.gen_outline_layout import filter_image_table, gen_outline_layout
7
+ from PosterAgent.gen_poster_content import gen_poster_content
8
+ from PosterAgent.fill_and_style import fill_poster_content, stylize_poster
9
+ from PosterAgent.deoverflow import deoverflow
10
+ from PosterAgent.apply_theme import poster_apply_theme
11
+
12
+ if __name__ == '__main__':
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--poster_name', type=str, default=None)
15
+ parser.add_argument('--model_name', type=str, default='4o')
16
+ parser.add_argument('--poster_path', type=str, required=True)
17
+ parser.add_argument('--index', type=int, default=0)
18
+ parser.add_argument('--template_path', type=str, default=None)
19
+ parser.add_argument('--max_retry', type=int, default=3)
20
+ args = parser.parse_args()
21
+
22
+ start_time = time.time()
23
+
24
+ actor_config = get_agent_config(args.model_name)
25
+ critic_config = get_agent_config(args.model_name)
26
+
27
+ if args.poster_name is None:
28
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
29
+
30
+ total_input_token, total_output_token = 0, 0
31
+
32
+ # Parse raw content
33
+ input_token, output_token = parse_raw(args, actor_config)
34
+ total_input_token += input_token
35
+ total_output_token += output_token
36
+
37
+ # Generate images and tables
38
+ _, _ = gen_image_and_table(args)
39
+
40
+ print()
41
+ print(f'Parsing token consumption: {input_token} -> {output_token}')
42
+
43
+ input_token, output_token = filter_image_table(args, actor_config)
44
+ total_input_token += input_token
45
+ total_output_token += output_token
46
+ print()
47
+ print(f'Filter images and tables token consumption: {input_token} -> {output_token}')
48
+
49
+ input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
50
+ total_input_token += input_token
51
+ total_output_token += output_token
52
+ print()
53
+ print(f'Generate outline and layout token consumption: {input_token} -> {output_token}')
54
+
55
+ input_token, output_token = gen_poster_content(args, actor_config)
56
+ total_input_token += input_token
57
+ total_output_token += output_token
58
+ print()
59
+ print(f'Generate poster content token consumption: {input_token} -> {output_token}')
60
+
61
+ input_token, output_token = fill_poster_content(args, actor_config)
62
+ total_input_token += input_token
63
+ total_output_token += output_token
64
+ print()
65
+ print(f'Fill poster content token consumption: {input_token} -> {output_token}')
66
+
67
+ input_token, output_token = stylize_poster(args, actor_config)
68
+ total_input_token += input_token
69
+ total_output_token += output_token
70
+ print()
71
+ print(f'Stylize poster token consumption: {input_token} -> {output_token}')
72
+
73
+ input_token, output_token = deoverflow(args, actor_config, critic_config)
74
+ total_input_token += input_token
75
+ total_output_token += output_token
76
+ print()
77
+ print(f'Deoverflow token consumption: {input_token} -> {output_token}')
78
+
79
+ if args.template_path is not None:
80
+ input_token, output_token = poster_apply_theme(args, actor_config, critic_config)
81
+ total_input_token += input_token
82
+ total_output_token += output_token
83
+ print()
84
+ print(f'Apply theme token consumption: {input_token} -> {output_token}')
85
+
86
+ print()
87
+ print(f'Total token consumption: {total_input_token} -> {total_output_token}')
88
+
89
+ end_time = time.time()
90
+ elapsed_time = end_time - start_time
91
+ # Convert to hh:mm:ss format
92
+ hours, rem = divmod(elapsed_time, 3600)
93
+ minutes, seconds = divmod(rem, 60)
94
+
95
+ print(f"Execution Time: {int(hours):02}:{int(minutes):02}:{int(seconds):02}")
96
+
97
+ log_path = f'log/{args.model_name}_{args.poster_name}_{args.index}_log.txt'
98
+ with open(log_path, 'w') as f:
99
+ f.write(f'Total token consumption: {total_input_token} -> {total_output_token}\n')
100
+ f.write(f'Execution Time: {int(hours):02}:{int(minutes):02}:{int(seconds):02}\n')
101
+
Paper2Poster/PosterAgent/tree_split_layout.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lxml import etree
2
+ import os
3
+ import copy
4
+ import glob
5
+ import numpy as np
6
+ from sklearn.linear_model import LinearRegression, LogisticRegression
7
+ import matplotlib.pyplot as plt
8
+ import matplotlib.patches as patches
9
+
10
+ def parse_xml_with_recovery(xml_file_path):
11
+ parser = etree.XMLParser(recover=True)
12
+ tree = etree.parse(xml_file_path, parser)
13
+ return tree.getroot()
14
+
15
+ def parse_poster_xml(xml_file):
16
+ """
17
+ Parse an XML describing a single poster layout, e.g.:
18
+
19
+ <Poster Width="685" Height="968">
20
+ <Panel left="5" right="160" width="674" height="123">
21
+ <Text>Introduction</Text>
22
+ <Figure left="567" right="178" width="81" height="99" no="1" ... />
23
+ </Panel>
24
+ ...
25
+ </Poster>
26
+
27
+ Returns a dict with:
28
+ {
29
+ 'poster_width': float,
30
+ 'poster_height': float,
31
+ 'panels': [
32
+ {
33
+ 'x': float,
34
+ 'y': float,
35
+ 'width': float,
36
+ 'height': float,
37
+ 'text_blocks': [string, string, ...],
38
+ 'figure_blocks': [(fx, fy, fw, fh), ...]
39
+ },
40
+ ...
41
+ ]
42
+ }
43
+ """
44
+ root = parse_xml_with_recovery(xml_file)
45
+
46
+ # Poster dimensions
47
+ poster_w = float(root.get("Width", "1"))
48
+ poster_h = float(root.get("Height", "1"))
49
+
50
+ panels_data = []
51
+
52
+ # Iterate <Panel> elements
53
+ for panel_node in root.findall("Panel"):
54
+ x = float(panel_node.get("left", "0"))
55
+ y = float(panel_node.get("right", "0"))
56
+ w = float(panel_node.get("width", "0"))
57
+ h = float(panel_node.get("height", "0"))
58
+
59
+ # Gather text blocks
60
+ text_blocks = []
61
+ for text_node in panel_node.findall("Text"):
62
+ txt = text_node.text or ""
63
+ txt = txt.strip()
64
+ if txt:
65
+ text_blocks.append(txt)
66
+
67
+ # Gather figure blocks
68
+ figure_blocks = []
69
+ for fig_node in panel_node.findall("Figure"):
70
+ fx = float(fig_node.get("left", "0"))
71
+ fy = float(fig_node.get("right", "0"))
72
+ fw = float(fig_node.get("width", "0"))
73
+ fh = float(fig_node.get("height", "0"))
74
+ figure_blocks.append((fx, fy, fw, fh))
75
+
76
+ panel_info = {
77
+ "x": x,
78
+ "y": y,
79
+ "width": w,
80
+ "height": h,
81
+ "text_blocks": text_blocks,
82
+ "figure_blocks": figure_blocks
83
+ }
84
+ panels_data.append(panel_info)
85
+
86
+ return {
87
+ "poster_width": poster_w,
88
+ "poster_height": poster_h,
89
+ "panels": panels_data
90
+ }
91
+
92
+ def compute_panel_attributes(poster_data):
93
+ """
94
+ Given poster_data, compute:
95
+ - tp: ratio of text length for each panel
96
+ - gp: ratio of figure area for each panel
97
+ - sp: ratio of panel area to total poster area
98
+ - rp: aspect ratio (width / height)
99
+
100
+ Returns a list of dicts, each:
101
+ {
102
+ 'tp': float,
103
+ 'gp': float,
104
+ 'sp': float,
105
+ 'rp': float
106
+ }
107
+ """
108
+
109
+ poster_w = poster_data["poster_width"]
110
+ poster_h = poster_data["poster_height"]
111
+ panels = poster_data["panels"]
112
+
113
+ poster_area = max(poster_w * poster_h, 1.0) # avoid zero
114
+
115
+ # 1) Compute total text length across all panels
116
+ # 2) Compute total figure area across all panels
117
+ total_text_length = 0
118
+ total_figure_area = 0
119
+
120
+ # We'll store partial info about each panel so we don't parse multiple times
121
+ panel_list = []
122
+ for p in panels:
123
+ # Combine all text
124
+ panel_text_joined = " ".join(p["text_blocks"])
125
+ panel_text_len = len(panel_text_joined)
126
+
127
+ # Sum area of figure blocks
128
+ panel_fig_area = 0.0
129
+ for (fx, fy, fw, fh) in p["figure_blocks"]:
130
+ panel_fig_area += (fw * fh)
131
+
132
+ panel_list.append({
133
+ "x": p["x"],
134
+ "y": p["y"],
135
+ "width": p["width"],
136
+ "height": p["height"],
137
+ "text_len": panel_text_len,
138
+ "fig_area": panel_fig_area
139
+ })
140
+
141
+ total_text_length += panel_text_len
142
+ total_figure_area += panel_fig_area
143
+
144
+ # Avoid divide by zero
145
+ if total_text_length < 1:
146
+ total_text_length = 1
147
+ if total_figure_area < 1e-9:
148
+ total_figure_area = 1e-9
149
+
150
+ # 3) Compute attributes
151
+ results = []
152
+ for pinfo in panel_list:
153
+ pw = pinfo["width"]
154
+ ph = pinfo["height"]
155
+
156
+ panel_area = pw * ph
157
+ sp = panel_area / poster_area # fraction of total area
158
+ rp = (pw / ph) if ph > 0 else 1.0
159
+
160
+ tp = pinfo["text_len"] / float(total_text_length)
161
+ gp = pinfo["fig_area"] / float(total_figure_area)
162
+
163
+ results.append({
164
+ "tp": tp,
165
+ "gp": gp,
166
+ "sp": sp,
167
+ "rp": rp
168
+ })
169
+
170
+ return results
171
+
172
+ def train_panel_attribute_inference(panel_records):
173
+ """
174
+ The training data `panel_records` is a list of dicts, each containing:
175
+ {
176
+ 'tp': float,
177
+ 'gp': float,
178
+ 'sp': float, # (label for the sp regression)
179
+ 'rp': float # (label for the rp regression)
180
+ }
181
+
182
+ We'll train two linear regressors:
183
+ sp = w_s * [tp, gp, 1]
184
+ rp = w_r * [tp, gp, 1]
185
+
186
+ Returns dict with learned parameters:
187
+ {
188
+ 'w_s': array, # shape (3,) => sp = w_s[0]*tp + w_s[1]*gp + w_s[2]
189
+ 'sigma_s': float, # variance of residual for sp
190
+ 'w_r': array,
191
+ 'sigma_r': float
192
+ }
193
+ """
194
+ # Build data arrays
195
+ X_list = []
196
+ sp_list = []
197
+ rp_list = []
198
+
199
+ for rec in panel_records:
200
+ tp = rec['tp']
201
+ gp = rec['gp']
202
+ sp = rec['sp']
203
+ rp = rec['rp']
204
+ # X = [tp, gp, 1]
205
+ X_list.append([tp, gp, 1.0])
206
+ sp_list.append(sp)
207
+ rp_list.append(rp)
208
+
209
+ X_array = np.array(X_list, dtype=float)
210
+ y_sp = np.array(sp_list, dtype=float)
211
+ y_rp = np.array(rp_list, dtype=float)
212
+
213
+ # Fit linear regression for sp
214
+ linreg_sp = LinearRegression(fit_intercept=False)
215
+ linreg_sp.fit(X_array, y_sp)
216
+ w_s = linreg_sp.coef_
217
+ pred_sp = linreg_sp.predict(X_array)
218
+ residual_sp = y_sp - pred_sp
219
+ sigma_s = np.var(residual_sp, ddof=1)
220
+
221
+ # Fit linear regression for rp
222
+ linreg_rp = LinearRegression(fit_intercept=False)
223
+ linreg_rp.fit(X_array, y_rp)
224
+ w_r = linreg_rp.coef_
225
+ pred_rp = linreg_rp.predict(X_array)
226
+ residual_rp = y_rp - pred_rp
227
+ sigma_r = np.var(residual_rp, ddof=1)
228
+
229
+ model_params = {
230
+ "w_s": w_s,
231
+ "sigma_s": sigma_s,
232
+ "w_r": w_r,
233
+ "sigma_r": sigma_r
234
+ }
235
+ return model_params
236
+
237
+
238
+ def parse_poster_xml_for_figures(xml_path):
239
+ root = parse_xml_with_recovery(xml_path)
240
+
241
+ poster_w = float(root.get("Width", "1"))
242
+ poster_h = float(root.get("Height", "1"))
243
+ poster_area = poster_w * poster_h
244
+
245
+ records = []
246
+
247
+ for panel in root.findall("Panel"):
248
+ px, py = float(panel.get("left", 0)), float(panel.get("right", 0))
249
+ pw, ph = float(panel.get("width", 1)), float(panel.get("height", 1))
250
+ panel_area = pw * ph
251
+ sp = panel_area / poster_area
252
+ rp = pw / ph if ph > 0 else 1.0
253
+
254
+ lp = sum(len(t.text.strip()) for t in panel.findall("Text") if t.text)
255
+
256
+ for fig in panel.findall("Figure"):
257
+ fx, fy = float(fig.get("left", 0)), float(fig.get("right", 0))
258
+ fw, fh = float(fig.get("width", 1)), float(fig.get("height", 1))
259
+
260
+ sg = (fw * fh) / poster_area
261
+ rg = fw / fh if fh > 0 else 1.0
262
+ ug = fw / pw if pw > 0 else 0.1
263
+
264
+ panel_center_x = px + pw / 2
265
+ fig_center_x = fx + fw / 2
266
+ delta_x = fig_center_x - panel_center_x
267
+
268
+ hg = 0 if delta_x < -pw / 6 else (2 if delta_x > pw / 6 else 1)
269
+
270
+ record = {"sp": sp, "rp": rp, "lp": lp, "sg": sg, "rg": rg, "hg": hg, "ug": ug}
271
+ records.append(record)
272
+
273
+ return records
274
+
275
+
276
+ def train_figure_model(figure_records):
277
+ X_hg, y_hg, X_ug, y_ug = [], [], [], []
278
+ for r in figure_records:
279
+ feats = [r["sp"], r["lp"], r["sg"], 1.0]
280
+ X_hg.append(feats)
281
+ y_hg.append(r["hg"])
282
+ X_ug.append(feats)
283
+ y_ug.append(r["ug"])
284
+
285
+ clf_hg = LogisticRegression(multi_class="multinomial", solver="lbfgs", fit_intercept=False)
286
+ clf_hg.fit(X_hg, y_hg)
287
+
288
+ lin_ug = LinearRegression(fit_intercept=False)
289
+ lin_ug.fit(X_ug, y_ug)
290
+ residuals = y_ug - lin_ug.predict(X_ug)
291
+ sigma_u = np.var(residuals, ddof=1)
292
+
293
+ return {
294
+ "clf_hg": clf_hg,
295
+ "w_u": lin_ug.coef_,
296
+ "sigma_u": sigma_u
297
+ }
298
+
299
+
300
+ def main_train():
301
+ poster_dataset_path = 'assets/poster_data/Train'
302
+ # loop through all folders in the dataset
303
+ xml_files = []
304
+ for folder in os.listdir(poster_dataset_path):
305
+ folder_path = os.path.join(poster_dataset_path, folder)
306
+ if os.path.isdir(folder_path):
307
+ # find all XML files in this folder
308
+ xml_files.extend(glob.glob(os.path.join(folder_path, "*.txt")))
309
+
310
+ all_panel_records = []
311
+ for xml_file in xml_files:
312
+ poster_data = parse_poster_xml(xml_file)
313
+ # compute tp, gp, sp, rp
314
+ panel_attrs = compute_panel_attributes(poster_data)
315
+ # each panel_attrs entry is {tp, gp, sp, rp}
316
+ all_panel_records.extend(panel_attrs)
317
+
318
+ all_figure_records = []
319
+ for xml_path in xml_files:
320
+ recs = parse_poster_xml_for_figures(xml_path)
321
+ all_figure_records.extend(recs)
322
+
323
+ panel_model_params = train_panel_attribute_inference(all_panel_records)
324
+ figure_model_params = train_figure_model(all_figure_records)
325
+
326
+ return panel_model_params, figure_model_params
327
+
328
+ def place_text_and_figures_exact(panel_dict, figure_model_params, section_title_height=32):
329
+ """
330
+ Lay out text and figure boxes inside a panel.
331
+
332
+ The figure’s aspect ratio (width / height) is now enforced strictly:
333
+ • width ≤ panel width
334
+ • height ≤ 0.60 × panel height (empirical upper‑bound you already used)
335
+ ��� width / height == panel_dict["figure_aspect"]
336
+ """
337
+ # ---------------- Constants used for text layout -----------------
338
+ char_width_px = 7
339
+ line_height_px = 16
340
+ chars_per_line = max(int(panel_dict["width"] / char_width_px), 1)
341
+
342
+ total_lines_text = np.ceil(panel_dict["text_len"] / chars_per_line)
343
+ total_text_height = total_lines_text * line_height_px
344
+
345
+ x_p, y_p = panel_dict["x"], panel_dict["y"]
346
+ w_p, h_p = panel_dict["width"], panel_dict["height"]
347
+
348
+ figure_boxes, text_boxes = [], []
349
+
350
+ panel_name_lower = panel_dict["panel_name"].lower()
351
+ has_title_in_name = "title" in panel_name_lower
352
+
353
+ # -------------------------------------------------------
354
+ # Helper to build a text‑box dict
355
+ # -------------------------------------------------------
356
+ def make_text_box(panel_id, x, y, width, height, textbox_id, textbox_name):
357
+ return {
358
+ "panel_id": panel_id,
359
+ "x": float(x),
360
+ "y": float(y),
361
+ "width": float(width),
362
+ "height": float(height),
363
+ "textbox_id": textbox_id,
364
+ "textbox_name": textbox_name,
365
+ }
366
+
367
+ # -----------------------------------------------------------------------
368
+ # Case 1 — no figure in this panel
369
+ # -----------------------------------------------------------------------
370
+ if panel_dict["figure_size"] <= 0:
371
+ if has_title_in_name:
372
+ text_boxes.append(
373
+ make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, h_p,
374
+ textbox_id=0,
375
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t0')
376
+ )
377
+ else:
378
+ title_h = min(section_title_height, h_p)
379
+ text_boxes.extend([
380
+ make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, title_h,
381
+ textbox_id=0,
382
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t0'),
383
+ make_text_box(panel_dict["panel_id"], x_p, y_p + title_h, w_p, h_p - title_h,
384
+ textbox_id=1,
385
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t1'),
386
+ ])
387
+ return text_boxes, figure_boxes # early‑return (simpler branch)
388
+
389
+ # -----------------------------------------------------------------------
390
+ # Case 2 — there *is* a figure
391
+ # -----------------------------------------------------------------------
392
+ # 1. Sample horizontal‑alignment class (hg) and raw width fraction (ug)
393
+ feat = np.array([panel_dict["sp"],
394
+ panel_dict["text_len"],
395
+ panel_dict["figure_size"],
396
+ 1.0]).reshape(1, -1)
397
+
398
+ clf_hg = figure_model_params["clf_hg"]
399
+ hg_sample = int(np.argmax(clf_hg.predict_proba(feat)[0]))
400
+
401
+ mean_ug = float(np.dot(figure_model_params["w_u"], feat.flatten()))
402
+ sigma_u = float(np.sqrt(figure_model_params["sigma_u"]))
403
+ ug_sample = float(np.clip(np.random.normal(mean_ug, sigma_u), 0.10, 0.80)) # 10‑80 % of width
404
+
405
+ # 2. **Size the figure while *preserving* aspect ratio**
406
+ aspect = float(panel_dict["figure_aspect"]) # width / height
407
+ fig_w = ug_sample * w_p # preliminary width
408
+ fig_h = fig_w / aspect
409
+
410
+ max_fig_h = 0.60 * h_p # same limit you had
411
+ if fig_h > max_fig_h: # too tall → scale down
412
+ scale = max_fig_h / fig_h
413
+ fig_w *= scale
414
+ fig_h = max_fig_h # (ratio still intact)
415
+
416
+ # 3. Horizontal placement
417
+ if hg_sample == 0: # left
418
+ fig_x = x_p
419
+ elif hg_sample == 2: # right
420
+ fig_x = x_p + w_p - fig_w
421
+ else: # center
422
+ fig_x = x_p + 0.5 * (w_p - fig_w)
423
+ # Vertical centering
424
+ fig_y = y_p + 0.5 * (h_p - fig_h)
425
+
426
+ # 4. Split text into “top” and “bottom” areas around the figure
427
+ top_text_h = (fig_y - y_p)
428
+ bottom_text_h = (y_p + h_p) - (fig_y + fig_h)
429
+
430
+ # --- build top‑text boxes
431
+ if has_title_in_name:
432
+ text_boxes.append(
433
+ make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, top_text_h,
434
+ textbox_id=0,
435
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t0')
436
+ )
437
+ next_id = 1
438
+ else:
439
+ title_h = min(section_title_height, top_text_h)
440
+ text_boxes.extend([
441
+ make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, title_h,
442
+ textbox_id=0,
443
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t0'),
444
+ make_text_box(panel_dict["panel_id"], x_p, y_p + title_h, w_p, top_text_h - title_h,
445
+ textbox_id=1,
446
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t1'),
447
+ ])
448
+ next_id = 2
449
+
450
+ # --- bottom text box
451
+ text_boxes.append(
452
+ make_text_box(panel_dict["panel_id"], x_p, fig_y + fig_h, w_p, bottom_text_h,
453
+ textbox_id=next_id,
454
+ textbox_name=f'p<{panel_dict["panel_name"]}>_t{next_id}')
455
+ )
456
+
457
+ # 5. Figure box
458
+ figure_boxes.append({
459
+ "panel_id": panel_dict["panel_id"],
460
+ "x": float(fig_x),
461
+ "y": float(fig_y),
462
+ "width": float(fig_w),
463
+ "height": float(fig_h),
464
+ "figure_id": 0,
465
+ "figure_name": f'p<{panel_dict["panel_name"]}>_f0',
466
+ })
467
+
468
+ return text_boxes, figure_boxes
469
+
470
+
471
+ def to_inches(value_in_units, units_per_inch=72):
472
+ """
473
+ Convert a single coordinate or dimension from 'units' to inches.
474
+ For example, if your units are 'points' (72 points = 1 inch),
475
+ then units_per_inch=72.
476
+ If your units are 'pixels' at 96 DPI, then units_per_inch=96.
477
+ """
478
+ return value_in_units / units_per_inch
479
+
480
+
481
+ def from_inches(value_in_inches, units_per_inch=72):
482
+ """
483
+ Convert from inches back to the original 'units'.
484
+ """
485
+ return value_in_inches * units_per_inch
486
+
487
+
488
+ def softmax(logits):
489
+ s = sum(np.exp(logits))
490
+ return [np.exp(l)/s for l in logits]
491
+
492
+
493
+ def infer_panel_attrs(panel_model, tp, gp):
494
+ # sp = w_s dot [tp, gp, 1]
495
+ # rp = w_r dot [tp, gp, 1]
496
+ vec = np.array([tp, gp, 1.0])
497
+ w_s = panel_model["w_s"]
498
+ w_r = panel_model["w_r"]
499
+ sp = np.dot(w_s, vec)
500
+ rp = np.dot(w_r, vec)
501
+ # clamp
502
+ sp = max(sp, 0.01)
503
+ rp = max(rp, 0.05)
504
+ return sp, rp
505
+
506
+
507
+ def panel_layout_generation(panels, x, y, w, h):
508
+ # If only 1 panel, place it entirely
509
+ if len(panels) == 1:
510
+ p = panels[0]
511
+ cur_rp = (w/h) if h>1e-9 else p["rp"]
512
+ loss = abs(p["rp"] - cur_rp)
513
+ arrangement = [{
514
+ "panel_name": p["section_name"],
515
+ "panel_id": p["panel_id"],
516
+ "x": x, "y": y,
517
+ "width": w, "height": h
518
+ }]
519
+ return loss, arrangement
520
+
521
+ best_loss = float('inf')
522
+ best_arr = []
523
+ total_sp = sum(pp["sp"] for pp in panels)
524
+ n = len(panels)
525
+
526
+ for i in range(1, n):
527
+ subset1 = panels[:i]
528
+ subset2 = panels[i:]
529
+ sp1 = sum(pp["sp"] for pp in subset1)
530
+ ratio = sp1 / total_sp
531
+
532
+ # horizontal
533
+ h_top = ratio * h
534
+ if 0 < h_top < h:
535
+ l1, a1 = panel_layout_generation(subset1, x, y, w, h_top)
536
+ l2, a2 = panel_layout_generation(subset2, x, y + h_top, w, h - h_top)
537
+ if (l1 + l2) < best_loss:
538
+ best_loss = l1 + l2
539
+ best_arr = a1 + a2
540
+
541
+ # vertical
542
+ w_left = ratio * w
543
+ if 0 < w_left < w:
544
+ l1, a1 = panel_layout_generation(subset1, x, y, w_left, h)
545
+ l2, a2 = panel_layout_generation(subset2, x + w_left, y, w - w_left, h)
546
+ if (l1 + l2) < best_loss:
547
+ best_loss = l1 + l2
548
+ best_arr = a1 + a2
549
+
550
+ return best_loss, best_arr
551
+
552
+ def split_textbox(textbox, ratio):
553
+ """
554
+ Splits a textbox dictionary horizontally into two parts.
555
+
556
+ Parameters:
557
+ textbox (dict): A dictionary with the keys
558
+ 'panel_id', 'x', 'y', 'width', 'height', 'textbox_id', 'textbox_name'
559
+ ratio (float or int): Ratio of top height to bottom height.
560
+ For example, if ratio is 3, then:
561
+ top_height = (3/4) * height
562
+ bottom_height = (1/4) * height
563
+
564
+ Returns:
565
+ tuple: Two dictionaries corresponding to the top and bottom split textboxes.
566
+ """
567
+ # Calculate the new heights
568
+ total_ratio = ratio + 1 # because the ratio represents top:bottom as (ratio):(1)
569
+ top_height = textbox['height'] * ratio / total_ratio
570
+ bottom_height = textbox['height'] * 1 / total_ratio
571
+
572
+ # Derive the base textbox name by splitting off the existing _t suffix if present.
573
+ # This assumes the original textbox_name ends with "_t<number>".
574
+ base_name = textbox['textbox_name'].rsplit('_t', 1)[0]
575
+
576
+ # Create the top textbox dictionary
577
+ top_box = dict(textbox) # make a shallow copy
578
+ top_box['height'] = top_height
579
+ # y remains the same for the top textbox
580
+ top_box['textbox_name'] = f"{base_name}_t0" # rename with _t0
581
+
582
+ # Create the bottom textbox dictionary
583
+ bottom_box = dict(textbox) # make a shallow copy
584
+ bottom_box['y'] = textbox['y'] + top_height # adjust the y position
585
+ bottom_box['height'] = bottom_height
586
+ bottom_box['textbox_name'] = f"{base_name}_t1" # rename with _t1
587
+
588
+ return top_box, bottom_box
589
+
590
+ def generate_constrained_layout(paper_panels, poster_w, poster_h, title_height_ratio=0.1):
591
+ # Find title panel explicitly
592
+ try:
593
+ title_panel = next(p for p in paper_panels if ('title' in p["section_name"].lower()))
594
+ other_panels = [p for p in paper_panels if ('title' not in p["section_name"].lower())]
595
+ except StopIteration:
596
+ print('Oops, no title found, please try again.')
597
+ raise
598
+
599
+ title_h = poster_h * title_height_ratio
600
+ title_layout = {
601
+ "panel_name": title_panel["section_name"],
602
+ "panel_id": title_panel["panel_id"],
603
+ "x": 0, "y": 0,
604
+ "width": poster_w, "height": title_h
605
+ }
606
+
607
+ # Generate recursive layout on remaining space for other panels
608
+ layout_loss, remaining_layout = panel_layout_generation(
609
+ other_panels,
610
+ x=0, y=title_h,
611
+ w=poster_w, h=poster_h - title_h
612
+ )
613
+
614
+ # Combine title panel with others
615
+ complete_layout = [title_layout] + remaining_layout
616
+ return layout_loss, complete_layout
617
+
618
+
619
+ def main_inference(
620
+ paper_panels,
621
+ panel_model_params,
622
+ figure_model_params,
623
+ poster_width=1200,
624
+ poster_height=800,
625
+ shrink_margin=0
626
+ ):
627
+ for p in paper_panels:
628
+ sp, rp = infer_panel_attrs(panel_model_params, p["tp"], p["gp"])
629
+ p["sp"] = sp
630
+ p["rp"] = rp
631
+
632
+ layout_loss, panel_arrangement = generate_constrained_layout(paper_panels, poster_width, poster_height, title_height_ratio=0.1)
633
+ print("Panel layout cost:", layout_loss)
634
+ for p in panel_arrangement:
635
+ print("Panel:", p)
636
+
637
+ panel_map = {}
638
+ for p in paper_panels:
639
+ panel_map[p["panel_id"]] = p
640
+
641
+ final_panels = []
642
+ for pa in panel_arrangement:
643
+ # Merge bounding box with the original sp,rp data
644
+ pid = pa["panel_id"]
645
+ merged_panel = {
646
+ "panel_id": pid,
647
+ "panel_name": pa['panel_name'],
648
+ "x": pa["x"] + shrink_margin,
649
+ "y": pa["y"] + shrink_margin,
650
+ "width": pa["width"] - 2 * shrink_margin,
651
+ "height": pa["height"] - 2 * shrink_margin,
652
+ "sp": panel_map[pid]["sp"],
653
+ "rp": panel_map[pid]["rp"],
654
+ "text_len": panel_map[pid]["text_len"],
655
+ "figure_size": panel_map[pid]["figure_size"],
656
+ "figure_aspect": panel_map[pid]["figure_aspect"]
657
+ }
658
+ final_panels.append(merged_panel)
659
+
660
+ text_arrangement = []
661
+ figure_arrangement = []
662
+
663
+ for p in final_panels:
664
+ text_boxes, fig_boxes = place_text_and_figures_exact(p, figure_model_params)
665
+ text_arrangement.extend(text_boxes) # text arrangement
666
+ figure_arrangement.extend(fig_boxes) # figure arrangement
667
+
668
+ return panel_arrangement, figure_arrangement, text_arrangement
669
+
670
+ def visualize_complete_layout(
671
+ panels, text_boxes, figure_boxes, poster_width, poster_height
672
+ ):
673
+ fig, ax = plt.subplots(figsize=(12,8))
674
+ ax.set_xlim(0, poster_width)
675
+ ax.set_ylim(0, poster_height)
676
+ ax.set_aspect('equal')
677
+
678
+ # Draw panels
679
+ for panel in panels:
680
+ rect = patches.Rectangle(
681
+ (panel["x"], panel["y"]), panel["width"], panel["height"],
682
+ linewidth=1, edgecolor='black', facecolor='none'
683
+ )
684
+ ax.add_patch(rect)
685
+ ax.text(
686
+ panel["x"] + 5, panel["y"] + panel["height"] - 5,
687
+ f'Panel {panel["panel_id"]}', fontsize=8, va='top', color='black'
688
+ )
689
+
690
+ # Draw text boxes
691
+ for txt in text_boxes:
692
+ rect = patches.Rectangle(
693
+ (txt["x"], txt["y"]), txt["width"], txt["height"],
694
+ linewidth=1, edgecolor='green', linestyle='-.', facecolor='none'
695
+ )
696
+ ax.add_patch(rect)
697
+ ax.text(
698
+ txt["x"] + 2, txt["y"] + txt["height"] - 2,
699
+ f'Text {txt["panel_id"]}', fontsize=7, color='green', va='top'
700
+ )
701
+
702
+ # Draw figures
703
+ for fig_box in figure_boxes:
704
+ rect = patches.Rectangle(
705
+ (fig_box["x"], fig_box["y"]), fig_box["width"], fig_box["height"],
706
+ linewidth=1, edgecolor='blue', linestyle='--', facecolor='none'
707
+ )
708
+ ax.add_patch(rect)
709
+ ax.text(
710
+ fig_box["x"] + 2, fig_box["y"] + 2,
711
+ f'Fig {fig_box["panel_id"]}', fontsize=7, color='blue', va='bottom'
712
+ )
713
+
714
+ plt.gca().invert_yaxis() # optional: invert y-axis if needed
715
+ plt.show()
716
+
717
+
718
+ def get_arrangments_in_inches(
719
+ width,
720
+ height,
721
+ panel_arrangement,
722
+ figure_arrangement,
723
+ text_arrangement,
724
+ units_per_inch=72
725
+ ):
726
+
727
+ panel_arrangement_inches = copy.deepcopy(panel_arrangement)
728
+ figure_arrangement_inches = copy.deepcopy(figure_arrangement)
729
+ text_arrangement_inches = copy.deepcopy(text_arrangement)
730
+
731
+ for p in panel_arrangement_inches:
732
+ p["x"] = to_inches(p["x"], units_per_inch)
733
+ p["y"] = to_inches(p["y"], units_per_inch)
734
+ p["width"] = to_inches(p["width"], units_per_inch)
735
+ p["height"] = to_inches(p["height"], units_per_inch)
736
+
737
+ for f in figure_arrangement_inches:
738
+ f["x"] = to_inches(f["x"], units_per_inch)
739
+ f["y"] = to_inches(f["y"], units_per_inch)
740
+ f["width"] = to_inches(f["width"], units_per_inch)
741
+ f["height"] = to_inches(f["height"], units_per_inch)
742
+
743
+ for t in text_arrangement_inches:
744
+ t["x"] = to_inches(t["x"], units_per_inch)
745
+ t["y"] = to_inches(t["y"], units_per_inch)
746
+ t["width"] = to_inches(t["width"], units_per_inch)
747
+ t["height"] = to_inches(t["height"], units_per_inch)
748
+
749
+ width_inch, height_inch = to_inches(width, units_per_inch), to_inches(height, units_per_inch)
750
+ return width_inch, height_inch, panel_arrangement_inches, figure_arrangement_inches, text_arrangement_inches
Paper2Poster/README.md ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🎓Paper2Poster: Multimodal Poster Automation from Scientific Papers
2
+ # 从学术论文自动生成学术海报
3
+
4
+ <p align="center">
5
+ <a href="https://arxiv.org/abs/2505.21497" target="_blank"><img src="https://img.shields.io/badge/arXiv-2505.21497-red"></a>
6
+ <a href="https://paper2poster.github.io/" target="_blank"><img src="https://img.shields.io/badge/Project-Page-brightgreen"></a>
7
+ <a href="https://huggingface.co/datasets/Paper2Poster/Paper2Poster" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-orange"></a>
8
+ <a href="https://huggingface.co/papers/2505.21497" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Daily Papers-red"></a>
9
+ <a href="https://x.com/_akhaliq/status/1927721150584390129" target="_blank"><img alt="X (formerly Twitter) URL" src="https://img.shields.io/twitter/url?url=https%3A%2F%2Fx.com%2F_akhaliq%2Fstatus%2F1927721150584390129"></a>
10
+ </p>
11
+
12
+ We address **How to create a poster from a paper** and **How to evaluate poster.**
13
+
14
+ ![Overview](./assets/overall.png)
15
+
16
+
17
+ ## 🔥 Update
18
+ - [x] [2025.10.13] Added automatic **logo support** for conferences and institutions, **YAML-based style customization**, a new default theme.
19
+ - [x] [2025.9.18] Paper2Poster has been accepted to **NeurIPS 2025 Dataset and Benchmark Track**.
20
+ - [x] [2025.9.3] We now support generate per section content in **parallel** for faster generation, by simply specifying `--max_workers`.
21
+ - [x] [2025.5.27] We release the [arXiv](https://arxiv.org/abs/2505.21497), [code](https://github.com/Paper2Poster/Paper2Poster) and [`dataset`](https://huggingface.co/datasets/Paper2Poster/Paper2Poster)
22
+
23
+ <!--## 📚 Introduction-->
24
+
25
+ **PosterAgent** is a top-down, visual-in-the-loop multi-agent system from `paper.pdf` to **editable** `poster.pptx`.
26
+
27
+ ![PosterAgent Overview](./assets/posteragent.png)
28
+
29
+ <!--A Top-down, visual-in-the-loop, efficient multi-agent pipeline, which includes (a) Parser distills the paper into a structured asset library; the (b) Planner aligns text–visual pairs into a binary‐tree layout that preserves reading order and spatial balance; and the (c) Painter-Commentor loop refines each panel by executing rendering code and using VLM feedback to eliminate overflow and ensure alignment.-->
30
+
31
+ <!--![Paper2Poster Overview](./assets/paperquiz.png)-->
32
+
33
+ <!--**Paper2Poster:** A benchmark for paper to poster generation, paired with human generated poster, with a comprehensive evaluation suite, including metrics like **Visual Quality**, **Textual Coherence**, **VLM-as-Judge** and **PaperQuiz**. Notably, PaperQuiz is a novel evaluation which assume A Good poster should convey core paper content visually.-->
34
+
35
+ ## 📋 Table of Contents
36
+
37
+ <!--- [📚 Introduction](#-introduction)-->
38
+ - [🛠️ Installation](#-installation)
39
+ - [🚀 Quick Start](#-quick-start)
40
+ - [🔮 Evaluation](#-evaluation)
41
+ ---
42
+
43
+ ## 🛠️ Installation
44
+ Our Paper2Poster supports both local deployment (via [vLLM](https://docs.vllm.ai/en/v0.6.6/getting_started/installation.html)) or API-based access (e.g., GPT-4o).
45
+
46
+ **Python Environment**
47
+ ```bash
48
+ pip install -r requirements.txt
49
+ ```
50
+
51
+ **Install Libreoffice**
52
+ ```bash
53
+ sudo apt install libreoffice
54
+ ```
55
+
56
+ or, if you do **not** have sudo access, download `soffice` executable directly: https://www.libreoffice.org/download/download-libreoffice/, and add the executable directory to your `$PATH`.
57
+
58
+ **Install poppler**
59
+ ```bash
60
+ conda install -c conda-forge poppler
61
+ ```
62
+
63
+ **API Key**
64
+
65
+ Create a `.env` file in the project root and add your OpenAI API key:
66
+
67
+ ```bash
68
+ OPENAI_API_KEY=<your_openai_api_key>
69
+ ```
70
+
71
+ **Optional: Google Search API (for logo search)**
72
+
73
+ To use Google Custom Search for more reliable logo search, add these to your `.env` file:
74
+
75
+ ```bash
76
+ GOOGLE_SEARCH_API_KEY=<your_google_search_api_key>
77
+ GOOGLE_SEARCH_ENGINE_ID=<your_search_engine_id>
78
+ ```
79
+
80
+ ---
81
+
82
+ ## 🚀 Quick Start
83
+ Create a folder named `{paper_name}` under `{dataset_dir}`, and place your paper inside it as a PDF file named `paper.pdf`.
84
+ ```
85
+ 📁 {dataset_dir}/
86
+ └── 📁 {paper_name}/
87
+ └── 📄 paper.pdf
88
+ ```
89
+ To use open-source models, you need to first deploy them using [vLLM](https://docs.vllm.ai/en/v0.6.6/getting_started/installation.html), ensuring the port is correctly specified in the `get_agent_config()` function in [`utils/wei_utils.py`](utils/wei_utils.py).
90
+
91
+ - [High Performance] Generate a poster with `GPT-4o`:
92
+
93
+ ```bash
94
+ python -m PosterAgent.new_pipeline \
95
+ --poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
96
+ --model_name_t="4o" \ # LLM
97
+ --model_name_v="4o" \ # VLM
98
+ --poster_width_inches=48 \
99
+ --poster_height_inches=36
100
+ ```
101
+
102
+ - [Economic] Generate a poster with `Qwen-2.5-7B-Instruct` and `GPT-4o`:
103
+
104
+ ```bash
105
+ python -m PosterAgent.new_pipeline \
106
+ --poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
107
+ --model_name_t="vllm_qwen" \ # LLM
108
+ --model_name_v="4o" \ # VLM
109
+ --poster_width_inches=48 \
110
+ --poster_height_inches=36 \
111
+ --no_blank_detection # An option to disable blank detection
112
+ ```
113
+
114
+ - [Local] Generate a poster with `Qwen-2.5-7B-Instruct`:
115
+
116
+ ```bash
117
+ python -m PosterAgent.new_pipeline \
118
+ --poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
119
+ --model_name_t="vllm_qwen" \ # LLM
120
+ --model_name_v="vllm_qwen_vl" \ # VLM
121
+ --poster_width_inches=48 \
122
+ --poster_height_inches=36
123
+ ```
124
+
125
+ PosterAgent **supports flexible combination of LLM / VLM**, feel free to try other options, or customize your own settings in `get_agent_config()` in [`utils/wei_utils.py`](utils/wei_utils.py).
126
+
127
+ ### Adding Logos to Posters
128
+
129
+ You can automatically add institutional and conference logos to your posters:
130
+
131
+ ```bash
132
+ python -m PosterAgent.new_pipeline \
133
+ --poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
134
+ --model_name_t="4o" \
135
+ --model_name_v="4o" \
136
+ --poster_width_inches=48 \
137
+ --poster_height_inches=36 \
138
+ --conference_venue="NeurIPS" # Automatically searches for conference logo
139
+ ```
140
+
141
+ **Logo Search Strategy:**
142
+ 1. **Local search**: First checks the provided logo store (`logo_store/institutes/` and `logo_store/conferences/`)
143
+ 2. **Web search**: If not found locally, performs online search
144
+ - By default, uses DuckDuckGo (no API key required)
145
+ - For more reliable results, use `--use_google_search` (requires `GOOGLE_SEARCH_API_KEY` and `GOOGLE_SEARCH_ENGINE_ID` in `.env`)
146
+
147
+ You can also specify custom logo paths to skip auto-detection:
148
+ ```bash
149
+ --institution_logo_path="path/to/institution_logo.png" \
150
+ --conference_logo_path="path/to/conference_logo.png"
151
+ ```
152
+
153
+ ### YAML Style Customization
154
+
155
+ Customize poster appearance via YAML configuration files:
156
+ - **Global defaults**: `config/poster.yaml` (applies to all posters)
157
+ - **Per-poster override**: Place `poster.yaml` next to your `paper.pdf` for custom styling
158
+
159
+
160
+ ## 🔮 Evaluation
161
+ Download Paper2Poster evaluation dataset via:
162
+ ```bash
163
+ python -m PosterAgent.create_dataset
164
+ ```
165
+
166
+ In evaluation, papers are stored under a directory called `Paper2Poster-data`.
167
+
168
+ To evaluate a generated poster with **PaperQuiz**:
169
+ ```bash
170
+ python -m Paper2Poster-eval.eval_poster_pipeline \
171
+ --paper_name="${paper_name}" \
172
+ --poster_method="${model_t}_${model_v}_generated_posters" \
173
+ --metric=qa # PaperQuiz
174
+ ```
175
+
176
+ To evaluate a generated poster with **VLM-as-Judge**:
177
+ ```bash
178
+ python -m Paper2Poster-eval.eval_poster_pipeline \
179
+ --paper_name="${paper_name}" \
180
+ --poster_method="${model_t}_${model_v}_generated_posters" \
181
+ --metric=judge # VLM-as-Judge
182
+ ```
183
+
184
+ To evaluate a generated poster with other statistical metrics (such as visual similarity, PPL, etc):
185
+ ```bash
186
+ python -m Paper2Poster-eval.eval_poster_pipeline \
187
+ --paper_name="${paper_name}" \
188
+ --poster_method="${model_t}_${model_v}_generated_posters" \
189
+ --metric=stats # statistical measures
190
+ ```
191
+
192
+ If you want to create a PaperQuiz for your own paper:
193
+ ```bash
194
+ python -m Paper2Poster-eval.create_paper_questions \
195
+ --paper_folder="Paper2Poster-data/${paper_name}"
196
+ ```
197
+
198
+ ## ❤ Acknowledgement
199
+ We extend our gratitude to [🐫CAMEL](https://github.com/camel-ai/camel), [🦉OWL](https://github.com/camel-ai/owl), [Docling](https://github.com/docling-project/docling), [PPTAgent](https://github.com/icip-cas/PPTAgent) for providing their codebases.
200
+
201
+ ## 📖 Citation
202
+
203
+ Please kindly cite our paper if you find this project helpful.
204
+
205
+ ```bibtex
206
+ @misc{paper2poster,
207
+ title={Paper2Poster: Towards Multimodal Poster Automation from Scientific Papers},
208
+ author={Wei Pang and Kevin Qinghong Lin and Xiangru Jian and Xi He and Philip Torr},
209
+ year={2025},
210
+ eprint={2505.21497},
211
+ archivePrefix={arXiv},
212
+ primaryClass={cs.CV},
213
+ url={https://arxiv.org/abs/2505.21497},
214
+ }
215
+ ```
Paper2Poster/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import sys, os
2
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "camel"))
3
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), "docling"))
Paper2Poster/camel/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from camel.logger import disable_logging, enable_logging, set_log_level
16
+
17
+ __version__ = '0.2.19'
18
+
19
+ __all__ = [
20
+ '__version__',
21
+ 'camel',
22
+ 'disable_logging',
23
+ 'enable_logging',
24
+ 'set_log_level',
25
+ ]
Paper2Poster/camel/agents/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .base import BaseAgent
15
+ from .chat_agent import ChatAgent
16
+ from .critic_agent import CriticAgent
17
+ from .embodied_agent import EmbodiedAgent
18
+ from .knowledge_graph_agent import KnowledgeGraphAgent
19
+ from .role_assignment_agent import RoleAssignmentAgent
20
+ from .search_agent import SearchAgent
21
+ from .task_agent import (
22
+ TaskCreationAgent,
23
+ TaskPlannerAgent,
24
+ TaskPrioritizationAgent,
25
+ TaskSpecifyAgent,
26
+ )
27
+ from .tool_agents.base import BaseToolAgent
28
+ from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
29
+
30
+ __all__ = [
31
+ 'BaseAgent',
32
+ 'ChatAgent',
33
+ 'TaskSpecifyAgent',
34
+ 'TaskPlannerAgent',
35
+ 'TaskCreationAgent',
36
+ 'TaskPrioritizationAgent',
37
+ 'CriticAgent',
38
+ 'BaseToolAgent',
39
+ 'HuggingFaceToolAgent',
40
+ 'EmbodiedAgent',
41
+ 'RoleAssignmentAgent',
42
+ 'SearchAgent',
43
+ 'KnowledgeGraphAgent',
44
+ ]
Paper2Poster/camel/agents/base.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any
16
+
17
+
18
+ class BaseAgent(ABC):
19
+ r"""An abstract base class for all CAMEL agents."""
20
+
21
+ @abstractmethod
22
+ def reset(self, *args: Any, **kwargs: Any) -> Any:
23
+ r"""Resets the agent to its initial state."""
24
+ pass
25
+
26
+ @abstractmethod
27
+ def step(self, *args: Any, **kwargs: Any) -> Any:
28
+ r"""Performs a single step of the agent."""
29
+ pass
Paper2Poster/camel/agents/chat_agent.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import logging
18
+ import re
19
+ import uuid
20
+ from collections import defaultdict
21
+ from typing import (
22
+ TYPE_CHECKING,
23
+ Any,
24
+ Callable,
25
+ Dict,
26
+ List,
27
+ Optional,
28
+ Tuple,
29
+ Type,
30
+ Union,
31
+ )
32
+
33
+ from openai.types.chat import ChatCompletionMessageToolCall
34
+ from openai.types.chat.chat_completion_message_tool_call import Function
35
+ from pydantic import BaseModel, ValidationError
36
+
37
+ from camel.agents.base import BaseAgent
38
+ from camel.memories import (
39
+ AgentMemory,
40
+ ChatHistoryMemory,
41
+ MemoryRecord,
42
+ ScoreBasedContextCreator,
43
+ )
44
+ from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
45
+ from camel.models import (
46
+ BaseModelBackend,
47
+ ModelFactory,
48
+ ModelManager,
49
+ ModelProcessingError,
50
+ )
51
+ from camel.responses import ChatAgentResponse
52
+ from camel.types import (
53
+ ChatCompletion,
54
+ ChatCompletionChunk,
55
+ ModelPlatformType,
56
+ ModelType,
57
+ OpenAIBackendRole,
58
+ RoleType,
59
+ )
60
+ from camel.utils import (
61
+ func_string_to_callable,
62
+ generate_prompt_for_structured_output,
63
+ get_model_encoding,
64
+ get_pydantic_object_schema,
65
+ json_to_function_code,
66
+ )
67
+
68
+ if TYPE_CHECKING:
69
+ from openai import Stream
70
+
71
+ from camel.terminators import ResponseTerminator
72
+ from camel.toolkits import FunctionTool
73
+
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ # AgentOps decorator setting
78
+ try:
79
+ import os
80
+
81
+ if os.getenv("AGENTOPS_API_KEY") is not None:
82
+ from agentops import track_agent
83
+ else:
84
+ raise ImportError
85
+ except (ImportError, AttributeError):
86
+ from camel.utils import track_agent
87
+
88
+
89
+ class FunctionCallingRecord(BaseModel):
90
+ r"""Historical records of functions called in the conversation.
91
+
92
+ Attributes:
93
+ func_name (str): The name of the function being called.
94
+ args (Dict[str, Any]): The dictionary of arguments passed to
95
+ the function.
96
+ result (Any): The execution result of calling this function.
97
+ tool_call_id (str): The ID of the tool call, if available.
98
+ """
99
+
100
+ func_name: str
101
+ args: Dict[str, Any]
102
+ result: Any
103
+ tool_call_id: str
104
+
105
+ def __str__(self) -> str:
106
+ r"""Overridden version of the string function.
107
+
108
+ Returns:
109
+ str: Modified string to represent the function calling.
110
+ """
111
+ return (
112
+ f"Function Execution: {self.func_name}\n"
113
+ f"\tArgs: {self.args}\n"
114
+ f"\tResult: {self.result}\n"
115
+ )
116
+
117
+ def as_dict(self) -> dict[str, Any]:
118
+ r"""Returns the function calling record as a dictionary.
119
+
120
+ Returns:
121
+ dict[str, Any]: The function calling record as a dictionary.
122
+ """
123
+ return self.model_dump()
124
+
125
+
126
+ @track_agent(name="ChatAgent")
127
+ class ChatAgent(BaseAgent):
128
+ r"""Class for managing conversations of CAMEL Chat Agents.
129
+
130
+ Args:
131
+ system_message (Union[BaseMessage, str], optional): The system message
132
+ for the chat agent.
133
+ model (BaseModelBackend, optional): The model backend to use for
134
+ generating responses. (default: :obj:`ModelPlatformType.DEFAULT`
135
+ with `ModelType.DEFAULT`)
136
+ memory (AgentMemory, optional): The agent memory for managing chat
137
+ messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
138
+ (default: :obj:`None`)
139
+ message_window_size (int, optional): The maximum number of previous
140
+ messages to include in the context window. If `None`, no windowing
141
+ is performed. (default: :obj:`None`)
142
+ token_limit (int, optional): The maximum number of tokens in a context.
143
+ The context will be automatically pruned to fulfill the limitation.
144
+ If `None`, it will be set according to the backend model.
145
+ (default: :obj:`None`)
146
+ output_language (str, optional): The language to be output by the
147
+ agent. (default: :obj:`None`)
148
+ tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
149
+ of available :obj:`FunctionTool` or :obj:`Callable`. (default:
150
+ :obj:`None`)
151
+ external_tools (Optional[List[Union[FunctionTool, Callable]]],
152
+ optional): List of external tools (:obj:`FunctionTool` or or
153
+ :obj:`Callable`) bind to one chat agent. When these tools are
154
+ called, the agent will directly return the request instead of
155
+ processing it. (default: :obj:`None`)
156
+ response_terminators (List[ResponseTerminator], optional): List of
157
+ :obj:`ResponseTerminator` bind to one chat agent.
158
+ (default: :obj:`None`)
159
+ scheduling_strategy (str): name of function that defines how to select
160
+ the next model in ModelManager. (default: :str:`round_robin`)
161
+ single_iteration (bool): Whether to let the agent perform only one
162
+ model calling at each step. (default: :obj:`False`)
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ system_message: Optional[Union[BaseMessage, str]] = None,
168
+ model: Optional[
169
+ Union[BaseModelBackend, List[BaseModelBackend]]
170
+ ] = None,
171
+ memory: Optional[AgentMemory] = None,
172
+ message_window_size: Optional[int] = None,
173
+ token_limit: Optional[int] = None,
174
+ output_language: Optional[str] = None,
175
+ tools: Optional[List[Union[FunctionTool, Callable]]] = None,
176
+ external_tools: Optional[List[Union[FunctionTool, Callable]]] = None,
177
+ response_terminators: Optional[List[ResponseTerminator]] = None,
178
+ scheduling_strategy: str = "round_robin",
179
+ single_iteration: bool = False,
180
+ ) -> None:
181
+ # Initialize the system message, converting string to BaseMessage if needed
182
+ if isinstance(system_message, str):
183
+ system_message = BaseMessage.make_assistant_message(
184
+ role_name='Assistant', content=system_message
185
+ )
186
+
187
+ self.orig_sys_message: Optional[BaseMessage] = system_message
188
+ self._system_message: Optional[BaseMessage] = system_message
189
+ self.role_name: str = (
190
+ getattr(system_message, 'role_name', None) or "assistant"
191
+ )
192
+ self.role_type: RoleType = (
193
+ getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
194
+ )
195
+ self.model_backend = ModelManager(
196
+ model
197
+ if model is not None
198
+ else ModelFactory.create(
199
+ model_platform=ModelPlatformType.DEFAULT,
200
+ model_type=ModelType.DEFAULT,
201
+ ),
202
+ scheduling_strategy=scheduling_strategy,
203
+ )
204
+ self.model_type = self.model_backend.model_type
205
+
206
+ # Initialize tools
207
+ self.tools: List[FunctionTool] = (
208
+ self._initialize_tools(tools) if tools else []
209
+ )
210
+ self.external_tools: List[FunctionTool] = (
211
+ self._initialize_tools(external_tools) if external_tools else []
212
+ )
213
+ self.external_tool_names: List[str] = [
214
+ tool.get_function_name() for tool in self.external_tools
215
+ ]
216
+ self.all_tools = self.tools + self.external_tools or []
217
+
218
+ # Create tool dictionaries and configure backend tools if necessary
219
+ self.tool_dict = {
220
+ tool.get_function_name(): tool for tool in self.all_tools
221
+ }
222
+
223
+ # If the user set tools from `ChatAgent`, it will override the
224
+ # configured tools in `BaseModelBackend`.
225
+ if self.all_tools:
226
+ logger.warning(
227
+ "Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
228
+ )
229
+ tool_schema_list = [
230
+ tool.get_openai_tool_schema() for tool in self.all_tools
231
+ ]
232
+ self.model_backend.model_config_dict['tools'] = tool_schema_list
233
+
234
+ self.model_token_limit = token_limit or self.model_backend.token_limit
235
+ context_creator = ScoreBasedContextCreator(
236
+ self.model_backend.token_counter,
237
+ self.model_token_limit,
238
+ )
239
+ self.memory: AgentMemory = memory or ChatHistoryMemory(
240
+ context_creator, window_size=message_window_size
241
+ )
242
+
243
+ self.output_language: Optional[str] = output_language
244
+ if self.output_language is not None:
245
+ self.set_output_language(self.output_language)
246
+
247
+ self.terminated: bool = False
248
+ self.response_terminators = response_terminators or []
249
+ self.init_messages()
250
+ self.tool_prompt_added = False
251
+ self.single_iteration = single_iteration
252
+
253
+ def _initialize_tools(
254
+ self, tools: List[Union[FunctionTool, Callable]]
255
+ ) -> List[FunctionTool]:
256
+ r"""Helper method to initialize tools as FunctionTool instances."""
257
+ from camel.toolkits import FunctionTool
258
+
259
+ func_tools = []
260
+ for tool in tools:
261
+ if not isinstance(tool, FunctionTool):
262
+ tool = FunctionTool(tool)
263
+ func_tools.append(tool)
264
+ return func_tools
265
+
266
+ def add_tool(
267
+ self, tool: Union[FunctionTool, Callable], is_external: bool = False
268
+ ) -> None:
269
+ r"""Add a tool to the agent, specifying if it's an external tool."""
270
+ # Initialize the tool
271
+ initialized_tool = self._initialize_tools([tool])
272
+
273
+ # Update tools or external tools based on is_external flag
274
+ if is_external:
275
+ self.external_tools = self.external_tools + initialized_tool
276
+ self.external_tool_names.extend(
277
+ tool.get_function_name() for tool in initialized_tool
278
+ )
279
+ else:
280
+ self.tools = self.tools + initialized_tool
281
+
282
+ # Rebuild all_tools, and tool_dict
283
+ self.all_tools = self.tools + self.external_tools
284
+ self.tool_dict = {
285
+ tool.get_function_name(): tool for tool in self.all_tools
286
+ }
287
+
288
+ tool_schema_list = [
289
+ tool.get_openai_tool_schema() for tool in self.all_tools
290
+ ]
291
+ self.model_backend.model_config_dict['tools'] = tool_schema_list
292
+
293
+ def remove_tool(self, tool_name: str, is_external: bool = False) -> bool:
294
+ r"""Remove a tool by name, specifying if it's an external tool."""
295
+ tool_list = self.external_tools if is_external else self.tools
296
+ if not tool_list:
297
+ return False
298
+
299
+ for tool in tool_list:
300
+ if tool.get_function_name() == tool_name:
301
+ tool_list.remove(tool)
302
+ if is_external:
303
+ self.external_tool_names.remove(tool_name)
304
+ # Reinitialize the tool dictionary
305
+ self.all_tools = (self.tools or []) + (
306
+ self.external_tools or []
307
+ )
308
+ self.tool_dict = {
309
+ tool.get_function_name(): tool for tool in self.all_tools
310
+ }
311
+ tool_schema_list = [
312
+ tool.get_openai_tool_schema() for tool in self.all_tools
313
+ ]
314
+ self.model_backend.model_config_dict['tools'] = (
315
+ tool_schema_list
316
+ )
317
+ return True
318
+ return False
319
+
320
+ def list_tools(self) -> dict:
321
+ r"""List all tools, separated into normal and external tools."""
322
+ normal_tools = [
323
+ tool.get_function_name() for tool in (self.tools or [])
324
+ ]
325
+ external_tools = [
326
+ tool.get_function_name() for tool in (self.external_tools or [])
327
+ ]
328
+
329
+ return {"normal_tools": normal_tools, "external_tools": external_tools}
330
+
331
+ # ruff: noqa: E501
332
+ def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
333
+ r"""Generates a tool prompt based on the provided tool schema list.
334
+
335
+ Args:
336
+ tool_schema_list (List[Dict]): A list of dictionaries, each
337
+ containing a tool schema.
338
+
339
+ Returns:
340
+ str: A string representing the tool prompt.
341
+ """
342
+ tool_prompts = []
343
+
344
+ for tool in tool_schema_list:
345
+ tool_info = tool['function']
346
+ tool_name = tool_info['name']
347
+ tool_description = tool_info['description']
348
+ tool_json = json.dumps(tool_info, indent=4)
349
+
350
+ prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
351
+ tool_prompts.append(prompt)
352
+
353
+ tool_prompt_str = "\n".join(tool_prompts)
354
+
355
+ final_prompt = f"""
356
+ You have access to the following functions:
357
+
358
+ {tool_prompt_str}
359
+
360
+ If you choose to call a function ONLY reply in the following format with no
361
+ prefix or suffix:
362
+
363
+ <function=example_function_name>{{"example_name": "example_value"}}</function>
364
+
365
+ Reminder:
366
+ - Function calls MUST follow the specified format, start with <function= and end with </function>
367
+ - Required parameters MUST be specified
368
+ - Only call one function at a time
369
+ - Put the entire function call reply on one line
370
+ - If there is no function call available, answer the question like normal
371
+ with your current knowledge and do not tell the user about function calls
372
+ """
373
+ return final_prompt
374
+
375
+ def _parse_tool_response(self, response: str):
376
+ r"""Parses the tool response to extract the function name and
377
+ arguments.
378
+
379
+ Args:
380
+ response (str): The response from the model containing the
381
+ function call.
382
+
383
+ Returns:
384
+ Optional[Dict[str, Any]]: The parsed function name and arguments
385
+ if found, otherwise :obj:`None`.
386
+ """
387
+ function_regex = r"<function=(\w+)>(.*?)</function>"
388
+ match = re.search(function_regex, response)
389
+
390
+ if match:
391
+ function_name, args_string = match.groups()
392
+ try:
393
+ args = json.loads(args_string)
394
+ return {"function": function_name, "arguments": args}
395
+ except json.JSONDecodeError as error:
396
+ logger.error(f"Error parsing function arguments: {error}")
397
+ return None
398
+ return None
399
+
400
+ def reset(self):
401
+ r"""Resets the :obj:`ChatAgent` to its initial state."""
402
+ self.terminated = False
403
+ self.init_messages()
404
+ for terminator in self.response_terminators:
405
+ terminator.reset()
406
+
407
+ @property
408
+ def system_message(self) -> Optional[BaseMessage]:
409
+ r"""The getter method for the property :obj:`system_message`.
410
+
411
+ Returns:
412
+ Optional[BaseMessage]: The system message of this agent if set,
413
+ else :obj:`None`.
414
+ """
415
+ return self._system_message
416
+
417
+ @system_message.setter
418
+ def system_message(self, message: BaseMessage) -> None:
419
+ r"""The setter method for the property :obj:`system_message`.
420
+
421
+ Args:
422
+ message (BaseMessage): The message to be set as the
423
+ new system message of this agent.
424
+ """
425
+ self._system_message = message
426
+
427
+ def is_tools_added(self) -> bool:
428
+ r"""Whether tool calling is enabled for this agent.
429
+
430
+ Returns:
431
+ bool: Whether tool calling is enabled for this agent, determined
432
+ by whether the dictionary of tools is empty.
433
+ """
434
+ return len(self.tool_dict) > 0
435
+
436
+ def update_memory(
437
+ self, message: BaseMessage, role: OpenAIBackendRole
438
+ ) -> None:
439
+ r"""Updates the agent memory with a new message.
440
+
441
+ Args:
442
+ message (BaseMessage): The new message to add to the stored
443
+ messages.
444
+ role (OpenAIBackendRole): The backend role type.
445
+ """
446
+ self.memory.write_record(
447
+ MemoryRecord(message=message, role_at_backend=role)
448
+ )
449
+
450
+ def set_output_language(self, output_language: str) -> BaseMessage:
451
+ r"""Sets the output language for the system message. This method
452
+ updates the output language for the system message. The output
453
+ language determines the language in which the output text should be
454
+ generated.
455
+
456
+ Args:
457
+ output_language (str): The desired output language.
458
+
459
+ Returns:
460
+ BaseMessage: The updated system message object.
461
+ """
462
+ self.output_language = output_language
463
+ language_prompt = (
464
+ "\nRegardless of the input language, "
465
+ f"you must output text in {output_language}."
466
+ )
467
+ if self.orig_sys_message is not None:
468
+ content = self.orig_sys_message.content + language_prompt
469
+ self._system_message = self.orig_sys_message.create_new_instance(
470
+ content
471
+ )
472
+ else:
473
+ self._system_message = BaseMessage.make_assistant_message(
474
+ role_name="Assistant",
475
+ content=language_prompt,
476
+ )
477
+
478
+ system_record = MemoryRecord(
479
+ message=self._system_message,
480
+ role_at_backend=OpenAIBackendRole.SYSTEM,
481
+ )
482
+ self.memory.clear()
483
+ self.memory.write_record(system_record)
484
+ return self._system_message
485
+
486
+ def get_info(
487
+ self,
488
+ session_id: Optional[str],
489
+ usage: Optional[Dict[str, int]],
490
+ termination_reasons: List[str],
491
+ num_tokens: int,
492
+ tool_calls: List[FunctionCallingRecord],
493
+ external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
494
+ ) -> Dict[str, Any]:
495
+ r"""Returns a dictionary containing information about the chat session.
496
+
497
+ Args:
498
+ session_id (str, optional): The ID of the chat session.
499
+ usage (Dict[str, int], optional): Information about the usage of
500
+ the LLM.
501
+ termination_reasons (List[str]): The reasons for the termination
502
+ of the chat session.
503
+ num_tokens (int): The number of tokens used in the chat session.
504
+ tool_calls (List[FunctionCallingRecord]): The list of function
505
+ calling records, containing the information of called tools.
506
+ external_tool_request
507
+ (Optional[ChatCompletionMessageToolCall], optional):
508
+ The tool calling request of external tools from the model.
509
+ These requests are directly returned to the user instead of
510
+ being processed by the agent automatically.
511
+ (default: :obj:`None`)
512
+
513
+ Returns:
514
+ Dict[str, Any]: The chat session information.
515
+ """
516
+ return {
517
+ "id": session_id,
518
+ "usage": usage,
519
+ "termination_reasons": termination_reasons,
520
+ "num_tokens": num_tokens,
521
+ "tool_calls": tool_calls,
522
+ "external_tool_request": external_tool_request,
523
+ }
524
+
525
+ def init_messages(self) -> None:
526
+ r"""Initializes the stored messages list with the current system
527
+ message.
528
+ """
529
+ if self._system_message is not None:
530
+ system_record = MemoryRecord(
531
+ message=self._system_message,
532
+ role_at_backend=OpenAIBackendRole.SYSTEM,
533
+ )
534
+ self.memory.clear()
535
+ self.memory.write_record(system_record)
536
+ else:
537
+ self.memory.clear()
538
+
539
+ def record_message(self, message: BaseMessage) -> None:
540
+ r"""Records the externally provided message into the agent memory as if
541
+ it were an answer of the :obj:`ChatAgent` from the backend. Currently,
542
+ the choice of the critic is submitted with this method.
543
+
544
+ Args:
545
+ message (BaseMessage): An external message to be recorded in the
546
+ memory.
547
+ """
548
+ self.update_memory(message, OpenAIBackendRole.ASSISTANT)
549
+
550
+ def step(
551
+ self,
552
+ input_message: Union[BaseMessage, str],
553
+ response_format: Optional[Type[BaseModel]] = None,
554
+ ) -> ChatAgentResponse:
555
+ r"""Executes a single step in the chat session, generating a response
556
+ to the input message.
557
+
558
+ Args:
559
+ input_message (Union[BaseMessage, str]): The input message for the
560
+ agent. If provided as a BaseMessage, the `role` is adjusted to
561
+ `user` to indicate an external message.
562
+ response_format (Optional[Type[BaseModel]], optional): A Pydantic
563
+ model defining the expected structure of the response. Used to
564
+ generate a structured response if provided. (default:
565
+ :obj:`None`)
566
+
567
+ Returns:
568
+ ChatAgentResponse: Contains output messages, a termination status
569
+ flag, and session information.
570
+ """
571
+
572
+ if (
573
+ self.model_backend.model_config_dict.get("response_format")
574
+ and response_format
575
+ ):
576
+ raise ValueError(
577
+ "The `response_format` parameter cannot be set both in "
578
+ "the model configuration and in the ChatAgent step."
579
+ )
580
+
581
+ self.original_model_dict = self.model_backend.model_config_dict
582
+ model_response_format_modified = False
583
+ if (
584
+ response_format
585
+ and self.model_type.support_native_structured_output
586
+ ):
587
+ self.model_backend.model_config_dict = (
588
+ self.original_model_dict.copy()
589
+ )
590
+ self.model_backend.model_config_dict["response_format"] = (
591
+ response_format
592
+ )
593
+ model_response_format_modified = True
594
+
595
+ # Convert input message to BaseMessage if necessary
596
+ if isinstance(input_message, str):
597
+ input_message = BaseMessage.make_user_message(
598
+ role_name='User', content=input_message
599
+ )
600
+
601
+ # Handle tool prompt injection if needed
602
+ if (
603
+ self.is_tools_added()
604
+ and not self.model_type.support_native_tool_calling
605
+ and not self.tool_prompt_added
606
+ ):
607
+ self._inject_tool_prompt()
608
+
609
+ # Add user input to memory
610
+ self.update_memory(input_message, OpenAIBackendRole.USER)
611
+
612
+ try:
613
+ return self._handle_step(response_format, self.single_iteration)
614
+ finally:
615
+ if model_response_format_modified:
616
+ # Reset model config back to original state
617
+ self.model_backend.model_config_dict = self.original_model_dict
618
+
619
+ def _inject_tool_prompt(self) -> None:
620
+ r"""Generate and add the tool prompt to memory."""
621
+ tool_prompt = self._generate_tool_prompt(
622
+ self.model_backend.model_config_dict["tools"]
623
+ )
624
+ tool_msg = BaseMessage.make_assistant_message(
625
+ role_name="Assistant", content=tool_prompt
626
+ )
627
+ self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
628
+ self.tool_prompt_added = True
629
+
630
+ def _handle_step(
631
+ self,
632
+ response_format: Optional[Type[BaseModel]],
633
+ single_step: bool,
634
+ ) -> ChatAgentResponse:
635
+ r"""Handles a single or multi-step interaction."""
636
+
637
+ if (
638
+ self.model_backend.model_config_dict.get("tool_choice")
639
+ == "required"
640
+ and not single_step
641
+ ):
642
+ raise ValueError(
643
+ "`tool_choice` cannot be set to `required` for multi-step"
644
+ " mode. To proceed, set `single_iteration` to `True`."
645
+ )
646
+
647
+ # Record function calls made during the session
648
+ tool_call_records: List[FunctionCallingRecord] = []
649
+
650
+ external_tool_request = None
651
+
652
+ while True:
653
+ try:
654
+ openai_messages, num_tokens = self.memory.get_context()
655
+ except RuntimeError as e:
656
+ self.model_backend.model_config_dict = self.original_model_dict
657
+ return self._step_token_exceed(
658
+ e.args[1], tool_call_records, "max_tokens_exceeded"
659
+ )
660
+
661
+ # Prompt engineering approach for structured output for non-native tool calling models
662
+ inject_prompt_for_structured_output = (
663
+ response_format
664
+ and not self.model_type.support_native_structured_output
665
+ )
666
+
667
+ if inject_prompt_for_structured_output:
668
+ # update last openai message
669
+ usr_msg = openai_messages.pop()
670
+ usr_msg["content"] = generate_prompt_for_structured_output(
671
+ response_format,
672
+ usr_msg["content"], # type: ignore [arg-type]
673
+ )
674
+ openai_messages.append(usr_msg)
675
+
676
+ # Process model response
677
+ (
678
+ response,
679
+ output_messages,
680
+ finish_reasons,
681
+ usage_dict,
682
+ response_id,
683
+ ) = self._step_model_response(openai_messages, num_tokens)
684
+
685
+ # Try to parse structured output to return a Pydantic object
686
+ if inject_prompt_for_structured_output and isinstance(
687
+ response, ChatCompletion
688
+ ):
689
+ content = response.choices[0].message.content
690
+ try:
691
+ json_content = json.loads(str(content))
692
+ output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
693
+ except json.JSONDecodeError as e:
694
+ logger.error(
695
+ f"Failed in parsing the output into JSON: {e}"
696
+ )
697
+ output_messages[0].parsed = None
698
+ except ValidationError as e:
699
+ logger.warning(
700
+ "Successfully generating JSON response, "
701
+ "but failed in parsing it into Pydantic object :"
702
+ f"{e}, return the JSON response in parsed field"
703
+ )
704
+ output_messages[0].parsed = json_content
705
+
706
+ # Finalize on standard response in multi-step mode
707
+ if self._is_standard_response(response):
708
+ break
709
+
710
+ # Handle tool requests
711
+ tool_request = self._extract_tool_call(response)
712
+ if isinstance(response, ChatCompletion) and tool_request:
713
+ response.choices[0].message.tool_calls = [tool_request]
714
+ tool_call_records.append(
715
+ self._step_tool_call_and_update(response)
716
+ )
717
+
718
+ if tool_request.function.name in self.external_tool_names:
719
+ external_tool_request = tool_request
720
+ info = self._step_get_info(
721
+ output_messages,
722
+ finish_reasons,
723
+ usage_dict,
724
+ response_id,
725
+ tool_call_records,
726
+ num_tokens,
727
+ tool_request,
728
+ )
729
+ self._log_final_output(output_messages)
730
+ self.model_backend.model_config_dict = (
731
+ self.original_model_dict
732
+ )
733
+ return ChatAgentResponse(
734
+ msgs=output_messages,
735
+ terminated=self.terminated,
736
+ info=info,
737
+ )
738
+
739
+ # Single-step mode ends after one iteration
740
+ if single_step:
741
+ break
742
+
743
+ # Optional structured output via function calling
744
+ if (
745
+ response_format
746
+ and not inject_prompt_for_structured_output
747
+ and self.model_type
748
+ not in {
749
+ "gpt-4o",
750
+ "gpt-4o-mini",
751
+ }
752
+ ):
753
+ (
754
+ output_messages,
755
+ finish_reasons,
756
+ usage_dict,
757
+ response_id,
758
+ tool_call,
759
+ num_tokens,
760
+ ) = self._structure_output_with_function(response_format)
761
+ tool_call_records.append(tool_call)
762
+
763
+ # Final info and response
764
+ info = self._step_get_info(
765
+ output_messages,
766
+ finish_reasons,
767
+ usage_dict,
768
+ response_id,
769
+ tool_call_records,
770
+ num_tokens,
771
+ external_tool_request,
772
+ )
773
+ self._log_final_output(output_messages)
774
+ self.model_backend.model_config_dict = self.original_model_dict
775
+ return ChatAgentResponse(
776
+ msgs=output_messages, terminated=self.terminated, info=info
777
+ )
778
+
779
+ def _extract_tool_call(
780
+ self, response: Any
781
+ ) -> Optional[ChatCompletionMessageToolCall]:
782
+ r"""Extract the tool call from the model response, if present.
783
+
784
+ Args:
785
+ response (Any): The model's response object.
786
+
787
+ Returns:
788
+ Optional[ChatCompletionMessageToolCall]: The parsed tool call if
789
+ present, otherwise None.
790
+ """
791
+ # Check if the response contains tool calls
792
+ if (
793
+ self.is_tools_added()
794
+ and not self.model_type.support_native_tool_calling
795
+ and "</function>" in response.choices[0].message.content
796
+ ):
797
+ parsed_content = self._parse_tool_response(
798
+ response.choices[0].message.content
799
+ )
800
+ if parsed_content:
801
+ return ChatCompletionMessageToolCall(
802
+ id=str(uuid.uuid4()),
803
+ function=Function(
804
+ arguments=str(parsed_content["arguments"]).replace(
805
+ "'", '"'
806
+ ),
807
+ name=str(parsed_content["function"]),
808
+ ),
809
+ type="function",
810
+ )
811
+ elif (
812
+ self.is_tools_added()
813
+ and self.model_type.support_native_tool_calling
814
+ and response.choices[0].message.tool_calls
815
+ ):
816
+ return response.choices[0].message.tool_calls[0]
817
+
818
+ # No tool call found
819
+ return None
820
+
821
+ def _is_standard_response(self, response: Any) -> bool:
822
+ r"""Determine if the provided response is a standard reply without
823
+ tool calls.
824
+
825
+ Args:
826
+ response (Any): The response object to evaluate.
827
+
828
+ Returns:
829
+ bool: `True` if the response is a standard reply, `False`
830
+ otherwise.
831
+ """
832
+ if not self.is_tools_added():
833
+ return True
834
+
835
+ if not isinstance(response, ChatCompletion):
836
+ return True
837
+
838
+ if self.model_type.support_native_tool_calling:
839
+ return not response.choices[0].message.tool_calls
840
+
841
+ return "</function>" not in str(
842
+ response.choices[0].message.content or ""
843
+ )
844
+
845
+ def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
846
+ r"""Log final messages or warnings about multiple responses."""
847
+ if len(output_messages) == 1:
848
+ self.record_message(output_messages[0])
849
+ else:
850
+ logger.warning(
851
+ "Multiple messages returned in `step()`. Record "
852
+ "selected message manually using `record_message()`."
853
+ )
854
+
855
+ async def step_async(
856
+ self,
857
+ input_message: Union[BaseMessage, str],
858
+ response_format: Optional[Type[BaseModel]] = None,
859
+ ) -> ChatAgentResponse:
860
+ r"""Performs a single step in the chat session by generating a response
861
+ to the input message. This agent step can call async function calls.
862
+
863
+ Args:
864
+ input_message (Union[BaseMessage, str]): The input message to the
865
+ agent. For BaseMessage input, its `role` field that specifies
866
+ the role at backend may be either `user` or `assistant` but it
867
+ will be set to `user` anyway since for the self agent any
868
+ incoming message is external. For str input, the `role_name`
869
+ would be `User`.
870
+ response_format (Optional[Type[BaseModel]], optional): A pydantic
871
+ model class that includes value types and field descriptions
872
+ used to generate a structured response by LLM. This schema
873
+ helps in defining the expected output format. (default:
874
+ :obj:`None`)
875
+
876
+ Returns:
877
+ ChatAgentResponse: A struct containing the output messages,
878
+ a boolean indicating whether the chat session has terminated,
879
+ and information about the chat session.
880
+ """
881
+ if isinstance(input_message, str):
882
+ input_message = BaseMessage.make_user_message(
883
+ role_name='User', content=input_message
884
+ )
885
+
886
+ self.update_memory(input_message, OpenAIBackendRole.USER)
887
+
888
+ tool_call_records: List[FunctionCallingRecord] = []
889
+ while True:
890
+ try:
891
+ openai_messages, num_tokens = self.memory.get_context()
892
+ except RuntimeError as e:
893
+ return self._step_token_exceed(
894
+ e.args[1], tool_call_records, "max_tokens_exceeded"
895
+ )
896
+
897
+ (
898
+ response,
899
+ output_messages,
900
+ finish_reasons,
901
+ usage_dict,
902
+ response_id,
903
+ ) = self._step_model_response(openai_messages, num_tokens)
904
+
905
+ if (
906
+ not self.is_tools_added()
907
+ or not isinstance(response, ChatCompletion)
908
+ or not response.choices[0].message.tool_calls
909
+ ):
910
+ break
911
+
912
+ # Check for external tool call
913
+ external_tool_request = response.choices[0].message.tool_calls[0]
914
+ if external_tool_request.function.name in self.external_tool_names:
915
+ # if model calls an external tool, directly return the request
916
+ info = self._step_get_info(
917
+ output_messages,
918
+ finish_reasons,
919
+ usage_dict,
920
+ response_id,
921
+ tool_call_records,
922
+ num_tokens,
923
+ external_tool_request,
924
+ )
925
+ return ChatAgentResponse(
926
+ msgs=output_messages, terminated=self.terminated, info=info
927
+ )
928
+
929
+ # Normal function calling
930
+ tool_call_records.append(
931
+ await self._step_tool_call_and_update_async(response)
932
+ )
933
+
934
+ if (
935
+ response_format is not None
936
+ and self.model_type.support_native_tool_calling
937
+ ):
938
+ (
939
+ output_messages,
940
+ finish_reasons,
941
+ usage_dict,
942
+ response_id,
943
+ tool_call_record,
944
+ num_tokens,
945
+ ) = self._structure_output_with_function(response_format)
946
+ tool_call_records.append(tool_call_record)
947
+
948
+ info = self._step_get_info(
949
+ output_messages,
950
+ finish_reasons,
951
+ usage_dict,
952
+ response_id,
953
+ tool_call_records,
954
+ num_tokens,
955
+ )
956
+
957
+ if len(output_messages) == 1:
958
+ # Auto record if the output result is a single message
959
+ self.record_message(output_messages[0])
960
+ else:
961
+ logger.warning(
962
+ "Multiple messages returned in `step()`, message won't be "
963
+ "recorded automatically. Please call `record_message()` to "
964
+ "record the selected message manually."
965
+ )
966
+
967
+ return ChatAgentResponse(
968
+ msgs=output_messages, terminated=self.terminated, info=info
969
+ )
970
+
971
+ def _step_tool_call_and_update(
972
+ self, response: ChatCompletion
973
+ ) -> FunctionCallingRecord:
974
+ r"""Processes a function call within the chat completion response,
975
+ records the function call in the provided list of tool calls and
976
+ updates the memory of the current agent.
977
+
978
+ Args:
979
+ response (ChatCompletion): The response object from the chat
980
+ completion.
981
+
982
+ Returns:
983
+ FunctionCallingRecord: The record of calling the function.
984
+ """
985
+
986
+ # Perform function calling
987
+ func_assistant_msg, func_result_msg, tool_call_record = (
988
+ self._step_tool_call(response)
989
+ )
990
+
991
+ # Update the messages
992
+ self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
993
+ self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
994
+
995
+ return tool_call_record
996
+
997
+ async def _step_tool_call_and_update_async(
998
+ self, response: ChatCompletion
999
+ ) -> FunctionCallingRecord:
1000
+ (
1001
+ func_assistant_msg,
1002
+ func_result_msg,
1003
+ func_record,
1004
+ ) = await self.step_tool_call_async(response)
1005
+
1006
+ self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
1007
+ self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
1008
+
1009
+ return func_record
1010
+
1011
+ def _structure_output_with_function(
1012
+ self, response_format: Type[BaseModel]
1013
+ ) -> Tuple[
1014
+ List[BaseMessage],
1015
+ List[str],
1016
+ Dict[str, int],
1017
+ str,
1018
+ FunctionCallingRecord,
1019
+ int,
1020
+ ]:
1021
+ r"""Internal function of structuring the output of the agent based on
1022
+ the given output schema.
1023
+
1024
+ Args:
1025
+ response_format (Type[BaseModel]): The output schema to use for
1026
+ structuring the output.
1027
+
1028
+ Returns:
1029
+ Tuple[List[BaseMessage], List[str], Dict[str, int], str,
1030
+ FunctionCallingRecord, int]:
1031
+ A tuple containing the output messages, finish reasons, usage
1032
+ dictionary, response ID, function calling record, and number of
1033
+ tokens.
1034
+ """
1035
+ from camel.toolkits import FunctionTool
1036
+
1037
+ schema_json = get_pydantic_object_schema(response_format)
1038
+ func_str = json_to_function_code(schema_json)
1039
+ func_callable = func_string_to_callable(func_str)
1040
+ func = FunctionTool(func_callable)
1041
+
1042
+ original_model_dict = self.model_backend.model_config_dict
1043
+
1044
+ # Replace the original tools with the structuring function
1045
+ self.tool_dict = {func.get_function_name(): func}
1046
+ self.model_backend.model_config_dict = original_model_dict.copy()
1047
+ self.model_backend.model_config_dict["tools"] = [
1048
+ func.get_openai_tool_schema()
1049
+ ]
1050
+ self.model_backend.model_config_dict["tool_choice"] = "required"
1051
+
1052
+ openai_messages, num_tokens = self.memory.get_context()
1053
+ (
1054
+ response,
1055
+ output_messages,
1056
+ finish_reasons,
1057
+ usage_dict,
1058
+ response_id,
1059
+ ) = self._step_model_response(openai_messages, num_tokens)
1060
+
1061
+ if isinstance(response, ChatCompletion):
1062
+ tool_call_record = self._step_tool_call_and_update(response)
1063
+ else:
1064
+ raise ValueError(
1065
+ "Structured output is not supported for stream responses."
1066
+ )
1067
+
1068
+ for base_message_item in output_messages:
1069
+ base_message_item.content = json.dumps(tool_call_record.result)
1070
+
1071
+ # Recover the original tools
1072
+ self.model_backend.model_config_dict = original_model_dict
1073
+
1074
+ return (
1075
+ output_messages,
1076
+ finish_reasons,
1077
+ usage_dict,
1078
+ response_id,
1079
+ tool_call_record,
1080
+ num_tokens,
1081
+ )
1082
+
1083
+ def _step_model_response(
1084
+ self,
1085
+ openai_messages: List[OpenAIMessage],
1086
+ num_tokens: int,
1087
+ ) -> tuple[
1088
+ Union[ChatCompletion, Stream],
1089
+ List[BaseMessage],
1090
+ List[str],
1091
+ Dict[str, int],
1092
+ str,
1093
+ ]:
1094
+ r"""Internal function for agent step model response."""
1095
+
1096
+ response = None
1097
+ # Obtain the model's response
1098
+ for _ in range(len(self.model_backend.models)):
1099
+ try:
1100
+ response = self.model_backend.run(openai_messages)
1101
+ break
1102
+ except Exception as exc:
1103
+ logger.error(
1104
+ f"An error occurred while running model "
1105
+ f"{self.model_backend.model_type}, "
1106
+ f"index: {self.model_backend.current_model_index}",
1107
+ exc_info=exc,
1108
+ )
1109
+ continue
1110
+ if not response:
1111
+ raise ModelProcessingError(
1112
+ "Unable to process messages: none of the provided models "
1113
+ "run succesfully."
1114
+ )
1115
+
1116
+ logger.info(
1117
+ f"Model {self.model_backend.model_type}, "
1118
+ f"index {self.model_backend.current_model_index}, "
1119
+ f"processed these messages: {openai_messages}"
1120
+ )
1121
+
1122
+ if isinstance(response, ChatCompletion):
1123
+ output_messages, finish_reasons, usage_dict, response_id = (
1124
+ self.handle_batch_response(response)
1125
+ )
1126
+ else:
1127
+ output_messages, finish_reasons, usage_dict, response_id = (
1128
+ self.handle_stream_response(response, num_tokens)
1129
+ )
1130
+ return (
1131
+ response,
1132
+ output_messages,
1133
+ finish_reasons,
1134
+ usage_dict,
1135
+ response_id,
1136
+ )
1137
+
1138
+ def _step_get_info(
1139
+ self,
1140
+ output_messages: List[BaseMessage],
1141
+ finish_reasons: List[str],
1142
+ usage_dict: Dict[str, int],
1143
+ response_id: str,
1144
+ tool_calls: List[FunctionCallingRecord],
1145
+ num_tokens: int,
1146
+ external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
1147
+ ) -> Dict[str, Any]:
1148
+ r"""Process the output of a chat step and gather information about the
1149
+ step.
1150
+
1151
+ This method checks for termination conditions, updates the agent's
1152
+ state, and collects information about the chat step, including tool
1153
+ calls and termination reasons.
1154
+
1155
+ Args:
1156
+ output_messages (List[BaseMessage]): The messages generated in
1157
+ this step.
1158
+ finish_reasons (List[str]): The reasons for finishing the
1159
+ generation for each message.
1160
+ usage_dict (Dict[str, int]): Dictionary containing token usage
1161
+ information.
1162
+ response_id (str): The ID of the response from the model.
1163
+ tool_calls (List[FunctionCallingRecord]): Records of function calls
1164
+ made during this step.
1165
+ num_tokens (int): The number of tokens used in this step.
1166
+ external_tool_request (Optional[ChatCompletionMessageToolCall]):
1167
+ Any external tool request made during this step.
1168
+ (default: :obj:`None`)
1169
+
1170
+ Returns:
1171
+ Dict[str, Any]: A dictionary containing information about the chat
1172
+ step, including termination status, reasons, and tool call
1173
+ information.
1174
+
1175
+ Note:
1176
+ This method iterates over all response terminators and checks if
1177
+ any of them signal termination. If a terminator signals
1178
+ termination, the agent's state is updated accordingly, and the
1179
+ termination reason is recorded.
1180
+ """
1181
+ termination = [
1182
+ terminator.is_terminated(output_messages)
1183
+ for terminator in self.response_terminators
1184
+ ]
1185
+ # Terminate the agent if any of the terminator terminates
1186
+ self.terminated, termination_reason = next(
1187
+ (
1188
+ (terminated, termination_reason)
1189
+ for terminated, termination_reason in termination
1190
+ if terminated
1191
+ ),
1192
+ (False, None),
1193
+ )
1194
+ # For now only retain the first termination reason
1195
+ if self.terminated and termination_reason is not None:
1196
+ finish_reasons = [termination_reason] * len(finish_reasons)
1197
+
1198
+ info = self.get_info(
1199
+ response_id,
1200
+ usage_dict,
1201
+ finish_reasons,
1202
+ num_tokens,
1203
+ tool_calls,
1204
+ external_tool_request,
1205
+ )
1206
+ return info
1207
+
1208
+ def handle_batch_response(
1209
+ self, response: ChatCompletion
1210
+ ) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
1211
+ r"""Process a batch response from the model and extract the necessary
1212
+ information.
1213
+
1214
+ Args:
1215
+ response (dict): Model response.
1216
+
1217
+ Returns:
1218
+ tuple: A tuple of list of output `ChatMessage`, list of
1219
+ finish reasons, usage dictionary, and response id.
1220
+ """
1221
+ output_messages: List[BaseMessage] = []
1222
+ for choice in response.choices:
1223
+ chat_message = BaseMessage(
1224
+ role_name=self.role_name,
1225
+ role_type=self.role_type,
1226
+ meta_dict=dict(),
1227
+ content=choice.message.content or "",
1228
+ parsed=getattr(choice.message, 'parsed', None),
1229
+ )
1230
+ # Process log probabilities and append to the message meta information
1231
+ if choice.logprobs is not None:
1232
+ tokens_logprobs = choice.logprobs.content
1233
+
1234
+ if tokens_logprobs is not None:
1235
+ # Extract and structure logprob information
1236
+ logprobs_info = [
1237
+ {
1238
+ "token": token_logprob.token,
1239
+ "logprob": token_logprob.logprob,
1240
+ "top_logprobs": [
1241
+ (top_logprob.token, top_logprob.logprob)
1242
+ for top_logprob in token_logprob.top_logprobs
1243
+ ],
1244
+ }
1245
+ for token_logprob in tokens_logprobs
1246
+ ]
1247
+ # Ensure meta_dict exists before adding logprobs info
1248
+ if chat_message.meta_dict is None:
1249
+ chat_message.meta_dict = {}
1250
+ chat_message.meta_dict["logprobs_info"] = logprobs_info
1251
+ # Append the processed chat message to output
1252
+ output_messages.append(chat_message)
1253
+
1254
+ finish_reasons = [
1255
+ str(choice.finish_reason) for choice in response.choices
1256
+ ]
1257
+ usage = (
1258
+ self._safe_model_dump(response.usage)
1259
+ if response.usage is not None
1260
+ else {}
1261
+ )
1262
+ return (
1263
+ output_messages,
1264
+ finish_reasons,
1265
+ usage,
1266
+ response.id,
1267
+ )
1268
+
1269
+ def _safe_model_dump(self, obj) -> dict:
1270
+ r"""Safely dump a Pydantic model to a dictionary.
1271
+
1272
+ This method attempts to use the `model_dump` method if available,
1273
+ otherwise it falls back to the `dict` method.
1274
+
1275
+ Args:
1276
+ obj: The Pydantic model instance to be dumped.
1277
+
1278
+ Returns:
1279
+ dict: A dictionary representation of the Pydantic model.
1280
+ """
1281
+ # Check if the `model_dump` method exists (Pydantic v2)
1282
+ if hasattr(obj, 'model_dump'):
1283
+ return obj.model_dump()
1284
+ # Fallback to `dict()` method (Pydantic v1)
1285
+ elif hasattr(obj, 'dict'):
1286
+ return obj.dict()
1287
+ else:
1288
+ raise TypeError("The object is not a Pydantic model")
1289
+
1290
+ def handle_stream_response(
1291
+ self,
1292
+ response: Stream[ChatCompletionChunk],
1293
+ prompt_tokens: int,
1294
+ ) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
1295
+ r"""Process a stream response from the model and extract the necessary
1296
+ information.
1297
+
1298
+ Args:
1299
+ response (dict): Model response.
1300
+ prompt_tokens (int): Number of input prompt tokens.
1301
+
1302
+ Returns:
1303
+ tuple: A tuple of list of output `ChatMessage`, list of
1304
+ finish reasons, usage dictionary, and response id.
1305
+ """
1306
+ content_dict: defaultdict = defaultdict(lambda: "")
1307
+ finish_reasons_dict: defaultdict = defaultdict(lambda: "")
1308
+ output_messages: List[BaseMessage] = []
1309
+ response_id: str = ""
1310
+ # All choices in one response share one role
1311
+ for chunk in response:
1312
+ response_id = chunk.id
1313
+ for choice in chunk.choices:
1314
+ index = choice.index
1315
+ delta = choice.delta
1316
+ if delta.content is not None:
1317
+ # When response has not been stopped
1318
+ # Notice that only the first chunk_dict has the "role"
1319
+ content_dict[index] += delta.content
1320
+ if choice.finish_reason:
1321
+ finish_reasons_dict[index] = choice.finish_reason
1322
+ chat_message = BaseMessage(
1323
+ role_name=self.role_name,
1324
+ role_type=self.role_type,
1325
+ meta_dict=dict(),
1326
+ content=content_dict[index],
1327
+ )
1328
+ output_messages.append(chat_message)
1329
+ finish_reasons = [
1330
+ finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
1331
+ ]
1332
+ usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
1333
+ return output_messages, finish_reasons, usage_dict, response_id
1334
+
1335
+ def _step_token_exceed(
1336
+ self,
1337
+ num_tokens: int,
1338
+ tool_calls: List[FunctionCallingRecord],
1339
+ termination_reason: str,
1340
+ ) -> ChatAgentResponse:
1341
+ r"""Return trivial response containing number of tokens and information
1342
+ of called functions when the number of tokens exceeds.
1343
+
1344
+ Args:
1345
+ num_tokens (int): Number of tokens in the messages.
1346
+ tool_calls (List[FunctionCallingRecord]): List of information
1347
+ objects of functions called in the current step.
1348
+ termination_reason (str): String of termination reason.
1349
+
1350
+ Returns:
1351
+ ChatAgentResponse: The struct containing trivial outputs and
1352
+ information about token number and called functions.
1353
+ """
1354
+ self.terminated = True
1355
+ output_messages: List[BaseMessage] = []
1356
+
1357
+ info = self.get_info(
1358
+ None,
1359
+ None,
1360
+ [termination_reason],
1361
+ num_tokens,
1362
+ tool_calls,
1363
+ )
1364
+
1365
+ return ChatAgentResponse(
1366
+ msgs=output_messages,
1367
+ terminated=self.terminated,
1368
+ info=info,
1369
+ )
1370
+
1371
+ def _step_tool_call(
1372
+ self,
1373
+ response: ChatCompletion,
1374
+ ) -> Tuple[
1375
+ FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
1376
+ ]:
1377
+ r"""Execute the function with arguments following the model's response.
1378
+
1379
+ Args:
1380
+ response (Dict[str, Any]): The response obtained by calling the
1381
+ model.
1382
+
1383
+ Returns:
1384
+ tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
1385
+ one about the arguments and the other about the execution
1386
+ result, and a struct for logging information about this
1387
+ function call.
1388
+ """
1389
+ choice = response.choices[0]
1390
+ if choice.message.tool_calls is None:
1391
+ raise RuntimeError("Tool call is None")
1392
+ func_name = choice.message.tool_calls[0].function.name
1393
+
1394
+ arguments_str = choice.message.tool_calls[0].function.arguments
1395
+ args = self._safe_json_loads(arguments_str)
1396
+
1397
+ tool = self.tool_dict[func_name]
1398
+ result = tool(**args)
1399
+ tool_call_id = choice.message.tool_calls[0].id
1400
+
1401
+ assist_msg = FunctionCallingMessage(
1402
+ role_name=self.role_name,
1403
+ role_type=self.role_type,
1404
+ meta_dict=None,
1405
+ content="",
1406
+ func_name=func_name,
1407
+ args=args,
1408
+ tool_call_id=tool_call_id,
1409
+ )
1410
+ func_msg = FunctionCallingMessage(
1411
+ role_name=self.role_name,
1412
+ role_type=self.role_type,
1413
+ meta_dict=None,
1414
+ content="",
1415
+ func_name=func_name,
1416
+ result=result,
1417
+ tool_call_id=tool_call_id,
1418
+ )
1419
+
1420
+ # Record information about this function call
1421
+ func_record = FunctionCallingRecord(
1422
+ func_name=func_name,
1423
+ args=args,
1424
+ result=result,
1425
+ tool_call_id=tool_call_id,
1426
+ )
1427
+ return assist_msg, func_msg, func_record
1428
+
1429
+ def _safe_json_loads(self, arguments_str):
1430
+ # Replace Python types with their JSON equivalents
1431
+ arguments_str = arguments_str.replace("None", "null")
1432
+ arguments_str = arguments_str.replace("True", "true")
1433
+ arguments_str = arguments_str.replace("False", "false")
1434
+
1435
+ # Attempt to parse the corrected string
1436
+ try:
1437
+ return json.loads(arguments_str)
1438
+ except json.JSONDecodeError as e:
1439
+ raise ValueError(f"Invalid JSON format: {e}")
1440
+
1441
+ async def step_tool_call_async(
1442
+ self,
1443
+ response: ChatCompletion,
1444
+ ) -> Tuple[
1445
+ FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
1446
+ ]:
1447
+ r"""Execute the async function with arguments following the model's
1448
+ response.
1449
+
1450
+ Args:
1451
+ response (Dict[str, Any]): The response obtained by calling the
1452
+ model.
1453
+
1454
+ Returns:
1455
+ tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
1456
+ one about the arguments and the other about the execution
1457
+ result, and a struct for logging information about this
1458
+ function call.
1459
+ """
1460
+ # Note that when function calling is enabled, `n` is set to 1.
1461
+ choice = response.choices[0]
1462
+ if choice.message.tool_calls is None:
1463
+ raise RuntimeError("Tool call is None")
1464
+ func_name = choice.message.tool_calls[0].function.name
1465
+
1466
+ args = json.loads(choice.message.tool_calls[0].function.arguments)
1467
+ tool = self.tool_dict[func_name]
1468
+ result = await tool(**args)
1469
+ tool_call_id = choice.message.tool_calls[0].id
1470
+
1471
+ assist_msg = FunctionCallingMessage(
1472
+ role_name=self.role_name,
1473
+ role_type=self.role_type,
1474
+ meta_dict=None,
1475
+ content="",
1476
+ func_name=func_name,
1477
+ args=args,
1478
+ tool_call_id=tool_call_id,
1479
+ )
1480
+ func_msg = FunctionCallingMessage(
1481
+ role_name=self.role_name,
1482
+ role_type=self.role_type,
1483
+ meta_dict=None,
1484
+ content="",
1485
+ func_name=func_name,
1486
+ result=result,
1487
+ tool_call_id=tool_call_id,
1488
+ )
1489
+
1490
+ # Record information about this function call
1491
+ func_record = FunctionCallingRecord(
1492
+ func_name=func_name,
1493
+ args=args,
1494
+ result=result,
1495
+ tool_call_id=tool_call_id,
1496
+ )
1497
+ return assist_msg, func_msg, func_record
1498
+
1499
+ def get_usage_dict(
1500
+ self, output_messages: List[BaseMessage], prompt_tokens: int
1501
+ ) -> Dict[str, int]:
1502
+ r"""Get usage dictionary when using the stream mode.
1503
+
1504
+ Args:
1505
+ output_messages (list): List of output messages.
1506
+ prompt_tokens (int): Number of input prompt tokens.
1507
+
1508
+ Returns:
1509
+ dict: Usage dictionary.
1510
+ """
1511
+ encoding = get_model_encoding(self.model_type.value_for_tiktoken)
1512
+ completion_tokens = 0
1513
+ for message in output_messages:
1514
+ completion_tokens += len(encoding.encode(message.content))
1515
+ usage_dict = dict(
1516
+ completion_tokens=completion_tokens,
1517
+ prompt_tokens=prompt_tokens,
1518
+ total_tokens=completion_tokens + prompt_tokens,
1519
+ )
1520
+ return usage_dict
1521
+
1522
+ def add_model_scheduling_strategy(self, name: str, strategy_fn: Callable):
1523
+ r"""Add a scheduling strategy method provided by user to ModelManger.
1524
+
1525
+ Args:
1526
+ name (str): The name of the strategy.
1527
+ strategy_fn (Callable): The scheduling strategy function.
1528
+ """
1529
+ self.model_backend.add_strategy(name, strategy_fn)
1530
+
1531
+ def __repr__(self) -> str:
1532
+ r"""Returns a string representation of the :obj:`ChatAgent`.
1533
+
1534
+ Returns:
1535
+ str: The string representation of the :obj:`ChatAgent`.
1536
+ """
1537
+ return (
1538
+ f"ChatAgent({self.role_name}, {self.role_type}, {self.model_type})"
1539
+ )
Paper2Poster/camel/agents/critic_agent.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import random
15
+ import warnings
16
+ from typing import Any, Dict, Optional, Sequence
17
+
18
+ from colorama import Fore
19
+
20
+ from camel.agents.chat_agent import ChatAgent
21
+ from camel.memories import AgentMemory
22
+ from camel.messages import BaseMessage
23
+ from camel.models import BaseModelBackend
24
+ from camel.responses import ChatAgentResponse
25
+ from camel.utils import get_first_int, print_text_animated
26
+
27
+ # AgentOps decorator setting
28
+ try:
29
+ import os
30
+
31
+ if os.getenv("AGENTOPS_API_KEY") is not None:
32
+ from agentops import track_agent
33
+ else:
34
+ raise ImportError
35
+ except (ImportError, AttributeError):
36
+ from camel.utils import track_agent
37
+
38
+
39
+ @track_agent(name="CriticAgent")
40
+ class CriticAgent(ChatAgent):
41
+ r"""A class for the critic agent that assists in selecting an option.
42
+
43
+ Args:
44
+ system_message (BaseMessage): The system message for the critic
45
+ agent.
46
+ model (BaseModelBackend, optional): The model backend to use for
47
+ generating responses. (default: :obj:`OpenAIModel` with
48
+ `GPT_4O_MINI`)
49
+ message_window_size (int, optional): The maximum number of previous
50
+ messages to include in the context window. If `None`, no windowing
51
+ is performed. (default: :obj:`6`)
52
+ retry_attempts (int, optional): The number of retry attempts if the
53
+ critic fails to return a valid option. (default: :obj:`2`)
54
+ verbose (bool, optional): Whether to print the critic's messages.
55
+ logger_color (Any): The color of the menu options displayed to the
56
+ user. (default: :obj:`Fore.MAGENTA`)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ system_message: BaseMessage,
62
+ model: Optional[BaseModelBackend] = None,
63
+ memory: Optional[AgentMemory] = None,
64
+ message_window_size: int = 6,
65
+ retry_attempts: int = 2,
66
+ verbose: bool = False,
67
+ logger_color: Any = Fore.MAGENTA,
68
+ ) -> None:
69
+ super().__init__(
70
+ system_message,
71
+ model=model,
72
+ memory=memory,
73
+ message_window_size=message_window_size,
74
+ )
75
+ self.options_dict: Dict[str, str] = dict()
76
+ self.retry_attempts = retry_attempts
77
+ self.verbose = verbose
78
+ self.logger_color = logger_color
79
+
80
+ def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
81
+ r"""Flattens the options to the critic.
82
+
83
+ Args:
84
+ messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
85
+
86
+ Returns:
87
+ str: A string containing the flattened options to the critic.
88
+ """
89
+ options = [message.content for message in messages]
90
+ flatten_options = (
91
+ f"> Proposals from "
92
+ f"{messages[0].role_name} ({messages[0].role_type}). "
93
+ "Please choose an option:\n"
94
+ )
95
+ for index, option in enumerate(options):
96
+ flatten_options += f"Option {index + 1}:\n{option}\n\n"
97
+ self.options_dict[str(index + 1)] = option
98
+ format = (
99
+ f"Please first enter your choice ([1-{len(self.options_dict)}]) "
100
+ "and then your explanation and comparison: "
101
+ )
102
+ return flatten_options + format
103
+
104
+ def get_option(self, input_message: BaseMessage) -> str:
105
+ r"""Gets the option selected by the critic.
106
+
107
+ Args:
108
+ input_message (BaseMessage): A `BaseMessage` object representing
109
+ the input message.
110
+
111
+ Returns:
112
+ str: The option selected by the critic.
113
+ """
114
+ # TODO: Add support for editing options by the critic.
115
+ msg_content = input_message.content
116
+ i = 0
117
+ while i < self.retry_attempts:
118
+ critic_response = self.step(input_message)
119
+
120
+ if critic_response.msgs is None or len(critic_response.msgs) == 0:
121
+ raise RuntimeError("Got None critic messages.")
122
+ if critic_response.terminated:
123
+ raise RuntimeError("Critic step failed.")
124
+
125
+ critic_msg = critic_response.msg
126
+ if self.verbose:
127
+ print_text_animated(
128
+ self.logger_color + "\n> Critic response: "
129
+ f"\x1b[3m{critic_msg.content}\x1b[0m\n"
130
+ )
131
+ choice = self.parse_critic(critic_msg)
132
+
133
+ if choice in self.options_dict:
134
+ return self.options_dict[choice]
135
+ else:
136
+ input_message = BaseMessage(
137
+ role_name=input_message.role_name,
138
+ role_type=input_message.role_type,
139
+ meta_dict=input_message.meta_dict,
140
+ content="> Invalid choice. Please choose again.\n"
141
+ + msg_content,
142
+ )
143
+ i += 1
144
+ warnings.warn(
145
+ "Critic failed to get a valid option. "
146
+ f"After {self.retry_attempts} attempts. "
147
+ "Returning a random option."
148
+ )
149
+ return random.choice(list(self.options_dict.values()))
150
+
151
+ def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
152
+ r"""Parses the critic's message and extracts the choice.
153
+
154
+ Args:
155
+ critic_msg (BaseMessage): A `BaseMessage` object representing the
156
+ critic's response.
157
+
158
+ Returns:
159
+ Optional[str]: The critic's choice as a string, or None if the
160
+ message could not be parsed.
161
+ """
162
+ choice = str(get_first_int(critic_msg.content))
163
+ return choice
164
+
165
+ def reduce_step(
166
+ self,
167
+ input_messages: Sequence[BaseMessage],
168
+ ) -> ChatAgentResponse:
169
+ r"""Performs one step of the conversation by flattening options to the
170
+ critic, getting the option, and parsing the choice.
171
+
172
+ Args:
173
+ input_messages (Sequence[BaseMessage]): A list of BaseMessage
174
+ objects.
175
+
176
+ Returns:
177
+ ChatAgentResponse: A `ChatAgentResponse` object includes the
178
+ critic's choice.
179
+ """
180
+ meta_chat_message = BaseMessage(
181
+ role_name=input_messages[0].role_name,
182
+ role_type=input_messages[0].role_type,
183
+ meta_dict=input_messages[0].meta_dict,
184
+ content="",
185
+ )
186
+
187
+ flatten_options = self.flatten_options(input_messages)
188
+ if self.verbose:
189
+ print_text_animated(
190
+ self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
191
+ )
192
+ input_msg = meta_chat_message.create_new_instance(flatten_options)
193
+
194
+ option = self.get_option(input_msg)
195
+ output_msg = meta_chat_message.create_new_instance(option)
196
+
197
+ # TODO: The return `info` can be improved.
198
+ return ChatAgentResponse(
199
+ msgs=[output_msg],
200
+ terminated=False,
201
+ info={},
202
+ )
Paper2Poster/camel/agents/deductive_reasoner_agent.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import re
15
+ from typing import Dict, List, Optional, Union
16
+
17
+ from camel.agents.chat_agent import ChatAgent
18
+ from camel.logger import get_logger
19
+ from camel.messages import BaseMessage
20
+ from camel.models import BaseModelBackend
21
+ from camel.prompts import TextPrompt
22
+ from camel.types import RoleType
23
+
24
+ logger = get_logger(__name__)
25
+
26
+ # AgentOps decorator setting
27
+ try:
28
+ import os
29
+
30
+ if os.getenv("AGENTOPS_API_KEY") is not None:
31
+ from agentops import track_agent
32
+ else:
33
+ raise ImportError
34
+ except (ImportError, AttributeError):
35
+ from camel.utils import track_agent
36
+
37
+
38
+ @track_agent(name="DeductiveReasonerAgent")
39
+ class DeductiveReasonerAgent(ChatAgent):
40
+ r"""An agent responsible for deductive reasoning. Model of deductive
41
+ reasoning:
42
+ - L: A ⊕ C -> q * B
43
+ - A represents the known starting state.
44
+ - B represents the known target state.
45
+ - C represents the conditions required to transition from A to B.
46
+ - Q represents the quality or effectiveness of the transition from
47
+ A to B.
48
+ - L represents the path or process from A to B.
49
+
50
+ Args:
51
+ model (BaseModelBackend, optional): The model backend to use for
52
+ generating responses. (default: :obj:`OpenAIModel` with
53
+ `GPT_4O_MINI`)
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model: Optional[BaseModelBackend] = None,
59
+ ) -> None:
60
+ system_message = BaseMessage(
61
+ role_name="Insight Agent",
62
+ role_type=RoleType.ASSISTANT,
63
+ meta_dict=None,
64
+ content="You assign roles based on tasks.",
65
+ )
66
+ super().__init__(system_message, model=model)
67
+
68
+ def deduce_conditions_and_quality(
69
+ self,
70
+ starting_state: str,
71
+ target_state: str,
72
+ role_descriptions_dict: Optional[Dict[str, str]] = None,
73
+ ) -> Dict[str, Union[List[str], Dict[str, str]]]:
74
+ r"""Derives the conditions and quality from the starting state and the
75
+ target state based on the model of the deductive reasoning and the
76
+ knowledge base. It can optionally consider the roles involved in the
77
+ scenario, which allows tailoring the output more closely to the AI
78
+ agent's environment.
79
+
80
+ Args:
81
+ starting_state (str): The initial or starting state from which
82
+ conditions are deduced.
83
+ target_state (str): The target state of the task.
84
+ role_descriptions_dict (Optional[Dict[str, str]], optional): The
85
+ descriptions of the roles. (default: :obj:`None`)
86
+ role_descriptions_dict (Optional[Dict[str, str]], optional): A
87
+ dictionary describing the roles involved in the scenario. This
88
+ is optional and can be used to provide a context for the
89
+ CAMEL's role-playing, enabling the generation of more relevant
90
+ and tailored conditions and quality assessments. This could be
91
+ generated using a `RoleAssignmentAgent()` or defined manually
92
+ by the user.
93
+
94
+ Returns:
95
+ Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
96
+ extracted data from the message. The dictionary contains three
97
+ keys:
98
+ - 'conditions': A list where each key is a condition ID and
99
+ each value is the corresponding condition text.
100
+ - 'labels': A list of label strings extracted from the message.
101
+ - 'quality': A string of quality assessment strings extracted
102
+ from the message.
103
+ """
104
+ self.reset()
105
+
106
+ deduce_prompt = """You are a deductive reasoner. You are tasked to
107
+ complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
108
+ STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
109
+ CONTENT to help you complete the TASK.
110
+ Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
111
+ fill in the BLANKs, and DO NOT alter or modify any other part of the template
112
+
113
+ ===== MODELING OF DEDUCTIVE REASONING =====
114
+ You are tasked with understanding a mathematical model based on the components
115
+ ${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
116
+ - $A$ represents the known starting state.
117
+ - $B$ represents the known target state.
118
+ - $C$ represents the conditions required to transition from $A$ to $B$.
119
+ - $Q$ represents the quality or effectiveness of the transition from $A$ to
120
+ $B$.
121
+ - $L$ represents the path or process from $A$ to $B$.
122
+
123
+ ===== THOUGHT OF DEDUCTIVE REASONING =====
124
+ 1. Define the Parameters of A and B:
125
+ - Characterization: Before delving into transitions, thoroughly understand
126
+ the nature and boundaries of both $A$ and $B$. This includes the type,
127
+ properties, constraints, and possible interactions between the two.
128
+ - Contrast and Compare: Highlight the similarities and differences between
129
+ $A$ and $B$. This comparative analysis will give an insight into what
130
+ needs changing and what remains constant.
131
+ 2. Historical & Empirical Analysis:
132
+ - Previous Transitions according to the Knowledge Base of GPT: (if
133
+ applicable) Extract conditions and patterns from the historical instances
134
+ where a similar transition from a state comparable to $A$ moved towards
135
+ $B$.
136
+ - Scientific Principles: (if applicable) Consider the underlying
137
+ scientific principles governing or related to the states and their
138
+ transition. For example, if $A$ and $B$ are physical states, laws of
139
+ physics might apply.
140
+ 3. Logical Deduction of Conditions ($C$):
141
+ - Direct Path Analysis: What are the immediate and direct conditions
142
+ required to move from $A$ to $B$?
143
+ - Intermediate States: Are there states between $A$ and $B$ that must be
144
+ traversed or can be used to make the transition smoother or more
145
+ efficient? If yes, what is the content?
146
+ - Constraints & Limitations: Identify potential barriers or restrictions
147
+ in moving from $A$ to $B$. These can be external (e.g., environmental
148
+ factors) or internal (properties of $A$ or $B$).
149
+ - Resource and Information Analysis: What resources and information are
150
+ required for the transition? This could be time, entity, factor, code
151
+ language, software platform, unknowns, etc.
152
+ - External Influences: Consider socio-economic, political, or
153
+ environmental factors (if applicable) that could influence the transition
154
+ conditions.
155
+ - Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
156
+ no matter how unconventional they might seem. Utilize analogies,
157
+ metaphors, or brainstorming techniques to envision possible conditions or
158
+ paths from $A$ to $B$.
159
+ - The conditions $C$ should be multiple but in one sentence. And each
160
+ condition should be concerned with one aspect/entity.
161
+ 4. Entity/Label Recognition of Conditions ($C$):
162
+ - Identify and categorize entities of Conditions ($C$) such as the names,
163
+ locations, dates, specific technical terms or contextual parameters that
164
+ might be associated with events, innovations post-2022.
165
+ - The output of the entities/labels will be used as tags or labels for
166
+ semantic similarity searches. The entities/labels may be the words, or
167
+ phrases, each of them should contain valuable, high information entropy
168
+ information, and should be independent.
169
+ - Ensure that the identified entities are formatted in a manner suitable
170
+ for database indexing and retrieval. Organize the entities into
171
+ categories, and combine the category with its instance into a continuous
172
+ phrase, without using colons or other separators.
173
+ - Format these entities for database indexing: output the category rather
174
+ than its instance/content into a continuous phrase. For example, instead
175
+ of "Jan. 02", identify it as "Event time".
176
+ 5. Quality Assessment ($Q$):
177
+ - Efficiency: How efficient is the transition from $A$ to $B$, which
178
+ measures the resources used versus the desired outcome?
179
+ - Effectiveness: Did the transition achieve the desired outcome or was the
180
+ target state achieved as intended?
181
+ - Safety & Risks: Assess any risks associated with the transition and the
182
+ measures to mitigate them.
183
+ - Feedback Mechanisms: Incorporate feedback loops to continuously monitor
184
+ and adjust the quality of transition, making it more adaptive.
185
+ 6. Iterative Evaluation:
186
+ - Test & Refine: Based on the initially deduced conditions and assessed
187
+ quality, iterate the process to refine and optimize the transition. This
188
+ might involve tweaking conditions, employing different paths, or changing
189
+ resources.
190
+ - Feedback Integration: Use feedback to make improvements and increase the
191
+ quality of the transition.
192
+ 7. Real-world scenarios often present challenges that may not be captured by
193
+ models and frameworks. While using the model, maintain an adaptive mindset:
194
+ - Scenario Exploration: Continuously imagine various possible scenarios,
195
+ both positive and negative, to prepare for unexpected events.
196
+ - Flexibility: Be prepared to modify conditions ($C$) or alter the path/
197
+ process ($L$) if unforeseen challenges arise.
198
+ - Feedback Integration: Rapidly integrate feedback from actual
199
+ implementations to adjust the model's application, ensuring relevancy and
200
+ effectiveness.
201
+
202
+ ===== TASK =====
203
+ Given the starting state $A$ and the target state $B$, assuming that a path
204
+ $L$ always exists between $A$ and $B$, how can one deduce or identify the
205
+ necessary conditions $C$ and the quality $Q$ of the transition?
206
+
207
+ ===== STARTING STATE $A$ =====
208
+ {starting_state}
209
+
210
+ ===== TARGET STATE $B$ =====
211
+ {target_state}
212
+
213
+ {role_with_description_prompt}
214
+ ===== ANSWER TEMPLATE =====
215
+ - Characterization and comparison of $A$ and $B$:\n<BLANK>
216
+ - Historical & Empirical Analysis:\n<BLANK>/None
217
+ - Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
218
+ condition <NUM>:
219
+ <BLANK>.
220
+ - Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
221
+ square brackets)
222
+ - Quality Assessment ($Q$) (do not use symbols):
223
+ <BLANK>.
224
+ - Iterative Evaluation:\n<BLANK>/None"""
225
+
226
+ if role_descriptions_dict is not None:
227
+ role_names = role_descriptions_dict.keys()
228
+ role_with_description_prompt = (
229
+ "===== ROLES WITH DESCRIPTIONS =====\n"
230
+ + "\n".join(
231
+ f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
232
+ for role_name in role_names
233
+ )
234
+ + "\n\n"
235
+ )
236
+ else:
237
+ role_with_description_prompt = ""
238
+ deduce_prompt = TextPrompt(deduce_prompt)
239
+
240
+ deduce = deduce_prompt.format(
241
+ starting_state=starting_state,
242
+ target_state=target_state,
243
+ role_with_description_prompt=role_with_description_prompt,
244
+ )
245
+
246
+ conditions_and_quality_generation_msg = BaseMessage.make_user_message(
247
+ role_name="Deductive Reasoner", content=deduce
248
+ )
249
+
250
+ response = self.step(
251
+ input_message=conditions_and_quality_generation_msg
252
+ )
253
+
254
+ if response.terminated:
255
+ raise RuntimeError(
256
+ "Deduction failed. Error:\n" + f"{response.info}"
257
+ )
258
+ msg: BaseMessage = response.msg
259
+ logger.info(f"Message content:\n{msg.content}")
260
+
261
+ # Extract the conditions from the message
262
+ conditions_dict = {
263
+ f"condition {i}": cdt.replace("<", "")
264
+ .replace(">", "")
265
+ .strip()
266
+ .strip('\n')
267
+ for i, cdt in re.findall(
268
+ r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
269
+ msg.content,
270
+ re.DOTALL,
271
+ )
272
+ }
273
+
274
+ # Extract the labels from the message
275
+ labels = [
276
+ label.strip().strip('\n').strip("\"'")
277
+ for label in re.findall(
278
+ r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
279
+ msg.content,
280
+ re.DOTALL,
281
+ )[0].split(",")
282
+ ]
283
+
284
+ # Extract the quality from the message
285
+ quality = next(
286
+ q.strip().strip('\n')
287
+ for q in re.findall(
288
+ r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
289
+ r"\n(.+?)- Iterative",
290
+ msg.content,
291
+ re.DOTALL,
292
+ )
293
+ )
294
+
295
+ # Convert them into JSON format
296
+ conditions_and_quality_json: Dict[
297
+ str, Union[List[str], Dict[str, str]]
298
+ ] = {}
299
+ conditions_and_quality_json["conditions"] = conditions_dict
300
+ conditions_and_quality_json["labels"] = labels
301
+ conditions_and_quality_json["evaluate_quality"] = quality
302
+
303
+ return conditions_and_quality_json
Paper2Poster/camel/agents/embodied_agent.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, List, Optional
15
+
16
+ from colorama import Fore
17
+
18
+ from camel.agents.chat_agent import ChatAgent
19
+ from camel.agents.tool_agents.base import BaseToolAgent
20
+ from camel.interpreters import (
21
+ BaseInterpreter,
22
+ InternalPythonInterpreter,
23
+ SubprocessInterpreter,
24
+ )
25
+ from camel.messages import BaseMessage
26
+ from camel.models import BaseModelBackend
27
+ from camel.responses import ChatAgentResponse
28
+ from camel.utils import print_text_animated
29
+
30
+ # AgentOps decorator setting
31
+ try:
32
+ import os
33
+
34
+ if os.getenv("AGENTOPS_API_KEY") is not None:
35
+ from agentops import track_agent
36
+ else:
37
+ raise ImportError
38
+ except (ImportError, AttributeError):
39
+ from camel.utils import track_agent
40
+
41
+
42
+ @track_agent(name="EmbodiedAgent")
43
+ class EmbodiedAgent(ChatAgent):
44
+ r"""Class for managing conversations of CAMEL Embodied Agents.
45
+
46
+ Args:
47
+ system_message (BaseMessage): The system message for the chat agent.
48
+ model (BaseModelBackend, optional): The model backend to use for
49
+ generating responses. (default: :obj:`OpenAIModel` with
50
+ `GPT_4O_MINI`)
51
+ message_window_size (int, optional): The maximum number of previous
52
+ messages to include in the context window. If `None`, no windowing
53
+ is performed. (default: :obj:`None`)
54
+ tool_agents (List[BaseToolAgent], optional): The tools agents to use in
55
+ the embodied agent. (default: :obj:`None`)
56
+ code_interpreter (BaseInterpreter, optional): The code interpreter to
57
+ execute codes. If `code_interpreter` and `tool_agent` are both
58
+ `None`, default to `SubProcessInterpreter`. If `code_interpreter`
59
+ is `None` and `tool_agents` is not `None`, default to
60
+ `InternalPythonInterpreter`. (default: :obj:`None`)
61
+ verbose (bool, optional): Whether to print the critic's messages.
62
+ logger_color (Any): The color of the logger displayed to the user.
63
+ (default: :obj:`Fore.MAGENTA`)
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ system_message: BaseMessage,
69
+ model: Optional[BaseModelBackend] = None,
70
+ message_window_size: Optional[int] = None,
71
+ tool_agents: Optional[List[BaseToolAgent]] = None,
72
+ code_interpreter: Optional[BaseInterpreter] = None,
73
+ verbose: bool = False,
74
+ logger_color: Any = Fore.MAGENTA,
75
+ ) -> None:
76
+ self.tool_agents = tool_agents
77
+ self.code_interpreter: BaseInterpreter
78
+ if code_interpreter is not None:
79
+ self.code_interpreter = code_interpreter
80
+ elif self.tool_agents:
81
+ self.code_interpreter = InternalPythonInterpreter()
82
+ else:
83
+ self.code_interpreter = SubprocessInterpreter()
84
+
85
+ if self.tool_agents:
86
+ system_message = self._set_tool_agents(system_message)
87
+ self.verbose = verbose
88
+ self.logger_color = logger_color
89
+ super().__init__(
90
+ system_message=system_message,
91
+ model=model,
92
+ message_window_size=message_window_size,
93
+ )
94
+
95
+ def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
96
+ action_space_prompt = self._get_tool_agents_prompt()
97
+ result_message = system_message.create_new_instance(
98
+ content=system_message.content.format(
99
+ action_space=action_space_prompt
100
+ )
101
+ )
102
+ if self.tool_agents is not None:
103
+ self.code_interpreter.update_action_space(
104
+ {tool.name: tool for tool in self.tool_agents}
105
+ )
106
+ return result_message
107
+
108
+ def _get_tool_agents_prompt(self) -> str:
109
+ r"""Returns the action space prompt.
110
+
111
+ Returns:
112
+ str: The action space prompt.
113
+ """
114
+ if self.tool_agents is not None:
115
+ return "\n".join(
116
+ [
117
+ f"*** {tool.name} ***:\n {tool.description}"
118
+ for tool in self.tool_agents
119
+ ]
120
+ )
121
+ else:
122
+ return ""
123
+
124
+ def get_tool_agent_names(self) -> List[str]:
125
+ r"""Returns the names of tool agents.
126
+
127
+ Returns:
128
+ List[str]: The names of tool agents.
129
+ """
130
+ if self.tool_agents is not None:
131
+ return [tool.name for tool in self.tool_agents]
132
+ else:
133
+ return []
134
+
135
+ # ruff: noqa: E501
136
+ def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
137
+ r"""Performs a step in the conversation.
138
+
139
+ Args:
140
+ input_message (BaseMessage): The input message.
141
+
142
+ Returns:
143
+ ChatAgentResponse: A struct containing the output messages,
144
+ a boolean indicating whether the chat session has terminated,
145
+ and information about the chat session.
146
+ """
147
+ response = super().step(input_message)
148
+
149
+ if response.msgs is None or len(response.msgs) == 0:
150
+ raise RuntimeError("Got None output messages.")
151
+ if response.terminated:
152
+ raise RuntimeError(f"{self.__class__.__name__} step failed.")
153
+
154
+ # NOTE: Only single output messages are supported
155
+ explanations, codes = response.msg.extract_text_and_code_prompts()
156
+
157
+ if self.verbose:
158
+ for explanation, code in zip(explanations, codes):
159
+ print_text_animated(
160
+ self.logger_color + f"> Explanation:\n{explanation}"
161
+ )
162
+ print_text_animated(self.logger_color + f"> Code:\n{code}")
163
+
164
+ if len(explanations) > len(codes):
165
+ print_text_animated(
166
+ self.logger_color + f"> Explanation:\n{explanations[-1]}"
167
+ )
168
+
169
+ content = response.msg.content
170
+
171
+ if codes is not None:
172
+ try:
173
+ content = "\n> Executed Results:\n"
174
+ for block_idx, code in enumerate(codes):
175
+ executed_output = self.code_interpreter.run(
176
+ code, code.code_type
177
+ )
178
+ content += (
179
+ f"Executing code block {block_idx}: {{\n"
180
+ + executed_output
181
+ + "}\n"
182
+ )
183
+ except InterruptedError as e:
184
+ content = (
185
+ f"\n> Running code fail: {e}\n"
186
+ "Please regenerate the code."
187
+ )
188
+
189
+ # TODO: Handle errors
190
+ content = input_message.content + f"\n> Embodied Actions:\n{content}"
191
+ message = BaseMessage(
192
+ input_message.role_name,
193
+ input_message.role_type,
194
+ input_message.meta_dict,
195
+ content,
196
+ )
197
+ return ChatAgentResponse(
198
+ msgs=[message],
199
+ terminated=response.terminated,
200
+ info=response.info,
201
+ )
Paper2Poster/camel/agents/knowledge_graph_agent.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import TYPE_CHECKING, Optional, Union
15
+
16
+ if TYPE_CHECKING:
17
+ from unstructured.documents.elements import Element
18
+
19
+ from camel.agents import ChatAgent
20
+ from camel.messages import BaseMessage
21
+ from camel.models import BaseModelBackend
22
+ from camel.prompts import TextPrompt
23
+ from camel.storages.graph_storages.graph_element import (
24
+ GraphElement,
25
+ Node,
26
+ Relationship,
27
+ )
28
+ from camel.types import RoleType
29
+
30
+ # AgentOps decorator setting
31
+ try:
32
+ import os
33
+
34
+ if os.getenv("AGENTOPS_API_KEY") is not None:
35
+ from agentops import track_agent
36
+ else:
37
+ raise ImportError
38
+ except (ImportError, AttributeError):
39
+ from camel.utils import track_agent
40
+
41
+
42
+ text_prompt = """
43
+ You are tasked with extracting nodes and relationships from given content and
44
+ structures them into Node and Relationship objects. Here's the outline of what
45
+ you needs to do:
46
+
47
+ Content Extraction:
48
+ You should be able to process input content and identify entities mentioned
49
+ within it.
50
+ Entities can be any noun phrases or concepts that represent distinct entities
51
+ in the context of the given content.
52
+
53
+ Node Extraction:
54
+ For each identified entity, you should create a Node object.
55
+ Each Node object should have a unique identifier (id) and a type (type).
56
+ Additional properties associated with the node can also be extracted and
57
+ stored.
58
+
59
+ Relationship Extraction:
60
+ You should identify relationships between entities mentioned in the content.
61
+ For each relationship, create a Relationship object.
62
+ A Relationship object should have a subject (subj) and an object (obj) which
63
+ are Node objects representing the entities involved in the relationship.
64
+ Each relationship should also have a type (type), and additional properties if
65
+ applicable.
66
+
67
+ Output Formatting:
68
+ The extracted nodes and relationships should be formatted as instances of the
69
+ provided Node and Relationship classes.
70
+ Ensure that the extracted data adheres to the structure defined by the classes.
71
+ Output the structured data in a format that can be easily validated against
72
+ the provided code.
73
+
74
+ Instructions for you:
75
+ Read the provided content thoroughly.
76
+ Identify distinct entities mentioned in the content and categorize them as
77
+ nodes.
78
+ Determine relationships between these entities and represent them as directed
79
+ relationships.
80
+ Provide the extracted nodes and relationships in the specified format below.
81
+ Example for you:
82
+
83
+ Example Content:
84
+ "John works at XYZ Corporation. He is a software engineer. The company is
85
+ located in New York City."
86
+
87
+ Expected Output:
88
+
89
+ Nodes:
90
+
91
+ Node(id='John', type='Person')
92
+ Node(id='XYZ Corporation', type='Organization')
93
+ Node(id='New York City', type='Location')
94
+
95
+ Relationships:
96
+
97
+ Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
98
+ Corporation', type='Organization'), type='WorksAt')
99
+ Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
100
+ type='Location'), type='ResidesIn')
101
+
102
+ ===== TASK =====
103
+ Please extracts nodes and relationships from given content and structures them
104
+ into Node and Relationship objects.
105
+
106
+ {task}
107
+ """
108
+
109
+
110
+ @track_agent(name="KnowledgeGraphAgent")
111
+ class KnowledgeGraphAgent(ChatAgent):
112
+ r"""An agent that can extract node and relationship information for
113
+ different entities from given `Element` content.
114
+
115
+ Attributes:
116
+ task_prompt (TextPrompt): A prompt for the agent to extract node and
117
+ relationship information for different entities.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ model: Optional[BaseModelBackend] = None,
123
+ ) -> None:
124
+ r"""Initialize the `KnowledgeGraphAgent`.
125
+
126
+ Args:
127
+ model (BaseModelBackend, optional): The model backend to use for
128
+ generating responses. (default: :obj:`OpenAIModel` with
129
+ `GPT_4O_MINI`)
130
+ """
131
+ system_message = BaseMessage(
132
+ role_name="Graphify",
133
+ role_type=RoleType.ASSISTANT,
134
+ meta_dict=None,
135
+ content="Your mission is to transform unstructured content "
136
+ "into structured graph data. Extract nodes and relationships with "
137
+ "precision, and let the connections unfold. Your graphs will "
138
+ "illuminate the hidden connections within the chaos of "
139
+ "information.",
140
+ )
141
+ super().__init__(system_message, model=model)
142
+
143
+ def run(
144
+ self,
145
+ element: "Element",
146
+ parse_graph_elements: bool = False,
147
+ ) -> Union[str, GraphElement]:
148
+ r"""Run the agent to extract node and relationship information.
149
+
150
+ Args:
151
+ element (Element): The input element.
152
+ parse_graph_elements (bool, optional): Whether to parse into
153
+ `GraphElement`. Defaults to `False`.
154
+
155
+ Returns:
156
+ Union[str, GraphElement]: The extracted node and relationship
157
+ information. If `parse_graph_elements` is `True` then return
158
+ `GraphElement`, else return `str`.
159
+ """
160
+ self.reset()
161
+ self.element = element
162
+
163
+ knowledge_graph_prompt = TextPrompt(text_prompt)
164
+ knowledge_graph_generation = knowledge_graph_prompt.format(
165
+ task=str(element)
166
+ )
167
+
168
+ knowledge_graph_generation_msg = BaseMessage.make_user_message(
169
+ role_name="Graphify", content=knowledge_graph_generation
170
+ )
171
+
172
+ response = self.step(input_message=knowledge_graph_generation_msg)
173
+
174
+ content = response.msg.content
175
+
176
+ if parse_graph_elements:
177
+ content = self._parse_graph_elements(content)
178
+
179
+ return content
180
+
181
+ def _validate_node(self, node: Node) -> bool:
182
+ r"""Validate if the object is a valid Node.
183
+
184
+ Args:
185
+ node (Node): Object to be validated.
186
+
187
+ Returns:
188
+ bool: True if the object is a valid Node, False otherwise.
189
+ """
190
+ return (
191
+ isinstance(node, Node)
192
+ and isinstance(node.id, (str, int))
193
+ and isinstance(node.type, str)
194
+ )
195
+
196
+ def _validate_relationship(self, relationship: Relationship) -> bool:
197
+ r"""Validate if the object is a valid Relationship.
198
+
199
+ Args:
200
+ relationship (Relationship): Object to be validated.
201
+
202
+ Returns:
203
+ bool: True if the object is a valid Relationship, False otherwise.
204
+ """
205
+ return (
206
+ isinstance(relationship, Relationship)
207
+ and self._validate_node(relationship.subj)
208
+ and self._validate_node(relationship.obj)
209
+ and isinstance(relationship.type, str)
210
+ )
211
+
212
+ def _parse_graph_elements(self, input_string: str) -> GraphElement:
213
+ r"""Parses graph elements from given content.
214
+
215
+ Args:
216
+ input_string (str): The input content.
217
+
218
+ Returns:
219
+ GraphElement: The parsed graph elements.
220
+ """
221
+ import re
222
+
223
+ # Regular expressions to extract nodes and relationships
224
+ node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
225
+ rel_pattern = (
226
+ r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
227
+ r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
228
+ )
229
+
230
+ nodes = {}
231
+ relationships = []
232
+
233
+ # Extract nodes
234
+ for match in re.finditer(node_pattern, input_string):
235
+ id, type = match.groups()
236
+ properties = {'source': 'agent_created'}
237
+ if id not in nodes:
238
+ node = Node(id=id, type=type, properties=properties)
239
+ if self._validate_node(node):
240
+ nodes[id] = node
241
+
242
+ # Extract relationships
243
+ for match in re.finditer(rel_pattern, input_string):
244
+ subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
245
+ properties = {'source': 'agent_created'}
246
+ if subj_id in nodes and obj_id in nodes:
247
+ subj = nodes[subj_id]
248
+ obj = nodes[obj_id]
249
+ relationship = Relationship(
250
+ subj=subj, obj=obj, type=rel_type, properties=properties
251
+ )
252
+ if self._validate_relationship(relationship):
253
+ relationships.append(relationship)
254
+
255
+ return GraphElement(
256
+ nodes=list(nodes.values()),
257
+ relationships=relationships,
258
+ source=self.element,
259
+ )
Paper2Poster/camel/agents/multi_hop_generator_agent.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import textwrap
16
+ from typing import Any
17
+
18
+ from pydantic import ConfigDict
19
+
20
+ from camel.agents.programmed_agent_instruction import (
21
+ ProgrammableChatAgent,
22
+ ProgrammedAgentInstructionResult,
23
+ programmable_capability,
24
+ )
25
+ from camel.datagen.source2synth.models import (
26
+ ContextPrompt,
27
+ MultiHopQA,
28
+ )
29
+ from camel.messages import BaseMessage
30
+
31
+
32
+ class MultiHopGeneratorAgent(ProgrammableChatAgent):
33
+ r"""An agent specialized in generating multi-hop question-answer pairs.
34
+
35
+ This agent is designed to create complex questions that require multiple
36
+ steps of reasoning to answer. It analyzes context to identify related
37
+ facts and generates questions that require connecting these facts
38
+ logically.
39
+
40
+ Attributes:
41
+ model_config (ConfigDict): Configuration for model behavior.
42
+ system_message (BaseMessage): System message defining agent's role and
43
+ instructions.
44
+ """
45
+
46
+ model_config = ConfigDict(arbitrary_types_allowed=True)
47
+
48
+ def __init__(self, **kwargs: Any) -> None:
49
+ r"""Initialize the MultiHopGeneratorAgent.
50
+
51
+ Args:
52
+ **kwargs (Any): Additional keyword arguments to pass to parent
53
+ class.
54
+ """
55
+ super().__init__(**kwargs)
56
+
57
+ system_text: str = textwrap.dedent(
58
+ """\
59
+ You are an expert at generating
60
+ multi-hop question-answer pairs.
61
+ For each context, you should:
62
+ 1. Identify multiple related facts or pieces of information
63
+ 2. Create questions that require reasoning across these multiple pieces
64
+ 3. Ensure the reasoning chain is clear and logical
65
+ 4. Generate questions that require at least 2-3 steps of reasoning
66
+ 5. Include the reasoning steps in the answer
67
+
68
+ Give your response with this information:
69
+ Question: [Complex question requiring multiple reasoning steps]
70
+ Reasoning Steps:
71
+ 1. [First reasoning step]
72
+ 2. [Second reasoning step]
73
+ 3. [Final reasoning step]
74
+ Answer: [Final answer]
75
+ Supporting Facts: [List of relevant text segments used]
76
+ """ # noqa: E501
77
+ )
78
+ self.system_message = BaseMessage.make_assistant_message(
79
+ role_name='Assistant', content=system_text
80
+ )
81
+
82
+ @programmable_capability
83
+ def generate_multi_hop_qa(
84
+ self, context: str
85
+ ) -> ProgrammedAgentInstructionResult[MultiHopQA]:
86
+ r"""Generate a multi-hop question-answer pair from given context.
87
+
88
+ Args:
89
+ context (str): The input text context to generate QA from.
90
+
91
+ Returns:
92
+ ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
93
+ generated question, reasoning steps, answer, and supporting
94
+ facts.
95
+
96
+ Raises:
97
+ RuntimeError: If the agent fails to generate a response.
98
+ """
99
+ context_prompt = ContextPrompt(
100
+ main_context=context, related_contexts=None
101
+ )
102
+
103
+ user_message = BaseMessage.make_user_message(
104
+ content=context_prompt.model_dump_json(), role_name="User"
105
+ )
106
+ response = self.step(
107
+ input_message=user_message, response_format=MultiHopQA
108
+ )
109
+ value = MultiHopQA.model_validate_json(response.msgs[0].content)
110
+
111
+ if response.msgs:
112
+ return ProgrammedAgentInstructionResult(
113
+ user_message=user_message,
114
+ agent_message=response.msgs[0],
115
+ value=value,
116
+ )
117
+ raise RuntimeError("No response from agent")
Paper2Poster/camel/agents/programmed_agent_instruction.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import abc
15
+ import threading
16
+ from enum import Enum
17
+ from functools import wraps
18
+ from typing import Any, Callable, Generic, Optional, TypeVar
19
+
20
+ from pydantic import BaseModel, ConfigDict
21
+
22
+ from camel.agents import ChatAgent
23
+ from camel.messages import BaseMessage
24
+
25
+ T = TypeVar('T')
26
+
27
+
28
+ class ProgrammableAgentRequirement(Enum):
29
+ r"""Requirements for programmable agent state.
30
+
31
+ Defines the possible requirements that can be used to repair the state
32
+ of a programmable agent.
33
+
34
+ Attributes:
35
+ LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
36
+ conversation was not from the user.
37
+ """
38
+
39
+ LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
40
+
41
+
42
+ class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
43
+ r"""Result of a programmable agent instruction execution.
44
+
45
+ Contains the messages exchanged during execution and the computed value.
46
+ The value type is specified by the generic type parameter T.
47
+
48
+ Attributes:
49
+ user_message (BaseMessage): The message sent by the user.
50
+ agent_message (BaseMessage): The message sent by the agent.
51
+ value (T): The computed result value of type T.
52
+ """
53
+
54
+ user_message: BaseMessage
55
+ agent_message: BaseMessage
56
+ value: T
57
+
58
+ model_config = ConfigDict(arbitrary_types_allowed=True)
59
+
60
+
61
+ class AbstractProgrammableAgent(abc.ABC):
62
+ r"""Abstract class for a programmable agent.
63
+
64
+ A programmable agent is an agent that can be programmed to perform a
65
+ specific function or task. This class defines the interface for a
66
+ programmable agent.
67
+
68
+ These methods should be implemented in order to ensure the agent supports
69
+ the necessary guarantees to enable a programming interface while
70
+ maintaining compatibility in a multi-agent system.
71
+
72
+ A programmable agent is responsible for providing and maintaining a
73
+ programming interface for its functionality.
74
+ """
75
+
76
+ @abc.abstractmethod
77
+ def run_atomic(
78
+ self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
79
+ ) -> ProgrammedAgentInstructionResult[T]:
80
+ r"""Run an atomic operation on the agent.
81
+
82
+ An atomic operation is an operation that is guaranteed to
83
+ be executed without interruption by any other operation.
84
+
85
+ Args:
86
+ callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
87
+ operation to execute atomically.
88
+
89
+ Returns:
90
+ ProgrammedAgentInstructionResult[T]: The result of the operation.
91
+
92
+ Raises:
93
+ RuntimeError: If an operation is already in progress.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ @abc.abstractmethod
98
+ def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
99
+ r"""Repair the state of the agent.
100
+
101
+ Agents may have other non-atomic interfaces, such as a user interface,
102
+ or chat between other agents. This method should restore the agent to
103
+ a state where it can perform operations according to the specified
104
+ requirement.
105
+
106
+ Args:
107
+ requirement (ProgrammableAgentRequirement): The requirement to
108
+ repair the state for.
109
+ """
110
+ raise NotImplementedError
111
+
112
+
113
+ def programmable_capability(
114
+ func: Callable[..., ProgrammedAgentInstructionResult[T]],
115
+ ) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
116
+ r"""Decorator for programmable agent capabilities.
117
+
118
+ This decorator ensures that the decorated method is executed atomically
119
+ and maintains the agent's state guarantees.
120
+
121
+ Args:
122
+ func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
123
+ to decorate.
124
+
125
+ Returns:
126
+ Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
127
+ method that ensures atomic execution.
128
+ """
129
+
130
+ @wraps(func)
131
+ def wrapper(
132
+ self, *args: Any, **kwargs: Any
133
+ ) -> ProgrammedAgentInstructionResult[T]:
134
+ return self.run_atomic(lambda: func(self, *args, **kwargs))
135
+
136
+ return wrapper
137
+
138
+
139
+ class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
140
+ r"""A chat agent that can be programmed to perform specific tasks.
141
+
142
+ Provides a default implementation of atomic execution using threading locks
143
+ and basic state tracking for message roles. Implementing classes need to
144
+ provide specific repair logic for their use cases.
145
+
146
+ Attributes:
147
+ _operation_lock (threading.Lock): Lock for ensuring atomic operations.
148
+ _last_message_role (Optional[str]): Role of the last message in the
149
+ conversation.
150
+ """
151
+
152
+ def __init__(self, **kwargs: Any) -> None:
153
+ r"""Initialize the ProgrammableChatAgent.
154
+
155
+ Args:
156
+ **kwargs (Any): Additional keyword arguments to pass to parent
157
+ class.
158
+ """
159
+ super().__init__(**kwargs)
160
+ self._operation_lock = threading.Lock()
161
+ self._last_message_role: Optional[str] = None
162
+
163
+ def run_atomic(
164
+ self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
165
+ ) -> ProgrammedAgentInstructionResult[T]:
166
+ r"""Run an atomic operation on the agent.
167
+
168
+ Ensures thread-safe execution of the callback function by using a lock.
169
+
170
+ Args:
171
+ callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
172
+ operation to execute atomically.
173
+
174
+ Returns:
175
+ ProgrammedAgentInstructionResult[T]: The result of the operation.
176
+
177
+ Raises:
178
+ RuntimeError: If an operation is already in progress.
179
+ """
180
+ if not self._operation_lock.acquire(blocking=False):
181
+ raise RuntimeError("Operation already in progress")
182
+
183
+ try:
184
+ result = callback()
185
+ self._last_message_role = result.agent_message.role_name
186
+ return result
187
+ finally:
188
+ self._operation_lock.release()
189
+
190
+ def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
191
+ r"""Repair the state of the agent.
192
+
193
+ Implements basic state repair for message role requirements.
194
+
195
+ Args:
196
+ requirement (ProgrammableAgentRequirement): The requirement to
197
+ repair the state for.
198
+ """
199
+ if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
200
+ if self._last_message_role == "user":
201
+ raise NotImplementedError(
202
+ "Must implement repair for LAST_MESSAGE_NOT_USER"
203
+ )
Paper2Poster/camel/agents/role_assignment_agent.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import re
15
+ from typing import Dict, Optional, Union
16
+
17
+ from camel.agents.chat_agent import ChatAgent
18
+ from camel.messages import BaseMessage
19
+ from camel.models import BaseModelBackend
20
+ from camel.prompts import TextPrompt
21
+ from camel.types import RoleType
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="RoleAssignmentAgent")
36
+ class RoleAssignmentAgent(ChatAgent):
37
+ r"""An agent that generates role names based on the task prompt.
38
+
39
+ Args:
40
+ model (BaseModelBackend, optional): The model backend to use for
41
+ generating responses. (default: :obj:`OpenAIModel` with
42
+ `GPT_4O_MINI`)
43
+
44
+ Attributes:
45
+ role_assignment_prompt (TextPrompt): A prompt for the agent to generate
46
+ role names.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model: Optional[BaseModelBackend] = None,
52
+ ) -> None:
53
+ system_message = BaseMessage(
54
+ role_name="Role Assigner",
55
+ role_type=RoleType.ASSISTANT,
56
+ meta_dict=None,
57
+ content="You assign roles based on tasks.",
58
+ )
59
+ super().__init__(system_message, model=model)
60
+
61
+ def run(
62
+ self,
63
+ task_prompt: Union[str, TextPrompt],
64
+ num_roles: int = 2,
65
+ ) -> Dict[str, str]:
66
+ r"""Generate role names based on the input task prompt.
67
+
68
+ Args:
69
+ task_prompt (Union[str, TextPrompt]): The prompt
70
+ for the task based on which the roles are to be generated.
71
+ num_roles (int, optional): The number of roles to generate.
72
+ (default: :obj:`2`)
73
+
74
+ Returns:
75
+ Dict[str, str]: A dictionary mapping role names to their
76
+ descriptions.
77
+ """
78
+ self.reset()
79
+
80
+ expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
81
+ f"Domain expert {i + 1}: <BLANK>\n"
82
+ f"Associated competencies, characteristics, duties "
83
+ f"and workflows: <BLANK>. End."
84
+ for i in range(num_roles or 0)
85
+ )
86
+ role_assignment_generation_prompt = TextPrompt(
87
+ "You are a role assignment agent, and you're in charge of "
88
+ + "recruiting {num_roles} experts for the following task."
89
+ + "\n==== TASK =====\n {task}\n\n"
90
+ + "Identify the domain experts you'd recruit and detail their "
91
+ + "associated competencies, characteristics, duties and workflows "
92
+ + "to complete the task.\n "
93
+ + "Your answer MUST adhere to the format of ANSWER PROMPT, and "
94
+ + "ONLY answer the BLANKs.\n"
95
+ + expert_prompt
96
+ )
97
+ role_assignment_generation = role_assignment_generation_prompt.format(
98
+ num_roles=num_roles, task=task_prompt
99
+ )
100
+
101
+ role_assignment_generation_msg = BaseMessage.make_user_message(
102
+ role_name="Role Assigner", content=role_assignment_generation
103
+ )
104
+
105
+ response = self.step(input_message=role_assignment_generation_msg)
106
+
107
+ msg = response.msg # type: BaseMessage
108
+ terminated = response.terminated
109
+
110
+ # Distribute the output completions into role names and descriptions
111
+ role_names = [
112
+ desc.replace("<|", "").replace("|>", "")
113
+ for desc in re.findall(
114
+ r"Domain expert \d: (.+?)\nAssociated competencies,",
115
+ msg.content,
116
+ re.DOTALL,
117
+ )
118
+ ]
119
+ role_descriptions = [
120
+ desc.replace("<|", "").replace("|>", "")
121
+ for desc in re.findall(
122
+ r"Associated competencies, characteristics, "
123
+ r"duties and workflows: (.+?) End.",
124
+ msg.content,
125
+ re.DOTALL,
126
+ )
127
+ ]
128
+
129
+ if len(role_names) != num_roles or len(role_descriptions) != num_roles:
130
+ raise RuntimeError(
131
+ "Got None or insufficient information of roles."
132
+ )
133
+ if terminated:
134
+ raise RuntimeError("Role assignment failed.")
135
+
136
+ role_descriptions_dict = {
137
+ role_name: description
138
+ for role_name, description in zip(role_names, role_descriptions)
139
+ }
140
+
141
+ return role_descriptions_dict
Paper2Poster/camel/agents/search_agent.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Optional
15
+
16
+ from camel.agents.chat_agent import ChatAgent
17
+ from camel.messages import BaseMessage
18
+ from camel.models import BaseModelBackend
19
+ from camel.prompts import TextPrompt
20
+ from camel.types import RoleType
21
+ from camel.utils import create_chunks
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="SearchAgent")
36
+ class SearchAgent(ChatAgent):
37
+ r"""An agent that summarizes text based on a query and evaluates the
38
+ relevance of an answer.
39
+
40
+ Args:
41
+ model (BaseModelBackend, optional): The model backend to use for
42
+ generating responses. (default: :obj:`OpenAIModel` with
43
+ `GPT_4O_MINI`)
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ model: Optional[BaseModelBackend] = None,
49
+ ) -> None:
50
+ system_message = BaseMessage(
51
+ role_name="Assistant",
52
+ role_type=RoleType.ASSISTANT,
53
+ meta_dict=None,
54
+ content="You are a helpful assistant.",
55
+ )
56
+ super().__init__(system_message, model=model)
57
+
58
+ def summarize_text(self, text: str, query: str) -> str:
59
+ r"""Summarize the information from the text, base on the query.
60
+
61
+ Args:
62
+ text (str): Text to summarize.
63
+ query (str): What information you want.
64
+
65
+ Returns:
66
+ str: Strings with information.
67
+ """
68
+ self.reset()
69
+
70
+ summary_prompt = TextPrompt(
71
+ '''Gather information from this text that relative to the
72
+ question, but do not directly answer the question.\nquestion:
73
+ {query}\ntext '''
74
+ )
75
+ summary_prompt = summary_prompt.format(query=query)
76
+ # Max length of each chunk
77
+ max_len = 3000
78
+ results = ""
79
+ chunks = create_chunks(text, max_len)
80
+ # Summarize
81
+ for i, chunk in enumerate(chunks, start=1):
82
+ prompt = summary_prompt + str(i) + ": " + chunk
83
+ user_msg = BaseMessage.make_user_message(
84
+ role_name="User",
85
+ content=prompt,
86
+ )
87
+ result = self.step(user_msg).msg.content
88
+ results += result + "\n"
89
+
90
+ # Final summarization
91
+ final_prompt = TextPrompt(
92
+ '''Here are some summarized texts which split from one text. Using
93
+ the information to answer the question. If can't find the answer,
94
+ you must answer "I can not find the answer to the query" and
95
+ explain why.\n Query:\n{query}.\n\nText:\n'''
96
+ )
97
+ final_prompt = final_prompt.format(query=query)
98
+ prompt = final_prompt + results
99
+
100
+ user_msg = BaseMessage.make_user_message(
101
+ role_name="User",
102
+ content=prompt,
103
+ )
104
+ response = self.step(user_msg).msg.content
105
+
106
+ return response
107
+
108
+ def continue_search(self, query: str, answer: str) -> bool:
109
+ r"""Ask whether to continue search or not based on the provided answer.
110
+
111
+ Args:
112
+ query (str): The question.
113
+ answer (str): The answer to the question.
114
+
115
+ Returns:
116
+ bool: `True` if the user want to continue search, `False`
117
+ otherwise.
118
+ """
119
+ prompt = TextPrompt(
120
+ "Do you think the ANSWER can answer the QUERY? "
121
+ "Use only 'yes' or 'no' to answer.\n"
122
+ "===== QUERY =====\n{query}\n\n"
123
+ "===== ANSWER =====\n{answer}"
124
+ )
125
+ prompt = prompt.format(query=query, answer=answer)
126
+ user_msg = BaseMessage.make_user_message(
127
+ role_name="User",
128
+ content=prompt,
129
+ )
130
+ response = self.step(user_msg).msg.content
131
+ if "yes" in str(response).lower():
132
+ return False
133
+ return True
Paper2Poster/camel/agents/task_agent.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Dict, List, Optional, Union
15
+
16
+ from camel.agents.chat_agent import ChatAgent
17
+ from camel.messages import BaseMessage
18
+ from camel.models import BaseModelBackend
19
+ from camel.prompts import PromptTemplateGenerator, TextPrompt
20
+ from camel.types import RoleType, TaskType
21
+ from camel.utils import get_task_list
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="TaskSpecifyAgent")
36
+ class TaskSpecifyAgent(ChatAgent):
37
+ r"""An agent that specifies a given task prompt by prompting the user to
38
+ provide more details.
39
+
40
+ Attributes:
41
+ DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
42
+ task_specify_prompt (TextPrompt): The prompt for specifying the task.
43
+
44
+ Args:
45
+ model (BaseModelBackend, optional): The model backend to use for
46
+ generating responses. (default: :obj:`OpenAIModel` with
47
+ `GPT_4O_MINI`)
48
+ task_type (TaskType, optional): The type of task for which to generate
49
+ a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
50
+ task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
51
+ specifying the task. (default: :obj:`None`)
52
+ word_limit (int, optional): The word limit for the task prompt.
53
+ (default: :obj:`50`)
54
+ output_language (str, optional): The language to be output by the
55
+ agent. (default: :obj:`None`)
56
+ """
57
+
58
+ DEFAULT_WORD_LIMIT = 50
59
+
60
+ def __init__(
61
+ self,
62
+ model: Optional[BaseModelBackend] = None,
63
+ task_type: TaskType = TaskType.AI_SOCIETY,
64
+ task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
65
+ word_limit: int = DEFAULT_WORD_LIMIT,
66
+ output_language: Optional[str] = None,
67
+ ) -> None:
68
+ self.task_specify_prompt: Union[str, TextPrompt]
69
+ if task_specify_prompt is None:
70
+ task_specify_prompt_template = (
71
+ PromptTemplateGenerator().get_task_specify_prompt(task_type)
72
+ )
73
+
74
+ self.task_specify_prompt = task_specify_prompt_template.format(
75
+ word_limit=word_limit
76
+ )
77
+ else:
78
+ self.task_specify_prompt = TextPrompt(task_specify_prompt)
79
+
80
+ system_message = BaseMessage(
81
+ role_name="Task Specifier",
82
+ role_type=RoleType.ASSISTANT,
83
+ meta_dict=None,
84
+ content="You can make a task more specific.",
85
+ )
86
+
87
+ super().__init__(
88
+ system_message,
89
+ model=model,
90
+ output_language=output_language,
91
+ )
92
+
93
+ def run(
94
+ self,
95
+ task_prompt: Union[str, TextPrompt],
96
+ meta_dict: Optional[Dict[str, Any]] = None,
97
+ ) -> TextPrompt:
98
+ r"""Specify the given task prompt by providing more details.
99
+
100
+ Args:
101
+ task_prompt (Union[str, TextPrompt]): The original task
102
+ prompt.
103
+ meta_dict (Dict[str, Any], optional): A dictionary containing
104
+ additional information to include in the prompt.
105
+ (default: :obj:`None`)
106
+
107
+ Returns:
108
+ TextPrompt: The specified task prompt.
109
+ """
110
+ self.reset()
111
+ task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
112
+
113
+ if meta_dict is not None:
114
+ task_specify_prompt = task_specify_prompt.format(**meta_dict)
115
+ task_msg = BaseMessage.make_user_message(
116
+ role_name="Task Specifier", content=task_specify_prompt
117
+ )
118
+ specifier_response = self.step(task_msg)
119
+
120
+ if specifier_response.terminated:
121
+ raise RuntimeError("Task specification failed.")
122
+ if len(specifier_response.msgs) == 0:
123
+ raise RuntimeError("Got no specification message.")
124
+
125
+ specified_task_msg = specifier_response.msgs[0]
126
+
127
+ return TextPrompt(specified_task_msg.content)
128
+
129
+
130
+ @track_agent(name="TaskPlannerAgent")
131
+ class TaskPlannerAgent(ChatAgent):
132
+ r"""An agent that helps divide a task into subtasks based on the input
133
+ task prompt.
134
+
135
+ Attributes:
136
+ task_planner_prompt (TextPrompt): A prompt for the agent to divide
137
+ the task into subtasks.
138
+
139
+ Args:
140
+ model (BaseModelBackend, optional): The model backend to use for
141
+ generating responses. (default: :obj:`OpenAIModel` with
142
+ `GPT_4O_MINI`)
143
+ output_language (str, optional): The language to be output by the
144
+ agent. (default: :obj:`None`)
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ model: Optional[BaseModelBackend] = None,
150
+ output_language: Optional[str] = None,
151
+ ) -> None:
152
+ self.task_planner_prompt = TextPrompt(
153
+ "Divide this task into subtasks: {task}. Be concise."
154
+ )
155
+ system_message = BaseMessage(
156
+ role_name="Task Planner",
157
+ role_type=RoleType.ASSISTANT,
158
+ meta_dict=None,
159
+ content="You are a helpful task planner.",
160
+ )
161
+
162
+ super().__init__(
163
+ system_message,
164
+ model=model,
165
+ output_language=output_language,
166
+ )
167
+
168
+ def run(
169
+ self,
170
+ task_prompt: Union[str, TextPrompt],
171
+ ) -> TextPrompt:
172
+ r"""Generate subtasks based on the input task prompt.
173
+
174
+ Args:
175
+ task_prompt (Union[str, TextPrompt]): The prompt for the task to
176
+ be divided into subtasks.
177
+
178
+ Returns:
179
+ TextPrompt: A prompt for the subtasks generated by the agent.
180
+ """
181
+ # TODO: Maybe include roles information.
182
+ self.reset()
183
+ task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
184
+
185
+ task_msg = BaseMessage.make_user_message(
186
+ role_name="Task Planner", content=task_planner_prompt
187
+ )
188
+
189
+ task_response = self.step(task_msg)
190
+
191
+ if task_response.terminated:
192
+ raise RuntimeError("Task planning failed.")
193
+ if len(task_response.msgs) == 0:
194
+ raise RuntimeError("Got no task planning message.")
195
+
196
+ sub_tasks_msg = task_response.msgs[0]
197
+ return TextPrompt(sub_tasks_msg.content)
198
+
199
+
200
+ @track_agent(name="TaskCreationAgent")
201
+ class TaskCreationAgent(ChatAgent):
202
+ r"""An agent that helps create new tasks based on the objective
203
+ and last completed task. Compared to :obj:`TaskPlannerAgent`,
204
+ it's still a task planner, but it has more context information
205
+ like last task and incomplete task list. Modified from
206
+ `BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
207
+
208
+ Attributes:
209
+ task_creation_prompt (TextPrompt): A prompt for the agent to
210
+ create new tasks.
211
+
212
+ Args:
213
+ role_name (str): The role name of the Agent to create the task.
214
+ objective (Union[str, TextPrompt]): The objective of the Agent to
215
+ perform the task.
216
+ model (BaseModelBackend, optional): The LLM backend to use for
217
+ generating responses. (default: :obj:`OpenAIModel` with
218
+ `GPT_4O_MINI`)
219
+ output_language (str, optional): The language to be output by the
220
+ agent. (default: :obj:`None`)
221
+ message_window_size (int, optional): The maximum number of previous
222
+ messages to include in the context window. If `None`, no windowing
223
+ is performed. (default: :obj:`None`)
224
+ max_task_num (int, optional): The maximum number of planned
225
+ tasks in one round. (default: :obj:3)
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ role_name: str,
231
+ objective: Union[str, TextPrompt],
232
+ model: Optional[BaseModelBackend] = None,
233
+ output_language: Optional[str] = None,
234
+ message_window_size: Optional[int] = None,
235
+ max_task_num: Optional[int] = 3,
236
+ ) -> None:
237
+ task_creation_prompt = TextPrompt(
238
+ """Create new a task with the following objective: {objective}.
239
+ Never forget you are a Task Creator of {role_name}.
240
+ You must instruct me based on my expertise and your needs to solve the task.
241
+ You should consider past solved tasks and in-progress tasks: {task_list}.
242
+ The new created tasks must not overlap with these past tasks.
243
+ The result must be a numbered list in the format:
244
+
245
+ #. First Task
246
+ #. Second Task
247
+ #. Third Task
248
+
249
+ You can only give me up to {max_task_num} tasks at a time. \
250
+ Each task should be concise, concrete and doable for a {role_name}.
251
+ You should make task plan and not ask me questions.
252
+ If you think no new tasks are needed right now, write "No tasks to add."
253
+ Now start to give me new tasks one by one. No more than three tasks.
254
+ Be concrete.
255
+ """
256
+ )
257
+
258
+ self.task_creation_prompt = task_creation_prompt.format(
259
+ objective=objective, role_name=role_name, max_task_num=max_task_num
260
+ )
261
+ self.objective = objective
262
+
263
+ system_message = BaseMessage(
264
+ role_name="Task Creator",
265
+ role_type=RoleType.ASSISTANT,
266
+ meta_dict=None,
267
+ content="You are a helpful task creator.",
268
+ )
269
+
270
+ super().__init__(
271
+ system_message,
272
+ model=model,
273
+ output_language=output_language,
274
+ message_window_size=message_window_size,
275
+ )
276
+
277
+ def run(
278
+ self,
279
+ task_list: List[str],
280
+ ) -> List[str]:
281
+ r"""Generate subtasks based on the previous task results and
282
+ incomplete task list.
283
+
284
+ Args:
285
+ task_list (List[str]): The completed or in-progress
286
+ tasks which should not overlap with new created tasks.
287
+
288
+ Returns:
289
+ List[str]: The new task list generated by the Agent.
290
+ """
291
+
292
+ if len(task_list) > 0:
293
+ task_creation_prompt = self.task_creation_prompt.format(
294
+ task_list=task_list
295
+ )
296
+ else:
297
+ task_creation_prompt = self.task_creation_prompt.format(
298
+ task_list=""
299
+ )
300
+
301
+ task_msg = BaseMessage.make_user_message(
302
+ role_name="Task Creator", content=task_creation_prompt
303
+ )
304
+ task_response = self.step(task_msg)
305
+
306
+ if task_response.terminated:
307
+ raise RuntimeError("Task creation failed.")
308
+ if len(task_response.msgs) == 0:
309
+ raise RuntimeError("Got no task creation message.")
310
+
311
+ sub_tasks_msg = task_response.msgs[0]
312
+ return get_task_list(sub_tasks_msg.content)
313
+
314
+
315
+ @track_agent(name="TaskPrioritizationAgent")
316
+ class TaskPrioritizationAgent(ChatAgent):
317
+ r"""An agent that helps re-prioritize the task list and
318
+ returns numbered prioritized list. Modified from
319
+ `BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
320
+
321
+ Attributes:
322
+ task_prioritization_prompt (TextPrompt): A prompt for the agent to
323
+ prioritize tasks.
324
+
325
+ Args:
326
+ objective (Union[str, TextPrompt]): The objective of the Agent to
327
+ perform the task.
328
+ model (BaseModelBackend, optional): The LLM backend to use for
329
+ generating responses. (default: :obj:`OpenAIModel` with
330
+ `GPT_4O_MINI`)
331
+ output_language (str, optional): The language to be output by the
332
+ agent. (default: :obj:`None`)
333
+ message_window_size (int, optional): The maximum number of previous
334
+ messages to include in the context window. If `None`, no windowing
335
+ is performed. (default: :obj:`None`)
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ objective: Union[str, TextPrompt],
341
+ model: Optional[BaseModelBackend] = None,
342
+ output_language: Optional[str] = None,
343
+ message_window_size: Optional[int] = None,
344
+ ) -> None:
345
+ task_prioritization_prompt = TextPrompt(
346
+ """Prioritize the following tasks : {task_list}.
347
+ Consider the ultimate objective of you: {objective}.
348
+ Tasks should be sorted from highest to lowest priority, where higher-priority \
349
+ tasks are those that act as pre-requisites or are more essential for meeting \
350
+ the objective. Return one task per line in your response.
351
+ Do not remove or modify any tasks.
352
+ The result must be a numbered list in the format:
353
+
354
+ #. First task
355
+ #. Second task
356
+
357
+ The entries must be consecutively numbered, starting with 1.
358
+ The number of each entry must be followed by a period.
359
+ Do not include any headers before your ranked list or follow your list \
360
+ with any other output."""
361
+ )
362
+
363
+ self.task_prioritization_prompt = task_prioritization_prompt.format(
364
+ objective=objective
365
+ )
366
+ self.objective = objective
367
+
368
+ system_message = BaseMessage(
369
+ role_name="Task Prioritizer",
370
+ role_type=RoleType.ASSISTANT,
371
+ meta_dict=None,
372
+ content="You are a helpful task prioritizer.",
373
+ )
374
+
375
+ super().__init__(
376
+ system_message,
377
+ model=model,
378
+ output_language=output_language,
379
+ message_window_size=message_window_size,
380
+ )
381
+
382
+ def run(
383
+ self,
384
+ task_list: List[str],
385
+ ) -> List[str]:
386
+ r"""Prioritize the task list given the agent objective.
387
+
388
+ Args:
389
+ task_list (List[str]): The unprioritized tasks of agent.
390
+
391
+ Returns:
392
+ List[str]: The new prioritized task list generated by the Agent.
393
+ """
394
+ task_prioritization_prompt = self.task_prioritization_prompt.format(
395
+ task_list=task_list
396
+ )
397
+
398
+ task_msg = BaseMessage.make_user_message(
399
+ role_name="Task Prioritizer", content=task_prioritization_prompt
400
+ )
401
+
402
+ task_response = self.step(task_msg)
403
+
404
+ if task_response.terminated:
405
+ raise RuntimeError("Task prioritization failed.")
406
+ if len(task_response.msgs) == 0:
407
+ raise RuntimeError("Got no task prioritization message.")
408
+
409
+ sub_tasks_msg = task_response.msgs[0]
410
+ return get_task_list(sub_tasks_msg.content)
Paper2Poster/camel/agents/tool_agents/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .base import BaseToolAgent
15
+ from .hugging_face_tool_agent import HuggingFaceToolAgent
16
+
17
+ __all__ = [
18
+ 'BaseToolAgent',
19
+ 'HuggingFaceToolAgent',
20
+ ]
Paper2Poster/camel/agents/tool_agents/base.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from camel.agents import BaseAgent
15
+
16
+
17
+ class BaseToolAgent(BaseAgent):
18
+ r"""Creates a :obj:`BaseToolAgent` object with the specified name and
19
+ description.
20
+
21
+ Args:
22
+ name (str): The name of the tool agent.
23
+ description (str): The description of the tool agent.
24
+ """
25
+
26
+ def __init__(self, name: str, description: str) -> None:
27
+ self.name = name
28
+ self.description = description
29
+
30
+ def reset(self) -> None:
31
+ r"""Resets the agent to its initial state."""
32
+ pass
33
+
34
+ def step(self) -> None:
35
+ r"""Performs a single step of the agent."""
36
+ pass
37
+
38
+ def __str__(self) -> str:
39
+ return f"{self.name}: {self.description}"
Paper2Poster/camel/agents/tool_agents/hugging_face_tool_agent.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Optional
15
+
16
+ from camel.agents.tool_agents.base import BaseToolAgent
17
+
18
+
19
+ # flake8: noqa :E501
20
+ class HuggingFaceToolAgent(BaseToolAgent):
21
+ r"""Tool agent for calling HuggingFace models. This agent is a wrapper
22
+ around agents from the `transformers` library. For more information
23
+ about the available models, please see the `transformers` documentation
24
+ at https://huggingface.co/docs/transformers/transformers_agents.
25
+
26
+ Args:
27
+ name (str): The name of the agent.
28
+ *args (Any): Additional positional arguments to pass to the underlying
29
+ Agent class.
30
+ remote (bool, optional): Flag indicating whether to run the agent
31
+ remotely. (default: :obj:`True`)
32
+ **kwargs (Any): Additional keyword arguments to pass to the underlying
33
+ Agent class.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ name: str,
39
+ *args: Any,
40
+ remote: bool = True,
41
+ **kwargs: Any,
42
+ ) -> None:
43
+ try:
44
+ # TODO: Support other tool agents
45
+ import transformers
46
+ from packaging import version
47
+
48
+ if version.parse(transformers.__version__) < version.parse(
49
+ "4.31.0"
50
+ ):
51
+ raise ValueError(
52
+ "The version of \"transformers\" package should >= 4.31.0"
53
+ )
54
+
55
+ from transformers.tools import OpenAiAgent
56
+ from transformers.tools.agent_types import AgentImage
57
+ except (ImportError, ValueError):
58
+ raise ValueError(
59
+ "Could not import transformers tool agents. "
60
+ "Please setup the environment with "
61
+ "pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
62
+ )
63
+ self.agent_image_type = AgentImage
64
+ self.agent = OpenAiAgent(*args, **kwargs)
65
+ description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
66
+ - Document question answering: given a document (such as a PDF) in image format, answer a question on this document
67
+ - Text question answering: given a long text and a question, answer the question in the text
68
+ - Unconditional image captioning: Caption the image!
69
+ - Image question answering: given an image, answer a question on this image
70
+ - Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
71
+ - Speech to text: given an audio recording of a person talking, transcribe the speech into text
72
+ - Text to speech: convert text to speech
73
+ - Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
74
+ - Text summarization: summarize a long text in one or a few sentences
75
+ - Translation: translate the text into a given language
76
+ - Text downloading: to download a text from a web URL
77
+ - Text to image: generate an image according to a prompt, leveraging stable diffusion
78
+ - Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
79
+ - Text to video: generate a small video according to a prompt
80
+
81
+ Here are some python code examples of what you can do with this agent:
82
+
83
+ Single execution (step) mode, the single execution method is when using the step() method of the agent:
84
+ ```
85
+ # Text to image
86
+ rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
87
+ rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
88
+
89
+ # Text to image -> Image transformation
90
+ sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
91
+ sea_add_island_image.save("./sea_add_island_image.png")
92
+
93
+ # If you'd like to keep a state across executions or to pass non-text objects to the agent,
94
+ # you can do so by specifying variables that you would like the agent to use. For example,
95
+ # you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
96
+ picture = {name}.step("Generate a picture of rivers and lakes.")
97
+ picture.save("./picture.png")
98
+ updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
99
+ updated_picture.save("./updated_picture.png")
100
+
101
+ capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
102
+ capybara_sea_image.save("./capybara_sea_image.png")
103
+
104
+ # Document question answering
105
+ answer = {name}.step(
106
+ "In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
107
+ document=document,
108
+ )
109
+ print(answer)
110
+
111
+
112
+ # Text to image
113
+ boat_image = {name}.step("Generate an image of a boat in the water")
114
+ boat_image.save("./boat_image.png")
115
+
116
+ # Unconditional image captioning
117
+ boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
118
+ print(boat_image_caption)
119
+
120
+ # Text to image -> Unconditional image captioning -> Text to speech
121
+ boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
122
+
123
+ # Text downloading
124
+ document = {name}.step("Download the text from http://hf.co")
125
+ print(document)
126
+
127
+ # Text summarization
128
+ summary = {name}.step("Summarize the following text: `document`", document=document)
129
+ print(summary)
130
+
131
+ # Text downloading -> Text summarization -> Text to speech
132
+ audio = {name}.step("Read out loud the summary of http://hf.co")
133
+ ```
134
+
135
+ Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
136
+ ```
137
+ # Clean the chat history
138
+ {name}.reset()
139
+
140
+ # Text to image
141
+ capybara_image = {name}.chat("Show me an an image of a capybara")
142
+ capybara_image.save("./capybara_image.png")
143
+
144
+ # Image transformation
145
+ transformed_capybara_image = {name}.chat("Transform the image so that it snows")
146
+ transformed_capybara_image.save("./transformed_capybara_image.png")
147
+
148
+ # Image segmentation
149
+ segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
150
+ segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
151
+ ```
152
+ """
153
+ super(HuggingFaceToolAgent, self).__init__(name, description)
154
+ self.remote = remote
155
+
156
+ def reset(self) -> None:
157
+ r"""Resets the chat history of the agent."""
158
+ self.agent.prepare_for_new_chat()
159
+
160
+ def step(
161
+ self,
162
+ *args: Any,
163
+ remote: Optional[bool] = None,
164
+ **kwargs: Any,
165
+ ) -> Any:
166
+ r"""Runs the agent in single execution mode.
167
+
168
+ Args:
169
+ *args (Any): Positional arguments to pass to the agent.
170
+ remote (bool, optional): Flag indicating whether to run the agent
171
+ remotely. Overrides the default setting. (default: :obj:`None`)
172
+ **kwargs (Any): Keyword arguments to pass to the agent.
173
+
174
+ Returns:
175
+ str: The response from the agent.
176
+ """
177
+ if remote is None:
178
+ remote = self.remote
179
+ agent_output = self.agent.run(*args, remote=remote, **kwargs)
180
+ if isinstance(agent_output, self.agent_image_type):
181
+ agent_output = agent_output.to_raw()
182
+ return agent_output
183
+
184
+ def chat(
185
+ self,
186
+ *args: Any,
187
+ remote: Optional[bool] = None,
188
+ **kwargs: Any,
189
+ ) -> Any:
190
+ r"""Runs the agent in a chat conversation mode.
191
+
192
+ Args:
193
+ *args (Any): Positional arguments to pass to the agent.
194
+ remote (bool, optional): Flag indicating whether to run the agent
195
+ remotely. Overrides the default setting. (default: :obj:`None`)
196
+ **kwargs (Any): Keyword arguments to pass to the agent.
197
+
198
+ Returns:
199
+ str: The response from the agent.
200
+ """
201
+ if remote is None:
202
+ remote = self.remote
203
+ agent_output = self.agent.chat(*args, remote=remote, **kwargs)
204
+ if isinstance(agent_output, self.agent_image_type):
205
+ agent_output = agent_output.to_raw()
206
+ return agent_output
Paper2Poster/camel/benchmarks/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from .apibank import APIBankBenchmark
16
+ from .apibench import APIBenchBenchmark
17
+ from .base import BaseBenchmark
18
+ from .gaia import DefaultGAIARetriever, GAIABenchmark
19
+ from .nexus import NexusBenchmark
20
+ from .ragbench import RAGBenchBenchmark
21
+
22
+ __all__ = [
23
+ "BaseBenchmark",
24
+ "GAIABenchmark",
25
+ "DefaultGAIARetriever",
26
+ "NexusBenchmark",
27
+ "APIBenchBenchmark",
28
+ "APIBankBenchmark",
29
+ "RAGBenchBenchmark",
30
+ ]
Paper2Poster/camel/benchmarks/apibank.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import random
19
+ import re
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Literal, Optional
23
+
24
+ import numpy as np
25
+ from rouge import Rouge
26
+ from tqdm import tqdm
27
+
28
+ from camel.agents import ChatAgent
29
+ from camel.benchmarks.base import BaseBenchmark
30
+ from camel.messages import BaseMessage
31
+ from camel.utils import download_github_subdirectory
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Add current folder to sys.path to enable relative import
36
+ current_folder = os.getcwd()
37
+ if current_folder not in sys.path:
38
+ sys.path.append(current_folder)
39
+
40
+
41
+ def process_messages(
42
+ chat_history: List[Dict[str, Any]],
43
+ prompt: str,
44
+ ) -> List[Dict[str, str]]:
45
+ """
46
+ Processes chat history into a structured format for further use.
47
+
48
+ Args:
49
+ chat_history (List[Dict[str, Any]):
50
+ A list of dictionaries representing the chat history.
51
+ prompt (str): A propmt to be set as the system message.
52
+
53
+ Returns:
54
+ List[Dict[str, str]]: A list of dictionaries representing
55
+ the processed messages, where each dictionary has:
56
+ - 'role': The role of the message ('system', 'user', or 'assistant').
57
+ - 'content': The content of the message, including formatted
58
+ API responses when applicable.
59
+ """
60
+ messages = [{'role': 'system', 'content': prompt}]
61
+ for item in chat_history:
62
+ role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
63
+ chat_role = role_map.get(
64
+ item['role'], 'unknown'
65
+ ) # default role to 'unknown'
66
+ if item['role'] == 'API':
67
+ chat_content = '[{}({})] Response: {}'.format(
68
+ item['api_name'],
69
+ ', '.join(
70
+ [
71
+ '{}=\'{}\''.format(k, v)
72
+ for k, v in item['param_dict'].items()
73
+ ]
74
+ ),
75
+ str(item['result']['output']),
76
+ )
77
+ else:
78
+ chat_content = item['text']
79
+ messages.append({'role': chat_role, 'content': chat_content})
80
+ return messages
81
+
82
+
83
+ class APIBankBenchmark(BaseBenchmark):
84
+ r"""API-Bank Benchmark adapted from `API-Bank:
85
+ A Comprehensive Benchmark for Tool-Augmented LLMs`
86
+ <https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
87
+
88
+ Args:
89
+ save_to (str): The file to save the results.
90
+ processes (int, optional): The number of processes to use.
91
+ (default: :obj:`1`)
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ save_to: str,
97
+ processes: int = 1,
98
+ ):
99
+ r"""Initialize the APIBank benchmark.
100
+
101
+ Args:
102
+ save_to (str): The file to save the results.
103
+ processes (int, optional): The number of processes to use for
104
+ parallel processing. (default: :obj:`1`)
105
+ """
106
+ # Predefine data_dir for better import management
107
+ super().__init__("apibank", "api_bank", save_to, processes)
108
+ self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
109
+
110
+ def download(self):
111
+ r"""Download APIBank dataset and code from Github."""
112
+
113
+ repo = "AlibabaResearch/DAMO-ConvAI"
114
+ subdir = "api-bank"
115
+ data_dir = self.data_dir
116
+
117
+ download_github_subdirectory(repo, subdir, data_dir)
118
+
119
+ sys.path.insert(0, self.data_dir)
120
+ logger.info("Download completed.")
121
+
122
+ def load(self, level: str, force_download: bool = False): # type: ignore[override]
123
+ r"""Load the APIBank Benchmark dataset.
124
+
125
+ Args:
126
+ level (str): Level to run benchmark on.
127
+ force_download (bool, optional): Whether to
128
+ force download the data.
129
+ """
130
+ if force_download:
131
+ logger.info("Force downloading data.")
132
+ self.download()
133
+
134
+ if level == "level-1":
135
+ file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
136
+ elif level == 'level-2':
137
+ file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
138
+ jsonl_files = [
139
+ f for f in os.listdir(file_path) if f.endswith('.jsonl')
140
+ ]
141
+ for file in tqdm(jsonl_files, desc="Processing files"):
142
+ history = []
143
+ with open(file_path / file, 'r') as f:
144
+ for line in f:
145
+ history.append(json.loads(line))
146
+ samples = APIBankSample.from_chat_history(history)
147
+ self._data[file.rsplit('.', 1)[0]] = samples
148
+
149
+ # Change import to relative import in the downloaded python files
150
+ def process_files(folder_path, replacements):
151
+ r"""Replace absolute imports in downloaded files with
152
+ relative import."""
153
+ for file in os.listdir(folder_path):
154
+ if file.endswith(".py"):
155
+ file_path = os.path.join(folder_path, file)
156
+ try:
157
+ with open(file_path, "r", encoding="utf-8") as file:
158
+ content = file.read()
159
+
160
+ original_content = content
161
+
162
+ for pattern, replacement in replacements:
163
+ content = re.sub(pattern, replacement, content)
164
+
165
+ if content != original_content:
166
+ with open(
167
+ file_path, "w", encoding="utf-8"
168
+ ) as file:
169
+ file.write(content)
170
+ logger.info(f"Updated file: {file_path}")
171
+
172
+ except Exception as e:
173
+ logger.info(f"Error processing file {file_path}: {e}")
174
+
175
+ api_bank_folder = "api_bank"
176
+ apis_folder = os.path.join(api_bank_folder, "apis")
177
+
178
+ apis_replacements = [
179
+ (r"from apis.api", "from .api"),
180
+ (r"from apis import", "from .api import"),
181
+ ]
182
+
183
+ api_bank_replacements = [
184
+ (r"from apis", "from .apis"),
185
+ (r"from api_call_extraction", "from .api_call_extraction"),
186
+ (r"f'{basename}", r"f'api_bank.{basename}"),
187
+ ]
188
+
189
+ process_files(apis_folder, apis_replacements)
190
+ process_files(api_bank_folder, api_bank_replacements)
191
+
192
+ def run( # type: ignore[override, return]
193
+ self,
194
+ agent: ChatAgent,
195
+ level: Literal["level-1", "level-2"],
196
+ api_test_enabled=True,
197
+ randomize: bool = False,
198
+ subset: Optional[int] = None,
199
+ ) -> Dict[str, Any]:
200
+ r"""Run the benchmark.
201
+
202
+ Args:
203
+ agent (ChatAgent): The agent to run the
204
+ benchmark.
205
+ level (Literal['level-1', 'level-2']):
206
+ The level to run the benchmark on.
207
+ randomize (bool, optional): Whether to
208
+ randomize the data.
209
+ api_test_enabled (bool): Whether to test
210
+ API calling (`True`) or response (`False`)
211
+ (default: :obj:`False`)
212
+ subset (Optional[int], optional):
213
+ The subset of data to run.
214
+ (default: :obj:`None`)
215
+
216
+ Returns:
217
+ Dict[str, Any]: The results of the benchmark.
218
+ """
219
+ logger.info(f"Running APIBench benchmark on {level}.")
220
+ self.load(level)
221
+ datas = self._data
222
+
223
+ # Shuffle and subset data if necessary
224
+ if randomize:
225
+ randomized_items = list(datas.items())
226
+ random.shuffle(randomized_items)
227
+ datas = dict(randomized_items)
228
+ if subset:
229
+ datas = dict(list(datas.items())[:subset])
230
+
231
+ logger.info(f"Number of tasks: {len(datas)}")
232
+
233
+ # Initialize results storage
234
+ self._results = []
235
+
236
+ # The following code are adapted from the evaluator
237
+ # from the original repo:
238
+ tool_search_enabled = level == "level-2"
239
+ dialog_test_enabled = not api_test_enabled
240
+ total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
241
+
242
+ with open(self.save_to, "w") as f:
243
+ for test in tqdm(datas, desc="Running"):
244
+ samples = self._data[test]
245
+ evaluator = Evaluator(samples) # type: ignore[arg-type]
246
+
247
+ for sample_id in evaluator.get_all_sample_ids():
248
+ # Process sample and generate response
249
+ sample = evaluator.dataset[sample_id]
250
+
251
+ if (
252
+ sample.ground_truth['role'] == 'API'
253
+ and api_test_enabled
254
+ ):
255
+ if tool_search_enabled:
256
+ _, chat_history = evaluator.get_model_input(
257
+ sample_id
258
+ )
259
+ api_descriptions = evaluator.get_api_description(
260
+ 'ToolSearcher'
261
+ )
262
+ else:
263
+ api_descriptions, chat_history = (
264
+ evaluator.get_model_input(sample_id)
265
+ )
266
+ messages = process_messages(
267
+ chat_history, API_CALL_PROMPT + api_descriptions
268
+ )
269
+ model_output = agent_call(messages, agent)
270
+ api_call = get_api_call(model_output)
271
+
272
+ # Evaluate API call
273
+ if api_call:
274
+ try:
275
+ correct, model_output_result = (
276
+ evaluator.evaluate(sample_id, api_call)
277
+ )
278
+ except AssertionError as e:
279
+ if 'The API name is not correct.' not in str(
280
+ e
281
+ ):
282
+ raise e
283
+ logging.info('AssertionError: {}'.format(e))
284
+ correct = False
285
+ else:
286
+ model_output_result = 'No API call found'
287
+ correct = False
288
+ if correct:
289
+ correct_api_calls += 1
290
+ logging.info(
291
+ 'Correct API call: {} Ground truth: {}'.format(
292
+ api_call, sample.ground_truth
293
+ )
294
+ )
295
+ else:
296
+ logging.info(
297
+ 'Incorrect model output: {} Result: {} \
298
+ Ground truth: {} File: {} Sample ID: {} \
299
+ Messages: {}'.format(
300
+ model_output.replace('\n', ' '),
301
+ model_output_result,
302
+ sample.ground_truth,
303
+ test,
304
+ sample_id,
305
+ messages[1:],
306
+ )
307
+ )
308
+ total_api_calls += 1
309
+ self._results.append(
310
+ {
311
+ 'Role': 'API',
312
+ 'Model_output': model_output,
313
+ 'Model_output_result': model_output_result,
314
+ 'Ground_truth': sample.ground_truth,
315
+ 'Test': test,
316
+ 'Correct': correct,
317
+ }
318
+ )
319
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
320
+
321
+ elif (
322
+ sample.ground_truth['role'] == 'AI'
323
+ and dialog_test_enabled
324
+ ):
325
+ # Process sample and generate response
326
+ api_descriptions, chat_history = (
327
+ evaluator.get_model_input(sample_id)
328
+ )
329
+
330
+ messages = process_messages(
331
+ chat_history, RESPONSE_PROMPT + api_descriptions
332
+ )
333
+ model_output = agent_call(messages, agent)
334
+
335
+ # Evaluate model response
336
+ if model_output:
337
+ score = evaluator.evaluate(sample_id, model_output)
338
+ else:
339
+ score = 0
340
+ rougel_scores.append(score)
341
+ if score < 0.2:
342
+ logging.info(
343
+ 'Low score: {} Score: {} Ground truth: {} \
344
+ Test: {} Sample ID: {} \
345
+ Messages: {}'.format(
346
+ model_output.replace('\n', ' '),
347
+ score,
348
+ sample.ground_truth,
349
+ test,
350
+ sample_id,
351
+ messages[1:],
352
+ )
353
+ )
354
+
355
+ self._results.append(
356
+ {
357
+ 'Role': 'AI',
358
+ 'Model_output': model_output,
359
+ 'Score': score,
360
+ 'Ground_truth': sample.ground_truth,
361
+ 'Test': test,
362
+ }
363
+ )
364
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
365
+
366
+ f.flush()
367
+
368
+ if api_test_enabled:
369
+ return {
370
+ 'total': total_api_calls,
371
+ 'correct': correct_api_calls,
372
+ "accuracy": correct_api_calls / total_api_calls
373
+ if total_api_calls
374
+ else 0,
375
+ }
376
+ elif dialog_test_enabled:
377
+ return {'Dialog_score': np.mean(rougel_scores)}
378
+
379
+
380
+ # The following code are migrated from the original repo:
381
+ # https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
382
+ def agent_call(messages: List[Dict], agent: ChatAgent):
383
+ r"""Add messages to agent memory and get response."""
384
+ for i, msg in enumerate(messages):
385
+ if msg['role'] == 'user':
386
+ message = BaseMessage.make_user_message(
387
+ role_name="CAMEL User", content=msg['content']
388
+ )
389
+ elif msg['role'] == 'assistant':
390
+ message = BaseMessage.make_assistant_message(
391
+ role_name="CAMEL Assistant", content=msg['content']
392
+ )
393
+ elif msg['role'] == 'system':
394
+ message = BaseMessage.make_assistant_message(
395
+ role_name="System", content=msg['content']
396
+ )
397
+ else:
398
+ raise ValueError(f"Unrecognized role: {msg['role']}")
399
+
400
+ if i == len(messages) - 1:
401
+ break
402
+ agent.record_message(message)
403
+
404
+ response = agent.step(message)
405
+ model_output = response.msgs[0].content
406
+ agent.reset()
407
+ return model_output
408
+
409
+
410
+ def calculate_rouge_l_score(reference, hypothesis):
411
+ r"""Calculate rouge l score between hypothesis and reference."""
412
+ rouge = Rouge()
413
+ scores = rouge.get_scores(hypothesis, reference)
414
+ rouge_l_score = scores[0]['rouge-l']['f']
415
+ return rouge_l_score
416
+
417
+
418
+ def get_api_call(model_output):
419
+ r"""Parse api call from model output."""
420
+ api_call_pattern = r"\[(\w+)\((.*)\)\]"
421
+ api_call_pattern = re.compile(api_call_pattern)
422
+ match = api_call_pattern.search(model_output)
423
+ if match:
424
+ return match.group(0)
425
+ else:
426
+ return None
427
+
428
+
429
+ class APIBankSample:
430
+ r"""APIBank sample used to load the datasets."""
431
+
432
+ def __init__(self, chat_history, apis, ground_truth):
433
+ self.chat_history = chat_history
434
+ self.apis = apis
435
+ self.ground_truth = ground_truth
436
+
437
+ def __repr__(self):
438
+ return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
439
+ self.chat_history, self.apis, self.ground_truth
440
+ )
441
+
442
+ @classmethod
443
+ def from_chat_history(cls, chat_history):
444
+ apis = set()
445
+ api_positions = []
446
+ for i, item in enumerate(chat_history):
447
+ if item['role'] == 'API':
448
+ apis.add(item['api_name'])
449
+ api_positions.append(i)
450
+
451
+ samples = []
452
+ for i in api_positions:
453
+ sample = cls(chat_history[:i], apis, chat_history[i])
454
+ samples.append(sample)
455
+ sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
456
+ samples.append(sample)
457
+
458
+ return samples
459
+
460
+
461
+ class Evaluator:
462
+ r"""Evaluator for APIBank benchmark."""
463
+
464
+ def __init__(self, samples: List[APIBankSample]):
465
+ # Place holder for import as the import
466
+ # only works after the files have been downloaded
467
+ try:
468
+ from api_bank.tool_manager import ( # type: ignore[import-not-found]
469
+ ToolManager,
470
+ )
471
+ except Exception as e:
472
+ logger.info(f"{e}, Module will be imported after download.")
473
+ self.dataset = samples
474
+ self.sample_ids = list(range(len(self.dataset)))
475
+ os.chdir("api_bank")
476
+ self.tool_manager = ToolManager("apis")
477
+ os.chdir("..")
478
+
479
+ def get_all_sample_ids(self):
480
+ return self.sample_ids
481
+
482
+ def get_api_description(self, api_name):
483
+ return self.tool_manager.get_api_description(api_name)
484
+
485
+ def get_model_input(self, sample_id: int):
486
+ sample = self.dataset[sample_id]
487
+ apis = sample.apis
488
+ chat_history = sample.chat_history
489
+ api_descriptions = []
490
+ for api_name in apis:
491
+ api_descriptions.append(
492
+ self.tool_manager.get_api_description(api_name)
493
+ )
494
+ api_description = '\n'.join(api_descriptions)
495
+ return api_description, chat_history
496
+
497
+ def evaluate(self, sample_id, model_output):
498
+ try:
499
+ from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
500
+ parse_api_call,
501
+ )
502
+ except Exception as e:
503
+ logger.info(f"{e}, Module will be imported after download.")
504
+ sample = self.dataset[sample_id]
505
+ ground_truth = sample.ground_truth
506
+ if ground_truth['role'] == 'API':
507
+ api_name, param_dict = parse_api_call(model_output)
508
+ if api_name != ground_truth['api_name']:
509
+ return False, 'API Name Mismatch: {} vs {}'.format(
510
+ api_name, ground_truth['api_name']
511
+ )
512
+ try:
513
+ result = self.tool_manager.api_call(api_name, **param_dict)
514
+ except Exception as e:
515
+ return False, str(e)
516
+ api = self.tool_manager.init_tool(api_name)
517
+ try:
518
+ correct = api.check_api_call_correctness(
519
+ result, ground_truth['result']
520
+ )
521
+ except KeyError:
522
+ correct = False
523
+ result = 'KeyError' + str(result)
524
+ return correct, result
525
+ elif ground_truth['role'] == 'AI':
526
+ score = calculate_rouge_l_score(ground_truth['text'], model_output)
527
+ return round(score, 4)
528
+
529
+
530
+ API_CALL_PROMPT = '''
531
+ Based on the given API description and the existing \
532
+ conversation history 1..t, please generate the API request \
533
+ that the AI should call in step t+1 and output it in the \
534
+ format of [ApiName(key1='value1', key2='value2', ...)], \
535
+ replace the ApiName with the actual API name, and \
536
+ replace the key and value with the actual parameters. \
537
+ Your output should start with a square bracket "[" \
538
+ and end with a square bracket "]". Do not output any \
539
+ other explanation or prompt or the result of the API call in your output.
540
+ This year is 2023.
541
+ Input:
542
+ User: [User's utterence]
543
+ AI: [AI's utterence]
544
+
545
+ Expected output:
546
+ [ApiName(key1='value1', key2='value2', ...)]
547
+
548
+ API descriptions:
549
+ '''
550
+
551
+ RESPONSE_PROMPT = '''
552
+ Based on the given API description and the existing \
553
+ conversation history 1..t, please generate the next \
554
+ dialog that the AI should response after the API call t.
555
+ This year is 2023.
556
+ Input:
557
+ User: [User's utterence]
558
+ AI: [AI's utterence]
559
+ [ApiName(key1='value1', key2='value2', …)]
560
+
561
+ Expected output:
562
+ AI: [AI's utterence]
563
+
564
+ API descriptions:
565
+ '''
Paper2Poster/camel/benchmarks/apibench.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import random
18
+ from pathlib import Path
19
+ from typing import Any, Dict, Literal, Optional
20
+
21
+ import tree_sitter_python as tspython
22
+ from tqdm import tqdm
23
+ from tree_sitter import Language, Parser
24
+
25
+ from camel.agents import ChatAgent
26
+ from camel.benchmarks.base import BaseBenchmark
27
+ from camel.messages import BaseMessage
28
+ from camel.utils import download_github_subdirectory
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # Mapping of dataset names to file names
34
+ # 'Oracle' retriver used here which means all the full
35
+ # API documentation will be included in the prompt
36
+ dataset_mapping = {
37
+ "huggingface": {
38
+ "api": "huggingface_api.jsonl",
39
+ "eval": "huggingface_eval.json",
40
+ "train": "huggingface_train.json",
41
+ "questions": "questions_huggingface_oracle.jsonl",
42
+ },
43
+ "tensorflowhub": {
44
+ "api": "tensorflowhub_api.jsonl",
45
+ "eval": "tensorflow_eval.json",
46
+ "train": "tensorflow_train.json",
47
+ "questions": "questions_tensorflowhub_oracle.jsonl",
48
+ },
49
+ "torchhub": {
50
+ "api": "torchhub_api.jsonl",
51
+ "eval": "torchhub_eval.json",
52
+ "train": "torchhub_train.json",
53
+ "questions": "questions_torchhub_oracle.jsonl",
54
+ },
55
+ }
56
+
57
+
58
+ # This function is migrated from the original repo:
59
+ # https://github.com/ShishirPatil/gorilla
60
+ def encode_question(question: str, dataset_name: str) -> str:
61
+ r"""Encode multiple prompt instructions into a single string."""
62
+
63
+ if dataset_name == "torchhub":
64
+ domains = "1. $DOMAIN is inferred from the task description and \
65
+ should include one of {Classification, Semantic Segmentation, \
66
+ Object Detection, Audio Separation, Video Classification, \
67
+ Text-to-Speech}."
68
+ elif dataset_name == "huggingface":
69
+ domains = "1. $DOMAIN should include one of {Multimodal Feature \
70
+ Extraction, Multimodal Text-to-Image, Multimodal \
71
+ Image-to-Text, Multimodal Text-to-Video, \
72
+ Multimodal Visual Question Answering, Multimodal Document \
73
+ Question Answer, Multimodal Graph Machine Learning, \
74
+ Computer Vision Depth Estimation, Computer Vision Image \
75
+ Classification, Computer Vision Object Detection, \
76
+ Computer Vision Image Segmentation, Computer Vision \
77
+ Image-to-Image, Computer Vision Unconditional \
78
+ Image Generation, Computer Vision Video Classification, \
79
+ Computer Vision Zero-Shor Image Classification, \
80
+ Natural Language Processing Text Classification, \
81
+ Natural Language Processing Token Classification, \
82
+ Natural Language Processing Table Question Answering, \
83
+ Natural Language Processing Question Answering, \
84
+ Natural Language Processing, Zero-Shot Classification \
85
+ Natural Language Processing Translation, Natural Language \
86
+ Processing Summarization, Natural Language Processing \
87
+ Conversational, Natural Language Processing Text \
88
+ Generation, Natural Language Processing Fill-Mask, \
89
+ Natural Language Processing Text2Text Generation, \
90
+ Natural Language Processing Sentence Similarity, \
91
+ Audio Text-to-Speech, Audio Automatic Speech Recognition, \
92
+ Audio Audio-to-Audio, Audio Audio Classification, \
93
+ Audio Voice Activity Detection, Tabular Tabular \
94
+ Classification, Tabular Tabular Regression, \
95
+ Reinforcement Learning Reinforcement Learning, \
96
+ Reinforcement Learning Robotics }"
97
+ elif dataset_name == "tensorflowhub":
98
+ domains = "1. $DOMAIN is inferred from the task description \
99
+ and should include one of {text-sequence-alignment, \
100
+ text-embedding, text-language-model, text-preprocessing, \
101
+ text-classification, text-generation, text-question-answering, \
102
+ text-retrieval-question-answering, text-segmentation, \
103
+ text-to-mel, image-classification, image-feature-vector, \
104
+ image-object-detection, image-segmentation, \
105
+ image-generator, image-pose-detection, image-rnn-agent, \
106
+ image-augmentation, image-classifier, image-style-transfer, \
107
+ image-aesthetic-quality, image-depth-estimation, \
108
+ image-super-resolution, image-deblurring, image-extrapolation, \
109
+ image-text-recognition, image-dehazing, image-deraining, \
110
+ image-enhancemenmt, image-classification-logits, \
111
+ image-frame-interpolation, image-text-detection, image-denoising, \
112
+ image-others, video-classification, video-feature-extraction, \
113
+ video-generation, video-audio-text, video-text, \
114
+ audio-embedding, audio-event-classification, audio-command-detection, \
115
+ audio-paralinguists-classification, audio-speech-to-text, \
116
+ audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
117
+ else:
118
+ logger.info("Error: API name is not supported.")
119
+
120
+ prompt = (
121
+ question
122
+ + "\nWrite a python program in 1 to 2 lines to call API in "
123
+ + dataset_name
124
+ + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
125
+ <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
126
+ <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
127
+ Here are the requirements:\n"
128
+ + domains
129
+ + "\n2. The $API_CALL should have only 1 line of code \
130
+ that calls api.\n 3. The $API_PROVIDER should be the \
131
+ programming framework used.\n4. $EXPLANATION should be \
132
+ a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
133
+ Do not repeat the format in your answer."
134
+ )
135
+ return prompt
136
+
137
+
138
+ class APIBenchBenchmark(BaseBenchmark):
139
+ r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
140
+ Connected with Massive APIs`
141
+ <https://huggingface.co/datasets/gorilla-llm/APIBench>.
142
+
143
+ Args:
144
+ data_dir (str): The directory to save the data.
145
+ save_to (str): The file to save the results.
146
+ processes (int, optional): The number of processes to use.
147
+ (default: :obj:`1`)
148
+ """
149
+
150
+ # TODO: Integrate retriever (pending)
151
+
152
+ def __init__(
153
+ self,
154
+ data_dir: str,
155
+ save_to: str,
156
+ processes: int = 1,
157
+ ):
158
+ r"""Initialize the APIBench benchmark.
159
+
160
+ Args:
161
+ data_dir (str): The directory to save the data.
162
+ save_to (str): The file to save the results.
163
+ processes (int, optional): The number of processes to use for
164
+ parallel processing. (default: :obj:`1`)
165
+ """
166
+ super().__init__("apibench", data_dir, save_to, processes)
167
+
168
+ def download(self):
169
+ r"""Download the APIBench dataset."""
170
+ from huggingface_hub import snapshot_download
171
+
172
+ snapshot_download(
173
+ repo_id="gorilla-llm/APIBench",
174
+ repo_type="dataset",
175
+ local_dir=self.data_dir,
176
+ local_dir_use_symlinks=True,
177
+ )
178
+
179
+ repo = "ShishirPatil/gorilla"
180
+ subdir = "/gorilla/eval/eval-data/questions"
181
+ data_dir = self.data_dir
182
+
183
+ download_github_subdirectory(repo, subdir, data_dir)
184
+
185
+ def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
186
+ r"""Load the APIBench Benchmark dataset.
187
+
188
+ Args:
189
+ dataset_name (str): Name of the specific dataset to be loaded.
190
+ force_download (bool, optional): Whether to force
191
+ download the data. (default: :obj:`False`)
192
+ """
193
+
194
+ if force_download:
195
+ logger.info("Force downloading data.")
196
+ self.download()
197
+
198
+ def load_json_lines(file_path: Path):
199
+ r"""Helper function to load JSON lines from a file."""
200
+ try:
201
+ with open(file_path, "r") as f:
202
+ return [json.loads(line) for line in f]
203
+ except FileNotFoundError:
204
+ raise FileNotFoundError(f"File not found: {file_path}")
205
+ except json.JSONDecodeError as e:
206
+ raise ValueError(
207
+ f"Error decoding JSON in file {file_path}: {e}"
208
+ )
209
+
210
+ dataset_path = self.data_dir / dataset_name
211
+ if not dataset_path.exists():
212
+ raise FileNotFoundError(
213
+ f"Dataset directory does not exist: {dataset_path}"
214
+ )
215
+
216
+ for label in ['api', 'eval', 'questions']:
217
+ file_name = dataset_mapping[dataset_name][label]
218
+ file_path = (
219
+ dataset_path / file_name
220
+ if label == 'questions'
221
+ else self.data_dir / file_name
222
+ )
223
+
224
+ # Load data based on label type
225
+ if label in ['api', 'questions', 'eval']:
226
+ data = load_json_lines(file_path)
227
+
228
+ if label == 'eval':
229
+ # Extract 'api_data' specifically for eval label
230
+ data = [item['api_data'] for item in data]
231
+
232
+ self._data[label] = data
233
+ else:
234
+ raise ValueError(f"Unknown label: {label}")
235
+
236
+ ast_database = []
237
+ for data in self._data['api']:
238
+ ast_tree = ast_parse(data['api_call'])
239
+ ast_database.append(ast_tree)
240
+ self._data['ast'] = ast_database
241
+
242
+ def run( # type: ignore[override]
243
+ self,
244
+ agent: ChatAgent,
245
+ dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
246
+ randomize: bool = False,
247
+ subset: Optional[int] = None,
248
+ ) -> Dict[str, Any]:
249
+ r"""Run the benchmark.
250
+
251
+ Args:
252
+ agent (ChatAgent): The agent to run the
253
+ benchmark.
254
+ dataset_name (Literal["huggingface",
255
+ "tensorflowhub", "torchhub"]):
256
+ The dataset to run the benchmark.
257
+ randomize (bool, optional): Whether to randomize the data.
258
+ (default: :obj:`False`)
259
+ subset (Optional[int], optional): The subset of data to run.
260
+ (default: :obj:`None`)
261
+ """
262
+
263
+ if dataset_name not in dataset_mapping:
264
+ raise ValueError(f"Invalid value for dataset: {dataset_name}.")
265
+
266
+ logger.info(f"Running APIBench benchmark on {dataset_name}.")
267
+ self.load(dataset_name)
268
+ datas = self._data['questions']
269
+
270
+ # Shuffle and subset data if necessary
271
+ if randomize:
272
+ random.shuffle(datas)
273
+ if subset:
274
+ datas = datas[:subset]
275
+
276
+ logger.info(f"Number of tasks: {len(datas)}")
277
+
278
+ # Initialize results storage
279
+ self._results = []
280
+
281
+ with open(self.save_to, "w") as f:
282
+ for question in tqdm(datas, desc="Running"):
283
+ prompt = encode_question(question["text"], dataset_name)
284
+ msg = BaseMessage.make_user_message(
285
+ role_name="User", content=prompt
286
+ )
287
+ try:
288
+ # Generate response
289
+ responses = agent.step(msg)
290
+ response = responses.msgs[0].content
291
+ api_database = self._data['api']
292
+ qa_pairs = self._data['eval']
293
+ ast_database = self._data['ast']
294
+ question_id = question['question_id']
295
+
296
+ # Evaluate response
297
+ error, correct, hallucination = evaluate_response(
298
+ response,
299
+ question_id,
300
+ dataset_name,
301
+ api_database,
302
+ qa_pairs,
303
+ ast_database,
304
+ )
305
+ self._results.append(
306
+ {
307
+ "question": question,
308
+ "agent_response": response,
309
+ "correct": correct,
310
+ "hallucination": hallucination,
311
+ "error": str(error) if error else None,
312
+ }
313
+ )
314
+ except Exception as e:
315
+ logger.warning(
316
+ f"Error in processing task: {question}: {e}"
317
+ )
318
+ self._results.append(
319
+ {
320
+ "question": question,
321
+ "agent_response": None,
322
+ "correct": False,
323
+ "hallucination": False,
324
+ "error": str(e),
325
+ }
326
+ )
327
+
328
+ agent.reset()
329
+
330
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
331
+ f.flush()
332
+
333
+ total = len(self._results)
334
+ correct = sum(r["correct"] for r in self.results)
335
+ hallucination = sum(r["hallucination"] for r in self.results)
336
+
337
+ return {
338
+ "total": total,
339
+ "correct": correct,
340
+ "hallucination": hallucination,
341
+ "accuracy": correct / total if total else "N/A",
342
+ "hallucination rate": hallucination / total if total else "N/A",
343
+ }
344
+
345
+
346
+ # This code is modified from the
347
+ # evaluators in the original repo
348
+ # https://github.com/ShishirPatil/gorilla
349
+ # Get all the subtrees given a root_node
350
+ def get_all_sub_trees(root_node):
351
+ node_stack = []
352
+ sub_tree_sexp_list = []
353
+ depth = 1
354
+ # text = root_node.text
355
+ node_stack.append([root_node, depth])
356
+ while len(node_stack) != 0:
357
+ cur_node, cur_depth = node_stack.pop()
358
+ if cur_node.child_count > 0:
359
+ sub_tree_sexp_list.append(
360
+ [
361
+ str(cur_node),
362
+ cur_depth,
363
+ cur_node,
364
+ cur_node.children[0].text,
365
+ ]
366
+ )
367
+ else:
368
+ sub_tree_sexp_list.append(
369
+ [str(cur_node), cur_depth, cur_node, None]
370
+ )
371
+ for child_node in cur_node.children:
372
+ if len(child_node.children) != 0:
373
+ depth = cur_depth + 1
374
+ node_stack.append([child_node, depth])
375
+ return sub_tree_sexp_list
376
+
377
+
378
+ # Parse the program into AST trees
379
+ def ast_parse(candidate):
380
+ PY_LANGUAGE = Language(tspython.language())
381
+ parser = Parser(PY_LANGUAGE)
382
+
383
+ candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
384
+ return candidate_tree
385
+
386
+
387
+ # Get all the arguments in the ast tree
388
+ def get_args(node, dataset_name):
389
+ if node.child_count == 0:
390
+ return []
391
+ args_list = []
392
+ if dataset_name == "huggingface":
393
+ for child in node.children[0].children[0].children[1].children:
394
+ if "=" in child.text.decode():
395
+ args_list.append(child.children[2].text)
396
+ elif (
397
+ child.text.decode() != "("
398
+ and child.text.decode() != ")"
399
+ and child.text.decode() != ","
400
+ ):
401
+ args_list.append(child.text)
402
+ elif dataset_name == "tensorflowhub":
403
+ for child in node.children[0].children[0].children[1].children:
404
+ if (
405
+ 'model=' in child.text.decode()
406
+ or 'model =' in child.text.decode()
407
+ ):
408
+ args_list.append(child.children[2].text)
409
+ elif (
410
+ child.text.decode() != "("
411
+ and child.text.decode() != ")"
412
+ and child.text.decode() != ","
413
+ ):
414
+ args_list.append(child.text)
415
+ elif dataset_name == "torchhub":
416
+ for child in node.children[0].children[0].children[1].children:
417
+ if (
418
+ "repo_or_dir" in child.text.decode()
419
+ or "model" in child.text.decode()
420
+ ):
421
+ args_list.append(child.children[2].text)
422
+ return args_list
423
+
424
+
425
+ # Check if there is an api match
426
+ def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
427
+ for idx, base_tree in enumerate(base_tree_list):
428
+ if base_tree.children[0].children[0].child_count == 0:
429
+ continue
430
+ api_name = base_tree.children[0].children[0].children[0].text
431
+ for candidate_tree in candidate_subtree_list:
432
+ if candidate_tree[3] == api_name:
433
+ break
434
+ # Now we have a sub-tree
435
+ candidate_tree = candidate_tree[2]
436
+ args_list = get_args(base_tree, dataset_name)
437
+ if len(args_list) == 0:
438
+ continue
439
+ ast_match = True
440
+ for arg in args_list:
441
+ if (
442
+ arg.decode().lstrip("'").rstrip("'")
443
+ not in candidate_tree.text.decode()
444
+ ):
445
+ ast_match = False
446
+ break
447
+ if ast_match:
448
+ return idx
449
+ return -1
450
+
451
+
452
+ def evaluate_response(
453
+ response, question_id, dataset_name, api_database, qa_pairs, ast_database
454
+ ):
455
+ try:
456
+ # Index the "api_call" domain
457
+ output = response.split("api_call")
458
+ if len(output) == 1:
459
+ api_call = output[0]
460
+ else:
461
+ # Parse the output
462
+ output = output[1].split("api_provider")[0]
463
+ if ":" not in output:
464
+ start = 0
465
+ else:
466
+ start = output.index(":")
467
+ if ")" not in output:
468
+ end = -2
469
+ else:
470
+ end = output.rindex(")")
471
+ api_call = output[start + 2 : end + 1]
472
+
473
+ try:
474
+ ast_tree = ast_parse(api_call)
475
+ except Exception as parse_error:
476
+ print(f"Error parsing api_call: {api_call}, error: {parse_error}")
477
+ return parse_error, False, False
478
+ # Search for a subtree
479
+ ast_subtree_list = get_all_sub_trees(ast_tree)
480
+ # Check which ast tree is matching
481
+ database_index = ast_check(
482
+ ast_subtree_list, ast_database, dataset_name
483
+ )
484
+ # We cannot index this ast in our database
485
+ if database_index == -1:
486
+ halluncination = True
487
+ correct = False
488
+ # We index our reference api_call
489
+ ref_api_call = api_database[database_index]
490
+ # Check for functionality
491
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
492
+ correct = True
493
+ halluncination = False
494
+ else:
495
+ return None, False, False
496
+ except Exception as e:
497
+ print(f'Error parsing response: {response}, error: {e}')
498
+ return e, False, False
499
+
500
+ return None, correct, halluncination
Paper2Poster/camel/benchmarks/base.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import logging
16
+ from abc import ABC, abstractmethod
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Literal, Optional
19
+
20
+ from camel.agents import ChatAgent
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class BaseBenchmark(ABC):
26
+ r"""Base class for benchmarks.
27
+
28
+ Attributes:
29
+ name (str): Name of the benchmark.
30
+ data_dir (str): Path to the data directory.
31
+ save_to (str): Path to save the results.
32
+ processes (int): Number of processes to use for parallel
33
+ processing. :(default: :obj:`1`)
34
+ """
35
+
36
+ def __init__(
37
+ self, name: str, data_dir: str, save_to: str, processes: int = 1
38
+ ):
39
+ r"""Initialize the benchmark.
40
+
41
+ Args:
42
+ name (str): Name of the benchmark.
43
+ data_dir (str): Path to the data directory.
44
+ save_to (str): Path to save the results.
45
+ processes (int): Number of processes to use for parallel
46
+ processing. :(default: :obj:`1`)
47
+
48
+ """
49
+ self.name = name
50
+ self.data_dir = Path(data_dir)
51
+ self.processes = processes
52
+ self.save_to = save_to
53
+ if not self.data_dir.exists():
54
+ logger.info(
55
+ f"Data directory {data_dir} does not exist. Creating it."
56
+ )
57
+ self.data_dir.mkdir(parents=True, exist_ok=True)
58
+ if not self.data_dir.is_dir():
59
+ raise NotADirectoryError(
60
+ f"Data directory {data_dir} is not a directory"
61
+ )
62
+ self._data: Dict[str, List[Dict[str, Any]]] = dict()
63
+ self._results: List[Dict[str, Any]] = []
64
+
65
+ @abstractmethod
66
+ def download(self) -> "BaseBenchmark":
67
+ r"""Download the benchmark data.
68
+
69
+ Returns:
70
+ BaseBenchmark: The benchmark instance.
71
+ """
72
+ pass
73
+
74
+ @abstractmethod
75
+ def load(self, force_download: bool = False) -> "BaseBenchmark":
76
+ r"""Load the benchmark data.
77
+
78
+ Args:
79
+ force_download (bool): Whether to force download the data.
80
+
81
+ Returns:
82
+ BaseBenchmark: The benchmark instance.
83
+ """
84
+ pass
85
+
86
+ @property
87
+ def train(self) -> List[Dict[str, Any]]:
88
+ r"""Get the training data.
89
+
90
+ Returns:
91
+ List[Dict[str, Any]]: The training data.
92
+ """
93
+ if not self._data:
94
+ logger.info("Data not loaded. Loading data.")
95
+ self.load()
96
+ return self._data["train"]
97
+
98
+ @property
99
+ def valid(self) -> List[Dict[str, Any]]:
100
+ r"""Get the validation data.
101
+
102
+ Returns:
103
+ List[Dict[str, Any]]: The validation data.
104
+ """
105
+ if not self._data:
106
+ logger.info("Data not loaded. Loading data.")
107
+ self.load()
108
+ return self._data["valid"]
109
+
110
+ @property
111
+ def test(self) -> List[Dict[str, Any]]:
112
+ r"""Get the test data.
113
+
114
+ Returns:
115
+ List[Dict[str, Any]]: The test data.
116
+ """
117
+ if not self._data:
118
+ logger.info("Data not loaded. Loading data.")
119
+ self.load()
120
+ return self._data["test"]
121
+
122
+ @abstractmethod
123
+ def run(
124
+ self,
125
+ agent: ChatAgent,
126
+ on: Literal["train", "valid", "test"],
127
+ randomize: bool = False,
128
+ subset: Optional[int] = None,
129
+ *args,
130
+ **kwargs,
131
+ ) -> "BaseBenchmark":
132
+ r"""Run the benchmark.
133
+
134
+ Args:
135
+ agent (ChatAgent): The chat agent.
136
+ on (str): The data split to run the benchmark on.
137
+ randomize (bool): Whether to randomize the data.
138
+ subset (int): The subset of the data to run the benchmark on.
139
+
140
+ Returns:
141
+ BaseBenchmark: The benchmark instance.
142
+ """
143
+ pass
144
+
145
+ @property
146
+ def results(self) -> List[Dict[str, Any]]:
147
+ r"""Get the results.
148
+
149
+ Returns:
150
+ List[Dict[str, Any]]: The results.
151
+ """
152
+ return self._results
Paper2Poster/camel/benchmarks/gaia.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import random
19
+ import re
20
+ import string
21
+ import uuid
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Literal, Optional, Protocol, Union
24
+
25
+ from tqdm import tqdm
26
+
27
+ from camel.agents import ChatAgent
28
+ from camel.benchmarks.base import BaseBenchmark
29
+ from camel.messages import BaseMessage
30
+ from camel.retrievers.auto_retriever import AutoRetriever
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class RetrieverProtocol(Protocol):
36
+ r"""Protocol for the retriever class. Any retriever class implementing
37
+ this protocol can be used in the benchmark class.
38
+ """
39
+
40
+ def retrieve(
41
+ self, query: str, contents: List[str], **kwargs: Dict[str, Any]
42
+ ) -> Dict[str, Any]:
43
+ r"""Retrieve the relevant content for the query.
44
+
45
+ Args:
46
+ query (str): The query to retrieve the content for.
47
+ contents (List[str]): The list of contents to search in.
48
+ **kwargs (Dict[str, Any]): Additional keyword arguments.
49
+
50
+ Returns:
51
+ Dict[str, Any]: The relevant content for the query.
52
+ """
53
+ ...
54
+
55
+ def reset(self, **kwargs) -> bool:
56
+ r"""Reset the retriever.
57
+ Some benchmarks may require resetting the retriever
58
+ after each query.
59
+
60
+ Args:
61
+ **kwargs: Additional keyword arguments.
62
+
63
+ Returns:
64
+ bool: True if the reset was successful, False otherwise.
65
+ """
66
+ ...
67
+
68
+
69
+ class DefaultGAIARetriever(AutoRetriever):
70
+ r"""Default retriever for the GAIA benchmark.
71
+ This retriever uses AutoRetriever in camel to retrieve the content based on
72
+ the query.
73
+ """
74
+
75
+ def retrieve(
76
+ self, query: str, contents: List[str], **kwargs: Any
77
+ ) -> Dict[str, Any]:
78
+ r"""Retrieve the content based on the query.
79
+
80
+ Args:
81
+ query (str): The query to search for.
82
+ contents (List[str]): The list of contents to search from.
83
+ **kwargs (Any): The keyword arguments to pass to the
84
+ retriever.
85
+
86
+ Returns:
87
+ Dict[str, Any]: The retrieved content.
88
+ """
89
+ return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
90
+
91
+ def reset(self, **kwargs: Any) -> bool:
92
+ r"""Reset the retriever.
93
+
94
+ Args:
95
+ **kwargs (Any): The keyword arguments to pass to the
96
+ retriever.
97
+
98
+ Returns:
99
+ bool: Whether the reset was successful.
100
+ """
101
+ path = Path(self.vector_storage_local_path or os.getcwd())
102
+ task_id = str(kwargs.get("task_id", uuid.uuid4()))
103
+ retriever_dir = path / task_id
104
+ if not retriever_dir.exists():
105
+ try:
106
+ retriever_dir.mkdir(parents=True)
107
+ except Exception as e:
108
+ logger.error(
109
+ "Error in creating directory: " + f"{retriever_dir}: {e!s}"
110
+ )
111
+ return False
112
+ self.vector_storage_local_path = str(retriever_dir)
113
+ return True
114
+
115
+
116
+ class GAIABenchmark(BaseBenchmark):
117
+ r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
118
+ Assistants"
119
+ <https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
120
+
121
+ Args:
122
+ data_dir (str): The directory to save the data.
123
+ save_to (str): The file to save the results.
124
+ retriever (Optional[RetrieverProtocol]): The retriever to use.
125
+ (default: :obj:`None`)
126
+ processes (int, optional): The number of processes to use.
127
+ (default: :obj:`1`)
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ data_dir: str,
133
+ save_to: str,
134
+ retriever: Optional[RetrieverProtocol] = None,
135
+ processes: int = 1,
136
+ ):
137
+ r"""Initialize the GAIA benchmark.
138
+
139
+ Args:
140
+ data_dir (str): The directory to save the data.
141
+ save_to (str): The file to save the results.
142
+ retriever (Optional[RetrieverProtocol], optional): The retriever to
143
+ use. (default: :obj:`None`)
144
+ processes (int, optional): The number of processes to use for
145
+ parallel processing. (default: :obj:`1`)
146
+ """
147
+ super().__init__("gaia", data_dir, save_to, processes)
148
+ self.retriever = retriever or DefaultGAIARetriever()
149
+
150
+ def download(self):
151
+ r"""Download the GAIA dataset."""
152
+ from huggingface_hub import snapshot_download
153
+
154
+ snapshot_download(
155
+ repo_id="gaia-benchmark/GAIA",
156
+ repo_type="dataset",
157
+ local_dir=self.data_dir,
158
+ local_dir_use_symlinks=True,
159
+ )
160
+
161
+ def load(self, force_download=False):
162
+ r"""Load the GAIA dataset.
163
+
164
+ Args:
165
+ force_download (bool, optional): Whether to
166
+ force download the data.
167
+ """
168
+ if force_download:
169
+ logger.info("Force downloading data.")
170
+ self.download()
171
+
172
+ # Define validation and test directories
173
+ valid_dir = self.data_dir / "2023/validation"
174
+ test_dir = self.data_dir / "2023/test"
175
+
176
+ # Check if directories exist; if not, download the data
177
+ if not valid_dir.is_dir() or not test_dir.is_dir():
178
+ logger.info("Data not found. Downloading data.")
179
+ self.download()
180
+
181
+ # Load metadata for both validation and test datasets
182
+ for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
183
+ self._data[label] = []
184
+ with open(path / "metadata.jsonl", "r") as f:
185
+ lines = f.readlines()
186
+ for line in lines:
187
+ data = json.loads(line)
188
+ if data["task_id"] == "0-0-0-0-0":
189
+ continue
190
+ if data["file_name"]:
191
+ data["file_name"] = path / data["file_name"]
192
+ self._data[label].append(data)
193
+ return self
194
+
195
+ @property
196
+ def train(self):
197
+ r"""Get the training set."""
198
+ raise NotImplementedError("GAIA does not have a training set.")
199
+
200
+ def run( # type: ignore[override]
201
+ self,
202
+ agent: ChatAgent,
203
+ on: Literal["train", "valid", "test"],
204
+ level: Union[int, List[int], Literal["all"]],
205
+ randomize: bool = False,
206
+ subset: Optional[int] = None,
207
+ ) -> Dict[str, Any]:
208
+ r"""Run the benchmark.
209
+
210
+ Args:
211
+ agent (ChatAgent): The agent to run the benchmark.
212
+ on (Literal["valid", "test"]): The set to run the benchmark.
213
+ level (Union[int, List[int], Literal["all"]]): The level to run
214
+ the benchmark.
215
+ randomize (bool, optional): Whether to randomize the data.
216
+ (default: :obj:`False`)
217
+ subset (Optional[int], optional): The subset of data to run.
218
+ (default: :obj:`None`)
219
+
220
+ Returns:
221
+ Dict[str, Any]: The results of the benchmark.
222
+ """
223
+ # Validate inputs
224
+ if on not in ["valid", "test"]:
225
+ raise ValueError(
226
+ f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
227
+ )
228
+
229
+ levels = (
230
+ [1, 2, 3]
231
+ if level == "all"
232
+ else [level]
233
+ if isinstance(level, int)
234
+ else level
235
+ )
236
+ if not all(
237
+ isinstance(level, int) and level in [1, 2, 3] for level in levels
238
+ ):
239
+ raise ValueError(
240
+ f"Invalid value for `level`: {level}, expected 1, 2, 3 "
241
+ "or 'all'."
242
+ )
243
+
244
+ logger.info(f"Running benchmark on {on} set at levels {levels}.")
245
+ datas = [data for data in self._data[on] if data["Level"] in levels]
246
+
247
+ # Shuffle and subset data if necessary
248
+ if randomize:
249
+ random.shuffle(datas)
250
+ if subset:
251
+ datas = datas[:subset]
252
+
253
+ logger.info(f"Number of tasks: {len(datas)}")
254
+
255
+ # Initialize results storage
256
+ self._results = []
257
+
258
+ # Process tasks
259
+ with open(self.save_to, "w") as f:
260
+ for task in tqdm(datas, desc="Running"):
261
+ if not self._prepare_task(task):
262
+ continue
263
+
264
+ try:
265
+ result = agent.step(self._create_user_message(task))
266
+ self._process_result(agent, task, result, f)
267
+ except Exception as e:
268
+ self._handle_error(task, e, f)
269
+ finally:
270
+ agent.reset()
271
+
272
+ return self._generate_summary()
273
+
274
+ def _prepare_task(self, task: Dict[str, Any]) -> bool:
275
+ r"""Prepare the task by validating and enriching its data."""
276
+ if task["file_name"]:
277
+ file_path = Path(task["file_name"])
278
+ if not file_path.exists():
279
+ logger.info(
280
+ f"Skipping task because file not found: {file_path}"
281
+ )
282
+ return False
283
+ if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
284
+ if not self.retriever.reset(task_id=task["task_id"]):
285
+ return False
286
+ retrieved_info = self.retriever.retrieve(
287
+ query=task["Question"], contents=[task["file_name"]]
288
+ )
289
+ retrieved_content = [
290
+ item["text"]
291
+ for item in retrieved_info.get("Retrieved Context", [])
292
+ ]
293
+ if retrieved_content:
294
+ task["Question"] += "\n" + "\n".join(retrieved_content)
295
+ else:
296
+ logger.info(
297
+ f"Skipping task due to unsupported file "
298
+ f"format: {file_path.suffix}"
299
+ )
300
+ return False
301
+ return True
302
+
303
+ def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
304
+ r"""Create a user message from a task."""
305
+ return BaseMessage.make_user_message(
306
+ role_name="User",
307
+ content=task["Question"],
308
+ )
309
+
310
+ def _process_result(
311
+ self,
312
+ agent: ChatAgent,
313
+ task: Dict[str, Any],
314
+ result: Any,
315
+ file_obj: Any,
316
+ ) -> None:
317
+ r"""Process and store the result of a task."""
318
+ model_answer = self.get_final_answer(result.msgs[0].content)
319
+ final_answer = task["Final answer"]
320
+ score = self.question_scorer(model_answer, final_answer)
321
+ tool_calls = result.info.get("tool_calls", [])
322
+
323
+ result_data = {
324
+ "task_id": task["task_id"],
325
+ "question": task["Question"],
326
+ "level": task["Level"],
327
+ "model_answer": model_answer,
328
+ "ground_truth": final_answer,
329
+ "tool_calls": [tool.model_dump() for tool in tool_calls],
330
+ "error": None,
331
+ "score": int(score),
332
+ "history": agent.memory.get_context(),
333
+ }
334
+ self._results.append(result_data)
335
+ file_obj.write(json.dumps(result_data, indent=2) + "\n")
336
+ file_obj.flush()
337
+
338
+ def _handle_error(
339
+ self, task: Dict[str, Any], error: Exception, file_obj: Any
340
+ ) -> None:
341
+ r"""Handle errors encountered during task processing."""
342
+ logger.warning(f"Error processing task {task['task_id']}: {error}")
343
+ error_data = {
344
+ "task_id": task["task_id"],
345
+ "question": task["Question"],
346
+ "level": task["Level"],
347
+ "model_answer": "ERROR",
348
+ "ground_truth": task["Final answer"],
349
+ "tool_calls": [],
350
+ "error": str(error),
351
+ "score": 0,
352
+ }
353
+ self._results.append(error_data)
354
+ file_obj.write(json.dumps(error_data, indent=2) + "\n")
355
+ file_obj.flush()
356
+
357
+ def _generate_summary(self) -> Dict[str, Any]:
358
+ r"""Generate and return a summary of the benchmark results."""
359
+ return {
360
+ "total": len(self._results),
361
+ "correct": sum(result["score"] for result in self._results),
362
+ "results": self._results,
363
+ }
364
+
365
+ def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
366
+ r"""Scorer for the GAIA benchmark.
367
+ https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
368
+ scorer.py
369
+
370
+ Args:
371
+ model_answer (str): The model answer.
372
+ ground_truth (str): The ground truth answer.
373
+
374
+ Returns:
375
+ bool: The score of the model
376
+ """
377
+
378
+ def is_float(element: Any) -> bool:
379
+ try:
380
+ float(element)
381
+ return True
382
+ except ValueError:
383
+ return False
384
+
385
+ if is_float(ground_truth):
386
+ logger.info(f"Evaluating {model_answer} as a number.")
387
+ normalized_answer = self.normalize_number_str(model_answer)
388
+ return normalized_answer == float(ground_truth)
389
+
390
+ elif any(char in ground_truth for char in [",", ";"]):
391
+ logger.info(
392
+ f"Evaluating {model_answer} as a comma separated list."
393
+ )
394
+ gt_elems = self.split_string(ground_truth)
395
+ ma_elems = self.split_string(model_answer)
396
+
397
+ if len(gt_elems) != len(ma_elems):
398
+ logger.warning(
399
+ "Answer lists have different lengths, returning False.",
400
+ UserWarning,
401
+ )
402
+ return False
403
+
404
+ comparisons = []
405
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
406
+ if is_float(gt_elem):
407
+ normalized_ma_elem = self.normalize_number_str(ma_elem)
408
+ comparisons.append(normalized_ma_elem == float(gt_elem))
409
+ else:
410
+ ma_elem = self.normalize_str(ma_elem, remove_punct=False)
411
+ gt_elem = self.normalize_str(gt_elem, remove_punct=False)
412
+ comparisons.append(ma_elem == gt_elem)
413
+ return all(comparisons)
414
+ else:
415
+ logger.info(f"Evaluating {model_answer} as a string.")
416
+ ma_elem = self.normalize_str(model_answer)
417
+ gt_elem = self.normalize_str(ground_truth)
418
+ return ma_elem == gt_elem
419
+
420
+ def normalize_number_str(self, number_str: str) -> float:
421
+ for char in ["$", "%", ","]:
422
+ number_str = number_str.replace(char, "")
423
+ try:
424
+ return float(number_str)
425
+ except ValueError:
426
+ logger.error(
427
+ f"String {number_str} cannot be normalized to number str."
428
+ )
429
+ return float("inf")
430
+
431
+ def split_string(
432
+ self, s: str, char_list: Optional[List[str]] = None
433
+ ) -> list[str]:
434
+ r"""Split a string based on a list of characters.
435
+
436
+ Args:
437
+ s (str): The string to split.
438
+ char_list (Optional[List[str]], optional): T
439
+ he list of characters to split on.
440
+ (default: :obj:`None`)
441
+ """
442
+ if char_list is None:
443
+ char_list = [",", ";"]
444
+ pattern = f"[{''.join(char_list)}]"
445
+ return re.split(pattern, s)
446
+
447
+ def normalize_str(self, input_str, remove_punct=True) -> str:
448
+ r"""Normalize a string.
449
+
450
+ Args:
451
+ input_str: The input string to normalize.
452
+ remove_punct: Whether to remove punctuation.
453
+
454
+ Returns:
455
+ str: The normalized string.
456
+ """
457
+ no_spaces = re.sub(r"\s", "", input_str)
458
+ if remove_punct:
459
+ translator = str.maketrans("", "", string.punctuation)
460
+ return no_spaces.lower().translate(translator)
461
+ else:
462
+ return no_spaces.lower()
463
+
464
+ def get_final_answer(self, content: str) -> str:
465
+ r"""Get the final answer from the content.
466
+
467
+ Args:
468
+ content (str): The content to extract the final answer from.
469
+
470
+ Returns:
471
+ str: The final answer.
472
+ """
473
+ final_answer_index = content.find("FINAL ANSWER")
474
+ if final_answer_index == -1:
475
+ return "FINAL ANSWER not found"
476
+ start_index = final_answer_index + len("FINAL ANSWER: ")
477
+ final_answer_content = content[start_index:].strip()
478
+ return final_answer_content
Paper2Poster/camel/benchmarks/nexus.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import ast
16
+ import json
17
+ import logging
18
+ import os
19
+ import random
20
+ import textwrap
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
24
+
25
+ import pandas as pd
26
+ from datasets import load_dataset
27
+ from tqdm import tqdm
28
+
29
+ from camel.agents import ChatAgent
30
+ from camel.benchmarks.base import BaseBenchmark
31
+ from camel.messages import BaseMessage
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # Define the data class
37
+ @dataclass
38
+ class NexusSample:
39
+ r"""Nexus benchmark dataset sample."""
40
+
41
+ input: str
42
+ output: str
43
+
44
+
45
+ @dataclass
46
+ class NexusTool:
47
+ r"""Nexus benchmark tool"""
48
+
49
+ function_calls: str
50
+ descriptions: str
51
+
52
+
53
+ dataset_mapping = {
54
+ "NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
55
+ "VirusTotal": "Nexusflow/VirusTotalBenchmark",
56
+ "PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
57
+ "ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
58
+ "OTX": "Nexusflow/OTXAPIBenchmark",
59
+ "VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
60
+ "VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
61
+ "NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
62
+ }
63
+
64
+ TOOL_CALLING_PROMPT = """
65
+ You are given multiple functions and a user query.
66
+
67
+ Please proceed with generating a function call for the function \
68
+ with the proper arguments that best answers the given prompt.
69
+
70
+ Respond with nothing but the function call ONLY, such that I can \
71
+ directly execute your function call without any post processing \
72
+ necessary from my end. Do not use variables.
73
+ If there are more than two function calls, separate them with a semicolon (;).
74
+
75
+ {tools}
76
+
77
+ Question: {input}
78
+ """
79
+
80
+
81
+ class NexusBenchmark(BaseBenchmark):
82
+ r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
83
+ Function Calling Benchmark`
84
+ <https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
85
+
86
+ Args:
87
+ data_dir (str): The directory to save the data.
88
+ save_to (str): The file to save the results.
89
+ processes (int, optional): The number of processes to use.
90
+ (default: :obj:`1`)
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ data_dir: str,
96
+ save_to: str,
97
+ processes: int = 1,
98
+ ):
99
+ r"""Initialize the Nexus Function Calling benchmark.
100
+
101
+ Args:
102
+ data_dir (str): The directory to save the data.
103
+ save_to (str): The file to save the results.
104
+ processes (int, optional): The number of processes to use for
105
+ parallel processing. (default: :obj:`1`)
106
+ """
107
+ super().__init__("nexus", data_dir, save_to, processes)
108
+ self._data: List[NexusSample] = [] # type: ignore[assignment]
109
+
110
+ def download(self):
111
+ r"""Download the Nexus Functional Calling Benchmark dataset."""
112
+ from huggingface_hub import snapshot_download
113
+
114
+ for dataset_name, repo_id in dataset_mapping.items():
115
+ local_dir = self.data_dir / dataset_name
116
+ snapshot_download(
117
+ repo_id=repo_id,
118
+ repo_type="dataset",
119
+ local_dir=local_dir,
120
+ local_dir_use_symlinks=True,
121
+ )
122
+
123
+ def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
124
+ r"""Load the Nexus Benchmark dataset.
125
+
126
+ Args:
127
+ dataset_name (str): Name of the specific dataset to be loaded.
128
+ force_download (bool): Whether to force download the data.
129
+ """
130
+
131
+ def _load_csv_data(dataset_dir: Path) -> List:
132
+ r"""Load datasets from CSV files."""
133
+ dataset = []
134
+ for file_name in os.listdir(dataset_dir):
135
+ file_path = dataset_dir / file_name
136
+ if file_name.endswith(".csv"):
137
+ data = pd.read_csv(file_path)
138
+ for _, sample in data.iterrows():
139
+ dataset.append(
140
+ NexusSample(
141
+ sample["Input"], "".join(sample["Output"])
142
+ )
143
+ )
144
+ continue
145
+
146
+ logger.warning(f"Skipping unsupported file: {file_name}")
147
+ return dataset
148
+
149
+ def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
150
+ r"""Load datasets from Parquet files."""
151
+ dataset = []
152
+ if not data_dir.exists():
153
+ raise FileNotFoundError(
154
+ f"Data directory '{data_dir}' does not exist."
155
+ )
156
+
157
+ for file_name in os.listdir(data_dir):
158
+ file_path = data_dir / file_name
159
+ if file_name.endswith(".parquet"):
160
+ data = pd.read_parquet(file_path)
161
+ dataset.extend(_process_parquet_data(data, dataset_name))
162
+ continue
163
+
164
+ logger.warning(f"Skipping unsupported file: {file_name}")
165
+
166
+ return dataset
167
+
168
+ def _process_parquet_data(
169
+ data: pd.DataFrame, dataset_name: str
170
+ ) -> List:
171
+ r"""Process data from Parquet files based on dataset name."""
172
+ dataset: List = []
173
+ dataset_handlers = {
174
+ "NVDLibrary": _process_nvdlibrary,
175
+ "VirusTotal": _process_simple,
176
+ "PlacesAPI": _process_simple,
177
+ "ClimateAPI": _process_simple,
178
+ "OTX": _process_simple,
179
+ "VirusTotal-NestedCalls": _process_nested_calls,
180
+ "VirusTotal-ParallelCalls": _process_parallel_calls,
181
+ }
182
+
183
+ if dataset_name not in dataset_handlers:
184
+ logger.warning(
185
+ f"No specific handler for dataset: {dataset_name}"
186
+ )
187
+ return dataset
188
+
189
+ handler = dataset_handlers[dataset_name]
190
+ for _, sample in data.iterrows():
191
+ processed_sample = handler(sample)
192
+ if processed_sample:
193
+ dataset.append(processed_sample)
194
+ return dataset
195
+
196
+ def _process_nvdlibrary(sample) -> NexusSample:
197
+ r"""Process samples for the NVDLibrary dataset."""
198
+ return NexusSample(
199
+ sample["Input"], sample["Output"].replace("r = nvdlib.", "")
200
+ )
201
+
202
+ def _process_simple(sample) -> NexusSample:
203
+ r"""Process samples for simple datasets (e.g., VirusTotal)."""
204
+ return NexusSample(sample["Input"], sample["Output"])
205
+
206
+ def _process_nested_calls(sample) -> Union[NexusSample, None]:
207
+ r"""Process samples for VirusTotal-NestedCalls dataset."""
208
+ if len(sample["fncall"]) == 1:
209
+ return NexusSample(
210
+ sample["generated_question"], "".join(sample["fncall"])
211
+ )
212
+ return None
213
+
214
+ def _process_parallel_calls(sample) -> Union[NexusSample, None]:
215
+ r"""Process samples for VirusTotal-ParallelCalls dataset."""
216
+ if len(sample["fncall"]) > 1:
217
+ return NexusSample(
218
+ sample["generated_question"], "; ".join(sample["fncall"])
219
+ )
220
+ return None
221
+
222
+ if force_download:
223
+ logger.info("Force downloading data.")
224
+ self.download()
225
+
226
+ # Validate dataset name
227
+ if dataset_name not in dataset_mapping:
228
+ available_datasets = list(dataset_mapping.keys())
229
+ raise ValueError(
230
+ f"Dataset '{dataset_name}' is not recognized. "
231
+ f"Available datasets: {available_datasets}"
232
+ )
233
+
234
+ # Get the dataset directory
235
+ dataset_dir = self.data_dir / dataset_name
236
+ if not dataset_dir.exists():
237
+ raise FileNotFoundError(
238
+ f"The dataset directory for '{dataset_name}' \
239
+ does not exist at {dataset_dir}. "
240
+ "Please download it first."
241
+ )
242
+
243
+ # Load the dataset
244
+ if dataset_name == "NVDLibrary-NestedCalls":
245
+ self._data = _load_csv_data(dataset_dir)
246
+ else:
247
+ self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
248
+
249
+ @property
250
+ def train(self):
251
+ r"""Get the training set."""
252
+ raise NotImplementedError(
253
+ "Nexus Functional Calling has only a single 'train' set."
254
+ )
255
+
256
+ def run( # type: ignore[override, return]
257
+ self,
258
+ agent: ChatAgent,
259
+ task: Literal[
260
+ "NVDLibrary",
261
+ "VirusTotal",
262
+ "OTX",
263
+ "PlacesAPI",
264
+ "ClimateAPI",
265
+ "VirusTotal-ParallelCalls",
266
+ "VirusTotal-NestedCalls",
267
+ "NVDLibrary-NestedCalls",
268
+ ],
269
+ randomize: bool = False,
270
+ subset: Optional[int] = None,
271
+ ) -> Dict[str, Any]:
272
+ r"""Run the benchmark.
273
+
274
+ Args:
275
+ agent (ChatAgent): The agent to run the benchmark.
276
+ task (Literal["NVDLibrary", "VirusTotal", "OTX",
277
+ "PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
278
+ "VirusTotal-NestedCalls",
279
+ "NVDLibrary-NestedCalls"]): The task to run the benchmark.
280
+ randomize (bool, optional): Whether to randomize the data.
281
+ (default: :obj:`False`)
282
+ subset (Optional[int], optional): The subset of data to run.
283
+ (default: :obj:`None`)
284
+
285
+ Returns:
286
+ Dict[str, Any]: The results of the benchmark.
287
+ """
288
+
289
+ if task not in dataset_mapping:
290
+ raise ValueError(f"Invalid value for dataset: {task}.")
291
+
292
+ logger.info(f"Running Nexus Function Calling benchmark on {task}.")
293
+ self.load(task)
294
+ datas = self._data
295
+
296
+ # Shuffle and subset data if necessary
297
+ if randomize:
298
+ random.shuffle(datas)
299
+ if subset:
300
+ datas = datas[:subset]
301
+
302
+ logger.info(f"Number of tasks: {len(datas)}")
303
+
304
+ # Initialize results storage
305
+ self._results = []
306
+
307
+ # Process samples
308
+ tools = construct_tool_descriptions(task)
309
+ with open(self.save_to, "w") as f:
310
+ for sample in tqdm(datas, desc="Running"):
311
+ prompt = construct_prompt(input=sample.input, tools=tools)
312
+ msg = BaseMessage.make_user_message(
313
+ role_name="User", content=prompt
314
+ )
315
+ ground_truth_call = sample.output
316
+ try:
317
+ # Generate response
318
+ response = agent.step(msg)
319
+ agent_call = response.msgs[0].content
320
+
321
+ # Evaluate response
322
+ if agent_call:
323
+ result = compare_function_calls(
324
+ agent_call=agent_call,
325
+ ground_truth_call=ground_truth_call,
326
+ )
327
+ self._results.append(
328
+ {
329
+ "input": sample.input,
330
+ "agent_call": agent_call,
331
+ "ground_truth_call": ground_truth_call,
332
+ "result": result,
333
+ "error": None,
334
+ }
335
+ )
336
+ except Exception as e:
337
+ logger.warning(f"Error in processing task: {sample.input}")
338
+ self._results.append(
339
+ {
340
+ "input": sample.input,
341
+ "agent_call": None,
342
+ "ground_truth_call": ground_truth_call,
343
+ "result": 0,
344
+ "error": str(e),
345
+ }
346
+ )
347
+
348
+ agent.reset()
349
+
350
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
351
+ f.flush()
352
+
353
+ total = len(self._results)
354
+ correct = sum(r["result"] for r in self._results)
355
+
356
+ return {
357
+ "total": total,
358
+ "correct": correct,
359
+ "accuracy": correct / total,
360
+ }
361
+
362
+
363
+ # Utility functions
364
+ def construct_tool_descriptions(dataset_name: str) -> str:
365
+ r"""Construct tool descriptions from function definitions and
366
+ descriptions."""
367
+ tool_dataset_mapping = {
368
+ "NVDLibrary": "CVECPE",
369
+ "VirusTotal": "VirusTotal",
370
+ "PlacesAPI": "Places",
371
+ "ClimateAPI": "Climate",
372
+ "OTX": "OTX",
373
+ "VirusTotal-NestedCalls": "VT_Multi (Nested)",
374
+ "VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
375
+ "NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
376
+ }
377
+
378
+ if dataset_name not in tool_dataset_mapping:
379
+ raise ValueError(
380
+ f"Dataset '{dataset_name}' is not recognized. "
381
+ f"Available datasets: {list(dataset_mapping.keys())}"
382
+ )
383
+
384
+ # Load the dataset based on the dataset name
385
+ dataset = load_dataset(
386
+ "Nexusflow/Function_Call_Definitions",
387
+ name=tool_dataset_mapping[dataset_name],
388
+ )["train"]
389
+
390
+ # Construct tool descriptions
391
+ tools = [
392
+ NexusTool(tool["function_calls"], tool["descriptions"])
393
+ for tool in dataset
394
+ ]
395
+
396
+ # Generate the tool prompt
397
+ tool_prompt = "".join(
398
+ f"Function:\ndef {tool.function_calls}:\n"
399
+ + "\"\"\"\n"
400
+ + f"{tool.descriptions}\n"
401
+ + "\"\"\"\n"
402
+ for tool in tools
403
+ )
404
+
405
+ return tool_prompt
406
+
407
+
408
+ def construct_prompt(input: str, tools: str) -> str:
409
+ r"Construct prompt from tools and input."
410
+ return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
411
+
412
+
413
+ # Functions for function call evaluation
414
+ def parse_function_call(
415
+ call: str,
416
+ ) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
417
+ r"""Parse a function call string to extract the function name,
418
+ positional arguments, and keyword arguments, including
419
+ nested function calls.
420
+
421
+ Args:
422
+ call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
423
+
424
+ Returns:
425
+ tuple: (function_name (str), positional_args (list),
426
+ keyword_args (dict)) or (None, None, None).
427
+ """
428
+
429
+ def preprocess_input(call: str) -> str:
430
+ r"""Remove formatting like code blocks and whitespace."""
431
+ if call.strip().startswith("```python"):
432
+ call = call.strip().removeprefix("```python").removesuffix("```")
433
+ return textwrap.dedent(call).strip()
434
+
435
+ def evaluate_arg(arg):
436
+ r"""Recursively evaluate arguments, including nested calls."""
437
+ if isinstance(arg, ast.Call):
438
+ # Recursively parse nested calls
439
+ func_name, args, kwargs = parse_function_call(ast.unparse(arg))
440
+ return func_name, args, kwargs
441
+ elif isinstance(
442
+ arg, ast.Constant
443
+ ): # Handle literals like numbers, strings, etc.
444
+ return arg.value
445
+ elif isinstance(arg, ast.List): # Handle list literals
446
+ return [evaluate_arg(el) for el in arg.elts]
447
+ elif isinstance(arg, ast.Dict): # Handle dictionary literals
448
+ return {
449
+ evaluate_arg(k): evaluate_arg(v)
450
+ for k, v in zip(arg.keys, arg.values)
451
+ }
452
+ elif isinstance(arg, ast.Tuple): # Handle tuple literals
453
+ return tuple(evaluate_arg(el) for el in arg.elts)
454
+ else:
455
+ return ast.literal_eval(arg) # Safely evaluate other types
456
+
457
+ call = preprocess_input(call)
458
+ parsed_calls = []
459
+
460
+ try:
461
+ # Parse the string into an AST
462
+ parsed_calls = call.split(";")
463
+ for single_call in parsed_calls:
464
+ tree = ast.parse(single_call, mode='eval')
465
+
466
+ # Ensure it's a function call
467
+ if isinstance(tree.body, ast.Call):
468
+ # Extract function name
469
+ if isinstance(
470
+ tree.body.func, ast.Name
471
+ ): # Simple function call
472
+ func_name = tree.body.func.id
473
+ elif isinstance(
474
+ tree.body.func, ast.Attribute
475
+ ): # Attribute function call
476
+ func_name = (
477
+ f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
478
+ )
479
+ else:
480
+ raise ValueError(f"Unsupported function call: {call}")
481
+
482
+ # Extract positional arguments
483
+ args = [evaluate_arg(arg) for arg in tree.body.args]
484
+
485
+ # Extract keyword arguments
486
+ kwargs: Dict[str, Any] = {
487
+ kw.arg: evaluate_arg(kw.value)
488
+ for kw in tree.body.keywords
489
+ if kw.arg is not None
490
+ }
491
+ logger.info("Valid call.")
492
+ return func_name, args, kwargs
493
+ else:
494
+ raise ValueError(f"Not a valid function call: {call}")
495
+ except Exception as e:
496
+ logger.info(f"Error parsing call: {call}, {e}")
497
+ return None, None, None
498
+
499
+
500
+ def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
501
+ r"""Compare the function name and arguments of
502
+ agent_call and ground_truth_call.
503
+ Args:
504
+ agent_call (str): Function call by agent.
505
+ ground_truth_call (str): Ground truth function call.
506
+
507
+ Returns:
508
+ - `True` if the function names and arguments match.
509
+ - `False` otherwise.
510
+ """
511
+ # Parse both calls
512
+ agent_parsed = parse_function_call(agent_call)
513
+ gt_parsed = parse_function_call(ground_truth_call)
514
+
515
+ if agent_parsed and gt_parsed:
516
+ return agent_parsed == gt_parsed
517
+ else:
518
+ return False
Paper2Poster/camel/benchmarks/ragbench.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
16
+
17
+ import numpy as np
18
+ from datasets import Dataset, load_dataset
19
+
20
+ from camel.agents import ChatAgent
21
+ from camel.benchmarks import BaseBenchmark
22
+ from camel.logger import get_logger
23
+ from camel.retrievers import AutoRetriever
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ class RagasFields:
29
+ r"""Constants for RAGAS evaluation field names."""
30
+
31
+ INPUT_CONTEXT = "contexts"
32
+ INPUT_QUESTION = "question"
33
+ INPUT_ANSWER = "answer"
34
+
35
+
36
+ def annotate_dataset(
37
+ dataset: Dataset,
38
+ context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
39
+ answer_call: Optional[Callable[[Dict[str, Any]], str]],
40
+ ) -> Dataset:
41
+ r"""Annotate the dataset by adding context and answers using the provided
42
+ functions.
43
+
44
+ Args:
45
+ dataset (Dataset): The input dataset to annotate.
46
+ context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
47
+ Function to generate context for each example.
48
+ answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
49
+ generate answer for each example.
50
+
51
+ Returns:
52
+ Dataset: The annotated dataset with added contexts and/or answers.
53
+ """
54
+
55
+ def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
56
+ if context_call:
57
+ example["contexts"] = context_call(example)
58
+ if answer_call:
59
+ example["answer"] = answer_call(example)
60
+ return example
61
+
62
+ return dataset.map(process_example)
63
+
64
+
65
+ def rmse(
66
+ input_trues: Sequence[float],
67
+ input_preds: Sequence[float],
68
+ ) -> Optional[float]:
69
+ r"""Calculate Root Mean Squared Error (RMSE).
70
+
71
+ Args:
72
+ input_trues (Sequence[float]): Ground truth values.
73
+ input_preds (Sequence[float]): Predicted values.
74
+
75
+ Returns:
76
+ Optional[float]: RMSE value, or None if inputs have different lengths.
77
+ """
78
+ if len(input_trues) != len(input_preds):
79
+ logger.warning("Input lengths mismatch in RMSE calculation")
80
+ return None
81
+
82
+ trues = np.array(input_trues)
83
+ preds = np.array(input_preds, dtype=float)
84
+
85
+ # Ignore NaN values in predictions
86
+ eval_idx = ~np.isnan(preds)
87
+ if not np.any(eval_idx):
88
+ logger.warning("No valid predictions for RMSE calculation")
89
+ return None
90
+
91
+ trues = trues[eval_idx]
92
+ preds = preds[eval_idx]
93
+
94
+ return float(np.sqrt(np.mean((preds - trues) ** 2)))
95
+
96
+
97
+ def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
98
+ r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
99
+
100
+ Args:
101
+ trues (Sequence[bool]): Ground truth binary values.
102
+ preds (Sequence[float]): Predicted probability values.
103
+
104
+ Returns:
105
+ float: AUROC score.
106
+ """
107
+ from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
108
+
109
+ eval_idx = ~np.isnan(preds)
110
+ if not np.any(eval_idx):
111
+ logger.warning("No valid predictions for AUROC calculation")
112
+ return 0.5 # Return random classifier score
113
+
114
+ return float(
115
+ roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
116
+ )
117
+
118
+
119
+ def ragas_calculate_metrics(
120
+ dataset: Dataset,
121
+ pred_context_relevance_field: Optional[str],
122
+ pred_faithfulness_field: Optional[str],
123
+ metrics_to_evaluate: Optional[List[str]] = None,
124
+ ground_truth_context_relevance_field: str = "relevance_score",
125
+ ground_truth_faithfulness_field: str = "adherence_score",
126
+ ) -> Dict[str, Optional[float]]:
127
+ r"""Calculate RAGAS evaluation metrics.
128
+
129
+ Args:
130
+ dataset (Dataset): The dataset containing predictions and ground truth.
131
+ pred_context_relevance_field (Optional[str]): Field name for predicted
132
+ context relevance.
133
+ pred_faithfulness_field (Optional[str]): Field name for predicted
134
+ faithfulness.
135
+ metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
136
+ ground_truth_context_relevance_field (str): Field name for ground truth
137
+ relevance.
138
+ ground_truth_faithfulness_field (str): Field name for ground truth
139
+ adherence.
140
+
141
+ Returns:
142
+ Dict[str, Optional[float]]: Dictionary of calculated metrics.
143
+ """
144
+ metrics_to_evaluate = metrics_to_evaluate or [
145
+ "context_relevancy",
146
+ "faithfulness",
147
+ ]
148
+ calculated_metrics: Dict[str, Optional[float]] = {}
149
+
150
+ if (
151
+ "context_relevancy" in metrics_to_evaluate
152
+ and pred_context_relevance_field
153
+ ):
154
+ trues_relevance = dataset[ground_truth_context_relevance_field]
155
+ preds_relevance = dataset[pred_context_relevance_field]
156
+ calculated_metrics["relevance_rmse"] = rmse(
157
+ trues_relevance, preds_relevance
158
+ )
159
+
160
+ if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
161
+ trues_hallucination = ~np.array(
162
+ dataset[ground_truth_faithfulness_field]
163
+ )
164
+ preds_hallucination = 1 - np.array(
165
+ dataset[pred_faithfulness_field], dtype=float
166
+ )
167
+ calculated_metrics["hallucination_auroc"] = auroc(
168
+ trues_hallucination.tolist(), preds_hallucination.tolist()
169
+ )
170
+
171
+ return calculated_metrics
172
+
173
+
174
+ def ragas_evaluate_dataset(
175
+ dataset: Dataset,
176
+ contexts_field_name: Optional[str],
177
+ answer_field_name: Optional[str],
178
+ metrics_to_evaluate: Optional[List[str]] = None,
179
+ ) -> Dataset:
180
+ r"""Evaluate the dataset using RAGAS metrics.
181
+
182
+ Args:
183
+ dataset (Dataset): Input dataset to evaluate.
184
+ contexts_field_name (Optional[str]): Field name containing contexts.
185
+ answer_field_name (Optional[str]): Field name containing answers.
186
+ metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
187
+
188
+ Returns:
189
+ Dataset: Dataset with added evaluation metrics.
190
+ """
191
+ from ragas import evaluate
192
+ from ragas.metrics import ( # type: ignore[import-untyped]
193
+ context_relevancy,
194
+ faithfulness,
195
+ )
196
+
197
+ metrics_to_evaluate = metrics_to_evaluate or [
198
+ "context_relevancy",
199
+ "faithfulness",
200
+ ]
201
+
202
+ # Rename fields if necessary
203
+ if (
204
+ contexts_field_name
205
+ and contexts_field_name != RagasFields.INPUT_CONTEXT
206
+ ):
207
+ dataset = dataset.rename_column(
208
+ contexts_field_name, RagasFields.INPUT_CONTEXT
209
+ )
210
+ if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
211
+ dataset = dataset.rename_column(
212
+ answer_field_name, RagasFields.INPUT_ANSWER
213
+ )
214
+
215
+ metrics = []
216
+ if "context_relevancy" in metrics_to_evaluate:
217
+ metrics.append(context_relevancy)
218
+ if "faithfulness" in metrics_to_evaluate:
219
+ metrics.append(faithfulness)
220
+
221
+ ragas_result = evaluate(dataset, metrics=metrics)
222
+ return Dataset.from_pandas(ragas_result.to_pandas())
223
+
224
+
225
+ class RAGBenchBenchmark(BaseBenchmark):
226
+ r"""RAGBench Benchmark for evaluating RAG performance.
227
+
228
+ This benchmark uses the rungalileo/ragbench dataset to evaluate
229
+ retrieval-augmented generation (RAG) systems. It measures context
230
+ relevancy and faithfulness metrics as described in
231
+ https://arxiv.org/abs/2407.11005.
232
+
233
+ Args:
234
+ processes (int, optional): Number of processes for parallel processing.
235
+ subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
236
+ split (str, optional): Dataset split to use (e.g., "test").
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ processes: int = 1,
242
+ subset: Literal[
243
+ "covidqa",
244
+ "cuad",
245
+ "delucionqa",
246
+ "emanual",
247
+ "expertqa",
248
+ "finqa",
249
+ "hagrid",
250
+ "hotpotqa",
251
+ "msmarco",
252
+ "pubmedqa",
253
+ "tatqa",
254
+ "techqa",
255
+ ] = "hotpotqa",
256
+ split: Literal["train", "test", "validation"] = "test",
257
+ ) -> None:
258
+ super().__init__("ragbench", "rag_bench", "", processes)
259
+ self.subset = subset
260
+ self.split = split
261
+ self.dataset: Optional[Dataset] = None
262
+
263
+ def download(self):
264
+ r"""Download the RAGBench dataset."""
265
+ try:
266
+ self.dataset = load_dataset(
267
+ "rungalileo/ragbench", self.subset, split=self.split
268
+ )
269
+ except Exception as e:
270
+ logger.error(f"Failed to download dataset: {e}")
271
+ raise
272
+
273
+ def load(self, force_download: bool = False):
274
+ r"""Load the RAGBench dataset.
275
+
276
+ Args:
277
+ force_download (bool, optional): Whether to force download the
278
+ data.
279
+ """
280
+ if force_download or self.dataset is None:
281
+ logger.info(
282
+ "%s dataset",
283
+ "Force downloading" if force_download else "Loading",
284
+ )
285
+ self.download()
286
+
287
+ def run( # type: ignore[override, return]
288
+ self,
289
+ agent: ChatAgent,
290
+ auto_retriever: AutoRetriever,
291
+ ) -> Dict[str, Optional[float]]:
292
+ r"""Run the benchmark evaluation.
293
+
294
+ Args:
295
+ agent (ChatAgent): Chat agent for generating answers.
296
+ auto_retriever (AutoRetriever): Retriever for finding relevant
297
+ contexts.
298
+
299
+ Returns:
300
+ Dict[str, Optional[float]]: Dictionary of evaluation metrics.
301
+ """
302
+
303
+ def context_call(example):
304
+ retrieved_info = auto_retriever.run_vector_retriever(
305
+ query=example['question'],
306
+ contents=example['documents'],
307
+ top_k=1,
308
+ return_detailed_info=True,
309
+ similarity_threshold=0.5,
310
+ )
311
+ return [c['text'] for c in retrieved_info['Retrieved Context']]
312
+
313
+ def answer_call(example: Dict[str, Any]) -> str:
314
+ user_msg = str(example)
315
+ assistant_response = agent.step(user_msg)
316
+ return assistant_response.msg.content
317
+
318
+ # Annotate the dataset
319
+ annotated_ds = annotate_dataset(
320
+ self.dataset, context_call, answer_call
321
+ )
322
+ evaluated_ds = ragas_evaluate_dataset(
323
+ annotated_ds,
324
+ contexts_field_name="contexts",
325
+ answer_field_name="answer",
326
+ metrics_to_evaluate=["context_relevancy", "faithfulness"],
327
+ )
328
+
329
+ return ragas_calculate_metrics(
330
+ evaluated_ds,
331
+ pred_context_relevance_field="context_relevancy",
332
+ pred_faithfulness_field="faithfulness",
333
+ )