Mqleet commited on
Commit
fcaa164
·
1 Parent(s): 0aba08a
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. LICENSE +21 -0
  3. ProjectPageAgent/__init__.py +7 -0
  4. ProjectPageAgent/content_planner.py +509 -0
  5. ProjectPageAgent/css_checker.py +111 -0
  6. ProjectPageAgent/html_finder.py +32 -0
  7. ProjectPageAgent/html_generator.py +633 -0
  8. ProjectPageAgent/main_pipline.py +379 -0
  9. ProjectPageAgent/parse_paper.py +88 -0
  10. ProjectPageAgent/parse_raw.py +256 -0
  11. ProjectPageAgent/template_analyzer.py +436 -0
  12. app.py +1671 -0
  13. camel/__init__.py +25 -0
  14. camel/agents/__init__.py +44 -0
  15. camel/agents/base.py +29 -0
  16. camel/agents/chat_agent.py +1539 -0
  17. camel/agents/critic_agent.py +202 -0
  18. camel/agents/deductive_reasoner_agent.py +303 -0
  19. camel/agents/embodied_agent.py +201 -0
  20. camel/agents/knowledge_graph_agent.py +259 -0
  21. camel/agents/multi_hop_generator_agent.py +117 -0
  22. camel/agents/programmed_agent_instruction.py +203 -0
  23. camel/agents/role_assignment_agent.py +141 -0
  24. camel/agents/search_agent.py +133 -0
  25. camel/agents/task_agent.py +410 -0
  26. camel/agents/tool_agents/__init__.py +20 -0
  27. camel/agents/tool_agents/base.py +39 -0
  28. camel/agents/tool_agents/hugging_face_tool_agent.py +206 -0
  29. camel/benchmarks/__init__.py +30 -0
  30. camel/benchmarks/apibank.py +565 -0
  31. camel/benchmarks/apibench.py +500 -0
  32. camel/benchmarks/base.py +152 -0
  33. camel/benchmarks/gaia.py +478 -0
  34. camel/benchmarks/nexus.py +518 -0
  35. camel/benchmarks/ragbench.py +333 -0
  36. camel/bots/__init__.py +34 -0
  37. camel/bots/discord/__init__.py +26 -0
  38. camel/bots/discord/discord_app.py +384 -0
  39. camel/bots/discord/discord_installation.py +64 -0
  40. camel/bots/discord/discord_store.py +160 -0
  41. camel/bots/slack/__init__.py +30 -0
  42. camel/bots/slack/models.py +158 -0
  43. camel/bots/slack/slack_app.py +255 -0
  44. camel/bots/telegram_bot.py +82 -0
  45. camel/configs/__init__.py +85 -0
  46. camel/configs/anthropic_config.py +71 -0
  47. camel/configs/base_config.py +89 -0
  48. camel/configs/cohere_config.py +76 -0
  49. camel/configs/deepseek_config.py +134 -0
  50. camel/configs/gemini_config.py +114 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ templates/**/*.wav
2
+ templates/**/*.mp4
3
+ templates/**/*.gif
4
+ templates/**/*.webm
5
+ templates/**/*.mov
6
+ templates/**/*.pdf*.ttf
7
+ templates/**/*.pdf
8
+ templates/**/*?
9
+ *.woff
10
+ *.woff2
11
+ *.png
12
+ *.jpg
13
+
14
+ .DS_Store
15
+
16
+ **/__pycache__/*
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Qianli Ma
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ProjectPageAgent/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """
2
+ ProjectPageAgent: A multi-agent system for generating project pages from research papers.
3
+ Based on Paper2Poster architecture, adapted for project page generation.
4
+ """
5
+
6
+ __version__ = "1.0.0"
7
+ __author__ = "Paper2ProjectPage Team"
ProjectPageAgent/content_planner.py ADDED
@@ -0,0 +1,509 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Content planner for project page generation.
3
+ Plans the structure and content organization for the project page.
4
+ """
5
+
6
+ import json
7
+ import yaml
8
+ import os
9
+ from jinja2 import Environment, StrictUndefined
10
+ from camel.models import ModelFactory
11
+ from camel.agents import ChatAgent
12
+ from utils.wei_utils import account_token
13
+ from utils.src.utils import get_json_from_response
14
+ from camel.messages import BaseMessage
15
+ from rich import print
16
+ from rich.pretty import Pretty
17
+ import base64
18
+ from camel.messages import BaseMessage
19
+ from camel.models import ModelFactory
20
+
21
+ def filter_references(md_content: str) -> str:
22
+
23
+ lines = md_content.splitlines()
24
+ result_lines = []
25
+ for line in lines:
26
+ if line.strip().lower().startswith("## references"):
27
+ break
28
+ result_lines.append(line)
29
+ return "\n".join(result_lines)
30
+
31
+ class ProjectPageContentPlanner:
32
+ """Plans the content structure and organization for project pages."""
33
+
34
+ def __init__(self, agent_config, args):
35
+ self.agent_config = agent_config
36
+ self.args = args
37
+ self.planner_agent = self._create_planner_agent()
38
+ self.reviewer_agent = self._create_reviewer_agent()
39
+ os.makedirs('project_contents', exist_ok=True)
40
+
41
+ def _create_planner_agent(self):
42
+ """Create the content planning (generation) agent."""
43
+ model_type = str(self.agent_config['model_type'])
44
+
45
+ # Get API key from environment variables
46
+ api_key = None
47
+ if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
48
+ api_key = os.environ.get('OPENAI_API_KEY')
49
+ elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
50
+ api_key = os.environ.get('GEMINI_API_KEY')
51
+ elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
52
+ api_key = os.environ.get('QWEN_API_KEY')
53
+ elif self.args.model_name_t.startswith('openrouter_'):
54
+ api_key = os.environ.get('OPENROUTER_API_KEY')
55
+ elif self.args.model_name_t in ['zhipuai']:
56
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
57
+
58
+ if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
59
+ model = ModelFactory.create(
60
+ model_platform=self.agent_config['model_platform'],
61
+ model_type=self.agent_config['model_type'],
62
+ model_config_dict=self.agent_config['model_config'],
63
+ url=self.agent_config.get('url', None),
64
+ api_key=api_key,
65
+ )
66
+ else:
67
+ model = ModelFactory.create(
68
+ model_platform=self.agent_config['model_platform'],
69
+ model_type=self.agent_config['model_type'],
70
+ model_config_dict=self.agent_config['model_config'],
71
+ api_key=api_key,
72
+ )
73
+
74
+
75
+ system_message = """You are a helpful academic expert and web developer, who is specialized in generating a paper project page, from given research paper's contents and figures."""
76
+
77
+ return ChatAgent(
78
+ system_message=system_message,
79
+ model=model,
80
+ message_window_size=10,
81
+ token_limit=self.agent_config.get('token_limit', None)
82
+ )
83
+
84
+ def _create_reviewer_agent(self):
85
+
86
+ model_type = str(self.agent_config['model_type'])
87
+
88
+ # Get API key from environment variables
89
+ api_key = None
90
+ if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
91
+ api_key = os.environ.get('OPENAI_API_KEY')
92
+ elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
93
+ api_key = os.environ.get('GEMINI_API_KEY')
94
+ elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
95
+ api_key = os.environ.get('QWEN_API_KEY')
96
+ elif self.args.model_name_t.startswith('openrouter_'):
97
+ api_key = os.environ.get('OPENROUTER_API_KEY')
98
+ elif self.args.model_name_t in ['zhipuai']:
99
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
100
+
101
+ if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
102
+ model = ModelFactory.create(
103
+ model_platform=self.agent_config['model_platform'],
104
+ model_type=self.agent_config['model_type'],
105
+ model_config_dict=self.agent_config['model_config'],
106
+ url=self.agent_config.get('url', None),
107
+ api_key=api_key,
108
+ )
109
+ else:
110
+ model = ModelFactory.create(
111
+ model_platform=self.agent_config['model_platform'],
112
+ model_type=self.agent_config['model_type'],
113
+ model_config_dict=self.agent_config['model_config'],
114
+ api_key=api_key,
115
+ )
116
+
117
+ reviewer_system = (
118
+ "You are a precise, constructive reviewer of generated project pages. "
119
+ )
120
+ return ChatAgent(
121
+ system_message=reviewer_system,
122
+ model=model,
123
+ message_window_size=10,
124
+ token_limit=self.agent_config.get('token_limit', None)
125
+ )
126
+
127
+ def _render_generation_prompt(self, paper_content, figures, text_page_content, template_str):
128
+
129
+ jinja_env = Environment(undefined=StrictUndefined)
130
+ template = jinja_env.from_string(template_str)
131
+ jinja_args = {
132
+ 'paper_content': paper_content,
133
+ 'figures': json.dumps(figures, indent=2),
134
+ 'project_page_content': json.dumps(text_page_content, indent=2),
135
+ }
136
+ return template.render(**jinja_args)
137
+
138
+ def _build_reviewer_prompt(self, paper_content, figures, text_page_content, generated_json):
139
+
140
+ with open('utils/prompt_templates/page_templates/full_content_review.yaml', 'r') as f:
141
+ planner_config = yaml.safe_load(f)
142
+
143
+ jinja_env = Environment(undefined=StrictUndefined)
144
+ template = jinja_env.from_string(planner_config["template"])
145
+
146
+ jinja_args = {
147
+ 'paper_content': paper_content,
148
+ 'figures': json.dumps(figures['images'], indent=2),
149
+ 'tables': json.dumps(figures['tables'], indent=2),
150
+ "generated_content": generated_json
151
+ }
152
+
153
+ prompt = template.render(**jinja_args)
154
+
155
+ return prompt
156
+
157
+ def _build_revision_prompt(self, review_json):
158
+ with open('utils/prompt_templates/page_templates/full_content_revise.yaml', 'r') as f:
159
+ planner_config = yaml.safe_load(f)
160
+
161
+ jinja_env = Environment(undefined=StrictUndefined)
162
+ template = jinja_env.from_string(planner_config["template"])
163
+
164
+ jinja_args = {
165
+ "review_content": json.dumps(review_json, indent=2)
166
+ }
167
+
168
+ prompt = template.render(**jinja_args)
169
+
170
+ return prompt
171
+
172
+ def _build_revision_prompt_with_resume(self, review_json, current_content, figures):
173
+ with open('utils/prompt_templates/page_templates/full_content_revise_with_resume.yaml', 'r') as f:
174
+ planner_config = yaml.safe_load(f)
175
+
176
+ jinja_env = Environment(undefined=StrictUndefined)
177
+ template = jinja_env.from_string(planner_config["template"])
178
+
179
+ print(review_json)
180
+
181
+ jinja_args = {
182
+ "review_content": json.dumps(review_json, indent=2),
183
+ "figures": json.dumps(figures, indent=2),
184
+ "current_content": current_content
185
+ }
186
+
187
+ prompt = template.render(**jinja_args)
188
+
189
+ return prompt
190
+
191
+ def full_content_generation(
192
+ self,
193
+ args,
194
+ paper_content,
195
+ figures,
196
+ generated_section,
197
+ text_page_content,
198
+ ):
199
+ """
200
+ Plan + Generate -> Review -> Revise
201
+
202
+ Args:
203
+ paper_content: parsed paper content
204
+ figures: list/dict of figures
205
+ generated_section: format_instructions / schema hints
206
+ text_page_content: initial text-only page structure
207
+
208
+ Returns:
209
+ tuple: (final_generated_content_json, input_token_total, output_token_total)
210
+ """
211
+ if args.resume in ['parse_pdf','generate_content']:
212
+
213
+ print("full content generation start")
214
+
215
+ with open('utils/prompt_templates/page_templates/full_content_generation.yaml', 'r') as f:
216
+ planner_config = yaml.safe_load(f)
217
+
218
+ jinja_env = Environment(undefined=StrictUndefined)
219
+ template = jinja_env.from_string(planner_config["template"])
220
+
221
+ jinja_args = {
222
+ 'paper_content': paper_content,
223
+ 'figures': json.dumps(figures, indent=2),
224
+ 'project_page_content': json.dumps(text_page_content, indent=2)
225
+ }
226
+
227
+ prompt = template.render(**jinja_args)
228
+
229
+ self.planner_agent.reset()
230
+ response = self.planner_agent.step(prompt)
231
+
232
+ gen_in_tok, gen_out_tok = account_token(response)
233
+
234
+ current_output = get_json_from_response(response.msgs[0].content)
235
+
236
+ first_path = f'project_contents/{self.args.paper_name}_generated_full_content.v0.json'
237
+ with open(first_path, 'w', encoding='utf-8') as f:
238
+ json.dump(current_output, f, ensure_ascii=False, indent=2)
239
+ print(f" - Initial generation saved: {first_path}")
240
+
241
+ total_in_tok, total_out_tok = gen_in_tok, gen_out_tok
242
+ else:
243
+ print("Skipping initial full content generation, loading existing content.")
244
+ with open(f'project_contents/{self.args.paper_name}_generated_full_content.v0.json', 'r', encoding='utf-8') as f:
245
+ current_output = json.load(f)
246
+ total_in_tok, total_out_tok = 0, 0
247
+
248
+ for it in range(0, args.full_content_check_times):
249
+ # check
250
+ self.reviewer_agent.reset()
251
+
252
+ review_prompt = self._build_reviewer_prompt(
253
+ paper_content=paper_content,
254
+ figures=figures,
255
+ text_page_content=text_page_content,
256
+ generated_json=current_output
257
+ )
258
+ review_resp = self.reviewer_agent.step(review_prompt)
259
+ rin, rout = account_token(review_resp)
260
+
261
+ review_json = get_json_from_response(review_resp.msgs[0].content)
262
+
263
+ review_path = f'project_contents/{self.args.paper_name}_review.iter{it}.json'
264
+ with open(review_path, 'w', encoding='utf-8') as f:
265
+ json.dump(review_json, f, ensure_ascii=False, indent=2)
266
+ print(f" - Review saved: {review_path}")
267
+
268
+ total_in_tok += rin
269
+ total_out_tok += rout
270
+
271
+ if args.resume != 'full_content_check':
272
+ revision_prompt = self._build_revision_prompt(
273
+ review_json=review_json
274
+ )
275
+
276
+ else:
277
+ revision_prompt = self._build_revision_prompt_with_resume(
278
+ review_json=review_json,
279
+ current_content=current_output,
280
+ figures=figures
281
+ )
282
+ rev_resp = self.planner_agent.step(revision_prompt)
283
+ rin2, rout2 = account_token(rev_resp)
284
+
285
+ revised_output = get_json_from_response(rev_resp.msgs[0].content)
286
+
287
+ out_path = f'project_contents/{self.args.paper_name}_generated_full_content.v{it+1}.json'
288
+ with open(out_path, 'w', encoding='utf-8') as f:
289
+ json.dump(revised_output, f, ensure_ascii=False, indent=2)
290
+ print(f" - Revised generation saved: {out_path}")
291
+
292
+ total_in_tok += rin2
293
+ total_out_tok += rout2
294
+ current_output = revised_output
295
+ if self.args.human_input == '1':
296
+ print('-'*50)
297
+ print(Pretty(current_output, expand_all=True))
298
+ print('-'*50)
299
+ user_feedback = input('The above is the final generated full content! If you are satisfied with the generated content, enter yes\n If not, enter your feedback.\n')
300
+ while user_feedback.lower() != 'yes':
301
+ message = BaseMessage.make_assistant_message(
302
+ role_name='User',
303
+ content='human feedback'+user_feedback +"The above is human feedback. Please make modifications based on this feedback and the original content.The output format is as specified above."
304
+ )
305
+ response = self.planner_agent.step(message)
306
+ current_output = get_json_from_response(response.msgs[0].content)
307
+ print('-'*50)
308
+ print(Pretty(current_output, expand_all=True))
309
+ print('-'*50)
310
+ user_feedback = input('The above is the final generated full content! If you are satisfied with the generated content, enter yes. \n If not, enter your feedback.\n')
311
+ in_tok, out_tok = account_token(response)
312
+ total_in_tok += in_tok
313
+ total_out_tok += out_tok
314
+
315
+ # 4) 最终保存(保持你原有的命名)
316
+ final_path = f'project_contents/{self.args.paper_name}_generated_full_content.json'
317
+ with open(final_path, 'w', encoding='utf-8') as f:
318
+ json.dump(current_output, f, ensure_ascii=False, indent=2)
319
+ print(f"full content generation completed. Tokens: {total_in_tok} -> {total_out_tok}")
320
+ print(f" - Final content: {final_path}")
321
+
322
+ return current_output, total_in_tok, total_out_tok
323
+
324
+ def section_generation(self, paper_content, figures):
325
+ """
326
+ Plan the content structure for the project page.
327
+
328
+ Args:
329
+ paper_content: Parsed paper content
330
+
331
+ Returns:
332
+ dict: project page content
333
+ """
334
+
335
+ # Load planning prompt template
336
+
337
+ with open('utils/prompt_templates/page_templates/section_generation.yaml', 'r') as f:
338
+ planner_config = yaml.safe_load(f)
339
+
340
+ jinja_env = Environment(undefined=StrictUndefined)
341
+ template = jinja_env.from_string(planner_config["template"])
342
+
343
+ json_format_example = """
344
+ ```json
345
+ {{
346
+ "Introduction": "Brief overview of the paper's main topic and objectives.",
347
+ "Methodology": "Description of the methods used in the research.",
348
+ "Results": "Summary of the key findings and results."
349
+ }}
350
+ ```
351
+ """
352
+
353
+ # Prepare template arguments
354
+ jinja_args = {
355
+ 'paper_content': paper_content,
356
+ 'json_format_example': json.dumps(paper_content, indent=2)
357
+ }
358
+
359
+ prompt = template.render(**jinja_args)
360
+
361
+ # Generate content plan
362
+ self.planner_agent.reset()
363
+ response = self.planner_agent.step(prompt)
364
+ input_token, output_token = account_token(response)
365
+ generated_section = get_json_from_response(response.msgs[0].content)
366
+
367
+ if self.args.human_input == '1':
368
+ print('-'*50)
369
+ print(Pretty(generated_section, expand_all=True))
370
+ print('-'*50)
371
+ user_feedback = input('The above is the generated section! If you are satisfied with the generated section, enter yes. \nIf not, enter your feedback.\n')
372
+ while user_feedback.lower() != 'yes':
373
+ message = BaseMessage.make_assistant_message(
374
+ role_name='User',
375
+ content='human feedback'+user_feedback +"The above is human feedback. Please make modifications based on this feedback and the original content.The output format is as specified above."
376
+ )
377
+ response = self.planner_agent.step(message)
378
+ generated_section = get_json_from_response(response.msgs[0].content)
379
+ print('-'*50)
380
+ print(Pretty(generated_section, expand_all=True))
381
+ print('-'*50)
382
+ user_feedback = input('The above is the generated section! If you are satisfied with the generated section, enter yes. \nIf not, enter your feedback.\n')
383
+ in_tok, out_tok = account_token(response)
384
+ input_token += in_tok
385
+ output_token += out_tok
386
+
387
+ print(f"section planning completed. Tokens: {input_token} -> {output_token}")
388
+
389
+ def create_dynamic_page_dict(sections: dict[str, str]) -> dict[str, str]:
390
+ poster_dict = {
391
+ "title": "Title of the paper",
392
+ "authors": "Authors of the paper, Each author must be accompanied by the superscript number(s) of their corresponding affiliation(s).",
393
+ "affiliation": "Affiliation of the authors, each affiliation must be accompanied by the corresponding superscript number.",
394
+ }
395
+
396
+ poster_dict.update(sections)
397
+ return poster_dict
398
+
399
+ generated_section = create_dynamic_page_dict(generated_section)
400
+
401
+ # Save generated content
402
+ # print(self.agent_config)
403
+ generated_path = f'project_contents/{self.args.paper_name}_generated_section.json'
404
+ with open(generated_path, 'w') as f:
405
+ json.dump(generated_section, f, indent=4)
406
+
407
+ print(f" - Generated section plan: {generated_path}")
408
+
409
+ return generated_section, input_token, output_token
410
+
411
+ def text_content_generation(self, paper_content, figures, generated_section):
412
+ """
413
+ Plan the content structure for the project page.
414
+
415
+ Args:
416
+ paper_content: Parsed paper content
417
+
418
+ Returns:
419
+ dict: project page content
420
+ """
421
+
422
+ # Delete tags in figures
423
+ figures_ = {}
424
+ figures_['images'] = [{k: v for k, v in value.items() if k != 'tag'} for value in figures['images'].values()]
425
+ figures_['tables'] = [{k: v for k, v in value.items() if k != 'tag'} for value in figures['tables'].values()]
426
+
427
+ # Load planning prompt template
428
+ with open('utils/prompt_templates/page_templates/text_content_generation.yaml', 'r') as f:
429
+ planner_config = yaml.safe_load(f)
430
+
431
+ jinja_env = Environment(undefined=StrictUndefined)
432
+ template = jinja_env.from_string(planner_config["template"])
433
+
434
+ # Prepare template arguments
435
+ jinja_args = {
436
+ 'paper_content': paper_content,
437
+ 'figures': json.dumps(figures_, indent=2),
438
+ 'format_instructions': json.dumps(generated_section, indent=2)
439
+ }
440
+
441
+ prompt = template.render(**jinja_args)
442
+
443
+ # Generate content plan
444
+ self.planner_agent.reset()
445
+ response = self.planner_agent.step(prompt)
446
+ input_token, output_token = account_token(response)
447
+
448
+ generated_text_content = get_json_from_response(response.msgs[0].content)
449
+
450
+ print(f"text content generation completed. Tokens: {input_token} -> {output_token}")
451
+
452
+ # Save generated content
453
+ generated_path = f'project_contents/{self.args.paper_name}_generated_text_content.json'
454
+ with open(generated_path, 'w') as f:
455
+ json.dump(generated_text_content, f, indent=4)
456
+
457
+ print(f" - Generated text content: {generated_path}")
458
+
459
+ return generated_text_content, input_token, output_token
460
+
461
+ def filter_raw_content(self, paper_content, figures):
462
+ paper_content = filter_references(paper_content)
463
+ # Load planning prompt template
464
+ with open('utils/prompt_templates/page_templates/filter_figures.yaml', 'r') as f:
465
+ planner_config = yaml.safe_load(f)
466
+
467
+ jinja_env = Environment(undefined=StrictUndefined)
468
+ template = jinja_env.from_string(planner_config["template"])
469
+
470
+ # Prepare template arguments
471
+ jinja_args = {
472
+ 'paper_content': paper_content,
473
+ 'figures': json.dumps(figures, indent=2),
474
+ }
475
+
476
+ prompt = template.render(**jinja_args)
477
+
478
+ # Generate filtered figures
479
+ self.planner_agent.reset()
480
+ response = self.planner_agent.step(prompt)
481
+ input_token, output_token = account_token(response)
482
+ filtered_figures = get_json_from_response(response.msgs[0].content)
483
+ #print(filtered_figures)
484
+
485
+ def remove_items_without_section(data: dict) -> dict:
486
+
487
+ for key in ["images", "tables"]:
488
+ if key in data and isinstance(data[key], dict):
489
+ data[key] = {
490
+ k: v for k, v in data[key].items()
491
+ if v.get("original_section") is not None
492
+ }
493
+ return data
494
+
495
+ filtered_figures = remove_items_without_section(filtered_figures)
496
+
497
+ print(f"filtered figures generation completed. Tokens: {input_token} -> {output_token}")
498
+
499
+ # Save generated filtered figures
500
+ generated_path = f'project_contents/{self.args.paper_name}_generated_filtered_figures.json'
501
+ with open(generated_path, 'w') as f:
502
+ json.dump(filtered_figures, f, indent=4)
503
+
504
+ print(f" - Generated filtered figures: {generated_path}")
505
+
506
+ return paper_content, filtered_figures, input_token, output_token
507
+
508
+
509
+
ProjectPageAgent/css_checker.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from collections import OrderedDict
3
+ from ProjectPageAgent.html_finder import HtmlFinder
4
+ import os
5
+
6
+
7
+
8
+ _LINK_CSS_RE = re.compile(
9
+ r'''(?isx)
10
+ <link[^>]*?
11
+ href\s*=\s*
12
+ (?:
13
+ "([^"]+?\.css(?:\?[^"]*)?)" |
14
+ '([^']+?\.css(?:\?[^']*)?)' |
15
+ ([^\s"'=<>`]+?\.css(?:\?[^\s"'=<>`]*)?)
16
+ )
17
+ [^>]*?>
18
+ '''
19
+ )
20
+
21
+
22
+ _IMPORT_CSS_RE = re.compile(
23
+ r'''(?isx)
24
+ @import
25
+ \s+(?:url\()?
26
+ \s*
27
+ (?:
28
+ "([^"]+?\.css(?:\?[^"]*)?)" |
29
+ '([^']+?\.css(?:\?[^']*)?)' |
30
+ ([^'")\s;]+?\.css(?:\?[^'")\s;]+)?)
31
+ )
32
+ \s*
33
+ \)?
34
+ '''
35
+ )
36
+
37
+
38
+ def _first_nonempty(groups_list):
39
+ out = []
40
+ for groups in groups_list:
41
+ for g in groups:
42
+ if g:
43
+ out.append(g)
44
+ break
45
+ return out
46
+
47
+ def extract_css_paths(html: str):
48
+
49
+ links = _first_nonempty(_LINK_CSS_RE.findall(html))
50
+ imports = _first_nonempty(_IMPORT_CSS_RE.findall(html))
51
+ seen = OrderedDict()
52
+ for u in links + imports:
53
+ u = u.strip()
54
+ if u and u not in seen:
55
+ seen[u] = True
56
+ return list(seen.keys())
57
+
58
+ def check_css(generated_html: str, template_html: str):
59
+ generated_css = extract_css_paths(generated_html)
60
+ template_css = extract_css_paths(template_html)
61
+ print(f'num of css in generated page: {len(generated_css)}')
62
+ print(f'num of css in template page: {len(template_css)}')
63
+ template_css_name = {css.strip().split('/')[-1]: css for css in template_css}
64
+
65
+ errors = {}
66
+ for css in generated_css:
67
+ if css.startswith('http'):
68
+ continue
69
+ if css not in template_css:
70
+ match = template_css_name.get(css.strip().split('/')[-1], None)
71
+ if match is not None:
72
+ errors[css] = match
73
+ else:
74
+ print(f"[⚠️ Warning] Missing CSS match for {css}")
75
+
76
+ new_html = generated_html
77
+ for css, new_css in errors.items():
78
+ if new_css:
79
+ new_html = new_html.replace(css, new_css)
80
+
81
+ return new_html
82
+
83
+
84
+
85
+
86
+
87
+ if __name__ == "__main__":
88
+
89
+ templates_root = '/home/jimu/Project_resources/project_page/page_assets/'
90
+ html_finder = HtmlFinder(specific_name='index.html')
91
+
92
+ count = 0
93
+ for page in os.listdir('generated_FastVGGT'):
94
+ print(page)
95
+ count += 1
96
+ with open(html_finder.find_html(os.path.join('generated_FastVGGT', page)), 'r') as f:
97
+ generated_html = f.read()
98
+
99
+ with open(html_finder.find_html(os.path.join(templates_root, page)), 'r') as f:
100
+ template_html = f.read()
101
+
102
+
103
+ _ = check_css(generated_html, template_html, page)
104
+ print(count)
105
+
106
+
107
+
108
+
109
+
110
+
111
+
ProjectPageAgent/html_finder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+
4
+ class HtmlFinder(object):
5
+ def __init__(self, specific_name=None):
6
+ self.queue = []
7
+ self.specific_name = specific_name
8
+
9
+ def find_html(self, path):
10
+ try:
11
+ if not os.path.isdir(path):
12
+ return
13
+ if self.queue:
14
+ del self.queue[0]
15
+ for dir in os.listdir(path):
16
+ dir_path = os.path.join(path, dir)
17
+ if os.path.isdir(dir_path):
18
+ self.queue.append(dir_path)
19
+ elif self.specific_name is not None and dir_path.endswith(self.specific_name):
20
+ return dir_path
21
+ elif dir_path.endswith(".html"):
22
+ html_path = dir_path
23
+ return html_path
24
+ else: continue
25
+ html_path = self.find_html(self.queue[0])
26
+ if html_path is not None:
27
+ return html_path
28
+ except Exception as e:
29
+ print(f"Error appears when finding {path}, error: {str(e)}")
30
+
31
+ def reset_queue(self):
32
+ self.queue = []
ProjectPageAgent/html_generator.py ADDED
@@ -0,0 +1,633 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HTML generator for project page generation.
3
+ Generates the final HTML project page from planned content.
4
+ """
5
+
6
+ import json
7
+ import yaml
8
+ import os
9
+ import io
10
+ import re
11
+ import json
12
+ import yaml
13
+ from pathlib import Path
14
+ from urllib.parse import urlparse
15
+ from datetime import datetime
16
+ from jinja2 import Environment, StrictUndefined
17
+ from camel.models import ModelFactory
18
+ from camel.agents import ChatAgent
19
+ from utils.wei_utils import get_agent_config, account_token
20
+ from utils.src.utils import get_json_from_response, extract_html_code_block
21
+ from ProjectPageAgent.css_checker import check_css
22
+ from utils.src.utils import run_sync_screenshots
23
+ from PIL import Image
24
+ from camel.messages import BaseMessage
25
+
26
+
27
+ from camel.models import ModelFactory
28
+
29
+ def to_url(input_path_or_url: str) -> str:
30
+ parsed = urlparse(input_path_or_url)
31
+ if parsed.scheme in ("http", "https", "file"):
32
+ return input_path_or_url
33
+ p = Path(input_path_or_url).expanduser().resolve()
34
+ if not p.exists():
35
+ raise FileNotFoundError(f"Input not found: {p}")
36
+ return p.as_uri() # file://...
37
+
38
+
39
+ def crop_image_to_max_size(image_path, max_bytes=8*1024*1024, output_path=None):
40
+ img = Image.open(image_path)
41
+ img_format = img.format
42
+ if output_path is None:
43
+ output_path = image_path
44
+
45
+ buffer = io.BytesIO()
46
+ img.save(buffer, format=img_format)
47
+ size = buffer.getbuffer().nbytes
48
+
49
+ if size <= max_bytes:
50
+ img.save(output_path, format=img_format)
51
+ return output_path
52
+
53
+ width, height = img.size
54
+ scale = max_bytes / size
55
+ new_height = max(int(height * scale), 1)
56
+ img_cropped = img.crop((0, 0, width, new_height))
57
+ img_cropped.save(output_path, format=img_format)
58
+
59
+ return output_path
60
+ class ProjectPageHTMLGenerator:
61
+ """Generates HTML project pages from planned content."""
62
+
63
+ def __init__(self, agent_config,args):
64
+ self.agent_config = agent_config
65
+ self.args = args
66
+ self.html_agent = self._create_html_agent()
67
+ self.review_agent = self._create_review_agent()
68
+ self.table_agent = self._create_table_agent()
69
+ self.long_agent = self._create_long_agent()
70
+
71
+ # self.client = OpenAI(api_key=api_key,base_url=api_url)
72
+
73
+ def _create_html_agent(self):
74
+ """Create the HTML generation agent."""
75
+ model_type = str(self.agent_config['model_type'])
76
+
77
+ # Get API key from environment variables
78
+ api_key = None
79
+ if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
80
+ api_key = os.environ.get('OPENAI_API_KEY')
81
+ elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
82
+ api_key = os.environ.get('GEMINI_API_KEY')
83
+ elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
84
+ api_key = os.environ.get('QWEN_API_KEY')
85
+ elif self.args.model_name_t.startswith('openrouter_'):
86
+ api_key = os.environ.get('OPENROUTER_API_KEY')
87
+ elif self.args.model_name_t in ['zhipuai']:
88
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
89
+
90
+ if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
91
+ model = ModelFactory.create(
92
+ model_platform=self.agent_config['model_platform'],
93
+ model_type=self.agent_config['model_type'],
94
+ model_config_dict=self.agent_config['model_config'],
95
+ url=self.agent_config.get('url', None),
96
+ api_key=api_key,
97
+ )
98
+ else:
99
+ model = ModelFactory.create(
100
+ model_platform=self.agent_config['model_platform'],
101
+ model_type=self.agent_config['model_type'],
102
+ model_config_dict=self.agent_config['model_config'],
103
+ api_key=api_key,
104
+ )
105
+
106
+ system_message = """You are an expert web developer specializing in creating professional project pages for research papers.
107
+ You have extensive experience in HTML5, CSS3, responsive design, and academic content presentation.
108
+ Your goal is to create engaging, well-structured, and visually appealing project pages."""
109
+
110
+ return ChatAgent(
111
+ system_message=system_message,
112
+ model=model,
113
+ message_window_size=10
114
+ )
115
+ def _create_review_agent(self):
116
+ with open('utils/prompt_templates/page_templates/html_review.yaml', 'r') as f:
117
+ prompt_config = yaml.safe_load(f)
118
+
119
+ jinja_env = Environment(undefined=StrictUndefined)
120
+ system_message_template = jinja_env.from_string(prompt_config["system_prompt"])
121
+
122
+ system_message = system_message_template.render()
123
+
124
+ model_type = self.args.model_name_v
125
+
126
+ # Get API key from environment variables
127
+ api_key = None
128
+ if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
129
+ api_key = os.environ.get('OPENAI_API_KEY')
130
+ elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
131
+ api_key = os.environ.get('GEMINI_API_KEY')
132
+ elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
133
+ api_key = os.environ.get('QWEN_API_KEY')
134
+ elif self.args.model_name_v.startswith('openrouter_'):
135
+ api_key = os.environ.get('OPENROUTER_API_KEY')
136
+ elif self.args.model_name_v in ['zhipuai']:
137
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
138
+
139
+ config = get_agent_config(model_type)
140
+ model = ModelFactory.create(
141
+ model_platform=config['model_platform'],
142
+ model_type=config['model_type'],
143
+ model_config_dict=config['model_config'],
144
+ url=config.get('url', None),
145
+ api_key=api_key,
146
+ )
147
+
148
+ return ChatAgent(
149
+ system_message=system_message,
150
+ model=model,
151
+ message_window_size=10
152
+ )
153
+
154
+
155
+ def _create_table_agent(self):
156
+
157
+ model_type = self.args.model_name_v
158
+
159
+ # Get API key from environment variables
160
+ api_key = None
161
+ if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
162
+ api_key = os.environ.get('OPENAI_API_KEY')
163
+ elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
164
+ api_key = os.environ.get('GEMINI_API_KEY')
165
+ elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
166
+ api_key = os.environ.get('QWEN_API_KEY')
167
+ elif self.args.model_name_v.startswith('openrouter_'):
168
+ api_key = os.environ.get('OPENROUTER_API_KEY')
169
+ elif self.args.model_name_v in ['zhipuai']:
170
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
171
+
172
+ vlm_config = get_agent_config(model_type)
173
+ vlm_model = ModelFactory.create(
174
+ model_platform=vlm_config['model_platform'],
175
+ model_type=vlm_config['model_type'],
176
+ model_config_dict=vlm_config['model_config'],
177
+ url=vlm_config.get('url', None),
178
+ api_key=api_key,
179
+ )
180
+ return ChatAgent(
181
+ system_message=None,
182
+ model=vlm_model,
183
+ message_window_size=10,
184
+ )
185
+ def _create_long_agent(self):
186
+ model_type = self.args.model_name_t
187
+
188
+ # Get API key from environment variables
189
+ api_key = None
190
+ if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
191
+ api_key = os.environ.get('OPENAI_API_KEY')
192
+ elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
193
+ api_key = os.environ.get('GEMINI_API_KEY')
194
+ elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
195
+ api_key = os.environ.get('QWEN_API_KEY')
196
+ elif self.args.model_name_t.startswith('openrouter_'):
197
+ api_key = os.environ.get('OPENROUTER_API_KEY')
198
+ elif self.args.model_name_t in ['zhipuai']:
199
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
200
+
201
+ long_config = get_agent_config(model_type)
202
+ long_model = ModelFactory.create(
203
+ model_platform=long_config['model_platform'],
204
+ model_type=long_config['model_type'],
205
+ model_config_dict=long_config['model_config'],
206
+ url=long_config.get('url', None),
207
+ api_key=api_key,
208
+ )
209
+
210
+ return ChatAgent(
211
+ system_message=None,
212
+ model=long_model,
213
+ message_window_size=10,
214
+ token_limit=long_config.get('token_limit', None)
215
+ )
216
+ def render_html_to_png(self, iter, html_content, project_output_dir) -> str:
217
+
218
+ import time
219
+ tmp_html = Path(project_output_dir) / f"index_iter{iter}.html"
220
+ tmp_html.write_text(html_content, encoding="utf-8")
221
+ url = tmp_html.resolve().as_uri()
222
+
223
+ image_path = str(Path(project_output_dir) / f"page_iter{iter}.png")
224
+
225
+ run_sync_screenshots(url, image_path)
226
+ return image_path
227
+
228
+ def get_revision_suggestions(self, image_path: str, html_path) -> str:
229
+
230
+ def crop_image_max_width(img, max_width=1280):
231
+ width, height = img.size
232
+ if width > max_width:
233
+ img = img.crop((0, 0, max_width, height)) # (left, top, right, bottom)
234
+ return img
235
+ img = Image.open(image_path)
236
+ img = crop_image_max_width(img, max_width=1280)
237
+ img.save(image_path,format='PNG')
238
+ crop_image_to_max_size(image_path=image_path,output_path=image_path)
239
+ img =Image.open(image_path)
240
+
241
+ message = BaseMessage.make_user_message(
242
+ role_name="User",
243
+ content = '\nHere is the image of the generated project page.',
244
+ image_list=[img]
245
+ )
246
+ response = self.review_agent.step(message)
247
+
248
+ return get_json_from_response(response.msgs[0].content.strip())
249
+
250
+
251
+ def modify_html_table(self, html_content: str,html_dir: str):
252
+
253
+
254
+ in_tokens, out_tokens = 0, 0
255
+ print("Starting table modification...")
256
+ def replace_tables_in_html(html_content, table_html_map, paper_name):
257
+
258
+ pattern = rf'<img[^>]*src="(assets/{paper_name}-table-\d+\.png)"[^>]*>'
259
+
260
+ def repl(match):
261
+ img_path = match.group(1) # e.g. assets/MambaFusion-table-10.png
262
+ if img_path in table_html_map:
263
+ return table_html_map[img_path]
264
+ return match.group(0)
265
+
266
+ return re.sub(pattern, repl, html_content)
267
+
268
+ # ============ step 1 extract table ============
269
+
270
+ pattern = rf"assets/{self.args.paper_name}-table-\d+\.png"
271
+ with open(os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'index_no_modify_table.html'), 'r', encoding='utf-8') as f:
272
+ html_content = f.read()
273
+ matches = re.findall(pattern, html_content)
274
+
275
+ if matches is None:
276
+ print("No table images found, skipping modification.")
277
+ return None, 0, 0
278
+
279
+
280
+ model_type = self.args.model_name_v
281
+ print(f"Starting table modification phase 1: Table Extraction with {model_type}...")
282
+
283
+ with open('utils/prompt_templates/page_templates/extract_table.yaml', 'r') as f:
284
+ table_extraction_config = yaml.safe_load(f)
285
+ content = table_extraction_config["system_prompt"]
286
+
287
+ init_message = BaseMessage.make_user_message(
288
+ role_name="User",
289
+ content=content
290
+ )
291
+ response = self.table_agent.step(init_message)
292
+ in_tok , out_tok = account_token(response)
293
+ in_tokens += in_tok
294
+ out_tokens += out_tok
295
+ # Step 2
296
+ table_html_map = {}
297
+
298
+ matches = list(set(matches))
299
+ for match in matches:
300
+ img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,match)
301
+ print(f"Processing table image: {img_path}")
302
+ img = Image.open(img_path)
303
+ msg = BaseMessage.make_user_message(
304
+ role_name="User",
305
+ content=f'''Here is table image: {match}
306
+ Please output its HTML table (<table>...</table>) with an inline <style>...</style> block.
307
+ Only return pure HTML , nothing else.
308
+ ''',
309
+ image_list=[img]
310
+ )
311
+ response = self.table_agent.step(msg)
312
+ in_tok , out_tok = account_token(response)
313
+ in_tokens += in_tok
314
+ out_tokens += out_tok
315
+ print(f'in:{in_tok},out:{out_tok}')
316
+ _output_html = response.msgs[0].content.strip()
317
+ table_html_map[match] = _output_html
318
+ tabel_dir = os.path.join(self.args.output_dir,self.args.paper_name, html_dir)
319
+ os.makedirs(f'{tabel_dir}/table_html', exist_ok=True)
320
+
321
+ with open(f'{tabel_dir}/table_html/{match.replace("/", "_")}.html', 'w', encoding='utf-8') as f:
322
+ f.write(table_html_map[match])
323
+
324
+ # ============ 阶段 2:HTML Merge ============
325
+
326
+ self.table_agent.reset()
327
+ img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'page_final_no_modify_table.png')
328
+ img = Image.open(img_path)
329
+ with open('utils/prompt_templates/page_templates/color_suggestion.yaml','r') as f:
330
+ prompt_config = yaml.safe_load(f)
331
+
332
+ jinja_env = Environment(undefined=StrictUndefined)
333
+ init_prompt_template = jinja_env.from_string(prompt_config["system_prompt"])
334
+
335
+ init_prompt = init_prompt_template.render()
336
+
337
+ msg = BaseMessage.make_user_message(
338
+ role_name="User",
339
+ content=init_prompt,
340
+ image_list=[img]
341
+ )
342
+
343
+ color_response = self.table_agent.step(msg)
344
+ color_suggestion = color_response.msgs[0].content.strip()
345
+ in_tok , out_tok = account_token(color_response)
346
+ in_tokens += in_tok
347
+ out_tokens += out_tok
348
+
349
+
350
+ print(f"Starting table modification phase 2: HTML Merging with {model_type}...")
351
+
352
+
353
+ tables_str = "\n\n".join(
354
+ [f"Table extracted for {fname}:\n{html}" for fname, html in table_html_map.items()]
355
+ )
356
+ with open("utils/prompt_templates/page_templates/merge_html_table.yaml",'r') as f:
357
+ prompt_config = yaml.safe_load(f)
358
+
359
+ jinja_env = Environment(undefined=StrictUndefined)
360
+ template = jinja_env.from_string(prompt_config["template"])
361
+
362
+ jinja_args = {
363
+ 'html_content': html_content,
364
+ 'color_suggestion': color_suggestion,
365
+ 'tables_str': tables_str
366
+ }
367
+
368
+ prompt = template.render(**jinja_args)
369
+
370
+ final_message = BaseMessage.make_user_message(
371
+ role_name = "User",
372
+ content = prompt
373
+ )
374
+
375
+ for i in range(3):
376
+ self.long_agent.reset()
377
+ response = self.long_agent.step(final_message)
378
+ in_tok, out_tok = account_token(response)
379
+ in_tokens += in_tok
380
+ out_tokens += out_tok
381
+ output_html = response.msgs[0].content.strip()
382
+ print(f'in:{in_tok},out:{out_tok}')
383
+ exteact_html_code = extract_html_code_block(output_html)
384
+ if exteact_html_code is not None:
385
+ break
386
+ print(f"html format is not correct, regenerate {i} turn")
387
+
388
+ return exteact_html_code, in_tokens, out_tokens
389
+
390
+
391
+ def modify_html_from_human_feedback(self, html_content: str, user_feedback: str):
392
+ """
393
+ Modify HTML based on human feedback using the HTML agent.
394
+
395
+ Args:
396
+ html_content: Original HTML content
397
+ user_feedback: Feedback from human reviewers
398
+
399
+ Returns:
400
+ str: Modified HTML content
401
+ """
402
+ in_tokens, out_tokens = 0, 0
403
+ print("Starting HTML modification based on human feedback...")
404
+ with open('utils/prompt_templates/page_templates/modify_html_from_human_feedback.yaml', 'r') as f:
405
+ modifier_config = yaml.safe_load(f)
406
+
407
+ jinja_env = Environment(undefined=StrictUndefined)
408
+ template = jinja_env.from_string(modifier_config["template"])
409
+
410
+ jinja_args = {
411
+ 'generated_html': html_content,
412
+ 'user_feedback': user_feedback
413
+ }
414
+
415
+ prompt = template.render(**jinja_args)
416
+ for i in range(3):
417
+ self.html_agent.reset()
418
+ response = self.html_agent.step(prompt)
419
+ in_tok, out_tok = account_token(response)
420
+ in_tokens += in_tok
421
+ out_tokens += out_tok
422
+ print(f'input_token: {in_tok}, output_token: {out_tok}')
423
+ modified_html = extract_html_code_block(response.msgs[0].content)
424
+
425
+ if modified_html is not None:
426
+ break
427
+ print(f"html format is not correct, regenerate {i} turn")
428
+
429
+ return modified_html, in_tokens, out_tokens
430
+
431
+
432
+ def generate_complete_html(self, args, generated_content, html_dir, html_template=None):
433
+ """
434
+ Generate complete HTML by combining all sections, then render to PNG,
435
+ send to OpenAI API for feedback, and regenerate HTML with suggestions.
436
+ """
437
+
438
+ # Create output directory for this specific project
439
+ project_output_dir = f"{args.output_dir}/{args.paper_name}"
440
+ html_path = os.path.join(project_output_dir, html_dir)
441
+ if args.resume != 'html_check':
442
+ with open('utils/prompt_templates/page_templates/html_generation.yaml', 'r') as f:
443
+ generator_config = yaml.safe_load(f)
444
+
445
+ jinja_env = Environment(undefined=StrictUndefined)
446
+ template = jinja_env.from_string(generator_config["template"])
447
+
448
+ jinja_args = {
449
+ 'generated_content': json.dumps(generated_content, indent=2),
450
+ 'html_template': html_template,
451
+ }
452
+
453
+ prompt = template.render(**jinja_args)
454
+ for i in range(3):
455
+ self.html_agent.reset()
456
+ # print(self.html_agent)
457
+
458
+ response = self.html_agent.step(prompt)
459
+ # print(response.msgs[0].content)
460
+ input_token, output_token = account_token(response)
461
+ print(f'input_token: {input_token}, output_token: {output_token}')
462
+ #print(input_token, output_token)
463
+ html_content = extract_html_code_block(response.msgs[0].content)
464
+
465
+ if html_content is not None:
466
+ break
467
+ print(f"html format is not correct, regenerate {i} turn")
468
+
469
+
470
+ # check css paths
471
+ html_content = check_css(html_content, html_template)
472
+
473
+ with open(os.path.join(html_path, 'index_init.html'),'w') as f:
474
+ f.write(html_content)
475
+
476
+ print(f"Initial HTML generation completed. Tokens: {input_token} -> {output_token}")
477
+
478
+ else:
479
+ with open(os.path.join(html_path, 'index_init.html'), 'r', encoding='utf-8') as f:
480
+ html_content = f.read()
481
+
482
+ revised_html = html_content
483
+
484
+ for i in range(self.args.html_check_times):
485
+ if i==0:
486
+ print("starting html check and revision...")
487
+
488
+ image_path = self.render_html_to_png(i, revised_html, html_path)
489
+
490
+ suggestions = self.get_revision_suggestions(image_path,os.path.join(html_path,f'index_iter{i}.html'))
491
+ # print(f"Revision suggestions from {self.args.model_name_v}:\n", suggestions)
492
+
493
+ review_path = f'project_contents/{args.paper_name}_html_review_iter{i}.json'
494
+ with open(review_path, 'w') as f:
495
+ json.dump(suggestions, f, indent=4)
496
+
497
+ self.html_agent.reset()
498
+ with open('utils/prompt_templates/page_templates/html_modify_from_suggestion.yaml', 'r') as f:
499
+ regenerator_config = yaml.safe_load(f)
500
+
501
+ jinja_env = Environment(undefined=StrictUndefined)
502
+ _template = jinja_env.from_string(regenerator_config["template"])
503
+
504
+ _jinja_args = {
505
+ 'existing_html': revised_html,
506
+ 'suggestions': suggestions
507
+ }
508
+
509
+ revision_prompt = _template.render(**_jinja_args)
510
+
511
+ # print(revision_prompt)
512
+ revised_response = self.html_agent.step(revision_prompt)
513
+ # print(revised_response.msgs[0].content)
514
+ revised_html = extract_html_code_block(revised_response.msgs[0].content)
515
+
516
+ print("Revised HTML generation completed.")
517
+ input_token, output_token = account_token(revised_response)
518
+ print(f'in:{input_token}, out:{output_token}')
519
+
520
+ return revised_html, input_token, output_token
521
+
522
+
523
+ def save_html_file(self, html_content, args, html_dir, output_dir="generated_project_pages"):
524
+ """
525
+ Save the generated HTML to a file.
526
+
527
+ Args:
528
+ html_content: Generated HTML content
529
+ args: Command line arguments
530
+ output_dir: Output directory for the HTML file
531
+
532
+ Returns:html_check
533
+ str: Path to the saved HTML file
534
+ """
535
+ os.makedirs(output_dir, exist_ok=True)
536
+
537
+ # Create output directory for this specific project
538
+ project_output_dir = f"{output_dir}/{args.paper_name}"
539
+ os.makedirs(project_output_dir, exist_ok=True)
540
+
541
+ # Save HTML file
542
+ html_file_path = f"{project_output_dir}/{html_dir}/index.html"
543
+ with open(html_file_path, 'w', encoding='utf-8') as f:
544
+ f.write(html_content)
545
+
546
+ print(f"HTML project page saved to: {html_file_path}")
547
+
548
+ return html_file_path
549
+
550
+ def create_assets_directory(self, args, html_dir, output_dir="generated_project_pages"):
551
+ """
552
+ Create assets directory and copy images/tables.
553
+
554
+ Args:
555
+ args: Command line arguments
556
+ output_dir: Output directory
557
+
558
+ Returns:
559
+ str: Path to the assets directory
560
+ """
561
+ project_output_dir = f"{output_dir}/{args.paper_name}"
562
+ assets_dir = os.path.join(project_output_dir, html_dir, "assets")
563
+ os.makedirs(assets_dir, exist_ok=True)
564
+
565
+ # Copy images and tables from the extracted assets
566
+ source_assets_dir = f"generated_project_pages/images_and_tables/{args.paper_name}"
567
+ if os.path.exists(source_assets_dir):
568
+ import shutil
569
+ for file in os.listdir(source_assets_dir):
570
+ if file.endswith(('.png', '.jpg', '.jpeg', '.gif')):
571
+ src_path = os.path.join(source_assets_dir, file)
572
+ dst_path = os.path.join(assets_dir, file)
573
+ shutil.copy2(src_path, dst_path)
574
+
575
+ print(f"Assets directory created at: {assets_dir}")
576
+ return assets_dir
577
+
578
+ def generate_metadata(self, generated_content, args):
579
+ """
580
+ Generate metadata for the project page.
581
+
582
+ Args:
583
+ generated_content: Generated content
584
+ args: Command line arguments
585
+
586
+ Returns:
587
+ dict: Metadata for the project page
588
+ """
589
+ metadata = {
590
+ 'title': generated_content.get('meta', {}).get('poster_title', 'Research Project'),
591
+ 'description': generated_content.get('meta', {}).get('abstract', '')[:160],
592
+ 'authors': generated_content.get('meta', {}).get('authors', ''),
593
+ 'affiliations': generated_content.get('meta', {}).get('affiliations', ''),
594
+ 'keywords': [],
595
+ 'generated_by': f"Paper2ProjectPage ({args.model_name_t}_{args.model_name_v})",
596
+ 'generation_date': str(datetime.now())
597
+ }
598
+
599
+ # Extract keywords from content
600
+ content_text = json.dumps(generated_content, ensure_ascii=False)
601
+ # Simple keyword extraction (can be improved)
602
+ words = content_text.lower().split()
603
+ word_freq = {}
604
+ for word in words:
605
+ if len(word) > 4 and word.isalpha():
606
+ word_freq[word] = word_freq.get(word, 0) + 1
607
+
608
+ # Get top 10 most frequent words as keywords
609
+ sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
610
+ metadata['keywords'] = [word for word, freq in sorted_words[:10]]
611
+
612
+ return metadata
613
+
614
+ def save_metadata(self, metadata, args, output_dir="generated_project_pages"):
615
+ """
616
+ Save metadata to a JSON file.
617
+
618
+ Args:
619
+ metadata: Generated metadata
620
+ args: Command line arguments
621
+ output_dir: Output directory
622
+
623
+ Returns:
624
+ str: Path to the saved metadata file
625
+ """
626
+ project_output_dir = f"{output_dir}/{args.paper_name}"
627
+ metadata_file_path = f"{project_output_dir}/metadata.json"
628
+
629
+ with open(metadata_file_path, 'w', encoding='utf-8') as f:
630
+ json.dump(metadata, f, indent=4, ensure_ascii=False)
631
+
632
+ print(f"Metadata saved to: {metadata_file_path}")
633
+ return metadata_file_path
ProjectPageAgent/main_pipline.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main pipeline for Paper2ProjectPage.
3
+ Integrates all modules to generate project pages from research papers.
4
+ """
5
+
6
+ import argparse
7
+ import json
8
+ import os
9
+ import time
10
+ from dotenv import load_dotenv
11
+ from pathlib import Path
12
+ import shutil
13
+ from ProjectPageAgent.parse_paper import parse_paper_for_project_page, save_parsed_content
14
+ from ProjectPageAgent.html_finder import HtmlFinder
15
+ from ProjectPageAgent.content_planner import ProjectPageContentPlanner
16
+ from ProjectPageAgent.html_generator import ProjectPageHTMLGenerator,to_url
17
+ from utils.wei_utils import get_agent_config
18
+ from ProjectPageAgent.content_planner import filter_references
19
+ from utils.src.utils import run_sync_screenshots
20
+
21
+ load_dotenv()
22
+
23
+ def matching(requirement):
24
+ weight = {
25
+ "background_color": 1.0,
26
+ "has_hero_section": 0.75,
27
+ "Page density": 0.85,
28
+ "image_layout": 0.65,
29
+ "title_color": 0.6,
30
+ "has_navigation": 0.7
31
+ }
32
+ with open('tags.json', 'r') as f:
33
+ template_tags = json.load(f)
34
+
35
+ points = {}
36
+ for name, tag in template_tags.items():
37
+ for feature, value in tag.items():
38
+ if requirement[feature] == value:
39
+ if name not in points.keys():
40
+ points[name] = weight[feature]
41
+ else:
42
+ points[name] += weight[feature]
43
+ sorted_points = sorted(points.items(), key=lambda x: x[1], reverse=True)
44
+ return [template[0] for template in sorted_points[0:3]]
45
+
46
+ def copy_static_files(template_file_path, template_root_dir, output_dir, paper_name):
47
+
48
+ print(f"Detecting Static files: {template_file_path}")
49
+ os.makedirs(output_dir, exist_ok=True)
50
+
51
+ # Create output directory for this specific project
52
+ project_output_dir = f"{output_dir}/{paper_name}"
53
+ os.makedirs(project_output_dir, exist_ok=True)
54
+
55
+ # template_dir = os.path.dirname(template_file_path)
56
+ static_dir = os.path.join(project_output_dir, 'static')
57
+ os.makedirs(static_dir, exist_ok=True)
58
+
59
+
60
+ html_relative_path = os.path.relpath(template_file_path, template_root_dir)
61
+
62
+ # template_static_dir = os.path.join(template_dir, 'static')
63
+ if os.path.exists(template_root_dir) and os.path.isdir(template_root_dir):
64
+ print(f"Found template dir: {template_root_dir}")
65
+ try:
66
+ shutil.copytree(template_root_dir, project_output_dir, dirs_exist_ok=True)
67
+ os.remove(os.path.join(project_output_dir, html_relative_path))
68
+ print(f"Copied template to: {project_output_dir}")
69
+ except Exception as e:
70
+ print(f"Failed to copy static files: {e}")
71
+
72
+ try:
73
+ with open(template_file_path, 'r', encoding='utf-8') as f:
74
+ html_content = f.read()
75
+ except Exception as e:
76
+ print(f"Failed to read template file: {e}")
77
+ return
78
+
79
+ return static_dir
80
+
81
+ def main():
82
+ """Main pipeline for generating project pages from research papers."""
83
+ parser = argparse.ArgumentParser(description='Paper2ProjectPage Generation Pipeline')
84
+ parser.add_argument('--paper_path', type=str, required=True, help='Path to the research paper PDF')
85
+ parser.add_argument('--model_name_t', type=str, default='4o', help='Text model name')
86
+ parser.add_argument('--model_name_v', type=str, default='4o', help='Vision model name')
87
+ parser.add_argument('--template_root', type=str, default="project_templates", help='Directory containing all templates')
88
+ parser.add_argument('--template_dir', type=str, help='Directory of chosen template')
89
+ parser.add_argument('--template_file', type=str, help='Path to a specific template file to use')
90
+ parser.add_argument('--output_dir', type=str, default='generated_project_pages', help='Output directory for generated pages')
91
+ parser.add_argument('--style_preference', type=str, default=None, help='Path to style preference JSON file')
92
+ parser.add_argument('--tmp_dir', type=str, default='tmp', help='Temporary directory')
93
+ parser.add_argument('--full_content_check_times', type=int, default='0', help='Temporary directory')
94
+ parser.add_argument('--background_color', type=str, choices=['light', 'dark'], required=True,
95
+ help='Background color of generated project page')
96
+ parser.add_argument('--has_navigation', type=str, choices=['yes', 'no'], required=True,
97
+ help='Is the generated project page has navigation')
98
+ parser.add_argument('--has_hero_section', type=str, choices=['yes', 'no'], required=True,
99
+ help='Is the generated project page has hero section')
100
+ parser.add_argument('--title_color', type=str, choices=['pure', 'colorful'], required=True,
101
+ help="Is the title's color of the project page is pure or colorful")
102
+ parser.add_argument('--page_density', type=str, choices=['spacious', 'compact'], required=True,
103
+ help="The overall spacing tightness—amount of white space vs. information density")
104
+ parser.add_argument('--image_layout', type=str, choices=['rotation', 'parallelism'], required=True,
105
+ help="The dominant arrangement style for images.")
106
+ parser.add_argument('--html_check_times', type=int, default='1', help='Temporary directory')
107
+ parser.add_argument(
108
+ '--resume',
109
+ type=str,
110
+ choices=['parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'],
111
+ default='parse_pdf',
112
+ help="From which step to resume: 'parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'",
113
+ )
114
+ parser.add_argument('--human_input', type=str, default='1',choices=['0','1'] ,help='Human input for feedback')
115
+
116
+ args = parser.parse_args()
117
+
118
+ if not args.template_dir:
119
+ template_requirement = {
120
+ "background_color": args.background_color,
121
+ "has_hero_section": args.has_hero_section,
122
+ "Page density": args.page_density,
123
+ "image_layout": args.image_layout,
124
+ "has_navigation": args.has_navigation,
125
+ "title_color": args.title_color
126
+ }
127
+ matched_template = matching(template_requirement)
128
+ print('Below is names of the most matching 3 templates:')
129
+ print(' '.join(matched_template))
130
+ template_name = input('Please choose one from them, you can just input the name of your favorite template')
131
+ while template_name not in matched_template:
132
+ template_name = input('Please input the correct name of your favorite template!!')
133
+ args.template_dir = os.path.join(args.template_root, template_name)
134
+
135
+ # Extract html path from root path
136
+ if not args.template_file:
137
+ html_finder_ = HtmlFinder()
138
+ args.template_file = html_finder_.find_html(args.template_dir)
139
+
140
+ # Extract paper name from path
141
+ paper_name = args.paper_path.split('/')[-1].replace('.pdf', '') if '/' in args.paper_path else args.paper_path.replace('.pdf', '')
142
+ args.paper_name = paper_name
143
+
144
+ print(f"Starting Paper2ProjectPage generation for: {paper_name}")
145
+ print(f"Paper path: {args.paper_path}")
146
+ print(f"Models: {args.model_name_t} (text), {args.model_name_v} (vision)")
147
+
148
+ start_time = time.time()
149
+ total_input_tokens_t = 0
150
+ total_output_tokens_t = 0
151
+ total_input_tokens_v = 0
152
+ total_output_tokens_v = 0
153
+
154
+ # Create temporary directory
155
+ os.makedirs(args.tmp_dir, exist_ok=True)
156
+
157
+ try:
158
+ # Get agent configurations
159
+ agent_config_t = get_agent_config(args.model_name_t)
160
+ agent_config_v = get_agent_config(args.model_name_v)
161
+
162
+ # Step 1: Parse the research paper
163
+ print("\n" + "="*50)
164
+ print("STEP 1: Parsing Research Paper")
165
+ print("="*50)
166
+
167
+ raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
168
+ if not os.path.exists(raw_content_path):
169
+ print(f"Raw content does not exist at {raw_content_path}")
170
+
171
+
172
+ input_token, output_token, raw_result, images, tables = parse_paper_for_project_page(args, agent_config_t)
173
+ total_input_tokens_t += input_token
174
+ total_output_tokens_t += output_token
175
+
176
+ # Save parsed content
177
+ raw_content_path, token_log_path = save_parsed_content(args, raw_result, images, tables, input_token, output_token)
178
+
179
+ # Load parsed content
180
+ with open(raw_content_path, 'r') as f:
181
+ paper_content = json.load(f)
182
+ else:
183
+ print(f"Loading existing raw content from {raw_content_path}")
184
+ with open(raw_content_path, 'r') as f:
185
+ paper_content = json.load(f)
186
+ # Load images and tables from the saved content
187
+ images = paper_content.get('images', [])
188
+ tables = paper_content.get('tables', [])
189
+ token_log_path = raw_content_path.replace('_raw_content.json', '_parse_log.json')
190
+
191
+ images = paper_content.get('images', [])
192
+ tables = paper_content.get('tables', [])
193
+ figures = {
194
+ 'images': images,
195
+ 'tables': tables
196
+ }
197
+ paper_content = paper_content.get('markdown_content', "")
198
+
199
+
200
+ print("\n" + "="*50)
201
+ print("STEP 2: Generate project page content")
202
+ print("="*50)
203
+
204
+ planner = ProjectPageContentPlanner(agent_config_t, args)
205
+ figures_path = f'project_contents/{args.paper_name}_generated_filtered_figures.json'
206
+ generated_section_path = f'project_contents/{args.paper_name}_generated_section.json'
207
+ text_page_content_path = f'project_contents/{args.paper_name}_generated_text_content.json'
208
+ generated_content_path = f'project_contents/{args.paper_name}_generated_full_content.json'
209
+ if args.resume in ['parse_pdf','generate_content','full_content_check']:
210
+
211
+ if args.resume != 'full_content_check':
212
+
213
+ paper_content, figures, input_token, output_token = planner.filter_raw_content(paper_content, figures)
214
+ total_input_tokens_t += input_token
215
+ total_output_tokens_t += output_token
216
+
217
+ generated_section, input_token, output_token = planner.section_generation(paper_content, figures)
218
+ total_input_tokens_t += input_token
219
+ total_output_tokens_t += output_token
220
+
221
+ text_page_content, input_token, output_token = planner.text_content_generation(paper_content, figures, generated_section)
222
+ total_input_tokens_t += input_token
223
+ total_output_tokens_t += output_token
224
+
225
+ else :
226
+ print("Skipping content generation: filter_raw_content, section_generation, text_content_generation")
227
+ print("Loading existing content from previous steps.")
228
+ paper_content = filter_references(paper_content)
229
+ with open(figures_path, 'r') as f:
230
+ figures = json.load(f)
231
+ with open(generated_section_path, 'r') as f:
232
+ generated_section = json.load(f)
233
+ with open(text_page_content_path, 'r') as f:
234
+ text_page_content = json.load(f)
235
+
236
+ generated_content, input_token, output_token = planner.full_content_generation(args, paper_content, figures, generated_section, text_page_content)
237
+ total_input_tokens_t += input_token
238
+ total_output_tokens_t += output_token
239
+
240
+ print("\n" + "="*50)
241
+ print("STEP 2.5: Copying Static Files")
242
+ print("="*50)
243
+ static_dir = copy_static_files(args.template_file, args.template_dir, args.output_dir, args.paper_name)
244
+
245
+ else:
246
+ print("Page content is already generated, loading existing content.")
247
+
248
+ paper_content = filter_references(paper_content)
249
+ with open(generated_section_path, 'r') as f:
250
+ generated_section = json.load(f)
251
+ with open(text_page_content_path, 'r') as f:
252
+ text_page_content = json.load(f)
253
+ with open(generated_content_path, 'r') as f:
254
+ generated_content = json.load(f)
255
+
256
+ static_dir = copy_static_files(args.template_file, args.template_dir, args.output_dir, args.paper_name)
257
+ # static_dir = os.path.join(args.output_dir, args.paper_name, 'static')
258
+ # Step 3: Generate HTML project page
259
+ print("\n" + "="*50)
260
+ print("STEP 3: Generating HTML Project Page")
261
+ print("="*50)
262
+ html_relative_path = os.path.relpath(args.template_file, args.template_dir)
263
+ html_dir = '/'.join(html_relative_path.strip().split('/')[:-1])
264
+ html_generator = ProjectPageHTMLGenerator(agent_config_t,args)
265
+ with open(args.template_file, 'r', encoding='utf-8') as file:
266
+ html_template = file.read()
267
+ # Generate HTML
268
+ if args.resume != 'modify_table' and args.resume != 'html_feedback':
269
+
270
+ # Create assets directory and copy images
271
+ assets_dir = html_generator.create_assets_directory(args, html_dir, args.output_dir)
272
+ # Generate complete HTML
273
+ html_content, input_token, output_token = html_generator.generate_complete_html(
274
+ args, generated_content, html_dir, html_template
275
+ )
276
+ total_input_tokens_t += input_token
277
+ total_output_tokens_t += output_token
278
+
279
+ # Save HTML file
280
+ html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_no_modify_table.html')
281
+ with open(html_file_path,'w') as file:
282
+ file.write(html_content)
283
+ run_sync_screenshots(to_url(html_file_path), os.path.join(args.output_dir,args.paper_name, html_dir,'page_final_no_modify_table.png'))
284
+
285
+ else:
286
+ print(f"skip generate_html and html_check, load html from {os.path.join(args.output_dir,args.paper_name, html_dir,'index.html')}")
287
+ assets_dir = os.path.join(args.output_dir, args.paper_name, html_dir,'assets')
288
+ with open(os.path.join(args.output_dir,args.paper_name, html_dir,'index_no_modify_table.html'),'r') as file:
289
+ html_content = file.read()
290
+
291
+ if args.resume != 'html_feedback':
292
+ html_content ,input_token,output_token = html_generator.modify_html_table(html_content,html_dir)
293
+ total_input_tokens_t += input_token
294
+ total_output_tokens_t += output_token
295
+ html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_modify_table.html')
296
+ with open(html_file_path,'w') as file:
297
+ file.write(html_content)
298
+ # html_file_path = html_generator.save_html_file(html_content, args, html_dir,args.output_dir)
299
+ else:
300
+ print("skipping modify_table,go to html_feedback")
301
+ html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_modify_table.html')
302
+ with open(html_file_path,'r') as file:
303
+ html_content = file.read()
304
+
305
+ print('-'*50)
306
+ run_sync_screenshots(to_url(html_file_path), os.path.join(args.output_dir, args.paper_name, html_dir,'page_final.png'))
307
+ if args.human_input == '1':
308
+ human_feedback = input('Please view the final html in index.html,and image in page_final.png,If there are no problems, enter yes and press Enter.\n If there are any problems, please give me feedback directly.\n')
309
+ while human_feedback.lower() != 'yes':
310
+
311
+ html_content ,input_token,output_token = html_generator.modify_html_from_human_feedback(html_content,human_feedback)
312
+ total_input_tokens_t += input_token
313
+ total_output_tokens_t += output_token
314
+ with open(os.path.join(args.output_dir, args.paper_name, html_dir, 'index.html'),'w') as file:
315
+ file.write(html_content)
316
+ run_sync_screenshots(to_url(os.path.join(args.output_dir, args.paper_name, html_dir, 'index.html')), os.path.join(args.output_dir, args.paper_name, html_dir,'page_final.png'))
317
+ print('-'*50)
318
+ human_feedback = input('Please view the final html in index.html,and image in page_final.png,If there are no problems, enter yes and press Enter. \n If there are any problems, please give me feedback directly.\n')
319
+
320
+ html_file_path = html_generator.save_html_file(html_content, args, html_dir,args.output_dir)
321
+
322
+ # Generate and save metadata
323
+ metadata = html_generator.generate_metadata(generated_content, args)
324
+ metadata_path = html_generator.save_metadata(metadata, args, args.output_dir)
325
+
326
+ # Step 4: Finalize and save logs
327
+ print("\n" + "="*50)
328
+ print("STEP 4: Finalizing Generation")
329
+ print("="*50)
330
+
331
+ end_time = time.time()
332
+ time_taken = end_time - start_time
333
+
334
+ # Save generation log
335
+ log_data = {
336
+ 'paper_name': paper_name,
337
+ 'paper_path': args.paper_path,
338
+ 'models': {
339
+ 'text_model': args.model_name_t,
340
+ 'vision_model': args.model_name_v
341
+ },
342
+ 'token_usage': {
343
+ 'text_input_tokens': total_input_tokens_t,
344
+ 'text_output_tokens': total_output_tokens_t,
345
+ 'vision_input_tokens': total_input_tokens_v,
346
+ 'vision_output_tokens': total_output_tokens_v
347
+ },
348
+ 'generation_time': time_taken,
349
+ 'output_files': {
350
+ 'html_file': html_file_path,
351
+ 'assets_dir': assets_dir,
352
+ 'static_dir': static_dir,
353
+ 'metadata_file': metadata_path
354
+ },
355
+ 'content_files': {
356
+ 'raw_content': raw_content_path,
357
+ 'token_log': token_log_path
358
+ }
359
+ }
360
+
361
+ log_path = f"{args.output_dir}/{args.paper_name}/generation_log.json"
362
+ with open(log_path, 'w') as f:
363
+ json.dump(log_data, f, indent=4)
364
+
365
+ print(f"\n✅ Paper2ProjectPage generation completed successfully!")
366
+ print(f"📁 Output directory: {args.output_dir}/{args.paper_name}")
367
+ print(f"🌐 HTML file: {html_file_path}")
368
+ print(f"📊 Assets directory: {assets_dir}")
369
+ print(f"🎨 Static directory: {static_dir}")
370
+ print(f"📋 Metadata file: {metadata_path}")
371
+ print(f"⏱️ Total time: {time_taken:.2f} seconds")
372
+ print(f"🔢 Token usage - Text: {total_input_tokens_t}→{total_output_tokens_t}, Vision: {total_input_tokens_v}→{total_output_tokens_v}")
373
+
374
+ except Exception as e:
375
+ print(f"\n❌ Error during generation: {str(e)}")
376
+ raise
377
+
378
+ if __name__ == '__main__':
379
+ main()
ProjectPageAgent/parse_paper.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Paper parsing module for ProjectPageAgent.
3
+ Reuses the parsing capabilities from Paper2Poster.
4
+ """
5
+
6
+ from ProjectPageAgent.parse_raw import parse_raw, gen_image_and_table
7
+ from utils.wei_utils import get_agent_config
8
+ import json
9
+ import os
10
+ import argparse
11
+
12
+ def parse_paper_for_project_page(args, agent_config_t, version=2):
13
+ """
14
+ Parse a research paper PDF and extract content for project page generation.
15
+
16
+ Args:
17
+ args: Command line arguments
18
+ agent_config_t: Text model configuration
19
+ version: Parser version to use
20
+
21
+ Returns:
22
+ tuple: (input_tokens, output_tokens, raw_result, images, tables)
23
+ """
24
+ print("Step 1: Parsing the research paper...")
25
+
26
+ # Add poster_path and poster_name attributes to args for compatibility with parse_raw
27
+ if not hasattr(args, 'poster_path'):
28
+ args.poster_path = args.paper_path
29
+
30
+ if not hasattr(args, 'poster_name'):
31
+ args.poster_name = args.paper_name
32
+
33
+ # Parse the raw paper content
34
+ input_token, output_token, raw_result = parse_raw(args, agent_config_t, version=version)
35
+
36
+ # Extract images and tables
37
+ _, _, images, tables = gen_image_and_table(args, raw_result)
38
+
39
+ print(f"Parsing completed. Tokens: {input_token} -> {output_token}")
40
+ print(f"Extracted {len(images)} images and {len(tables)} tables")
41
+
42
+ return input_token, output_token, raw_result, images, tables
43
+
44
+ def save_parsed_content(args, raw_result, images, tables, input_token, output_token):
45
+ """
46
+ Save parsed content to files for later use.
47
+
48
+ Args:
49
+ args: Command line arguments
50
+ raw_result: Parsed raw content
51
+ images: Extracted images
52
+ tables: Extracted tables
53
+ input_token: Input token count
54
+ output_token: Output token count
55
+ """
56
+ # Save raw content
57
+ os.makedirs('project_contents', exist_ok=True)
58
+ raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
59
+
60
+ # Convert raw_result to JSON format if needed
61
+ if hasattr(raw_result, 'document'):
62
+ # Extract text content from docling result
63
+ raw_markdown = raw_result.document.export_to_markdown()
64
+ content_json = {
65
+ 'markdown_content': raw_markdown,
66
+ 'images': images,
67
+ 'tables': tables
68
+ }
69
+ else:
70
+ content_json = raw_result
71
+
72
+ with open(raw_content_path, 'w') as f:
73
+ json.dump(content_json, f, indent=4)
74
+
75
+ # Save token usage
76
+ token_log = {
77
+ 'parse_input_tokens': input_token,
78
+ 'parse_output_tokens': output_token,
79
+ 'total_images': len(images),
80
+ 'total_tables': len(tables)
81
+ }
82
+
83
+ token_log_path = f'project_contents/{args.paper_name}_parse_log.json'
84
+ with open(token_log_path, 'w') as f:
85
+ json.dump(token_log, f, indent=4)
86
+
87
+ print(f"Parsed content saved to {raw_content_path}")
88
+ return raw_content_path, token_log_path
ProjectPageAgent/parse_raw.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from utils.src.utils import get_json_from_response
3
+ from utils.src.model_utils import parse_pdf
4
+ import json
5
+ import random
6
+ import os
7
+
8
+ from camel.models import ModelFactory
9
+ from camel.agents import ChatAgent
10
+ from tenacity import retry, stop_after_attempt
11
+ from docling_core.types.doc import ImageRefMode, PictureItem, TableItem
12
+
13
+ from docling.datamodel.base_models import InputFormat
14
+ from docling.datamodel.pipeline_options import PdfPipelineOptions
15
+ from docling.document_converter import DocumentConverter, PdfFormatOption
16
+
17
+ from pathlib import Path
18
+
19
+ import PIL
20
+
21
+ from marker.models import create_model_dict
22
+
23
+ from utils.wei_utils import *
24
+
25
+ from utils.pptx_utils import *
26
+ from utils.critic_utils import *
27
+ import torch
28
+ from jinja2 import Template
29
+ import re
30
+ import argparse
31
+
32
+ load_dotenv()
33
+ IMAGE_RESOLUTION_SCALE = 5.0
34
+
35
+ pipeline_options = PdfPipelineOptions()
36
+ pipeline_options.images_scale = IMAGE_RESOLUTION_SCALE
37
+ pipeline_options.generate_page_images = True
38
+ pipeline_options.generate_picture_images = True
39
+
40
+ doc_converter = DocumentConverter(
41
+ format_options={
42
+ InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
43
+ }
44
+ )
45
+
46
+ @retry(stop=stop_after_attempt(5))
47
+ def parse_raw(args, actor_config, version=1):
48
+ raw_source = args.poster_path
49
+ markdown_clean_pattern = re.compile(r"<!--[\s\S]*?-->")
50
+
51
+ raw_result = doc_converter.convert(raw_source)
52
+
53
+ raw_markdown = raw_result.document.export_to_markdown()
54
+ text_content = markdown_clean_pattern.sub("", raw_markdown)
55
+
56
+ if len(text_content) < 500:
57
+ print('\nParsing with docling failed, using marker instead\n')
58
+ parser_model = create_model_dict(device='cuda', dtype=torch.float16)
59
+ text_content, rendered = parse_pdf(raw_source, model_lst=parser_model, save_file=False)
60
+
61
+ if version == 1:
62
+ template = Template(open("utils/prompts/gen_page_raw_content.txt").read())
63
+ elif version == 2:
64
+ template = Template(open("utils/prompts/gen_page_raw_content_v2.txt").read())
65
+
66
+ # Get API key from environment variables
67
+ api_key = None
68
+ if args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
69
+ api_key = os.environ.get('OPENAI_API_KEY')
70
+ elif args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
71
+ api_key = os.environ.get('GEMINI_API_KEY')
72
+ elif args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
73
+ api_key = os.environ.get('QWEN_API_KEY')
74
+ elif args.model_name_t.startswith('openrouter_'):
75
+ api_key = os.environ.get('OPENROUTER_API_KEY')
76
+ elif args.model_name_t in ['zhipuai']:
77
+ api_key = os.environ.get('ZHIPUAI_API_KEY')
78
+
79
+ if args.model_name_t.startswith('vllm_qwen'):
80
+ actor_model = ModelFactory.create(
81
+ model_platform=actor_config['model_platform'],
82
+ model_type=actor_config['model_type'],
83
+ model_config_dict=actor_config['model_config'],
84
+ url=actor_config['url'],
85
+ api_key=api_key,
86
+ )
87
+ else:
88
+ actor_model = ModelFactory.create(
89
+ model_platform=actor_config['model_platform'],
90
+ model_type=actor_config['model_type'],
91
+ model_config_dict=actor_config['model_config'],
92
+ api_key=api_key,
93
+ )
94
+
95
+ actor_sys_msg = 'You are the author of the paper, and you will create a poster for the paper.'
96
+
97
+ actor_agent = ChatAgent(
98
+ system_message=actor_sys_msg,
99
+ model=actor_model,
100
+ message_window_size=10,
101
+ token_limit=actor_config.get('token_limit', None)
102
+ )
103
+
104
+ while True:
105
+ prompt = template.render(
106
+ markdown_document=text_content,
107
+ )
108
+ actor_agent.reset()
109
+ response = actor_agent.step(prompt)
110
+ input_token, output_token = account_token(response)
111
+
112
+ content_json = get_json_from_response(response.msgs[0].content)
113
+
114
+ if len(content_json) > 0:
115
+ break
116
+ print('Error: Empty response, retrying...')
117
+ if args.model_name_t.startswith('vllm_qwen'):
118
+ text_content = text_content[:80000]
119
+
120
+ if len(content_json['sections']) > 9:
121
+ # First 2 sections + randomly select 5 sections + last 2 sections
122
+ selected_sections = content_json['sections'][:2] + random.sample(content_json['sections'][2:-2], 5) + content_json['sections'][-2:]
123
+ content_json['sections'] = selected_sections
124
+
125
+ has_title = False
126
+
127
+ for section in content_json['sections']:
128
+ if type(section) != dict or not 'title' in section or not 'content' in section:
129
+ print(f"Ouch! The response is invalid, the LLM is not following the format :(")
130
+ print('Trying again...')
131
+ raise
132
+ if 'title' in section['title'].lower():
133
+ has_title = True
134
+
135
+ if not has_title:
136
+ print('Ouch! The response is invalid, the LLM is not following the format :(')
137
+ raise
138
+
139
+ os.makedirs('contents', exist_ok=True)
140
+ json.dump(content_json, open(f'contents/{args.poster_name}_raw_content.json', 'w'), indent=4)
141
+ return input_token, output_token, raw_result
142
+
143
+
144
+ def gen_image_and_table(args, conv_res):
145
+ input_token, output_token = 0, 0
146
+ raw_source = args.poster_path
147
+
148
+ output_dir = Path(f'generated_project_pages/images_and_tables/{args.poster_name}')
149
+
150
+ output_dir.mkdir(parents=True, exist_ok=True)
151
+ doc_filename = args.poster_name
152
+
153
+ # Save page images
154
+ for page_no, page in conv_res.document.pages.items():
155
+ page_no = page.page_no
156
+ page_image_filename = output_dir / f"{doc_filename}-{page_no}.png"
157
+ with page_image_filename.open("wb") as fp:
158
+ page.image.pil_image.save(fp, format="PNG")
159
+
160
+ # Save images of figures and tables
161
+ table_counter = 0
162
+ picture_counter = 0
163
+ for element, _level in conv_res.document.iterate_items():
164
+ if isinstance(element, TableItem):
165
+ table_counter += 1
166
+ element_image_filename = (
167
+ output_dir / f"{doc_filename}-table-{table_counter}.png"
168
+ )
169
+ with element_image_filename.open("wb") as fp:
170
+ element.get_image(conv_res.document).save(fp, "PNG")
171
+
172
+ if isinstance(element, PictureItem):
173
+ picture_counter += 1
174
+ element_image_filename = (
175
+ output_dir / f"{doc_filename}-picture-{picture_counter}.png"
176
+ )
177
+ with element_image_filename.open("wb") as fp:
178
+ element.get_image(conv_res.document).save(fp, "PNG")
179
+
180
+ # Save markdown with embedded pictures
181
+ md_filename = output_dir / f"{doc_filename}-with-images.md"
182
+ conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.EMBEDDED)
183
+
184
+ # Save markdown with externally referenced pictures
185
+ md_filename = output_dir / f"{doc_filename}-with-image-refs.md"
186
+ conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.REFERENCED)
187
+
188
+ # Save HTML with externally referenced pictures
189
+ html_filename = output_dir / f"{doc_filename}-with-image-refs.html"
190
+ conv_res.document.save_as_html(html_filename, image_mode=ImageRefMode.REFERENCED)
191
+
192
+ tables = {}
193
+
194
+ table_index = 1
195
+ for table in conv_res.document.tables:
196
+ caption = table.caption_text(conv_res.document)
197
+ if len(caption) > 0:
198
+ table_img_path = f'generated_project_pages/images_and_tables/{args.poster_name}/{args.poster_name}-table-{table_index}.png'
199
+ assests_table_path = f'assets/{args.poster_name}-table-{table_index}.png'
200
+ table_img = PIL.Image.open(table_img_path)
201
+ tables[str(table_index)] = {
202
+ 'caption': caption,
203
+ 'table_path': assests_table_path,
204
+ # 'assests_table_path': assests_table_path,
205
+ 'width': table_img.width,
206
+ 'height': table_img.height,
207
+ 'figure_size': table_img.width * table_img.height,
208
+ 'figure_aspect': table_img.width / table_img.height,
209
+ }
210
+
211
+ table_index += 1
212
+
213
+ images = {}
214
+ image_index = 1
215
+ for image in conv_res.document.pictures:
216
+ caption = image.caption_text(conv_res.document)
217
+ if len(caption) > 0:
218
+ image_img_path = f'generated_project_pages/images_and_tables/{args.poster_name}/{args.poster_name}-picture-{image_index}.png'
219
+ assests_image_path = f'assets/{args.poster_name}-picture-{image_index}.png'
220
+ image_img = PIL.Image.open(image_img_path)
221
+ images[str(image_index)] = {
222
+ 'caption': caption,
223
+ 'image_path': assests_image_path,
224
+ # 'assests_image_path': assests_image_path,
225
+ 'width': image_img.width,
226
+ 'height': image_img.height,
227
+ 'figure_size': image_img.width * image_img.height,
228
+ 'figure_aspect': image_img.width / image_img.height,
229
+ }
230
+ image_index += 1
231
+
232
+ json.dump(images, open(f'generated_project_pages/images_and_tables/{args.poster_name}_images.json', 'w'), indent=4)
233
+ json.dump(tables, open(f'generated_project_pages/images_and_tables/{args.poster_name}_tables.json', 'w'), indent=4)
234
+
235
+ return input_token, output_token, images, tables
236
+
237
+ if __name__ == '__main__':
238
+ parser = argparse.ArgumentParser()
239
+ parser.add_argument('--poster_name', type=str, default=None)
240
+ parser.add_argument('--model_name', type=str, default='4o')
241
+ parser.add_argument('--poster_path', type=str, required=True)
242
+ parser.add_argument('--index', type=int, default=0)
243
+ args = parser.parse_args()
244
+
245
+ agent_config = get_agent_config(args.model_name)
246
+
247
+ if args.poster_name is None:
248
+ args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
249
+
250
+ # Parse raw content
251
+ input_token, output_token = parse_raw(args, agent_config)
252
+
253
+ # Generate images and tables
254
+ _, _ = gen_image_and_table(args)
255
+
256
+ print(f'Token consumption: {input_token} -> {output_token}')
ProjectPageAgent/template_analyzer.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Template analyzer for project page generation.
3
+ Analyzes existing project page templates to understand structure and style.
4
+ """
5
+
6
+ import os
7
+ import json
8
+ import re
9
+ from bs4 import BeautifulSoup
10
+ from pathlib import Path
11
+ import yaml
12
+ from jinja2 import Environment, StrictUndefined
13
+
14
+ class ProjectPageTemplateAnalyzer:
15
+ """Analyzes project page templates to extract structure and styling patterns."""
16
+
17
+ def __init__(self, template_dir="project_templates"):
18
+ self.template_dir = Path(template_dir)
19
+ self.template_dir.mkdir(exist_ok=True)
20
+ self.templates = {}
21
+ self.common_patterns = {}
22
+
23
+ def analyze_html_template(self, html_file_path):
24
+ """
25
+ Analyze an HTML template file to extract structure and styling.
26
+
27
+ Args:
28
+ html_file_path: Path to the HTML template file
29
+
30
+ Returns:
31
+ dict: Analysis results including structure, styling, and patterns
32
+ """
33
+ try:
34
+ with open(html_file_path, 'r', encoding='utf-8') as f:
35
+ html_content = f.read()
36
+
37
+ soup = BeautifulSoup(html_content, 'html.parser')
38
+
39
+ analysis = {
40
+ 'file_path': html_file_path,
41
+ 'structure': self._extract_structure(soup),
42
+ 'styling': self._extract_styling(soup),
43
+ 'sections': self._extract_sections(soup),
44
+ 'components': self._extract_components(soup),
45
+ 'meta_info': self._extract_meta_info(soup)
46
+ }
47
+
48
+ return analysis
49
+
50
+ except Exception as e:
51
+ print(f"Error analyzing template {html_file_path}: {e}")
52
+ return None
53
+
54
+ def _extract_structure(self, soup):
55
+ """Extract the overall structure of the HTML document."""
56
+ structure = {
57
+ 'doctype': soup.find('!DOCTYPE') is not None,
58
+ 'html_lang': soup.html.get('lang', 'en') if soup.html else 'en',
59
+ 'head_sections': [],
60
+ 'body_sections': [],
61
+ 'main_content': None,
62
+ 'navigation': None,
63
+ 'footer': None
64
+ }
65
+
66
+ # Extract head sections
67
+ if soup.head:
68
+ for tag in soup.head.find_all(['meta', 'link', 'script', 'title']):
69
+ structure['head_sections'].append({
70
+ 'tag': tag.name,
71
+ 'attrs': dict(tag.attrs)
72
+ })
73
+
74
+ # Extract body structure
75
+ if soup.body:
76
+ for section in soup.body.find_all(['header', 'nav', 'main', 'section', 'article', 'aside', 'footer']):
77
+ structure['body_sections'].append({
78
+ 'tag': section.name,
79
+ 'id': section.get('id', ''),
80
+ 'class': section.get('class', []),
81
+ 'content_type': self._identify_content_type(section)
82
+ })
83
+
84
+ return structure
85
+
86
+ def _extract_styling(self, soup):
87
+ """Extract CSS styling information."""
88
+ styling = {
89
+ 'inline_styles': [],
90
+ 'external_css': [],
91
+ 'color_scheme': [],
92
+ 'typography': {},
93
+ 'layout': {}
94
+ }
95
+
96
+ # Extract inline styles
97
+ for tag in soup.find_all(style=True):
98
+ styling['inline_styles'].append({
99
+ 'tag': tag.name,
100
+ 'style': tag.get('style', '')
101
+ })
102
+
103
+ # Extract external CSS links
104
+ for link in soup.find_all('link', rel='stylesheet'):
105
+ styling['external_css'].append(link.get('href', ''))
106
+
107
+ # Extract color information
108
+ color_pattern = re.compile(r'#[0-9a-fA-F]{3,6}|rgb\([^)]+\)|rgba\([^)]+\)')
109
+ for tag in soup.find_all(style=True):
110
+ colors = color_pattern.findall(tag.get('style', ''))
111
+ styling['color_scheme'].extend(colors)
112
+
113
+ # Extract typography patterns
114
+ for tag in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p']):
115
+ font_size = re.search(r'font-size:\s*([^;]+)', tag.get('style', ''))
116
+ if font_size:
117
+ styling['typography'][tag.name] = font_size.group(1)
118
+
119
+ return styling
120
+
121
+ def _extract_sections(self, soup):
122
+ """Extract content sections and their organization."""
123
+ sections = []
124
+
125
+ for section in soup.find_all(['section', 'article', 'div'], class_=True):
126
+ section_info = {
127
+ 'tag': section.name,
128
+ 'id': section.get('id', ''),
129
+ 'classes': section.get('class', []),
130
+ 'content': self._extract_section_content(section),
131
+ 'images': self._extract_images(section),
132
+ 'tables': self._extract_tables(section)
133
+ }
134
+ sections.append(section_info)
135
+
136
+ return sections
137
+
138
+ def _extract_components(self, soup):
139
+ """Extract reusable components and their patterns."""
140
+ components = {
141
+ 'navigation': self._extract_navigation(soup),
142
+ 'hero_section': self._extract_hero_section(soup),
143
+ 'content_blocks': self._extract_content_blocks(soup),
144
+ 'image_galleries': self._extract_image_galleries(soup),
145
+ 'contact_forms': self._extract_contact_forms(soup)
146
+ }
147
+
148
+ return components
149
+
150
+ def _extract_meta_info(self, soup):
151
+ """Extract meta information and SEO elements."""
152
+ meta_info = {
153
+ 'title': soup.title.string if soup.title else '',
154
+ 'meta_tags': [],
155
+ 'open_graph': {},
156
+ 'twitter_cards': {}
157
+ }
158
+
159
+ for meta in soup.find_all('meta'):
160
+ meta_info['meta_tags'].append({
161
+ 'name': meta.get('name', ''),
162
+ 'content': meta.get('content', ''),
163
+ 'property': meta.get('property', '')
164
+ })
165
+
166
+ # Extract Open Graph tags
167
+ if meta.get('property', '').startswith('og:'):
168
+ meta_info['open_graph'][meta.get('property')] = meta.get('content', '')
169
+
170
+ # Extract Twitter Card tags
171
+ if meta.get('name', '').startswith('twitter:'):
172
+ meta_info['twitter_cards'][meta.get('name')] = meta.get('content', '')
173
+
174
+ return meta_info
175
+
176
+ def _identify_content_type(self, element):
177
+ """Identify the type of content in an element."""
178
+ text = element.get_text().lower()
179
+
180
+ if any(word in text for word in ['abstract', 'summary', 'overview']):
181
+ return 'abstract'
182
+ elif any(word in text for word in ['introduction', 'background']):
183
+ return 'introduction'
184
+ elif any(word in text for word in ['method', 'approach', 'methodology']):
185
+ return 'methodology'
186
+ elif any(word in text for word in ['result', 'experiment', 'evaluation']):
187
+ return 'results'
188
+ elif any(word in text for word in ['conclusion', 'discussion', 'future']):
189
+ return 'conclusion'
190
+ elif any(word in text for word in ['contact', 'author', 'team']):
191
+ return 'contact'
192
+ else:
193
+ return 'general'
194
+
195
+ def _extract_section_content(self, element):
196
+ """Extract text content from a section."""
197
+ content = {
198
+ 'headings': [],
199
+ 'paragraphs': [],
200
+ 'lists': [],
201
+ 'code_blocks': []
202
+ }
203
+
204
+ for heading in element.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
205
+ content['headings'].append({
206
+ 'level': int(heading.name[1]),
207
+ 'text': heading.get_text().strip()
208
+ })
209
+
210
+ for p in element.find_all('p'):
211
+ content['paragraphs'].append(p.get_text().strip())
212
+
213
+ for ul in element.find_all(['ul', 'ol']):
214
+ items = [li.get_text().strip() for li in ul.find_all('li')]
215
+ content['lists'].append({
216
+ 'type': ul.name,
217
+ 'items': items
218
+ })
219
+
220
+ for code in element.find_all(['code', 'pre']):
221
+ content['code_blocks'].append({
222
+ 'type': code.name,
223
+ 'content': code.get_text().strip()
224
+ })
225
+
226
+ return content
227
+
228
+ def _extract_images(self, element):
229
+ """Extract image information from an element."""
230
+ images = []
231
+ for img in element.find_all('img'):
232
+ images.append({
233
+ 'src': img.get('src', ''),
234
+ 'alt': img.get('alt', ''),
235
+ 'title': img.get('title', ''),
236
+ 'class': img.get('class', [])
237
+ })
238
+ return images
239
+
240
+ def _extract_tables(self, element):
241
+ """Extract table information from an element."""
242
+ tables = []
243
+ for table in element.find_all('table'):
244
+ table_info = {
245
+ 'class': table.get('class', []),
246
+ 'headers': [],
247
+ 'rows': []
248
+ }
249
+
250
+ # Extract headers
251
+ for th in table.find_all('th'):
252
+ table_info['headers'].append(th.get_text().strip())
253
+
254
+ # Extract rows
255
+ for tr in table.find_all('tr'):
256
+ row = [td.get_text().strip() for td in tr.find_all('td')]
257
+ if row:
258
+ table_info['rows'].append(row)
259
+
260
+ tables.append(table_info)
261
+
262
+ return tables
263
+
264
+ def _extract_navigation(self, soup):
265
+ """Extract navigation structure."""
266
+ nav = soup.find('nav')
267
+ if nav:
268
+ return {
269
+ 'links': [a.get('href', '') for a in nav.find_all('a')],
270
+ 'texts': [a.get_text().strip() for a in nav.find_all('a')],
271
+ 'structure': self._extract_nav_structure(nav)
272
+ }
273
+ return None
274
+
275
+ def _extract_nav_structure(self, nav_element):
276
+ """Extract the hierarchical structure of navigation."""
277
+ structure = []
278
+ for item in nav_element.find_all(['a', 'li'], recursive=False):
279
+ if item.name == 'a':
280
+ structure.append({
281
+ 'type': 'link',
282
+ 'text': item.get_text().strip(),
283
+ 'href': item.get('href', '')
284
+ })
285
+ elif item.name == 'li':
286
+ sub_items = []
287
+ for sub_item in item.find_all('a'):
288
+ sub_items.append({
289
+ 'text': sub_item.get_text().strip(),
290
+ 'href': sub_item.get('href', '')
291
+ })
292
+ structure.append({
293
+ 'type': 'group',
294
+ 'items': sub_items
295
+ })
296
+ return structure
297
+
298
+ def _extract_hero_section(self, soup):
299
+ """Extract hero section information."""
300
+ hero = soup.find(['header', 'section'], class_=re.compile(r'hero|banner|intro'))
301
+ if hero:
302
+ return {
303
+ 'title': hero.find(['h1', 'h2']).get_text().strip() if hero.find(['h1', 'h2']) else '',
304
+ 'subtitle': hero.find(['h2', 'h3', 'p']).get_text().strip() if hero.find(['h2', 'h3', 'p']) else '',
305
+ 'background_image': hero.find('img').get('src', '') if hero.find('img') else '',
306
+ 'cta_buttons': [a.get_text().strip() for a in hero.find_all('a', class_=re.compile(r'btn|button'))]
307
+ }
308
+ return None
309
+
310
+ def _extract_content_blocks(self, soup):
311
+ """Extract content block patterns."""
312
+ blocks = []
313
+ for block in soup.find_all(['div', 'section'], class_=re.compile(r'content|block|section')):
314
+ blocks.append({
315
+ 'classes': block.get('class', []),
316
+ 'content_type': self._identify_content_type(block),
317
+ 'has_images': bool(block.find('img')),
318
+ 'has_tables': bool(block.find('table')),
319
+ 'has_code': bool(block.find(['code', 'pre']))
320
+ })
321
+ return blocks
322
+
323
+ def _extract_image_galleries(self, soup):
324
+ """Extract image gallery patterns."""
325
+ galleries = []
326
+ for gallery in soup.find_all(['div', 'section'], class_=re.compile(r'gallery|carousel|slider')):
327
+ images = gallery.find_all('img')
328
+ galleries.append({
329
+ 'image_count': len(images),
330
+ 'layout': 'grid' if 'grid' in str(gallery.get('class', [])) else 'carousel',
331
+ 'images': [img.get('src', '') for img in images]
332
+ })
333
+ return galleries
334
+
335
+ def _extract_contact_forms(self, soup):
336
+ """Extract contact form patterns."""
337
+ forms = []
338
+ for form in soup.find_all('form'):
339
+ form_info = {
340
+ 'action': form.get('action', ''),
341
+ 'method': form.get('method', 'get'),
342
+ 'fields': []
343
+ }
344
+
345
+ for input_field in form.find_all(['input', 'textarea', 'select']):
346
+ form_info['fields'].append({
347
+ 'type': input_field.get('type', input_field.name),
348
+ 'name': input_field.get('name', ''),
349
+ 'placeholder': input_field.get('placeholder', ''),
350
+ 'required': input_field.get('required') is not None
351
+ })
352
+
353
+ forms.append(form_info)
354
+
355
+ return forms
356
+
357
+ def analyze_multiple_templates(self, template_files):
358
+ """
359
+ Analyze multiple template files and find common patterns.
360
+
361
+ Args:
362
+ template_files: List of template file paths
363
+
364
+ Returns:
365
+ dict: Analysis results with common patterns
366
+ """
367
+ all_analyses = []
368
+
369
+ for template_file in template_files:
370
+ analysis = self.analyze_html_template(template_file)
371
+ if analysis:
372
+ all_analyses.append(analysis)
373
+
374
+ # Find common patterns
375
+ common_patterns = self._find_common_patterns(all_analyses)
376
+
377
+ return {
378
+ 'individual_analyses': all_analyses,
379
+ 'common_patterns': common_patterns
380
+ }
381
+
382
+ def _find_common_patterns(self, analyses):
383
+ """Find common patterns across multiple template analyses."""
384
+ patterns = {
385
+ 'common_sections': [],
386
+ 'common_styles': [],
387
+ 'common_components': [],
388
+ 'color_schemes': [],
389
+ 'layout_patterns': []
390
+ }
391
+
392
+ # Analyze common sections
393
+ all_sections = []
394
+ for analysis in analyses:
395
+ all_sections.extend(analysis['sections'])
396
+
397
+ section_types = {}
398
+ for section in all_sections:
399
+ content_type = section.get('content_type', 'unknown')
400
+ if content_type not in section_types:
401
+ section_types[content_type] = 0
402
+ section_types[content_type] += 1
403
+
404
+ patterns['common_sections'] = [
405
+ section_type for section_type, count in section_types.items()
406
+ if count > len(analyses) * 0.5 # Appears in more than 50% of templates
407
+ ]
408
+
409
+ # Analyze common styles
410
+ all_colors = []
411
+ for analysis in analyses:
412
+ all_colors.extend(analysis['styling']['color_scheme'])
413
+
414
+ color_counts = {}
415
+ for color in all_colors:
416
+ if color not in color_counts:
417
+ color_counts[color] = 0
418
+ color_counts[color] += 1
419
+
420
+ patterns['color_schemes'] = [
421
+ color for color, count in color_counts.items()
422
+ if count > len(analyses) * 0.3 # Appears in more than 30% of templates
423
+ ]
424
+
425
+ return patterns
426
+
427
+ def save_analysis(self, analysis, output_path):
428
+ """Save analysis results to a JSON file."""
429
+ try:
430
+ with open(output_path, 'w') as f:
431
+ json.dump(analysis, f, indent=2)
432
+ print(f"Analysis saved to {output_path}")
433
+ return True
434
+ except Exception as e:
435
+ print(f"Error saving analysis: {e}")
436
+ return False
app.py ADDED
@@ -0,0 +1,1671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import os
4
+ import json
5
+ from pathlib import Path
6
+ import base64
7
+ import re
8
+ from threading import Thread
9
+ from http.server import HTTPServer, SimpleHTTPRequestHandler
10
+ import socket
11
+ from dotenv import load_dotenv
12
+ from ProjectPageAgent.parse_paper import parse_paper_for_project_page, save_parsed_content
13
+ from ProjectPageAgent.html_finder import HtmlFinder
14
+ from ProjectPageAgent.content_planner import ProjectPageContentPlanner
15
+ from ProjectPageAgent.html_generator import ProjectPageHTMLGenerator, to_url
16
+ from utils.wei_utils import get_agent_config
17
+ import os
18
+ import subprocess
19
+
20
+ from ProjectPageAgent.content_planner import filter_references
21
+ from utils.src.utils import run_sync_screenshots
22
+ from ProjectPageAgent.main_pipline import matching, copy_static_files
23
+
24
+ load_dotenv()
25
+
26
+ subprocess.run(["playwright", "install", "chromium"], check=True)
27
+
28
+
29
+ def get_agent_config_with_keys(model_type, openai_api_key="", gemini_api_key="",
30
+ qwen_api_key="", zhipuai_api_key="", openrouter_api_key=""):
31
+ """
32
+ Get agent configuration with user-provided API keys.
33
+ Falls back to environment variables if user keys are not provided.
34
+ Note: This function sets environment variables but does NOT restore them.
35
+ The environment variables will remain set for the duration of the application.
36
+ """
37
+ # Set environment variables with user-provided keys
38
+ api_keys = {
39
+ 'OPENAI_API_KEY': openai_api_key,
40
+ 'GEMINI_API_KEY': gemini_api_key,
41
+ 'QWEN_API_KEY': qwen_api_key,
42
+ 'ZHIPUAI_API_KEY': zhipuai_api_key,
43
+ 'OPENROUTER_API_KEY': openrouter_api_key
44
+ }
45
+
46
+ # Set new API keys in environment
47
+ for key, value in api_keys.items():
48
+ if value and value.strip():
49
+ os.environ[key] = value
50
+
51
+ # Get agent config with the new API keys
52
+ config = get_agent_config(model_type)
53
+ return config
54
+
55
+ def validate_api_keys(model_name_t, model_name_v, openai_api_key, gemini_api_key,
56
+ qwen_api_key, zhipuai_api_key, openrouter_api_key):
57
+ """
58
+ Validate that required API keys are provided for the selected models.
59
+ """
60
+ errors = []
61
+
62
+ # Check text model requirements
63
+ if model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
64
+ if not openai_api_key or not openai_api_key.strip():
65
+ errors.append("OpenAI API key is required for GPT models")
66
+ elif model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
67
+ if not gemini_api_key or not gemini_api_key.strip():
68
+ errors.append("Gemini API key is required for Gemini models")
69
+ elif model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
70
+ if not qwen_api_key or not qwen_api_key.strip():
71
+ errors.append("Qwen API key is required for Qwen models")
72
+ elif model_name_t.startswith('openrouter_'):
73
+ if not openrouter_api_key or not openrouter_api_key.strip():
74
+ errors.append("OpenRouter API key is required for OpenRouter models")
75
+
76
+ # Check vision model requirements
77
+ if model_name_v in ['4o', '4o-mini']:
78
+ if not openai_api_key or not openai_api_key.strip():
79
+ errors.append("OpenAI API key is required for GPT vision models")
80
+ elif model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
81
+ if not gemini_api_key or not gemini_api_key.strip():
82
+ errors.append("Gemini API key is required for Gemini vision models")
83
+ elif model_name_v in ['qwen-vl-max', 'qwen-2.5-vl-72b']:
84
+ if not qwen_api_key or not qwen_api_key.strip():
85
+ errors.append("Qwen API key is required for Qwen vision models")
86
+ elif model_name_v.startswith('openrouter_'):
87
+ if not openrouter_api_key or not openrouter_api_key.strip():
88
+ errors.append("OpenRouter API key is required for OpenRouter vision models")
89
+
90
+ return errors
91
+
92
+ # Global Variables
93
+ current_html_dir = None
94
+ preview_server = None
95
+ preview_port = None
96
+ template_preview_servers = []
97
+
98
+ class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
99
+ def __init__(self, *args, **kwargs):
100
+ super().__init__(*args, directory=current_html_dir, **kwargs)
101
+
102
+ def log_message(self, format, *args):
103
+ pass
104
+
105
+ def find_free_port(start_port=8000, max_attempts=100):
106
+ for port in range(start_port, start_port + max_attempts):
107
+ try:
108
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
109
+ s.bind(('', port))
110
+ return port
111
+ except OSError:
112
+ continue
113
+ raise RuntimeError(f"Could not find available port")
114
+
115
+ def start_preview_server(html_dir):
116
+ global current_html_dir, preview_server, preview_port
117
+ stop_preview_server()
118
+ current_html_dir = html_dir
119
+ preview_port = find_free_port()
120
+ preview_server = HTTPServer(('0.0.0.0', preview_port), CustomHTTPRequestHandler)
121
+ server_thread = Thread(target=preview_server.serve_forever, daemon=True)
122
+ server_thread.start()
123
+ return preview_port
124
+
125
+ def stop_preview_server():
126
+ global preview_server, preview_port
127
+ if preview_server:
128
+ preview_server.shutdown()
129
+ preview_server = None
130
+ preview_port = None
131
+
132
+ def start_ephemeral_server_for_dir(html_dir):
133
+ port = find_free_port()
134
+ class _TempHandler(SimpleHTTPRequestHandler):
135
+ def __init__(self, *args, **kwargs):
136
+ super().__init__(*args, directory=html_dir, **kwargs)
137
+ def log_message(self, format, *args):
138
+ pass
139
+ srv = HTTPServer(('0.0.0.0', port), _TempHandler)
140
+ t = Thread(target=srv.serve_forever, daemon=True)
141
+ t.start()
142
+ template_preview_servers.append((srv, port))
143
+ return port
144
+
145
+ def stop_all_template_preview_servers():
146
+ global template_preview_servers
147
+ for srv, _ in template_preview_servers:
148
+ try:
149
+ srv.shutdown()
150
+ except Exception:
151
+ pass
152
+ template_preview_servers = []
153
+
154
+ class GenerationArgs:
155
+ def __init__(self, paper_path, model_name_t, model_name_v, template_root,
156
+ template_dir, template_file, output_dir, style_preference, tmp_dir,
157
+ full_content_check_times, background_color, has_navigation,
158
+ has_hero_section, title_color, page_density, image_layout,
159
+ html_check_times, resume, human_input):
160
+ self.paper_path = paper_path
161
+ self.model_name_t = model_name_t
162
+ self.model_name_v = model_name_v
163
+ self.template_root = template_root
164
+ self.template_dir = template_dir
165
+ self.template_file = template_file
166
+ self.output_dir = output_dir
167
+ self.style_preference = style_preference
168
+ self.tmp_dir = tmp_dir
169
+ self.full_content_check_times = full_content_check_times
170
+ self.background_color = background_color
171
+ self.has_navigation = has_navigation
172
+ self.has_hero_section = has_hero_section
173
+ self.title_color = title_color
174
+ self.page_density = page_density
175
+ self.image_layout = image_layout
176
+ self.html_check_times = html_check_times
177
+ self.resume = resume
178
+ self.human_input = human_input
179
+ self.paper_name = None
180
+
181
+ # ==================== Formatting Functions ====================
182
+
183
+ def format_section_to_markdown(section_data):
184
+ """
185
+ Convert Section JSON to beautifully formatted Markdown
186
+
187
+ Args:
188
+ section_data: Section JSON data
189
+
190
+ Returns:
191
+ str: Formatted Markdown string
192
+ """
193
+ if not section_data:
194
+ return "No data available"
195
+
196
+ md_lines = []
197
+
198
+ # Title
199
+ md_lines.append("# 📄 Paper Page Structure Preview\n")
200
+
201
+ # Basic Information
202
+ if "title" in section_data:
203
+ md_lines.append(f"## 📌 Title\n**{section_data['title']}**\n")
204
+
205
+ if "authors" in section_data:
206
+ md_lines.append(f"## 👥 Authors\n{section_data['authors']}\n")
207
+
208
+ if "affiliation" in section_data:
209
+ md_lines.append(f"## 🏛️ Affiliation\n{section_data['affiliation']}\n")
210
+
211
+ # Other Sections
212
+ md_lines.append("## 📑 Page Sections\n")
213
+
214
+ section_count = 0
215
+ for key, value in section_data.items():
216
+ if key in ["title", "authors", "affiliation"]:
217
+ continue
218
+
219
+ section_count += 1
220
+
221
+ # Section Title
222
+ section_title = key.replace("_", " ").title()
223
+ md_lines.append(f"### {section_count}. {section_title}\n")
224
+
225
+ # Section Content
226
+ if isinstance(value, dict):
227
+ # If dictionary, process recursively
228
+ for sub_key, sub_value in value.items():
229
+ sub_title = sub_key.replace("_", " ").title()
230
+ md_lines.append(f"**{sub_title}**: {sub_value}\n")
231
+ elif isinstance(value, list):
232
+ # If list
233
+ for item in value:
234
+ if isinstance(item, str):
235
+ md_lines.append(f"- {item}\n")
236
+ elif isinstance(item, dict):
237
+ for k, v in item.items():
238
+ md_lines.append(f"- **{k}**: {v}\n")
239
+ else:
240
+ # Simple value
241
+ md_lines.append(f"{value}\n")
242
+
243
+ md_lines.append("") # Empty line
244
+
245
+ # Add Statistics
246
+ md_lines.append("---\n")
247
+ md_lines.append(f"**📊 Total {section_count} sections**\n")
248
+
249
+ return "\n".join(md_lines)
250
+
251
+
252
+ def format_full_content_to_markdown(content_data, figures=None):
253
+ """
254
+ Convert Full Content JSON to beautifully formatted Markdown
255
+
256
+ Args:
257
+ content_data: Full Content JSON data
258
+ figures: Images and tables data (optional)
259
+
260
+ Returns:
261
+ str: Formatted Markdown string
262
+ """
263
+ if not content_data:
264
+ return "No data available"
265
+
266
+ md_lines = []
267
+
268
+ # Title
269
+ md_lines.append("# 📄 Full Content Preview\n")
270
+
271
+ # Basic Information
272
+ if "title" in content_data:
273
+ md_lines.append(f"# {content_data['title']}\n")
274
+
275
+ if "authors" in content_data:
276
+ md_lines.append(f"**Authors**: {content_data['authors']}\n")
277
+
278
+ if "affiliation" in content_data:
279
+ md_lines.append(f"**Affiliation**: {content_data['affiliation']}\n")
280
+
281
+ md_lines.append("---\n")
282
+
283
+ # Process Each Section
284
+ section_count = 0
285
+ image_count = 0
286
+ table_count = 0
287
+
288
+ for key, value in content_data.items():
289
+ if key in ["title", "authors", "affiliation"]:
290
+ continue
291
+
292
+ section_count += 1
293
+
294
+ # Section Title
295
+ section_title = key.replace("_", " ").title()
296
+ md_lines.append(f"## {section_count}. {section_title}\n")
297
+
298
+ # Process Content
299
+ if isinstance(value, dict):
300
+ # Process dictionary type content
301
+ for sub_key, sub_value in value.items():
302
+ if sub_key.lower() in ['content', 'description', 'text']:
303
+ # Main text content
304
+ md_lines.append(f"{sub_value}\n")
305
+ elif sub_key.lower() in ['image', 'figure', 'img']:
306
+ # Image
307
+ image_count += 1
308
+ if isinstance(sub_value, dict):
309
+ caption = sub_value.get('caption', f'Figure {image_count}')
310
+ path = sub_value.get('path', '')
311
+ md_lines.append(f"\n**🖼️ {caption}**\n")
312
+ if path:
313
+ md_lines.append(f"*Image path: `{path}`*\n")
314
+ else:
315
+ md_lines.append(f"\n**🖼️ Figure {image_count}**: {sub_value}\n")
316
+ elif sub_key.lower() in ['table']:
317
+ # Table
318
+ table_count += 1
319
+ md_lines.append(f"\n**📊 Table {table_count}**\n")
320
+ if isinstance(sub_value, dict):
321
+ caption = sub_value.get('caption', f'Table {table_count}')
322
+ md_lines.append(f"*{caption}*\n")
323
+ else:
324
+ md_lines.append(f"{sub_value}\n")
325
+ elif sub_key.lower() in ['code']:
326
+ # Code block
327
+ md_lines.append(f"\n```\n{sub_value}\n```\n")
328
+ else:
329
+ # Other subtitles
330
+ sub_title = sub_key.replace("_", " ").title()
331
+ md_lines.append(f"\n### {sub_title}\n")
332
+ md_lines.append(f"{sub_value}\n")
333
+
334
+ elif isinstance(value, list):
335
+ # Process list type content
336
+ for idx, item in enumerate(value):
337
+ if isinstance(item, dict):
338
+ # Dictionary items in list
339
+ if 'title' in item or 'name' in item:
340
+ item_title = item.get('title', item.get('name', f'Item {idx+1}'))
341
+ md_lines.append(f"\n### {item_title}\n")
342
+
343
+ for k, v in item.items():
344
+ if k not in ['title', 'name']:
345
+ if k.lower() in ['content', 'description', 'text']:
346
+ md_lines.append(f"{v}\n")
347
+ elif k.lower() in ['image', 'figure']:
348
+ image_count += 1
349
+ md_lines.append(f"\n**🖼️ Figure {image_count}**: {v}\n")
350
+ elif k.lower() == 'table':
351
+ table_count += 1
352
+ md_lines.append(f"\n**📊 Table {table_count}**: {v}\n")
353
+ else:
354
+ k_title = k.replace("_", " ").title()
355
+ md_lines.append(f"**{k_title}**: {v}\n")
356
+ else:
357
+ # Simple list item
358
+ md_lines.append(f"- {item}\n")
359
+
360
+ else:
361
+ # Simple text value
362
+ md_lines.append(f"{value}\n")
363
+
364
+ md_lines.append("") # Empty line between sections
365
+
366
+ # Add Statistics
367
+ md_lines.append("\n---\n")
368
+ stats = []
369
+ stats.append(f"📊 **Statistics**")
370
+ stats.append(f"- Sections: {section_count}")
371
+ if image_count > 0:
372
+ stats.append(f"- Images: {image_count}")
373
+ if table_count > 0:
374
+ stats.append(f"- Tables: {table_count}")
375
+
376
+ # If figures data is provided, add more information
377
+ if figures:
378
+ if 'images' in figures and figures['images']:
379
+ stats.append(f"- Available images: {len(figures['images'])}")
380
+ if 'tables' in figures and figures['tables']:
381
+ stats.append(f"- Available tables: {len(figures['tables'])}")
382
+
383
+ md_lines.append("\n".join(stats))
384
+ md_lines.append("\n")
385
+
386
+ return "\n".join(md_lines)
387
+
388
+ # ==================== Global State Management ====================
389
+
390
+ class GenerationState:
391
+ def __init__(self):
392
+ self.reset()
393
+
394
+ def reset(self):
395
+ self.args = None
396
+ self.paper_content = None
397
+ self.figures = None
398
+ self.generated_section = None
399
+ self.text_page_content = None
400
+ self.generated_content = None
401
+ self.html_content = None
402
+ self.html_file_path = None
403
+ self.html_dir = None
404
+ self.planner = None
405
+ self.html_generator = None
406
+ self.agent_config_t = None
407
+ self.total_input_tokens_t = 0
408
+ self.total_output_tokens_t = 0
409
+ self.current_stage = "init"
410
+ self.preview_url = None
411
+
412
+ state = GenerationState()
413
+
414
+ def create_project_zip(project_dir, output_dir, paper_name):
415
+ """
416
+ Create project archive
417
+
418
+ Args:
419
+ project_dir: Project directory path
420
+ output_dir: Output directory
421
+ paper_name: Paper name
422
+
423
+ Returns:
424
+ str: Archive path, None if failed
425
+ """
426
+ import zipfile
427
+
428
+ zip_filename = f"{paper_name}_project_page.zip"
429
+ zip_path = os.path.join(output_dir, zip_filename)
430
+
431
+ print(f"Creating project archive: {zip_path}")
432
+
433
+ try:
434
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
435
+ # Traverse project directory, add all files
436
+ for root, dirs, files in os.walk(project_dir):
437
+ for file in files:
438
+ file_path = os.path.join(root, file)
439
+ # Calculate relative path
440
+ arcname = os.path.relpath(file_path, output_dir)
441
+ zipf.write(file_path, arcname)
442
+
443
+ print(f"Archive created successfully: {zip_path}")
444
+
445
+ # Get archive size
446
+ zip_size = os.path.getsize(zip_path)
447
+ zip_size_mb = zip_size / (1024 * 1024)
448
+ print(f"Archive size: {zip_size_mb:.2f} MB")
449
+
450
+ return zip_path
451
+
452
+ except Exception as e:
453
+ print(f"Archive creation failed: {e}")
454
+ return None
455
+
456
+ def start_generation(pdf_file, model_name_t, model_name_v, template_root,
457
+ template_dir, template_file, output_dir, style_preference,
458
+ tmp_dir, full_content_check_times, background_color,
459
+ has_navigation, has_hero_section, title_color, page_density,
460
+ image_layout, html_check_times, resume, human_input,
461
+ template_choice_value, openai_api_key, gemini_api_key,
462
+ qwen_api_key, zhipuai_api_key, openrouter_api_key):
463
+ """Start generation process"""
464
+ if pdf_file is None:
465
+ return "❌ Please upload a PDF file", gr.update(visible=False), "", "", gr.update(), gr.update(), ""
466
+
467
+ # Validate API keys
468
+ validation_errors = validate_api_keys(
469
+ model_name_t, model_name_v, openai_api_key, gemini_api_key,
470
+ qwen_api_key, zhipuai_api_key, openrouter_api_key
471
+ )
472
+
473
+ if validation_errors:
474
+ error_msg = "❌ API Key Validation Failed:\n" + "\n".join(f"• {error}" for error in validation_errors)
475
+ return error_msg, gr.update(visible=False), "", "", gr.update(), gr.update(), ""
476
+
477
+ state.reset()
478
+
479
+ # Handle template selection
480
+ if not (template_dir and str(template_dir).strip()):
481
+ if not template_choice_value:
482
+ stop_all_template_preview_servers()
483
+ template_requirement = {
484
+ "background_color": background_color,
485
+ "has_hero_section": has_hero_section,
486
+ "Page density": page_density,
487
+ "image_layout": image_layout,
488
+ "has_navigation": has_navigation,
489
+ "title_color": title_color
490
+ }
491
+ try:
492
+ matched = matching(template_requirement)
493
+ except Exception as e:
494
+ return f"❌ Template recommendation failed: {e}", gr.update(visible=False), "", "", gr.update(choices=[], value=None), gr.update(visible=False, value=""), ""
495
+
496
+ html_finder_ = HtmlFinder()
497
+ with open('templates/template_link.json','r') as f:
498
+ template_link = json.load(f)
499
+ previews = []
500
+ for name in matched:
501
+ t_dir = os.path.join(template_root, name)
502
+ try:
503
+ html_path = html_finder_.find_html(t_dir)
504
+ if not os.path.exists(html_path):
505
+ continue
506
+ html_dir = os.path.dirname(os.path.abspath(html_path))
507
+ filename = os.path.basename(html_path)
508
+ port = start_ephemeral_server_for_dir(html_dir)
509
+ url = template_link[name]
510
+ previews.append((name, html_path, url))
511
+ except Exception:
512
+ continue
513
+
514
+ if not previews:
515
+ return "❌ No previewable templates found", gr.update(visible=False), "", "", gr.update(choices=[], value=None), gr.update(visible=False, value=""), ""
516
+
517
+ md_lines = ["### 🔍 Please select a template to preview before clicking **Start Generation**", ""]
518
+ for name, _, url in previews:
519
+ md_lines.append(f"- **{name}** → [{url}]({url})")
520
+ md = "\n".join(md_lines)
521
+
522
+ return "Recommended 3 templates, please select one to continue", gr.update(visible=False), "", "", gr.update(choices=[n for n, _, _ in previews], value=None), gr.update(visible=True, value=md), ""
523
+
524
+ template_dir = os.path.join(template_root, template_choice_value)
525
+
526
+ # Create arguments object
527
+ args = GenerationArgs(
528
+ paper_path=pdf_file.name,
529
+ model_name_t=model_name_t,
530
+ model_name_v=model_name_v,
531
+ template_root=template_root,
532
+ template_dir=template_dir,
533
+ template_file=template_file,
534
+ output_dir=output_dir,
535
+ style_preference=style_preference,
536
+ tmp_dir=tmp_dir,
537
+ full_content_check_times=full_content_check_times,
538
+ background_color=background_color,
539
+ has_navigation=has_navigation,
540
+ has_hero_section=has_hero_section,
541
+ title_color=title_color,
542
+ page_density=page_density,
543
+ image_layout=image_layout,
544
+ html_check_times=html_check_times,
545
+ resume=resume,
546
+ human_input=human_input
547
+ )
548
+
549
+ if not args.template_dir:
550
+ return "❌ Please select a template", gr.update(visible=False), "", "", gr.update(), gr.update(), ""
551
+
552
+ if not args.template_file:
553
+ html_finder_ = HtmlFinder()
554
+ args.template_file = html_finder_.find_html(args.template_dir)
555
+
556
+ paper_name = args.paper_path.split('/')[-1].replace('.pdf', '') if '/' in args.paper_path else args.paper_path.replace('.pdf', '')
557
+ args.paper_name = paper_name
558
+
559
+ os.makedirs(args.tmp_dir, exist_ok=True)
560
+
561
+ try:
562
+ # Initialization
563
+ agent_config_t = get_agent_config_with_keys(
564
+ args.model_name_t, openai_api_key, gemini_api_key,
565
+ qwen_api_key, zhipuai_api_key, openrouter_api_key
566
+ )
567
+ state.agent_config_t = agent_config_t
568
+ state.args = args
569
+
570
+ # Step 1: Parse PDF
571
+ print("="*50)
572
+ print("STEP 1: Parsing Research Paper")
573
+ print("="*50)
574
+
575
+ raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
576
+ if not os.path.exists(raw_content_path):
577
+ agent_config_v = get_agent_config_with_keys(
578
+ args.model_name_v, openai_api_key, gemini_api_key,
579
+ qwen_api_key, zhipuai_api_key, openrouter_api_key
580
+ )
581
+ input_token, output_token, raw_result, images, tables = parse_paper_for_project_page(args, agent_config_t)
582
+ state.total_input_tokens_t += input_token
583
+ state.total_output_tokens_t += output_token
584
+ raw_content_path, _ = save_parsed_content(args, raw_result, images, tables, input_token, output_token)
585
+
586
+ with open(raw_content_path, 'r') as f:
587
+ paper_content = json.load(f)
588
+
589
+ images = paper_content.get('images', [])
590
+ tables = paper_content.get('tables', [])
591
+ figures = {'images': images, 'tables': tables}
592
+ paper_content = paper_content.get('markdown_content', "")
593
+
594
+ state.paper_content = paper_content
595
+ state.figures = figures
596
+
597
+ # Step 2: Filter content
598
+ print("="*50)
599
+ print("STEP 2: Filtering Content")
600
+ print("="*50)
601
+
602
+ planner = ProjectPageContentPlanner(agent_config_t, args)
603
+ state.planner = planner
604
+
605
+ paper_content, figures, input_token, output_token = planner.filter_raw_content(paper_content, figures)
606
+ state.total_input_tokens_t += input_token
607
+ state.total_output_tokens_t += output_token
608
+ state.paper_content = paper_content
609
+ state.figures = figures
610
+
611
+ # Step 3: Generate Section
612
+ print("="*50)
613
+ print("STEP 3: Generating Sections")
614
+ print("="*50)
615
+
616
+ state.current_stage = "section"
617
+
618
+ generated_section, input_token, output_token = generate_section_initial()
619
+ state.total_input_tokens_t += input_token
620
+ state.total_output_tokens_t += output_token
621
+
622
+ # Use Markdown formatting
623
+ section_display_md = format_section_to_markdown(generated_section)
624
+ section_display_json = json.dumps(generated_section, indent=2, ensure_ascii=False)
625
+
626
+ return (
627
+ f"✅ Section generation completed, please review and provide feedback\n\nTokens: {input_token} → {output_token}",
628
+ gr.update(visible=True), # feedback_section
629
+ section_display_md, # Markdown format
630
+ section_display_json, # JSON format (hidden)
631
+ gr.update(),
632
+ gr.update(visible=False, value=""),
633
+ ""
634
+ )
635
+
636
+ except Exception as e:
637
+ import traceback
638
+ error_msg = f"❌ Generation failed: {str(e)}\n{traceback.format_exc()}"
639
+ return error_msg, gr.update(visible=False), "", "", gr.update(), gr.update(), ""
640
+
641
+ def generate_section_initial():
642
+ """Generate initial Section"""
643
+ import yaml
644
+ from jinja2 import Environment, StrictUndefined
645
+ from utils.wei_utils import account_token
646
+ from utils.src.utils import get_json_from_response
647
+
648
+ with open('utils/prompt_templates/page_templates/section_generation.yaml', 'r') as f:
649
+ planner_config = yaml.safe_load(f)
650
+
651
+ jinja_env = Environment(undefined=StrictUndefined)
652
+ template = jinja_env.from_string(planner_config["template"])
653
+
654
+ jinja_args = {
655
+ 'paper_content': state.paper_content,
656
+ 'json_format_example': json.dumps(state.paper_content, indent=2)
657
+ }
658
+
659
+ prompt = template.render(**jinja_args)
660
+
661
+ state.planner.planner_agent.reset()
662
+ response = state.planner.planner_agent.step(prompt)
663
+ input_token, output_token = account_token(response)
664
+ generated_section = get_json_from_response(response.msgs[0].content)
665
+
666
+ def create_dynamic_page_dict(sections):
667
+ poster_dict = {
668
+ "title": "Title of the paper",
669
+ "authors": "Authors of the paper",
670
+ "affiliation": "Affiliation of the authors",
671
+ }
672
+ poster_dict.update(sections)
673
+ return poster_dict
674
+
675
+ generated_section = create_dynamic_page_dict(generated_section)
676
+ state.generated_section = generated_section
677
+
678
+ generated_path = f'project_contents/{state.args.paper_name}_generated_section.json'
679
+ with open(generated_path, 'w') as f:
680
+ json.dump(generated_section, f, indent=4)
681
+
682
+ return generated_section, input_token, output_token
683
+
684
+ def submit_section_feedback(feedback_text):
685
+ """Submit Section feedback"""
686
+ if not feedback_text or feedback_text.strip().lower() == 'yes':
687
+ # User satisfied, proceed to next stage
688
+ result = proceed_to_text_content()
689
+ status, fc_section_visible, fc_display_visible, fc_display_md, fc_display_json, fc_feedback_visible = result
690
+ return (
691
+ status,
692
+ "", # section_display_md clear
693
+ "", # section_display_json clear
694
+ "", # section_feedback_input clear
695
+ gr.update(visible=False), # feedback_section hide
696
+ fc_section_visible, # feedback_full_content show
697
+ fc_display_visible, # full_content_display_md show
698
+ fc_display_md, # full_content_display_md content
699
+ fc_display_json, # full_content_display_json content
700
+ fc_feedback_visible # full_content_feedback_input show
701
+ )
702
+
703
+ # User provides feedback, modify Section
704
+ from camel.messages import BaseMessage
705
+ from utils.wei_utils import account_token
706
+ from utils.src.utils import get_json_from_response
707
+
708
+ message = BaseMessage.make_assistant_message(
709
+ role_name='User',
710
+ content=f'human feedback: {feedback_text}\n\nPlease make modifications based on this feedback. Output format as specified above.'
711
+ )
712
+ response = state.planner.planner_agent.step(message)
713
+ input_token, output_token = account_token(response)
714
+ state.total_input_tokens_t += input_token
715
+ state.total_output_tokens_t += output_token
716
+
717
+ generated_section = get_json_from_response(response.msgs[0].content)
718
+ state.generated_section = generated_section
719
+
720
+ generated_path = f'project_contents/{state.args.paper_name}_generated_section.json'
721
+ with open(generated_path, 'w') as f:
722
+ json.dump(generated_section, f, indent=4)
723
+
724
+ # Use Markdown formatting
725
+ section_display_md = format_section_to_markdown(generated_section)
726
+ section_display_json = json.dumps(generated_section, indent=2, ensure_ascii=False)
727
+
728
+ return (
729
+ f"✅ Section updated, please continue reviewing\n\nTokens: {input_token} → {output_token}",
730
+ section_display_md, # Markdown format
731
+ section_display_json, # JSON format
732
+ "", # Clear input box
733
+ gr.update(visible=True), # feedback_section keep visible
734
+ gr.update(visible=False), # feedback_full_content keep hidden
735
+ gr.update(visible=False), # full_content_display_md keep hidden
736
+ "", # full_content_display_md content
737
+ "", # full_content_display_json content
738
+ gr.update(visible=False) # full_content_feedback_input keep hidden
739
+ )
740
+
741
+ def proceed_to_text_content():
742
+ """Enter Text Content generation stage"""
743
+ print("="*50)
744
+ print("STEP 4: Generating Text Content")
745
+ print("="*50)
746
+
747
+ text_page_content, input_token, output_token = state.planner.text_content_generation(
748
+ state.paper_content, state.figures, state.generated_section
749
+ )
750
+ state.total_input_tokens_t += input_token
751
+ state.total_output_tokens_t += output_token
752
+ state.text_page_content = text_page_content
753
+
754
+ # Enter Full Content stage
755
+ return proceed_to_full_content()
756
+
757
+ def proceed_to_full_content():
758
+ """Enter Full Content generation stage"""
759
+ print("="*50)
760
+ print("STEP 5: Generating Full Content")
761
+ print("="*50)
762
+
763
+ state.current_stage = "full_content"
764
+
765
+ generated_content, input_token, output_token = generate_full_content_initial()
766
+ state.total_input_tokens_t += input_token
767
+ state.total_output_tokens_t += output_token
768
+
769
+ # Use Markdown formatting
770
+ content_display_md = format_full_content_to_markdown(generated_content, state.figures)
771
+ content_display_json = json.dumps(generated_content, indent=2, ensure_ascii=False)
772
+
773
+ return (
774
+ f"✅ Full Content generation completed, please review and provide feedback\n\nTokens: {input_token} → {output_token}",
775
+ gr.update(visible=True), # feedback_full_content show
776
+ gr.update(visible=True), # full_content_display_md show
777
+ content_display_md, # Markdown format
778
+ content_display_json, # JSON format
779
+ gr.update(visible=True) # full_content_feedback_input show
780
+ )
781
+
782
+ def generate_full_content_initial():
783
+ """Generate initial Full Content"""
784
+ import yaml
785
+ from jinja2 import Environment, StrictUndefined
786
+ from utils.wei_utils import account_token
787
+ from utils.src.utils import get_json_from_response
788
+
789
+ with open('utils/prompt_templates/page_templates/full_content_generation.yaml', 'r') as f:
790
+ planner_config = yaml.safe_load(f)
791
+
792
+ jinja_env = Environment(undefined=StrictUndefined)
793
+ template = jinja_env.from_string(planner_config["template"])
794
+
795
+ jinja_args = {
796
+ 'paper_content': state.paper_content,
797
+ 'figures': json.dumps(state.figures, indent=2),
798
+ 'project_page_content': json.dumps(state.text_page_content, indent=2)
799
+ }
800
+
801
+ prompt = template.render(**jinja_args)
802
+
803
+ state.planner.planner_agent.reset()
804
+ response = state.planner.planner_agent.step(prompt)
805
+ input_token, output_token = account_token(response)
806
+ generated_content = get_json_from_response(response.msgs[0].content)
807
+
808
+ state.generated_content = generated_content
809
+
810
+ first_path = f'project_contents/{state.args.paper_name}_generated_full_content.v0.json'
811
+ with open(first_path, 'w', encoding='utf-8') as f:
812
+ json.dump(generated_content, f, ensure_ascii=False, indent=2)
813
+
814
+ return generated_content, input_token, output_token
815
+
816
+ def submit_full_content_feedback(feedback_text):
817
+ """Submit Full Content feedback"""
818
+ if not feedback_text or feedback_text.strip().lower() == 'yes':
819
+ # User satisfied, proceed to HTML generation
820
+ result = proceed_to_html_generation()
821
+ status, html_feedback_visible, preview_info, preview_url, open_btn_visible = result
822
+ return (
823
+ status,
824
+ "", # full_content_display_md clear
825
+ "", # full_content_display_json clear
826
+ "", # full_content_feedback_input clear
827
+ gr.update(visible=False), # feedback_full_content hide
828
+ html_feedback_visible, # feedback_html show
829
+ preview_info, # preview_info_display
830
+ preview_url, # preview_url_state
831
+ open_btn_visible # open_preview_btn show
832
+ )
833
+
834
+ # User provides feedback
835
+ from camel.messages import BaseMessage
836
+ from utils.wei_utils import account_token
837
+ from utils.src.utils import get_json_from_response
838
+
839
+ message = BaseMessage.make_assistant_message(
840
+ role_name='User',
841
+ content=f'human feedback: {feedback_text}\n\nPlease make modifications based on this feedback. Output format as specified above.'
842
+ )
843
+ response = state.planner.planner_agent.step(message)
844
+ input_token, output_token = account_token(response)
845
+ state.total_input_tokens_t += input_token
846
+ state.total_output_tokens_t += output_token
847
+
848
+ generated_content = get_json_from_response(response.msgs[0].content)
849
+ state.generated_content = generated_content
850
+
851
+ final_path = f'project_contents/{state.args.paper_name}_generated_full_content.json'
852
+ with open(final_path, 'w', encoding='utf-8') as f:
853
+ json.dump(generated_content, f, ensure_ascii=False, indent=2)
854
+
855
+ # Use Markdown formatting
856
+ content_display_md = format_full_content_to_markdown(generated_content, state.figures)
857
+ content_display_json = json.dumps(generated_content, indent=2, ensure_ascii=False)
858
+
859
+ return (
860
+ f"✅ Full Content updated, please continue reviewing\n\nTokens: {input_token} → {output_token}",
861
+ content_display_md, # Markdown format
862
+ content_display_json, # JSON format
863
+ "", # Clear input box
864
+ gr.update(visible=True), # feedback_full_content keep visible
865
+ gr.update(visible=False), # feedback_html keep hidden
866
+ "", # preview_info_display
867
+ "", # preview_url_state
868
+ gr.update(visible=False) # open_preview_btn keep hidden
869
+ )
870
+
871
+ def proceed_to_html_generation():
872
+ """Enter HTML generation stage"""
873
+ print("="*50)
874
+ print("STEP 6: Generating HTML")
875
+ print("="*50)
876
+
877
+ state.current_stage = "html"
878
+
879
+ # Copy static files
880
+ static_dir = copy_static_files(
881
+ state.args.template_file,
882
+ state.args.template_dir,
883
+ state.args.output_dir,
884
+ state.args.paper_name
885
+ )
886
+
887
+ # Generate HTML
888
+ html_relative_path = os.path.relpath(state.args.template_file, state.args.template_dir)
889
+ html_dir = '/'.join(html_relative_path.strip().split('/')[:-1])
890
+ state.html_dir = html_dir
891
+
892
+ html_generator = ProjectPageHTMLGenerator(state.agent_config_t, state.args)
893
+ state.html_generator = html_generator
894
+
895
+ with open(state.args.template_file, 'r', encoding='utf-8') as file:
896
+ html_template = file.read()
897
+
898
+ # Create assets directory
899
+ assets_dir = html_generator.create_assets_directory(state.args, html_dir, state.args.output_dir)
900
+
901
+ # Generate HTML
902
+ html_content, input_token, output_token = html_generator.generate_complete_html(
903
+ state.args, state.generated_content, html_dir, html_template
904
+ )
905
+ state.total_input_tokens_t += input_token
906
+ state.total_output_tokens_t += output_token
907
+
908
+ # Save HTML (before table modification)
909
+ html_dir_path = os.path.join(state.args.output_dir, state.args.paper_name, html_dir)
910
+ os.makedirs(html_dir_path, exist_ok=True)
911
+
912
+ html_file_path_no_modify = os.path.join(html_dir_path, 'index_no_modify_table.html')
913
+ with open(html_file_path_no_modify, 'w', encoding='utf-8') as file:
914
+ file.write(html_content)
915
+
916
+ # Generate screenshot (before table modification)
917
+ screenshot_path_no_modify = os.path.join(html_dir_path, 'page_final_no_modify_table.png')
918
+ run_sync_screenshots(to_url(html_file_path_no_modify), screenshot_path_no_modify)
919
+
920
+ # Modify tables
921
+ html_content, input_token, output_token = html_generator.modify_html_table(html_content, html_dir)
922
+ state.total_input_tokens_t += input_token
923
+ state.total_output_tokens_t += output_token
924
+
925
+ state.html_content = html_content
926
+
927
+ # Save HTML (after table modification)
928
+ html_file_path = os.path.join(html_dir_path, 'index.html')
929
+ with open(html_file_path, 'w', encoding='utf-8') as file:
930
+ file.write(html_content)
931
+
932
+ state.html_file_path = html_file_path
933
+
934
+ # Generate screenshot (after table modification)
935
+ run_sync_screenshots(
936
+ to_url(html_file_path),
937
+ os.path.join(html_dir_path, 'page_final.png')
938
+ )
939
+
940
+ # Start preview server
941
+ html_full_dir = os.path.dirname(os.path.abspath(html_file_path))
942
+ port = start_preview_server(html_full_dir)
943
+ preview_url = f"http://localhost:{port}/index.html"
944
+ state.preview_url = preview_url
945
+
946
+ # Create preview info display
947
+ preview_info = f"""
948
+ ### 🌐 HTML Generation Completed
949
+
950
+ **Preview URL**: {preview_url}
951
+
952
+ **Instructions**:
953
+ 1. Click the **"🌐 Open Preview in New Tab"** button below to view the generated webpage
954
+ 2. Carefully review the page in the new tab
955
+ 3. If satisfied, enter **'yes'** in the feedback box and submit
956
+ 4. If modifications are needed, provide detailed feedback and submit
957
+
958
+ **Token Usage**: {input_token} → {output_token}
959
+ """
960
+
961
+ return (
962
+ f"✅ HTML generation completed\n\nTokens: {input_token} → {output_token}",
963
+ gr.update(visible=True), # feedback_html show
964
+ preview_info, # preview_info_display
965
+ preview_url, # preview_url_state
966
+ gr.update(visible=True) # open_preview_btn show
967
+ )
968
+
969
+ def submit_html_feedback(feedback_text):
970
+ """Submit HTML feedback"""
971
+ if not feedback_text or feedback_text.strip().lower() == 'yes':
972
+ # User satisfied, complete generation
973
+ result = finalize_generation()
974
+ status, html_file = result
975
+ return (
976
+ status,
977
+ "", # preview_info_display clear
978
+ "", # html_feedback_input clear
979
+ gr.update(visible=False), # feedback_html hide
980
+ gr.update(visible=False), # open_preview_btn hide
981
+ html_file # html_file_output
982
+ )
983
+
984
+ # User provides feedback
985
+ html_content, input_token, output_token = state.html_generator.modify_html_from_human_feedback(
986
+ state.html_content, feedback_text
987
+ )
988
+ state.total_input_tokens_t += input_token
989
+ state.total_output_tokens_t += output_token
990
+ state.html_content = html_content
991
+
992
+ # Save updated HTML
993
+ html_dir_path = os.path.dirname(state.html_file_path)
994
+
995
+ # Save as temporary version (for possible feedback iteration)
996
+ import time
997
+ timestamp = int(time.time())
998
+ html_file_feedback = os.path.join(html_dir_path, f'index_feedback_{timestamp}.html')
999
+ with open(html_file_feedback, 'w', encoding='utf-8') as file:
1000
+ file.write(html_content)
1001
+
1002
+ # Also update main file
1003
+ with open(state.html_file_path, 'w', encoding='utf-8') as file:
1004
+ file.write(html_content)
1005
+
1006
+ # Regenerate screenshot
1007
+ screenshot_path = os.path.join(html_dir_path, 'page_final.png')
1008
+ try:
1009
+ run_sync_screenshots(to_url(state.html_file_path), screenshot_path)
1010
+ except Exception as e:
1011
+ print(f"Screenshot generation failed: {e}")
1012
+
1013
+ # Update preview info
1014
+ preview_info = f"""
1015
+ ### 🌐 HTML Updated
1016
+
1017
+ **Preview URL**: {state.preview_url}
1018
+
1019
+ **Instructions**:
1020
+ 1. Click the **"🌐 Open Preview in New Tab"** button below to view the updated webpage
1021
+ 2. **Refresh the browser** to see the latest version
1022
+ 3. If satisfied, enter **'yes'** in the feedback box and submit
1023
+ 4. If further modifications are needed, continue providing feedback
1024
+
1025
+ **Token Usage**: {input_token} → {output_token}
1026
+ """
1027
+
1028
+ return (
1029
+ f"✅ HTML updated, please refresh the preview page\n\nTokens: {input_token} → {output_token}",
1030
+ preview_info, # preview_info_display
1031
+ "", # Clear input box
1032
+ gr.update(visible=True), # feedback_html keep visible
1033
+ gr.update(visible=True), # open_preview_btn keep visible
1034
+ None # html_file_output no download yet
1035
+ )
1036
+
1037
+ def finalize_generation():
1038
+ """Complete generation and save final results"""
1039
+ import time
1040
+
1041
+ # Ensure final HTML is saved
1042
+ html_dir_path = os.path.dirname(state.html_file_path)
1043
+
1044
+ # Save final version
1045
+ final_html_path = os.path.join(html_dir_path, 'index_final.html')
1046
+ with open(final_html_path, 'w', encoding='utf-8') as file:
1047
+ file.write(state.html_content)
1048
+
1049
+ # Also update main file
1050
+ with open(state.html_file_path, 'w', encoding='utf-8') as file:
1051
+ file.write(state.html_content)
1052
+
1053
+ # Save metadata
1054
+ metadata = state.html_generator.generate_metadata(state.generated_content, state.args)
1055
+ metadata_path = state.html_generator.save_metadata(metadata, state.args, state.args.output_dir)
1056
+
1057
+ # Create README file
1058
+ readme_path = os.path.join(state.args.output_dir, state.args.paper_name, 'README.md')
1059
+ readme_content = f"""# {state.args.paper_name} - Project Page
1060
+
1061
+ ## 📄 Project Information
1062
+
1063
+ - **Paper Name**: {state.args.paper_name}
1064
+ - **Generation Time**: {time.strftime('%Y-%m-%d %H:%M:%S')}
1065
+ - **Text Model**: {state.args.model_name_t}
1066
+ - **Vision Model**: {state.args.model_name_v}
1067
+
1068
+ ## 🚀 Usage
1069
+
1070
+ 1. Extract this archive to any directory
1071
+ 2. Open `index.html` to view the project page
1072
+ 3. All resources (CSS, images, etc.) are included
1073
+
1074
+ ## 📁 File Structure
1075
+
1076
+ - `index.html` - Main page file
1077
+ - `index_final.html` - Final confirmed version
1078
+ - `assets/` - Image and table resources
1079
+ - `css/` or `styles/` - Style files
1080
+ - `js/` or `scripts/` - JavaScript files
1081
+ - `metadata.json` - Page metadata
1082
+ - `generation_log.json` - Generation log
1083
+
1084
+ ## 💡 Tips
1085
+
1086
+ - Recommended browsers: Chrome, Firefox, Safari, Edge
1087
+ - For web deployment, simply upload the entire folder
1088
+ - Feel free to modify HTML and CSS for customization
1089
+
1090
+ ---
1091
+ Generated by Paper2ProjectPage
1092
+ """
1093
+
1094
+ with open(readme_path, 'w', encoding='utf-8') as f:
1095
+ f.write(readme_content)
1096
+
1097
+ # Save generation log
1098
+ log_data = {
1099
+ 'paper_name': state.args.paper_name,
1100
+ 'paper_path': state.args.paper_path,
1101
+ 'models': {
1102
+ 'text_model': state.args.model_name_t,
1103
+ 'vision_model': state.args.model_name_v
1104
+ },
1105
+ 'token_usage': {
1106
+ 'text_input_tokens': state.total_input_tokens_t,
1107
+ 'text_output_tokens': state.total_output_tokens_t
1108
+ },
1109
+ 'output_files': {
1110
+ 'html_file': state.html_file_path,
1111
+ 'final_html_file': final_html_path,
1112
+ 'metadata_file': metadata_path,
1113
+ 'readme_file': readme_path
1114
+ },
1115
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
1116
+ }
1117
+
1118
+ log_path = f"{state.args.output_dir}/{state.args.paper_name}/generation_log.json"
1119
+ with open(log_path, 'w') as f:
1120
+ json.dump(log_data, f, indent=4, ensure_ascii=False)
1121
+
1122
+ # Create project archive
1123
+ project_dir = os.path.join(state.args.output_dir, state.args.paper_name)
1124
+ zip_path = create_project_zip(project_dir, state.args.output_dir, state.args.paper_name)
1125
+
1126
+ if zip_path and os.path.exists(zip_path):
1127
+ # Get archive size
1128
+ zip_size = os.path.getsize(zip_path)
1129
+ zip_size_mb = zip_size / (1024 * 1024)
1130
+ zip_filename = os.path.basename(zip_path)
1131
+
1132
+ success_msg = f"""
1133
+ ✅ Project page generation completed!
1134
+
1135
+ 📁 Output directory: {state.args.output_dir}/{state.args.paper_name}
1136
+ 🌐 HTML file: {state.html_file_path}
1137
+ 🌐 Final version: {final_html_path}
1138
+ 📋 Metadata: {metadata_path}
1139
+ 📖 README: {readme_path}
1140
+ 📊 Log file: {log_path}
1141
+ 📦 Archive: {zip_filename} ({zip_size_mb:.2f} MB)
1142
+ 🔢 Total token usage: {state.total_input_tokens_t} → {state.total_output_tokens_t}
1143
+
1144
+ 🎉 All feedback completed, page successfully generated!
1145
+ Click the button below to download the complete project archive (including HTML, CSS, images, README, and all resources).
1146
+ """
1147
+
1148
+ return (
1149
+ success_msg,
1150
+ zip_path # Return archive for download
1151
+ )
1152
+
1153
+ else:
1154
+ error_msg = f"""
1155
+ ⚠️ Project page generated, but archive creation failed!
1156
+
1157
+ 📁 Output directory: {state.args.output_dir}/{state.args.paper_name}
1158
+ 🌐 HTML file: {state.html_file_path}
1159
+ 📋 Metadata: {metadata_path}
1160
+
1161
+ You can manually retrieve all files from the output directory {project_dir}.
1162
+ """
1163
+ return (
1164
+ error_msg,
1165
+ state.html_file_path # Return HTML file
1166
+ )
1167
+
1168
+ # ==================== Gradio Interface ====================
1169
+
1170
+ # Custom CSS for better English font rendering
1171
+ custom_css = """
1172
+ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
1173
+
1174
+ * {
1175
+ font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif !important;
1176
+ }
1177
+
1178
+ code, pre, .code {
1179
+ font-family: 'JetBrains Mono', 'Courier New', Consolas, Monaco, monospace !important;
1180
+ }
1181
+
1182
+ h1, h2, h3, h4, h5, h6 {
1183
+ font-weight: 600 !important;
1184
+ letter-spacing: -0.02em !important;
1185
+ }
1186
+
1187
+ .markdown-text {
1188
+ line-height: 1.7 !important;
1189
+ font-size: 15px !important;
1190
+ }
1191
+
1192
+ .gr-button {
1193
+ font-weight: 500 !important;
1194
+ letter-spacing: 0.01em !important;
1195
+ }
1196
+
1197
+ .gr-input, .gr-textarea {
1198
+ font-size: 14px !important;
1199
+ line-height: 1.6 !important;
1200
+ }
1201
+
1202
+ .gr-box {
1203
+ border-radius: 8px !important;
1204
+ }
1205
+
1206
+ /* Better spacing for English content */
1207
+ .gr-markdown p {
1208
+ margin-bottom: 0.8em !important;
1209
+ }
1210
+
1211
+ .gr-markdown ul, .gr-markdown ol {
1212
+ margin-left: 1.2em !important;
1213
+ }
1214
+
1215
+ .gr-markdown li {
1216
+ margin-bottom: 0.4em !important;
1217
+ }
1218
+ """
1219
+
1220
+ with gr.Blocks(title="Paper2ProjectPage Generator", theme=gr.themes.Soft(), css=custom_css) as demo:
1221
+
1222
+ gr.Markdown("""
1223
+ # 📄 AutoPage Generator with Interactive Feedback
1224
+
1225
+ Upload your research paper PDF and generate beautiful project pages through multi-round interactive feedback
1226
+ """)
1227
+
1228
+ with gr.Row():
1229
+ with gr.Column(scale=1):
1230
+ # PDF Upload
1231
+ pdf_input = gr.File(
1232
+ label="📎 Upload PDF Paper",
1233
+ file_types=[".pdf"],
1234
+ type="filepath"
1235
+ )
1236
+
1237
+ gr.Markdown("### 🔑 API Keys Configuration")
1238
+ gr.Markdown("""
1239
+ **⚠️ Security Notice**: Your API keys are only stored in memory during the session and are never saved to disk.
1240
+
1241
+ **📋 How to get API keys:**
1242
+ - **OpenAI**: Get your API key from [OpenAI Platform](https://platform.openai.com/api-keys)
1243
+ - **Gemini**: Get your API key from [Google AI Studio](https://aistudio.google.com/app/apikey)
1244
+ - **Qwen**: Get your API key from [DashScope](https://dashscope.console.aliyun.com/apiKey)
1245
+ - **ZhipuAI**: Get your API key from [ZhipuAI Console](https://open.bigmodel.cn/usercenter/apikeys)
1246
+ - **OpenRouter**: Get your API key from [OpenRouter](https://openrouter.ai/keys)
1247
+
1248
+ **🚀 For HuggingFace Spaces**: You can also set these as environment variables in your Space settings.
1249
+ """)
1250
+
1251
+ with gr.Row():
1252
+ openai_api_key = gr.Textbox(
1253
+ label="OpenAI API Key",
1254
+ value=os.getenv("OPENAI_API_KEY", ""),
1255
+ type="password",
1256
+ placeholder="sk-...",
1257
+ info="Required for GPT models"
1258
+ )
1259
+ gemini_api_key = gr.Textbox(
1260
+ label="Gemini API Key",
1261
+ value=os.getenv("GEMINI_API_KEY", ""),
1262
+ type="password",
1263
+ placeholder="AI...",
1264
+ info="Required for Gemini models"
1265
+ )
1266
+
1267
+ with gr.Row():
1268
+ qwen_api_key = gr.Textbox(
1269
+ label="Qwen API Key",
1270
+ value=os.getenv("QWEN_API_KEY", ""),
1271
+ type="password",
1272
+ placeholder="sk-...",
1273
+ info="Required for Qwen models"
1274
+ )
1275
+ zhipuai_api_key = gr.Textbox(
1276
+ label="ZhipuAI API Key",
1277
+ value=os.getenv("ZHIPUAI_API_KEY", ""),
1278
+ type="password",
1279
+ placeholder="...",
1280
+ info="Required for GLM models"
1281
+ )
1282
+
1283
+ openrouter_api_key = gr.Textbox(
1284
+ label="OpenRouter API Key",
1285
+ value=os.getenv("OPENROUTER_API_KEY", ""),
1286
+ type="password",
1287
+ placeholder="sk-or-...",
1288
+ info="Required for OpenRouter models"
1289
+ )
1290
+
1291
+ gr.Markdown("### 🤖 Model Configuration")
1292
+
1293
+ # Text Model Options
1294
+ text_model_options = [
1295
+ ("GPT-4o", "4o"),
1296
+ ("GPT-4o Mini", "4o-mini"),
1297
+ ("GPT-4.1", "gpt-4.1"),
1298
+ ("GPT-4.1 Mini", "gpt-4.1-mini"),
1299
+ ("O1", "o1"),
1300
+ ("O3", "o3"),
1301
+ ("O3 Mini", "o3-mini"),
1302
+ ("Gemini 2.5 Pro", "gemini"),
1303
+ ("Gemini 2.5 Pro (Alt)", "gemini-2.5-pro"),
1304
+ ("Gemini 2.5 Flash", "gemini-2.5-flash"),
1305
+ ("Qwen", "qwen"),
1306
+ ("Qwen Plus", "qwen-plus"),
1307
+ ("Qwen Max", "qwen-max"),
1308
+ ("Qwen Long", "qwen-long"),
1309
+ ("OpenRouter Qwen Plus", "openrouter_qwen-plus"),
1310
+ ("OpenRouter GPT-4o Mini", "openrouter_gpt-4o-mini"),
1311
+ ("OpenRouter Gemini 2.5 Flash", "openrouter_gemini-2.5-flash"),
1312
+ ("OpenRouter O3", "openrouter_openai/o3"),
1313
+ ("OpenRouter Claude Sonnet 4.5", "openrouter_claude-sonnet-4.5"),
1314
+ ]
1315
+
1316
+ # Vision Model Options
1317
+ vision_model_options = [
1318
+ ("GPT-4o", "4o"),
1319
+ ("GPT-4o Mini", "4o-mini"),
1320
+ ("Gemini 2.5 Pro", "gemini"),
1321
+ ("Gemini 2.5 Pro (Alt)", "gemini-2.5-pro"),
1322
+ ("Gemini 2.5 Flash", "gemini-2.5-flash"),
1323
+ ("Qwen VL Max", "qwen-vl-max"),
1324
+ ("Qwen 2.5 VL 72B", "qwen-2.5-vl-72b"),
1325
+ ("OpenRouter Qwen VL 72B", "openrouter_qwen_vl_72b"),
1326
+ ("OpenRouter Qwen VL 7B", "openrouter_qwen_vl_7b"),
1327
+ ("OpenRouter Qwen VL Max", "openrouter_qwen-vl-max"),
1328
+ ("OpenRouter Gemini 2.5 Flash", "openrouter_gemini-2.5-flash"),
1329
+ ]
1330
+
1331
+ with gr.Row():
1332
+ model_name_t = gr.Dropdown(
1333
+ label="Text Model",
1334
+ choices=text_model_options,
1335
+ value="gemini",
1336
+ info="Select model for text processing"
1337
+ )
1338
+ model_name_v = gr.Dropdown(
1339
+ label="Vision Model",
1340
+ choices=vision_model_options,
1341
+ value="gemini",
1342
+ info="Select model for vision processing"
1343
+ )
1344
+
1345
+ gr.Markdown("### 📁 Path Configuration")
1346
+ template_root = gr.Textbox(
1347
+ label="Template Root",
1348
+ value="templates",
1349
+ info="Root directory for templates"
1350
+ )
1351
+ template_dir = gr.Textbox(
1352
+ label="Template Directory",
1353
+ value="",
1354
+ info="Selected template directory (optional)"
1355
+ )
1356
+ template_file = gr.Textbox(
1357
+ label="Template File",
1358
+ value="",
1359
+ info="Specific template file path (optional)"
1360
+ )
1361
+ template_choice = gr.Radio(
1362
+ label="Recommended Templates",
1363
+ choices=[],
1364
+ value=None,
1365
+ info="Select from recommended templates",
1366
+ visible=True
1367
+ )
1368
+ output_dir = gr.Textbox(
1369
+ label="Output Directory",
1370
+ value="generated_project_pages",
1371
+ info="Directory for output files"
1372
+ )
1373
+ style_preference = gr.Textbox(
1374
+ label="Style Preference JSON",
1375
+ value="",
1376
+ info="Style preference JSON file path (optional)"
1377
+ )
1378
+ tmp_dir = gr.Textbox(
1379
+ label="Temporary Directory",
1380
+ value="tmp",
1381
+ info="Directory for temporary files"
1382
+ )
1383
+
1384
+ template_preview_links = gr.Markdown(
1385
+ label="Template Preview Links",
1386
+ value="",
1387
+ visible=False
1388
+ )
1389
+
1390
+ # ===== Hidden parameters with default values =====
1391
+ resume = gr.Radio(
1392
+ label="Resume From Step",
1393
+ choices=['parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'],
1394
+ value='parse_pdf',
1395
+ visible=False
1396
+ )
1397
+
1398
+ human_input = gr.Radio(
1399
+ label="Enable Human Feedback",
1400
+ choices=[0, 1],
1401
+ value=1,
1402
+ visible=False
1403
+ )
1404
+
1405
+ with gr.Column(scale=1):
1406
+ gr.Markdown("### 🎨 Style Configuration")
1407
+
1408
+ background_color = gr.Radio(
1409
+ label="Background Color",
1410
+ choices=["light", "dark"],
1411
+ value="light",
1412
+ info="Background color theme"
1413
+ )
1414
+
1415
+ has_navigation = gr.Radio(
1416
+ label="Has Navigation",
1417
+ choices=["yes", "no"],
1418
+ value="yes",
1419
+ info="Include navigation bar"
1420
+ )
1421
+
1422
+ has_hero_section = gr.Radio(
1423
+ label="Has Hero Section",
1424
+ choices=["yes", "no"],
1425
+ value="yes",
1426
+ info="Include hero/header section"
1427
+ )
1428
+
1429
+ title_color = gr.Radio(
1430
+ label="Title Color",
1431
+ choices=["pure", "colorful"],
1432
+ value="pure",
1433
+ info="Title color style"
1434
+ )
1435
+
1436
+ page_density = gr.Radio(
1437
+ label="Page Density",
1438
+ choices=["spacious", "compact"],
1439
+ value="spacious",
1440
+ info="Page spacing density"
1441
+ )
1442
+
1443
+ image_layout = gr.Radio(
1444
+ label="Image Layout",
1445
+ choices=["rotation", "parallelism"],
1446
+ value="parallelism",
1447
+ info="Image layout style"
1448
+ )
1449
+
1450
+ gr.Markdown("### ⚙️ Advanced Options")
1451
+
1452
+ full_content_check_times = gr.Number(
1453
+ label="Full Content Check Times",
1454
+ value=1,
1455
+ precision=0,
1456
+ info="Number of full content validation checks"
1457
+ )
1458
+
1459
+ html_check_times = gr.Number(
1460
+ label="HTML Check Times",
1461
+ value=1,
1462
+ precision=0,
1463
+ info="Number of HTML validation checks"
1464
+ )
1465
+
1466
+ # Start Generation Button
1467
+ start_btn = gr.Button("🚀 Start Generation", variant="primary", size="lg")
1468
+
1469
+ # Status Output
1470
+ status_output = gr.Textbox(
1471
+ label="📊 Generation Status",
1472
+ lines=5,
1473
+ interactive=False
1474
+ )
1475
+
1476
+ # Section Feedback Area
1477
+ with gr.Group(visible=False) as feedback_section:
1478
+ gr.Markdown("### 📝 Section Generation Results")
1479
+ gr.Markdown("Please review the generated section structure. If satisfied, enter **'yes'**, otherwise provide modification feedback:")
1480
+
1481
+ with gr.Tabs():
1482
+ with gr.Tab("📖 Preview (Markdown)"):
1483
+ section_display_md = gr.Markdown(
1484
+ label="Section Preview",
1485
+ value=""
1486
+ )
1487
+ with gr.Tab("📋 Raw Data (JSON)"):
1488
+ section_display_json = gr.Code(
1489
+ label="Section JSON",
1490
+ language="json",
1491
+ value="",
1492
+ lines=15
1493
+ )
1494
+
1495
+ section_feedback_input = gr.TextArea(
1496
+ label="Your Feedback",
1497
+ placeholder="Enter 'yes' to continue, or provide modification feedback...",
1498
+ lines=3
1499
+ )
1500
+ section_submit_btn = gr.Button("Submit Feedback", variant="primary")
1501
+
1502
+ # Full Content Feedback Area
1503
+ with gr.Group(visible=False) as feedback_full_content:
1504
+ gr.Markdown("### 📄 Full Content Generation Results")
1505
+ gr.Markdown("Please review the generated full content. If satisfied, enter **'yes'**, otherwise provide modification feedback:")
1506
+
1507
+ with gr.Tabs():
1508
+ with gr.Tab("📖 Preview (Markdown)"):
1509
+ full_content_display_md = gr.Markdown(
1510
+ label="Full Content Preview",
1511
+ value=""
1512
+ )
1513
+ with gr.Tab("📋 Raw Data (JSON)"):
1514
+ full_content_display_json = gr.Code(
1515
+ label="Full Content JSON",
1516
+ language="json",
1517
+ value="",
1518
+ lines=15
1519
+ )
1520
+
1521
+ full_content_feedback_input = gr.TextArea(
1522
+ label="Your Feedback",
1523
+ placeholder="Enter 'yes' to continue, or provide modification feedback...",
1524
+ lines=3
1525
+ )
1526
+ full_content_submit_btn = gr.Button("Submit Feedback", variant="primary")
1527
+
1528
+ # HTML Feedback Area
1529
+ with gr.Group(visible=False) as feedback_html:
1530
+ gr.Markdown("### 🌐 HTML Generation Results")
1531
+
1532
+ # Preview Info Display
1533
+ preview_info_display = gr.Markdown(
1534
+ value="",
1535
+ label="Preview Information"
1536
+ )
1537
+
1538
+ # Preview URL (hidden state for JS)
1539
+ preview_url_state = gr.Textbox(visible=False)
1540
+
1541
+ # Open Preview in New Tab Button
1542
+ open_preview_btn = gr.Button(
1543
+ "🌐 Open Preview in New Tab",
1544
+ variant="secondary",
1545
+ size="lg",
1546
+ visible=False
1547
+ )
1548
+
1549
+ gr.Markdown("---")
1550
+
1551
+ # Feedback Input Area
1552
+ html_feedback_input = gr.TextArea(
1553
+ label="Your Feedback",
1554
+ placeholder="Enter 'yes' to finalize, or provide modification feedback...",
1555
+ lines=3
1556
+ )
1557
+ html_submit_btn = gr.Button("Submit Feedback", variant="primary")
1558
+
1559
+ # Final Output
1560
+ html_file_output = gr.File(
1561
+ label="📥 Download Project Archive",
1562
+ interactive=False
1563
+ )
1564
+
1565
+ gr.Markdown("""
1566
+ ---
1567
+ ### 💡 User Guide
1568
+
1569
+ 1. **Upload PDF**: Select your research paper PDF file
1570
+ 2. **Configure Parameters**: Adjust model, path, and style settings as needed
1571
+ 3. **Start Generation**: Click the "Start Generation" button
1572
+ 4. **Three-Stage Feedback**:
1573
+ - 📝 **Section Feedback**: Review the generated page structure (Markdown preview + JSON data), provide feedback or enter 'yes' to continue
1574
+ - 📄 **Full Content Feedback**: Review the generated complete content (Markdown preview + JSON data), provide feedback or enter 'yes' to continue
1575
+ - 🌐 **HTML Feedback**: View the generated webpage in a new tab, provide feedback or enter 'yes' to finalize
1576
+ 5. **Download Results**: Download the complete project archive after completion
1577
+
1578
+ ⚠️ **Tips**:
1579
+ - Each stage supports multiple rounds of feedback until you're satisfied
1580
+ - Section and Full Content stages offer **Markdown preview** and **JSON raw data** viewing options
1581
+ - Markdown preview is more visually appealing, JSON data shows complete structure
1582
+ - HTML stage requires clicking "Open Preview in New Tab" to view the full page in browser
1583
+ - Enter 'yes' to indicate satisfaction and proceed to the next stage
1584
+ - The final ZIP download includes the complete project folder with all resources
1585
+ """)
1586
+
1587
+ # Bind Events
1588
+ start_btn.click(
1589
+ fn=start_generation,
1590
+ inputs=[
1591
+ pdf_input, model_name_t, model_name_v, template_root,
1592
+ template_dir, template_file, output_dir, style_preference,
1593
+ tmp_dir, full_content_check_times, background_color,
1594
+ has_navigation, has_hero_section, title_color, page_density,
1595
+ image_layout, html_check_times, resume, human_input,
1596
+ template_choice, openai_api_key, gemini_api_key,
1597
+ qwen_api_key, zhipuai_api_key, openrouter_api_key
1598
+ ],
1599
+ outputs=[
1600
+ status_output,
1601
+ feedback_section,
1602
+ section_display_md,
1603
+ section_display_json,
1604
+ template_choice,
1605
+ template_preview_links,
1606
+ section_feedback_input
1607
+ ]
1608
+ )
1609
+
1610
+ section_submit_btn.click(
1611
+ fn=submit_section_feedback,
1612
+ inputs=[section_feedback_input],
1613
+ outputs=[
1614
+ status_output,
1615
+ section_display_md,
1616
+ section_display_json,
1617
+ section_feedback_input,
1618
+ feedback_section,
1619
+ feedback_full_content,
1620
+ full_content_display_md,
1621
+ full_content_display_md,
1622
+ full_content_display_json,
1623
+ full_content_feedback_input
1624
+ ]
1625
+ )
1626
+
1627
+ full_content_submit_btn.click(
1628
+ fn=submit_full_content_feedback,
1629
+ inputs=[full_content_feedback_input],
1630
+ outputs=[
1631
+ status_output,
1632
+ full_content_display_md,
1633
+ full_content_display_json,
1634
+ full_content_feedback_input,
1635
+ feedback_full_content,
1636
+ feedback_html,
1637
+ preview_info_display,
1638
+ preview_url_state,
1639
+ open_preview_btn
1640
+ ]
1641
+ )
1642
+
1643
+ html_submit_btn.click(
1644
+ fn=submit_html_feedback,
1645
+ inputs=[html_feedback_input],
1646
+ outputs=[
1647
+ status_output,
1648
+ preview_info_display,
1649
+ html_feedback_input,
1650
+ feedback_html,
1651
+ open_preview_btn,
1652
+ html_file_output
1653
+ ]
1654
+ )
1655
+
1656
+ # Open Preview Button - Use JavaScript to open in new tab
1657
+ open_preview_btn.click(
1658
+ fn=None,
1659
+ inputs=[preview_url_state],
1660
+ outputs=None,
1661
+ js="(url) => window.open(url, '_blank')"
1662
+ )
1663
+
1664
+ # Launch Application
1665
+ if __name__ == "__main__":
1666
+ demo.launch(
1667
+ server_name="0.0.0.0",
1668
+ server_port=7860,
1669
+ share=False,
1670
+ show_error=True
1671
+ )
camel/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from camel.logger import disable_logging, enable_logging, set_log_level
16
+
17
+ __version__ = '0.2.19'
18
+
19
+ __all__ = [
20
+ '__version__',
21
+ 'camel',
22
+ 'disable_logging',
23
+ 'enable_logging',
24
+ 'set_log_level',
25
+ ]
camel/agents/__init__.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .base import BaseAgent
15
+ from .chat_agent import ChatAgent
16
+ from .critic_agent import CriticAgent
17
+ from .embodied_agent import EmbodiedAgent
18
+ from .knowledge_graph_agent import KnowledgeGraphAgent
19
+ from .role_assignment_agent import RoleAssignmentAgent
20
+ from .search_agent import SearchAgent
21
+ from .task_agent import (
22
+ TaskCreationAgent,
23
+ TaskPlannerAgent,
24
+ TaskPrioritizationAgent,
25
+ TaskSpecifyAgent,
26
+ )
27
+ from .tool_agents.base import BaseToolAgent
28
+ from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
29
+
30
+ __all__ = [
31
+ 'BaseAgent',
32
+ 'ChatAgent',
33
+ 'TaskSpecifyAgent',
34
+ 'TaskPlannerAgent',
35
+ 'TaskCreationAgent',
36
+ 'TaskPrioritizationAgent',
37
+ 'CriticAgent',
38
+ 'BaseToolAgent',
39
+ 'HuggingFaceToolAgent',
40
+ 'EmbodiedAgent',
41
+ 'RoleAssignmentAgent',
42
+ 'SearchAgent',
43
+ 'KnowledgeGraphAgent',
44
+ ]
camel/agents/base.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from abc import ABC, abstractmethod
15
+ from typing import Any
16
+
17
+
18
+ class BaseAgent(ABC):
19
+ r"""An abstract base class for all CAMEL agents."""
20
+
21
+ @abstractmethod
22
+ def reset(self, *args: Any, **kwargs: Any) -> Any:
23
+ r"""Resets the agent to its initial state."""
24
+ pass
25
+
26
+ @abstractmethod
27
+ def step(self, *args: Any, **kwargs: Any) -> Any:
28
+ r"""Performs a single step of the agent."""
29
+ pass
camel/agents/chat_agent.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ import json
17
+ import logging
18
+ import re
19
+ import uuid
20
+ from collections import defaultdict
21
+ from typing import (
22
+ TYPE_CHECKING,
23
+ Any,
24
+ Callable,
25
+ Dict,
26
+ List,
27
+ Optional,
28
+ Tuple,
29
+ Type,
30
+ Union,
31
+ )
32
+
33
+ from openai.types.chat import ChatCompletionMessageToolCall
34
+ from openai.types.chat.chat_completion_message_tool_call import Function
35
+ from pydantic import BaseModel, ValidationError
36
+
37
+ from camel.agents.base import BaseAgent
38
+ from camel.memories import (
39
+ AgentMemory,
40
+ ChatHistoryMemory,
41
+ MemoryRecord,
42
+ ScoreBasedContextCreator,
43
+ )
44
+ from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
45
+ from camel.models import (
46
+ BaseModelBackend,
47
+ ModelFactory,
48
+ ModelManager,
49
+ ModelProcessingError,
50
+ )
51
+ from camel.responses import ChatAgentResponse
52
+ from camel.types import (
53
+ ChatCompletion,
54
+ ChatCompletionChunk,
55
+ ModelPlatformType,
56
+ ModelType,
57
+ OpenAIBackendRole,
58
+ RoleType,
59
+ )
60
+ from camel.utils import (
61
+ func_string_to_callable,
62
+ generate_prompt_for_structured_output,
63
+ get_model_encoding,
64
+ get_pydantic_object_schema,
65
+ json_to_function_code,
66
+ )
67
+
68
+ if TYPE_CHECKING:
69
+ from openai import Stream
70
+
71
+ from camel.terminators import ResponseTerminator
72
+ from camel.toolkits import FunctionTool
73
+
74
+
75
+ logger = logging.getLogger(__name__)
76
+
77
+ # AgentOps decorator setting
78
+ try:
79
+ import os
80
+
81
+ if os.getenv("AGENTOPS_API_KEY") is not None:
82
+ from agentops import track_agent
83
+ else:
84
+ raise ImportError
85
+ except (ImportError, AttributeError):
86
+ from camel.utils import track_agent
87
+
88
+
89
+ class FunctionCallingRecord(BaseModel):
90
+ r"""Historical records of functions called in the conversation.
91
+
92
+ Attributes:
93
+ func_name (str): The name of the function being called.
94
+ args (Dict[str, Any]): The dictionary of arguments passed to
95
+ the function.
96
+ result (Any): The execution result of calling this function.
97
+ tool_call_id (str): The ID of the tool call, if available.
98
+ """
99
+
100
+ func_name: str
101
+ args: Dict[str, Any]
102
+ result: Any
103
+ tool_call_id: str
104
+
105
+ def __str__(self) -> str:
106
+ r"""Overridden version of the string function.
107
+
108
+ Returns:
109
+ str: Modified string to represent the function calling.
110
+ """
111
+ return (
112
+ f"Function Execution: {self.func_name}\n"
113
+ f"\tArgs: {self.args}\n"
114
+ f"\tResult: {self.result}\n"
115
+ )
116
+
117
+ def as_dict(self) -> dict[str, Any]:
118
+ r"""Returns the function calling record as a dictionary.
119
+
120
+ Returns:
121
+ dict[str, Any]: The function calling record as a dictionary.
122
+ """
123
+ return self.model_dump()
124
+
125
+
126
+ @track_agent(name="ChatAgent")
127
+ class ChatAgent(BaseAgent):
128
+ r"""Class for managing conversations of CAMEL Chat Agents.
129
+
130
+ Args:
131
+ system_message (Union[BaseMessage, str], optional): The system message
132
+ for the chat agent.
133
+ model (BaseModelBackend, optional): The model backend to use for
134
+ generating responses. (default: :obj:`ModelPlatformType.DEFAULT`
135
+ with `ModelType.DEFAULT`)
136
+ memory (AgentMemory, optional): The agent memory for managing chat
137
+ messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
138
+ (default: :obj:`None`)
139
+ message_window_size (int, optional): The maximum number of previous
140
+ messages to include in the context window. If `None`, no windowing
141
+ is performed. (default: :obj:`None`)
142
+ token_limit (int, optional): The maximum number of tokens in a context.
143
+ The context will be automatically pruned to fulfill the limitation.
144
+ If `None`, it will be set according to the backend model.
145
+ (default: :obj:`None`)
146
+ output_language (str, optional): The language to be output by the
147
+ agent. (default: :obj:`None`)
148
+ tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
149
+ of available :obj:`FunctionTool` or :obj:`Callable`. (default:
150
+ :obj:`None`)
151
+ external_tools (Optional[List[Union[FunctionTool, Callable]]],
152
+ optional): List of external tools (:obj:`FunctionTool` or or
153
+ :obj:`Callable`) bind to one chat agent. When these tools are
154
+ called, the agent will directly return the request instead of
155
+ processing it. (default: :obj:`None`)
156
+ response_terminators (List[ResponseTerminator], optional): List of
157
+ :obj:`ResponseTerminator` bind to one chat agent.
158
+ (default: :obj:`None`)
159
+ scheduling_strategy (str): name of function that defines how to select
160
+ the next model in ModelManager. (default: :str:`round_robin`)
161
+ single_iteration (bool): Whether to let the agent perform only one
162
+ model calling at each step. (default: :obj:`False`)
163
+ """
164
+
165
+ def __init__(
166
+ self,
167
+ system_message: Optional[Union[BaseMessage, str]] = None,
168
+ model: Optional[
169
+ Union[BaseModelBackend, List[BaseModelBackend]]
170
+ ] = None,
171
+ memory: Optional[AgentMemory] = None,
172
+ message_window_size: Optional[int] = None,
173
+ token_limit: Optional[int] = None,
174
+ output_language: Optional[str] = None,
175
+ tools: Optional[List[Union[FunctionTool, Callable]]] = None,
176
+ external_tools: Optional[List[Union[FunctionTool, Callable]]] = None,
177
+ response_terminators: Optional[List[ResponseTerminator]] = None,
178
+ scheduling_strategy: str = "round_robin",
179
+ single_iteration: bool = False,
180
+ ) -> None:
181
+ # Initialize the system message, converting string to BaseMessage if needed
182
+ if isinstance(system_message, str):
183
+ system_message = BaseMessage.make_assistant_message(
184
+ role_name='Assistant', content=system_message
185
+ )
186
+
187
+ self.orig_sys_message: Optional[BaseMessage] = system_message
188
+ self._system_message: Optional[BaseMessage] = system_message
189
+ self.role_name: str = (
190
+ getattr(system_message, 'role_name', None) or "assistant"
191
+ )
192
+ self.role_type: RoleType = (
193
+ getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
194
+ )
195
+ self.model_backend = ModelManager(
196
+ model
197
+ if model is not None
198
+ else ModelFactory.create(
199
+ model_platform=ModelPlatformType.DEFAULT,
200
+ model_type=ModelType.DEFAULT,
201
+ ),
202
+ scheduling_strategy=scheduling_strategy,
203
+ )
204
+ self.model_type = self.model_backend.model_type
205
+
206
+ # Initialize tools
207
+ self.tools: List[FunctionTool] = (
208
+ self._initialize_tools(tools) if tools else []
209
+ )
210
+ self.external_tools: List[FunctionTool] = (
211
+ self._initialize_tools(external_tools) if external_tools else []
212
+ )
213
+ self.external_tool_names: List[str] = [
214
+ tool.get_function_name() for tool in self.external_tools
215
+ ]
216
+ self.all_tools = self.tools + self.external_tools or []
217
+
218
+ # Create tool dictionaries and configure backend tools if necessary
219
+ self.tool_dict = {
220
+ tool.get_function_name(): tool for tool in self.all_tools
221
+ }
222
+
223
+ # If the user set tools from `ChatAgent`, it will override the
224
+ # configured tools in `BaseModelBackend`.
225
+ if self.all_tools:
226
+ logger.warning(
227
+ "Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
228
+ )
229
+ tool_schema_list = [
230
+ tool.get_openai_tool_schema() for tool in self.all_tools
231
+ ]
232
+ self.model_backend.model_config_dict['tools'] = tool_schema_list
233
+
234
+ self.model_token_limit = token_limit or self.model_backend.token_limit
235
+ context_creator = ScoreBasedContextCreator(
236
+ self.model_backend.token_counter,
237
+ self.model_token_limit,
238
+ )
239
+ self.memory: AgentMemory = memory or ChatHistoryMemory(
240
+ context_creator, window_size=message_window_size
241
+ )
242
+
243
+ self.output_language: Optional[str] = output_language
244
+ if self.output_language is not None:
245
+ self.set_output_language(self.output_language)
246
+
247
+ self.terminated: bool = False
248
+ self.response_terminators = response_terminators or []
249
+ self.init_messages()
250
+ self.tool_prompt_added = False
251
+ self.single_iteration = single_iteration
252
+
253
+ def _initialize_tools(
254
+ self, tools: List[Union[FunctionTool, Callable]]
255
+ ) -> List[FunctionTool]:
256
+ r"""Helper method to initialize tools as FunctionTool instances."""
257
+ from camel.toolkits import FunctionTool
258
+
259
+ func_tools = []
260
+ for tool in tools:
261
+ if not isinstance(tool, FunctionTool):
262
+ tool = FunctionTool(tool)
263
+ func_tools.append(tool)
264
+ return func_tools
265
+
266
+ def add_tool(
267
+ self, tool: Union[FunctionTool, Callable], is_external: bool = False
268
+ ) -> None:
269
+ r"""Add a tool to the agent, specifying if it's an external tool."""
270
+ # Initialize the tool
271
+ initialized_tool = self._initialize_tools([tool])
272
+
273
+ # Update tools or external tools based on is_external flag
274
+ if is_external:
275
+ self.external_tools = self.external_tools + initialized_tool
276
+ self.external_tool_names.extend(
277
+ tool.get_function_name() for tool in initialized_tool
278
+ )
279
+ else:
280
+ self.tools = self.tools + initialized_tool
281
+
282
+ # Rebuild all_tools, and tool_dict
283
+ self.all_tools = self.tools + self.external_tools
284
+ self.tool_dict = {
285
+ tool.get_function_name(): tool for tool in self.all_tools
286
+ }
287
+
288
+ tool_schema_list = [
289
+ tool.get_openai_tool_schema() for tool in self.all_tools
290
+ ]
291
+ self.model_backend.model_config_dict['tools'] = tool_schema_list
292
+
293
+ def remove_tool(self, tool_name: str, is_external: bool = False) -> bool:
294
+ r"""Remove a tool by name, specifying if it's an external tool."""
295
+ tool_list = self.external_tools if is_external else self.tools
296
+ if not tool_list:
297
+ return False
298
+
299
+ for tool in tool_list:
300
+ if tool.get_function_name() == tool_name:
301
+ tool_list.remove(tool)
302
+ if is_external:
303
+ self.external_tool_names.remove(tool_name)
304
+ # Reinitialize the tool dictionary
305
+ self.all_tools = (self.tools or []) + (
306
+ self.external_tools or []
307
+ )
308
+ self.tool_dict = {
309
+ tool.get_function_name(): tool for tool in self.all_tools
310
+ }
311
+ tool_schema_list = [
312
+ tool.get_openai_tool_schema() for tool in self.all_tools
313
+ ]
314
+ self.model_backend.model_config_dict['tools'] = (
315
+ tool_schema_list
316
+ )
317
+ return True
318
+ return False
319
+
320
+ def list_tools(self) -> dict:
321
+ r"""List all tools, separated into normal and external tools."""
322
+ normal_tools = [
323
+ tool.get_function_name() for tool in (self.tools or [])
324
+ ]
325
+ external_tools = [
326
+ tool.get_function_name() for tool in (self.external_tools or [])
327
+ ]
328
+
329
+ return {"normal_tools": normal_tools, "external_tools": external_tools}
330
+
331
+ # ruff: noqa: E501
332
+ def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
333
+ r"""Generates a tool prompt based on the provided tool schema list.
334
+
335
+ Args:
336
+ tool_schema_list (List[Dict]): A list of dictionaries, each
337
+ containing a tool schema.
338
+
339
+ Returns:
340
+ str: A string representing the tool prompt.
341
+ """
342
+ tool_prompts = []
343
+
344
+ for tool in tool_schema_list:
345
+ tool_info = tool['function']
346
+ tool_name = tool_info['name']
347
+ tool_description = tool_info['description']
348
+ tool_json = json.dumps(tool_info, indent=4)
349
+
350
+ prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
351
+ tool_prompts.append(prompt)
352
+
353
+ tool_prompt_str = "\n".join(tool_prompts)
354
+
355
+ final_prompt = f"""
356
+ You have access to the following functions:
357
+
358
+ {tool_prompt_str}
359
+
360
+ If you choose to call a function ONLY reply in the following format with no
361
+ prefix or suffix:
362
+
363
+ <function=example_function_name>{{"example_name": "example_value"}}</function>
364
+
365
+ Reminder:
366
+ - Function calls MUST follow the specified format, start with <function= and end with </function>
367
+ - Required parameters MUST be specified
368
+ - Only call one function at a time
369
+ - Put the entire function call reply on one line
370
+ - If there is no function call available, answer the question like normal
371
+ with your current knowledge and do not tell the user about function calls
372
+ """
373
+ return final_prompt
374
+
375
+ def _parse_tool_response(self, response: str):
376
+ r"""Parses the tool response to extract the function name and
377
+ arguments.
378
+
379
+ Args:
380
+ response (str): The response from the model containing the
381
+ function call.
382
+
383
+ Returns:
384
+ Optional[Dict[str, Any]]: The parsed function name and arguments
385
+ if found, otherwise :obj:`None`.
386
+ """
387
+ function_regex = r"<function=(\w+)>(.*?)</function>"
388
+ match = re.search(function_regex, response)
389
+
390
+ if match:
391
+ function_name, args_string = match.groups()
392
+ try:
393
+ args = json.loads(args_string)
394
+ return {"function": function_name, "arguments": args}
395
+ except json.JSONDecodeError as error:
396
+ logger.error(f"Error parsing function arguments: {error}")
397
+ return None
398
+ return None
399
+
400
+ def reset(self):
401
+ r"""Resets the :obj:`ChatAgent` to its initial state."""
402
+ self.terminated = False
403
+ self.init_messages()
404
+ for terminator in self.response_terminators:
405
+ terminator.reset()
406
+
407
+ @property
408
+ def system_message(self) -> Optional[BaseMessage]:
409
+ r"""The getter method for the property :obj:`system_message`.
410
+
411
+ Returns:
412
+ Optional[BaseMessage]: The system message of this agent if set,
413
+ else :obj:`None`.
414
+ """
415
+ return self._system_message
416
+
417
+ @system_message.setter
418
+ def system_message(self, message: BaseMessage) -> None:
419
+ r"""The setter method for the property :obj:`system_message`.
420
+
421
+ Args:
422
+ message (BaseMessage): The message to be set as the
423
+ new system message of this agent.
424
+ """
425
+ self._system_message = message
426
+
427
+ def is_tools_added(self) -> bool:
428
+ r"""Whether tool calling is enabled for this agent.
429
+
430
+ Returns:
431
+ bool: Whether tool calling is enabled for this agent, determined
432
+ by whether the dictionary of tools is empty.
433
+ """
434
+ return len(self.tool_dict) > 0
435
+
436
+ def update_memory(
437
+ self, message: BaseMessage, role: OpenAIBackendRole
438
+ ) -> None:
439
+ r"""Updates the agent memory with a new message.
440
+
441
+ Args:
442
+ message (BaseMessage): The new message to add to the stored
443
+ messages.
444
+ role (OpenAIBackendRole): The backend role type.
445
+ """
446
+ self.memory.write_record(
447
+ MemoryRecord(message=message, role_at_backend=role)
448
+ )
449
+
450
+ def set_output_language(self, output_language: str) -> BaseMessage:
451
+ r"""Sets the output language for the system message. This method
452
+ updates the output language for the system message. The output
453
+ language determines the language in which the output text should be
454
+ generated.
455
+
456
+ Args:
457
+ output_language (str): The desired output language.
458
+
459
+ Returns:
460
+ BaseMessage: The updated system message object.
461
+ """
462
+ self.output_language = output_language
463
+ language_prompt = (
464
+ "\nRegardless of the input language, "
465
+ f"you must output text in {output_language}."
466
+ )
467
+ if self.orig_sys_message is not None:
468
+ content = self.orig_sys_message.content + language_prompt
469
+ self._system_message = self.orig_sys_message.create_new_instance(
470
+ content
471
+ )
472
+ else:
473
+ self._system_message = BaseMessage.make_assistant_message(
474
+ role_name="Assistant",
475
+ content=language_prompt,
476
+ )
477
+
478
+ system_record = MemoryRecord(
479
+ message=self._system_message,
480
+ role_at_backend=OpenAIBackendRole.SYSTEM,
481
+ )
482
+ self.memory.clear()
483
+ self.memory.write_record(system_record)
484
+ return self._system_message
485
+
486
+ def get_info(
487
+ self,
488
+ session_id: Optional[str],
489
+ usage: Optional[Dict[str, int]],
490
+ termination_reasons: List[str],
491
+ num_tokens: int,
492
+ tool_calls: List[FunctionCallingRecord],
493
+ external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
494
+ ) -> Dict[str, Any]:
495
+ r"""Returns a dictionary containing information about the chat session.
496
+
497
+ Args:
498
+ session_id (str, optional): The ID of the chat session.
499
+ usage (Dict[str, int], optional): Information about the usage of
500
+ the LLM.
501
+ termination_reasons (List[str]): The reasons for the termination
502
+ of the chat session.
503
+ num_tokens (int): The number of tokens used in the chat session.
504
+ tool_calls (List[FunctionCallingRecord]): The list of function
505
+ calling records, containing the information of called tools.
506
+ external_tool_request
507
+ (Optional[ChatCompletionMessageToolCall], optional):
508
+ The tool calling request of external tools from the model.
509
+ These requests are directly returned to the user instead of
510
+ being processed by the agent automatically.
511
+ (default: :obj:`None`)
512
+
513
+ Returns:
514
+ Dict[str, Any]: The chat session information.
515
+ """
516
+ return {
517
+ "id": session_id,
518
+ "usage": usage,
519
+ "termination_reasons": termination_reasons,
520
+ "num_tokens": num_tokens,
521
+ "tool_calls": tool_calls,
522
+ "external_tool_request": external_tool_request,
523
+ }
524
+
525
+ def init_messages(self) -> None:
526
+ r"""Initializes the stored messages list with the current system
527
+ message.
528
+ """
529
+ if self._system_message is not None:
530
+ system_record = MemoryRecord(
531
+ message=self._system_message,
532
+ role_at_backend=OpenAIBackendRole.SYSTEM,
533
+ )
534
+ self.memory.clear()
535
+ self.memory.write_record(system_record)
536
+ else:
537
+ self.memory.clear()
538
+
539
+ def record_message(self, message: BaseMessage) -> None:
540
+ r"""Records the externally provided message into the agent memory as if
541
+ it were an answer of the :obj:`ChatAgent` from the backend. Currently,
542
+ the choice of the critic is submitted with this method.
543
+
544
+ Args:
545
+ message (BaseMessage): An external message to be recorded in the
546
+ memory.
547
+ """
548
+ self.update_memory(message, OpenAIBackendRole.ASSISTANT)
549
+
550
+ def step(
551
+ self,
552
+ input_message: Union[BaseMessage, str],
553
+ response_format: Optional[Type[BaseModel]] = None,
554
+ ) -> ChatAgentResponse:
555
+ r"""Executes a single step in the chat session, generating a response
556
+ to the input message.
557
+
558
+ Args:
559
+ input_message (Union[BaseMessage, str]): The input message for the
560
+ agent. If provided as a BaseMessage, the `role` is adjusted to
561
+ `user` to indicate an external message.
562
+ response_format (Optional[Type[BaseModel]], optional): A Pydantic
563
+ model defining the expected structure of the response. Used to
564
+ generate a structured response if provided. (default:
565
+ :obj:`None`)
566
+
567
+ Returns:
568
+ ChatAgentResponse: Contains output messages, a termination status
569
+ flag, and session information.
570
+ """
571
+
572
+ if (
573
+ self.model_backend.model_config_dict.get("response_format")
574
+ and response_format
575
+ ):
576
+ raise ValueError(
577
+ "The `response_format` parameter cannot be set both in "
578
+ "the model configuration and in the ChatAgent step."
579
+ )
580
+
581
+ self.original_model_dict = self.model_backend.model_config_dict
582
+ model_response_format_modified = False
583
+ if (
584
+ response_format
585
+ and self.model_type.support_native_structured_output
586
+ ):
587
+ self.model_backend.model_config_dict = (
588
+ self.original_model_dict.copy()
589
+ )
590
+ self.model_backend.model_config_dict["response_format"] = (
591
+ response_format
592
+ )
593
+ model_response_format_modified = True
594
+
595
+ # Convert input message to BaseMessage if necessary
596
+ if isinstance(input_message, str):
597
+ input_message = BaseMessage.make_user_message(
598
+ role_name='User', content=input_message
599
+ )
600
+
601
+ # Handle tool prompt injection if needed
602
+ if (
603
+ self.is_tools_added()
604
+ and not self.model_type.support_native_tool_calling
605
+ and not self.tool_prompt_added
606
+ ):
607
+ self._inject_tool_prompt()
608
+
609
+ # Add user input to memory
610
+ self.update_memory(input_message, OpenAIBackendRole.USER)
611
+
612
+ try:
613
+ return self._handle_step(response_format, self.single_iteration)
614
+ finally:
615
+ if model_response_format_modified:
616
+ # Reset model config back to original state
617
+ self.model_backend.model_config_dict = self.original_model_dict
618
+
619
+ def _inject_tool_prompt(self) -> None:
620
+ r"""Generate and add the tool prompt to memory."""
621
+ tool_prompt = self._generate_tool_prompt(
622
+ self.model_backend.model_config_dict["tools"]
623
+ )
624
+ tool_msg = BaseMessage.make_assistant_message(
625
+ role_name="Assistant", content=tool_prompt
626
+ )
627
+ self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
628
+ self.tool_prompt_added = True
629
+
630
+ def _handle_step(
631
+ self,
632
+ response_format: Optional[Type[BaseModel]],
633
+ single_step: bool,
634
+ ) -> ChatAgentResponse:
635
+ r"""Handles a single or multi-step interaction."""
636
+
637
+ if (
638
+ self.model_backend.model_config_dict.get("tool_choice")
639
+ == "required"
640
+ and not single_step
641
+ ):
642
+ raise ValueError(
643
+ "`tool_choice` cannot be set to `required` for multi-step"
644
+ " mode. To proceed, set `single_iteration` to `True`."
645
+ )
646
+
647
+ # Record function calls made during the session
648
+ tool_call_records: List[FunctionCallingRecord] = []
649
+
650
+ external_tool_request = None
651
+
652
+ while True:
653
+ try:
654
+ openai_messages, num_tokens = self.memory.get_context()
655
+ except RuntimeError as e:
656
+ self.model_backend.model_config_dict = self.original_model_dict
657
+ return self._step_token_exceed(
658
+ e.args[1], tool_call_records, "max_tokens_exceeded"
659
+ )
660
+
661
+ # Prompt engineering approach for structured output for non-native tool calling models
662
+ inject_prompt_for_structured_output = (
663
+ response_format
664
+ and not self.model_type.support_native_structured_output
665
+ )
666
+
667
+ if inject_prompt_for_structured_output:
668
+ # update last openai message
669
+ usr_msg = openai_messages.pop()
670
+ usr_msg["content"] = generate_prompt_for_structured_output(
671
+ response_format,
672
+ usr_msg["content"], # type: ignore [arg-type]
673
+ )
674
+ openai_messages.append(usr_msg)
675
+
676
+ # Process model response
677
+ (
678
+ response,
679
+ output_messages,
680
+ finish_reasons,
681
+ usage_dict,
682
+ response_id,
683
+ ) = self._step_model_response(openai_messages, num_tokens)
684
+
685
+ # Try to parse structured output to return a Pydantic object
686
+ if inject_prompt_for_structured_output and isinstance(
687
+ response, ChatCompletion
688
+ ):
689
+ content = response.choices[0].message.content
690
+ try:
691
+ json_content = json.loads(str(content))
692
+ output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
693
+ except json.JSONDecodeError as e:
694
+ logger.error(
695
+ f"Failed in parsing the output into JSON: {e}"
696
+ )
697
+ output_messages[0].parsed = None
698
+ except ValidationError as e:
699
+ logger.warning(
700
+ "Successfully generating JSON response, "
701
+ "but failed in parsing it into Pydantic object :"
702
+ f"{e}, return the JSON response in parsed field"
703
+ )
704
+ output_messages[0].parsed = json_content
705
+
706
+ # Finalize on standard response in multi-step mode
707
+ if self._is_standard_response(response):
708
+ break
709
+
710
+ # Handle tool requests
711
+ tool_request = self._extract_tool_call(response)
712
+ if isinstance(response, ChatCompletion) and tool_request:
713
+ response.choices[0].message.tool_calls = [tool_request]
714
+ tool_call_records.append(
715
+ self._step_tool_call_and_update(response)
716
+ )
717
+
718
+ if tool_request.function.name in self.external_tool_names:
719
+ external_tool_request = tool_request
720
+ info = self._step_get_info(
721
+ output_messages,
722
+ finish_reasons,
723
+ usage_dict,
724
+ response_id,
725
+ tool_call_records,
726
+ num_tokens,
727
+ tool_request,
728
+ )
729
+ self._log_final_output(output_messages)
730
+ self.model_backend.model_config_dict = (
731
+ self.original_model_dict
732
+ )
733
+ return ChatAgentResponse(
734
+ msgs=output_messages,
735
+ terminated=self.terminated,
736
+ info=info,
737
+ )
738
+
739
+ # Single-step mode ends after one iteration
740
+ if single_step:
741
+ break
742
+
743
+ # Optional structured output via function calling
744
+ if (
745
+ response_format
746
+ and not inject_prompt_for_structured_output
747
+ and self.model_type
748
+ not in {
749
+ "gpt-4o",
750
+ "gpt-4o-mini",
751
+ }
752
+ ):
753
+ (
754
+ output_messages,
755
+ finish_reasons,
756
+ usage_dict,
757
+ response_id,
758
+ tool_call,
759
+ num_tokens,
760
+ ) = self._structure_output_with_function(response_format)
761
+ tool_call_records.append(tool_call)
762
+
763
+ # Final info and response
764
+ info = self._step_get_info(
765
+ output_messages,
766
+ finish_reasons,
767
+ usage_dict,
768
+ response_id,
769
+ tool_call_records,
770
+ num_tokens,
771
+ external_tool_request,
772
+ )
773
+ self._log_final_output(output_messages)
774
+ self.model_backend.model_config_dict = self.original_model_dict
775
+ return ChatAgentResponse(
776
+ msgs=output_messages, terminated=self.terminated, info=info
777
+ )
778
+
779
+ def _extract_tool_call(
780
+ self, response: Any
781
+ ) -> Optional[ChatCompletionMessageToolCall]:
782
+ r"""Extract the tool call from the model response, if present.
783
+
784
+ Args:
785
+ response (Any): The model's response object.
786
+
787
+ Returns:
788
+ Optional[ChatCompletionMessageToolCall]: The parsed tool call if
789
+ present, otherwise None.
790
+ """
791
+ # Check if the response contains tool calls
792
+ if (
793
+ self.is_tools_added()
794
+ and not self.model_type.support_native_tool_calling
795
+ and "</function>" in response.choices[0].message.content
796
+ ):
797
+ parsed_content = self._parse_tool_response(
798
+ response.choices[0].message.content
799
+ )
800
+ if parsed_content:
801
+ return ChatCompletionMessageToolCall(
802
+ id=str(uuid.uuid4()),
803
+ function=Function(
804
+ arguments=str(parsed_content["arguments"]).replace(
805
+ "'", '"'
806
+ ),
807
+ name=str(parsed_content["function"]),
808
+ ),
809
+ type="function",
810
+ )
811
+ elif (
812
+ self.is_tools_added()
813
+ and self.model_type.support_native_tool_calling
814
+ and response.choices[0].message.tool_calls
815
+ ):
816
+ return response.choices[0].message.tool_calls[0]
817
+
818
+ # No tool call found
819
+ return None
820
+
821
+ def _is_standard_response(self, response: Any) -> bool:
822
+ r"""Determine if the provided response is a standard reply without
823
+ tool calls.
824
+
825
+ Args:
826
+ response (Any): The response object to evaluate.
827
+
828
+ Returns:
829
+ bool: `True` if the response is a standard reply, `False`
830
+ otherwise.
831
+ """
832
+ if not self.is_tools_added():
833
+ return True
834
+
835
+ if not isinstance(response, ChatCompletion):
836
+ return True
837
+
838
+ if self.model_type.support_native_tool_calling:
839
+ return not response.choices[0].message.tool_calls
840
+
841
+ return "</function>" not in str(
842
+ response.choices[0].message.content or ""
843
+ )
844
+
845
+ def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
846
+ r"""Log final messages or warnings about multiple responses."""
847
+ if len(output_messages) == 1:
848
+ self.record_message(output_messages[0])
849
+ else:
850
+ logger.warning(
851
+ "Multiple messages returned in `step()`. Record "
852
+ "selected message manually using `record_message()`."
853
+ )
854
+
855
+ async def step_async(
856
+ self,
857
+ input_message: Union[BaseMessage, str],
858
+ response_format: Optional[Type[BaseModel]] = None,
859
+ ) -> ChatAgentResponse:
860
+ r"""Performs a single step in the chat session by generating a response
861
+ to the input message. This agent step can call async function calls.
862
+
863
+ Args:
864
+ input_message (Union[BaseMessage, str]): The input message to the
865
+ agent. For BaseMessage input, its `role` field that specifies
866
+ the role at backend may be either `user` or `assistant` but it
867
+ will be set to `user` anyway since for the self agent any
868
+ incoming message is external. For str input, the `role_name`
869
+ would be `User`.
870
+ response_format (Optional[Type[BaseModel]], optional): A pydantic
871
+ model class that includes value types and field descriptions
872
+ used to generate a structured response by LLM. This schema
873
+ helps in defining the expected output format. (default:
874
+ :obj:`None`)
875
+
876
+ Returns:
877
+ ChatAgentResponse: A struct containing the output messages,
878
+ a boolean indicating whether the chat session has terminated,
879
+ and information about the chat session.
880
+ """
881
+ if isinstance(input_message, str):
882
+ input_message = BaseMessage.make_user_message(
883
+ role_name='User', content=input_message
884
+ )
885
+
886
+ self.update_memory(input_message, OpenAIBackendRole.USER)
887
+
888
+ tool_call_records: List[FunctionCallingRecord] = []
889
+ while True:
890
+ try:
891
+ openai_messages, num_tokens = self.memory.get_context()
892
+ except RuntimeError as e:
893
+ return self._step_token_exceed(
894
+ e.args[1], tool_call_records, "max_tokens_exceeded"
895
+ )
896
+
897
+ (
898
+ response,
899
+ output_messages,
900
+ finish_reasons,
901
+ usage_dict,
902
+ response_id,
903
+ ) = self._step_model_response(openai_messages, num_tokens)
904
+
905
+ if (
906
+ not self.is_tools_added()
907
+ or not isinstance(response, ChatCompletion)
908
+ or not response.choices[0].message.tool_calls
909
+ ):
910
+ break
911
+
912
+ # Check for external tool call
913
+ external_tool_request = response.choices[0].message.tool_calls[0]
914
+ if external_tool_request.function.name in self.external_tool_names:
915
+ # if model calls an external tool, directly return the request
916
+ info = self._step_get_info(
917
+ output_messages,
918
+ finish_reasons,
919
+ usage_dict,
920
+ response_id,
921
+ tool_call_records,
922
+ num_tokens,
923
+ external_tool_request,
924
+ )
925
+ return ChatAgentResponse(
926
+ msgs=output_messages, terminated=self.terminated, info=info
927
+ )
928
+
929
+ # Normal function calling
930
+ tool_call_records.append(
931
+ await self._step_tool_call_and_update_async(response)
932
+ )
933
+
934
+ if (
935
+ response_format is not None
936
+ and self.model_type.support_native_tool_calling
937
+ ):
938
+ (
939
+ output_messages,
940
+ finish_reasons,
941
+ usage_dict,
942
+ response_id,
943
+ tool_call_record,
944
+ num_tokens,
945
+ ) = self._structure_output_with_function(response_format)
946
+ tool_call_records.append(tool_call_record)
947
+
948
+ info = self._step_get_info(
949
+ output_messages,
950
+ finish_reasons,
951
+ usage_dict,
952
+ response_id,
953
+ tool_call_records,
954
+ num_tokens,
955
+ )
956
+
957
+ if len(output_messages) == 1:
958
+ # Auto record if the output result is a single message
959
+ self.record_message(output_messages[0])
960
+ else:
961
+ logger.warning(
962
+ "Multiple messages returned in `step()`, message won't be "
963
+ "recorded automatically. Please call `record_message()` to "
964
+ "record the selected message manually."
965
+ )
966
+
967
+ return ChatAgentResponse(
968
+ msgs=output_messages, terminated=self.terminated, info=info
969
+ )
970
+
971
+ def _step_tool_call_and_update(
972
+ self, response: ChatCompletion
973
+ ) -> FunctionCallingRecord:
974
+ r"""Processes a function call within the chat completion response,
975
+ records the function call in the provided list of tool calls and
976
+ updates the memory of the current agent.
977
+
978
+ Args:
979
+ response (ChatCompletion): The response object from the chat
980
+ completion.
981
+
982
+ Returns:
983
+ FunctionCallingRecord: The record of calling the function.
984
+ """
985
+
986
+ # Perform function calling
987
+ func_assistant_msg, func_result_msg, tool_call_record = (
988
+ self._step_tool_call(response)
989
+ )
990
+
991
+ # Update the messages
992
+ self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
993
+ self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
994
+
995
+ return tool_call_record
996
+
997
+ async def _step_tool_call_and_update_async(
998
+ self, response: ChatCompletion
999
+ ) -> FunctionCallingRecord:
1000
+ (
1001
+ func_assistant_msg,
1002
+ func_result_msg,
1003
+ func_record,
1004
+ ) = await self.step_tool_call_async(response)
1005
+
1006
+ self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
1007
+ self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
1008
+
1009
+ return func_record
1010
+
1011
+ def _structure_output_with_function(
1012
+ self, response_format: Type[BaseModel]
1013
+ ) -> Tuple[
1014
+ List[BaseMessage],
1015
+ List[str],
1016
+ Dict[str, int],
1017
+ str,
1018
+ FunctionCallingRecord,
1019
+ int,
1020
+ ]:
1021
+ r"""Internal function of structuring the output of the agent based on
1022
+ the given output schema.
1023
+
1024
+ Args:
1025
+ response_format (Type[BaseModel]): The output schema to use for
1026
+ structuring the output.
1027
+
1028
+ Returns:
1029
+ Tuple[List[BaseMessage], List[str], Dict[str, int], str,
1030
+ FunctionCallingRecord, int]:
1031
+ A tuple containing the output messages, finish reasons, usage
1032
+ dictionary, response ID, function calling record, and number of
1033
+ tokens.
1034
+ """
1035
+ from camel.toolkits import FunctionTool
1036
+
1037
+ schema_json = get_pydantic_object_schema(response_format)
1038
+ func_str = json_to_function_code(schema_json)
1039
+ func_callable = func_string_to_callable(func_str)
1040
+ func = FunctionTool(func_callable)
1041
+
1042
+ original_model_dict = self.model_backend.model_config_dict
1043
+
1044
+ # Replace the original tools with the structuring function
1045
+ self.tool_dict = {func.get_function_name(): func}
1046
+ self.model_backend.model_config_dict = original_model_dict.copy()
1047
+ self.model_backend.model_config_dict["tools"] = [
1048
+ func.get_openai_tool_schema()
1049
+ ]
1050
+ self.model_backend.model_config_dict["tool_choice"] = "required"
1051
+
1052
+ openai_messages, num_tokens = self.memory.get_context()
1053
+ (
1054
+ response,
1055
+ output_messages,
1056
+ finish_reasons,
1057
+ usage_dict,
1058
+ response_id,
1059
+ ) = self._step_model_response(openai_messages, num_tokens)
1060
+
1061
+ if isinstance(response, ChatCompletion):
1062
+ tool_call_record = self._step_tool_call_and_update(response)
1063
+ else:
1064
+ raise ValueError(
1065
+ "Structured output is not supported for stream responses."
1066
+ )
1067
+
1068
+ for base_message_item in output_messages:
1069
+ base_message_item.content = json.dumps(tool_call_record.result)
1070
+
1071
+ # Recover the original tools
1072
+ self.model_backend.model_config_dict = original_model_dict
1073
+
1074
+ return (
1075
+ output_messages,
1076
+ finish_reasons,
1077
+ usage_dict,
1078
+ response_id,
1079
+ tool_call_record,
1080
+ num_tokens,
1081
+ )
1082
+
1083
+ def _step_model_response(
1084
+ self,
1085
+ openai_messages: List[OpenAIMessage],
1086
+ num_tokens: int,
1087
+ ) -> tuple[
1088
+ Union[ChatCompletion, Stream],
1089
+ List[BaseMessage],
1090
+ List[str],
1091
+ Dict[str, int],
1092
+ str,
1093
+ ]:
1094
+ r"""Internal function for agent step model response."""
1095
+
1096
+ response = None
1097
+ # Obtain the model's response
1098
+ for _ in range(len(self.model_backend.models)):
1099
+ try:
1100
+ response = self.model_backend.run(openai_messages)
1101
+ break
1102
+ except Exception as exc:
1103
+ logger.error(
1104
+ f"An error occurred while running model "
1105
+ f"{self.model_backend.model_type}, "
1106
+ f"index: {self.model_backend.current_model_index}",
1107
+ exc_info=exc,
1108
+ )
1109
+ continue
1110
+ if not response:
1111
+ raise ModelProcessingError(
1112
+ "Unable to process messages: none of the provided models "
1113
+ "run succesfully."
1114
+ )
1115
+
1116
+ logger.info(
1117
+ f"Model {self.model_backend.model_type}, "
1118
+ f"index {self.model_backend.current_model_index}, "
1119
+ f"processed these messages: {openai_messages}"
1120
+ )
1121
+
1122
+ if isinstance(response, ChatCompletion):
1123
+ output_messages, finish_reasons, usage_dict, response_id = (
1124
+ self.handle_batch_response(response)
1125
+ )
1126
+ else:
1127
+ output_messages, finish_reasons, usage_dict, response_id = (
1128
+ self.handle_stream_response(response, num_tokens)
1129
+ )
1130
+ return (
1131
+ response,
1132
+ output_messages,
1133
+ finish_reasons,
1134
+ usage_dict,
1135
+ response_id,
1136
+ )
1137
+
1138
+ def _step_get_info(
1139
+ self,
1140
+ output_messages: List[BaseMessage],
1141
+ finish_reasons: List[str],
1142
+ usage_dict: Dict[str, int],
1143
+ response_id: str,
1144
+ tool_calls: List[FunctionCallingRecord],
1145
+ num_tokens: int,
1146
+ external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
1147
+ ) -> Dict[str, Any]:
1148
+ r"""Process the output of a chat step and gather information about the
1149
+ step.
1150
+
1151
+ This method checks for termination conditions, updates the agent's
1152
+ state, and collects information about the chat step, including tool
1153
+ calls and termination reasons.
1154
+
1155
+ Args:
1156
+ output_messages (List[BaseMessage]): The messages generated in
1157
+ this step.
1158
+ finish_reasons (List[str]): The reasons for finishing the
1159
+ generation for each message.
1160
+ usage_dict (Dict[str, int]): Dictionary containing token usage
1161
+ information.
1162
+ response_id (str): The ID of the response from the model.
1163
+ tool_calls (List[FunctionCallingRecord]): Records of function calls
1164
+ made during this step.
1165
+ num_tokens (int): The number of tokens used in this step.
1166
+ external_tool_request (Optional[ChatCompletionMessageToolCall]):
1167
+ Any external tool request made during this step.
1168
+ (default: :obj:`None`)
1169
+
1170
+ Returns:
1171
+ Dict[str, Any]: A dictionary containing information about the chat
1172
+ step, including termination status, reasons, and tool call
1173
+ information.
1174
+
1175
+ Note:
1176
+ This method iterates over all response terminators and checks if
1177
+ any of them signal termination. If a terminator signals
1178
+ termination, the agent's state is updated accordingly, and the
1179
+ termination reason is recorded.
1180
+ """
1181
+ termination = [
1182
+ terminator.is_terminated(output_messages)
1183
+ for terminator in self.response_terminators
1184
+ ]
1185
+ # Terminate the agent if any of the terminator terminates
1186
+ self.terminated, termination_reason = next(
1187
+ (
1188
+ (terminated, termination_reason)
1189
+ for terminated, termination_reason in termination
1190
+ if terminated
1191
+ ),
1192
+ (False, None),
1193
+ )
1194
+ # For now only retain the first termination reason
1195
+ if self.terminated and termination_reason is not None:
1196
+ finish_reasons = [termination_reason] * len(finish_reasons)
1197
+
1198
+ info = self.get_info(
1199
+ response_id,
1200
+ usage_dict,
1201
+ finish_reasons,
1202
+ num_tokens,
1203
+ tool_calls,
1204
+ external_tool_request,
1205
+ )
1206
+ return info
1207
+
1208
+ def handle_batch_response(
1209
+ self, response: ChatCompletion
1210
+ ) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
1211
+ r"""Process a batch response from the model and extract the necessary
1212
+ information.
1213
+
1214
+ Args:
1215
+ response (dict): Model response.
1216
+
1217
+ Returns:
1218
+ tuple: A tuple of list of output `ChatMessage`, list of
1219
+ finish reasons, usage dictionary, and response id.
1220
+ """
1221
+ output_messages: List[BaseMessage] = []
1222
+ for choice in response.choices:
1223
+ chat_message = BaseMessage(
1224
+ role_name=self.role_name,
1225
+ role_type=self.role_type,
1226
+ meta_dict=dict(),
1227
+ content=choice.message.content or "",
1228
+ parsed=getattr(choice.message, 'parsed', None),
1229
+ )
1230
+ # Process log probabilities and append to the message meta information
1231
+ if choice.logprobs is not None:
1232
+ tokens_logprobs = choice.logprobs.content
1233
+
1234
+ if tokens_logprobs is not None:
1235
+ # Extract and structure logprob information
1236
+ logprobs_info = [
1237
+ {
1238
+ "token": token_logprob.token,
1239
+ "logprob": token_logprob.logprob,
1240
+ "top_logprobs": [
1241
+ (top_logprob.token, top_logprob.logprob)
1242
+ for top_logprob in token_logprob.top_logprobs
1243
+ ],
1244
+ }
1245
+ for token_logprob in tokens_logprobs
1246
+ ]
1247
+ # Ensure meta_dict exists before adding logprobs info
1248
+ if chat_message.meta_dict is None:
1249
+ chat_message.meta_dict = {}
1250
+ chat_message.meta_dict["logprobs_info"] = logprobs_info
1251
+ # Append the processed chat message to output
1252
+ output_messages.append(chat_message)
1253
+
1254
+ finish_reasons = [
1255
+ str(choice.finish_reason) for choice in response.choices
1256
+ ]
1257
+ usage = (
1258
+ self._safe_model_dump(response.usage)
1259
+ if response.usage is not None
1260
+ else {}
1261
+ )
1262
+ return (
1263
+ output_messages,
1264
+ finish_reasons,
1265
+ usage,
1266
+ response.id,
1267
+ )
1268
+
1269
+ def _safe_model_dump(self, obj) -> dict:
1270
+ r"""Safely dump a Pydantic model to a dictionary.
1271
+
1272
+ This method attempts to use the `model_dump` method if available,
1273
+ otherwise it falls back to the `dict` method.
1274
+
1275
+ Args:
1276
+ obj: The Pydantic model instance to be dumped.
1277
+
1278
+ Returns:
1279
+ dict: A dictionary representation of the Pydantic model.
1280
+ """
1281
+ # Check if the `model_dump` method exists (Pydantic v2)
1282
+ if hasattr(obj, 'model_dump'):
1283
+ return obj.model_dump()
1284
+ # Fallback to `dict()` method (Pydantic v1)
1285
+ elif hasattr(obj, 'dict'):
1286
+ return obj.dict()
1287
+ else:
1288
+ raise TypeError("The object is not a Pydantic model")
1289
+
1290
+ def handle_stream_response(
1291
+ self,
1292
+ response: Stream[ChatCompletionChunk],
1293
+ prompt_tokens: int,
1294
+ ) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
1295
+ r"""Process a stream response from the model and extract the necessary
1296
+ information.
1297
+
1298
+ Args:
1299
+ response (dict): Model response.
1300
+ prompt_tokens (int): Number of input prompt tokens.
1301
+
1302
+ Returns:
1303
+ tuple: A tuple of list of output `ChatMessage`, list of
1304
+ finish reasons, usage dictionary, and response id.
1305
+ """
1306
+ content_dict: defaultdict = defaultdict(lambda: "")
1307
+ finish_reasons_dict: defaultdict = defaultdict(lambda: "")
1308
+ output_messages: List[BaseMessage] = []
1309
+ response_id: str = ""
1310
+ # All choices in one response share one role
1311
+ for chunk in response:
1312
+ response_id = chunk.id
1313
+ for choice in chunk.choices:
1314
+ index = choice.index
1315
+ delta = choice.delta
1316
+ if delta.content is not None:
1317
+ # When response has not been stopped
1318
+ # Notice that only the first chunk_dict has the "role"
1319
+ content_dict[index] += delta.content
1320
+ if choice.finish_reason:
1321
+ finish_reasons_dict[index] = choice.finish_reason
1322
+ chat_message = BaseMessage(
1323
+ role_name=self.role_name,
1324
+ role_type=self.role_type,
1325
+ meta_dict=dict(),
1326
+ content=content_dict[index],
1327
+ )
1328
+ output_messages.append(chat_message)
1329
+ finish_reasons = [
1330
+ finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
1331
+ ]
1332
+ usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
1333
+ return output_messages, finish_reasons, usage_dict, response_id
1334
+
1335
+ def _step_token_exceed(
1336
+ self,
1337
+ num_tokens: int,
1338
+ tool_calls: List[FunctionCallingRecord],
1339
+ termination_reason: str,
1340
+ ) -> ChatAgentResponse:
1341
+ r"""Return trivial response containing number of tokens and information
1342
+ of called functions when the number of tokens exceeds.
1343
+
1344
+ Args:
1345
+ num_tokens (int): Number of tokens in the messages.
1346
+ tool_calls (List[FunctionCallingRecord]): List of information
1347
+ objects of functions called in the current step.
1348
+ termination_reason (str): String of termination reason.
1349
+
1350
+ Returns:
1351
+ ChatAgentResponse: The struct containing trivial outputs and
1352
+ information about token number and called functions.
1353
+ """
1354
+ self.terminated = True
1355
+ output_messages: List[BaseMessage] = []
1356
+
1357
+ info = self.get_info(
1358
+ None,
1359
+ None,
1360
+ [termination_reason],
1361
+ num_tokens,
1362
+ tool_calls,
1363
+ )
1364
+
1365
+ return ChatAgentResponse(
1366
+ msgs=output_messages,
1367
+ terminated=self.terminated,
1368
+ info=info,
1369
+ )
1370
+
1371
+ def _step_tool_call(
1372
+ self,
1373
+ response: ChatCompletion,
1374
+ ) -> Tuple[
1375
+ FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
1376
+ ]:
1377
+ r"""Execute the function with arguments following the model's response.
1378
+
1379
+ Args:
1380
+ response (Dict[str, Any]): The response obtained by calling the
1381
+ model.
1382
+
1383
+ Returns:
1384
+ tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
1385
+ one about the arguments and the other about the execution
1386
+ result, and a struct for logging information about this
1387
+ function call.
1388
+ """
1389
+ choice = response.choices[0]
1390
+ if choice.message.tool_calls is None:
1391
+ raise RuntimeError("Tool call is None")
1392
+ func_name = choice.message.tool_calls[0].function.name
1393
+
1394
+ arguments_str = choice.message.tool_calls[0].function.arguments
1395
+ args = self._safe_json_loads(arguments_str)
1396
+
1397
+ tool = self.tool_dict[func_name]
1398
+ result = tool(**args)
1399
+ tool_call_id = choice.message.tool_calls[0].id
1400
+
1401
+ assist_msg = FunctionCallingMessage(
1402
+ role_name=self.role_name,
1403
+ role_type=self.role_type,
1404
+ meta_dict=None,
1405
+ content="",
1406
+ func_name=func_name,
1407
+ args=args,
1408
+ tool_call_id=tool_call_id,
1409
+ )
1410
+ func_msg = FunctionCallingMessage(
1411
+ role_name=self.role_name,
1412
+ role_type=self.role_type,
1413
+ meta_dict=None,
1414
+ content="",
1415
+ func_name=func_name,
1416
+ result=result,
1417
+ tool_call_id=tool_call_id,
1418
+ )
1419
+
1420
+ # Record information about this function call
1421
+ func_record = FunctionCallingRecord(
1422
+ func_name=func_name,
1423
+ args=args,
1424
+ result=result,
1425
+ tool_call_id=tool_call_id,
1426
+ )
1427
+ return assist_msg, func_msg, func_record
1428
+
1429
+ def _safe_json_loads(self, arguments_str):
1430
+ # Replace Python types with their JSON equivalents
1431
+ arguments_str = arguments_str.replace("None", "null")
1432
+ arguments_str = arguments_str.replace("True", "true")
1433
+ arguments_str = arguments_str.replace("False", "false")
1434
+
1435
+ # Attempt to parse the corrected string
1436
+ try:
1437
+ return json.loads(arguments_str)
1438
+ except json.JSONDecodeError as e:
1439
+ raise ValueError(f"Invalid JSON format: {e}")
1440
+
1441
+ async def step_tool_call_async(
1442
+ self,
1443
+ response: ChatCompletion,
1444
+ ) -> Tuple[
1445
+ FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
1446
+ ]:
1447
+ r"""Execute the async function with arguments following the model's
1448
+ response.
1449
+
1450
+ Args:
1451
+ response (Dict[str, Any]): The response obtained by calling the
1452
+ model.
1453
+
1454
+ Returns:
1455
+ tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
1456
+ one about the arguments and the other about the execution
1457
+ result, and a struct for logging information about this
1458
+ function call.
1459
+ """
1460
+ # Note that when function calling is enabled, `n` is set to 1.
1461
+ choice = response.choices[0]
1462
+ if choice.message.tool_calls is None:
1463
+ raise RuntimeError("Tool call is None")
1464
+ func_name = choice.message.tool_calls[0].function.name
1465
+
1466
+ args = json.loads(choice.message.tool_calls[0].function.arguments)
1467
+ tool = self.tool_dict[func_name]
1468
+ result = await tool(**args)
1469
+ tool_call_id = choice.message.tool_calls[0].id
1470
+
1471
+ assist_msg = FunctionCallingMessage(
1472
+ role_name=self.role_name,
1473
+ role_type=self.role_type,
1474
+ meta_dict=None,
1475
+ content="",
1476
+ func_name=func_name,
1477
+ args=args,
1478
+ tool_call_id=tool_call_id,
1479
+ )
1480
+ func_msg = FunctionCallingMessage(
1481
+ role_name=self.role_name,
1482
+ role_type=self.role_type,
1483
+ meta_dict=None,
1484
+ content="",
1485
+ func_name=func_name,
1486
+ result=result,
1487
+ tool_call_id=tool_call_id,
1488
+ )
1489
+
1490
+ # Record information about this function call
1491
+ func_record = FunctionCallingRecord(
1492
+ func_name=func_name,
1493
+ args=args,
1494
+ result=result,
1495
+ tool_call_id=tool_call_id,
1496
+ )
1497
+ return assist_msg, func_msg, func_record
1498
+
1499
+ def get_usage_dict(
1500
+ self, output_messages: List[BaseMessage], prompt_tokens: int
1501
+ ) -> Dict[str, int]:
1502
+ r"""Get usage dictionary when using the stream mode.
1503
+
1504
+ Args:
1505
+ output_messages (list): List of output messages.
1506
+ prompt_tokens (int): Number of input prompt tokens.
1507
+
1508
+ Returns:
1509
+ dict: Usage dictionary.
1510
+ """
1511
+ encoding = get_model_encoding(self.model_type.value_for_tiktoken)
1512
+ completion_tokens = 0
1513
+ for message in output_messages:
1514
+ completion_tokens += len(encoding.encode(message.content))
1515
+ usage_dict = dict(
1516
+ completion_tokens=completion_tokens,
1517
+ prompt_tokens=prompt_tokens,
1518
+ total_tokens=completion_tokens + prompt_tokens,
1519
+ )
1520
+ return usage_dict
1521
+
1522
+ def add_model_scheduling_strategy(self, name: str, strategy_fn: Callable):
1523
+ r"""Add a scheduling strategy method provided by user to ModelManger.
1524
+
1525
+ Args:
1526
+ name (str): The name of the strategy.
1527
+ strategy_fn (Callable): The scheduling strategy function.
1528
+ """
1529
+ self.model_backend.add_strategy(name, strategy_fn)
1530
+
1531
+ def __repr__(self) -> str:
1532
+ r"""Returns a string representation of the :obj:`ChatAgent`.
1533
+
1534
+ Returns:
1535
+ str: The string representation of the :obj:`ChatAgent`.
1536
+ """
1537
+ return (
1538
+ f"ChatAgent({self.role_name}, {self.role_type}, {self.model_type})"
1539
+ )
camel/agents/critic_agent.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import random
15
+ import warnings
16
+ from typing import Any, Dict, Optional, Sequence
17
+
18
+ from colorama import Fore
19
+
20
+ from camel.agents.chat_agent import ChatAgent
21
+ from camel.memories import AgentMemory
22
+ from camel.messages import BaseMessage
23
+ from camel.models import BaseModelBackend
24
+ from camel.responses import ChatAgentResponse
25
+ from camel.utils import get_first_int, print_text_animated
26
+
27
+ # AgentOps decorator setting
28
+ try:
29
+ import os
30
+
31
+ if os.getenv("AGENTOPS_API_KEY") is not None:
32
+ from agentops import track_agent
33
+ else:
34
+ raise ImportError
35
+ except (ImportError, AttributeError):
36
+ from camel.utils import track_agent
37
+
38
+
39
+ @track_agent(name="CriticAgent")
40
+ class CriticAgent(ChatAgent):
41
+ r"""A class for the critic agent that assists in selecting an option.
42
+
43
+ Args:
44
+ system_message (BaseMessage): The system message for the critic
45
+ agent.
46
+ model (BaseModelBackend, optional): The model backend to use for
47
+ generating responses. (default: :obj:`OpenAIModel` with
48
+ `GPT_4O_MINI`)
49
+ message_window_size (int, optional): The maximum number of previous
50
+ messages to include in the context window. If `None`, no windowing
51
+ is performed. (default: :obj:`6`)
52
+ retry_attempts (int, optional): The number of retry attempts if the
53
+ critic fails to return a valid option. (default: :obj:`2`)
54
+ verbose (bool, optional): Whether to print the critic's messages.
55
+ logger_color (Any): The color of the menu options displayed to the
56
+ user. (default: :obj:`Fore.MAGENTA`)
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ system_message: BaseMessage,
62
+ model: Optional[BaseModelBackend] = None,
63
+ memory: Optional[AgentMemory] = None,
64
+ message_window_size: int = 6,
65
+ retry_attempts: int = 2,
66
+ verbose: bool = False,
67
+ logger_color: Any = Fore.MAGENTA,
68
+ ) -> None:
69
+ super().__init__(
70
+ system_message,
71
+ model=model,
72
+ memory=memory,
73
+ message_window_size=message_window_size,
74
+ )
75
+ self.options_dict: Dict[str, str] = dict()
76
+ self.retry_attempts = retry_attempts
77
+ self.verbose = verbose
78
+ self.logger_color = logger_color
79
+
80
+ def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
81
+ r"""Flattens the options to the critic.
82
+
83
+ Args:
84
+ messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
85
+
86
+ Returns:
87
+ str: A string containing the flattened options to the critic.
88
+ """
89
+ options = [message.content for message in messages]
90
+ flatten_options = (
91
+ f"> Proposals from "
92
+ f"{messages[0].role_name} ({messages[0].role_type}). "
93
+ "Please choose an option:\n"
94
+ )
95
+ for index, option in enumerate(options):
96
+ flatten_options += f"Option {index + 1}:\n{option}\n\n"
97
+ self.options_dict[str(index + 1)] = option
98
+ format = (
99
+ f"Please first enter your choice ([1-{len(self.options_dict)}]) "
100
+ "and then your explanation and comparison: "
101
+ )
102
+ return flatten_options + format
103
+
104
+ def get_option(self, input_message: BaseMessage) -> str:
105
+ r"""Gets the option selected by the critic.
106
+
107
+ Args:
108
+ input_message (BaseMessage): A `BaseMessage` object representing
109
+ the input message.
110
+
111
+ Returns:
112
+ str: The option selected by the critic.
113
+ """
114
+ # TODO: Add support for editing options by the critic.
115
+ msg_content = input_message.content
116
+ i = 0
117
+ while i < self.retry_attempts:
118
+ critic_response = self.step(input_message)
119
+
120
+ if critic_response.msgs is None or len(critic_response.msgs) == 0:
121
+ raise RuntimeError("Got None critic messages.")
122
+ if critic_response.terminated:
123
+ raise RuntimeError("Critic step failed.")
124
+
125
+ critic_msg = critic_response.msg
126
+ if self.verbose:
127
+ print_text_animated(
128
+ self.logger_color + "\n> Critic response: "
129
+ f"\x1b[3m{critic_msg.content}\x1b[0m\n"
130
+ )
131
+ choice = self.parse_critic(critic_msg)
132
+
133
+ if choice in self.options_dict:
134
+ return self.options_dict[choice]
135
+ else:
136
+ input_message = BaseMessage(
137
+ role_name=input_message.role_name,
138
+ role_type=input_message.role_type,
139
+ meta_dict=input_message.meta_dict,
140
+ content="> Invalid choice. Please choose again.\n"
141
+ + msg_content,
142
+ )
143
+ i += 1
144
+ warnings.warn(
145
+ "Critic failed to get a valid option. "
146
+ f"After {self.retry_attempts} attempts. "
147
+ "Returning a random option."
148
+ )
149
+ return random.choice(list(self.options_dict.values()))
150
+
151
+ def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
152
+ r"""Parses the critic's message and extracts the choice.
153
+
154
+ Args:
155
+ critic_msg (BaseMessage): A `BaseMessage` object representing the
156
+ critic's response.
157
+
158
+ Returns:
159
+ Optional[str]: The critic's choice as a string, or None if the
160
+ message could not be parsed.
161
+ """
162
+ choice = str(get_first_int(critic_msg.content))
163
+ return choice
164
+
165
+ def reduce_step(
166
+ self,
167
+ input_messages: Sequence[BaseMessage],
168
+ ) -> ChatAgentResponse:
169
+ r"""Performs one step of the conversation by flattening options to the
170
+ critic, getting the option, and parsing the choice.
171
+
172
+ Args:
173
+ input_messages (Sequence[BaseMessage]): A list of BaseMessage
174
+ objects.
175
+
176
+ Returns:
177
+ ChatAgentResponse: A `ChatAgentResponse` object includes the
178
+ critic's choice.
179
+ """
180
+ meta_chat_message = BaseMessage(
181
+ role_name=input_messages[0].role_name,
182
+ role_type=input_messages[0].role_type,
183
+ meta_dict=input_messages[0].meta_dict,
184
+ content="",
185
+ )
186
+
187
+ flatten_options = self.flatten_options(input_messages)
188
+ if self.verbose:
189
+ print_text_animated(
190
+ self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
191
+ )
192
+ input_msg = meta_chat_message.create_new_instance(flatten_options)
193
+
194
+ option = self.get_option(input_msg)
195
+ output_msg = meta_chat_message.create_new_instance(option)
196
+
197
+ # TODO: The return `info` can be improved.
198
+ return ChatAgentResponse(
199
+ msgs=[output_msg],
200
+ terminated=False,
201
+ info={},
202
+ )
camel/agents/deductive_reasoner_agent.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import re
15
+ from typing import Dict, List, Optional, Union
16
+
17
+ from camel.agents.chat_agent import ChatAgent
18
+ from camel.logger import get_logger
19
+ from camel.messages import BaseMessage
20
+ from camel.models import BaseModelBackend
21
+ from camel.prompts import TextPrompt
22
+ from camel.types import RoleType
23
+
24
+ logger = get_logger(__name__)
25
+
26
+ # AgentOps decorator setting
27
+ try:
28
+ import os
29
+
30
+ if os.getenv("AGENTOPS_API_KEY") is not None:
31
+ from agentops import track_agent
32
+ else:
33
+ raise ImportError
34
+ except (ImportError, AttributeError):
35
+ from camel.utils import track_agent
36
+
37
+
38
+ @track_agent(name="DeductiveReasonerAgent")
39
+ class DeductiveReasonerAgent(ChatAgent):
40
+ r"""An agent responsible for deductive reasoning. Model of deductive
41
+ reasoning:
42
+ - L: A ⊕ C -> q * B
43
+ - A represents the known starting state.
44
+ - B represents the known target state.
45
+ - C represents the conditions required to transition from A to B.
46
+ - Q represents the quality or effectiveness of the transition from
47
+ A to B.
48
+ - L represents the path or process from A to B.
49
+
50
+ Args:
51
+ model (BaseModelBackend, optional): The model backend to use for
52
+ generating responses. (default: :obj:`OpenAIModel` with
53
+ `GPT_4O_MINI`)
54
+ """
55
+
56
+ def __init__(
57
+ self,
58
+ model: Optional[BaseModelBackend] = None,
59
+ ) -> None:
60
+ system_message = BaseMessage(
61
+ role_name="Insight Agent",
62
+ role_type=RoleType.ASSISTANT,
63
+ meta_dict=None,
64
+ content="You assign roles based on tasks.",
65
+ )
66
+ super().__init__(system_message, model=model)
67
+
68
+ def deduce_conditions_and_quality(
69
+ self,
70
+ starting_state: str,
71
+ target_state: str,
72
+ role_descriptions_dict: Optional[Dict[str, str]] = None,
73
+ ) -> Dict[str, Union[List[str], Dict[str, str]]]:
74
+ r"""Derives the conditions and quality from the starting state and the
75
+ target state based on the model of the deductive reasoning and the
76
+ knowledge base. It can optionally consider the roles involved in the
77
+ scenario, which allows tailoring the output more closely to the AI
78
+ agent's environment.
79
+
80
+ Args:
81
+ starting_state (str): The initial or starting state from which
82
+ conditions are deduced.
83
+ target_state (str): The target state of the task.
84
+ role_descriptions_dict (Optional[Dict[str, str]], optional): The
85
+ descriptions of the roles. (default: :obj:`None`)
86
+ role_descriptions_dict (Optional[Dict[str, str]], optional): A
87
+ dictionary describing the roles involved in the scenario. This
88
+ is optional and can be used to provide a context for the
89
+ CAMEL's role-playing, enabling the generation of more relevant
90
+ and tailored conditions and quality assessments. This could be
91
+ generated using a `RoleAssignmentAgent()` or defined manually
92
+ by the user.
93
+
94
+ Returns:
95
+ Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
96
+ extracted data from the message. The dictionary contains three
97
+ keys:
98
+ - 'conditions': A list where each key is a condition ID and
99
+ each value is the corresponding condition text.
100
+ - 'labels': A list of label strings extracted from the message.
101
+ - 'quality': A string of quality assessment strings extracted
102
+ from the message.
103
+ """
104
+ self.reset()
105
+
106
+ deduce_prompt = """You are a deductive reasoner. You are tasked to
107
+ complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
108
+ STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
109
+ CONTENT to help you complete the TASK.
110
+ Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
111
+ fill in the BLANKs, and DO NOT alter or modify any other part of the template
112
+
113
+ ===== MODELING OF DEDUCTIVE REASONING =====
114
+ You are tasked with understanding a mathematical model based on the components
115
+ ${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
116
+ - $A$ represents the known starting state.
117
+ - $B$ represents the known target state.
118
+ - $C$ represents the conditions required to transition from $A$ to $B$.
119
+ - $Q$ represents the quality or effectiveness of the transition from $A$ to
120
+ $B$.
121
+ - $L$ represents the path or process from $A$ to $B$.
122
+
123
+ ===== THOUGHT OF DEDUCTIVE REASONING =====
124
+ 1. Define the Parameters of A and B:
125
+ - Characterization: Before delving into transitions, thoroughly understand
126
+ the nature and boundaries of both $A$ and $B$. This includes the type,
127
+ properties, constraints, and possible interactions between the two.
128
+ - Contrast and Compare: Highlight the similarities and differences between
129
+ $A$ and $B$. This comparative analysis will give an insight into what
130
+ needs changing and what remains constant.
131
+ 2. Historical & Empirical Analysis:
132
+ - Previous Transitions according to the Knowledge Base of GPT: (if
133
+ applicable) Extract conditions and patterns from the historical instances
134
+ where a similar transition from a state comparable to $A$ moved towards
135
+ $B$.
136
+ - Scientific Principles: (if applicable) Consider the underlying
137
+ scientific principles governing or related to the states and their
138
+ transition. For example, if $A$ and $B$ are physical states, laws of
139
+ physics might apply.
140
+ 3. Logical Deduction of Conditions ($C$):
141
+ - Direct Path Analysis: What are the immediate and direct conditions
142
+ required to move from $A$ to $B$?
143
+ - Intermediate States: Are there states between $A$ and $B$ that must be
144
+ traversed or can be used to make the transition smoother or more
145
+ efficient? If yes, what is the content?
146
+ - Constraints & Limitations: Identify potential barriers or restrictions
147
+ in moving from $A$ to $B$. These can be external (e.g., environmental
148
+ factors) or internal (properties of $A$ or $B$).
149
+ - Resource and Information Analysis: What resources and information are
150
+ required for the transition? This could be time, entity, factor, code
151
+ language, software platform, unknowns, etc.
152
+ - External Influences: Consider socio-economic, political, or
153
+ environmental factors (if applicable) that could influence the transition
154
+ conditions.
155
+ - Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
156
+ no matter how unconventional they might seem. Utilize analogies,
157
+ metaphors, or brainstorming techniques to envision possible conditions or
158
+ paths from $A$ to $B$.
159
+ - The conditions $C$ should be multiple but in one sentence. And each
160
+ condition should be concerned with one aspect/entity.
161
+ 4. Entity/Label Recognition of Conditions ($C$):
162
+ - Identify and categorize entities of Conditions ($C$) such as the names,
163
+ locations, dates, specific technical terms or contextual parameters that
164
+ might be associated with events, innovations post-2022.
165
+ - The output of the entities/labels will be used as tags or labels for
166
+ semantic similarity searches. The entities/labels may be the words, or
167
+ phrases, each of them should contain valuable, high information entropy
168
+ information, and should be independent.
169
+ - Ensure that the identified entities are formatted in a manner suitable
170
+ for database indexing and retrieval. Organize the entities into
171
+ categories, and combine the category with its instance into a continuous
172
+ phrase, without using colons or other separators.
173
+ - Format these entities for database indexing: output the category rather
174
+ than its instance/content into a continuous phrase. For example, instead
175
+ of "Jan. 02", identify it as "Event time".
176
+ 5. Quality Assessment ($Q$):
177
+ - Efficiency: How efficient is the transition from $A$ to $B$, which
178
+ measures the resources used versus the desired outcome?
179
+ - Effectiveness: Did the transition achieve the desired outcome or was the
180
+ target state achieved as intended?
181
+ - Safety & Risks: Assess any risks associated with the transition and the
182
+ measures to mitigate them.
183
+ - Feedback Mechanisms: Incorporate feedback loops to continuously monitor
184
+ and adjust the quality of transition, making it more adaptive.
185
+ 6. Iterative Evaluation:
186
+ - Test & Refine: Based on the initially deduced conditions and assessed
187
+ quality, iterate the process to refine and optimize the transition. This
188
+ might involve tweaking conditions, employing different paths, or changing
189
+ resources.
190
+ - Feedback Integration: Use feedback to make improvements and increase the
191
+ quality of the transition.
192
+ 7. Real-world scenarios often present challenges that may not be captured by
193
+ models and frameworks. While using the model, maintain an adaptive mindset:
194
+ - Scenario Exploration: Continuously imagine various possible scenarios,
195
+ both positive and negative, to prepare for unexpected events.
196
+ - Flexibility: Be prepared to modify conditions ($C$) or alter the path/
197
+ process ($L$) if unforeseen challenges arise.
198
+ - Feedback Integration: Rapidly integrate feedback from actual
199
+ implementations to adjust the model's application, ensuring relevancy and
200
+ effectiveness.
201
+
202
+ ===== TASK =====
203
+ Given the starting state $A$ and the target state $B$, assuming that a path
204
+ $L$ always exists between $A$ and $B$, how can one deduce or identify the
205
+ necessary conditions $C$ and the quality $Q$ of the transition?
206
+
207
+ ===== STARTING STATE $A$ =====
208
+ {starting_state}
209
+
210
+ ===== TARGET STATE $B$ =====
211
+ {target_state}
212
+
213
+ {role_with_description_prompt}
214
+ ===== ANSWER TEMPLATE =====
215
+ - Characterization and comparison of $A$ and $B$:\n<BLANK>
216
+ - Historical & Empirical Analysis:\n<BLANK>/None
217
+ - Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
218
+ condition <NUM>:
219
+ <BLANK>.
220
+ - Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
221
+ square brackets)
222
+ - Quality Assessment ($Q$) (do not use symbols):
223
+ <BLANK>.
224
+ - Iterative Evaluation:\n<BLANK>/None"""
225
+
226
+ if role_descriptions_dict is not None:
227
+ role_names = role_descriptions_dict.keys()
228
+ role_with_description_prompt = (
229
+ "===== ROLES WITH DESCRIPTIONS =====\n"
230
+ + "\n".join(
231
+ f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
232
+ for role_name in role_names
233
+ )
234
+ + "\n\n"
235
+ )
236
+ else:
237
+ role_with_description_prompt = ""
238
+ deduce_prompt = TextPrompt(deduce_prompt)
239
+
240
+ deduce = deduce_prompt.format(
241
+ starting_state=starting_state,
242
+ target_state=target_state,
243
+ role_with_description_prompt=role_with_description_prompt,
244
+ )
245
+
246
+ conditions_and_quality_generation_msg = BaseMessage.make_user_message(
247
+ role_name="Deductive Reasoner", content=deduce
248
+ )
249
+
250
+ response = self.step(
251
+ input_message=conditions_and_quality_generation_msg
252
+ )
253
+
254
+ if response.terminated:
255
+ raise RuntimeError(
256
+ "Deduction failed. Error:\n" + f"{response.info}"
257
+ )
258
+ msg: BaseMessage = response.msg
259
+ logger.info(f"Message content:\n{msg.content}")
260
+
261
+ # Extract the conditions from the message
262
+ conditions_dict = {
263
+ f"condition {i}": cdt.replace("<", "")
264
+ .replace(">", "")
265
+ .strip()
266
+ .strip('\n')
267
+ for i, cdt in re.findall(
268
+ r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
269
+ msg.content,
270
+ re.DOTALL,
271
+ )
272
+ }
273
+
274
+ # Extract the labels from the message
275
+ labels = [
276
+ label.strip().strip('\n').strip("\"'")
277
+ for label in re.findall(
278
+ r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
279
+ msg.content,
280
+ re.DOTALL,
281
+ )[0].split(",")
282
+ ]
283
+
284
+ # Extract the quality from the message
285
+ quality = next(
286
+ q.strip().strip('\n')
287
+ for q in re.findall(
288
+ r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
289
+ r"\n(.+?)- Iterative",
290
+ msg.content,
291
+ re.DOTALL,
292
+ )
293
+ )
294
+
295
+ # Convert them into JSON format
296
+ conditions_and_quality_json: Dict[
297
+ str, Union[List[str], Dict[str, str]]
298
+ ] = {}
299
+ conditions_and_quality_json["conditions"] = conditions_dict
300
+ conditions_and_quality_json["labels"] = labels
301
+ conditions_and_quality_json["evaluate_quality"] = quality
302
+
303
+ return conditions_and_quality_json
camel/agents/embodied_agent.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, List, Optional
15
+
16
+ from colorama import Fore
17
+
18
+ from camel.agents.chat_agent import ChatAgent
19
+ from camel.agents.tool_agents.base import BaseToolAgent
20
+ from camel.interpreters import (
21
+ BaseInterpreter,
22
+ InternalPythonInterpreter,
23
+ SubprocessInterpreter,
24
+ )
25
+ from camel.messages import BaseMessage
26
+ from camel.models import BaseModelBackend
27
+ from camel.responses import ChatAgentResponse
28
+ from camel.utils import print_text_animated
29
+
30
+ # AgentOps decorator setting
31
+ try:
32
+ import os
33
+
34
+ if os.getenv("AGENTOPS_API_KEY") is not None:
35
+ from agentops import track_agent
36
+ else:
37
+ raise ImportError
38
+ except (ImportError, AttributeError):
39
+ from camel.utils import track_agent
40
+
41
+
42
+ @track_agent(name="EmbodiedAgent")
43
+ class EmbodiedAgent(ChatAgent):
44
+ r"""Class for managing conversations of CAMEL Embodied Agents.
45
+
46
+ Args:
47
+ system_message (BaseMessage): The system message for the chat agent.
48
+ model (BaseModelBackend, optional): The model backend to use for
49
+ generating responses. (default: :obj:`OpenAIModel` with
50
+ `GPT_4O_MINI`)
51
+ message_window_size (int, optional): The maximum number of previous
52
+ messages to include in the context window. If `None`, no windowing
53
+ is performed. (default: :obj:`None`)
54
+ tool_agents (List[BaseToolAgent], optional): The tools agents to use in
55
+ the embodied agent. (default: :obj:`None`)
56
+ code_interpreter (BaseInterpreter, optional): The code interpreter to
57
+ execute codes. If `code_interpreter` and `tool_agent` are both
58
+ `None`, default to `SubProcessInterpreter`. If `code_interpreter`
59
+ is `None` and `tool_agents` is not `None`, default to
60
+ `InternalPythonInterpreter`. (default: :obj:`None`)
61
+ verbose (bool, optional): Whether to print the critic's messages.
62
+ logger_color (Any): The color of the logger displayed to the user.
63
+ (default: :obj:`Fore.MAGENTA`)
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ system_message: BaseMessage,
69
+ model: Optional[BaseModelBackend] = None,
70
+ message_window_size: Optional[int] = None,
71
+ tool_agents: Optional[List[BaseToolAgent]] = None,
72
+ code_interpreter: Optional[BaseInterpreter] = None,
73
+ verbose: bool = False,
74
+ logger_color: Any = Fore.MAGENTA,
75
+ ) -> None:
76
+ self.tool_agents = tool_agents
77
+ self.code_interpreter: BaseInterpreter
78
+ if code_interpreter is not None:
79
+ self.code_interpreter = code_interpreter
80
+ elif self.tool_agents:
81
+ self.code_interpreter = InternalPythonInterpreter()
82
+ else:
83
+ self.code_interpreter = SubprocessInterpreter()
84
+
85
+ if self.tool_agents:
86
+ system_message = self._set_tool_agents(system_message)
87
+ self.verbose = verbose
88
+ self.logger_color = logger_color
89
+ super().__init__(
90
+ system_message=system_message,
91
+ model=model,
92
+ message_window_size=message_window_size,
93
+ )
94
+
95
+ def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
96
+ action_space_prompt = self._get_tool_agents_prompt()
97
+ result_message = system_message.create_new_instance(
98
+ content=system_message.content.format(
99
+ action_space=action_space_prompt
100
+ )
101
+ )
102
+ if self.tool_agents is not None:
103
+ self.code_interpreter.update_action_space(
104
+ {tool.name: tool for tool in self.tool_agents}
105
+ )
106
+ return result_message
107
+
108
+ def _get_tool_agents_prompt(self) -> str:
109
+ r"""Returns the action space prompt.
110
+
111
+ Returns:
112
+ str: The action space prompt.
113
+ """
114
+ if self.tool_agents is not None:
115
+ return "\n".join(
116
+ [
117
+ f"*** {tool.name} ***:\n {tool.description}"
118
+ for tool in self.tool_agents
119
+ ]
120
+ )
121
+ else:
122
+ return ""
123
+
124
+ def get_tool_agent_names(self) -> List[str]:
125
+ r"""Returns the names of tool agents.
126
+
127
+ Returns:
128
+ List[str]: The names of tool agents.
129
+ """
130
+ if self.tool_agents is not None:
131
+ return [tool.name for tool in self.tool_agents]
132
+ else:
133
+ return []
134
+
135
+ # ruff: noqa: E501
136
+ def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
137
+ r"""Performs a step in the conversation.
138
+
139
+ Args:
140
+ input_message (BaseMessage): The input message.
141
+
142
+ Returns:
143
+ ChatAgentResponse: A struct containing the output messages,
144
+ a boolean indicating whether the chat session has terminated,
145
+ and information about the chat session.
146
+ """
147
+ response = super().step(input_message)
148
+
149
+ if response.msgs is None or len(response.msgs) == 0:
150
+ raise RuntimeError("Got None output messages.")
151
+ if response.terminated:
152
+ raise RuntimeError(f"{self.__class__.__name__} step failed.")
153
+
154
+ # NOTE: Only single output messages are supported
155
+ explanations, codes = response.msg.extract_text_and_code_prompts()
156
+
157
+ if self.verbose:
158
+ for explanation, code in zip(explanations, codes):
159
+ print_text_animated(
160
+ self.logger_color + f"> Explanation:\n{explanation}"
161
+ )
162
+ print_text_animated(self.logger_color + f"> Code:\n{code}")
163
+
164
+ if len(explanations) > len(codes):
165
+ print_text_animated(
166
+ self.logger_color + f"> Explanation:\n{explanations[-1]}"
167
+ )
168
+
169
+ content = response.msg.content
170
+
171
+ if codes is not None:
172
+ try:
173
+ content = "\n> Executed Results:\n"
174
+ for block_idx, code in enumerate(codes):
175
+ executed_output = self.code_interpreter.run(
176
+ code, code.code_type
177
+ )
178
+ content += (
179
+ f"Executing code block {block_idx}: {{\n"
180
+ + executed_output
181
+ + "}\n"
182
+ )
183
+ except InterruptedError as e:
184
+ content = (
185
+ f"\n> Running code fail: {e}\n"
186
+ "Please regenerate the code."
187
+ )
188
+
189
+ # TODO: Handle errors
190
+ content = input_message.content + f"\n> Embodied Actions:\n{content}"
191
+ message = BaseMessage(
192
+ input_message.role_name,
193
+ input_message.role_type,
194
+ input_message.meta_dict,
195
+ content,
196
+ )
197
+ return ChatAgentResponse(
198
+ msgs=[message],
199
+ terminated=response.terminated,
200
+ info=response.info,
201
+ )
camel/agents/knowledge_graph_agent.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import TYPE_CHECKING, Optional, Union
15
+
16
+ if TYPE_CHECKING:
17
+ from unstructured.documents.elements import Element
18
+
19
+ from camel.agents import ChatAgent
20
+ from camel.messages import BaseMessage
21
+ from camel.models import BaseModelBackend
22
+ from camel.prompts import TextPrompt
23
+ from camel.storages.graph_storages.graph_element import (
24
+ GraphElement,
25
+ Node,
26
+ Relationship,
27
+ )
28
+ from camel.types import RoleType
29
+
30
+ # AgentOps decorator setting
31
+ try:
32
+ import os
33
+
34
+ if os.getenv("AGENTOPS_API_KEY") is not None:
35
+ from agentops import track_agent
36
+ else:
37
+ raise ImportError
38
+ except (ImportError, AttributeError):
39
+ from camel.utils import track_agent
40
+
41
+
42
+ text_prompt = """
43
+ You are tasked with extracting nodes and relationships from given content and
44
+ structures them into Node and Relationship objects. Here's the outline of what
45
+ you needs to do:
46
+
47
+ Content Extraction:
48
+ You should be able to process input content and identify entities mentioned
49
+ within it.
50
+ Entities can be any noun phrases or concepts that represent distinct entities
51
+ in the context of the given content.
52
+
53
+ Node Extraction:
54
+ For each identified entity, you should create a Node object.
55
+ Each Node object should have a unique identifier (id) and a type (type).
56
+ Additional properties associated with the node can also be extracted and
57
+ stored.
58
+
59
+ Relationship Extraction:
60
+ You should identify relationships between entities mentioned in the content.
61
+ For each relationship, create a Relationship object.
62
+ A Relationship object should have a subject (subj) and an object (obj) which
63
+ are Node objects representing the entities involved in the relationship.
64
+ Each relationship should also have a type (type), and additional properties if
65
+ applicable.
66
+
67
+ Output Formatting:
68
+ The extracted nodes and relationships should be formatted as instances of the
69
+ provided Node and Relationship classes.
70
+ Ensure that the extracted data adheres to the structure defined by the classes.
71
+ Output the structured data in a format that can be easily validated against
72
+ the provided code.
73
+
74
+ Instructions for you:
75
+ Read the provided content thoroughly.
76
+ Identify distinct entities mentioned in the content and categorize them as
77
+ nodes.
78
+ Determine relationships between these entities and represent them as directed
79
+ relationships.
80
+ Provide the extracted nodes and relationships in the specified format below.
81
+ Example for you:
82
+
83
+ Example Content:
84
+ "John works at XYZ Corporation. He is a software engineer. The company is
85
+ located in New York City."
86
+
87
+ Expected Output:
88
+
89
+ Nodes:
90
+
91
+ Node(id='John', type='Person')
92
+ Node(id='XYZ Corporation', type='Organization')
93
+ Node(id='New York City', type='Location')
94
+
95
+ Relationships:
96
+
97
+ Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
98
+ Corporation', type='Organization'), type='WorksAt')
99
+ Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
100
+ type='Location'), type='ResidesIn')
101
+
102
+ ===== TASK =====
103
+ Please extracts nodes and relationships from given content and structures them
104
+ into Node and Relationship objects.
105
+
106
+ {task}
107
+ """
108
+
109
+
110
+ @track_agent(name="KnowledgeGraphAgent")
111
+ class KnowledgeGraphAgent(ChatAgent):
112
+ r"""An agent that can extract node and relationship information for
113
+ different entities from given `Element` content.
114
+
115
+ Attributes:
116
+ task_prompt (TextPrompt): A prompt for the agent to extract node and
117
+ relationship information for different entities.
118
+ """
119
+
120
+ def __init__(
121
+ self,
122
+ model: Optional[BaseModelBackend] = None,
123
+ ) -> None:
124
+ r"""Initialize the `KnowledgeGraphAgent`.
125
+
126
+ Args:
127
+ model (BaseModelBackend, optional): The model backend to use for
128
+ generating responses. (default: :obj:`OpenAIModel` with
129
+ `GPT_4O_MINI`)
130
+ """
131
+ system_message = BaseMessage(
132
+ role_name="Graphify",
133
+ role_type=RoleType.ASSISTANT,
134
+ meta_dict=None,
135
+ content="Your mission is to transform unstructured content "
136
+ "into structured graph data. Extract nodes and relationships with "
137
+ "precision, and let the connections unfold. Your graphs will "
138
+ "illuminate the hidden connections within the chaos of "
139
+ "information.",
140
+ )
141
+ super().__init__(system_message, model=model)
142
+
143
+ def run(
144
+ self,
145
+ element: "Element",
146
+ parse_graph_elements: bool = False,
147
+ ) -> Union[str, GraphElement]:
148
+ r"""Run the agent to extract node and relationship information.
149
+
150
+ Args:
151
+ element (Element): The input element.
152
+ parse_graph_elements (bool, optional): Whether to parse into
153
+ `GraphElement`. Defaults to `False`.
154
+
155
+ Returns:
156
+ Union[str, GraphElement]: The extracted node and relationship
157
+ information. If `parse_graph_elements` is `True` then return
158
+ `GraphElement`, else return `str`.
159
+ """
160
+ self.reset()
161
+ self.element = element
162
+
163
+ knowledge_graph_prompt = TextPrompt(text_prompt)
164
+ knowledge_graph_generation = knowledge_graph_prompt.format(
165
+ task=str(element)
166
+ )
167
+
168
+ knowledge_graph_generation_msg = BaseMessage.make_user_message(
169
+ role_name="Graphify", content=knowledge_graph_generation
170
+ )
171
+
172
+ response = self.step(input_message=knowledge_graph_generation_msg)
173
+
174
+ content = response.msg.content
175
+
176
+ if parse_graph_elements:
177
+ content = self._parse_graph_elements(content)
178
+
179
+ return content
180
+
181
+ def _validate_node(self, node: Node) -> bool:
182
+ r"""Validate if the object is a valid Node.
183
+
184
+ Args:
185
+ node (Node): Object to be validated.
186
+
187
+ Returns:
188
+ bool: True if the object is a valid Node, False otherwise.
189
+ """
190
+ return (
191
+ isinstance(node, Node)
192
+ and isinstance(node.id, (str, int))
193
+ and isinstance(node.type, str)
194
+ )
195
+
196
+ def _validate_relationship(self, relationship: Relationship) -> bool:
197
+ r"""Validate if the object is a valid Relationship.
198
+
199
+ Args:
200
+ relationship (Relationship): Object to be validated.
201
+
202
+ Returns:
203
+ bool: True if the object is a valid Relationship, False otherwise.
204
+ """
205
+ return (
206
+ isinstance(relationship, Relationship)
207
+ and self._validate_node(relationship.subj)
208
+ and self._validate_node(relationship.obj)
209
+ and isinstance(relationship.type, str)
210
+ )
211
+
212
+ def _parse_graph_elements(self, input_string: str) -> GraphElement:
213
+ r"""Parses graph elements from given content.
214
+
215
+ Args:
216
+ input_string (str): The input content.
217
+
218
+ Returns:
219
+ GraphElement: The parsed graph elements.
220
+ """
221
+ import re
222
+
223
+ # Regular expressions to extract nodes and relationships
224
+ node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
225
+ rel_pattern = (
226
+ r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
227
+ r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
228
+ )
229
+
230
+ nodes = {}
231
+ relationships = []
232
+
233
+ # Extract nodes
234
+ for match in re.finditer(node_pattern, input_string):
235
+ id, type = match.groups()
236
+ properties = {'source': 'agent_created'}
237
+ if id not in nodes:
238
+ node = Node(id=id, type=type, properties=properties)
239
+ if self._validate_node(node):
240
+ nodes[id] = node
241
+
242
+ # Extract relationships
243
+ for match in re.finditer(rel_pattern, input_string):
244
+ subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
245
+ properties = {'source': 'agent_created'}
246
+ if subj_id in nodes and obj_id in nodes:
247
+ subj = nodes[subj_id]
248
+ obj = nodes[obj_id]
249
+ relationship = Relationship(
250
+ subj=subj, obj=obj, type=rel_type, properties=properties
251
+ )
252
+ if self._validate_relationship(relationship):
253
+ relationships.append(relationship)
254
+
255
+ return GraphElement(
256
+ nodes=list(nodes.values()),
257
+ relationships=relationships,
258
+ source=self.element,
259
+ )
camel/agents/multi_hop_generator_agent.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import textwrap
16
+ from typing import Any
17
+
18
+ from pydantic import ConfigDict
19
+
20
+ from camel.agents.programmed_agent_instruction import (
21
+ ProgrammableChatAgent,
22
+ ProgrammedAgentInstructionResult,
23
+ programmable_capability,
24
+ )
25
+ from camel.datagen.source2synth.models import (
26
+ ContextPrompt,
27
+ MultiHopQA,
28
+ )
29
+ from camel.messages import BaseMessage
30
+
31
+
32
+ class MultiHopGeneratorAgent(ProgrammableChatAgent):
33
+ r"""An agent specialized in generating multi-hop question-answer pairs.
34
+
35
+ This agent is designed to create complex questions that require multiple
36
+ steps of reasoning to answer. It analyzes context to identify related
37
+ facts and generates questions that require connecting these facts
38
+ logically.
39
+
40
+ Attributes:
41
+ model_config (ConfigDict): Configuration for model behavior.
42
+ system_message (BaseMessage): System message defining agent's role and
43
+ instructions.
44
+ """
45
+
46
+ model_config = ConfigDict(arbitrary_types_allowed=True)
47
+
48
+ def __init__(self, **kwargs: Any) -> None:
49
+ r"""Initialize the MultiHopGeneratorAgent.
50
+
51
+ Args:
52
+ **kwargs (Any): Additional keyword arguments to pass to parent
53
+ class.
54
+ """
55
+ super().__init__(**kwargs)
56
+
57
+ system_text: str = textwrap.dedent(
58
+ """\
59
+ You are an expert at generating
60
+ multi-hop question-answer pairs.
61
+ For each context, you should:
62
+ 1. Identify multiple related facts or pieces of information
63
+ 2. Create questions that require reasoning across these multiple pieces
64
+ 3. Ensure the reasoning chain is clear and logical
65
+ 4. Generate questions that require at least 2-3 steps of reasoning
66
+ 5. Include the reasoning steps in the answer
67
+
68
+ Give your response with this information:
69
+ Question: [Complex question requiring multiple reasoning steps]
70
+ Reasoning Steps:
71
+ 1. [First reasoning step]
72
+ 2. [Second reasoning step]
73
+ 3. [Final reasoning step]
74
+ Answer: [Final answer]
75
+ Supporting Facts: [List of relevant text segments used]
76
+ """ # noqa: E501
77
+ )
78
+ self.system_message = BaseMessage.make_assistant_message(
79
+ role_name='Assistant', content=system_text
80
+ )
81
+
82
+ @programmable_capability
83
+ def generate_multi_hop_qa(
84
+ self, context: str
85
+ ) -> ProgrammedAgentInstructionResult[MultiHopQA]:
86
+ r"""Generate a multi-hop question-answer pair from given context.
87
+
88
+ Args:
89
+ context (str): The input text context to generate QA from.
90
+
91
+ Returns:
92
+ ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
93
+ generated question, reasoning steps, answer, and supporting
94
+ facts.
95
+
96
+ Raises:
97
+ RuntimeError: If the agent fails to generate a response.
98
+ """
99
+ context_prompt = ContextPrompt(
100
+ main_context=context, related_contexts=None
101
+ )
102
+
103
+ user_message = BaseMessage.make_user_message(
104
+ content=context_prompt.model_dump_json(), role_name="User"
105
+ )
106
+ response = self.step(
107
+ input_message=user_message, response_format=MultiHopQA
108
+ )
109
+ value = MultiHopQA.model_validate_json(response.msgs[0].content)
110
+
111
+ if response.msgs:
112
+ return ProgrammedAgentInstructionResult(
113
+ user_message=user_message,
114
+ agent_message=response.msgs[0],
115
+ value=value,
116
+ )
117
+ raise RuntimeError("No response from agent")
camel/agents/programmed_agent_instruction.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import abc
15
+ import threading
16
+ from enum import Enum
17
+ from functools import wraps
18
+ from typing import Any, Callable, Generic, Optional, TypeVar
19
+
20
+ from pydantic import BaseModel, ConfigDict
21
+
22
+ from camel.agents import ChatAgent
23
+ from camel.messages import BaseMessage
24
+
25
+ T = TypeVar('T')
26
+
27
+
28
+ class ProgrammableAgentRequirement(Enum):
29
+ r"""Requirements for programmable agent state.
30
+
31
+ Defines the possible requirements that can be used to repair the state
32
+ of a programmable agent.
33
+
34
+ Attributes:
35
+ LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
36
+ conversation was not from the user.
37
+ """
38
+
39
+ LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
40
+
41
+
42
+ class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
43
+ r"""Result of a programmable agent instruction execution.
44
+
45
+ Contains the messages exchanged during execution and the computed value.
46
+ The value type is specified by the generic type parameter T.
47
+
48
+ Attributes:
49
+ user_message (BaseMessage): The message sent by the user.
50
+ agent_message (BaseMessage): The message sent by the agent.
51
+ value (T): The computed result value of type T.
52
+ """
53
+
54
+ user_message: BaseMessage
55
+ agent_message: BaseMessage
56
+ value: T
57
+
58
+ model_config = ConfigDict(arbitrary_types_allowed=True)
59
+
60
+
61
+ class AbstractProgrammableAgent(abc.ABC):
62
+ r"""Abstract class for a programmable agent.
63
+
64
+ A programmable agent is an agent that can be programmed to perform a
65
+ specific function or task. This class defines the interface for a
66
+ programmable agent.
67
+
68
+ These methods should be implemented in order to ensure the agent supports
69
+ the necessary guarantees to enable a programming interface while
70
+ maintaining compatibility in a multi-agent system.
71
+
72
+ A programmable agent is responsible for providing and maintaining a
73
+ programming interface for its functionality.
74
+ """
75
+
76
+ @abc.abstractmethod
77
+ def run_atomic(
78
+ self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
79
+ ) -> ProgrammedAgentInstructionResult[T]:
80
+ r"""Run an atomic operation on the agent.
81
+
82
+ An atomic operation is an operation that is guaranteed to
83
+ be executed without interruption by any other operation.
84
+
85
+ Args:
86
+ callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
87
+ operation to execute atomically.
88
+
89
+ Returns:
90
+ ProgrammedAgentInstructionResult[T]: The result of the operation.
91
+
92
+ Raises:
93
+ RuntimeError: If an operation is already in progress.
94
+ """
95
+ raise NotImplementedError
96
+
97
+ @abc.abstractmethod
98
+ def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
99
+ r"""Repair the state of the agent.
100
+
101
+ Agents may have other non-atomic interfaces, such as a user interface,
102
+ or chat between other agents. This method should restore the agent to
103
+ a state where it can perform operations according to the specified
104
+ requirement.
105
+
106
+ Args:
107
+ requirement (ProgrammableAgentRequirement): The requirement to
108
+ repair the state for.
109
+ """
110
+ raise NotImplementedError
111
+
112
+
113
+ def programmable_capability(
114
+ func: Callable[..., ProgrammedAgentInstructionResult[T]],
115
+ ) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
116
+ r"""Decorator for programmable agent capabilities.
117
+
118
+ This decorator ensures that the decorated method is executed atomically
119
+ and maintains the agent's state guarantees.
120
+
121
+ Args:
122
+ func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
123
+ to decorate.
124
+
125
+ Returns:
126
+ Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
127
+ method that ensures atomic execution.
128
+ """
129
+
130
+ @wraps(func)
131
+ def wrapper(
132
+ self, *args: Any, **kwargs: Any
133
+ ) -> ProgrammedAgentInstructionResult[T]:
134
+ return self.run_atomic(lambda: func(self, *args, **kwargs))
135
+
136
+ return wrapper
137
+
138
+
139
+ class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
140
+ r"""A chat agent that can be programmed to perform specific tasks.
141
+
142
+ Provides a default implementation of atomic execution using threading locks
143
+ and basic state tracking for message roles. Implementing classes need to
144
+ provide specific repair logic for their use cases.
145
+
146
+ Attributes:
147
+ _operation_lock (threading.Lock): Lock for ensuring atomic operations.
148
+ _last_message_role (Optional[str]): Role of the last message in the
149
+ conversation.
150
+ """
151
+
152
+ def __init__(self, **kwargs: Any) -> None:
153
+ r"""Initialize the ProgrammableChatAgent.
154
+
155
+ Args:
156
+ **kwargs (Any): Additional keyword arguments to pass to parent
157
+ class.
158
+ """
159
+ super().__init__(**kwargs)
160
+ self._operation_lock = threading.Lock()
161
+ self._last_message_role: Optional[str] = None
162
+
163
+ def run_atomic(
164
+ self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
165
+ ) -> ProgrammedAgentInstructionResult[T]:
166
+ r"""Run an atomic operation on the agent.
167
+
168
+ Ensures thread-safe execution of the callback function by using a lock.
169
+
170
+ Args:
171
+ callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
172
+ operation to execute atomically.
173
+
174
+ Returns:
175
+ ProgrammedAgentInstructionResult[T]: The result of the operation.
176
+
177
+ Raises:
178
+ RuntimeError: If an operation is already in progress.
179
+ """
180
+ if not self._operation_lock.acquire(blocking=False):
181
+ raise RuntimeError("Operation already in progress")
182
+
183
+ try:
184
+ result = callback()
185
+ self._last_message_role = result.agent_message.role_name
186
+ return result
187
+ finally:
188
+ self._operation_lock.release()
189
+
190
+ def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
191
+ r"""Repair the state of the agent.
192
+
193
+ Implements basic state repair for message role requirements.
194
+
195
+ Args:
196
+ requirement (ProgrammableAgentRequirement): The requirement to
197
+ repair the state for.
198
+ """
199
+ if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
200
+ if self._last_message_role == "user":
201
+ raise NotImplementedError(
202
+ "Must implement repair for LAST_MESSAGE_NOT_USER"
203
+ )
camel/agents/role_assignment_agent.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import re
15
+ from typing import Dict, Optional, Union
16
+
17
+ from camel.agents.chat_agent import ChatAgent
18
+ from camel.messages import BaseMessage
19
+ from camel.models import BaseModelBackend
20
+ from camel.prompts import TextPrompt
21
+ from camel.types import RoleType
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="RoleAssignmentAgent")
36
+ class RoleAssignmentAgent(ChatAgent):
37
+ r"""An agent that generates role names based on the task prompt.
38
+
39
+ Args:
40
+ model (BaseModelBackend, optional): The model backend to use for
41
+ generating responses. (default: :obj:`OpenAIModel` with
42
+ `GPT_4O_MINI`)
43
+
44
+ Attributes:
45
+ role_assignment_prompt (TextPrompt): A prompt for the agent to generate
46
+ role names.
47
+ """
48
+
49
+ def __init__(
50
+ self,
51
+ model: Optional[BaseModelBackend] = None,
52
+ ) -> None:
53
+ system_message = BaseMessage(
54
+ role_name="Role Assigner",
55
+ role_type=RoleType.ASSISTANT,
56
+ meta_dict=None,
57
+ content="You assign roles based on tasks.",
58
+ )
59
+ super().__init__(system_message, model=model)
60
+
61
+ def run(
62
+ self,
63
+ task_prompt: Union[str, TextPrompt],
64
+ num_roles: int = 2,
65
+ ) -> Dict[str, str]:
66
+ r"""Generate role names based on the input task prompt.
67
+
68
+ Args:
69
+ task_prompt (Union[str, TextPrompt]): The prompt
70
+ for the task based on which the roles are to be generated.
71
+ num_roles (int, optional): The number of roles to generate.
72
+ (default: :obj:`2`)
73
+
74
+ Returns:
75
+ Dict[str, str]: A dictionary mapping role names to their
76
+ descriptions.
77
+ """
78
+ self.reset()
79
+
80
+ expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
81
+ f"Domain expert {i + 1}: <BLANK>\n"
82
+ f"Associated competencies, characteristics, duties "
83
+ f"and workflows: <BLANK>. End."
84
+ for i in range(num_roles or 0)
85
+ )
86
+ role_assignment_generation_prompt = TextPrompt(
87
+ "You are a role assignment agent, and you're in charge of "
88
+ + "recruiting {num_roles} experts for the following task."
89
+ + "\n==== TASK =====\n {task}\n\n"
90
+ + "Identify the domain experts you'd recruit and detail their "
91
+ + "associated competencies, characteristics, duties and workflows "
92
+ + "to complete the task.\n "
93
+ + "Your answer MUST adhere to the format of ANSWER PROMPT, and "
94
+ + "ONLY answer the BLANKs.\n"
95
+ + expert_prompt
96
+ )
97
+ role_assignment_generation = role_assignment_generation_prompt.format(
98
+ num_roles=num_roles, task=task_prompt
99
+ )
100
+
101
+ role_assignment_generation_msg = BaseMessage.make_user_message(
102
+ role_name="Role Assigner", content=role_assignment_generation
103
+ )
104
+
105
+ response = self.step(input_message=role_assignment_generation_msg)
106
+
107
+ msg = response.msg # type: BaseMessage
108
+ terminated = response.terminated
109
+
110
+ # Distribute the output completions into role names and descriptions
111
+ role_names = [
112
+ desc.replace("<|", "").replace("|>", "")
113
+ for desc in re.findall(
114
+ r"Domain expert \d: (.+?)\nAssociated competencies,",
115
+ msg.content,
116
+ re.DOTALL,
117
+ )
118
+ ]
119
+ role_descriptions = [
120
+ desc.replace("<|", "").replace("|>", "")
121
+ for desc in re.findall(
122
+ r"Associated competencies, characteristics, "
123
+ r"duties and workflows: (.+?) End.",
124
+ msg.content,
125
+ re.DOTALL,
126
+ )
127
+ ]
128
+
129
+ if len(role_names) != num_roles or len(role_descriptions) != num_roles:
130
+ raise RuntimeError(
131
+ "Got None or insufficient information of roles."
132
+ )
133
+ if terminated:
134
+ raise RuntimeError("Role assignment failed.")
135
+
136
+ role_descriptions_dict = {
137
+ role_name: description
138
+ for role_name, description in zip(role_names, role_descriptions)
139
+ }
140
+
141
+ return role_descriptions_dict
camel/agents/search_agent.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Optional
15
+
16
+ from camel.agents.chat_agent import ChatAgent
17
+ from camel.messages import BaseMessage
18
+ from camel.models import BaseModelBackend
19
+ from camel.prompts import TextPrompt
20
+ from camel.types import RoleType
21
+ from camel.utils import create_chunks
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="SearchAgent")
36
+ class SearchAgent(ChatAgent):
37
+ r"""An agent that summarizes text based on a query and evaluates the
38
+ relevance of an answer.
39
+
40
+ Args:
41
+ model (BaseModelBackend, optional): The model backend to use for
42
+ generating responses. (default: :obj:`OpenAIModel` with
43
+ `GPT_4O_MINI`)
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ model: Optional[BaseModelBackend] = None,
49
+ ) -> None:
50
+ system_message = BaseMessage(
51
+ role_name="Assistant",
52
+ role_type=RoleType.ASSISTANT,
53
+ meta_dict=None,
54
+ content="You are a helpful assistant.",
55
+ )
56
+ super().__init__(system_message, model=model)
57
+
58
+ def summarize_text(self, text: str, query: str) -> str:
59
+ r"""Summarize the information from the text, base on the query.
60
+
61
+ Args:
62
+ text (str): Text to summarize.
63
+ query (str): What information you want.
64
+
65
+ Returns:
66
+ str: Strings with information.
67
+ """
68
+ self.reset()
69
+
70
+ summary_prompt = TextPrompt(
71
+ '''Gather information from this text that relative to the
72
+ question, but do not directly answer the question.\nquestion:
73
+ {query}\ntext '''
74
+ )
75
+ summary_prompt = summary_prompt.format(query=query)
76
+ # Max length of each chunk
77
+ max_len = 3000
78
+ results = ""
79
+ chunks = create_chunks(text, max_len)
80
+ # Summarize
81
+ for i, chunk in enumerate(chunks, start=1):
82
+ prompt = summary_prompt + str(i) + ": " + chunk
83
+ user_msg = BaseMessage.make_user_message(
84
+ role_name="User",
85
+ content=prompt,
86
+ )
87
+ result = self.step(user_msg).msg.content
88
+ results += result + "\n"
89
+
90
+ # Final summarization
91
+ final_prompt = TextPrompt(
92
+ '''Here are some summarized texts which split from one text. Using
93
+ the information to answer the question. If can't find the answer,
94
+ you must answer "I can not find the answer to the query" and
95
+ explain why.\n Query:\n{query}.\n\nText:\n'''
96
+ )
97
+ final_prompt = final_prompt.format(query=query)
98
+ prompt = final_prompt + results
99
+
100
+ user_msg = BaseMessage.make_user_message(
101
+ role_name="User",
102
+ content=prompt,
103
+ )
104
+ response = self.step(user_msg).msg.content
105
+
106
+ return response
107
+
108
+ def continue_search(self, query: str, answer: str) -> bool:
109
+ r"""Ask whether to continue search or not based on the provided answer.
110
+
111
+ Args:
112
+ query (str): The question.
113
+ answer (str): The answer to the question.
114
+
115
+ Returns:
116
+ bool: `True` if the user want to continue search, `False`
117
+ otherwise.
118
+ """
119
+ prompt = TextPrompt(
120
+ "Do you think the ANSWER can answer the QUERY? "
121
+ "Use only 'yes' or 'no' to answer.\n"
122
+ "===== QUERY =====\n{query}\n\n"
123
+ "===== ANSWER =====\n{answer}"
124
+ )
125
+ prompt = prompt.format(query=query, answer=answer)
126
+ user_msg = BaseMessage.make_user_message(
127
+ role_name="User",
128
+ content=prompt,
129
+ )
130
+ response = self.step(user_msg).msg.content
131
+ if "yes" in str(response).lower():
132
+ return False
133
+ return True
camel/agents/task_agent.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Dict, List, Optional, Union
15
+
16
+ from camel.agents.chat_agent import ChatAgent
17
+ from camel.messages import BaseMessage
18
+ from camel.models import BaseModelBackend
19
+ from camel.prompts import PromptTemplateGenerator, TextPrompt
20
+ from camel.types import RoleType, TaskType
21
+ from camel.utils import get_task_list
22
+
23
+ # AgentOps decorator setting
24
+ try:
25
+ import os
26
+
27
+ if os.getenv("AGENTOPS_API_KEY") is not None:
28
+ from agentops import track_agent
29
+ else:
30
+ raise ImportError
31
+ except (ImportError, AttributeError):
32
+ from camel.utils import track_agent
33
+
34
+
35
+ @track_agent(name="TaskSpecifyAgent")
36
+ class TaskSpecifyAgent(ChatAgent):
37
+ r"""An agent that specifies a given task prompt by prompting the user to
38
+ provide more details.
39
+
40
+ Attributes:
41
+ DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
42
+ task_specify_prompt (TextPrompt): The prompt for specifying the task.
43
+
44
+ Args:
45
+ model (BaseModelBackend, optional): The model backend to use for
46
+ generating responses. (default: :obj:`OpenAIModel` with
47
+ `GPT_4O_MINI`)
48
+ task_type (TaskType, optional): The type of task for which to generate
49
+ a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
50
+ task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
51
+ specifying the task. (default: :obj:`None`)
52
+ word_limit (int, optional): The word limit for the task prompt.
53
+ (default: :obj:`50`)
54
+ output_language (str, optional): The language to be output by the
55
+ agent. (default: :obj:`None`)
56
+ """
57
+
58
+ DEFAULT_WORD_LIMIT = 50
59
+
60
+ def __init__(
61
+ self,
62
+ model: Optional[BaseModelBackend] = None,
63
+ task_type: TaskType = TaskType.AI_SOCIETY,
64
+ task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
65
+ word_limit: int = DEFAULT_WORD_LIMIT,
66
+ output_language: Optional[str] = None,
67
+ ) -> None:
68
+ self.task_specify_prompt: Union[str, TextPrompt]
69
+ if task_specify_prompt is None:
70
+ task_specify_prompt_template = (
71
+ PromptTemplateGenerator().get_task_specify_prompt(task_type)
72
+ )
73
+
74
+ self.task_specify_prompt = task_specify_prompt_template.format(
75
+ word_limit=word_limit
76
+ )
77
+ else:
78
+ self.task_specify_prompt = TextPrompt(task_specify_prompt)
79
+
80
+ system_message = BaseMessage(
81
+ role_name="Task Specifier",
82
+ role_type=RoleType.ASSISTANT,
83
+ meta_dict=None,
84
+ content="You can make a task more specific.",
85
+ )
86
+
87
+ super().__init__(
88
+ system_message,
89
+ model=model,
90
+ output_language=output_language,
91
+ )
92
+
93
+ def run(
94
+ self,
95
+ task_prompt: Union[str, TextPrompt],
96
+ meta_dict: Optional[Dict[str, Any]] = None,
97
+ ) -> TextPrompt:
98
+ r"""Specify the given task prompt by providing more details.
99
+
100
+ Args:
101
+ task_prompt (Union[str, TextPrompt]): The original task
102
+ prompt.
103
+ meta_dict (Dict[str, Any], optional): A dictionary containing
104
+ additional information to include in the prompt.
105
+ (default: :obj:`None`)
106
+
107
+ Returns:
108
+ TextPrompt: The specified task prompt.
109
+ """
110
+ self.reset()
111
+ task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
112
+
113
+ if meta_dict is not None:
114
+ task_specify_prompt = task_specify_prompt.format(**meta_dict)
115
+ task_msg = BaseMessage.make_user_message(
116
+ role_name="Task Specifier", content=task_specify_prompt
117
+ )
118
+ specifier_response = self.step(task_msg)
119
+
120
+ if specifier_response.terminated:
121
+ raise RuntimeError("Task specification failed.")
122
+ if len(specifier_response.msgs) == 0:
123
+ raise RuntimeError("Got no specification message.")
124
+
125
+ specified_task_msg = specifier_response.msgs[0]
126
+
127
+ return TextPrompt(specified_task_msg.content)
128
+
129
+
130
+ @track_agent(name="TaskPlannerAgent")
131
+ class TaskPlannerAgent(ChatAgent):
132
+ r"""An agent that helps divide a task into subtasks based on the input
133
+ task prompt.
134
+
135
+ Attributes:
136
+ task_planner_prompt (TextPrompt): A prompt for the agent to divide
137
+ the task into subtasks.
138
+
139
+ Args:
140
+ model (BaseModelBackend, optional): The model backend to use for
141
+ generating responses. (default: :obj:`OpenAIModel` with
142
+ `GPT_4O_MINI`)
143
+ output_language (str, optional): The language to be output by the
144
+ agent. (default: :obj:`None`)
145
+ """
146
+
147
+ def __init__(
148
+ self,
149
+ model: Optional[BaseModelBackend] = None,
150
+ output_language: Optional[str] = None,
151
+ ) -> None:
152
+ self.task_planner_prompt = TextPrompt(
153
+ "Divide this task into subtasks: {task}. Be concise."
154
+ )
155
+ system_message = BaseMessage(
156
+ role_name="Task Planner",
157
+ role_type=RoleType.ASSISTANT,
158
+ meta_dict=None,
159
+ content="You are a helpful task planner.",
160
+ )
161
+
162
+ super().__init__(
163
+ system_message,
164
+ model=model,
165
+ output_language=output_language,
166
+ )
167
+
168
+ def run(
169
+ self,
170
+ task_prompt: Union[str, TextPrompt],
171
+ ) -> TextPrompt:
172
+ r"""Generate subtasks based on the input task prompt.
173
+
174
+ Args:
175
+ task_prompt (Union[str, TextPrompt]): The prompt for the task to
176
+ be divided into subtasks.
177
+
178
+ Returns:
179
+ TextPrompt: A prompt for the subtasks generated by the agent.
180
+ """
181
+ # TODO: Maybe include roles information.
182
+ self.reset()
183
+ task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
184
+
185
+ task_msg = BaseMessage.make_user_message(
186
+ role_name="Task Planner", content=task_planner_prompt
187
+ )
188
+
189
+ task_response = self.step(task_msg)
190
+
191
+ if task_response.terminated:
192
+ raise RuntimeError("Task planning failed.")
193
+ if len(task_response.msgs) == 0:
194
+ raise RuntimeError("Got no task planning message.")
195
+
196
+ sub_tasks_msg = task_response.msgs[0]
197
+ return TextPrompt(sub_tasks_msg.content)
198
+
199
+
200
+ @track_agent(name="TaskCreationAgent")
201
+ class TaskCreationAgent(ChatAgent):
202
+ r"""An agent that helps create new tasks based on the objective
203
+ and last completed task. Compared to :obj:`TaskPlannerAgent`,
204
+ it's still a task planner, but it has more context information
205
+ like last task and incomplete task list. Modified from
206
+ `BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
207
+
208
+ Attributes:
209
+ task_creation_prompt (TextPrompt): A prompt for the agent to
210
+ create new tasks.
211
+
212
+ Args:
213
+ role_name (str): The role name of the Agent to create the task.
214
+ objective (Union[str, TextPrompt]): The objective of the Agent to
215
+ perform the task.
216
+ model (BaseModelBackend, optional): The LLM backend to use for
217
+ generating responses. (default: :obj:`OpenAIModel` with
218
+ `GPT_4O_MINI`)
219
+ output_language (str, optional): The language to be output by the
220
+ agent. (default: :obj:`None`)
221
+ message_window_size (int, optional): The maximum number of previous
222
+ messages to include in the context window. If `None`, no windowing
223
+ is performed. (default: :obj:`None`)
224
+ max_task_num (int, optional): The maximum number of planned
225
+ tasks in one round. (default: :obj:3)
226
+ """
227
+
228
+ def __init__(
229
+ self,
230
+ role_name: str,
231
+ objective: Union[str, TextPrompt],
232
+ model: Optional[BaseModelBackend] = None,
233
+ output_language: Optional[str] = None,
234
+ message_window_size: Optional[int] = None,
235
+ max_task_num: Optional[int] = 3,
236
+ ) -> None:
237
+ task_creation_prompt = TextPrompt(
238
+ """Create new a task with the following objective: {objective}.
239
+ Never forget you are a Task Creator of {role_name}.
240
+ You must instruct me based on my expertise and your needs to solve the task.
241
+ You should consider past solved tasks and in-progress tasks: {task_list}.
242
+ The new created tasks must not overlap with these past tasks.
243
+ The result must be a numbered list in the format:
244
+
245
+ #. First Task
246
+ #. Second Task
247
+ #. Third Task
248
+
249
+ You can only give me up to {max_task_num} tasks at a time. \
250
+ Each task should be concise, concrete and doable for a {role_name}.
251
+ You should make task plan and not ask me questions.
252
+ If you think no new tasks are needed right now, write "No tasks to add."
253
+ Now start to give me new tasks one by one. No more than three tasks.
254
+ Be concrete.
255
+ """
256
+ )
257
+
258
+ self.task_creation_prompt = task_creation_prompt.format(
259
+ objective=objective, role_name=role_name, max_task_num=max_task_num
260
+ )
261
+ self.objective = objective
262
+
263
+ system_message = BaseMessage(
264
+ role_name="Task Creator",
265
+ role_type=RoleType.ASSISTANT,
266
+ meta_dict=None,
267
+ content="You are a helpful task creator.",
268
+ )
269
+
270
+ super().__init__(
271
+ system_message,
272
+ model=model,
273
+ output_language=output_language,
274
+ message_window_size=message_window_size,
275
+ )
276
+
277
+ def run(
278
+ self,
279
+ task_list: List[str],
280
+ ) -> List[str]:
281
+ r"""Generate subtasks based on the previous task results and
282
+ incomplete task list.
283
+
284
+ Args:
285
+ task_list (List[str]): The completed or in-progress
286
+ tasks which should not overlap with new created tasks.
287
+
288
+ Returns:
289
+ List[str]: The new task list generated by the Agent.
290
+ """
291
+
292
+ if len(task_list) > 0:
293
+ task_creation_prompt = self.task_creation_prompt.format(
294
+ task_list=task_list
295
+ )
296
+ else:
297
+ task_creation_prompt = self.task_creation_prompt.format(
298
+ task_list=""
299
+ )
300
+
301
+ task_msg = BaseMessage.make_user_message(
302
+ role_name="Task Creator", content=task_creation_prompt
303
+ )
304
+ task_response = self.step(task_msg)
305
+
306
+ if task_response.terminated:
307
+ raise RuntimeError("Task creation failed.")
308
+ if len(task_response.msgs) == 0:
309
+ raise RuntimeError("Got no task creation message.")
310
+
311
+ sub_tasks_msg = task_response.msgs[0]
312
+ return get_task_list(sub_tasks_msg.content)
313
+
314
+
315
+ @track_agent(name="TaskPrioritizationAgent")
316
+ class TaskPrioritizationAgent(ChatAgent):
317
+ r"""An agent that helps re-prioritize the task list and
318
+ returns numbered prioritized list. Modified from
319
+ `BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
320
+
321
+ Attributes:
322
+ task_prioritization_prompt (TextPrompt): A prompt for the agent to
323
+ prioritize tasks.
324
+
325
+ Args:
326
+ objective (Union[str, TextPrompt]): The objective of the Agent to
327
+ perform the task.
328
+ model (BaseModelBackend, optional): The LLM backend to use for
329
+ generating responses. (default: :obj:`OpenAIModel` with
330
+ `GPT_4O_MINI`)
331
+ output_language (str, optional): The language to be output by the
332
+ agent. (default: :obj:`None`)
333
+ message_window_size (int, optional): The maximum number of previous
334
+ messages to include in the context window. If `None`, no windowing
335
+ is performed. (default: :obj:`None`)
336
+ """
337
+
338
+ def __init__(
339
+ self,
340
+ objective: Union[str, TextPrompt],
341
+ model: Optional[BaseModelBackend] = None,
342
+ output_language: Optional[str] = None,
343
+ message_window_size: Optional[int] = None,
344
+ ) -> None:
345
+ task_prioritization_prompt = TextPrompt(
346
+ """Prioritize the following tasks : {task_list}.
347
+ Consider the ultimate objective of you: {objective}.
348
+ Tasks should be sorted from highest to lowest priority, where higher-priority \
349
+ tasks are those that act as pre-requisites or are more essential for meeting \
350
+ the objective. Return one task per line in your response.
351
+ Do not remove or modify any tasks.
352
+ The result must be a numbered list in the format:
353
+
354
+ #. First task
355
+ #. Second task
356
+
357
+ The entries must be consecutively numbered, starting with 1.
358
+ The number of each entry must be followed by a period.
359
+ Do not include any headers before your ranked list or follow your list \
360
+ with any other output."""
361
+ )
362
+
363
+ self.task_prioritization_prompt = task_prioritization_prompt.format(
364
+ objective=objective
365
+ )
366
+ self.objective = objective
367
+
368
+ system_message = BaseMessage(
369
+ role_name="Task Prioritizer",
370
+ role_type=RoleType.ASSISTANT,
371
+ meta_dict=None,
372
+ content="You are a helpful task prioritizer.",
373
+ )
374
+
375
+ super().__init__(
376
+ system_message,
377
+ model=model,
378
+ output_language=output_language,
379
+ message_window_size=message_window_size,
380
+ )
381
+
382
+ def run(
383
+ self,
384
+ task_list: List[str],
385
+ ) -> List[str]:
386
+ r"""Prioritize the task list given the agent objective.
387
+
388
+ Args:
389
+ task_list (List[str]): The unprioritized tasks of agent.
390
+
391
+ Returns:
392
+ List[str]: The new prioritized task list generated by the Agent.
393
+ """
394
+ task_prioritization_prompt = self.task_prioritization_prompt.format(
395
+ task_list=task_list
396
+ )
397
+
398
+ task_msg = BaseMessage.make_user_message(
399
+ role_name="Task Prioritizer", content=task_prioritization_prompt
400
+ )
401
+
402
+ task_response = self.step(task_msg)
403
+
404
+ if task_response.terminated:
405
+ raise RuntimeError("Task prioritization failed.")
406
+ if len(task_response.msgs) == 0:
407
+ raise RuntimeError("Got no task prioritization message.")
408
+
409
+ sub_tasks_msg = task_response.msgs[0]
410
+ return get_task_list(sub_tasks_msg.content)
camel/agents/tool_agents/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .base import BaseToolAgent
15
+ from .hugging_face_tool_agent import HuggingFaceToolAgent
16
+
17
+ __all__ = [
18
+ 'BaseToolAgent',
19
+ 'HuggingFaceToolAgent',
20
+ ]
camel/agents/tool_agents/base.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from camel.agents import BaseAgent
15
+
16
+
17
+ class BaseToolAgent(BaseAgent):
18
+ r"""Creates a :obj:`BaseToolAgent` object with the specified name and
19
+ description.
20
+
21
+ Args:
22
+ name (str): The name of the tool agent.
23
+ description (str): The description of the tool agent.
24
+ """
25
+
26
+ def __init__(self, name: str, description: str) -> None:
27
+ self.name = name
28
+ self.description = description
29
+
30
+ def reset(self) -> None:
31
+ r"""Resets the agent to its initial state."""
32
+ pass
33
+
34
+ def step(self) -> None:
35
+ r"""Performs a single step of the agent."""
36
+ pass
37
+
38
+ def __str__(self) -> str:
39
+ return f"{self.name}: {self.description}"
camel/agents/tool_agents/hugging_face_tool_agent.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Any, Optional
15
+
16
+ from camel.agents.tool_agents.base import BaseToolAgent
17
+
18
+
19
+ # flake8: noqa :E501
20
+ class HuggingFaceToolAgent(BaseToolAgent):
21
+ r"""Tool agent for calling HuggingFace models. This agent is a wrapper
22
+ around agents from the `transformers` library. For more information
23
+ about the available models, please see the `transformers` documentation
24
+ at https://huggingface.co/docs/transformers/transformers_agents.
25
+
26
+ Args:
27
+ name (str): The name of the agent.
28
+ *args (Any): Additional positional arguments to pass to the underlying
29
+ Agent class.
30
+ remote (bool, optional): Flag indicating whether to run the agent
31
+ remotely. (default: :obj:`True`)
32
+ **kwargs (Any): Additional keyword arguments to pass to the underlying
33
+ Agent class.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ name: str,
39
+ *args: Any,
40
+ remote: bool = True,
41
+ **kwargs: Any,
42
+ ) -> None:
43
+ try:
44
+ # TODO: Support other tool agents
45
+ import transformers
46
+ from packaging import version
47
+
48
+ if version.parse(transformers.__version__) < version.parse(
49
+ "4.31.0"
50
+ ):
51
+ raise ValueError(
52
+ "The version of \"transformers\" package should >= 4.31.0"
53
+ )
54
+
55
+ from transformers.tools import OpenAiAgent
56
+ from transformers.tools.agent_types import AgentImage
57
+ except (ImportError, ValueError):
58
+ raise ValueError(
59
+ "Could not import transformers tool agents. "
60
+ "Please setup the environment with "
61
+ "pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
62
+ )
63
+ self.agent_image_type = AgentImage
64
+ self.agent = OpenAiAgent(*args, **kwargs)
65
+ description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
66
+ - Document question answering: given a document (such as a PDF) in image format, answer a question on this document
67
+ - Text question answering: given a long text and a question, answer the question in the text
68
+ - Unconditional image captioning: Caption the image!
69
+ - Image question answering: given an image, answer a question on this image
70
+ - Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
71
+ - Speech to text: given an audio recording of a person talking, transcribe the speech into text
72
+ - Text to speech: convert text to speech
73
+ - Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
74
+ - Text summarization: summarize a long text in one or a few sentences
75
+ - Translation: translate the text into a given language
76
+ - Text downloading: to download a text from a web URL
77
+ - Text to image: generate an image according to a prompt, leveraging stable diffusion
78
+ - Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
79
+ - Text to video: generate a small video according to a prompt
80
+
81
+ Here are some python code examples of what you can do with this agent:
82
+
83
+ Single execution (step) mode, the single execution method is when using the step() method of the agent:
84
+ ```
85
+ # Text to image
86
+ rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
87
+ rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
88
+
89
+ # Text to image -> Image transformation
90
+ sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
91
+ sea_add_island_image.save("./sea_add_island_image.png")
92
+
93
+ # If you'd like to keep a state across executions or to pass non-text objects to the agent,
94
+ # you can do so by specifying variables that you would like the agent to use. For example,
95
+ # you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
96
+ picture = {name}.step("Generate a picture of rivers and lakes.")
97
+ picture.save("./picture.png")
98
+ updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
99
+ updated_picture.save("./updated_picture.png")
100
+
101
+ capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
102
+ capybara_sea_image.save("./capybara_sea_image.png")
103
+
104
+ # Document question answering
105
+ answer = {name}.step(
106
+ "In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
107
+ document=document,
108
+ )
109
+ print(answer)
110
+
111
+
112
+ # Text to image
113
+ boat_image = {name}.step("Generate an image of a boat in the water")
114
+ boat_image.save("./boat_image.png")
115
+
116
+ # Unconditional image captioning
117
+ boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
118
+ print(boat_image_caption)
119
+
120
+ # Text to image -> Unconditional image captioning -> Text to speech
121
+ boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
122
+
123
+ # Text downloading
124
+ document = {name}.step("Download the text from http://hf.co")
125
+ print(document)
126
+
127
+ # Text summarization
128
+ summary = {name}.step("Summarize the following text: `document`", document=document)
129
+ print(summary)
130
+
131
+ # Text downloading -> Text summarization -> Text to speech
132
+ audio = {name}.step("Read out loud the summary of http://hf.co")
133
+ ```
134
+
135
+ Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
136
+ ```
137
+ # Clean the chat history
138
+ {name}.reset()
139
+
140
+ # Text to image
141
+ capybara_image = {name}.chat("Show me an an image of a capybara")
142
+ capybara_image.save("./capybara_image.png")
143
+
144
+ # Image transformation
145
+ transformed_capybara_image = {name}.chat("Transform the image so that it snows")
146
+ transformed_capybara_image.save("./transformed_capybara_image.png")
147
+
148
+ # Image segmentation
149
+ segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
150
+ segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
151
+ ```
152
+ """
153
+ super(HuggingFaceToolAgent, self).__init__(name, description)
154
+ self.remote = remote
155
+
156
+ def reset(self) -> None:
157
+ r"""Resets the chat history of the agent."""
158
+ self.agent.prepare_for_new_chat()
159
+
160
+ def step(
161
+ self,
162
+ *args: Any,
163
+ remote: Optional[bool] = None,
164
+ **kwargs: Any,
165
+ ) -> Any:
166
+ r"""Runs the agent in single execution mode.
167
+
168
+ Args:
169
+ *args (Any): Positional arguments to pass to the agent.
170
+ remote (bool, optional): Flag indicating whether to run the agent
171
+ remotely. Overrides the default setting. (default: :obj:`None`)
172
+ **kwargs (Any): Keyword arguments to pass to the agent.
173
+
174
+ Returns:
175
+ str: The response from the agent.
176
+ """
177
+ if remote is None:
178
+ remote = self.remote
179
+ agent_output = self.agent.run(*args, remote=remote, **kwargs)
180
+ if isinstance(agent_output, self.agent_image_type):
181
+ agent_output = agent_output.to_raw()
182
+ return agent_output
183
+
184
+ def chat(
185
+ self,
186
+ *args: Any,
187
+ remote: Optional[bool] = None,
188
+ **kwargs: Any,
189
+ ) -> Any:
190
+ r"""Runs the agent in a chat conversation mode.
191
+
192
+ Args:
193
+ *args (Any): Positional arguments to pass to the agent.
194
+ remote (bool, optional): Flag indicating whether to run the agent
195
+ remotely. Overrides the default setting. (default: :obj:`None`)
196
+ **kwargs (Any): Keyword arguments to pass to the agent.
197
+
198
+ Returns:
199
+ str: The response from the agent.
200
+ """
201
+ if remote is None:
202
+ remote = self.remote
203
+ agent_output = self.agent.chat(*args, remote=remote, **kwargs)
204
+ if isinstance(agent_output, self.agent_image_type):
205
+ agent_output = agent_output.to_raw()
206
+ return agent_output
camel/benchmarks/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from .apibank import APIBankBenchmark
16
+ from .apibench import APIBenchBenchmark
17
+ from .base import BaseBenchmark
18
+ from .gaia import DefaultGAIARetriever, GAIABenchmark
19
+ from .nexus import NexusBenchmark
20
+ from .ragbench import RAGBenchBenchmark
21
+
22
+ __all__ = [
23
+ "BaseBenchmark",
24
+ "GAIABenchmark",
25
+ "DefaultGAIARetriever",
26
+ "NexusBenchmark",
27
+ "APIBenchBenchmark",
28
+ "APIBankBenchmark",
29
+ "RAGBenchBenchmark",
30
+ ]
camel/benchmarks/apibank.py ADDED
@@ -0,0 +1,565 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import random
19
+ import re
20
+ import sys
21
+ from pathlib import Path
22
+ from typing import Any, Dict, List, Literal, Optional
23
+
24
+ import numpy as np
25
+ from rouge import Rouge
26
+ from tqdm import tqdm
27
+
28
+ from camel.agents import ChatAgent
29
+ from camel.benchmarks.base import BaseBenchmark
30
+ from camel.messages import BaseMessage
31
+ from camel.utils import download_github_subdirectory
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+ # Add current folder to sys.path to enable relative import
36
+ current_folder = os.getcwd()
37
+ if current_folder not in sys.path:
38
+ sys.path.append(current_folder)
39
+
40
+
41
+ def process_messages(
42
+ chat_history: List[Dict[str, Any]],
43
+ prompt: str,
44
+ ) -> List[Dict[str, str]]:
45
+ """
46
+ Processes chat history into a structured format for further use.
47
+
48
+ Args:
49
+ chat_history (List[Dict[str, Any]):
50
+ A list of dictionaries representing the chat history.
51
+ prompt (str): A propmt to be set as the system message.
52
+
53
+ Returns:
54
+ List[Dict[str, str]]: A list of dictionaries representing
55
+ the processed messages, where each dictionary has:
56
+ - 'role': The role of the message ('system', 'user', or 'assistant').
57
+ - 'content': The content of the message, including formatted
58
+ API responses when applicable.
59
+ """
60
+ messages = [{'role': 'system', 'content': prompt}]
61
+ for item in chat_history:
62
+ role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
63
+ chat_role = role_map.get(
64
+ item['role'], 'unknown'
65
+ ) # default role to 'unknown'
66
+ if item['role'] == 'API':
67
+ chat_content = '[{}({})] Response: {}'.format(
68
+ item['api_name'],
69
+ ', '.join(
70
+ [
71
+ '{}=\'{}\''.format(k, v)
72
+ for k, v in item['param_dict'].items()
73
+ ]
74
+ ),
75
+ str(item['result']['output']),
76
+ )
77
+ else:
78
+ chat_content = item['text']
79
+ messages.append({'role': chat_role, 'content': chat_content})
80
+ return messages
81
+
82
+
83
+ class APIBankBenchmark(BaseBenchmark):
84
+ r"""API-Bank Benchmark adapted from `API-Bank:
85
+ A Comprehensive Benchmark for Tool-Augmented LLMs`
86
+ <https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
87
+
88
+ Args:
89
+ save_to (str): The file to save the results.
90
+ processes (int, optional): The number of processes to use.
91
+ (default: :obj:`1`)
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ save_to: str,
97
+ processes: int = 1,
98
+ ):
99
+ r"""Initialize the APIBank benchmark.
100
+
101
+ Args:
102
+ save_to (str): The file to save the results.
103
+ processes (int, optional): The number of processes to use for
104
+ parallel processing. (default: :obj:`1`)
105
+ """
106
+ # Predefine data_dir for better import management
107
+ super().__init__("apibank", "api_bank", save_to, processes)
108
+ self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
109
+
110
+ def download(self):
111
+ r"""Download APIBank dataset and code from Github."""
112
+
113
+ repo = "AlibabaResearch/DAMO-ConvAI"
114
+ subdir = "api-bank"
115
+ data_dir = self.data_dir
116
+
117
+ download_github_subdirectory(repo, subdir, data_dir)
118
+
119
+ sys.path.insert(0, self.data_dir)
120
+ logger.info("Download completed.")
121
+
122
+ def load(self, level: str, force_download: bool = False): # type: ignore[override]
123
+ r"""Load the APIBank Benchmark dataset.
124
+
125
+ Args:
126
+ level (str): Level to run benchmark on.
127
+ force_download (bool, optional): Whether to
128
+ force download the data.
129
+ """
130
+ if force_download:
131
+ logger.info("Force downloading data.")
132
+ self.download()
133
+
134
+ if level == "level-1":
135
+ file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
136
+ elif level == 'level-2':
137
+ file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
138
+ jsonl_files = [
139
+ f for f in os.listdir(file_path) if f.endswith('.jsonl')
140
+ ]
141
+ for file in tqdm(jsonl_files, desc="Processing files"):
142
+ history = []
143
+ with open(file_path / file, 'r') as f:
144
+ for line in f:
145
+ history.append(json.loads(line))
146
+ samples = APIBankSample.from_chat_history(history)
147
+ self._data[file.rsplit('.', 1)[0]] = samples
148
+
149
+ # Change import to relative import in the downloaded python files
150
+ def process_files(folder_path, replacements):
151
+ r"""Replace absolute imports in downloaded files with
152
+ relative import."""
153
+ for file in os.listdir(folder_path):
154
+ if file.endswith(".py"):
155
+ file_path = os.path.join(folder_path, file)
156
+ try:
157
+ with open(file_path, "r", encoding="utf-8") as file:
158
+ content = file.read()
159
+
160
+ original_content = content
161
+
162
+ for pattern, replacement in replacements:
163
+ content = re.sub(pattern, replacement, content)
164
+
165
+ if content != original_content:
166
+ with open(
167
+ file_path, "w", encoding="utf-8"
168
+ ) as file:
169
+ file.write(content)
170
+ logger.info(f"Updated file: {file_path}")
171
+
172
+ except Exception as e:
173
+ logger.info(f"Error processing file {file_path}: {e}")
174
+
175
+ api_bank_folder = "api_bank"
176
+ apis_folder = os.path.join(api_bank_folder, "apis")
177
+
178
+ apis_replacements = [
179
+ (r"from apis.api", "from .api"),
180
+ (r"from apis import", "from .api import"),
181
+ ]
182
+
183
+ api_bank_replacements = [
184
+ (r"from apis", "from .apis"),
185
+ (r"from api_call_extraction", "from .api_call_extraction"),
186
+ (r"f'{basename}", r"f'api_bank.{basename}"),
187
+ ]
188
+
189
+ process_files(apis_folder, apis_replacements)
190
+ process_files(api_bank_folder, api_bank_replacements)
191
+
192
+ def run( # type: ignore[override, return]
193
+ self,
194
+ agent: ChatAgent,
195
+ level: Literal["level-1", "level-2"],
196
+ api_test_enabled=True,
197
+ randomize: bool = False,
198
+ subset: Optional[int] = None,
199
+ ) -> Dict[str, Any]:
200
+ r"""Run the benchmark.
201
+
202
+ Args:
203
+ agent (ChatAgent): The agent to run the
204
+ benchmark.
205
+ level (Literal['level-1', 'level-2']):
206
+ The level to run the benchmark on.
207
+ randomize (bool, optional): Whether to
208
+ randomize the data.
209
+ api_test_enabled (bool): Whether to test
210
+ API calling (`True`) or response (`False`)
211
+ (default: :obj:`False`)
212
+ subset (Optional[int], optional):
213
+ The subset of data to run.
214
+ (default: :obj:`None`)
215
+
216
+ Returns:
217
+ Dict[str, Any]: The results of the benchmark.
218
+ """
219
+ logger.info(f"Running APIBench benchmark on {level}.")
220
+ self.load(level)
221
+ datas = self._data
222
+
223
+ # Shuffle and subset data if necessary
224
+ if randomize:
225
+ randomized_items = list(datas.items())
226
+ random.shuffle(randomized_items)
227
+ datas = dict(randomized_items)
228
+ if subset:
229
+ datas = dict(list(datas.items())[:subset])
230
+
231
+ logger.info(f"Number of tasks: {len(datas)}")
232
+
233
+ # Initialize results storage
234
+ self._results = []
235
+
236
+ # The following code are adapted from the evaluator
237
+ # from the original repo:
238
+ tool_search_enabled = level == "level-2"
239
+ dialog_test_enabled = not api_test_enabled
240
+ total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
241
+
242
+ with open(self.save_to, "w") as f:
243
+ for test in tqdm(datas, desc="Running"):
244
+ samples = self._data[test]
245
+ evaluator = Evaluator(samples) # type: ignore[arg-type]
246
+
247
+ for sample_id in evaluator.get_all_sample_ids():
248
+ # Process sample and generate response
249
+ sample = evaluator.dataset[sample_id]
250
+
251
+ if (
252
+ sample.ground_truth['role'] == 'API'
253
+ and api_test_enabled
254
+ ):
255
+ if tool_search_enabled:
256
+ _, chat_history = evaluator.get_model_input(
257
+ sample_id
258
+ )
259
+ api_descriptions = evaluator.get_api_description(
260
+ 'ToolSearcher'
261
+ )
262
+ else:
263
+ api_descriptions, chat_history = (
264
+ evaluator.get_model_input(sample_id)
265
+ )
266
+ messages = process_messages(
267
+ chat_history, API_CALL_PROMPT + api_descriptions
268
+ )
269
+ model_output = agent_call(messages, agent)
270
+ api_call = get_api_call(model_output)
271
+
272
+ # Evaluate API call
273
+ if api_call:
274
+ try:
275
+ correct, model_output_result = (
276
+ evaluator.evaluate(sample_id, api_call)
277
+ )
278
+ except AssertionError as e:
279
+ if 'The API name is not correct.' not in str(
280
+ e
281
+ ):
282
+ raise e
283
+ logging.info('AssertionError: {}'.format(e))
284
+ correct = False
285
+ else:
286
+ model_output_result = 'No API call found'
287
+ correct = False
288
+ if correct:
289
+ correct_api_calls += 1
290
+ logging.info(
291
+ 'Correct API call: {} Ground truth: {}'.format(
292
+ api_call, sample.ground_truth
293
+ )
294
+ )
295
+ else:
296
+ logging.info(
297
+ 'Incorrect model output: {} Result: {} \
298
+ Ground truth: {} File: {} Sample ID: {} \
299
+ Messages: {}'.format(
300
+ model_output.replace('\n', ' '),
301
+ model_output_result,
302
+ sample.ground_truth,
303
+ test,
304
+ sample_id,
305
+ messages[1:],
306
+ )
307
+ )
308
+ total_api_calls += 1
309
+ self._results.append(
310
+ {
311
+ 'Role': 'API',
312
+ 'Model_output': model_output,
313
+ 'Model_output_result': model_output_result,
314
+ 'Ground_truth': sample.ground_truth,
315
+ 'Test': test,
316
+ 'Correct': correct,
317
+ }
318
+ )
319
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
320
+
321
+ elif (
322
+ sample.ground_truth['role'] == 'AI'
323
+ and dialog_test_enabled
324
+ ):
325
+ # Process sample and generate response
326
+ api_descriptions, chat_history = (
327
+ evaluator.get_model_input(sample_id)
328
+ )
329
+
330
+ messages = process_messages(
331
+ chat_history, RESPONSE_PROMPT + api_descriptions
332
+ )
333
+ model_output = agent_call(messages, agent)
334
+
335
+ # Evaluate model response
336
+ if model_output:
337
+ score = evaluator.evaluate(sample_id, model_output)
338
+ else:
339
+ score = 0
340
+ rougel_scores.append(score)
341
+ if score < 0.2:
342
+ logging.info(
343
+ 'Low score: {} Score: {} Ground truth: {} \
344
+ Test: {} Sample ID: {} \
345
+ Messages: {}'.format(
346
+ model_output.replace('\n', ' '),
347
+ score,
348
+ sample.ground_truth,
349
+ test,
350
+ sample_id,
351
+ messages[1:],
352
+ )
353
+ )
354
+
355
+ self._results.append(
356
+ {
357
+ 'Role': 'AI',
358
+ 'Model_output': model_output,
359
+ 'Score': score,
360
+ 'Ground_truth': sample.ground_truth,
361
+ 'Test': test,
362
+ }
363
+ )
364
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
365
+
366
+ f.flush()
367
+
368
+ if api_test_enabled:
369
+ return {
370
+ 'total': total_api_calls,
371
+ 'correct': correct_api_calls,
372
+ "accuracy": correct_api_calls / total_api_calls
373
+ if total_api_calls
374
+ else 0,
375
+ }
376
+ elif dialog_test_enabled:
377
+ return {'Dialog_score': np.mean(rougel_scores)}
378
+
379
+
380
+ # The following code are migrated from the original repo:
381
+ # https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
382
+ def agent_call(messages: List[Dict], agent: ChatAgent):
383
+ r"""Add messages to agent memory and get response."""
384
+ for i, msg in enumerate(messages):
385
+ if msg['role'] == 'user':
386
+ message = BaseMessage.make_user_message(
387
+ role_name="CAMEL User", content=msg['content']
388
+ )
389
+ elif msg['role'] == 'assistant':
390
+ message = BaseMessage.make_assistant_message(
391
+ role_name="CAMEL Assistant", content=msg['content']
392
+ )
393
+ elif msg['role'] == 'system':
394
+ message = BaseMessage.make_assistant_message(
395
+ role_name="System", content=msg['content']
396
+ )
397
+ else:
398
+ raise ValueError(f"Unrecognized role: {msg['role']}")
399
+
400
+ if i == len(messages) - 1:
401
+ break
402
+ agent.record_message(message)
403
+
404
+ response = agent.step(message)
405
+ model_output = response.msgs[0].content
406
+ agent.reset()
407
+ return model_output
408
+
409
+
410
+ def calculate_rouge_l_score(reference, hypothesis):
411
+ r"""Calculate rouge l score between hypothesis and reference."""
412
+ rouge = Rouge()
413
+ scores = rouge.get_scores(hypothesis, reference)
414
+ rouge_l_score = scores[0]['rouge-l']['f']
415
+ return rouge_l_score
416
+
417
+
418
+ def get_api_call(model_output):
419
+ r"""Parse api call from model output."""
420
+ api_call_pattern = r"\[(\w+)\((.*)\)\]"
421
+ api_call_pattern = re.compile(api_call_pattern)
422
+ match = api_call_pattern.search(model_output)
423
+ if match:
424
+ return match.group(0)
425
+ else:
426
+ return None
427
+
428
+
429
+ class APIBankSample:
430
+ r"""APIBank sample used to load the datasets."""
431
+
432
+ def __init__(self, chat_history, apis, ground_truth):
433
+ self.chat_history = chat_history
434
+ self.apis = apis
435
+ self.ground_truth = ground_truth
436
+
437
+ def __repr__(self):
438
+ return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
439
+ self.chat_history, self.apis, self.ground_truth
440
+ )
441
+
442
+ @classmethod
443
+ def from_chat_history(cls, chat_history):
444
+ apis = set()
445
+ api_positions = []
446
+ for i, item in enumerate(chat_history):
447
+ if item['role'] == 'API':
448
+ apis.add(item['api_name'])
449
+ api_positions.append(i)
450
+
451
+ samples = []
452
+ for i in api_positions:
453
+ sample = cls(chat_history[:i], apis, chat_history[i])
454
+ samples.append(sample)
455
+ sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
456
+ samples.append(sample)
457
+
458
+ return samples
459
+
460
+
461
+ class Evaluator:
462
+ r"""Evaluator for APIBank benchmark."""
463
+
464
+ def __init__(self, samples: List[APIBankSample]):
465
+ # Place holder for import as the import
466
+ # only works after the files have been downloaded
467
+ try:
468
+ from api_bank.tool_manager import ( # type: ignore[import-not-found]
469
+ ToolManager,
470
+ )
471
+ except Exception as e:
472
+ logger.info(f"{e}, Module will be imported after download.")
473
+ self.dataset = samples
474
+ self.sample_ids = list(range(len(self.dataset)))
475
+ os.chdir("api_bank")
476
+ self.tool_manager = ToolManager("apis")
477
+ os.chdir("..")
478
+
479
+ def get_all_sample_ids(self):
480
+ return self.sample_ids
481
+
482
+ def get_api_description(self, api_name):
483
+ return self.tool_manager.get_api_description(api_name)
484
+
485
+ def get_model_input(self, sample_id: int):
486
+ sample = self.dataset[sample_id]
487
+ apis = sample.apis
488
+ chat_history = sample.chat_history
489
+ api_descriptions = []
490
+ for api_name in apis:
491
+ api_descriptions.append(
492
+ self.tool_manager.get_api_description(api_name)
493
+ )
494
+ api_description = '\n'.join(api_descriptions)
495
+ return api_description, chat_history
496
+
497
+ def evaluate(self, sample_id, model_output):
498
+ try:
499
+ from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
500
+ parse_api_call,
501
+ )
502
+ except Exception as e:
503
+ logger.info(f"{e}, Module will be imported after download.")
504
+ sample = self.dataset[sample_id]
505
+ ground_truth = sample.ground_truth
506
+ if ground_truth['role'] == 'API':
507
+ api_name, param_dict = parse_api_call(model_output)
508
+ if api_name != ground_truth['api_name']:
509
+ return False, 'API Name Mismatch: {} vs {}'.format(
510
+ api_name, ground_truth['api_name']
511
+ )
512
+ try:
513
+ result = self.tool_manager.api_call(api_name, **param_dict)
514
+ except Exception as e:
515
+ return False, str(e)
516
+ api = self.tool_manager.init_tool(api_name)
517
+ try:
518
+ correct = api.check_api_call_correctness(
519
+ result, ground_truth['result']
520
+ )
521
+ except KeyError:
522
+ correct = False
523
+ result = 'KeyError' + str(result)
524
+ return correct, result
525
+ elif ground_truth['role'] == 'AI':
526
+ score = calculate_rouge_l_score(ground_truth['text'], model_output)
527
+ return round(score, 4)
528
+
529
+
530
+ API_CALL_PROMPT = '''
531
+ Based on the given API description and the existing \
532
+ conversation history 1..t, please generate the API request \
533
+ that the AI should call in step t+1 and output it in the \
534
+ format of [ApiName(key1='value1', key2='value2', ...)], \
535
+ replace the ApiName with the actual API name, and \
536
+ replace the key and value with the actual parameters. \
537
+ Your output should start with a square bracket "[" \
538
+ and end with a square bracket "]". Do not output any \
539
+ other explanation or prompt or the result of the API call in your output.
540
+ This year is 2023.
541
+ Input:
542
+ User: [User's utterence]
543
+ AI: [AI's utterence]
544
+
545
+ Expected output:
546
+ [ApiName(key1='value1', key2='value2', ...)]
547
+
548
+ API descriptions:
549
+ '''
550
+
551
+ RESPONSE_PROMPT = '''
552
+ Based on the given API description and the existing \
553
+ conversation history 1..t, please generate the next \
554
+ dialog that the AI should response after the API call t.
555
+ This year is 2023.
556
+ Input:
557
+ User: [User's utterence]
558
+ AI: [AI's utterence]
559
+ [ApiName(key1='value1', key2='value2', …)]
560
+
561
+ Expected output:
562
+ AI: [AI's utterence]
563
+
564
+ API descriptions:
565
+ '''
camel/benchmarks/apibench.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import random
18
+ from pathlib import Path
19
+ from typing import Any, Dict, Literal, Optional
20
+
21
+ import tree_sitter_python as tspython
22
+ from tqdm import tqdm
23
+ from tree_sitter import Language, Parser
24
+
25
+ from camel.agents import ChatAgent
26
+ from camel.benchmarks.base import BaseBenchmark
27
+ from camel.messages import BaseMessage
28
+ from camel.utils import download_github_subdirectory
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+
33
+ # Mapping of dataset names to file names
34
+ # 'Oracle' retriver used here which means all the full
35
+ # API documentation will be included in the prompt
36
+ dataset_mapping = {
37
+ "huggingface": {
38
+ "api": "huggingface_api.jsonl",
39
+ "eval": "huggingface_eval.json",
40
+ "train": "huggingface_train.json",
41
+ "questions": "questions_huggingface_oracle.jsonl",
42
+ },
43
+ "tensorflowhub": {
44
+ "api": "tensorflowhub_api.jsonl",
45
+ "eval": "tensorflow_eval.json",
46
+ "train": "tensorflow_train.json",
47
+ "questions": "questions_tensorflowhub_oracle.jsonl",
48
+ },
49
+ "torchhub": {
50
+ "api": "torchhub_api.jsonl",
51
+ "eval": "torchhub_eval.json",
52
+ "train": "torchhub_train.json",
53
+ "questions": "questions_torchhub_oracle.jsonl",
54
+ },
55
+ }
56
+
57
+
58
+ # This function is migrated from the original repo:
59
+ # https://github.com/ShishirPatil/gorilla
60
+ def encode_question(question: str, dataset_name: str) -> str:
61
+ r"""Encode multiple prompt instructions into a single string."""
62
+
63
+ if dataset_name == "torchhub":
64
+ domains = "1. $DOMAIN is inferred from the task description and \
65
+ should include one of {Classification, Semantic Segmentation, \
66
+ Object Detection, Audio Separation, Video Classification, \
67
+ Text-to-Speech}."
68
+ elif dataset_name == "huggingface":
69
+ domains = "1. $DOMAIN should include one of {Multimodal Feature \
70
+ Extraction, Multimodal Text-to-Image, Multimodal \
71
+ Image-to-Text, Multimodal Text-to-Video, \
72
+ Multimodal Visual Question Answering, Multimodal Document \
73
+ Question Answer, Multimodal Graph Machine Learning, \
74
+ Computer Vision Depth Estimation, Computer Vision Image \
75
+ Classification, Computer Vision Object Detection, \
76
+ Computer Vision Image Segmentation, Computer Vision \
77
+ Image-to-Image, Computer Vision Unconditional \
78
+ Image Generation, Computer Vision Video Classification, \
79
+ Computer Vision Zero-Shor Image Classification, \
80
+ Natural Language Processing Text Classification, \
81
+ Natural Language Processing Token Classification, \
82
+ Natural Language Processing Table Question Answering, \
83
+ Natural Language Processing Question Answering, \
84
+ Natural Language Processing, Zero-Shot Classification \
85
+ Natural Language Processing Translation, Natural Language \
86
+ Processing Summarization, Natural Language Processing \
87
+ Conversational, Natural Language Processing Text \
88
+ Generation, Natural Language Processing Fill-Mask, \
89
+ Natural Language Processing Text2Text Generation, \
90
+ Natural Language Processing Sentence Similarity, \
91
+ Audio Text-to-Speech, Audio Automatic Speech Recognition, \
92
+ Audio Audio-to-Audio, Audio Audio Classification, \
93
+ Audio Voice Activity Detection, Tabular Tabular \
94
+ Classification, Tabular Tabular Regression, \
95
+ Reinforcement Learning Reinforcement Learning, \
96
+ Reinforcement Learning Robotics }"
97
+ elif dataset_name == "tensorflowhub":
98
+ domains = "1. $DOMAIN is inferred from the task description \
99
+ and should include one of {text-sequence-alignment, \
100
+ text-embedding, text-language-model, text-preprocessing, \
101
+ text-classification, text-generation, text-question-answering, \
102
+ text-retrieval-question-answering, text-segmentation, \
103
+ text-to-mel, image-classification, image-feature-vector, \
104
+ image-object-detection, image-segmentation, \
105
+ image-generator, image-pose-detection, image-rnn-agent, \
106
+ image-augmentation, image-classifier, image-style-transfer, \
107
+ image-aesthetic-quality, image-depth-estimation, \
108
+ image-super-resolution, image-deblurring, image-extrapolation, \
109
+ image-text-recognition, image-dehazing, image-deraining, \
110
+ image-enhancemenmt, image-classification-logits, \
111
+ image-frame-interpolation, image-text-detection, image-denoising, \
112
+ image-others, video-classification, video-feature-extraction, \
113
+ video-generation, video-audio-text, video-text, \
114
+ audio-embedding, audio-event-classification, audio-command-detection, \
115
+ audio-paralinguists-classification, audio-speech-to-text, \
116
+ audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
117
+ else:
118
+ logger.info("Error: API name is not supported.")
119
+
120
+ prompt = (
121
+ question
122
+ + "\nWrite a python program in 1 to 2 lines to call API in "
123
+ + dataset_name
124
+ + ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
125
+ <<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
126
+ <<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
127
+ Here are the requirements:\n"
128
+ + domains
129
+ + "\n2. The $API_CALL should have only 1 line of code \
130
+ that calls api.\n 3. The $API_PROVIDER should be the \
131
+ programming framework used.\n4. $EXPLANATION should be \
132
+ a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
133
+ Do not repeat the format in your answer."
134
+ )
135
+ return prompt
136
+
137
+
138
+ class APIBenchBenchmark(BaseBenchmark):
139
+ r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
140
+ Connected with Massive APIs`
141
+ <https://huggingface.co/datasets/gorilla-llm/APIBench>.
142
+
143
+ Args:
144
+ data_dir (str): The directory to save the data.
145
+ save_to (str): The file to save the results.
146
+ processes (int, optional): The number of processes to use.
147
+ (default: :obj:`1`)
148
+ """
149
+
150
+ # TODO: Integrate retriever (pending)
151
+
152
+ def __init__(
153
+ self,
154
+ data_dir: str,
155
+ save_to: str,
156
+ processes: int = 1,
157
+ ):
158
+ r"""Initialize the APIBench benchmark.
159
+
160
+ Args:
161
+ data_dir (str): The directory to save the data.
162
+ save_to (str): The file to save the results.
163
+ processes (int, optional): The number of processes to use for
164
+ parallel processing. (default: :obj:`1`)
165
+ """
166
+ super().__init__("apibench", data_dir, save_to, processes)
167
+
168
+ def download(self):
169
+ r"""Download the APIBench dataset."""
170
+ from huggingface_hub import snapshot_download
171
+
172
+ snapshot_download(
173
+ repo_id="gorilla-llm/APIBench",
174
+ repo_type="dataset",
175
+ local_dir=self.data_dir,
176
+ local_dir_use_symlinks=True,
177
+ )
178
+
179
+ repo = "ShishirPatil/gorilla"
180
+ subdir = "/gorilla/eval/eval-data/questions"
181
+ data_dir = self.data_dir
182
+
183
+ download_github_subdirectory(repo, subdir, data_dir)
184
+
185
+ def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
186
+ r"""Load the APIBench Benchmark dataset.
187
+
188
+ Args:
189
+ dataset_name (str): Name of the specific dataset to be loaded.
190
+ force_download (bool, optional): Whether to force
191
+ download the data. (default: :obj:`False`)
192
+ """
193
+
194
+ if force_download:
195
+ logger.info("Force downloading data.")
196
+ self.download()
197
+
198
+ def load_json_lines(file_path: Path):
199
+ r"""Helper function to load JSON lines from a file."""
200
+ try:
201
+ with open(file_path, "r") as f:
202
+ return [json.loads(line) for line in f]
203
+ except FileNotFoundError:
204
+ raise FileNotFoundError(f"File not found: {file_path}")
205
+ except json.JSONDecodeError as e:
206
+ raise ValueError(
207
+ f"Error decoding JSON in file {file_path}: {e}"
208
+ )
209
+
210
+ dataset_path = self.data_dir / dataset_name
211
+ if not dataset_path.exists():
212
+ raise FileNotFoundError(
213
+ f"Dataset directory does not exist: {dataset_path}"
214
+ )
215
+
216
+ for label in ['api', 'eval', 'questions']:
217
+ file_name = dataset_mapping[dataset_name][label]
218
+ file_path = (
219
+ dataset_path / file_name
220
+ if label == 'questions'
221
+ else self.data_dir / file_name
222
+ )
223
+
224
+ # Load data based on label type
225
+ if label in ['api', 'questions', 'eval']:
226
+ data = load_json_lines(file_path)
227
+
228
+ if label == 'eval':
229
+ # Extract 'api_data' specifically for eval label
230
+ data = [item['api_data'] for item in data]
231
+
232
+ self._data[label] = data
233
+ else:
234
+ raise ValueError(f"Unknown label: {label}")
235
+
236
+ ast_database = []
237
+ for data in self._data['api']:
238
+ ast_tree = ast_parse(data['api_call'])
239
+ ast_database.append(ast_tree)
240
+ self._data['ast'] = ast_database
241
+
242
+ def run( # type: ignore[override]
243
+ self,
244
+ agent: ChatAgent,
245
+ dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
246
+ randomize: bool = False,
247
+ subset: Optional[int] = None,
248
+ ) -> Dict[str, Any]:
249
+ r"""Run the benchmark.
250
+
251
+ Args:
252
+ agent (ChatAgent): The agent to run the
253
+ benchmark.
254
+ dataset_name (Literal["huggingface",
255
+ "tensorflowhub", "torchhub"]):
256
+ The dataset to run the benchmark.
257
+ randomize (bool, optional): Whether to randomize the data.
258
+ (default: :obj:`False`)
259
+ subset (Optional[int], optional): The subset of data to run.
260
+ (default: :obj:`None`)
261
+ """
262
+
263
+ if dataset_name not in dataset_mapping:
264
+ raise ValueError(f"Invalid value for dataset: {dataset_name}.")
265
+
266
+ logger.info(f"Running APIBench benchmark on {dataset_name}.")
267
+ self.load(dataset_name)
268
+ datas = self._data['questions']
269
+
270
+ # Shuffle and subset data if necessary
271
+ if randomize:
272
+ random.shuffle(datas)
273
+ if subset:
274
+ datas = datas[:subset]
275
+
276
+ logger.info(f"Number of tasks: {len(datas)}")
277
+
278
+ # Initialize results storage
279
+ self._results = []
280
+
281
+ with open(self.save_to, "w") as f:
282
+ for question in tqdm(datas, desc="Running"):
283
+ prompt = encode_question(question["text"], dataset_name)
284
+ msg = BaseMessage.make_user_message(
285
+ role_name="User", content=prompt
286
+ )
287
+ try:
288
+ # Generate response
289
+ responses = agent.step(msg)
290
+ response = responses.msgs[0].content
291
+ api_database = self._data['api']
292
+ qa_pairs = self._data['eval']
293
+ ast_database = self._data['ast']
294
+ question_id = question['question_id']
295
+
296
+ # Evaluate response
297
+ error, correct, hallucination = evaluate_response(
298
+ response,
299
+ question_id,
300
+ dataset_name,
301
+ api_database,
302
+ qa_pairs,
303
+ ast_database,
304
+ )
305
+ self._results.append(
306
+ {
307
+ "question": question,
308
+ "agent_response": response,
309
+ "correct": correct,
310
+ "hallucination": hallucination,
311
+ "error": str(error) if error else None,
312
+ }
313
+ )
314
+ except Exception as e:
315
+ logger.warning(
316
+ f"Error in processing task: {question}: {e}"
317
+ )
318
+ self._results.append(
319
+ {
320
+ "question": question,
321
+ "agent_response": None,
322
+ "correct": False,
323
+ "hallucination": False,
324
+ "error": str(e),
325
+ }
326
+ )
327
+
328
+ agent.reset()
329
+
330
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
331
+ f.flush()
332
+
333
+ total = len(self._results)
334
+ correct = sum(r["correct"] for r in self.results)
335
+ hallucination = sum(r["hallucination"] for r in self.results)
336
+
337
+ return {
338
+ "total": total,
339
+ "correct": correct,
340
+ "hallucination": hallucination,
341
+ "accuracy": correct / total if total else "N/A",
342
+ "hallucination rate": hallucination / total if total else "N/A",
343
+ }
344
+
345
+
346
+ # This code is modified from the
347
+ # evaluators in the original repo
348
+ # https://github.com/ShishirPatil/gorilla
349
+ # Get all the subtrees given a root_node
350
+ def get_all_sub_trees(root_node):
351
+ node_stack = []
352
+ sub_tree_sexp_list = []
353
+ depth = 1
354
+ # text = root_node.text
355
+ node_stack.append([root_node, depth])
356
+ while len(node_stack) != 0:
357
+ cur_node, cur_depth = node_stack.pop()
358
+ if cur_node.child_count > 0:
359
+ sub_tree_sexp_list.append(
360
+ [
361
+ str(cur_node),
362
+ cur_depth,
363
+ cur_node,
364
+ cur_node.children[0].text,
365
+ ]
366
+ )
367
+ else:
368
+ sub_tree_sexp_list.append(
369
+ [str(cur_node), cur_depth, cur_node, None]
370
+ )
371
+ for child_node in cur_node.children:
372
+ if len(child_node.children) != 0:
373
+ depth = cur_depth + 1
374
+ node_stack.append([child_node, depth])
375
+ return sub_tree_sexp_list
376
+
377
+
378
+ # Parse the program into AST trees
379
+ def ast_parse(candidate):
380
+ PY_LANGUAGE = Language(tspython.language())
381
+ parser = Parser(PY_LANGUAGE)
382
+
383
+ candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
384
+ return candidate_tree
385
+
386
+
387
+ # Get all the arguments in the ast tree
388
+ def get_args(node, dataset_name):
389
+ if node.child_count == 0:
390
+ return []
391
+ args_list = []
392
+ if dataset_name == "huggingface":
393
+ for child in node.children[0].children[0].children[1].children:
394
+ if "=" in child.text.decode():
395
+ args_list.append(child.children[2].text)
396
+ elif (
397
+ child.text.decode() != "("
398
+ and child.text.decode() != ")"
399
+ and child.text.decode() != ","
400
+ ):
401
+ args_list.append(child.text)
402
+ elif dataset_name == "tensorflowhub":
403
+ for child in node.children[0].children[0].children[1].children:
404
+ if (
405
+ 'model=' in child.text.decode()
406
+ or 'model =' in child.text.decode()
407
+ ):
408
+ args_list.append(child.children[2].text)
409
+ elif (
410
+ child.text.decode() != "("
411
+ and child.text.decode() != ")"
412
+ and child.text.decode() != ","
413
+ ):
414
+ args_list.append(child.text)
415
+ elif dataset_name == "torchhub":
416
+ for child in node.children[0].children[0].children[1].children:
417
+ if (
418
+ "repo_or_dir" in child.text.decode()
419
+ or "model" in child.text.decode()
420
+ ):
421
+ args_list.append(child.children[2].text)
422
+ return args_list
423
+
424
+
425
+ # Check if there is an api match
426
+ def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
427
+ for idx, base_tree in enumerate(base_tree_list):
428
+ if base_tree.children[0].children[0].child_count == 0:
429
+ continue
430
+ api_name = base_tree.children[0].children[0].children[0].text
431
+ for candidate_tree in candidate_subtree_list:
432
+ if candidate_tree[3] == api_name:
433
+ break
434
+ # Now we have a sub-tree
435
+ candidate_tree = candidate_tree[2]
436
+ args_list = get_args(base_tree, dataset_name)
437
+ if len(args_list) == 0:
438
+ continue
439
+ ast_match = True
440
+ for arg in args_list:
441
+ if (
442
+ arg.decode().lstrip("'").rstrip("'")
443
+ not in candidate_tree.text.decode()
444
+ ):
445
+ ast_match = False
446
+ break
447
+ if ast_match:
448
+ return idx
449
+ return -1
450
+
451
+
452
+ def evaluate_response(
453
+ response, question_id, dataset_name, api_database, qa_pairs, ast_database
454
+ ):
455
+ try:
456
+ # Index the "api_call" domain
457
+ output = response.split("api_call")
458
+ if len(output) == 1:
459
+ api_call = output[0]
460
+ else:
461
+ # Parse the output
462
+ output = output[1].split("api_provider")[0]
463
+ if ":" not in output:
464
+ start = 0
465
+ else:
466
+ start = output.index(":")
467
+ if ")" not in output:
468
+ end = -2
469
+ else:
470
+ end = output.rindex(")")
471
+ api_call = output[start + 2 : end + 1]
472
+
473
+ try:
474
+ ast_tree = ast_parse(api_call)
475
+ except Exception as parse_error:
476
+ print(f"Error parsing api_call: {api_call}, error: {parse_error}")
477
+ return parse_error, False, False
478
+ # Search for a subtree
479
+ ast_subtree_list = get_all_sub_trees(ast_tree)
480
+ # Check which ast tree is matching
481
+ database_index = ast_check(
482
+ ast_subtree_list, ast_database, dataset_name
483
+ )
484
+ # We cannot index this ast in our database
485
+ if database_index == -1:
486
+ halluncination = True
487
+ correct = False
488
+ # We index our reference api_call
489
+ ref_api_call = api_database[database_index]
490
+ # Check for functionality
491
+ if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
492
+ correct = True
493
+ halluncination = False
494
+ else:
495
+ return None, False, False
496
+ except Exception as e:
497
+ print(f'Error parsing response: {response}, error: {e}')
498
+ return e, False, False
499
+
500
+ return None, correct, halluncination
camel/benchmarks/base.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import logging
16
+ from abc import ABC, abstractmethod
17
+ from pathlib import Path
18
+ from typing import Any, Dict, List, Literal, Optional
19
+
20
+ from camel.agents import ChatAgent
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class BaseBenchmark(ABC):
26
+ r"""Base class for benchmarks.
27
+
28
+ Attributes:
29
+ name (str): Name of the benchmark.
30
+ data_dir (str): Path to the data directory.
31
+ save_to (str): Path to save the results.
32
+ processes (int): Number of processes to use for parallel
33
+ processing. :(default: :obj:`1`)
34
+ """
35
+
36
+ def __init__(
37
+ self, name: str, data_dir: str, save_to: str, processes: int = 1
38
+ ):
39
+ r"""Initialize the benchmark.
40
+
41
+ Args:
42
+ name (str): Name of the benchmark.
43
+ data_dir (str): Path to the data directory.
44
+ save_to (str): Path to save the results.
45
+ processes (int): Number of processes to use for parallel
46
+ processing. :(default: :obj:`1`)
47
+
48
+ """
49
+ self.name = name
50
+ self.data_dir = Path(data_dir)
51
+ self.processes = processes
52
+ self.save_to = save_to
53
+ if not self.data_dir.exists():
54
+ logger.info(
55
+ f"Data directory {data_dir} does not exist. Creating it."
56
+ )
57
+ self.data_dir.mkdir(parents=True, exist_ok=True)
58
+ if not self.data_dir.is_dir():
59
+ raise NotADirectoryError(
60
+ f"Data directory {data_dir} is not a directory"
61
+ )
62
+ self._data: Dict[str, List[Dict[str, Any]]] = dict()
63
+ self._results: List[Dict[str, Any]] = []
64
+
65
+ @abstractmethod
66
+ def download(self) -> "BaseBenchmark":
67
+ r"""Download the benchmark data.
68
+
69
+ Returns:
70
+ BaseBenchmark: The benchmark instance.
71
+ """
72
+ pass
73
+
74
+ @abstractmethod
75
+ def load(self, force_download: bool = False) -> "BaseBenchmark":
76
+ r"""Load the benchmark data.
77
+
78
+ Args:
79
+ force_download (bool): Whether to force download the data.
80
+
81
+ Returns:
82
+ BaseBenchmark: The benchmark instance.
83
+ """
84
+ pass
85
+
86
+ @property
87
+ def train(self) -> List[Dict[str, Any]]:
88
+ r"""Get the training data.
89
+
90
+ Returns:
91
+ List[Dict[str, Any]]: The training data.
92
+ """
93
+ if not self._data:
94
+ logger.info("Data not loaded. Loading data.")
95
+ self.load()
96
+ return self._data["train"]
97
+
98
+ @property
99
+ def valid(self) -> List[Dict[str, Any]]:
100
+ r"""Get the validation data.
101
+
102
+ Returns:
103
+ List[Dict[str, Any]]: The validation data.
104
+ """
105
+ if not self._data:
106
+ logger.info("Data not loaded. Loading data.")
107
+ self.load()
108
+ return self._data["valid"]
109
+
110
+ @property
111
+ def test(self) -> List[Dict[str, Any]]:
112
+ r"""Get the test data.
113
+
114
+ Returns:
115
+ List[Dict[str, Any]]: The test data.
116
+ """
117
+ if not self._data:
118
+ logger.info("Data not loaded. Loading data.")
119
+ self.load()
120
+ return self._data["test"]
121
+
122
+ @abstractmethod
123
+ def run(
124
+ self,
125
+ agent: ChatAgent,
126
+ on: Literal["train", "valid", "test"],
127
+ randomize: bool = False,
128
+ subset: Optional[int] = None,
129
+ *args,
130
+ **kwargs,
131
+ ) -> "BaseBenchmark":
132
+ r"""Run the benchmark.
133
+
134
+ Args:
135
+ agent (ChatAgent): The chat agent.
136
+ on (str): The data split to run the benchmark on.
137
+ randomize (bool): Whether to randomize the data.
138
+ subset (int): The subset of the data to run the benchmark on.
139
+
140
+ Returns:
141
+ BaseBenchmark: The benchmark instance.
142
+ """
143
+ pass
144
+
145
+ @property
146
+ def results(self) -> List[Dict[str, Any]]:
147
+ r"""Get the results.
148
+
149
+ Returns:
150
+ List[Dict[str, Any]]: The results.
151
+ """
152
+ return self._results
camel/benchmarks/gaia.py ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import json
16
+ import logging
17
+ import os
18
+ import random
19
+ import re
20
+ import string
21
+ import uuid
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Literal, Optional, Protocol, Union
24
+
25
+ from tqdm import tqdm
26
+
27
+ from camel.agents import ChatAgent
28
+ from camel.benchmarks.base import BaseBenchmark
29
+ from camel.messages import BaseMessage
30
+ from camel.retrievers.auto_retriever import AutoRetriever
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ class RetrieverProtocol(Protocol):
36
+ r"""Protocol for the retriever class. Any retriever class implementing
37
+ this protocol can be used in the benchmark class.
38
+ """
39
+
40
+ def retrieve(
41
+ self, query: str, contents: List[str], **kwargs: Dict[str, Any]
42
+ ) -> Dict[str, Any]:
43
+ r"""Retrieve the relevant content for the query.
44
+
45
+ Args:
46
+ query (str): The query to retrieve the content for.
47
+ contents (List[str]): The list of contents to search in.
48
+ **kwargs (Dict[str, Any]): Additional keyword arguments.
49
+
50
+ Returns:
51
+ Dict[str, Any]: The relevant content for the query.
52
+ """
53
+ ...
54
+
55
+ def reset(self, **kwargs) -> bool:
56
+ r"""Reset the retriever.
57
+ Some benchmarks may require resetting the retriever
58
+ after each query.
59
+
60
+ Args:
61
+ **kwargs: Additional keyword arguments.
62
+
63
+ Returns:
64
+ bool: True if the reset was successful, False otherwise.
65
+ """
66
+ ...
67
+
68
+
69
+ class DefaultGAIARetriever(AutoRetriever):
70
+ r"""Default retriever for the GAIA benchmark.
71
+ This retriever uses AutoRetriever in camel to retrieve the content based on
72
+ the query.
73
+ """
74
+
75
+ def retrieve(
76
+ self, query: str, contents: List[str], **kwargs: Any
77
+ ) -> Dict[str, Any]:
78
+ r"""Retrieve the content based on the query.
79
+
80
+ Args:
81
+ query (str): The query to search for.
82
+ contents (List[str]): The list of contents to search from.
83
+ **kwargs (Any): The keyword arguments to pass to the
84
+ retriever.
85
+
86
+ Returns:
87
+ Dict[str, Any]: The retrieved content.
88
+ """
89
+ return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
90
+
91
+ def reset(self, **kwargs: Any) -> bool:
92
+ r"""Reset the retriever.
93
+
94
+ Args:
95
+ **kwargs (Any): The keyword arguments to pass to the
96
+ retriever.
97
+
98
+ Returns:
99
+ bool: Whether the reset was successful.
100
+ """
101
+ path = Path(self.vector_storage_local_path or os.getcwd())
102
+ task_id = str(kwargs.get("task_id", uuid.uuid4()))
103
+ retriever_dir = path / task_id
104
+ if not retriever_dir.exists():
105
+ try:
106
+ retriever_dir.mkdir(parents=True)
107
+ except Exception as e:
108
+ logger.error(
109
+ "Error in creating directory: " + f"{retriever_dir}: {e!s}"
110
+ )
111
+ return False
112
+ self.vector_storage_local_path = str(retriever_dir)
113
+ return True
114
+
115
+
116
+ class GAIABenchmark(BaseBenchmark):
117
+ r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
118
+ Assistants"
119
+ <https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
120
+
121
+ Args:
122
+ data_dir (str): The directory to save the data.
123
+ save_to (str): The file to save the results.
124
+ retriever (Optional[RetrieverProtocol]): The retriever to use.
125
+ (default: :obj:`None`)
126
+ processes (int, optional): The number of processes to use.
127
+ (default: :obj:`1`)
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ data_dir: str,
133
+ save_to: str,
134
+ retriever: Optional[RetrieverProtocol] = None,
135
+ processes: int = 1,
136
+ ):
137
+ r"""Initialize the GAIA benchmark.
138
+
139
+ Args:
140
+ data_dir (str): The directory to save the data.
141
+ save_to (str): The file to save the results.
142
+ retriever (Optional[RetrieverProtocol], optional): The retriever to
143
+ use. (default: :obj:`None`)
144
+ processes (int, optional): The number of processes to use for
145
+ parallel processing. (default: :obj:`1`)
146
+ """
147
+ super().__init__("gaia", data_dir, save_to, processes)
148
+ self.retriever = retriever or DefaultGAIARetriever()
149
+
150
+ def download(self):
151
+ r"""Download the GAIA dataset."""
152
+ from huggingface_hub import snapshot_download
153
+
154
+ snapshot_download(
155
+ repo_id="gaia-benchmark/GAIA",
156
+ repo_type="dataset",
157
+ local_dir=self.data_dir,
158
+ local_dir_use_symlinks=True,
159
+ )
160
+
161
+ def load(self, force_download=False):
162
+ r"""Load the GAIA dataset.
163
+
164
+ Args:
165
+ force_download (bool, optional): Whether to
166
+ force download the data.
167
+ """
168
+ if force_download:
169
+ logger.info("Force downloading data.")
170
+ self.download()
171
+
172
+ # Define validation and test directories
173
+ valid_dir = self.data_dir / "2023/validation"
174
+ test_dir = self.data_dir / "2023/test"
175
+
176
+ # Check if directories exist; if not, download the data
177
+ if not valid_dir.is_dir() or not test_dir.is_dir():
178
+ logger.info("Data not found. Downloading data.")
179
+ self.download()
180
+
181
+ # Load metadata for both validation and test datasets
182
+ for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
183
+ self._data[label] = []
184
+ with open(path / "metadata.jsonl", "r") as f:
185
+ lines = f.readlines()
186
+ for line in lines:
187
+ data = json.loads(line)
188
+ if data["task_id"] == "0-0-0-0-0":
189
+ continue
190
+ if data["file_name"]:
191
+ data["file_name"] = path / data["file_name"]
192
+ self._data[label].append(data)
193
+ return self
194
+
195
+ @property
196
+ def train(self):
197
+ r"""Get the training set."""
198
+ raise NotImplementedError("GAIA does not have a training set.")
199
+
200
+ def run( # type: ignore[override]
201
+ self,
202
+ agent: ChatAgent,
203
+ on: Literal["train", "valid", "test"],
204
+ level: Union[int, List[int], Literal["all"]],
205
+ randomize: bool = False,
206
+ subset: Optional[int] = None,
207
+ ) -> Dict[str, Any]:
208
+ r"""Run the benchmark.
209
+
210
+ Args:
211
+ agent (ChatAgent): The agent to run the benchmark.
212
+ on (Literal["valid", "test"]): The set to run the benchmark.
213
+ level (Union[int, List[int], Literal["all"]]): The level to run
214
+ the benchmark.
215
+ randomize (bool, optional): Whether to randomize the data.
216
+ (default: :obj:`False`)
217
+ subset (Optional[int], optional): The subset of data to run.
218
+ (default: :obj:`None`)
219
+
220
+ Returns:
221
+ Dict[str, Any]: The results of the benchmark.
222
+ """
223
+ # Validate inputs
224
+ if on not in ["valid", "test"]:
225
+ raise ValueError(
226
+ f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
227
+ )
228
+
229
+ levels = (
230
+ [1, 2, 3]
231
+ if level == "all"
232
+ else [level]
233
+ if isinstance(level, int)
234
+ else level
235
+ )
236
+ if not all(
237
+ isinstance(level, int) and level in [1, 2, 3] for level in levels
238
+ ):
239
+ raise ValueError(
240
+ f"Invalid value for `level`: {level}, expected 1, 2, 3 "
241
+ "or 'all'."
242
+ )
243
+
244
+ logger.info(f"Running benchmark on {on} set at levels {levels}.")
245
+ datas = [data for data in self._data[on] if data["Level"] in levels]
246
+
247
+ # Shuffle and subset data if necessary
248
+ if randomize:
249
+ random.shuffle(datas)
250
+ if subset:
251
+ datas = datas[:subset]
252
+
253
+ logger.info(f"Number of tasks: {len(datas)}")
254
+
255
+ # Initialize results storage
256
+ self._results = []
257
+
258
+ # Process tasks
259
+ with open(self.save_to, "w") as f:
260
+ for task in tqdm(datas, desc="Running"):
261
+ if not self._prepare_task(task):
262
+ continue
263
+
264
+ try:
265
+ result = agent.step(self._create_user_message(task))
266
+ self._process_result(agent, task, result, f)
267
+ except Exception as e:
268
+ self._handle_error(task, e, f)
269
+ finally:
270
+ agent.reset()
271
+
272
+ return self._generate_summary()
273
+
274
+ def _prepare_task(self, task: Dict[str, Any]) -> bool:
275
+ r"""Prepare the task by validating and enriching its data."""
276
+ if task["file_name"]:
277
+ file_path = Path(task["file_name"])
278
+ if not file_path.exists():
279
+ logger.info(
280
+ f"Skipping task because file not found: {file_path}"
281
+ )
282
+ return False
283
+ if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
284
+ if not self.retriever.reset(task_id=task["task_id"]):
285
+ return False
286
+ retrieved_info = self.retriever.retrieve(
287
+ query=task["Question"], contents=[task["file_name"]]
288
+ )
289
+ retrieved_content = [
290
+ item["text"]
291
+ for item in retrieved_info.get("Retrieved Context", [])
292
+ ]
293
+ if retrieved_content:
294
+ task["Question"] += "\n" + "\n".join(retrieved_content)
295
+ else:
296
+ logger.info(
297
+ f"Skipping task due to unsupported file "
298
+ f"format: {file_path.suffix}"
299
+ )
300
+ return False
301
+ return True
302
+
303
+ def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
304
+ r"""Create a user message from a task."""
305
+ return BaseMessage.make_user_message(
306
+ role_name="User",
307
+ content=task["Question"],
308
+ )
309
+
310
+ def _process_result(
311
+ self,
312
+ agent: ChatAgent,
313
+ task: Dict[str, Any],
314
+ result: Any,
315
+ file_obj: Any,
316
+ ) -> None:
317
+ r"""Process and store the result of a task."""
318
+ model_answer = self.get_final_answer(result.msgs[0].content)
319
+ final_answer = task["Final answer"]
320
+ score = self.question_scorer(model_answer, final_answer)
321
+ tool_calls = result.info.get("tool_calls", [])
322
+
323
+ result_data = {
324
+ "task_id": task["task_id"],
325
+ "question": task["Question"],
326
+ "level": task["Level"],
327
+ "model_answer": model_answer,
328
+ "ground_truth": final_answer,
329
+ "tool_calls": [tool.model_dump() for tool in tool_calls],
330
+ "error": None,
331
+ "score": int(score),
332
+ "history": agent.memory.get_context(),
333
+ }
334
+ self._results.append(result_data)
335
+ file_obj.write(json.dumps(result_data, indent=2) + "\n")
336
+ file_obj.flush()
337
+
338
+ def _handle_error(
339
+ self, task: Dict[str, Any], error: Exception, file_obj: Any
340
+ ) -> None:
341
+ r"""Handle errors encountered during task processing."""
342
+ logger.warning(f"Error processing task {task['task_id']}: {error}")
343
+ error_data = {
344
+ "task_id": task["task_id"],
345
+ "question": task["Question"],
346
+ "level": task["Level"],
347
+ "model_answer": "ERROR",
348
+ "ground_truth": task["Final answer"],
349
+ "tool_calls": [],
350
+ "error": str(error),
351
+ "score": 0,
352
+ }
353
+ self._results.append(error_data)
354
+ file_obj.write(json.dumps(error_data, indent=2) + "\n")
355
+ file_obj.flush()
356
+
357
+ def _generate_summary(self) -> Dict[str, Any]:
358
+ r"""Generate and return a summary of the benchmark results."""
359
+ return {
360
+ "total": len(self._results),
361
+ "correct": sum(result["score"] for result in self._results),
362
+ "results": self._results,
363
+ }
364
+
365
+ def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
366
+ r"""Scorer for the GAIA benchmark.
367
+ https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
368
+ scorer.py
369
+
370
+ Args:
371
+ model_answer (str): The model answer.
372
+ ground_truth (str): The ground truth answer.
373
+
374
+ Returns:
375
+ bool: The score of the model
376
+ """
377
+
378
+ def is_float(element: Any) -> bool:
379
+ try:
380
+ float(element)
381
+ return True
382
+ except ValueError:
383
+ return False
384
+
385
+ if is_float(ground_truth):
386
+ logger.info(f"Evaluating {model_answer} as a number.")
387
+ normalized_answer = self.normalize_number_str(model_answer)
388
+ return normalized_answer == float(ground_truth)
389
+
390
+ elif any(char in ground_truth for char in [",", ";"]):
391
+ logger.info(
392
+ f"Evaluating {model_answer} as a comma separated list."
393
+ )
394
+ gt_elems = self.split_string(ground_truth)
395
+ ma_elems = self.split_string(model_answer)
396
+
397
+ if len(gt_elems) != len(ma_elems):
398
+ logger.warning(
399
+ "Answer lists have different lengths, returning False.",
400
+ UserWarning,
401
+ )
402
+ return False
403
+
404
+ comparisons = []
405
+ for ma_elem, gt_elem in zip(ma_elems, gt_elems):
406
+ if is_float(gt_elem):
407
+ normalized_ma_elem = self.normalize_number_str(ma_elem)
408
+ comparisons.append(normalized_ma_elem == float(gt_elem))
409
+ else:
410
+ ma_elem = self.normalize_str(ma_elem, remove_punct=False)
411
+ gt_elem = self.normalize_str(gt_elem, remove_punct=False)
412
+ comparisons.append(ma_elem == gt_elem)
413
+ return all(comparisons)
414
+ else:
415
+ logger.info(f"Evaluating {model_answer} as a string.")
416
+ ma_elem = self.normalize_str(model_answer)
417
+ gt_elem = self.normalize_str(ground_truth)
418
+ return ma_elem == gt_elem
419
+
420
+ def normalize_number_str(self, number_str: str) -> float:
421
+ for char in ["$", "%", ","]:
422
+ number_str = number_str.replace(char, "")
423
+ try:
424
+ return float(number_str)
425
+ except ValueError:
426
+ logger.error(
427
+ f"String {number_str} cannot be normalized to number str."
428
+ )
429
+ return float("inf")
430
+
431
+ def split_string(
432
+ self, s: str, char_list: Optional[List[str]] = None
433
+ ) -> list[str]:
434
+ r"""Split a string based on a list of characters.
435
+
436
+ Args:
437
+ s (str): The string to split.
438
+ char_list (Optional[List[str]], optional): T
439
+ he list of characters to split on.
440
+ (default: :obj:`None`)
441
+ """
442
+ if char_list is None:
443
+ char_list = [",", ";"]
444
+ pattern = f"[{''.join(char_list)}]"
445
+ return re.split(pattern, s)
446
+
447
+ def normalize_str(self, input_str, remove_punct=True) -> str:
448
+ r"""Normalize a string.
449
+
450
+ Args:
451
+ input_str: The input string to normalize.
452
+ remove_punct: Whether to remove punctuation.
453
+
454
+ Returns:
455
+ str: The normalized string.
456
+ """
457
+ no_spaces = re.sub(r"\s", "", input_str)
458
+ if remove_punct:
459
+ translator = str.maketrans("", "", string.punctuation)
460
+ return no_spaces.lower().translate(translator)
461
+ else:
462
+ return no_spaces.lower()
463
+
464
+ def get_final_answer(self, content: str) -> str:
465
+ r"""Get the final answer from the content.
466
+
467
+ Args:
468
+ content (str): The content to extract the final answer from.
469
+
470
+ Returns:
471
+ str: The final answer.
472
+ """
473
+ final_answer_index = content.find("FINAL ANSWER")
474
+ if final_answer_index == -1:
475
+ return "FINAL ANSWER not found"
476
+ start_index = final_answer_index + len("FINAL ANSWER: ")
477
+ final_answer_content = content[start_index:].strip()
478
+ return final_answer_content
camel/benchmarks/nexus.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ import ast
16
+ import json
17
+ import logging
18
+ import os
19
+ import random
20
+ import textwrap
21
+ from dataclasses import dataclass
22
+ from pathlib import Path
23
+ from typing import Any, Dict, List, Literal, Optional, Tuple, Union
24
+
25
+ import pandas as pd
26
+ from datasets import load_dataset
27
+ from tqdm import tqdm
28
+
29
+ from camel.agents import ChatAgent
30
+ from camel.benchmarks.base import BaseBenchmark
31
+ from camel.messages import BaseMessage
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # Define the data class
37
+ @dataclass
38
+ class NexusSample:
39
+ r"""Nexus benchmark dataset sample."""
40
+
41
+ input: str
42
+ output: str
43
+
44
+
45
+ @dataclass
46
+ class NexusTool:
47
+ r"""Nexus benchmark tool"""
48
+
49
+ function_calls: str
50
+ descriptions: str
51
+
52
+
53
+ dataset_mapping = {
54
+ "NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
55
+ "VirusTotal": "Nexusflow/VirusTotalBenchmark",
56
+ "PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
57
+ "ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
58
+ "OTX": "Nexusflow/OTXAPIBenchmark",
59
+ "VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
60
+ "VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
61
+ "NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
62
+ }
63
+
64
+ TOOL_CALLING_PROMPT = """
65
+ You are given multiple functions and a user query.
66
+
67
+ Please proceed with generating a function call for the function \
68
+ with the proper arguments that best answers the given prompt.
69
+
70
+ Respond with nothing but the function call ONLY, such that I can \
71
+ directly execute your function call without any post processing \
72
+ necessary from my end. Do not use variables.
73
+ If there are more than two function calls, separate them with a semicolon (;).
74
+
75
+ {tools}
76
+
77
+ Question: {input}
78
+ """
79
+
80
+
81
+ class NexusBenchmark(BaseBenchmark):
82
+ r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
83
+ Function Calling Benchmark`
84
+ <https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
85
+
86
+ Args:
87
+ data_dir (str): The directory to save the data.
88
+ save_to (str): The file to save the results.
89
+ processes (int, optional): The number of processes to use.
90
+ (default: :obj:`1`)
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ data_dir: str,
96
+ save_to: str,
97
+ processes: int = 1,
98
+ ):
99
+ r"""Initialize the Nexus Function Calling benchmark.
100
+
101
+ Args:
102
+ data_dir (str): The directory to save the data.
103
+ save_to (str): The file to save the results.
104
+ processes (int, optional): The number of processes to use for
105
+ parallel processing. (default: :obj:`1`)
106
+ """
107
+ super().__init__("nexus", data_dir, save_to, processes)
108
+ self._data: List[NexusSample] = [] # type: ignore[assignment]
109
+
110
+ def download(self):
111
+ r"""Download the Nexus Functional Calling Benchmark dataset."""
112
+ from huggingface_hub import snapshot_download
113
+
114
+ for dataset_name, repo_id in dataset_mapping.items():
115
+ local_dir = self.data_dir / dataset_name
116
+ snapshot_download(
117
+ repo_id=repo_id,
118
+ repo_type="dataset",
119
+ local_dir=local_dir,
120
+ local_dir_use_symlinks=True,
121
+ )
122
+
123
+ def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
124
+ r"""Load the Nexus Benchmark dataset.
125
+
126
+ Args:
127
+ dataset_name (str): Name of the specific dataset to be loaded.
128
+ force_download (bool): Whether to force download the data.
129
+ """
130
+
131
+ def _load_csv_data(dataset_dir: Path) -> List:
132
+ r"""Load datasets from CSV files."""
133
+ dataset = []
134
+ for file_name in os.listdir(dataset_dir):
135
+ file_path = dataset_dir / file_name
136
+ if file_name.endswith(".csv"):
137
+ data = pd.read_csv(file_path)
138
+ for _, sample in data.iterrows():
139
+ dataset.append(
140
+ NexusSample(
141
+ sample["Input"], "".join(sample["Output"])
142
+ )
143
+ )
144
+ continue
145
+
146
+ logger.warning(f"Skipping unsupported file: {file_name}")
147
+ return dataset
148
+
149
+ def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
150
+ r"""Load datasets from Parquet files."""
151
+ dataset = []
152
+ if not data_dir.exists():
153
+ raise FileNotFoundError(
154
+ f"Data directory '{data_dir}' does not exist."
155
+ )
156
+
157
+ for file_name in os.listdir(data_dir):
158
+ file_path = data_dir / file_name
159
+ if file_name.endswith(".parquet"):
160
+ data = pd.read_parquet(file_path)
161
+ dataset.extend(_process_parquet_data(data, dataset_name))
162
+ continue
163
+
164
+ logger.warning(f"Skipping unsupported file: {file_name}")
165
+
166
+ return dataset
167
+
168
+ def _process_parquet_data(
169
+ data: pd.DataFrame, dataset_name: str
170
+ ) -> List:
171
+ r"""Process data from Parquet files based on dataset name."""
172
+ dataset: List = []
173
+ dataset_handlers = {
174
+ "NVDLibrary": _process_nvdlibrary,
175
+ "VirusTotal": _process_simple,
176
+ "PlacesAPI": _process_simple,
177
+ "ClimateAPI": _process_simple,
178
+ "OTX": _process_simple,
179
+ "VirusTotal-NestedCalls": _process_nested_calls,
180
+ "VirusTotal-ParallelCalls": _process_parallel_calls,
181
+ }
182
+
183
+ if dataset_name not in dataset_handlers:
184
+ logger.warning(
185
+ f"No specific handler for dataset: {dataset_name}"
186
+ )
187
+ return dataset
188
+
189
+ handler = dataset_handlers[dataset_name]
190
+ for _, sample in data.iterrows():
191
+ processed_sample = handler(sample)
192
+ if processed_sample:
193
+ dataset.append(processed_sample)
194
+ return dataset
195
+
196
+ def _process_nvdlibrary(sample) -> NexusSample:
197
+ r"""Process samples for the NVDLibrary dataset."""
198
+ return NexusSample(
199
+ sample["Input"], sample["Output"].replace("r = nvdlib.", "")
200
+ )
201
+
202
+ def _process_simple(sample) -> NexusSample:
203
+ r"""Process samples for simple datasets (e.g., VirusTotal)."""
204
+ return NexusSample(sample["Input"], sample["Output"])
205
+
206
+ def _process_nested_calls(sample) -> Union[NexusSample, None]:
207
+ r"""Process samples for VirusTotal-NestedCalls dataset."""
208
+ if len(sample["fncall"]) == 1:
209
+ return NexusSample(
210
+ sample["generated_question"], "".join(sample["fncall"])
211
+ )
212
+ return None
213
+
214
+ def _process_parallel_calls(sample) -> Union[NexusSample, None]:
215
+ r"""Process samples for VirusTotal-ParallelCalls dataset."""
216
+ if len(sample["fncall"]) > 1:
217
+ return NexusSample(
218
+ sample["generated_question"], "; ".join(sample["fncall"])
219
+ )
220
+ return None
221
+
222
+ if force_download:
223
+ logger.info("Force downloading data.")
224
+ self.download()
225
+
226
+ # Validate dataset name
227
+ if dataset_name not in dataset_mapping:
228
+ available_datasets = list(dataset_mapping.keys())
229
+ raise ValueError(
230
+ f"Dataset '{dataset_name}' is not recognized. "
231
+ f"Available datasets: {available_datasets}"
232
+ )
233
+
234
+ # Get the dataset directory
235
+ dataset_dir = self.data_dir / dataset_name
236
+ if not dataset_dir.exists():
237
+ raise FileNotFoundError(
238
+ f"The dataset directory for '{dataset_name}' \
239
+ does not exist at {dataset_dir}. "
240
+ "Please download it first."
241
+ )
242
+
243
+ # Load the dataset
244
+ if dataset_name == "NVDLibrary-NestedCalls":
245
+ self._data = _load_csv_data(dataset_dir)
246
+ else:
247
+ self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
248
+
249
+ @property
250
+ def train(self):
251
+ r"""Get the training set."""
252
+ raise NotImplementedError(
253
+ "Nexus Functional Calling has only a single 'train' set."
254
+ )
255
+
256
+ def run( # type: ignore[override, return]
257
+ self,
258
+ agent: ChatAgent,
259
+ task: Literal[
260
+ "NVDLibrary",
261
+ "VirusTotal",
262
+ "OTX",
263
+ "PlacesAPI",
264
+ "ClimateAPI",
265
+ "VirusTotal-ParallelCalls",
266
+ "VirusTotal-NestedCalls",
267
+ "NVDLibrary-NestedCalls",
268
+ ],
269
+ randomize: bool = False,
270
+ subset: Optional[int] = None,
271
+ ) -> Dict[str, Any]:
272
+ r"""Run the benchmark.
273
+
274
+ Args:
275
+ agent (ChatAgent): The agent to run the benchmark.
276
+ task (Literal["NVDLibrary", "VirusTotal", "OTX",
277
+ "PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
278
+ "VirusTotal-NestedCalls",
279
+ "NVDLibrary-NestedCalls"]): The task to run the benchmark.
280
+ randomize (bool, optional): Whether to randomize the data.
281
+ (default: :obj:`False`)
282
+ subset (Optional[int], optional): The subset of data to run.
283
+ (default: :obj:`None`)
284
+
285
+ Returns:
286
+ Dict[str, Any]: The results of the benchmark.
287
+ """
288
+
289
+ if task not in dataset_mapping:
290
+ raise ValueError(f"Invalid value for dataset: {task}.")
291
+
292
+ logger.info(f"Running Nexus Function Calling benchmark on {task}.")
293
+ self.load(task)
294
+ datas = self._data
295
+
296
+ # Shuffle and subset data if necessary
297
+ if randomize:
298
+ random.shuffle(datas)
299
+ if subset:
300
+ datas = datas[:subset]
301
+
302
+ logger.info(f"Number of tasks: {len(datas)}")
303
+
304
+ # Initialize results storage
305
+ self._results = []
306
+
307
+ # Process samples
308
+ tools = construct_tool_descriptions(task)
309
+ with open(self.save_to, "w") as f:
310
+ for sample in tqdm(datas, desc="Running"):
311
+ prompt = construct_prompt(input=sample.input, tools=tools)
312
+ msg = BaseMessage.make_user_message(
313
+ role_name="User", content=prompt
314
+ )
315
+ ground_truth_call = sample.output
316
+ try:
317
+ # Generate response
318
+ response = agent.step(msg)
319
+ agent_call = response.msgs[0].content
320
+
321
+ # Evaluate response
322
+ if agent_call:
323
+ result = compare_function_calls(
324
+ agent_call=agent_call,
325
+ ground_truth_call=ground_truth_call,
326
+ )
327
+ self._results.append(
328
+ {
329
+ "input": sample.input,
330
+ "agent_call": agent_call,
331
+ "ground_truth_call": ground_truth_call,
332
+ "result": result,
333
+ "error": None,
334
+ }
335
+ )
336
+ except Exception as e:
337
+ logger.warning(f"Error in processing task: {sample.input}")
338
+ self._results.append(
339
+ {
340
+ "input": sample.input,
341
+ "agent_call": None,
342
+ "ground_truth_call": ground_truth_call,
343
+ "result": 0,
344
+ "error": str(e),
345
+ }
346
+ )
347
+
348
+ agent.reset()
349
+
350
+ f.write(json.dumps(self._results[-1], indent=2) + "\n")
351
+ f.flush()
352
+
353
+ total = len(self._results)
354
+ correct = sum(r["result"] for r in self._results)
355
+
356
+ return {
357
+ "total": total,
358
+ "correct": correct,
359
+ "accuracy": correct / total,
360
+ }
361
+
362
+
363
+ # Utility functions
364
+ def construct_tool_descriptions(dataset_name: str) -> str:
365
+ r"""Construct tool descriptions from function definitions and
366
+ descriptions."""
367
+ tool_dataset_mapping = {
368
+ "NVDLibrary": "CVECPE",
369
+ "VirusTotal": "VirusTotal",
370
+ "PlacesAPI": "Places",
371
+ "ClimateAPI": "Climate",
372
+ "OTX": "OTX",
373
+ "VirusTotal-NestedCalls": "VT_Multi (Nested)",
374
+ "VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
375
+ "NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
376
+ }
377
+
378
+ if dataset_name not in tool_dataset_mapping:
379
+ raise ValueError(
380
+ f"Dataset '{dataset_name}' is not recognized. "
381
+ f"Available datasets: {list(dataset_mapping.keys())}"
382
+ )
383
+
384
+ # Load the dataset based on the dataset name
385
+ dataset = load_dataset(
386
+ "Nexusflow/Function_Call_Definitions",
387
+ name=tool_dataset_mapping[dataset_name],
388
+ )["train"]
389
+
390
+ # Construct tool descriptions
391
+ tools = [
392
+ NexusTool(tool["function_calls"], tool["descriptions"])
393
+ for tool in dataset
394
+ ]
395
+
396
+ # Generate the tool prompt
397
+ tool_prompt = "".join(
398
+ f"Function:\ndef {tool.function_calls}:\n"
399
+ + "\"\"\"\n"
400
+ + f"{tool.descriptions}\n"
401
+ + "\"\"\"\n"
402
+ for tool in tools
403
+ )
404
+
405
+ return tool_prompt
406
+
407
+
408
+ def construct_prompt(input: str, tools: str) -> str:
409
+ r"Construct prompt from tools and input."
410
+ return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
411
+
412
+
413
+ # Functions for function call evaluation
414
+ def parse_function_call(
415
+ call: str,
416
+ ) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
417
+ r"""Parse a function call string to extract the function name,
418
+ positional arguments, and keyword arguments, including
419
+ nested function calls.
420
+
421
+ Args:
422
+ call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
423
+
424
+ Returns:
425
+ tuple: (function_name (str), positional_args (list),
426
+ keyword_args (dict)) or (None, None, None).
427
+ """
428
+
429
+ def preprocess_input(call: str) -> str:
430
+ r"""Remove formatting like code blocks and whitespace."""
431
+ if call.strip().startswith("```python"):
432
+ call = call.strip().removeprefix("```python").removesuffix("```")
433
+ return textwrap.dedent(call).strip()
434
+
435
+ def evaluate_arg(arg):
436
+ r"""Recursively evaluate arguments, including nested calls."""
437
+ if isinstance(arg, ast.Call):
438
+ # Recursively parse nested calls
439
+ func_name, args, kwargs = parse_function_call(ast.unparse(arg))
440
+ return func_name, args, kwargs
441
+ elif isinstance(
442
+ arg, ast.Constant
443
+ ): # Handle literals like numbers, strings, etc.
444
+ return arg.value
445
+ elif isinstance(arg, ast.List): # Handle list literals
446
+ return [evaluate_arg(el) for el in arg.elts]
447
+ elif isinstance(arg, ast.Dict): # Handle dictionary literals
448
+ return {
449
+ evaluate_arg(k): evaluate_arg(v)
450
+ for k, v in zip(arg.keys, arg.values)
451
+ }
452
+ elif isinstance(arg, ast.Tuple): # Handle tuple literals
453
+ return tuple(evaluate_arg(el) for el in arg.elts)
454
+ else:
455
+ return ast.literal_eval(arg) # Safely evaluate other types
456
+
457
+ call = preprocess_input(call)
458
+ parsed_calls = []
459
+
460
+ try:
461
+ # Parse the string into an AST
462
+ parsed_calls = call.split(";")
463
+ for single_call in parsed_calls:
464
+ tree = ast.parse(single_call, mode='eval')
465
+
466
+ # Ensure it's a function call
467
+ if isinstance(tree.body, ast.Call):
468
+ # Extract function name
469
+ if isinstance(
470
+ tree.body.func, ast.Name
471
+ ): # Simple function call
472
+ func_name = tree.body.func.id
473
+ elif isinstance(
474
+ tree.body.func, ast.Attribute
475
+ ): # Attribute function call
476
+ func_name = (
477
+ f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
478
+ )
479
+ else:
480
+ raise ValueError(f"Unsupported function call: {call}")
481
+
482
+ # Extract positional arguments
483
+ args = [evaluate_arg(arg) for arg in tree.body.args]
484
+
485
+ # Extract keyword arguments
486
+ kwargs: Dict[str, Any] = {
487
+ kw.arg: evaluate_arg(kw.value)
488
+ for kw in tree.body.keywords
489
+ if kw.arg is not None
490
+ }
491
+ logger.info("Valid call.")
492
+ return func_name, args, kwargs
493
+ else:
494
+ raise ValueError(f"Not a valid function call: {call}")
495
+ except Exception as e:
496
+ logger.info(f"Error parsing call: {call}, {e}")
497
+ return None, None, None
498
+
499
+
500
+ def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
501
+ r"""Compare the function name and arguments of
502
+ agent_call and ground_truth_call.
503
+ Args:
504
+ agent_call (str): Function call by agent.
505
+ ground_truth_call (str): Ground truth function call.
506
+
507
+ Returns:
508
+ - `True` if the function names and arguments match.
509
+ - `False` otherwise.
510
+ """
511
+ # Parse both calls
512
+ agent_parsed = parse_function_call(agent_call)
513
+ gt_parsed = parse_function_call(ground_truth_call)
514
+
515
+ if agent_parsed and gt_parsed:
516
+ return agent_parsed == gt_parsed
517
+ else:
518
+ return False
camel/benchmarks/ragbench.py ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
16
+
17
+ import numpy as np
18
+ from datasets import Dataset, load_dataset
19
+
20
+ from camel.agents import ChatAgent
21
+ from camel.benchmarks import BaseBenchmark
22
+ from camel.logger import get_logger
23
+ from camel.retrievers import AutoRetriever
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ class RagasFields:
29
+ r"""Constants for RAGAS evaluation field names."""
30
+
31
+ INPUT_CONTEXT = "contexts"
32
+ INPUT_QUESTION = "question"
33
+ INPUT_ANSWER = "answer"
34
+
35
+
36
+ def annotate_dataset(
37
+ dataset: Dataset,
38
+ context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
39
+ answer_call: Optional[Callable[[Dict[str, Any]], str]],
40
+ ) -> Dataset:
41
+ r"""Annotate the dataset by adding context and answers using the provided
42
+ functions.
43
+
44
+ Args:
45
+ dataset (Dataset): The input dataset to annotate.
46
+ context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
47
+ Function to generate context for each example.
48
+ answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
49
+ generate answer for each example.
50
+
51
+ Returns:
52
+ Dataset: The annotated dataset with added contexts and/or answers.
53
+ """
54
+
55
+ def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
56
+ if context_call:
57
+ example["contexts"] = context_call(example)
58
+ if answer_call:
59
+ example["answer"] = answer_call(example)
60
+ return example
61
+
62
+ return dataset.map(process_example)
63
+
64
+
65
+ def rmse(
66
+ input_trues: Sequence[float],
67
+ input_preds: Sequence[float],
68
+ ) -> Optional[float]:
69
+ r"""Calculate Root Mean Squared Error (RMSE).
70
+
71
+ Args:
72
+ input_trues (Sequence[float]): Ground truth values.
73
+ input_preds (Sequence[float]): Predicted values.
74
+
75
+ Returns:
76
+ Optional[float]: RMSE value, or None if inputs have different lengths.
77
+ """
78
+ if len(input_trues) != len(input_preds):
79
+ logger.warning("Input lengths mismatch in RMSE calculation")
80
+ return None
81
+
82
+ trues = np.array(input_trues)
83
+ preds = np.array(input_preds, dtype=float)
84
+
85
+ # Ignore NaN values in predictions
86
+ eval_idx = ~np.isnan(preds)
87
+ if not np.any(eval_idx):
88
+ logger.warning("No valid predictions for RMSE calculation")
89
+ return None
90
+
91
+ trues = trues[eval_idx]
92
+ preds = preds[eval_idx]
93
+
94
+ return float(np.sqrt(np.mean((preds - trues) ** 2)))
95
+
96
+
97
+ def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
98
+ r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
99
+
100
+ Args:
101
+ trues (Sequence[bool]): Ground truth binary values.
102
+ preds (Sequence[float]): Predicted probability values.
103
+
104
+ Returns:
105
+ float: AUROC score.
106
+ """
107
+ from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
108
+
109
+ eval_idx = ~np.isnan(preds)
110
+ if not np.any(eval_idx):
111
+ logger.warning("No valid predictions for AUROC calculation")
112
+ return 0.5 # Return random classifier score
113
+
114
+ return float(
115
+ roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
116
+ )
117
+
118
+
119
+ def ragas_calculate_metrics(
120
+ dataset: Dataset,
121
+ pred_context_relevance_field: Optional[str],
122
+ pred_faithfulness_field: Optional[str],
123
+ metrics_to_evaluate: Optional[List[str]] = None,
124
+ ground_truth_context_relevance_field: str = "relevance_score",
125
+ ground_truth_faithfulness_field: str = "adherence_score",
126
+ ) -> Dict[str, Optional[float]]:
127
+ r"""Calculate RAGAS evaluation metrics.
128
+
129
+ Args:
130
+ dataset (Dataset): The dataset containing predictions and ground truth.
131
+ pred_context_relevance_field (Optional[str]): Field name for predicted
132
+ context relevance.
133
+ pred_faithfulness_field (Optional[str]): Field name for predicted
134
+ faithfulness.
135
+ metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
136
+ ground_truth_context_relevance_field (str): Field name for ground truth
137
+ relevance.
138
+ ground_truth_faithfulness_field (str): Field name for ground truth
139
+ adherence.
140
+
141
+ Returns:
142
+ Dict[str, Optional[float]]: Dictionary of calculated metrics.
143
+ """
144
+ metrics_to_evaluate = metrics_to_evaluate or [
145
+ "context_relevancy",
146
+ "faithfulness",
147
+ ]
148
+ calculated_metrics: Dict[str, Optional[float]] = {}
149
+
150
+ if (
151
+ "context_relevancy" in metrics_to_evaluate
152
+ and pred_context_relevance_field
153
+ ):
154
+ trues_relevance = dataset[ground_truth_context_relevance_field]
155
+ preds_relevance = dataset[pred_context_relevance_field]
156
+ calculated_metrics["relevance_rmse"] = rmse(
157
+ trues_relevance, preds_relevance
158
+ )
159
+
160
+ if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
161
+ trues_hallucination = ~np.array(
162
+ dataset[ground_truth_faithfulness_field]
163
+ )
164
+ preds_hallucination = 1 - np.array(
165
+ dataset[pred_faithfulness_field], dtype=float
166
+ )
167
+ calculated_metrics["hallucination_auroc"] = auroc(
168
+ trues_hallucination.tolist(), preds_hallucination.tolist()
169
+ )
170
+
171
+ return calculated_metrics
172
+
173
+
174
+ def ragas_evaluate_dataset(
175
+ dataset: Dataset,
176
+ contexts_field_name: Optional[str],
177
+ answer_field_name: Optional[str],
178
+ metrics_to_evaluate: Optional[List[str]] = None,
179
+ ) -> Dataset:
180
+ r"""Evaluate the dataset using RAGAS metrics.
181
+
182
+ Args:
183
+ dataset (Dataset): Input dataset to evaluate.
184
+ contexts_field_name (Optional[str]): Field name containing contexts.
185
+ answer_field_name (Optional[str]): Field name containing answers.
186
+ metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
187
+
188
+ Returns:
189
+ Dataset: Dataset with added evaluation metrics.
190
+ """
191
+ from ragas import evaluate
192
+ from ragas.metrics import ( # type: ignore[import-untyped]
193
+ context_relevancy,
194
+ faithfulness,
195
+ )
196
+
197
+ metrics_to_evaluate = metrics_to_evaluate or [
198
+ "context_relevancy",
199
+ "faithfulness",
200
+ ]
201
+
202
+ # Rename fields if necessary
203
+ if (
204
+ contexts_field_name
205
+ and contexts_field_name != RagasFields.INPUT_CONTEXT
206
+ ):
207
+ dataset = dataset.rename_column(
208
+ contexts_field_name, RagasFields.INPUT_CONTEXT
209
+ )
210
+ if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
211
+ dataset = dataset.rename_column(
212
+ answer_field_name, RagasFields.INPUT_ANSWER
213
+ )
214
+
215
+ metrics = []
216
+ if "context_relevancy" in metrics_to_evaluate:
217
+ metrics.append(context_relevancy)
218
+ if "faithfulness" in metrics_to_evaluate:
219
+ metrics.append(faithfulness)
220
+
221
+ ragas_result = evaluate(dataset, metrics=metrics)
222
+ return Dataset.from_pandas(ragas_result.to_pandas())
223
+
224
+
225
+ class RAGBenchBenchmark(BaseBenchmark):
226
+ r"""RAGBench Benchmark for evaluating RAG performance.
227
+
228
+ This benchmark uses the rungalileo/ragbench dataset to evaluate
229
+ retrieval-augmented generation (RAG) systems. It measures context
230
+ relevancy and faithfulness metrics as described in
231
+ https://arxiv.org/abs/2407.11005.
232
+
233
+ Args:
234
+ processes (int, optional): Number of processes for parallel processing.
235
+ subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
236
+ split (str, optional): Dataset split to use (e.g., "test").
237
+ """
238
+
239
+ def __init__(
240
+ self,
241
+ processes: int = 1,
242
+ subset: Literal[
243
+ "covidqa",
244
+ "cuad",
245
+ "delucionqa",
246
+ "emanual",
247
+ "expertqa",
248
+ "finqa",
249
+ "hagrid",
250
+ "hotpotqa",
251
+ "msmarco",
252
+ "pubmedqa",
253
+ "tatqa",
254
+ "techqa",
255
+ ] = "hotpotqa",
256
+ split: Literal["train", "test", "validation"] = "test",
257
+ ) -> None:
258
+ super().__init__("ragbench", "rag_bench", "", processes)
259
+ self.subset = subset
260
+ self.split = split
261
+ self.dataset: Optional[Dataset] = None
262
+
263
+ def download(self):
264
+ r"""Download the RAGBench dataset."""
265
+ try:
266
+ self.dataset = load_dataset(
267
+ "rungalileo/ragbench", self.subset, split=self.split
268
+ )
269
+ except Exception as e:
270
+ logger.error(f"Failed to download dataset: {e}")
271
+ raise
272
+
273
+ def load(self, force_download: bool = False):
274
+ r"""Load the RAGBench dataset.
275
+
276
+ Args:
277
+ force_download (bool, optional): Whether to force download the
278
+ data.
279
+ """
280
+ if force_download or self.dataset is None:
281
+ logger.info(
282
+ "%s dataset",
283
+ "Force downloading" if force_download else "Loading",
284
+ )
285
+ self.download()
286
+
287
+ def run( # type: ignore[override, return]
288
+ self,
289
+ agent: ChatAgent,
290
+ auto_retriever: AutoRetriever,
291
+ ) -> Dict[str, Optional[float]]:
292
+ r"""Run the benchmark evaluation.
293
+
294
+ Args:
295
+ agent (ChatAgent): Chat agent for generating answers.
296
+ auto_retriever (AutoRetriever): Retriever for finding relevant
297
+ contexts.
298
+
299
+ Returns:
300
+ Dict[str, Optional[float]]: Dictionary of evaluation metrics.
301
+ """
302
+
303
+ def context_call(example):
304
+ retrieved_info = auto_retriever.run_vector_retriever(
305
+ query=example['question'],
306
+ contents=example['documents'],
307
+ top_k=1,
308
+ return_detailed_info=True,
309
+ similarity_threshold=0.5,
310
+ )
311
+ return [c['text'] for c in retrieved_info['Retrieved Context']]
312
+
313
+ def answer_call(example: Dict[str, Any]) -> str:
314
+ user_msg = str(example)
315
+ assistant_response = agent.step(user_msg)
316
+ return assistant_response.msg.content
317
+
318
+ # Annotate the dataset
319
+ annotated_ds = annotate_dataset(
320
+ self.dataset, context_call, answer_call
321
+ )
322
+ evaluated_ds = ragas_evaluate_dataset(
323
+ annotated_ds,
324
+ contexts_field_name="contexts",
325
+ answer_field_name="answer",
326
+ metrics_to_evaluate=["context_relevancy", "faithfulness"],
327
+ )
328
+
329
+ return ragas_calculate_metrics(
330
+ evaluated_ds,
331
+ pred_context_relevance_field="context_relevancy",
332
+ pred_faithfulness_field="faithfulness",
333
+ )
camel/bots/__init__.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .discord import DiscordApp
15
+ from .slack.models import (
16
+ SlackAppMentionEventBody,
17
+ SlackAppMentionEventProfile,
18
+ SlackAuthProfile,
19
+ SlackEventBody,
20
+ SlackEventProfile,
21
+ )
22
+ from .slack.slack_app import SlackApp
23
+ from .telegram_bot import TelegramBot
24
+
25
+ __all__ = [
26
+ 'DiscordApp',
27
+ 'SlackApp',
28
+ 'SlackAppMentionEventBody',
29
+ 'SlackAppMentionEventProfile',
30
+ 'SlackAuthProfile',
31
+ 'SlackEventBody',
32
+ 'SlackEventProfile',
33
+ 'TelegramBot',
34
+ ]
camel/bots/discord/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .discord_app import DiscordApp
15
+ from .discord_installation import DiscordInstallation
16
+ from .discord_store import (
17
+ DiscordBaseInstallationStore,
18
+ DiscordSQLiteInstallationStore,
19
+ )
20
+
21
+ __all__ = [
22
+ "DiscordApp",
23
+ "DiscordInstallation",
24
+ "DiscordSQLiteInstallationStore",
25
+ "DiscordBaseInstallationStore",
26
+ ]
camel/bots/discord/discord_app.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import os
15
+ from datetime import datetime, timedelta
16
+ from typing import TYPE_CHECKING, List, Optional
17
+
18
+ import discord
19
+ import httpx
20
+ from fastapi import FastAPI
21
+
22
+ from camel.bots.discord.discord_installation import DiscordInstallation
23
+ from camel.logger import get_logger
24
+ from camel.utils import api_keys_required, dependencies_required
25
+
26
+ from .discord_store import DiscordBaseInstallationStore
27
+
28
+ if TYPE_CHECKING:
29
+ from discord import Message
30
+
31
+ logger = get_logger(__name__)
32
+
33
+ TOKEN_URL = "https://discord.com/api/oauth2/token"
34
+ USER_URL = "https://discord.com/api/users/@me"
35
+
36
+
37
+ class DiscordApp:
38
+ r"""A class representing a Discord app that uses the `discord.py` library
39
+ to interact with Discord servers.
40
+
41
+ This bot can respond to messages in specific channels and only reacts to
42
+ messages that mention the bot.
43
+
44
+ Attributes:
45
+ channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
46
+ provided, the bot will only respond to messages in these channels.
47
+ token (Optional[str]): The Discord bot token used for authentication.
48
+ """
49
+
50
+ @dependencies_required('discord')
51
+ @api_keys_required(
52
+ [
53
+ ("token", "DISCORD_BOT_TOKEN"),
54
+ ]
55
+ )
56
+ def __init__(
57
+ self,
58
+ channel_ids: Optional[List[int]] = None,
59
+ token: Optional[str] = None,
60
+ client_id: Optional[str] = None,
61
+ client_secret: Optional[str] = None,
62
+ redirect_uri: Optional[str] = None,
63
+ installation_store: Optional[DiscordBaseInstallationStore] = None,
64
+ intents: Optional[discord.Intents] = None,
65
+ ) -> None:
66
+ r"""Initialize the DiscordApp instance by setting up the Discord client
67
+ and event handlers.
68
+
69
+ Args:
70
+ channel_ids (Optional[List[int]]): A list of allowed channel IDs.
71
+ The bot will only respond to messages in these channels if
72
+ provided. (default: :obj:`None`)
73
+ token (Optional[str]): The Discord bot token for authentication.
74
+ If not provided, the token will be retrieved from the
75
+ environment variable `DISCORD_TOKEN`. (default: :obj:`None`)
76
+ client_id (str, optional): The client ID for Discord OAuth.
77
+ (default: :obj:`None`)
78
+ client_secret (Optional[str]): The client secret for Discord OAuth.
79
+ (default: :obj:`None`)
80
+ redirect_uri (str): The redirect URI for OAuth callbacks.
81
+ (default: :obj:`None`)
82
+ installation_store (DiscordAsyncInstallationStore): The database
83
+ stores all information of all installations.
84
+ (default: :obj:`None`)
85
+ intents (discord.Intents): The Discord intents of this app.
86
+ (default: :obj:`None`)
87
+
88
+ Raises:
89
+ ValueError: If the `DISCORD_BOT_TOKEN` is not found in environment
90
+ variables.
91
+ """
92
+ self.token = token or os.getenv("DISCORD_BOT_TOKEN")
93
+ self.channel_ids = channel_ids
94
+ self.installation_store = installation_store
95
+
96
+ if not intents:
97
+ intents = discord.Intents.all()
98
+ intents.message_content = True
99
+ intents.guilds = True
100
+
101
+ self._client = discord.Client(intents=intents)
102
+
103
+ # Register event handlers
104
+ self._client.event(self.on_ready)
105
+ self._client.event(self.on_message)
106
+
107
+ # OAuth flow
108
+ self.client_id = client_id or os.getenv("DISCORD_CLIENT_ID")
109
+ self.client_secret = client_secret or os.getenv(
110
+ "DISCORD_CLIENT_SECRET"
111
+ )
112
+ self.redirect_uri = redirect_uri
113
+
114
+ self.oauth_flow = bool(
115
+ self.client_id
116
+ and self.client_secret
117
+ and self.redirect_uri
118
+ and self.installation_store
119
+ )
120
+
121
+ self.app = FastAPI()
122
+
123
+ async def start(self):
124
+ r"""Asynchronously start the Discord bot using its token.
125
+
126
+ This method starts the bot and logs into Discord asynchronously using
127
+ the provided token. It should be awaited when used in an async
128
+ environment.
129
+ """
130
+ await self._client.start(self.token)
131
+
132
+ def run(self) -> None:
133
+ r"""Start the Discord bot using its token.
134
+
135
+ This method starts the bot and logs into Discord synchronously using
136
+ the provided token. It blocks execution and keeps the bot running.
137
+ """
138
+ self._client.run(self.token) # type: ignore[arg-type]
139
+
140
+ async def exchange_code_for_token_response(
141
+ self, code: str
142
+ ) -> Optional[str]:
143
+ r"""Exchange the authorization code for an access token.
144
+
145
+ Args:
146
+ code (str): The authorization code received from Discord after
147
+ user authorization.
148
+
149
+ Returns:
150
+ Optional[str]: The access token if successful, otherwise None.
151
+
152
+ Raises:
153
+ ValueError: If OAuth configuration is incomplete or invalid.
154
+ httpx.RequestError: If there is a network issue during the request.
155
+ """
156
+ if not self.oauth_flow:
157
+ logger.warning(
158
+ "OAuth is not enabled. Missing client_id, "
159
+ "client_secret, or redirect_uri."
160
+ )
161
+ return None
162
+ data = {
163
+ "client_id": self.client_id,
164
+ "client_secret": self.client_secret,
165
+ "grant_type": "authorization_code",
166
+ "code": code,
167
+ "redirect_uri": self.redirect_uri,
168
+ }
169
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
170
+ try:
171
+ async with httpx.AsyncClient() as client:
172
+ response = await client.post(
173
+ TOKEN_URL, data=data, headers=headers
174
+ )
175
+ if response.status_code != 200:
176
+ logger.error(f"Failed to exchange code: {response.text}")
177
+ return None
178
+ response_data = response.json()
179
+
180
+ return response_data
181
+ except (httpx.RequestError, ValueError) as e:
182
+ logger.error(f"Error during token fetch: {e}")
183
+ return None
184
+
185
+ async def get_user_info(self, access_token: str) -> Optional[dict]:
186
+ r"""Retrieve user information using the access token.
187
+
188
+ Args:
189
+ access_token (str): The access token received from Discord.
190
+
191
+ Returns:
192
+ dict: The user information retrieved from Discord.
193
+ """
194
+ if not self.oauth_flow:
195
+ logger.warning(
196
+ "OAuth is not enabled. Missing client_id, "
197
+ "client_secret, or redirect_uri."
198
+ )
199
+ return None
200
+ headers = {"Authorization": f"Bearer {access_token}"}
201
+ async with httpx.AsyncClient() as client:
202
+ user_response = await client.get(USER_URL, headers=headers)
203
+ return user_response.json()
204
+
205
+ async def refresh_access_token(self, refresh_token: str) -> Optional[str]:
206
+ r"""Refresh the access token using a refresh token.
207
+
208
+ Args:
209
+ refresh_token (str): The refresh token issued by Discord that
210
+ can be used to obtain a new access token.
211
+
212
+ Returns:
213
+ Optional[str]: The new access token if successful, otherwise None.
214
+ """
215
+ if not self.oauth_flow:
216
+ logger.warning(
217
+ "OAuth is not enabled. Missing client_id, "
218
+ "client_secret, or redirect_uri."
219
+ )
220
+ return None
221
+ data = {
222
+ "client_id": self.client_id,
223
+ "client_secret": self.client_secret,
224
+ "grant_type": "refresh_token",
225
+ "refresh_token": refresh_token,
226
+ "redirect_uri": self.redirect_uri,
227
+ }
228
+ headers = {"Content-Type": "application/x-www-form-urlencoded"}
229
+ async with httpx.AsyncClient() as client:
230
+ response = await client.post(TOKEN_URL, data=data, headers=headers)
231
+ if response.status_code != 200:
232
+ logger.error(f"Failed to refresh token: {response.text}")
233
+ return None
234
+ response_data = response.json()
235
+ return response_data.get("access_token")
236
+
237
+ async def get_valid_access_token(self, guild_id: str) -> Optional[str]:
238
+ r"""Retrieve a valid access token for the specified guild.
239
+
240
+ This method attempts to retrieve an access token for a specific guild.
241
+ If the current access token is expired, it will refresh the token using
242
+ the refresh token.
243
+
244
+ Args:
245
+ guild_id (str): The ID of the guild to retrieve the access
246
+ token for.
247
+
248
+ Returns:
249
+ Optional[str]: The valid access token if successful,
250
+ otherwise None.
251
+ """
252
+ if not self.oauth_flow:
253
+ logger.warning(
254
+ "OAuth is not enabled. Missing client_id, "
255
+ "client_secret, or redirect_uri."
256
+ )
257
+ return None
258
+ assert self.installation_store is not None
259
+ installation = await self.installation_store.find_by_guild(
260
+ guild_id=guild_id
261
+ )
262
+ if not installation:
263
+ logger.error(f"No installation found for guild: {guild_id}")
264
+ return None
265
+
266
+ if (
267
+ installation.token_expires_at
268
+ and datetime.now() >= installation.token_expires_at
269
+ ):
270
+ logger.info(
271
+ f"Access token expired for guild: {guild_id}, "
272
+ f"refreshing token..."
273
+ )
274
+ new_access_token = await self.refresh_access_token(
275
+ installation.refresh_token
276
+ )
277
+ if new_access_token:
278
+ installation.access_token = new_access_token
279
+ installation.token_expires_at = datetime.now() + timedelta(
280
+ seconds=3600
281
+ )
282
+ await self.installation_store.save(installation)
283
+ return new_access_token
284
+ else:
285
+ logger.error(
286
+ f"Failed to refresh access token for guild: {guild_id}"
287
+ )
288
+ return None
289
+
290
+ return installation.access_token
291
+
292
+ async def save_installation(
293
+ self,
294
+ guild_id: str,
295
+ access_token: str,
296
+ refresh_token: str,
297
+ expires_in: int,
298
+ ):
299
+ r"""Save the installation information for a given guild.
300
+
301
+ Args:
302
+ guild_id (str): The ID of the guild where the bot is installed.
303
+ access_token (str): The access token for the guild.
304
+ refresh_token (str): The refresh token for the guild.
305
+ expires_in: (int): The expiration time of the
306
+ access token.
307
+ """
308
+ if not self.oauth_flow:
309
+ logger.warning(
310
+ "OAuth is not enabled. Missing client_id, "
311
+ "client_secret, or redirect_uri."
312
+ )
313
+ return None
314
+ assert self.installation_store is not None
315
+ expires_at = datetime.now() + timedelta(seconds=expires_in)
316
+ installation = DiscordInstallation(
317
+ guild_id=guild_id,
318
+ access_token=access_token,
319
+ refresh_token=refresh_token,
320
+ installed_at=datetime.now(),
321
+ token_expires_at=expires_at,
322
+ )
323
+ await self.installation_store.save(installation)
324
+ logger.info(f"Installation saved for guild: {guild_id}")
325
+
326
+ async def remove_installation(self, guild: discord.Guild):
327
+ r"""Remove the installation for a given guild.
328
+
329
+ Args:
330
+ guild (discord.Guild): The guild from which the bot is
331
+ being removed.
332
+ """
333
+ if not self.oauth_flow:
334
+ logger.warning(
335
+ "OAuth is not enabled. Missing client_id, "
336
+ "client_secret, or redirect_uri."
337
+ )
338
+ return None
339
+ assert self.installation_store is not None
340
+ await self.installation_store.delete(guild_id=str(guild.id))
341
+ print(f"Bot removed from guild: {guild.id}")
342
+
343
+ async def on_ready(self) -> None:
344
+ r"""Event handler that is called when the bot has successfully
345
+ connected to the Discord server.
346
+
347
+ When the bot is ready and logged into Discord, it prints a message
348
+ displaying the bot's username.
349
+ """
350
+ logger.info(f'We have logged in as {self._client.user}')
351
+
352
+ async def on_message(self, message: 'Message') -> None:
353
+ r"""Event handler for processing incoming messages.
354
+
355
+ This method is called whenever a new message is received by the bot. It
356
+ will ignore messages sent by the bot itself, only respond to messages
357
+ in allowed channels (if specified), and only to messages that mention
358
+ the bot.
359
+
360
+ Args:
361
+ message (discord.Message): The message object received from
362
+ Discord.
363
+ """
364
+ # If the message author is the bot itself,
365
+ # do not respond to this message
366
+ if message.author == self._client.user:
367
+ return
368
+
369
+ # If allowed channel IDs are provided,
370
+ # only respond to messages in those channels
371
+ if self.channel_ids and message.channel.id not in self.channel_ids:
372
+ return
373
+
374
+ # Only respond to messages that mention the bot
375
+ if not self._client.user or not self._client.user.mentioned_in(
376
+ message
377
+ ):
378
+ return
379
+
380
+ logger.info(f"Received message: {message.content}")
381
+
382
+ @property
383
+ def client(self):
384
+ return self._client
camel/bots/discord/discord_installation.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from datetime import datetime
15
+ from typing import Optional
16
+
17
+
18
+ class DiscordInstallation:
19
+ r"""Represents an installation of a Discord application in a
20
+ specific guild (server).
21
+
22
+ Attributes:
23
+ guild_id (str): The unique identifier for the Discord guild (server)
24
+ where the application is installed.
25
+ access_token (str): The access token used to authenticate API requests
26
+ for the installed application.
27
+ refresh_token (str): The token used to refresh the access token when
28
+ it expires.
29
+ installed_at (datetime): The timestamp indicating when the application
30
+ was installed in the guild.
31
+ token_expires_at (Optional[datetime]): The optional timestamp
32
+ indicating when the access token will expire. Defaults to None
33
+ if the token does not have an expiration time.
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ guild_id: str,
39
+ access_token: str,
40
+ refresh_token: str,
41
+ installed_at: datetime,
42
+ token_expires_at: Optional[datetime] = None,
43
+ ):
44
+ r"""Initialize the DiscordInstallation.
45
+
46
+ Args:
47
+ guild_id (str): The unique identifier for the Discord guild
48
+ (server) where the application is installed.
49
+ access_token (str): The access token used to authenticate API
50
+ requests for the installed application.
51
+ refresh_token (str): The token used to refresh the access token
52
+ when it expires.
53
+ installed_at (datetime): The timestamp indicating when the
54
+ application was installed in the guild.
55
+ token_expires_at (Optional[datetime]): The optional timestamp
56
+ indicating when the access token will expire. Defaults to None
57
+ if the token does not have an expiration time.
58
+ (default: :obj:`None`)
59
+ """
60
+ self.guild_id = guild_id
61
+ self.access_token = access_token
62
+ self.refresh_token = refresh_token
63
+ self.installed_at = installed_at
64
+ self.token_expires_at = token_expires_at
camel/bots/discord/discord_store.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from typing import Optional
16
+
17
+ from .discord_installation import DiscordInstallation
18
+
19
+
20
+ class DiscordBaseInstallationStore:
21
+ r"""Abstract base class for managing Discord installations.
22
+
23
+ This class defines the interface for database operations related to storing
24
+ and retrieving Discord installation data. Subclasses must implement these
25
+ methods to handle database-specific logic.
26
+ """
27
+
28
+ async def init(self):
29
+ r"""Initializes the database connection or structure."""
30
+ pass
31
+
32
+ async def save(self, installation: DiscordInstallation):
33
+ r"""Saves or updates a Discord installation record."""
34
+ pass
35
+
36
+ async def find_by_guild(
37
+ self, guild_id: str
38
+ ) -> Optional[DiscordInstallation]:
39
+ r"""Finds an installation record by guild ID."""
40
+ pass
41
+
42
+ async def delete(self, guild_id: str):
43
+ r"""Deletes an installation record by guild ID."""
44
+ pass
45
+
46
+
47
+ class DiscordSQLiteInstallationStore(DiscordBaseInstallationStore):
48
+ r"""SQLite-based implementation for managing Discord installations.
49
+
50
+ This class provides methods for initializing the database, saving,
51
+ retrieving, and deleting installation records using SQLite.
52
+
53
+ Attributes:
54
+ database (str): Path to the SQLite database file.
55
+ """
56
+
57
+ def __init__(self, database: str):
58
+ r"""Initializes the SQLite installation store.
59
+
60
+ Args:
61
+ database (str): Path to the SQLite database file.
62
+ """
63
+ self.database = database
64
+
65
+ async def init(self):
66
+ r"""Initializes the database by creating the required table if it
67
+ does not exist."""
68
+ import aiosqlite
69
+
70
+ async with aiosqlite.connect(self.database) as db:
71
+ await db.execute(
72
+ """
73
+ CREATE TABLE IF NOT EXISTS discord_installations (
74
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
75
+ guild_id TEXT NOT NULL UNIQUE,
76
+ access_token TEXT NOT NULL,
77
+ refresh_token TEXT NOT NULL,
78
+ installed_at DATETIME NOT NULL,
79
+ token_expires_at DATETIME
80
+ );
81
+ """
82
+ )
83
+ await db.commit()
84
+
85
+ async def save(self, installation: DiscordInstallation):
86
+ r"""Saves a new installation record or updates an existing one.
87
+
88
+ Args:
89
+ installation (DiscordInstallation): The installation data to save.
90
+ """
91
+ import aiosqlite
92
+
93
+ async with aiosqlite.connect(self.database) as db:
94
+ await db.execute(
95
+ """
96
+ INSERT INTO discord_installations (
97
+ guild_id, access_token, refresh_token,
98
+ installed_at, token_expires_at
99
+ ) VALUES (?, ?, ?, ?, ?)
100
+ ON CONFLICT(guild_id) DO UPDATE SET
101
+ access_token = excluded.access_token,
102
+ refresh_token = excluded.refresh_token,
103
+ token_expires_at = excluded.token_expires_at;
104
+ """,
105
+ [
106
+ installation.guild_id,
107
+ installation.access_token,
108
+ installation.refresh_token,
109
+ installation.installed_at,
110
+ installation.token_expires_at,
111
+ ],
112
+ )
113
+ await db.commit()
114
+
115
+ async def find_by_guild(
116
+ self, guild_id: str
117
+ ) -> Optional[DiscordInstallation]:
118
+ r"""Finds an installation record by guild ID.
119
+
120
+ Args:
121
+ guild_id (str): The guild ID to search for.
122
+
123
+ Returns:
124
+ Optional[DiscordInstallation]: The installation record if found,
125
+ otherwise None.
126
+ """
127
+ import aiosqlite
128
+
129
+ async with aiosqlite.connect(self.database) as db:
130
+ async with db.execute(
131
+ "SELECT guild_id, access_token, refresh_token, "
132
+ "installed_at, token_expires_at FROM discord_installations "
133
+ "WHERE guild_id = ?",
134
+ [guild_id],
135
+ ) as cursor:
136
+ row = await cursor.fetchone()
137
+ if row:
138
+ return DiscordInstallation(
139
+ guild_id=row[0],
140
+ access_token=row[1],
141
+ refresh_token=row[2],
142
+ installed_at=row[3],
143
+ token_expires_at=row[4],
144
+ )
145
+ return None
146
+
147
+ async def delete(self, guild_id: str):
148
+ r"""Deletes an installation record by guild ID.
149
+
150
+ Args:
151
+ guild_id (str): The guild ID of the record to delete.
152
+ """
153
+ import aiosqlite
154
+
155
+ async with aiosqlite.connect(self.database) as db:
156
+ await db.execute(
157
+ "DELETE FROM discord_installations WHERE guild_id = ?",
158
+ [guild_id],
159
+ )
160
+ await db.commit()
camel/bots/slack/__init__.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .models import (
15
+ SlackAppMentionEventBody,
16
+ SlackAppMentionEventProfile,
17
+ SlackAuthProfile,
18
+ SlackEventBody,
19
+ SlackEventProfile,
20
+ )
21
+ from .slack_app import SlackApp
22
+
23
+ __all__ = [
24
+ 'SlackApp',
25
+ 'SlackAppMentionEventBody',
26
+ 'SlackAppMentionEventProfile',
27
+ 'SlackAuthProfile',
28
+ 'SlackEventBody',
29
+ 'SlackEventProfile',
30
+ ]
camel/bots/slack/models.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from typing import Optional
15
+
16
+ from pydantic import BaseModel
17
+
18
+
19
+ class SlackAuthProfile(BaseModel):
20
+ r"""Represents the authorization profile within a Slack event.
21
+
22
+ Events will contain a single, compact authorizations field that shows one
23
+ installation of your app that the event is visible to.
24
+ In other words, lists of authorizations will be truncated to one element.
25
+
26
+ If there's more than one installing party that your app is keeping track
27
+ of, it's best not to rely on the single party listed in authorizations to
28
+ be any particular one.
29
+
30
+ To get a full list of who can see events, call the apps.event.
31
+ authorizations.list method after obtaining an app-level token. Read more on
32
+ the changes here; they have taken effect for existing apps as of
33
+ February 24, 2021.
34
+
35
+ References:
36
+
37
+ - https://api.slack.com/apis/events-api#authorizations
38
+ - https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
39
+ """
40
+
41
+ enterprise_id: Optional[str] = None
42
+ """The ID of the enterprise associated with the authorization."""
43
+
44
+ team_id: str
45
+ """The ID of the team associated with the authorization."""
46
+
47
+ user_id: str
48
+ """The ID of the user associated with the authorization."""
49
+
50
+ is_bot: bool
51
+ """Whether the authorized user is a bot."""
52
+
53
+ is_enterprise_install: bool
54
+ """Whether the authorization is for an enterprise installation."""
55
+
56
+
57
+ class SlackEventProfile(BaseModel):
58
+ r"""Represents the detailed profile of a Slack event, including user,
59
+ message, and context data.
60
+ """
61
+
62
+ user: str
63
+ """The ID of the user associated with the event."""
64
+
65
+ type: str
66
+ """The type of the event (e.g., 'message')."""
67
+
68
+ ts: str
69
+ """A timestamp representing when the event was triggered."""
70
+
71
+ thread_ts: Optional[str] = None
72
+ """The timestamp of the parent message in a thread."""
73
+
74
+ client_msg_id: str
75
+ """A unique ID generated by the client for the message (if available)."""
76
+
77
+ text: str
78
+ """The message content text."""
79
+
80
+ team: str
81
+ """The ID of the team that the event is associated with."""
82
+
83
+ blocks: list
84
+ """The list of message blocks, providing structured information."""
85
+
86
+ channel: str
87
+ """The ID of the Slack channel where the event happened."""
88
+
89
+ event_ts: str
90
+ """The event-specific timestamp when it occurred."""
91
+
92
+ channel_type: Optional[str]
93
+ """The type of Slack channel (e.g., 'channel', 'im')."""
94
+
95
+
96
+ class SlackEventBody(BaseModel):
97
+ r"""Represents the entire body of a Slack event, including the event
98
+ profile, authorization, and context.
99
+ """
100
+
101
+ token: str
102
+ """The token to verify the source of the event."""
103
+
104
+ team_id: str
105
+ """The ID of the team where the event is happening."""
106
+
107
+ context_team_id: Optional[str]
108
+ """The team ID for the shared channel context, if applicable."""
109
+
110
+ context_enterprise_id: Optional[str] = None
111
+ """The enterprise ID for the shared channel context, if applicable."""
112
+
113
+ api_app_id: str
114
+ """The unique identifier for the Slack app that received the event."""
115
+
116
+ event: SlackEventProfile
117
+ """A detailed profile of the event"""
118
+
119
+ type: str
120
+ """The overall type of event received (e.g., 'event_callback')."""
121
+
122
+ event_id: str
123
+ """A unique identifier assigned to this event by Slack."""
124
+
125
+ event_time: int
126
+ """The timestamp (in seconds) representing when the event was triggered."""
127
+
128
+ authorizations: Optional[list[SlackAuthProfile]] = None
129
+ """An optional list of authorizations that describe which installation can
130
+ see the event."""
131
+
132
+ is_ext_shared_channel: bool
133
+ """Indicates if the event is part of a shared channel between different
134
+ organizations."""
135
+
136
+ event_context: str
137
+ """A unique string representing the context of the event."""
138
+
139
+
140
+ class SlackAppMentionEventProfile(SlackEventProfile):
141
+ r"""Represents the detailed profile of a Slack event where the app was
142
+ mentioned in a message.
143
+ """
144
+
145
+ channel_type: Optional[str] = None
146
+ """The type of Slack channel. it's None for app mentions."""
147
+
148
+
149
+ class SlackAppMentionEventBody(SlackEventBody):
150
+ r"""Represents the entire body of a Slack event where the app was mentioned
151
+ in a message.
152
+ """
153
+
154
+ context_team_id: Optional[str] = None
155
+ """A detailed profile of the event. it's None for app mentions."""
156
+
157
+ event: SlackAppMentionEventProfile
158
+ """A detailed profile of the event"""
camel/bots/slack/slack_app.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import logging
15
+ import os
16
+ from typing import TYPE_CHECKING, Any, Dict, Optional
17
+
18
+ from slack_sdk.oauth.installation_store.async_installation_store import (
19
+ AsyncInstallationStore,
20
+ )
21
+ from starlette import requests, responses
22
+
23
+ from camel.bots.slack.models import (
24
+ SlackAppMentionEventBody,
25
+ SlackAppMentionEventProfile,
26
+ SlackEventBody,
27
+ SlackEventProfile,
28
+ )
29
+ from camel.utils import dependencies_required
30
+
31
+ if TYPE_CHECKING:
32
+ from slack_bolt.context.async_context import AsyncBoltContext
33
+ from slack_bolt.context.say.async_say import AsyncSay
34
+ from slack_sdk.web.async_client import AsyncWebClient
35
+
36
+ logging.basicConfig(level=logging.INFO)
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ class SlackApp:
41
+ r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
42
+
43
+ This class is responsible for initializing and managing the Slack
44
+ application by setting up event handlers, running the app server, and
45
+ handling events such as messages and mentions from Slack.
46
+
47
+ Args:
48
+ token (Optional[str]): Slack API token for authentication.
49
+ scopes (Optional[str]): Slack app scopes for permissions.
50
+ signing_secret (Optional[str]): Signing secret for verifying Slack
51
+ requests.
52
+ client_id (Optional[str]): Slack app client ID.
53
+ client_secret (Optional[str]): Slack app client secret.
54
+ redirect_uri_path (str): The URI path for OAuth redirect, defaults to
55
+ "/slack/oauth_redirect".
56
+ installation_store (Optional[AsyncInstallationStore]): The installation
57
+ store for handling OAuth installations.
58
+ """
59
+
60
+ @dependencies_required('slack_bolt')
61
+ def __init__(
62
+ self,
63
+ token: Optional[str] = None,
64
+ scopes: Optional[str] = None,
65
+ signing_secret: Optional[str] = None,
66
+ client_id: Optional[str] = None,
67
+ client_secret: Optional[str] = None,
68
+ redirect_uri_path: str = "/slack/oauth_redirect",
69
+ installation_store: Optional[AsyncInstallationStore] = None,
70
+ ) -> None:
71
+ r"""Initializes the SlackApp instance by setting up the Slack Bolt app
72
+ and configuring event handlers and OAuth settings.
73
+
74
+ Args:
75
+ token (Optional[str]): The Slack API token.
76
+ scopes (Optional[str]): The scopes for Slack app permissions.
77
+ signing_secret (Optional[str]): The signing secret for verifying
78
+ requests.
79
+ client_id (Optional[str]): The Slack app client ID.
80
+ client_secret (Optional[str]): The Slack app client secret.
81
+ redirect_uri_path (str): The URI path for handling OAuth redirects
82
+ (default is "/slack/oauth_redirect").
83
+ installation_store (Optional[AsyncInstallationStore]): An optional
84
+ installation store for OAuth installations.
85
+ """
86
+ from slack_bolt.adapter.starlette.async_handler import (
87
+ AsyncSlackRequestHandler,
88
+ )
89
+ from slack_bolt.app.async_app import AsyncApp
90
+ from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
91
+
92
+ self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
93
+ self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
94
+ self.signing_secret: Optional[str] = signing_secret or os.getenv(
95
+ "SLACK_SIGNING_SECRET"
96
+ )
97
+ self.client_id: Optional[str] = client_id or os.getenv(
98
+ "SLACK_CLIENT_ID"
99
+ )
100
+ self.client_secret: Optional[str] = client_secret or os.getenv(
101
+ "SLACK_CLIENT_SECRET"
102
+ )
103
+
104
+ if not all([self.token, self.scopes, self.signing_secret]):
105
+ raise ValueError(
106
+ "`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
107
+ "environment variables must be set. Get it here: "
108
+ "`https://api.slack.com/apps`."
109
+ )
110
+
111
+ # Setup OAuth settings if client ID and secret are provided
112
+ if self.client_id and self.client_secret:
113
+ self._app = AsyncApp(
114
+ oauth_settings=AsyncOAuthSettings(
115
+ client_id=self.client_id,
116
+ client_secret=self.client_secret,
117
+ scopes=self.scopes,
118
+ redirect_uri_path=redirect_uri_path,
119
+ ),
120
+ logger=logger,
121
+ signing_secret=self.signing_secret,
122
+ installation_store=installation_store,
123
+ token=self.token,
124
+ )
125
+ else:
126
+ # Initialize Slack Bolt AsyncApp with settings
127
+ self._app = AsyncApp(
128
+ logger=logger,
129
+ signing_secret=self.signing_secret,
130
+ installation_store=installation_store,
131
+ token=self.token,
132
+ )
133
+
134
+ self._handler = AsyncSlackRequestHandler(self._app)
135
+ self.setup_handlers()
136
+
137
+ def setup_handlers(self) -> None:
138
+ r"""Sets up the event handlers for Slack events, such as `app_mention`
139
+ and `message`.
140
+
141
+ This method registers the `app_mention` and `on_message` event handlers
142
+ with the Slack Bolt app to respond to Slack events.
143
+ """
144
+ self._app.event("app_mention")(self.app_mention)
145
+ self._app.event("message")(self.on_message)
146
+
147
+ def run(
148
+ self,
149
+ port: int = 3000,
150
+ path: str = "/slack/events",
151
+ host: Optional[str] = None,
152
+ ) -> None:
153
+ r"""Starts the Slack Bolt app server to listen for incoming Slack
154
+ events.
155
+
156
+ Args:
157
+ port (int): The port on which the server should run (default is
158
+ 3000).
159
+ path (str): The endpoint path for receiving Slack events (default
160
+ is "/slack/events").
161
+ host (Optional[str]): The hostname to bind the server (default is
162
+ None).
163
+ """
164
+ self._app.start(port=port, path=path, host=host)
165
+
166
+ async def handle_request(
167
+ self, request: requests.Request
168
+ ) -> responses.Response:
169
+ r"""Handles incoming requests from Slack through the request handler.
170
+
171
+ Args:
172
+ request (Request): A Starlette request object representing the
173
+ incoming request.
174
+
175
+ Returns:
176
+ The response generated by the Slack Bolt handler.
177
+ """
178
+ return await self._handler.handle(request)
179
+
180
+ async def app_mention(
181
+ self,
182
+ context: "AsyncBoltContext",
183
+ client: "AsyncWebClient",
184
+ event: Dict[str, Any],
185
+ body: Dict[str, Any],
186
+ say: "AsyncSay",
187
+ ) -> None:
188
+ r"""Event handler for `app_mention` events.
189
+
190
+ This method is triggered when someone mentions the app in Slack.
191
+
192
+ Args:
193
+ context (AsyncBoltContext): The Slack Bolt context for the event.
194
+ client (AsyncWebClient): The Slack Web API client.
195
+ event (Dict[str, Any]): The event data for the app mention.
196
+ body (Dict[str, Any]): The full request body from Slack.
197
+ say (AsyncSay): A function to send a response back to the channel.
198
+ """
199
+ event_profile = SlackAppMentionEventProfile(**event)
200
+ event_body = SlackAppMentionEventBody(**body)
201
+
202
+ logger.info(f"app_mention, context: {context}")
203
+ logger.info(f"app_mention, client: {client}")
204
+ logger.info(f"app_mention, event_profile: {event_profile}")
205
+ logger.info(f"app_mention, event_body: {event_body}")
206
+ logger.info(f"app_mention, say: {say}")
207
+
208
+ async def on_message(
209
+ self,
210
+ context: "AsyncBoltContext",
211
+ client: "AsyncWebClient",
212
+ event: Dict[str, Any],
213
+ body: Dict[str, Any],
214
+ say: "AsyncSay",
215
+ ) -> None:
216
+ r"""Event handler for `message` events.
217
+
218
+ This method is triggered when the app receives a message in Slack.
219
+
220
+ Args:
221
+ context (AsyncBoltContext): The Slack Bolt context for the event.
222
+ client (AsyncWebClient): The Slack Web API client.
223
+ event (Dict[str, Any]): The event data for the message.
224
+ body (Dict[str, Any]): The full request body from Slack.
225
+ say (AsyncSay): A function to send a response back to the channel.
226
+ """
227
+ await context.ack()
228
+
229
+ event_profile = SlackEventProfile(**event)
230
+ event_body = SlackEventBody(**body)
231
+
232
+ logger.info(f"on_message, context: {context}")
233
+ logger.info(f"on_message, client: {client}")
234
+ logger.info(f"on_message, event_profile: {event_profile}")
235
+ logger.info(f"on_message, event_body: {event_body}")
236
+ logger.info(f"on_message, say: {say}")
237
+
238
+ logger.info(f"Received message: {event_profile.text}")
239
+
240
+ def mention_me(
241
+ self, context: "AsyncBoltContext", body: SlackEventBody
242
+ ) -> bool:
243
+ r"""Check if the bot is mentioned in the message.
244
+
245
+ Args:
246
+ context (AsyncBoltContext): The Slack Bolt context for the event.
247
+ body (SlackEventBody): The body of the Slack event.
248
+
249
+ Returns:
250
+ bool: True if the bot is mentioned in the message, False otherwise.
251
+ """
252
+ message = body.event.text
253
+ bot_user_id = context.bot_user_id
254
+ mention = f"<@{bot_user_id}>"
255
+ return mention in message
camel/bots/telegram_bot.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ import os
15
+ from typing import TYPE_CHECKING, Optional
16
+
17
+ from camel.agents import ChatAgent
18
+ from camel.messages import BaseMessage
19
+ from camel.utils import dependencies_required
20
+
21
+ # Conditionally import telebot types only for type checking
22
+ if TYPE_CHECKING:
23
+ from telebot.types import ( # type: ignore[import-untyped]
24
+ Message,
25
+ )
26
+
27
+
28
+ class TelegramBot:
29
+ r"""Represents a Telegram bot that is powered by an agent.
30
+
31
+ Attributes:
32
+ chat_agent (ChatAgent): Chat agent that will power the bot.
33
+ telegram_token (str, optional): The bot token.
34
+ """
35
+
36
+ @dependencies_required('telebot')
37
+ def __init__(
38
+ self,
39
+ chat_agent: ChatAgent,
40
+ telegram_token: Optional[str] = None,
41
+ ) -> None:
42
+ self.chat_agent = chat_agent
43
+
44
+ if not telegram_token:
45
+ self.token = os.getenv('TELEGRAM_TOKEN')
46
+ if not self.token:
47
+ raise ValueError(
48
+ "`TELEGRAM_TOKEN` not found in environment variables. "
49
+ "Get it from t.me/BotFather."
50
+ )
51
+ else:
52
+ self.token = telegram_token
53
+
54
+ import telebot # type: ignore[import-untyped]
55
+
56
+ self.bot = telebot.TeleBot(token=self.token)
57
+
58
+ # Register the message handler within the constructor
59
+ self.bot.message_handler(func=lambda message: True)(self.on_message)
60
+
61
+ def run(self) -> None:
62
+ r"""Start the Telegram bot."""
63
+ print("Telegram bot is running...")
64
+ self.bot.infinity_polling()
65
+
66
+ def on_message(self, message: 'Message') -> None:
67
+ r"""Handles incoming messages from the user.
68
+
69
+ Args:
70
+ message (types.Message): The incoming message object.
71
+ """
72
+ self.chat_agent.reset()
73
+
74
+ if not message.text:
75
+ return
76
+
77
+ user_msg = BaseMessage.make_user_message(
78
+ role_name="User", content=message.text
79
+ )
80
+ assistant_response = self.chat_agent.step(user_msg)
81
+
82
+ self.bot.reply_to(message, assistant_response.msg.content)
camel/configs/__init__.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
15
+ from .base_config import BaseConfig
16
+ from .cohere_config import COHERE_API_PARAMS, CohereConfig
17
+ from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
18
+ from .gemini_config import Gemini_API_PARAMS, GeminiConfig
19
+ from .groq_config import GROQ_API_PARAMS, GroqConfig
20
+ from .internlm_config import INTERNLM_API_PARAMS, InternLMConfig
21
+ from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
22
+ from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
23
+ from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
24
+ from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
25
+ from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
26
+ from .qwen_config import QWEN_API_PARAMS, QwenConfig
27
+ from .reka_config import REKA_API_PARAMS, RekaConfig
28
+ from .openrouter_config import OPENROUTER_API_PARAMS, OpenRouterConfig
29
+ from .samba_config import (
30
+ SAMBA_CLOUD_API_PARAMS,
31
+ SAMBA_VERSE_API_PARAMS,
32
+ SambaCloudAPIConfig,
33
+ SambaVerseAPIConfig,
34
+ )
35
+ from .sglang_config import SGLANG_API_PARAMS, SGLangConfig
36
+ from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
37
+ from .vllm_config import VLLM_API_PARAMS, VLLMConfig
38
+ from .yi_config import YI_API_PARAMS, YiConfig
39
+ from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
40
+
41
+ __all__ = [
42
+ 'BaseConfig',
43
+ 'ChatGPTConfig',
44
+ 'OPENAI_API_PARAMS',
45
+ 'AnthropicConfig',
46
+ 'ANTHROPIC_API_PARAMS',
47
+ 'GROQ_API_PARAMS',
48
+ 'GroqConfig',
49
+ 'LiteLLMConfig',
50
+ 'LITELLM_API_PARAMS',
51
+ 'NvidiaConfig',
52
+ 'NVIDIA_API_PARAMS',
53
+ 'OllamaConfig',
54
+ 'OLLAMA_API_PARAMS',
55
+ 'ZhipuAIConfig',
56
+ 'ZHIPUAI_API_PARAMS',
57
+ 'GeminiConfig',
58
+ 'Gemini_API_PARAMS',
59
+ 'VLLMConfig',
60
+ 'VLLM_API_PARAMS',
61
+ 'SGLangConfig',
62
+ 'SGLANG_API_PARAMS',
63
+ 'MistralConfig',
64
+ 'MISTRAL_API_PARAMS',
65
+ 'RekaConfig',
66
+ 'REKA_API_PARAMS',
67
+ 'SambaVerseAPIConfig',
68
+ 'SAMBA_VERSE_API_PARAMS',
69
+ 'SambaCloudAPIConfig',
70
+ 'SAMBA_CLOUD_API_PARAMS',
71
+ 'TogetherAIConfig',
72
+ 'TOGETHERAI_API_PARAMS',
73
+ 'CohereConfig',
74
+ 'COHERE_API_PARAMS',
75
+ 'YiConfig',
76
+ 'YI_API_PARAMS',
77
+ 'QwenConfig',
78
+ 'QWEN_API_PARAMS',
79
+ 'DeepSeekConfig',
80
+ 'DEEPSEEK_API_PARAMS',
81
+ 'InternLMConfig',
82
+ 'INTERNLM_API_PARAMS',
83
+ 'OPENROUTER_API_PARAMS',
84
+ 'OpenRouterConfig',
85
+ ]
camel/configs/anthropic_config.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ from typing import Any, ClassVar, List, Union
17
+
18
+ from camel.configs.base_config import BaseConfig
19
+ from camel.types import NotGiven
20
+
21
+
22
+ class AnthropicConfig(BaseConfig):
23
+ r"""Defines the parameters for generating chat completions using the
24
+ Anthropic API.
25
+
26
+ See: https://docs.anthropic.com/claude/reference/complete_post
27
+ Args:
28
+ max_tokens (int, optional): The maximum number of tokens to
29
+ generate before stopping. Note that Anthropic models may stop
30
+ before reaching this maximum. This parameter only specifies the
31
+ absolute maximum number of tokens to generate.
32
+ (default: :obj:`8192`)
33
+ stop_sequences (List[str], optional): Sequences that will cause the
34
+ model to stop generating completion text. Anthropic models stop
35
+ on "\n\nHuman:", and may include additional built-in stop sequences
36
+ in the future. By providing the stop_sequences parameter, you may
37
+ include additional strings that will cause the model to stop
38
+ generating. (default: :obj:`[]`)
39
+ temperature (float, optional): Amount of randomness injected into the
40
+ response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
41
+ for analytical / multiple choice, and closer to 1 for creative
42
+ and generative tasks. (default: :obj:`1`)
43
+ top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
44
+ compute the cumulative distribution over all the options for each
45
+ subsequent token in decreasing probability order and cut it off
46
+ once it reaches a particular probability specified by `top_p`.
47
+ You should either alter `temperature` or `top_p`,
48
+ but not both. (default: :obj:`0.7`)
49
+ top_k (int, optional): Only sample from the top K options for each
50
+ subsequent token. Used to remove "long tail" low probability
51
+ responses. (default: :obj:`5`)
52
+ metadata: An object describing metadata about the request.
53
+ stream (bool, optional): Whether to incrementally stream the response
54
+ using server-sent events. (default: :obj:`False`)
55
+ """
56
+
57
+ max_tokens: int = 8192
58
+ stop_sequences: ClassVar[Union[List[str], NotGiven]] = []
59
+ temperature: float = 1
60
+ top_p: Union[float, NotGiven] = 0.7
61
+ top_k: Union[int, NotGiven] = 5
62
+ stream: bool = False
63
+
64
+ def as_dict(self) -> dict[str, Any]:
65
+ config_dict = super().as_dict()
66
+ if "tools" in config_dict:
67
+ del config_dict["tools"] # TODO: Support tool calling.
68
+ return config_dict
69
+
70
+
71
+ ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
camel/configs/base_config.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ from abc import ABC
17
+ from typing import Any, List, Optional
18
+
19
+ from pydantic import BaseModel, ConfigDict, field_validator
20
+
21
+
22
+ class BaseConfig(ABC, BaseModel):
23
+ r"""Base configuration class for all models.
24
+
25
+ This class provides a common interface for all models, ensuring that all
26
+ models have a consistent set of attributes and methods.
27
+ """
28
+
29
+ model_config = ConfigDict(
30
+ arbitrary_types_allowed=True,
31
+ extra="forbid",
32
+ frozen=True,
33
+ # UserWarning: conflict with protected namespace "model_"
34
+ protected_namespaces=(),
35
+ )
36
+
37
+ tools: Optional[List[Any]] = None
38
+ """A list of tools the model may
39
+ call. Currently, only functions are supported as a tool. Use this
40
+ to provide a list of functions the model may generate JSON inputs
41
+ for. A max of 128 functions are supported.
42
+ """
43
+
44
+ @field_validator("tools", mode="before")
45
+ @classmethod
46
+ def fields_type_checking(cls, tools):
47
+ r"""Validate the type of tools in the configuration.
48
+
49
+ This method ensures that the tools provided in the configuration are
50
+ instances of `FunctionTool`. If any tool is not an instance of
51
+ `FunctionTool`, it raises a ValueError.
52
+ """
53
+ if tools is not None:
54
+ from camel.toolkits import FunctionTool
55
+
56
+ for tool in tools:
57
+ if not isinstance(tool, FunctionTool):
58
+ raise ValueError(
59
+ f"The tool {tool} should "
60
+ "be an instance of `FunctionTool`."
61
+ )
62
+ return tools
63
+
64
+ def as_dict(self) -> dict[str, Any]:
65
+ r"""Convert the current configuration to a dictionary.
66
+
67
+ This method converts the current configuration object to a dictionary
68
+ representation, which can be used for serialization or other purposes.
69
+
70
+ Returns:
71
+ dict[str, Any]: A dictionary representation of the current
72
+ configuration.
73
+ """
74
+ config_dict = self.model_dump()
75
+
76
+ tools_schema = None
77
+ if self.tools:
78
+ from camel.toolkits import FunctionTool
79
+
80
+ tools_schema = []
81
+ for tool in self.tools:
82
+ if not isinstance(tool, FunctionTool):
83
+ raise ValueError(
84
+ f"The tool {tool} should "
85
+ "be an instance of `FunctionTool`."
86
+ )
87
+ tools_schema.append(tool.get_openai_tool_schema())
88
+ config_dict["tools"] = tools_schema
89
+ return config_dict
camel/configs/cohere_config.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+ from __future__ import annotations
15
+
16
+ from typing import List, Optional
17
+
18
+ from camel.configs.base_config import BaseConfig
19
+
20
+
21
+ class CohereConfig(BaseConfig):
22
+ r"""Defines the parameters for generating chat completions using the
23
+ Cohere API.
24
+
25
+ Args:
26
+ temperature (float, optional): Sampling temperature to use, between
27
+ :obj:`0` and :obj:`2`. Higher values make the output more random,
28
+ while lower values make it more focused and deterministic.
29
+ (default: :obj:`0.3`)
30
+ documents (list, optional): A list of relevant documents that the
31
+ model can cite to generate a more accurate reply. Each document is
32
+ either a string or document object with content and metadata.
33
+ (default: :obj:`None`)
34
+ max_tokens (int, optional): The maximum number of tokens the model
35
+ will generate as part of the response. (default: :obj:`None`)
36
+ stop_sequences (List(str), optional): A list of up to 5 strings that
37
+ the model will use to stop generation. If the model generates a
38
+ string that matches any of the strings in the list, it will stop
39
+ generating tokens and return the generated text up to that point
40
+ not including the stop sequence. (default: :obj:`None`)
41
+ seed (int, optional): If specified, the backend will make a best
42
+ effort to sample tokens deterministically, such that repeated
43
+ requests with the same seed and parameters should return the same
44
+ result. However, determinism cannot be totally guaranteed.
45
+ (default: :obj:`None`)
46
+ frequency_penalty (float, optional): Min value of `0.0`, max value of
47
+ `1.0`. Used to reduce repetitiveness of generated tokens. The
48
+ higher the value, the stronger a penalty is applied to previously
49
+ present tokens, proportional to how many times they have already
50
+ appeared in the prompt or prior generation. (default: :obj:`0.0`)
51
+ presence_penalty (float, optional): Min value of `0.0`, max value of
52
+ `1.0`. Used to reduce repetitiveness of generated tokens. Similar
53
+ to `frequency_penalty`, except that this penalty is applied
54
+ equally to all tokens that have already appeared, regardless of
55
+ their exact frequencies. (default: :obj:`0.0`)
56
+ k (int, optional): Ensures only the top k most likely tokens are
57
+ considered for generation at each step. Min value of `0`, max
58
+ value of `500`. (default: :obj:`0`)
59
+ p (float, optional): Ensures that only the most likely tokens, with
60
+ total probability mass of `p`, are considered for generation at
61
+ each step. If both k and p are enabled, `p` acts after `k`. Min
62
+ value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
63
+ """
64
+
65
+ temperature: Optional[float] = 0.2
66
+ documents: Optional[list] = None
67
+ max_tokens: Optional[int] = None
68
+ stop_sequences: Optional[List[str]] = None
69
+ seed: Optional[int] = None
70
+ frequency_penalty: Optional[float] = 0.0
71
+ presence_penalty: Optional[float] = 0.0
72
+ k: Optional[int] = 0
73
+ p: Optional[float] = 0.75
74
+
75
+
76
+ COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
camel/configs/deepseek_config.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, Optional, Sequence, Type, Union
18
+
19
+ from pydantic import BaseModel
20
+
21
+ from camel.configs.base_config import BaseConfig
22
+ from camel.types import NOT_GIVEN, NotGiven
23
+
24
+
25
+ class DeepSeekConfig(BaseConfig):
26
+ r"""Defines the parameters for generating chat completions using the
27
+ DeepSeek API.
28
+
29
+ Args:
30
+ temperature (float, optional): Sampling temperature to use, between
31
+ :obj:`0` and :obj:`2`. Higher values make the output more random,
32
+ while lower values make it more focused and deterministic.
33
+ (default: :obj:`1.0`)
34
+ top_p (float, optional): Controls the diversity and focus of the
35
+ generated results. Higher values make the output more diverse,
36
+ while lower values make it more focused. (default: :obj:`1.0`)
37
+ response_format (object, optional): Specifies the format of the
38
+ returned content. The available values are `{"type": "text"}` or
39
+ `{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
40
+ will output a standard JSON string.
41
+ (default: :obj:`{"type": "text"}`)
42
+ stream (bool, optional): If set, partial message deltas will be sent.
43
+ Tokens will be sent as data-only server-sent events (SSE) as
44
+ they become available, with the stream terminated by a
45
+ data: [DONE] message. (default: :obj:`False`)
46
+ stop (Union[str, list[str]], optional): Up to 16 sequences where
47
+ the API will stop generating further tokens. (default: :obj:`None`)
48
+ max_tokens (int, optional): The maximum number of tokens that can
49
+ be generated in the chat completion. The total length of input
50
+ tokens and generated tokens is limited by the model's context
51
+ length. (default: :obj:`None`)
52
+ presence_penalty (float, optional): Number between -2.0 and 2.0.
53
+ Positive values penalize new tokens based on whether they
54
+ appear in the text so far, increasing the model's likelihood
55
+ to talk about new topics. (default: :obj:`0.0`)
56
+ frequency_penalty (float, optional): Number between -2.0 and 2.0.
57
+ Positive values penalize new tokens based on their existing
58
+ frequency in the text so far, decreasing the model's likelihood
59
+ to repeat the same line verbatim. (default: :obj:`0`)
60
+ tools (list[FunctionTool], optional): A list of tools the model may
61
+ call. Currently, only functions are supported as a tool. Use
62
+ this to provide a list of functions the model may generate JSON
63
+ inputs for. A max of 128 functions are supported.
64
+ (default: :obj:`None`)
65
+ tool_choice (Union[dict[str, str], str], optional): Controls which
66
+ (if any) tool is called by the model. "none" means the model
67
+ will not call any tool and instead generates a message. "auto"
68
+ means the model can pick between generating a message or calling
69
+ one or more tools. "required" means the model must call one or
70
+ more tools. Specifying a particular tool via
71
+ {"type": "function", "function": {"name": "my_function"}} forces
72
+ the model to call that tool. "none" is the default when no tools
73
+ are present. "auto" is the default if tools are present.
74
+ (default: :obj:`"auto"`)
75
+ logprobs (bool, optional): Whether to return log probabilities of
76
+ the output tokens or not. If true, returns the log probabilities
77
+ of each output token returned in the content of message.
78
+ (default: :obj:`False`)
79
+ top_logprobs (int, optional): An integer between 0 and 20 specifying
80
+ the number of most likely tokens to return at each token
81
+ position, each with an associated log probability. logprobs
82
+ must be set to true if this parameter is used.
83
+ (default: :obj:`None`)
84
+ include_usage (bool, optional): When streaming, specifies whether to
85
+ include usage information in `stream_options`. (default:
86
+ :obj:`True`)
87
+ """
88
+
89
+ temperature: float = 1.0 # deepseek default: 1.0
90
+ top_p: float = 1.0
91
+ stream: bool = False
92
+ stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
93
+ max_tokens: Union[int, NotGiven] = NOT_GIVEN
94
+ presence_penalty: float = 0.0
95
+ response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
96
+ frequency_penalty: float = 0.0
97
+ tool_choice: Optional[Union[dict[str, str], str]] = None
98
+ logprobs: bool = False
99
+ top_logprobs: Optional[int] = None
100
+
101
+ def __init__(self, include_usage: bool = True, **kwargs):
102
+ super().__init__(**kwargs)
103
+ # Only set stream_options when stream is True
104
+ # Otherwise, it will raise error when calling the API
105
+ if self.stream:
106
+ self.stream_options = {"include_usage": include_usage}
107
+
108
+ def as_dict(self) -> dict[str, Any]:
109
+ r"""Convert the current configuration to a dictionary.
110
+
111
+ This method converts the current configuration object to a dictionary
112
+ representation, which can be used for serialization or other purposes.
113
+
114
+ Returns:
115
+ dict[str, Any]: A dictionary representation of the current
116
+ configuration.
117
+ """
118
+ config_dict = self.model_dump()
119
+ if self.tools:
120
+ from camel.toolkits import FunctionTool
121
+
122
+ tools_schema = []
123
+ for tool in self.tools:
124
+ if not isinstance(tool, FunctionTool):
125
+ raise ValueError(
126
+ f"The tool {tool} should "
127
+ "be an instance of `FunctionTool`."
128
+ )
129
+ tools_schema.append(tool.get_openai_tool_schema())
130
+ config_dict["tools"] = NOT_GIVEN
131
+ return config_dict
132
+
133
+
134
+ DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
camel/configs/gemini_config.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Any, Optional, Sequence, Type, Union
18
+
19
+ from pydantic import BaseModel
20
+
21
+ from camel.configs.base_config import BaseConfig
22
+ from camel.types import NOT_GIVEN, NotGiven
23
+
24
+
25
+ class GeminiConfig(BaseConfig):
26
+ r"""Defines the parameters for generating chat completions using the
27
+ Gemini API.
28
+
29
+ Args:
30
+ temperature (float, optional): Sampling temperature to use, between
31
+ :obj:`0` and :obj:`2`. Higher values make the output more random,
32
+ while lower values make it more focused and deterministic.
33
+ (default: :obj:`0.2`)
34
+ top_p (float, optional): An alternative to sampling with temperature,
35
+ called nucleus sampling, where the model considers the results of
36
+ the tokens with top_p probability mass. So :obj:`0.1` means only
37
+ the tokens comprising the top 10% probability mass are considered.
38
+ (default: :obj:`1.0`)
39
+ n (int, optional): How many chat completion choices to generate for
40
+ each input message. (default: :obj:`1`)
41
+ response_format (object, optional): An object specifying the format
42
+ that the model must output. Compatible with GPT-4 Turbo and all
43
+ GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
44
+ {"type": "json_object"} enables JSON mode, which guarantees the
45
+ message the model generates is valid JSON. Important: when using
46
+ JSON mode, you must also instruct the model to produce JSON
47
+ yourself via a system or user message. Without this, the model
48
+ may generate an unending stream of whitespace until the generation
49
+ reaches the token limit, resulting in a long-running and seemingly
50
+ "stuck" request. Also note that the message content may be
51
+ partially cut off if finish_reason="length", which indicates the
52
+ generation exceeded max_tokens or the conversation exceeded the
53
+ max context length.
54
+ stream (bool, optional): If True, partial message deltas will be sent
55
+ as data-only server-sent events as they become available.
56
+ (default: :obj:`False`)
57
+ stop (str or list, optional): Up to :obj:`4` sequences where the API
58
+ will stop generating further tokens. (default: :obj:`None`)
59
+ max_tokens (int, optional): The maximum number of tokens to generate
60
+ in the chat completion. The total length of input tokens and
61
+ generated tokens is limited by the model's context length.
62
+ (default: :obj:`None`)
63
+ tools (list[FunctionTool], optional): A list of tools the model may
64
+ call. Currently, only functions are supported as a tool. Use this
65
+ to provide a list of functions the model may generate JSON inputs
66
+ for. A max of 128 functions are supported.
67
+ tool_choice (Union[dict[str, str], str], optional): Controls which (if
68
+ any) tool is called by the model. :obj:`"none"` means the model
69
+ will not call any tool and instead generates a message.
70
+ :obj:`"auto"` means the model can pick between generating a
71
+ message or calling one or more tools. :obj:`"required"` means the
72
+ model must call one or more tools. Specifying a particular tool
73
+ via {"type": "function", "function": {"name": "my_function"}}
74
+ forces the model to call that tool. :obj:`"none"` is the default
75
+ when no tools are present. :obj:`"auto"` is the default if tools
76
+ are present.
77
+ """
78
+
79
+ temperature: float = 0.2 # openai default: 1.0
80
+ top_p: float = 1.0
81
+ n: int = 1
82
+ stream: bool = False
83
+ stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
84
+ max_tokens: Union[int, NotGiven] = NOT_GIVEN
85
+ response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
86
+ tool_choice: Optional[Union[dict[str, str], str, NotGiven]] = NOT_GIVEN
87
+
88
+ def as_dict(self) -> dict[str, Any]:
89
+ r"""Convert the current configuration to a dictionary.
90
+
91
+ This method converts the current configuration object to a dictionary
92
+ representation, which can be used for serialization or other purposes.
93
+
94
+ Returns:
95
+ dict[str, Any]: A dictionary representation of the current
96
+ configuration.
97
+ """
98
+ config_dict = self.model_dump()
99
+ if self.tools:
100
+ from camel.toolkits import FunctionTool
101
+
102
+ tools_schema = []
103
+ for tool in self.tools:
104
+ if not isinstance(tool, FunctionTool):
105
+ raise ValueError(
106
+ f"The tool {tool} should "
107
+ "be an instance of `FunctionTool`."
108
+ )
109
+ tools_schema.append(tool.get_openai_tool_schema())
110
+ config_dict["tools"] = NOT_GIVEN
111
+ return config_dict
112
+
113
+
114
+ Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}