upd code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +16 -0
- LICENSE +21 -0
- ProjectPageAgent/__init__.py +7 -0
- ProjectPageAgent/content_planner.py +509 -0
- ProjectPageAgent/css_checker.py +111 -0
- ProjectPageAgent/html_finder.py +32 -0
- ProjectPageAgent/html_generator.py +633 -0
- ProjectPageAgent/main_pipline.py +379 -0
- ProjectPageAgent/parse_paper.py +88 -0
- ProjectPageAgent/parse_raw.py +256 -0
- ProjectPageAgent/template_analyzer.py +436 -0
- app.py +1671 -0
- camel/__init__.py +25 -0
- camel/agents/__init__.py +44 -0
- camel/agents/base.py +29 -0
- camel/agents/chat_agent.py +1539 -0
- camel/agents/critic_agent.py +202 -0
- camel/agents/deductive_reasoner_agent.py +303 -0
- camel/agents/embodied_agent.py +201 -0
- camel/agents/knowledge_graph_agent.py +259 -0
- camel/agents/multi_hop_generator_agent.py +117 -0
- camel/agents/programmed_agent_instruction.py +203 -0
- camel/agents/role_assignment_agent.py +141 -0
- camel/agents/search_agent.py +133 -0
- camel/agents/task_agent.py +410 -0
- camel/agents/tool_agents/__init__.py +20 -0
- camel/agents/tool_agents/base.py +39 -0
- camel/agents/tool_agents/hugging_face_tool_agent.py +206 -0
- camel/benchmarks/__init__.py +30 -0
- camel/benchmarks/apibank.py +565 -0
- camel/benchmarks/apibench.py +500 -0
- camel/benchmarks/base.py +152 -0
- camel/benchmarks/gaia.py +478 -0
- camel/benchmarks/nexus.py +518 -0
- camel/benchmarks/ragbench.py +333 -0
- camel/bots/__init__.py +34 -0
- camel/bots/discord/__init__.py +26 -0
- camel/bots/discord/discord_app.py +384 -0
- camel/bots/discord/discord_installation.py +64 -0
- camel/bots/discord/discord_store.py +160 -0
- camel/bots/slack/__init__.py +30 -0
- camel/bots/slack/models.py +158 -0
- camel/bots/slack/slack_app.py +255 -0
- camel/bots/telegram_bot.py +82 -0
- camel/configs/__init__.py +85 -0
- camel/configs/anthropic_config.py +71 -0
- camel/configs/base_config.py +89 -0
- camel/configs/cohere_config.py +76 -0
- camel/configs/deepseek_config.py +134 -0
- camel/configs/gemini_config.py +114 -0
.gitignore
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
templates/**/*.wav
|
| 2 |
+
templates/**/*.mp4
|
| 3 |
+
templates/**/*.gif
|
| 4 |
+
templates/**/*.webm
|
| 5 |
+
templates/**/*.mov
|
| 6 |
+
templates/**/*.pdf*.ttf
|
| 7 |
+
templates/**/*.pdf
|
| 8 |
+
templates/**/*?
|
| 9 |
+
*.woff
|
| 10 |
+
*.woff2
|
| 11 |
+
*.png
|
| 12 |
+
*.jpg
|
| 13 |
+
|
| 14 |
+
.DS_Store
|
| 15 |
+
|
| 16 |
+
**/__pycache__/*
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Qianli Ma
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
ProjectPageAgent/__init__.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ProjectPageAgent: A multi-agent system for generating project pages from research papers.
|
| 3 |
+
Based on Paper2Poster architecture, adapted for project page generation.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
__version__ = "1.0.0"
|
| 7 |
+
__author__ = "Paper2ProjectPage Team"
|
ProjectPageAgent/content_planner.py
ADDED
|
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Content planner for project page generation.
|
| 3 |
+
Plans the structure and content organization for the project page.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import yaml
|
| 8 |
+
import os
|
| 9 |
+
from jinja2 import Environment, StrictUndefined
|
| 10 |
+
from camel.models import ModelFactory
|
| 11 |
+
from camel.agents import ChatAgent
|
| 12 |
+
from utils.wei_utils import account_token
|
| 13 |
+
from utils.src.utils import get_json_from_response
|
| 14 |
+
from camel.messages import BaseMessage
|
| 15 |
+
from rich import print
|
| 16 |
+
from rich.pretty import Pretty
|
| 17 |
+
import base64
|
| 18 |
+
from camel.messages import BaseMessage
|
| 19 |
+
from camel.models import ModelFactory
|
| 20 |
+
|
| 21 |
+
def filter_references(md_content: str) -> str:
|
| 22 |
+
|
| 23 |
+
lines = md_content.splitlines()
|
| 24 |
+
result_lines = []
|
| 25 |
+
for line in lines:
|
| 26 |
+
if line.strip().lower().startswith("## references"):
|
| 27 |
+
break
|
| 28 |
+
result_lines.append(line)
|
| 29 |
+
return "\n".join(result_lines)
|
| 30 |
+
|
| 31 |
+
class ProjectPageContentPlanner:
|
| 32 |
+
"""Plans the content structure and organization for project pages."""
|
| 33 |
+
|
| 34 |
+
def __init__(self, agent_config, args):
|
| 35 |
+
self.agent_config = agent_config
|
| 36 |
+
self.args = args
|
| 37 |
+
self.planner_agent = self._create_planner_agent()
|
| 38 |
+
self.reviewer_agent = self._create_reviewer_agent()
|
| 39 |
+
os.makedirs('project_contents', exist_ok=True)
|
| 40 |
+
|
| 41 |
+
def _create_planner_agent(self):
|
| 42 |
+
"""Create the content planning (generation) agent."""
|
| 43 |
+
model_type = str(self.agent_config['model_type'])
|
| 44 |
+
|
| 45 |
+
# Get API key from environment variables
|
| 46 |
+
api_key = None
|
| 47 |
+
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 48 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 49 |
+
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 50 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 51 |
+
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 52 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 53 |
+
elif self.args.model_name_t.startswith('openrouter_'):
|
| 54 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 55 |
+
elif self.args.model_name_t in ['zhipuai']:
|
| 56 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 57 |
+
|
| 58 |
+
if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
|
| 59 |
+
model = ModelFactory.create(
|
| 60 |
+
model_platform=self.agent_config['model_platform'],
|
| 61 |
+
model_type=self.agent_config['model_type'],
|
| 62 |
+
model_config_dict=self.agent_config['model_config'],
|
| 63 |
+
url=self.agent_config.get('url', None),
|
| 64 |
+
api_key=api_key,
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
model = ModelFactory.create(
|
| 68 |
+
model_platform=self.agent_config['model_platform'],
|
| 69 |
+
model_type=self.agent_config['model_type'],
|
| 70 |
+
model_config_dict=self.agent_config['model_config'],
|
| 71 |
+
api_key=api_key,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
system_message = """You are a helpful academic expert and web developer, who is specialized in generating a paper project page, from given research paper's contents and figures."""
|
| 76 |
+
|
| 77 |
+
return ChatAgent(
|
| 78 |
+
system_message=system_message,
|
| 79 |
+
model=model,
|
| 80 |
+
message_window_size=10,
|
| 81 |
+
token_limit=self.agent_config.get('token_limit', None)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def _create_reviewer_agent(self):
|
| 85 |
+
|
| 86 |
+
model_type = str(self.agent_config['model_type'])
|
| 87 |
+
|
| 88 |
+
# Get API key from environment variables
|
| 89 |
+
api_key = None
|
| 90 |
+
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 91 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 92 |
+
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 93 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 94 |
+
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 95 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 96 |
+
elif self.args.model_name_t.startswith('openrouter_'):
|
| 97 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 98 |
+
elif self.args.model_name_t in ['zhipuai']:
|
| 99 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 100 |
+
|
| 101 |
+
if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
|
| 102 |
+
model = ModelFactory.create(
|
| 103 |
+
model_platform=self.agent_config['model_platform'],
|
| 104 |
+
model_type=self.agent_config['model_type'],
|
| 105 |
+
model_config_dict=self.agent_config['model_config'],
|
| 106 |
+
url=self.agent_config.get('url', None),
|
| 107 |
+
api_key=api_key,
|
| 108 |
+
)
|
| 109 |
+
else:
|
| 110 |
+
model = ModelFactory.create(
|
| 111 |
+
model_platform=self.agent_config['model_platform'],
|
| 112 |
+
model_type=self.agent_config['model_type'],
|
| 113 |
+
model_config_dict=self.agent_config['model_config'],
|
| 114 |
+
api_key=api_key,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
reviewer_system = (
|
| 118 |
+
"You are a precise, constructive reviewer of generated project pages. "
|
| 119 |
+
)
|
| 120 |
+
return ChatAgent(
|
| 121 |
+
system_message=reviewer_system,
|
| 122 |
+
model=model,
|
| 123 |
+
message_window_size=10,
|
| 124 |
+
token_limit=self.agent_config.get('token_limit', None)
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def _render_generation_prompt(self, paper_content, figures, text_page_content, template_str):
|
| 128 |
+
|
| 129 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 130 |
+
template = jinja_env.from_string(template_str)
|
| 131 |
+
jinja_args = {
|
| 132 |
+
'paper_content': paper_content,
|
| 133 |
+
'figures': json.dumps(figures, indent=2),
|
| 134 |
+
'project_page_content': json.dumps(text_page_content, indent=2),
|
| 135 |
+
}
|
| 136 |
+
return template.render(**jinja_args)
|
| 137 |
+
|
| 138 |
+
def _build_reviewer_prompt(self, paper_content, figures, text_page_content, generated_json):
|
| 139 |
+
|
| 140 |
+
with open('utils/prompt_templates/page_templates/full_content_review.yaml', 'r') as f:
|
| 141 |
+
planner_config = yaml.safe_load(f)
|
| 142 |
+
|
| 143 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 144 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 145 |
+
|
| 146 |
+
jinja_args = {
|
| 147 |
+
'paper_content': paper_content,
|
| 148 |
+
'figures': json.dumps(figures['images'], indent=2),
|
| 149 |
+
'tables': json.dumps(figures['tables'], indent=2),
|
| 150 |
+
"generated_content": generated_json
|
| 151 |
+
}
|
| 152 |
+
|
| 153 |
+
prompt = template.render(**jinja_args)
|
| 154 |
+
|
| 155 |
+
return prompt
|
| 156 |
+
|
| 157 |
+
def _build_revision_prompt(self, review_json):
|
| 158 |
+
with open('utils/prompt_templates/page_templates/full_content_revise.yaml', 'r') as f:
|
| 159 |
+
planner_config = yaml.safe_load(f)
|
| 160 |
+
|
| 161 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 162 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 163 |
+
|
| 164 |
+
jinja_args = {
|
| 165 |
+
"review_content": json.dumps(review_json, indent=2)
|
| 166 |
+
}
|
| 167 |
+
|
| 168 |
+
prompt = template.render(**jinja_args)
|
| 169 |
+
|
| 170 |
+
return prompt
|
| 171 |
+
|
| 172 |
+
def _build_revision_prompt_with_resume(self, review_json, current_content, figures):
|
| 173 |
+
with open('utils/prompt_templates/page_templates/full_content_revise_with_resume.yaml', 'r') as f:
|
| 174 |
+
planner_config = yaml.safe_load(f)
|
| 175 |
+
|
| 176 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 177 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 178 |
+
|
| 179 |
+
print(review_json)
|
| 180 |
+
|
| 181 |
+
jinja_args = {
|
| 182 |
+
"review_content": json.dumps(review_json, indent=2),
|
| 183 |
+
"figures": json.dumps(figures, indent=2),
|
| 184 |
+
"current_content": current_content
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
prompt = template.render(**jinja_args)
|
| 188 |
+
|
| 189 |
+
return prompt
|
| 190 |
+
|
| 191 |
+
def full_content_generation(
|
| 192 |
+
self,
|
| 193 |
+
args,
|
| 194 |
+
paper_content,
|
| 195 |
+
figures,
|
| 196 |
+
generated_section,
|
| 197 |
+
text_page_content,
|
| 198 |
+
):
|
| 199 |
+
"""
|
| 200 |
+
Plan + Generate -> Review -> Revise
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
paper_content: parsed paper content
|
| 204 |
+
figures: list/dict of figures
|
| 205 |
+
generated_section: format_instructions / schema hints
|
| 206 |
+
text_page_content: initial text-only page structure
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
tuple: (final_generated_content_json, input_token_total, output_token_total)
|
| 210 |
+
"""
|
| 211 |
+
if args.resume in ['parse_pdf','generate_content']:
|
| 212 |
+
|
| 213 |
+
print("full content generation start")
|
| 214 |
+
|
| 215 |
+
with open('utils/prompt_templates/page_templates/full_content_generation.yaml', 'r') as f:
|
| 216 |
+
planner_config = yaml.safe_load(f)
|
| 217 |
+
|
| 218 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 219 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 220 |
+
|
| 221 |
+
jinja_args = {
|
| 222 |
+
'paper_content': paper_content,
|
| 223 |
+
'figures': json.dumps(figures, indent=2),
|
| 224 |
+
'project_page_content': json.dumps(text_page_content, indent=2)
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
prompt = template.render(**jinja_args)
|
| 228 |
+
|
| 229 |
+
self.planner_agent.reset()
|
| 230 |
+
response = self.planner_agent.step(prompt)
|
| 231 |
+
|
| 232 |
+
gen_in_tok, gen_out_tok = account_token(response)
|
| 233 |
+
|
| 234 |
+
current_output = get_json_from_response(response.msgs[0].content)
|
| 235 |
+
|
| 236 |
+
first_path = f'project_contents/{self.args.paper_name}_generated_full_content.v0.json'
|
| 237 |
+
with open(first_path, 'w', encoding='utf-8') as f:
|
| 238 |
+
json.dump(current_output, f, ensure_ascii=False, indent=2)
|
| 239 |
+
print(f" - Initial generation saved: {first_path}")
|
| 240 |
+
|
| 241 |
+
total_in_tok, total_out_tok = gen_in_tok, gen_out_tok
|
| 242 |
+
else:
|
| 243 |
+
print("Skipping initial full content generation, loading existing content.")
|
| 244 |
+
with open(f'project_contents/{self.args.paper_name}_generated_full_content.v0.json', 'r', encoding='utf-8') as f:
|
| 245 |
+
current_output = json.load(f)
|
| 246 |
+
total_in_tok, total_out_tok = 0, 0
|
| 247 |
+
|
| 248 |
+
for it in range(0, args.full_content_check_times):
|
| 249 |
+
# check
|
| 250 |
+
self.reviewer_agent.reset()
|
| 251 |
+
|
| 252 |
+
review_prompt = self._build_reviewer_prompt(
|
| 253 |
+
paper_content=paper_content,
|
| 254 |
+
figures=figures,
|
| 255 |
+
text_page_content=text_page_content,
|
| 256 |
+
generated_json=current_output
|
| 257 |
+
)
|
| 258 |
+
review_resp = self.reviewer_agent.step(review_prompt)
|
| 259 |
+
rin, rout = account_token(review_resp)
|
| 260 |
+
|
| 261 |
+
review_json = get_json_from_response(review_resp.msgs[0].content)
|
| 262 |
+
|
| 263 |
+
review_path = f'project_contents/{self.args.paper_name}_review.iter{it}.json'
|
| 264 |
+
with open(review_path, 'w', encoding='utf-8') as f:
|
| 265 |
+
json.dump(review_json, f, ensure_ascii=False, indent=2)
|
| 266 |
+
print(f" - Review saved: {review_path}")
|
| 267 |
+
|
| 268 |
+
total_in_tok += rin
|
| 269 |
+
total_out_tok += rout
|
| 270 |
+
|
| 271 |
+
if args.resume != 'full_content_check':
|
| 272 |
+
revision_prompt = self._build_revision_prompt(
|
| 273 |
+
review_json=review_json
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
else:
|
| 277 |
+
revision_prompt = self._build_revision_prompt_with_resume(
|
| 278 |
+
review_json=review_json,
|
| 279 |
+
current_content=current_output,
|
| 280 |
+
figures=figures
|
| 281 |
+
)
|
| 282 |
+
rev_resp = self.planner_agent.step(revision_prompt)
|
| 283 |
+
rin2, rout2 = account_token(rev_resp)
|
| 284 |
+
|
| 285 |
+
revised_output = get_json_from_response(rev_resp.msgs[0].content)
|
| 286 |
+
|
| 287 |
+
out_path = f'project_contents/{self.args.paper_name}_generated_full_content.v{it+1}.json'
|
| 288 |
+
with open(out_path, 'w', encoding='utf-8') as f:
|
| 289 |
+
json.dump(revised_output, f, ensure_ascii=False, indent=2)
|
| 290 |
+
print(f" - Revised generation saved: {out_path}")
|
| 291 |
+
|
| 292 |
+
total_in_tok += rin2
|
| 293 |
+
total_out_tok += rout2
|
| 294 |
+
current_output = revised_output
|
| 295 |
+
if self.args.human_input == '1':
|
| 296 |
+
print('-'*50)
|
| 297 |
+
print(Pretty(current_output, expand_all=True))
|
| 298 |
+
print('-'*50)
|
| 299 |
+
user_feedback = input('The above is the final generated full content! If you are satisfied with the generated content, enter yes\n If not, enter your feedback.\n')
|
| 300 |
+
while user_feedback.lower() != 'yes':
|
| 301 |
+
message = BaseMessage.make_assistant_message(
|
| 302 |
+
role_name='User',
|
| 303 |
+
content='human feedback'+user_feedback +"The above is human feedback. Please make modifications based on this feedback and the original content.The output format is as specified above."
|
| 304 |
+
)
|
| 305 |
+
response = self.planner_agent.step(message)
|
| 306 |
+
current_output = get_json_from_response(response.msgs[0].content)
|
| 307 |
+
print('-'*50)
|
| 308 |
+
print(Pretty(current_output, expand_all=True))
|
| 309 |
+
print('-'*50)
|
| 310 |
+
user_feedback = input('The above is the final generated full content! If you are satisfied with the generated content, enter yes. \n If not, enter your feedback.\n')
|
| 311 |
+
in_tok, out_tok = account_token(response)
|
| 312 |
+
total_in_tok += in_tok
|
| 313 |
+
total_out_tok += out_tok
|
| 314 |
+
|
| 315 |
+
# 4) 最终保存(保持你原有的命名)
|
| 316 |
+
final_path = f'project_contents/{self.args.paper_name}_generated_full_content.json'
|
| 317 |
+
with open(final_path, 'w', encoding='utf-8') as f:
|
| 318 |
+
json.dump(current_output, f, ensure_ascii=False, indent=2)
|
| 319 |
+
print(f"full content generation completed. Tokens: {total_in_tok} -> {total_out_tok}")
|
| 320 |
+
print(f" - Final content: {final_path}")
|
| 321 |
+
|
| 322 |
+
return current_output, total_in_tok, total_out_tok
|
| 323 |
+
|
| 324 |
+
def section_generation(self, paper_content, figures):
|
| 325 |
+
"""
|
| 326 |
+
Plan the content structure for the project page.
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
paper_content: Parsed paper content
|
| 330 |
+
|
| 331 |
+
Returns:
|
| 332 |
+
dict: project page content
|
| 333 |
+
"""
|
| 334 |
+
|
| 335 |
+
# Load planning prompt template
|
| 336 |
+
|
| 337 |
+
with open('utils/prompt_templates/page_templates/section_generation.yaml', 'r') as f:
|
| 338 |
+
planner_config = yaml.safe_load(f)
|
| 339 |
+
|
| 340 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 341 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 342 |
+
|
| 343 |
+
json_format_example = """
|
| 344 |
+
```json
|
| 345 |
+
{{
|
| 346 |
+
"Introduction": "Brief overview of the paper's main topic and objectives.",
|
| 347 |
+
"Methodology": "Description of the methods used in the research.",
|
| 348 |
+
"Results": "Summary of the key findings and results."
|
| 349 |
+
}}
|
| 350 |
+
```
|
| 351 |
+
"""
|
| 352 |
+
|
| 353 |
+
# Prepare template arguments
|
| 354 |
+
jinja_args = {
|
| 355 |
+
'paper_content': paper_content,
|
| 356 |
+
'json_format_example': json.dumps(paper_content, indent=2)
|
| 357 |
+
}
|
| 358 |
+
|
| 359 |
+
prompt = template.render(**jinja_args)
|
| 360 |
+
|
| 361 |
+
# Generate content plan
|
| 362 |
+
self.planner_agent.reset()
|
| 363 |
+
response = self.planner_agent.step(prompt)
|
| 364 |
+
input_token, output_token = account_token(response)
|
| 365 |
+
generated_section = get_json_from_response(response.msgs[0].content)
|
| 366 |
+
|
| 367 |
+
if self.args.human_input == '1':
|
| 368 |
+
print('-'*50)
|
| 369 |
+
print(Pretty(generated_section, expand_all=True))
|
| 370 |
+
print('-'*50)
|
| 371 |
+
user_feedback = input('The above is the generated section! If you are satisfied with the generated section, enter yes. \nIf not, enter your feedback.\n')
|
| 372 |
+
while user_feedback.lower() != 'yes':
|
| 373 |
+
message = BaseMessage.make_assistant_message(
|
| 374 |
+
role_name='User',
|
| 375 |
+
content='human feedback'+user_feedback +"The above is human feedback. Please make modifications based on this feedback and the original content.The output format is as specified above."
|
| 376 |
+
)
|
| 377 |
+
response = self.planner_agent.step(message)
|
| 378 |
+
generated_section = get_json_from_response(response.msgs[0].content)
|
| 379 |
+
print('-'*50)
|
| 380 |
+
print(Pretty(generated_section, expand_all=True))
|
| 381 |
+
print('-'*50)
|
| 382 |
+
user_feedback = input('The above is the generated section! If you are satisfied with the generated section, enter yes. \nIf not, enter your feedback.\n')
|
| 383 |
+
in_tok, out_tok = account_token(response)
|
| 384 |
+
input_token += in_tok
|
| 385 |
+
output_token += out_tok
|
| 386 |
+
|
| 387 |
+
print(f"section planning completed. Tokens: {input_token} -> {output_token}")
|
| 388 |
+
|
| 389 |
+
def create_dynamic_page_dict(sections: dict[str, str]) -> dict[str, str]:
|
| 390 |
+
poster_dict = {
|
| 391 |
+
"title": "Title of the paper",
|
| 392 |
+
"authors": "Authors of the paper, Each author must be accompanied by the superscript number(s) of their corresponding affiliation(s).",
|
| 393 |
+
"affiliation": "Affiliation of the authors, each affiliation must be accompanied by the corresponding superscript number.",
|
| 394 |
+
}
|
| 395 |
+
|
| 396 |
+
poster_dict.update(sections)
|
| 397 |
+
return poster_dict
|
| 398 |
+
|
| 399 |
+
generated_section = create_dynamic_page_dict(generated_section)
|
| 400 |
+
|
| 401 |
+
# Save generated content
|
| 402 |
+
# print(self.agent_config)
|
| 403 |
+
generated_path = f'project_contents/{self.args.paper_name}_generated_section.json'
|
| 404 |
+
with open(generated_path, 'w') as f:
|
| 405 |
+
json.dump(generated_section, f, indent=4)
|
| 406 |
+
|
| 407 |
+
print(f" - Generated section plan: {generated_path}")
|
| 408 |
+
|
| 409 |
+
return generated_section, input_token, output_token
|
| 410 |
+
|
| 411 |
+
def text_content_generation(self, paper_content, figures, generated_section):
|
| 412 |
+
"""
|
| 413 |
+
Plan the content structure for the project page.
|
| 414 |
+
|
| 415 |
+
Args:
|
| 416 |
+
paper_content: Parsed paper content
|
| 417 |
+
|
| 418 |
+
Returns:
|
| 419 |
+
dict: project page content
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
# Delete tags in figures
|
| 423 |
+
figures_ = {}
|
| 424 |
+
figures_['images'] = [{k: v for k, v in value.items() if k != 'tag'} for value in figures['images'].values()]
|
| 425 |
+
figures_['tables'] = [{k: v for k, v in value.items() if k != 'tag'} for value in figures['tables'].values()]
|
| 426 |
+
|
| 427 |
+
# Load planning prompt template
|
| 428 |
+
with open('utils/prompt_templates/page_templates/text_content_generation.yaml', 'r') as f:
|
| 429 |
+
planner_config = yaml.safe_load(f)
|
| 430 |
+
|
| 431 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 432 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 433 |
+
|
| 434 |
+
# Prepare template arguments
|
| 435 |
+
jinja_args = {
|
| 436 |
+
'paper_content': paper_content,
|
| 437 |
+
'figures': json.dumps(figures_, indent=2),
|
| 438 |
+
'format_instructions': json.dumps(generated_section, indent=2)
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
prompt = template.render(**jinja_args)
|
| 442 |
+
|
| 443 |
+
# Generate content plan
|
| 444 |
+
self.planner_agent.reset()
|
| 445 |
+
response = self.planner_agent.step(prompt)
|
| 446 |
+
input_token, output_token = account_token(response)
|
| 447 |
+
|
| 448 |
+
generated_text_content = get_json_from_response(response.msgs[0].content)
|
| 449 |
+
|
| 450 |
+
print(f"text content generation completed. Tokens: {input_token} -> {output_token}")
|
| 451 |
+
|
| 452 |
+
# Save generated content
|
| 453 |
+
generated_path = f'project_contents/{self.args.paper_name}_generated_text_content.json'
|
| 454 |
+
with open(generated_path, 'w') as f:
|
| 455 |
+
json.dump(generated_text_content, f, indent=4)
|
| 456 |
+
|
| 457 |
+
print(f" - Generated text content: {generated_path}")
|
| 458 |
+
|
| 459 |
+
return generated_text_content, input_token, output_token
|
| 460 |
+
|
| 461 |
+
def filter_raw_content(self, paper_content, figures):
|
| 462 |
+
paper_content = filter_references(paper_content)
|
| 463 |
+
# Load planning prompt template
|
| 464 |
+
with open('utils/prompt_templates/page_templates/filter_figures.yaml', 'r') as f:
|
| 465 |
+
planner_config = yaml.safe_load(f)
|
| 466 |
+
|
| 467 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 468 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 469 |
+
|
| 470 |
+
# Prepare template arguments
|
| 471 |
+
jinja_args = {
|
| 472 |
+
'paper_content': paper_content,
|
| 473 |
+
'figures': json.dumps(figures, indent=2),
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
prompt = template.render(**jinja_args)
|
| 477 |
+
|
| 478 |
+
# Generate filtered figures
|
| 479 |
+
self.planner_agent.reset()
|
| 480 |
+
response = self.planner_agent.step(prompt)
|
| 481 |
+
input_token, output_token = account_token(response)
|
| 482 |
+
filtered_figures = get_json_from_response(response.msgs[0].content)
|
| 483 |
+
#print(filtered_figures)
|
| 484 |
+
|
| 485 |
+
def remove_items_without_section(data: dict) -> dict:
|
| 486 |
+
|
| 487 |
+
for key in ["images", "tables"]:
|
| 488 |
+
if key in data and isinstance(data[key], dict):
|
| 489 |
+
data[key] = {
|
| 490 |
+
k: v for k, v in data[key].items()
|
| 491 |
+
if v.get("original_section") is not None
|
| 492 |
+
}
|
| 493 |
+
return data
|
| 494 |
+
|
| 495 |
+
filtered_figures = remove_items_without_section(filtered_figures)
|
| 496 |
+
|
| 497 |
+
print(f"filtered figures generation completed. Tokens: {input_token} -> {output_token}")
|
| 498 |
+
|
| 499 |
+
# Save generated filtered figures
|
| 500 |
+
generated_path = f'project_contents/{self.args.paper_name}_generated_filtered_figures.json'
|
| 501 |
+
with open(generated_path, 'w') as f:
|
| 502 |
+
json.dump(filtered_figures, f, indent=4)
|
| 503 |
+
|
| 504 |
+
print(f" - Generated filtered figures: {generated_path}")
|
| 505 |
+
|
| 506 |
+
return paper_content, filtered_figures, input_token, output_token
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
|
ProjectPageAgent/css_checker.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from collections import OrderedDict
|
| 3 |
+
from ProjectPageAgent.html_finder import HtmlFinder
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
_LINK_CSS_RE = re.compile(
|
| 9 |
+
r'''(?isx)
|
| 10 |
+
<link[^>]*?
|
| 11 |
+
href\s*=\s*
|
| 12 |
+
(?:
|
| 13 |
+
"([^"]+?\.css(?:\?[^"]*)?)" |
|
| 14 |
+
'([^']+?\.css(?:\?[^']*)?)' |
|
| 15 |
+
([^\s"'=<>`]+?\.css(?:\?[^\s"'=<>`]*)?)
|
| 16 |
+
)
|
| 17 |
+
[^>]*?>
|
| 18 |
+
'''
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
_IMPORT_CSS_RE = re.compile(
|
| 23 |
+
r'''(?isx)
|
| 24 |
+
@import
|
| 25 |
+
\s+(?:url\()?
|
| 26 |
+
\s*
|
| 27 |
+
(?:
|
| 28 |
+
"([^"]+?\.css(?:\?[^"]*)?)" |
|
| 29 |
+
'([^']+?\.css(?:\?[^']*)?)' |
|
| 30 |
+
([^'")\s;]+?\.css(?:\?[^'")\s;]+)?)
|
| 31 |
+
)
|
| 32 |
+
\s*
|
| 33 |
+
\)?
|
| 34 |
+
'''
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _first_nonempty(groups_list):
|
| 39 |
+
out = []
|
| 40 |
+
for groups in groups_list:
|
| 41 |
+
for g in groups:
|
| 42 |
+
if g:
|
| 43 |
+
out.append(g)
|
| 44 |
+
break
|
| 45 |
+
return out
|
| 46 |
+
|
| 47 |
+
def extract_css_paths(html: str):
|
| 48 |
+
|
| 49 |
+
links = _first_nonempty(_LINK_CSS_RE.findall(html))
|
| 50 |
+
imports = _first_nonempty(_IMPORT_CSS_RE.findall(html))
|
| 51 |
+
seen = OrderedDict()
|
| 52 |
+
for u in links + imports:
|
| 53 |
+
u = u.strip()
|
| 54 |
+
if u and u not in seen:
|
| 55 |
+
seen[u] = True
|
| 56 |
+
return list(seen.keys())
|
| 57 |
+
|
| 58 |
+
def check_css(generated_html: str, template_html: str):
|
| 59 |
+
generated_css = extract_css_paths(generated_html)
|
| 60 |
+
template_css = extract_css_paths(template_html)
|
| 61 |
+
print(f'num of css in generated page: {len(generated_css)}')
|
| 62 |
+
print(f'num of css in template page: {len(template_css)}')
|
| 63 |
+
template_css_name = {css.strip().split('/')[-1]: css for css in template_css}
|
| 64 |
+
|
| 65 |
+
errors = {}
|
| 66 |
+
for css in generated_css:
|
| 67 |
+
if css.startswith('http'):
|
| 68 |
+
continue
|
| 69 |
+
if css not in template_css:
|
| 70 |
+
match = template_css_name.get(css.strip().split('/')[-1], None)
|
| 71 |
+
if match is not None:
|
| 72 |
+
errors[css] = match
|
| 73 |
+
else:
|
| 74 |
+
print(f"[⚠️ Warning] Missing CSS match for {css}")
|
| 75 |
+
|
| 76 |
+
new_html = generated_html
|
| 77 |
+
for css, new_css in errors.items():
|
| 78 |
+
if new_css:
|
| 79 |
+
new_html = new_html.replace(css, new_css)
|
| 80 |
+
|
| 81 |
+
return new_html
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
if __name__ == "__main__":
|
| 88 |
+
|
| 89 |
+
templates_root = '/home/jimu/Project_resources/project_page/page_assets/'
|
| 90 |
+
html_finder = HtmlFinder(specific_name='index.html')
|
| 91 |
+
|
| 92 |
+
count = 0
|
| 93 |
+
for page in os.listdir('generated_FastVGGT'):
|
| 94 |
+
print(page)
|
| 95 |
+
count += 1
|
| 96 |
+
with open(html_finder.find_html(os.path.join('generated_FastVGGT', page)), 'r') as f:
|
| 97 |
+
generated_html = f.read()
|
| 98 |
+
|
| 99 |
+
with open(html_finder.find_html(os.path.join(templates_root, page)), 'r') as f:
|
| 100 |
+
template_html = f.read()
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
_ = check_css(generated_html, template_html, page)
|
| 104 |
+
print(count)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
|
ProjectPageAgent/html_finder.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class HtmlFinder(object):
|
| 5 |
+
def __init__(self, specific_name=None):
|
| 6 |
+
self.queue = []
|
| 7 |
+
self.specific_name = specific_name
|
| 8 |
+
|
| 9 |
+
def find_html(self, path):
|
| 10 |
+
try:
|
| 11 |
+
if not os.path.isdir(path):
|
| 12 |
+
return
|
| 13 |
+
if self.queue:
|
| 14 |
+
del self.queue[0]
|
| 15 |
+
for dir in os.listdir(path):
|
| 16 |
+
dir_path = os.path.join(path, dir)
|
| 17 |
+
if os.path.isdir(dir_path):
|
| 18 |
+
self.queue.append(dir_path)
|
| 19 |
+
elif self.specific_name is not None and dir_path.endswith(self.specific_name):
|
| 20 |
+
return dir_path
|
| 21 |
+
elif dir_path.endswith(".html"):
|
| 22 |
+
html_path = dir_path
|
| 23 |
+
return html_path
|
| 24 |
+
else: continue
|
| 25 |
+
html_path = self.find_html(self.queue[0])
|
| 26 |
+
if html_path is not None:
|
| 27 |
+
return html_path
|
| 28 |
+
except Exception as e:
|
| 29 |
+
print(f"Error appears when finding {path}, error: {str(e)}")
|
| 30 |
+
|
| 31 |
+
def reset_queue(self):
|
| 32 |
+
self.queue = []
|
ProjectPageAgent/html_generator.py
ADDED
|
@@ -0,0 +1,633 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HTML generator for project page generation.
|
| 3 |
+
Generates the final HTML project page from planned content.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import json
|
| 7 |
+
import yaml
|
| 8 |
+
import os
|
| 9 |
+
import io
|
| 10 |
+
import re
|
| 11 |
+
import json
|
| 12 |
+
import yaml
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from urllib.parse import urlparse
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from jinja2 import Environment, StrictUndefined
|
| 17 |
+
from camel.models import ModelFactory
|
| 18 |
+
from camel.agents import ChatAgent
|
| 19 |
+
from utils.wei_utils import get_agent_config, account_token
|
| 20 |
+
from utils.src.utils import get_json_from_response, extract_html_code_block
|
| 21 |
+
from ProjectPageAgent.css_checker import check_css
|
| 22 |
+
from utils.src.utils import run_sync_screenshots
|
| 23 |
+
from PIL import Image
|
| 24 |
+
from camel.messages import BaseMessage
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
from camel.models import ModelFactory
|
| 28 |
+
|
| 29 |
+
def to_url(input_path_or_url: str) -> str:
|
| 30 |
+
parsed = urlparse(input_path_or_url)
|
| 31 |
+
if parsed.scheme in ("http", "https", "file"):
|
| 32 |
+
return input_path_or_url
|
| 33 |
+
p = Path(input_path_or_url).expanduser().resolve()
|
| 34 |
+
if not p.exists():
|
| 35 |
+
raise FileNotFoundError(f"Input not found: {p}")
|
| 36 |
+
return p.as_uri() # file://...
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def crop_image_to_max_size(image_path, max_bytes=8*1024*1024, output_path=None):
|
| 40 |
+
img = Image.open(image_path)
|
| 41 |
+
img_format = img.format
|
| 42 |
+
if output_path is None:
|
| 43 |
+
output_path = image_path
|
| 44 |
+
|
| 45 |
+
buffer = io.BytesIO()
|
| 46 |
+
img.save(buffer, format=img_format)
|
| 47 |
+
size = buffer.getbuffer().nbytes
|
| 48 |
+
|
| 49 |
+
if size <= max_bytes:
|
| 50 |
+
img.save(output_path, format=img_format)
|
| 51 |
+
return output_path
|
| 52 |
+
|
| 53 |
+
width, height = img.size
|
| 54 |
+
scale = max_bytes / size
|
| 55 |
+
new_height = max(int(height * scale), 1)
|
| 56 |
+
img_cropped = img.crop((0, 0, width, new_height))
|
| 57 |
+
img_cropped.save(output_path, format=img_format)
|
| 58 |
+
|
| 59 |
+
return output_path
|
| 60 |
+
class ProjectPageHTMLGenerator:
|
| 61 |
+
"""Generates HTML project pages from planned content."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, agent_config,args):
|
| 64 |
+
self.agent_config = agent_config
|
| 65 |
+
self.args = args
|
| 66 |
+
self.html_agent = self._create_html_agent()
|
| 67 |
+
self.review_agent = self._create_review_agent()
|
| 68 |
+
self.table_agent = self._create_table_agent()
|
| 69 |
+
self.long_agent = self._create_long_agent()
|
| 70 |
+
|
| 71 |
+
# self.client = OpenAI(api_key=api_key,base_url=api_url)
|
| 72 |
+
|
| 73 |
+
def _create_html_agent(self):
|
| 74 |
+
"""Create the HTML generation agent."""
|
| 75 |
+
model_type = str(self.agent_config['model_type'])
|
| 76 |
+
|
| 77 |
+
# Get API key from environment variables
|
| 78 |
+
api_key = None
|
| 79 |
+
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 80 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 81 |
+
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 82 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 83 |
+
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 84 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 85 |
+
elif self.args.model_name_t.startswith('openrouter_'):
|
| 86 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 87 |
+
elif self.args.model_name_t in ['zhipuai']:
|
| 88 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 89 |
+
|
| 90 |
+
if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
|
| 91 |
+
model = ModelFactory.create(
|
| 92 |
+
model_platform=self.agent_config['model_platform'],
|
| 93 |
+
model_type=self.agent_config['model_type'],
|
| 94 |
+
model_config_dict=self.agent_config['model_config'],
|
| 95 |
+
url=self.agent_config.get('url', None),
|
| 96 |
+
api_key=api_key,
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
model = ModelFactory.create(
|
| 100 |
+
model_platform=self.agent_config['model_platform'],
|
| 101 |
+
model_type=self.agent_config['model_type'],
|
| 102 |
+
model_config_dict=self.agent_config['model_config'],
|
| 103 |
+
api_key=api_key,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
system_message = """You are an expert web developer specializing in creating professional project pages for research papers.
|
| 107 |
+
You have extensive experience in HTML5, CSS3, responsive design, and academic content presentation.
|
| 108 |
+
Your goal is to create engaging, well-structured, and visually appealing project pages."""
|
| 109 |
+
|
| 110 |
+
return ChatAgent(
|
| 111 |
+
system_message=system_message,
|
| 112 |
+
model=model,
|
| 113 |
+
message_window_size=10
|
| 114 |
+
)
|
| 115 |
+
def _create_review_agent(self):
|
| 116 |
+
with open('utils/prompt_templates/page_templates/html_review.yaml', 'r') as f:
|
| 117 |
+
prompt_config = yaml.safe_load(f)
|
| 118 |
+
|
| 119 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 120 |
+
system_message_template = jinja_env.from_string(prompt_config["system_prompt"])
|
| 121 |
+
|
| 122 |
+
system_message = system_message_template.render()
|
| 123 |
+
|
| 124 |
+
model_type = self.args.model_name_v
|
| 125 |
+
|
| 126 |
+
# Get API key from environment variables
|
| 127 |
+
api_key = None
|
| 128 |
+
if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 129 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 130 |
+
elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 131 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 132 |
+
elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 133 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 134 |
+
elif self.args.model_name_v.startswith('openrouter_'):
|
| 135 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 136 |
+
elif self.args.model_name_v in ['zhipuai']:
|
| 137 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 138 |
+
|
| 139 |
+
config = get_agent_config(model_type)
|
| 140 |
+
model = ModelFactory.create(
|
| 141 |
+
model_platform=config['model_platform'],
|
| 142 |
+
model_type=config['model_type'],
|
| 143 |
+
model_config_dict=config['model_config'],
|
| 144 |
+
url=config.get('url', None),
|
| 145 |
+
api_key=api_key,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
return ChatAgent(
|
| 149 |
+
system_message=system_message,
|
| 150 |
+
model=model,
|
| 151 |
+
message_window_size=10
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _create_table_agent(self):
|
| 156 |
+
|
| 157 |
+
model_type = self.args.model_name_v
|
| 158 |
+
|
| 159 |
+
# Get API key from environment variables
|
| 160 |
+
api_key = None
|
| 161 |
+
if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 162 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 163 |
+
elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 164 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 165 |
+
elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 166 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 167 |
+
elif self.args.model_name_v.startswith('openrouter_'):
|
| 168 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 169 |
+
elif self.args.model_name_v in ['zhipuai']:
|
| 170 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 171 |
+
|
| 172 |
+
vlm_config = get_agent_config(model_type)
|
| 173 |
+
vlm_model = ModelFactory.create(
|
| 174 |
+
model_platform=vlm_config['model_platform'],
|
| 175 |
+
model_type=vlm_config['model_type'],
|
| 176 |
+
model_config_dict=vlm_config['model_config'],
|
| 177 |
+
url=vlm_config.get('url', None),
|
| 178 |
+
api_key=api_key,
|
| 179 |
+
)
|
| 180 |
+
return ChatAgent(
|
| 181 |
+
system_message=None,
|
| 182 |
+
model=vlm_model,
|
| 183 |
+
message_window_size=10,
|
| 184 |
+
)
|
| 185 |
+
def _create_long_agent(self):
|
| 186 |
+
model_type = self.args.model_name_t
|
| 187 |
+
|
| 188 |
+
# Get API key from environment variables
|
| 189 |
+
api_key = None
|
| 190 |
+
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 191 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 192 |
+
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 193 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 194 |
+
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 195 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 196 |
+
elif self.args.model_name_t.startswith('openrouter_'):
|
| 197 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 198 |
+
elif self.args.model_name_t in ['zhipuai']:
|
| 199 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 200 |
+
|
| 201 |
+
long_config = get_agent_config(model_type)
|
| 202 |
+
long_model = ModelFactory.create(
|
| 203 |
+
model_platform=long_config['model_platform'],
|
| 204 |
+
model_type=long_config['model_type'],
|
| 205 |
+
model_config_dict=long_config['model_config'],
|
| 206 |
+
url=long_config.get('url', None),
|
| 207 |
+
api_key=api_key,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
return ChatAgent(
|
| 211 |
+
system_message=None,
|
| 212 |
+
model=long_model,
|
| 213 |
+
message_window_size=10,
|
| 214 |
+
token_limit=long_config.get('token_limit', None)
|
| 215 |
+
)
|
| 216 |
+
def render_html_to_png(self, iter, html_content, project_output_dir) -> str:
|
| 217 |
+
|
| 218 |
+
import time
|
| 219 |
+
tmp_html = Path(project_output_dir) / f"index_iter{iter}.html"
|
| 220 |
+
tmp_html.write_text(html_content, encoding="utf-8")
|
| 221 |
+
url = tmp_html.resolve().as_uri()
|
| 222 |
+
|
| 223 |
+
image_path = str(Path(project_output_dir) / f"page_iter{iter}.png")
|
| 224 |
+
|
| 225 |
+
run_sync_screenshots(url, image_path)
|
| 226 |
+
return image_path
|
| 227 |
+
|
| 228 |
+
def get_revision_suggestions(self, image_path: str, html_path) -> str:
|
| 229 |
+
|
| 230 |
+
def crop_image_max_width(img, max_width=1280):
|
| 231 |
+
width, height = img.size
|
| 232 |
+
if width > max_width:
|
| 233 |
+
img = img.crop((0, 0, max_width, height)) # (left, top, right, bottom)
|
| 234 |
+
return img
|
| 235 |
+
img = Image.open(image_path)
|
| 236 |
+
img = crop_image_max_width(img, max_width=1280)
|
| 237 |
+
img.save(image_path,format='PNG')
|
| 238 |
+
crop_image_to_max_size(image_path=image_path,output_path=image_path)
|
| 239 |
+
img =Image.open(image_path)
|
| 240 |
+
|
| 241 |
+
message = BaseMessage.make_user_message(
|
| 242 |
+
role_name="User",
|
| 243 |
+
content = '\nHere is the image of the generated project page.',
|
| 244 |
+
image_list=[img]
|
| 245 |
+
)
|
| 246 |
+
response = self.review_agent.step(message)
|
| 247 |
+
|
| 248 |
+
return get_json_from_response(response.msgs[0].content.strip())
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def modify_html_table(self, html_content: str,html_dir: str):
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
in_tokens, out_tokens = 0, 0
|
| 255 |
+
print("Starting table modification...")
|
| 256 |
+
def replace_tables_in_html(html_content, table_html_map, paper_name):
|
| 257 |
+
|
| 258 |
+
pattern = rf'<img[^>]*src="(assets/{paper_name}-table-\d+\.png)"[^>]*>'
|
| 259 |
+
|
| 260 |
+
def repl(match):
|
| 261 |
+
img_path = match.group(1) # e.g. assets/MambaFusion-table-10.png
|
| 262 |
+
if img_path in table_html_map:
|
| 263 |
+
return table_html_map[img_path]
|
| 264 |
+
return match.group(0)
|
| 265 |
+
|
| 266 |
+
return re.sub(pattern, repl, html_content)
|
| 267 |
+
|
| 268 |
+
# ============ step 1 extract table ============
|
| 269 |
+
|
| 270 |
+
pattern = rf"assets/{self.args.paper_name}-table-\d+\.png"
|
| 271 |
+
with open(os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'index_no_modify_table.html'), 'r', encoding='utf-8') as f:
|
| 272 |
+
html_content = f.read()
|
| 273 |
+
matches = re.findall(pattern, html_content)
|
| 274 |
+
|
| 275 |
+
if matches is None:
|
| 276 |
+
print("No table images found, skipping modification.")
|
| 277 |
+
return None, 0, 0
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
model_type = self.args.model_name_v
|
| 281 |
+
print(f"Starting table modification phase 1: Table Extraction with {model_type}...")
|
| 282 |
+
|
| 283 |
+
with open('utils/prompt_templates/page_templates/extract_table.yaml', 'r') as f:
|
| 284 |
+
table_extraction_config = yaml.safe_load(f)
|
| 285 |
+
content = table_extraction_config["system_prompt"]
|
| 286 |
+
|
| 287 |
+
init_message = BaseMessage.make_user_message(
|
| 288 |
+
role_name="User",
|
| 289 |
+
content=content
|
| 290 |
+
)
|
| 291 |
+
response = self.table_agent.step(init_message)
|
| 292 |
+
in_tok , out_tok = account_token(response)
|
| 293 |
+
in_tokens += in_tok
|
| 294 |
+
out_tokens += out_tok
|
| 295 |
+
# Step 2
|
| 296 |
+
table_html_map = {}
|
| 297 |
+
|
| 298 |
+
matches = list(set(matches))
|
| 299 |
+
for match in matches:
|
| 300 |
+
img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,match)
|
| 301 |
+
print(f"Processing table image: {img_path}")
|
| 302 |
+
img = Image.open(img_path)
|
| 303 |
+
msg = BaseMessage.make_user_message(
|
| 304 |
+
role_name="User",
|
| 305 |
+
content=f'''Here is table image: {match}
|
| 306 |
+
Please output its HTML table (<table>...</table>) with an inline <style>...</style> block.
|
| 307 |
+
Only return pure HTML , nothing else.
|
| 308 |
+
''',
|
| 309 |
+
image_list=[img]
|
| 310 |
+
)
|
| 311 |
+
response = self.table_agent.step(msg)
|
| 312 |
+
in_tok , out_tok = account_token(response)
|
| 313 |
+
in_tokens += in_tok
|
| 314 |
+
out_tokens += out_tok
|
| 315 |
+
print(f'in:{in_tok},out:{out_tok}')
|
| 316 |
+
_output_html = response.msgs[0].content.strip()
|
| 317 |
+
table_html_map[match] = _output_html
|
| 318 |
+
tabel_dir = os.path.join(self.args.output_dir,self.args.paper_name, html_dir)
|
| 319 |
+
os.makedirs(f'{tabel_dir}/table_html', exist_ok=True)
|
| 320 |
+
|
| 321 |
+
with open(f'{tabel_dir}/table_html/{match.replace("/", "_")}.html', 'w', encoding='utf-8') as f:
|
| 322 |
+
f.write(table_html_map[match])
|
| 323 |
+
|
| 324 |
+
# ============ 阶段 2:HTML Merge ============
|
| 325 |
+
|
| 326 |
+
self.table_agent.reset()
|
| 327 |
+
img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'page_final_no_modify_table.png')
|
| 328 |
+
img = Image.open(img_path)
|
| 329 |
+
with open('utils/prompt_templates/page_templates/color_suggestion.yaml','r') as f:
|
| 330 |
+
prompt_config = yaml.safe_load(f)
|
| 331 |
+
|
| 332 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 333 |
+
init_prompt_template = jinja_env.from_string(prompt_config["system_prompt"])
|
| 334 |
+
|
| 335 |
+
init_prompt = init_prompt_template.render()
|
| 336 |
+
|
| 337 |
+
msg = BaseMessage.make_user_message(
|
| 338 |
+
role_name="User",
|
| 339 |
+
content=init_prompt,
|
| 340 |
+
image_list=[img]
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
color_response = self.table_agent.step(msg)
|
| 344 |
+
color_suggestion = color_response.msgs[0].content.strip()
|
| 345 |
+
in_tok , out_tok = account_token(color_response)
|
| 346 |
+
in_tokens += in_tok
|
| 347 |
+
out_tokens += out_tok
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
print(f"Starting table modification phase 2: HTML Merging with {model_type}...")
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
tables_str = "\n\n".join(
|
| 354 |
+
[f"Table extracted for {fname}:\n{html}" for fname, html in table_html_map.items()]
|
| 355 |
+
)
|
| 356 |
+
with open("utils/prompt_templates/page_templates/merge_html_table.yaml",'r') as f:
|
| 357 |
+
prompt_config = yaml.safe_load(f)
|
| 358 |
+
|
| 359 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 360 |
+
template = jinja_env.from_string(prompt_config["template"])
|
| 361 |
+
|
| 362 |
+
jinja_args = {
|
| 363 |
+
'html_content': html_content,
|
| 364 |
+
'color_suggestion': color_suggestion,
|
| 365 |
+
'tables_str': tables_str
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
prompt = template.render(**jinja_args)
|
| 369 |
+
|
| 370 |
+
final_message = BaseMessage.make_user_message(
|
| 371 |
+
role_name = "User",
|
| 372 |
+
content = prompt
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
for i in range(3):
|
| 376 |
+
self.long_agent.reset()
|
| 377 |
+
response = self.long_agent.step(final_message)
|
| 378 |
+
in_tok, out_tok = account_token(response)
|
| 379 |
+
in_tokens += in_tok
|
| 380 |
+
out_tokens += out_tok
|
| 381 |
+
output_html = response.msgs[0].content.strip()
|
| 382 |
+
print(f'in:{in_tok},out:{out_tok}')
|
| 383 |
+
exteact_html_code = extract_html_code_block(output_html)
|
| 384 |
+
if exteact_html_code is not None:
|
| 385 |
+
break
|
| 386 |
+
print(f"html format is not correct, regenerate {i} turn")
|
| 387 |
+
|
| 388 |
+
return exteact_html_code, in_tokens, out_tokens
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def modify_html_from_human_feedback(self, html_content: str, user_feedback: str):
|
| 392 |
+
"""
|
| 393 |
+
Modify HTML based on human feedback using the HTML agent.
|
| 394 |
+
|
| 395 |
+
Args:
|
| 396 |
+
html_content: Original HTML content
|
| 397 |
+
user_feedback: Feedback from human reviewers
|
| 398 |
+
|
| 399 |
+
Returns:
|
| 400 |
+
str: Modified HTML content
|
| 401 |
+
"""
|
| 402 |
+
in_tokens, out_tokens = 0, 0
|
| 403 |
+
print("Starting HTML modification based on human feedback...")
|
| 404 |
+
with open('utils/prompt_templates/page_templates/modify_html_from_human_feedback.yaml', 'r') as f:
|
| 405 |
+
modifier_config = yaml.safe_load(f)
|
| 406 |
+
|
| 407 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 408 |
+
template = jinja_env.from_string(modifier_config["template"])
|
| 409 |
+
|
| 410 |
+
jinja_args = {
|
| 411 |
+
'generated_html': html_content,
|
| 412 |
+
'user_feedback': user_feedback
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
prompt = template.render(**jinja_args)
|
| 416 |
+
for i in range(3):
|
| 417 |
+
self.html_agent.reset()
|
| 418 |
+
response = self.html_agent.step(prompt)
|
| 419 |
+
in_tok, out_tok = account_token(response)
|
| 420 |
+
in_tokens += in_tok
|
| 421 |
+
out_tokens += out_tok
|
| 422 |
+
print(f'input_token: {in_tok}, output_token: {out_tok}')
|
| 423 |
+
modified_html = extract_html_code_block(response.msgs[0].content)
|
| 424 |
+
|
| 425 |
+
if modified_html is not None:
|
| 426 |
+
break
|
| 427 |
+
print(f"html format is not correct, regenerate {i} turn")
|
| 428 |
+
|
| 429 |
+
return modified_html, in_tokens, out_tokens
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def generate_complete_html(self, args, generated_content, html_dir, html_template=None):
|
| 433 |
+
"""
|
| 434 |
+
Generate complete HTML by combining all sections, then render to PNG,
|
| 435 |
+
send to OpenAI API for feedback, and regenerate HTML with suggestions.
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
# Create output directory for this specific project
|
| 439 |
+
project_output_dir = f"{args.output_dir}/{args.paper_name}"
|
| 440 |
+
html_path = os.path.join(project_output_dir, html_dir)
|
| 441 |
+
if args.resume != 'html_check':
|
| 442 |
+
with open('utils/prompt_templates/page_templates/html_generation.yaml', 'r') as f:
|
| 443 |
+
generator_config = yaml.safe_load(f)
|
| 444 |
+
|
| 445 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 446 |
+
template = jinja_env.from_string(generator_config["template"])
|
| 447 |
+
|
| 448 |
+
jinja_args = {
|
| 449 |
+
'generated_content': json.dumps(generated_content, indent=2),
|
| 450 |
+
'html_template': html_template,
|
| 451 |
+
}
|
| 452 |
+
|
| 453 |
+
prompt = template.render(**jinja_args)
|
| 454 |
+
for i in range(3):
|
| 455 |
+
self.html_agent.reset()
|
| 456 |
+
# print(self.html_agent)
|
| 457 |
+
|
| 458 |
+
response = self.html_agent.step(prompt)
|
| 459 |
+
# print(response.msgs[0].content)
|
| 460 |
+
input_token, output_token = account_token(response)
|
| 461 |
+
print(f'input_token: {input_token}, output_token: {output_token}')
|
| 462 |
+
#print(input_token, output_token)
|
| 463 |
+
html_content = extract_html_code_block(response.msgs[0].content)
|
| 464 |
+
|
| 465 |
+
if html_content is not None:
|
| 466 |
+
break
|
| 467 |
+
print(f"html format is not correct, regenerate {i} turn")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# check css paths
|
| 471 |
+
html_content = check_css(html_content, html_template)
|
| 472 |
+
|
| 473 |
+
with open(os.path.join(html_path, 'index_init.html'),'w') as f:
|
| 474 |
+
f.write(html_content)
|
| 475 |
+
|
| 476 |
+
print(f"Initial HTML generation completed. Tokens: {input_token} -> {output_token}")
|
| 477 |
+
|
| 478 |
+
else:
|
| 479 |
+
with open(os.path.join(html_path, 'index_init.html'), 'r', encoding='utf-8') as f:
|
| 480 |
+
html_content = f.read()
|
| 481 |
+
|
| 482 |
+
revised_html = html_content
|
| 483 |
+
|
| 484 |
+
for i in range(self.args.html_check_times):
|
| 485 |
+
if i==0:
|
| 486 |
+
print("starting html check and revision...")
|
| 487 |
+
|
| 488 |
+
image_path = self.render_html_to_png(i, revised_html, html_path)
|
| 489 |
+
|
| 490 |
+
suggestions = self.get_revision_suggestions(image_path,os.path.join(html_path,f'index_iter{i}.html'))
|
| 491 |
+
# print(f"Revision suggestions from {self.args.model_name_v}:\n", suggestions)
|
| 492 |
+
|
| 493 |
+
review_path = f'project_contents/{args.paper_name}_html_review_iter{i}.json'
|
| 494 |
+
with open(review_path, 'w') as f:
|
| 495 |
+
json.dump(suggestions, f, indent=4)
|
| 496 |
+
|
| 497 |
+
self.html_agent.reset()
|
| 498 |
+
with open('utils/prompt_templates/page_templates/html_modify_from_suggestion.yaml', 'r') as f:
|
| 499 |
+
regenerator_config = yaml.safe_load(f)
|
| 500 |
+
|
| 501 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 502 |
+
_template = jinja_env.from_string(regenerator_config["template"])
|
| 503 |
+
|
| 504 |
+
_jinja_args = {
|
| 505 |
+
'existing_html': revised_html,
|
| 506 |
+
'suggestions': suggestions
|
| 507 |
+
}
|
| 508 |
+
|
| 509 |
+
revision_prompt = _template.render(**_jinja_args)
|
| 510 |
+
|
| 511 |
+
# print(revision_prompt)
|
| 512 |
+
revised_response = self.html_agent.step(revision_prompt)
|
| 513 |
+
# print(revised_response.msgs[0].content)
|
| 514 |
+
revised_html = extract_html_code_block(revised_response.msgs[0].content)
|
| 515 |
+
|
| 516 |
+
print("Revised HTML generation completed.")
|
| 517 |
+
input_token, output_token = account_token(revised_response)
|
| 518 |
+
print(f'in:{input_token}, out:{output_token}')
|
| 519 |
+
|
| 520 |
+
return revised_html, input_token, output_token
|
| 521 |
+
|
| 522 |
+
|
| 523 |
+
def save_html_file(self, html_content, args, html_dir, output_dir="generated_project_pages"):
|
| 524 |
+
"""
|
| 525 |
+
Save the generated HTML to a file.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
html_content: Generated HTML content
|
| 529 |
+
args: Command line arguments
|
| 530 |
+
output_dir: Output directory for the HTML file
|
| 531 |
+
|
| 532 |
+
Returns:html_check
|
| 533 |
+
str: Path to the saved HTML file
|
| 534 |
+
"""
|
| 535 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 536 |
+
|
| 537 |
+
# Create output directory for this specific project
|
| 538 |
+
project_output_dir = f"{output_dir}/{args.paper_name}"
|
| 539 |
+
os.makedirs(project_output_dir, exist_ok=True)
|
| 540 |
+
|
| 541 |
+
# Save HTML file
|
| 542 |
+
html_file_path = f"{project_output_dir}/{html_dir}/index.html"
|
| 543 |
+
with open(html_file_path, 'w', encoding='utf-8') as f:
|
| 544 |
+
f.write(html_content)
|
| 545 |
+
|
| 546 |
+
print(f"HTML project page saved to: {html_file_path}")
|
| 547 |
+
|
| 548 |
+
return html_file_path
|
| 549 |
+
|
| 550 |
+
def create_assets_directory(self, args, html_dir, output_dir="generated_project_pages"):
|
| 551 |
+
"""
|
| 552 |
+
Create assets directory and copy images/tables.
|
| 553 |
+
|
| 554 |
+
Args:
|
| 555 |
+
args: Command line arguments
|
| 556 |
+
output_dir: Output directory
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
str: Path to the assets directory
|
| 560 |
+
"""
|
| 561 |
+
project_output_dir = f"{output_dir}/{args.paper_name}"
|
| 562 |
+
assets_dir = os.path.join(project_output_dir, html_dir, "assets")
|
| 563 |
+
os.makedirs(assets_dir, exist_ok=True)
|
| 564 |
+
|
| 565 |
+
# Copy images and tables from the extracted assets
|
| 566 |
+
source_assets_dir = f"generated_project_pages/images_and_tables/{args.paper_name}"
|
| 567 |
+
if os.path.exists(source_assets_dir):
|
| 568 |
+
import shutil
|
| 569 |
+
for file in os.listdir(source_assets_dir):
|
| 570 |
+
if file.endswith(('.png', '.jpg', '.jpeg', '.gif')):
|
| 571 |
+
src_path = os.path.join(source_assets_dir, file)
|
| 572 |
+
dst_path = os.path.join(assets_dir, file)
|
| 573 |
+
shutil.copy2(src_path, dst_path)
|
| 574 |
+
|
| 575 |
+
print(f"Assets directory created at: {assets_dir}")
|
| 576 |
+
return assets_dir
|
| 577 |
+
|
| 578 |
+
def generate_metadata(self, generated_content, args):
|
| 579 |
+
"""
|
| 580 |
+
Generate metadata for the project page.
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
generated_content: Generated content
|
| 584 |
+
args: Command line arguments
|
| 585 |
+
|
| 586 |
+
Returns:
|
| 587 |
+
dict: Metadata for the project page
|
| 588 |
+
"""
|
| 589 |
+
metadata = {
|
| 590 |
+
'title': generated_content.get('meta', {}).get('poster_title', 'Research Project'),
|
| 591 |
+
'description': generated_content.get('meta', {}).get('abstract', '')[:160],
|
| 592 |
+
'authors': generated_content.get('meta', {}).get('authors', ''),
|
| 593 |
+
'affiliations': generated_content.get('meta', {}).get('affiliations', ''),
|
| 594 |
+
'keywords': [],
|
| 595 |
+
'generated_by': f"Paper2ProjectPage ({args.model_name_t}_{args.model_name_v})",
|
| 596 |
+
'generation_date': str(datetime.now())
|
| 597 |
+
}
|
| 598 |
+
|
| 599 |
+
# Extract keywords from content
|
| 600 |
+
content_text = json.dumps(generated_content, ensure_ascii=False)
|
| 601 |
+
# Simple keyword extraction (can be improved)
|
| 602 |
+
words = content_text.lower().split()
|
| 603 |
+
word_freq = {}
|
| 604 |
+
for word in words:
|
| 605 |
+
if len(word) > 4 and word.isalpha():
|
| 606 |
+
word_freq[word] = word_freq.get(word, 0) + 1
|
| 607 |
+
|
| 608 |
+
# Get top 10 most frequent words as keywords
|
| 609 |
+
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
|
| 610 |
+
metadata['keywords'] = [word for word, freq in sorted_words[:10]]
|
| 611 |
+
|
| 612 |
+
return metadata
|
| 613 |
+
|
| 614 |
+
def save_metadata(self, metadata, args, output_dir="generated_project_pages"):
|
| 615 |
+
"""
|
| 616 |
+
Save metadata to a JSON file.
|
| 617 |
+
|
| 618 |
+
Args:
|
| 619 |
+
metadata: Generated metadata
|
| 620 |
+
args: Command line arguments
|
| 621 |
+
output_dir: Output directory
|
| 622 |
+
|
| 623 |
+
Returns:
|
| 624 |
+
str: Path to the saved metadata file
|
| 625 |
+
"""
|
| 626 |
+
project_output_dir = f"{output_dir}/{args.paper_name}"
|
| 627 |
+
metadata_file_path = f"{project_output_dir}/metadata.json"
|
| 628 |
+
|
| 629 |
+
with open(metadata_file_path, 'w', encoding='utf-8') as f:
|
| 630 |
+
json.dump(metadata, f, indent=4, ensure_ascii=False)
|
| 631 |
+
|
| 632 |
+
print(f"Metadata saved to: {metadata_file_path}")
|
| 633 |
+
return metadata_file_path
|
ProjectPageAgent/main_pipline.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Main pipeline for Paper2ProjectPage.
|
| 3 |
+
Integrates all modules to generate project pages from research papers.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
import time
|
| 10 |
+
from dotenv import load_dotenv
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import shutil
|
| 13 |
+
from ProjectPageAgent.parse_paper import parse_paper_for_project_page, save_parsed_content
|
| 14 |
+
from ProjectPageAgent.html_finder import HtmlFinder
|
| 15 |
+
from ProjectPageAgent.content_planner import ProjectPageContentPlanner
|
| 16 |
+
from ProjectPageAgent.html_generator import ProjectPageHTMLGenerator,to_url
|
| 17 |
+
from utils.wei_utils import get_agent_config
|
| 18 |
+
from ProjectPageAgent.content_planner import filter_references
|
| 19 |
+
from utils.src.utils import run_sync_screenshots
|
| 20 |
+
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
def matching(requirement):
|
| 24 |
+
weight = {
|
| 25 |
+
"background_color": 1.0,
|
| 26 |
+
"has_hero_section": 0.75,
|
| 27 |
+
"Page density": 0.85,
|
| 28 |
+
"image_layout": 0.65,
|
| 29 |
+
"title_color": 0.6,
|
| 30 |
+
"has_navigation": 0.7
|
| 31 |
+
}
|
| 32 |
+
with open('tags.json', 'r') as f:
|
| 33 |
+
template_tags = json.load(f)
|
| 34 |
+
|
| 35 |
+
points = {}
|
| 36 |
+
for name, tag in template_tags.items():
|
| 37 |
+
for feature, value in tag.items():
|
| 38 |
+
if requirement[feature] == value:
|
| 39 |
+
if name not in points.keys():
|
| 40 |
+
points[name] = weight[feature]
|
| 41 |
+
else:
|
| 42 |
+
points[name] += weight[feature]
|
| 43 |
+
sorted_points = sorted(points.items(), key=lambda x: x[1], reverse=True)
|
| 44 |
+
return [template[0] for template in sorted_points[0:3]]
|
| 45 |
+
|
| 46 |
+
def copy_static_files(template_file_path, template_root_dir, output_dir, paper_name):
|
| 47 |
+
|
| 48 |
+
print(f"Detecting Static files: {template_file_path}")
|
| 49 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 50 |
+
|
| 51 |
+
# Create output directory for this specific project
|
| 52 |
+
project_output_dir = f"{output_dir}/{paper_name}"
|
| 53 |
+
os.makedirs(project_output_dir, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
# template_dir = os.path.dirname(template_file_path)
|
| 56 |
+
static_dir = os.path.join(project_output_dir, 'static')
|
| 57 |
+
os.makedirs(static_dir, exist_ok=True)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
html_relative_path = os.path.relpath(template_file_path, template_root_dir)
|
| 61 |
+
|
| 62 |
+
# template_static_dir = os.path.join(template_dir, 'static')
|
| 63 |
+
if os.path.exists(template_root_dir) and os.path.isdir(template_root_dir):
|
| 64 |
+
print(f"Found template dir: {template_root_dir}")
|
| 65 |
+
try:
|
| 66 |
+
shutil.copytree(template_root_dir, project_output_dir, dirs_exist_ok=True)
|
| 67 |
+
os.remove(os.path.join(project_output_dir, html_relative_path))
|
| 68 |
+
print(f"Copied template to: {project_output_dir}")
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"Failed to copy static files: {e}")
|
| 71 |
+
|
| 72 |
+
try:
|
| 73 |
+
with open(template_file_path, 'r', encoding='utf-8') as f:
|
| 74 |
+
html_content = f.read()
|
| 75 |
+
except Exception as e:
|
| 76 |
+
print(f"Failed to read template file: {e}")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
return static_dir
|
| 80 |
+
|
| 81 |
+
def main():
|
| 82 |
+
"""Main pipeline for generating project pages from research papers."""
|
| 83 |
+
parser = argparse.ArgumentParser(description='Paper2ProjectPage Generation Pipeline')
|
| 84 |
+
parser.add_argument('--paper_path', type=str, required=True, help='Path to the research paper PDF')
|
| 85 |
+
parser.add_argument('--model_name_t', type=str, default='4o', help='Text model name')
|
| 86 |
+
parser.add_argument('--model_name_v', type=str, default='4o', help='Vision model name')
|
| 87 |
+
parser.add_argument('--template_root', type=str, default="project_templates", help='Directory containing all templates')
|
| 88 |
+
parser.add_argument('--template_dir', type=str, help='Directory of chosen template')
|
| 89 |
+
parser.add_argument('--template_file', type=str, help='Path to a specific template file to use')
|
| 90 |
+
parser.add_argument('--output_dir', type=str, default='generated_project_pages', help='Output directory for generated pages')
|
| 91 |
+
parser.add_argument('--style_preference', type=str, default=None, help='Path to style preference JSON file')
|
| 92 |
+
parser.add_argument('--tmp_dir', type=str, default='tmp', help='Temporary directory')
|
| 93 |
+
parser.add_argument('--full_content_check_times', type=int, default='0', help='Temporary directory')
|
| 94 |
+
parser.add_argument('--background_color', type=str, choices=['light', 'dark'], required=True,
|
| 95 |
+
help='Background color of generated project page')
|
| 96 |
+
parser.add_argument('--has_navigation', type=str, choices=['yes', 'no'], required=True,
|
| 97 |
+
help='Is the generated project page has navigation')
|
| 98 |
+
parser.add_argument('--has_hero_section', type=str, choices=['yes', 'no'], required=True,
|
| 99 |
+
help='Is the generated project page has hero section')
|
| 100 |
+
parser.add_argument('--title_color', type=str, choices=['pure', 'colorful'], required=True,
|
| 101 |
+
help="Is the title's color of the project page is pure or colorful")
|
| 102 |
+
parser.add_argument('--page_density', type=str, choices=['spacious', 'compact'], required=True,
|
| 103 |
+
help="The overall spacing tightness—amount of white space vs. information density")
|
| 104 |
+
parser.add_argument('--image_layout', type=str, choices=['rotation', 'parallelism'], required=True,
|
| 105 |
+
help="The dominant arrangement style for images.")
|
| 106 |
+
parser.add_argument('--html_check_times', type=int, default='1', help='Temporary directory')
|
| 107 |
+
parser.add_argument(
|
| 108 |
+
'--resume',
|
| 109 |
+
type=str,
|
| 110 |
+
choices=['parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'],
|
| 111 |
+
default='parse_pdf',
|
| 112 |
+
help="From which step to resume: 'parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'",
|
| 113 |
+
)
|
| 114 |
+
parser.add_argument('--human_input', type=str, default='1',choices=['0','1'] ,help='Human input for feedback')
|
| 115 |
+
|
| 116 |
+
args = parser.parse_args()
|
| 117 |
+
|
| 118 |
+
if not args.template_dir:
|
| 119 |
+
template_requirement = {
|
| 120 |
+
"background_color": args.background_color,
|
| 121 |
+
"has_hero_section": args.has_hero_section,
|
| 122 |
+
"Page density": args.page_density,
|
| 123 |
+
"image_layout": args.image_layout,
|
| 124 |
+
"has_navigation": args.has_navigation,
|
| 125 |
+
"title_color": args.title_color
|
| 126 |
+
}
|
| 127 |
+
matched_template = matching(template_requirement)
|
| 128 |
+
print('Below is names of the most matching 3 templates:')
|
| 129 |
+
print(' '.join(matched_template))
|
| 130 |
+
template_name = input('Please choose one from them, you can just input the name of your favorite template')
|
| 131 |
+
while template_name not in matched_template:
|
| 132 |
+
template_name = input('Please input the correct name of your favorite template!!')
|
| 133 |
+
args.template_dir = os.path.join(args.template_root, template_name)
|
| 134 |
+
|
| 135 |
+
# Extract html path from root path
|
| 136 |
+
if not args.template_file:
|
| 137 |
+
html_finder_ = HtmlFinder()
|
| 138 |
+
args.template_file = html_finder_.find_html(args.template_dir)
|
| 139 |
+
|
| 140 |
+
# Extract paper name from path
|
| 141 |
+
paper_name = args.paper_path.split('/')[-1].replace('.pdf', '') if '/' in args.paper_path else args.paper_path.replace('.pdf', '')
|
| 142 |
+
args.paper_name = paper_name
|
| 143 |
+
|
| 144 |
+
print(f"Starting Paper2ProjectPage generation for: {paper_name}")
|
| 145 |
+
print(f"Paper path: {args.paper_path}")
|
| 146 |
+
print(f"Models: {args.model_name_t} (text), {args.model_name_v} (vision)")
|
| 147 |
+
|
| 148 |
+
start_time = time.time()
|
| 149 |
+
total_input_tokens_t = 0
|
| 150 |
+
total_output_tokens_t = 0
|
| 151 |
+
total_input_tokens_v = 0
|
| 152 |
+
total_output_tokens_v = 0
|
| 153 |
+
|
| 154 |
+
# Create temporary directory
|
| 155 |
+
os.makedirs(args.tmp_dir, exist_ok=True)
|
| 156 |
+
|
| 157 |
+
try:
|
| 158 |
+
# Get agent configurations
|
| 159 |
+
agent_config_t = get_agent_config(args.model_name_t)
|
| 160 |
+
agent_config_v = get_agent_config(args.model_name_v)
|
| 161 |
+
|
| 162 |
+
# Step 1: Parse the research paper
|
| 163 |
+
print("\n" + "="*50)
|
| 164 |
+
print("STEP 1: Parsing Research Paper")
|
| 165 |
+
print("="*50)
|
| 166 |
+
|
| 167 |
+
raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
|
| 168 |
+
if not os.path.exists(raw_content_path):
|
| 169 |
+
print(f"Raw content does not exist at {raw_content_path}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
input_token, output_token, raw_result, images, tables = parse_paper_for_project_page(args, agent_config_t)
|
| 173 |
+
total_input_tokens_t += input_token
|
| 174 |
+
total_output_tokens_t += output_token
|
| 175 |
+
|
| 176 |
+
# Save parsed content
|
| 177 |
+
raw_content_path, token_log_path = save_parsed_content(args, raw_result, images, tables, input_token, output_token)
|
| 178 |
+
|
| 179 |
+
# Load parsed content
|
| 180 |
+
with open(raw_content_path, 'r') as f:
|
| 181 |
+
paper_content = json.load(f)
|
| 182 |
+
else:
|
| 183 |
+
print(f"Loading existing raw content from {raw_content_path}")
|
| 184 |
+
with open(raw_content_path, 'r') as f:
|
| 185 |
+
paper_content = json.load(f)
|
| 186 |
+
# Load images and tables from the saved content
|
| 187 |
+
images = paper_content.get('images', [])
|
| 188 |
+
tables = paper_content.get('tables', [])
|
| 189 |
+
token_log_path = raw_content_path.replace('_raw_content.json', '_parse_log.json')
|
| 190 |
+
|
| 191 |
+
images = paper_content.get('images', [])
|
| 192 |
+
tables = paper_content.get('tables', [])
|
| 193 |
+
figures = {
|
| 194 |
+
'images': images,
|
| 195 |
+
'tables': tables
|
| 196 |
+
}
|
| 197 |
+
paper_content = paper_content.get('markdown_content', "")
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
print("\n" + "="*50)
|
| 201 |
+
print("STEP 2: Generate project page content")
|
| 202 |
+
print("="*50)
|
| 203 |
+
|
| 204 |
+
planner = ProjectPageContentPlanner(agent_config_t, args)
|
| 205 |
+
figures_path = f'project_contents/{args.paper_name}_generated_filtered_figures.json'
|
| 206 |
+
generated_section_path = f'project_contents/{args.paper_name}_generated_section.json'
|
| 207 |
+
text_page_content_path = f'project_contents/{args.paper_name}_generated_text_content.json'
|
| 208 |
+
generated_content_path = f'project_contents/{args.paper_name}_generated_full_content.json'
|
| 209 |
+
if args.resume in ['parse_pdf','generate_content','full_content_check']:
|
| 210 |
+
|
| 211 |
+
if args.resume != 'full_content_check':
|
| 212 |
+
|
| 213 |
+
paper_content, figures, input_token, output_token = planner.filter_raw_content(paper_content, figures)
|
| 214 |
+
total_input_tokens_t += input_token
|
| 215 |
+
total_output_tokens_t += output_token
|
| 216 |
+
|
| 217 |
+
generated_section, input_token, output_token = planner.section_generation(paper_content, figures)
|
| 218 |
+
total_input_tokens_t += input_token
|
| 219 |
+
total_output_tokens_t += output_token
|
| 220 |
+
|
| 221 |
+
text_page_content, input_token, output_token = planner.text_content_generation(paper_content, figures, generated_section)
|
| 222 |
+
total_input_tokens_t += input_token
|
| 223 |
+
total_output_tokens_t += output_token
|
| 224 |
+
|
| 225 |
+
else :
|
| 226 |
+
print("Skipping content generation: filter_raw_content, section_generation, text_content_generation")
|
| 227 |
+
print("Loading existing content from previous steps.")
|
| 228 |
+
paper_content = filter_references(paper_content)
|
| 229 |
+
with open(figures_path, 'r') as f:
|
| 230 |
+
figures = json.load(f)
|
| 231 |
+
with open(generated_section_path, 'r') as f:
|
| 232 |
+
generated_section = json.load(f)
|
| 233 |
+
with open(text_page_content_path, 'r') as f:
|
| 234 |
+
text_page_content = json.load(f)
|
| 235 |
+
|
| 236 |
+
generated_content, input_token, output_token = planner.full_content_generation(args, paper_content, figures, generated_section, text_page_content)
|
| 237 |
+
total_input_tokens_t += input_token
|
| 238 |
+
total_output_tokens_t += output_token
|
| 239 |
+
|
| 240 |
+
print("\n" + "="*50)
|
| 241 |
+
print("STEP 2.5: Copying Static Files")
|
| 242 |
+
print("="*50)
|
| 243 |
+
static_dir = copy_static_files(args.template_file, args.template_dir, args.output_dir, args.paper_name)
|
| 244 |
+
|
| 245 |
+
else:
|
| 246 |
+
print("Page content is already generated, loading existing content.")
|
| 247 |
+
|
| 248 |
+
paper_content = filter_references(paper_content)
|
| 249 |
+
with open(generated_section_path, 'r') as f:
|
| 250 |
+
generated_section = json.load(f)
|
| 251 |
+
with open(text_page_content_path, 'r') as f:
|
| 252 |
+
text_page_content = json.load(f)
|
| 253 |
+
with open(generated_content_path, 'r') as f:
|
| 254 |
+
generated_content = json.load(f)
|
| 255 |
+
|
| 256 |
+
static_dir = copy_static_files(args.template_file, args.template_dir, args.output_dir, args.paper_name)
|
| 257 |
+
# static_dir = os.path.join(args.output_dir, args.paper_name, 'static')
|
| 258 |
+
# Step 3: Generate HTML project page
|
| 259 |
+
print("\n" + "="*50)
|
| 260 |
+
print("STEP 3: Generating HTML Project Page")
|
| 261 |
+
print("="*50)
|
| 262 |
+
html_relative_path = os.path.relpath(args.template_file, args.template_dir)
|
| 263 |
+
html_dir = '/'.join(html_relative_path.strip().split('/')[:-1])
|
| 264 |
+
html_generator = ProjectPageHTMLGenerator(agent_config_t,args)
|
| 265 |
+
with open(args.template_file, 'r', encoding='utf-8') as file:
|
| 266 |
+
html_template = file.read()
|
| 267 |
+
# Generate HTML
|
| 268 |
+
if args.resume != 'modify_table' and args.resume != 'html_feedback':
|
| 269 |
+
|
| 270 |
+
# Create assets directory and copy images
|
| 271 |
+
assets_dir = html_generator.create_assets_directory(args, html_dir, args.output_dir)
|
| 272 |
+
# Generate complete HTML
|
| 273 |
+
html_content, input_token, output_token = html_generator.generate_complete_html(
|
| 274 |
+
args, generated_content, html_dir, html_template
|
| 275 |
+
)
|
| 276 |
+
total_input_tokens_t += input_token
|
| 277 |
+
total_output_tokens_t += output_token
|
| 278 |
+
|
| 279 |
+
# Save HTML file
|
| 280 |
+
html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_no_modify_table.html')
|
| 281 |
+
with open(html_file_path,'w') as file:
|
| 282 |
+
file.write(html_content)
|
| 283 |
+
run_sync_screenshots(to_url(html_file_path), os.path.join(args.output_dir,args.paper_name, html_dir,'page_final_no_modify_table.png'))
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
print(f"skip generate_html and html_check, load html from {os.path.join(args.output_dir,args.paper_name, html_dir,'index.html')}")
|
| 287 |
+
assets_dir = os.path.join(args.output_dir, args.paper_name, html_dir,'assets')
|
| 288 |
+
with open(os.path.join(args.output_dir,args.paper_name, html_dir,'index_no_modify_table.html'),'r') as file:
|
| 289 |
+
html_content = file.read()
|
| 290 |
+
|
| 291 |
+
if args.resume != 'html_feedback':
|
| 292 |
+
html_content ,input_token,output_token = html_generator.modify_html_table(html_content,html_dir)
|
| 293 |
+
total_input_tokens_t += input_token
|
| 294 |
+
total_output_tokens_t += output_token
|
| 295 |
+
html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_modify_table.html')
|
| 296 |
+
with open(html_file_path,'w') as file:
|
| 297 |
+
file.write(html_content)
|
| 298 |
+
# html_file_path = html_generator.save_html_file(html_content, args, html_dir,args.output_dir)
|
| 299 |
+
else:
|
| 300 |
+
print("skipping modify_table,go to html_feedback")
|
| 301 |
+
html_file_path = os.path.join(args.output_dir, args.paper_name, html_dir, 'index_modify_table.html')
|
| 302 |
+
with open(html_file_path,'r') as file:
|
| 303 |
+
html_content = file.read()
|
| 304 |
+
|
| 305 |
+
print('-'*50)
|
| 306 |
+
run_sync_screenshots(to_url(html_file_path), os.path.join(args.output_dir, args.paper_name, html_dir,'page_final.png'))
|
| 307 |
+
if args.human_input == '1':
|
| 308 |
+
human_feedback = input('Please view the final html in index.html,and image in page_final.png,If there are no problems, enter yes and press Enter.\n If there are any problems, please give me feedback directly.\n')
|
| 309 |
+
while human_feedback.lower() != 'yes':
|
| 310 |
+
|
| 311 |
+
html_content ,input_token,output_token = html_generator.modify_html_from_human_feedback(html_content,human_feedback)
|
| 312 |
+
total_input_tokens_t += input_token
|
| 313 |
+
total_output_tokens_t += output_token
|
| 314 |
+
with open(os.path.join(args.output_dir, args.paper_name, html_dir, 'index.html'),'w') as file:
|
| 315 |
+
file.write(html_content)
|
| 316 |
+
run_sync_screenshots(to_url(os.path.join(args.output_dir, args.paper_name, html_dir, 'index.html')), os.path.join(args.output_dir, args.paper_name, html_dir,'page_final.png'))
|
| 317 |
+
print('-'*50)
|
| 318 |
+
human_feedback = input('Please view the final html in index.html,and image in page_final.png,If there are no problems, enter yes and press Enter. \n If there are any problems, please give me feedback directly.\n')
|
| 319 |
+
|
| 320 |
+
html_file_path = html_generator.save_html_file(html_content, args, html_dir,args.output_dir)
|
| 321 |
+
|
| 322 |
+
# Generate and save metadata
|
| 323 |
+
metadata = html_generator.generate_metadata(generated_content, args)
|
| 324 |
+
metadata_path = html_generator.save_metadata(metadata, args, args.output_dir)
|
| 325 |
+
|
| 326 |
+
# Step 4: Finalize and save logs
|
| 327 |
+
print("\n" + "="*50)
|
| 328 |
+
print("STEP 4: Finalizing Generation")
|
| 329 |
+
print("="*50)
|
| 330 |
+
|
| 331 |
+
end_time = time.time()
|
| 332 |
+
time_taken = end_time - start_time
|
| 333 |
+
|
| 334 |
+
# Save generation log
|
| 335 |
+
log_data = {
|
| 336 |
+
'paper_name': paper_name,
|
| 337 |
+
'paper_path': args.paper_path,
|
| 338 |
+
'models': {
|
| 339 |
+
'text_model': args.model_name_t,
|
| 340 |
+
'vision_model': args.model_name_v
|
| 341 |
+
},
|
| 342 |
+
'token_usage': {
|
| 343 |
+
'text_input_tokens': total_input_tokens_t,
|
| 344 |
+
'text_output_tokens': total_output_tokens_t,
|
| 345 |
+
'vision_input_tokens': total_input_tokens_v,
|
| 346 |
+
'vision_output_tokens': total_output_tokens_v
|
| 347 |
+
},
|
| 348 |
+
'generation_time': time_taken,
|
| 349 |
+
'output_files': {
|
| 350 |
+
'html_file': html_file_path,
|
| 351 |
+
'assets_dir': assets_dir,
|
| 352 |
+
'static_dir': static_dir,
|
| 353 |
+
'metadata_file': metadata_path
|
| 354 |
+
},
|
| 355 |
+
'content_files': {
|
| 356 |
+
'raw_content': raw_content_path,
|
| 357 |
+
'token_log': token_log_path
|
| 358 |
+
}
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
log_path = f"{args.output_dir}/{args.paper_name}/generation_log.json"
|
| 362 |
+
with open(log_path, 'w') as f:
|
| 363 |
+
json.dump(log_data, f, indent=4)
|
| 364 |
+
|
| 365 |
+
print(f"\n✅ Paper2ProjectPage generation completed successfully!")
|
| 366 |
+
print(f"📁 Output directory: {args.output_dir}/{args.paper_name}")
|
| 367 |
+
print(f"🌐 HTML file: {html_file_path}")
|
| 368 |
+
print(f"📊 Assets directory: {assets_dir}")
|
| 369 |
+
print(f"🎨 Static directory: {static_dir}")
|
| 370 |
+
print(f"📋 Metadata file: {metadata_path}")
|
| 371 |
+
print(f"⏱️ Total time: {time_taken:.2f} seconds")
|
| 372 |
+
print(f"🔢 Token usage - Text: {total_input_tokens_t}→{total_output_tokens_t}, Vision: {total_input_tokens_v}→{total_output_tokens_v}")
|
| 373 |
+
|
| 374 |
+
except Exception as e:
|
| 375 |
+
print(f"\n❌ Error during generation: {str(e)}")
|
| 376 |
+
raise
|
| 377 |
+
|
| 378 |
+
if __name__ == '__main__':
|
| 379 |
+
main()
|
ProjectPageAgent/parse_paper.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Paper parsing module for ProjectPageAgent.
|
| 3 |
+
Reuses the parsing capabilities from Paper2Poster.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from ProjectPageAgent.parse_raw import parse_raw, gen_image_and_table
|
| 7 |
+
from utils.wei_utils import get_agent_config
|
| 8 |
+
import json
|
| 9 |
+
import os
|
| 10 |
+
import argparse
|
| 11 |
+
|
| 12 |
+
def parse_paper_for_project_page(args, agent_config_t, version=2):
|
| 13 |
+
"""
|
| 14 |
+
Parse a research paper PDF and extract content for project page generation.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
args: Command line arguments
|
| 18 |
+
agent_config_t: Text model configuration
|
| 19 |
+
version: Parser version to use
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
tuple: (input_tokens, output_tokens, raw_result, images, tables)
|
| 23 |
+
"""
|
| 24 |
+
print("Step 1: Parsing the research paper...")
|
| 25 |
+
|
| 26 |
+
# Add poster_path and poster_name attributes to args for compatibility with parse_raw
|
| 27 |
+
if not hasattr(args, 'poster_path'):
|
| 28 |
+
args.poster_path = args.paper_path
|
| 29 |
+
|
| 30 |
+
if not hasattr(args, 'poster_name'):
|
| 31 |
+
args.poster_name = args.paper_name
|
| 32 |
+
|
| 33 |
+
# Parse the raw paper content
|
| 34 |
+
input_token, output_token, raw_result = parse_raw(args, agent_config_t, version=version)
|
| 35 |
+
|
| 36 |
+
# Extract images and tables
|
| 37 |
+
_, _, images, tables = gen_image_and_table(args, raw_result)
|
| 38 |
+
|
| 39 |
+
print(f"Parsing completed. Tokens: {input_token} -> {output_token}")
|
| 40 |
+
print(f"Extracted {len(images)} images and {len(tables)} tables")
|
| 41 |
+
|
| 42 |
+
return input_token, output_token, raw_result, images, tables
|
| 43 |
+
|
| 44 |
+
def save_parsed_content(args, raw_result, images, tables, input_token, output_token):
|
| 45 |
+
"""
|
| 46 |
+
Save parsed content to files for later use.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
args: Command line arguments
|
| 50 |
+
raw_result: Parsed raw content
|
| 51 |
+
images: Extracted images
|
| 52 |
+
tables: Extracted tables
|
| 53 |
+
input_token: Input token count
|
| 54 |
+
output_token: Output token count
|
| 55 |
+
"""
|
| 56 |
+
# Save raw content
|
| 57 |
+
os.makedirs('project_contents', exist_ok=True)
|
| 58 |
+
raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
|
| 59 |
+
|
| 60 |
+
# Convert raw_result to JSON format if needed
|
| 61 |
+
if hasattr(raw_result, 'document'):
|
| 62 |
+
# Extract text content from docling result
|
| 63 |
+
raw_markdown = raw_result.document.export_to_markdown()
|
| 64 |
+
content_json = {
|
| 65 |
+
'markdown_content': raw_markdown,
|
| 66 |
+
'images': images,
|
| 67 |
+
'tables': tables
|
| 68 |
+
}
|
| 69 |
+
else:
|
| 70 |
+
content_json = raw_result
|
| 71 |
+
|
| 72 |
+
with open(raw_content_path, 'w') as f:
|
| 73 |
+
json.dump(content_json, f, indent=4)
|
| 74 |
+
|
| 75 |
+
# Save token usage
|
| 76 |
+
token_log = {
|
| 77 |
+
'parse_input_tokens': input_token,
|
| 78 |
+
'parse_output_tokens': output_token,
|
| 79 |
+
'total_images': len(images),
|
| 80 |
+
'total_tables': len(tables)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
token_log_path = f'project_contents/{args.paper_name}_parse_log.json'
|
| 84 |
+
with open(token_log_path, 'w') as f:
|
| 85 |
+
json.dump(token_log, f, indent=4)
|
| 86 |
+
|
| 87 |
+
print(f"Parsed content saved to {raw_content_path}")
|
| 88 |
+
return raw_content_path, token_log_path
|
ProjectPageAgent/parse_raw.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import get_json_from_response
|
| 3 |
+
from utils.src.model_utils import parse_pdf
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
from camel.models import ModelFactory
|
| 9 |
+
from camel.agents import ChatAgent
|
| 10 |
+
from tenacity import retry, stop_after_attempt
|
| 11 |
+
from docling_core.types.doc import ImageRefMode, PictureItem, TableItem
|
| 12 |
+
|
| 13 |
+
from docling.datamodel.base_models import InputFormat
|
| 14 |
+
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
| 15 |
+
from docling.document_converter import DocumentConverter, PdfFormatOption
|
| 16 |
+
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import PIL
|
| 20 |
+
|
| 21 |
+
from marker.models import create_model_dict
|
| 22 |
+
|
| 23 |
+
from utils.wei_utils import *
|
| 24 |
+
|
| 25 |
+
from utils.pptx_utils import *
|
| 26 |
+
from utils.critic_utils import *
|
| 27 |
+
import torch
|
| 28 |
+
from jinja2 import Template
|
| 29 |
+
import re
|
| 30 |
+
import argparse
|
| 31 |
+
|
| 32 |
+
load_dotenv()
|
| 33 |
+
IMAGE_RESOLUTION_SCALE = 5.0
|
| 34 |
+
|
| 35 |
+
pipeline_options = PdfPipelineOptions()
|
| 36 |
+
pipeline_options.images_scale = IMAGE_RESOLUTION_SCALE
|
| 37 |
+
pipeline_options.generate_page_images = True
|
| 38 |
+
pipeline_options.generate_picture_images = True
|
| 39 |
+
|
| 40 |
+
doc_converter = DocumentConverter(
|
| 41 |
+
format_options={
|
| 42 |
+
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
|
| 43 |
+
}
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
@retry(stop=stop_after_attempt(5))
|
| 47 |
+
def parse_raw(args, actor_config, version=1):
|
| 48 |
+
raw_source = args.poster_path
|
| 49 |
+
markdown_clean_pattern = re.compile(r"<!--[\s\S]*?-->")
|
| 50 |
+
|
| 51 |
+
raw_result = doc_converter.convert(raw_source)
|
| 52 |
+
|
| 53 |
+
raw_markdown = raw_result.document.export_to_markdown()
|
| 54 |
+
text_content = markdown_clean_pattern.sub("", raw_markdown)
|
| 55 |
+
|
| 56 |
+
if len(text_content) < 500:
|
| 57 |
+
print('\nParsing with docling failed, using marker instead\n')
|
| 58 |
+
parser_model = create_model_dict(device='cuda', dtype=torch.float16)
|
| 59 |
+
text_content, rendered = parse_pdf(raw_source, model_lst=parser_model, save_file=False)
|
| 60 |
+
|
| 61 |
+
if version == 1:
|
| 62 |
+
template = Template(open("utils/prompts/gen_page_raw_content.txt").read())
|
| 63 |
+
elif version == 2:
|
| 64 |
+
template = Template(open("utils/prompts/gen_page_raw_content_v2.txt").read())
|
| 65 |
+
|
| 66 |
+
# Get API key from environment variables
|
| 67 |
+
api_key = None
|
| 68 |
+
if args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 69 |
+
api_key = os.environ.get('OPENAI_API_KEY')
|
| 70 |
+
elif args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 71 |
+
api_key = os.environ.get('GEMINI_API_KEY')
|
| 72 |
+
elif args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 73 |
+
api_key = os.environ.get('QWEN_API_KEY')
|
| 74 |
+
elif args.model_name_t.startswith('openrouter_'):
|
| 75 |
+
api_key = os.environ.get('OPENROUTER_API_KEY')
|
| 76 |
+
elif args.model_name_t in ['zhipuai']:
|
| 77 |
+
api_key = os.environ.get('ZHIPUAI_API_KEY')
|
| 78 |
+
|
| 79 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 80 |
+
actor_model = ModelFactory.create(
|
| 81 |
+
model_platform=actor_config['model_platform'],
|
| 82 |
+
model_type=actor_config['model_type'],
|
| 83 |
+
model_config_dict=actor_config['model_config'],
|
| 84 |
+
url=actor_config['url'],
|
| 85 |
+
api_key=api_key,
|
| 86 |
+
)
|
| 87 |
+
else:
|
| 88 |
+
actor_model = ModelFactory.create(
|
| 89 |
+
model_platform=actor_config['model_platform'],
|
| 90 |
+
model_type=actor_config['model_type'],
|
| 91 |
+
model_config_dict=actor_config['model_config'],
|
| 92 |
+
api_key=api_key,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
actor_sys_msg = 'You are the author of the paper, and you will create a poster for the paper.'
|
| 96 |
+
|
| 97 |
+
actor_agent = ChatAgent(
|
| 98 |
+
system_message=actor_sys_msg,
|
| 99 |
+
model=actor_model,
|
| 100 |
+
message_window_size=10,
|
| 101 |
+
token_limit=actor_config.get('token_limit', None)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
while True:
|
| 105 |
+
prompt = template.render(
|
| 106 |
+
markdown_document=text_content,
|
| 107 |
+
)
|
| 108 |
+
actor_agent.reset()
|
| 109 |
+
response = actor_agent.step(prompt)
|
| 110 |
+
input_token, output_token = account_token(response)
|
| 111 |
+
|
| 112 |
+
content_json = get_json_from_response(response.msgs[0].content)
|
| 113 |
+
|
| 114 |
+
if len(content_json) > 0:
|
| 115 |
+
break
|
| 116 |
+
print('Error: Empty response, retrying...')
|
| 117 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 118 |
+
text_content = text_content[:80000]
|
| 119 |
+
|
| 120 |
+
if len(content_json['sections']) > 9:
|
| 121 |
+
# First 2 sections + randomly select 5 sections + last 2 sections
|
| 122 |
+
selected_sections = content_json['sections'][:2] + random.sample(content_json['sections'][2:-2], 5) + content_json['sections'][-2:]
|
| 123 |
+
content_json['sections'] = selected_sections
|
| 124 |
+
|
| 125 |
+
has_title = False
|
| 126 |
+
|
| 127 |
+
for section in content_json['sections']:
|
| 128 |
+
if type(section) != dict or not 'title' in section or not 'content' in section:
|
| 129 |
+
print(f"Ouch! The response is invalid, the LLM is not following the format :(")
|
| 130 |
+
print('Trying again...')
|
| 131 |
+
raise
|
| 132 |
+
if 'title' in section['title'].lower():
|
| 133 |
+
has_title = True
|
| 134 |
+
|
| 135 |
+
if not has_title:
|
| 136 |
+
print('Ouch! The response is invalid, the LLM is not following the format :(')
|
| 137 |
+
raise
|
| 138 |
+
|
| 139 |
+
os.makedirs('contents', exist_ok=True)
|
| 140 |
+
json.dump(content_json, open(f'contents/{args.poster_name}_raw_content.json', 'w'), indent=4)
|
| 141 |
+
return input_token, output_token, raw_result
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def gen_image_and_table(args, conv_res):
|
| 145 |
+
input_token, output_token = 0, 0
|
| 146 |
+
raw_source = args.poster_path
|
| 147 |
+
|
| 148 |
+
output_dir = Path(f'generated_project_pages/images_and_tables/{args.poster_name}')
|
| 149 |
+
|
| 150 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
doc_filename = args.poster_name
|
| 152 |
+
|
| 153 |
+
# Save page images
|
| 154 |
+
for page_no, page in conv_res.document.pages.items():
|
| 155 |
+
page_no = page.page_no
|
| 156 |
+
page_image_filename = output_dir / f"{doc_filename}-{page_no}.png"
|
| 157 |
+
with page_image_filename.open("wb") as fp:
|
| 158 |
+
page.image.pil_image.save(fp, format="PNG")
|
| 159 |
+
|
| 160 |
+
# Save images of figures and tables
|
| 161 |
+
table_counter = 0
|
| 162 |
+
picture_counter = 0
|
| 163 |
+
for element, _level in conv_res.document.iterate_items():
|
| 164 |
+
if isinstance(element, TableItem):
|
| 165 |
+
table_counter += 1
|
| 166 |
+
element_image_filename = (
|
| 167 |
+
output_dir / f"{doc_filename}-table-{table_counter}.png"
|
| 168 |
+
)
|
| 169 |
+
with element_image_filename.open("wb") as fp:
|
| 170 |
+
element.get_image(conv_res.document).save(fp, "PNG")
|
| 171 |
+
|
| 172 |
+
if isinstance(element, PictureItem):
|
| 173 |
+
picture_counter += 1
|
| 174 |
+
element_image_filename = (
|
| 175 |
+
output_dir / f"{doc_filename}-picture-{picture_counter}.png"
|
| 176 |
+
)
|
| 177 |
+
with element_image_filename.open("wb") as fp:
|
| 178 |
+
element.get_image(conv_res.document).save(fp, "PNG")
|
| 179 |
+
|
| 180 |
+
# Save markdown with embedded pictures
|
| 181 |
+
md_filename = output_dir / f"{doc_filename}-with-images.md"
|
| 182 |
+
conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.EMBEDDED)
|
| 183 |
+
|
| 184 |
+
# Save markdown with externally referenced pictures
|
| 185 |
+
md_filename = output_dir / f"{doc_filename}-with-image-refs.md"
|
| 186 |
+
conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.REFERENCED)
|
| 187 |
+
|
| 188 |
+
# Save HTML with externally referenced pictures
|
| 189 |
+
html_filename = output_dir / f"{doc_filename}-with-image-refs.html"
|
| 190 |
+
conv_res.document.save_as_html(html_filename, image_mode=ImageRefMode.REFERENCED)
|
| 191 |
+
|
| 192 |
+
tables = {}
|
| 193 |
+
|
| 194 |
+
table_index = 1
|
| 195 |
+
for table in conv_res.document.tables:
|
| 196 |
+
caption = table.caption_text(conv_res.document)
|
| 197 |
+
if len(caption) > 0:
|
| 198 |
+
table_img_path = f'generated_project_pages/images_and_tables/{args.poster_name}/{args.poster_name}-table-{table_index}.png'
|
| 199 |
+
assests_table_path = f'assets/{args.poster_name}-table-{table_index}.png'
|
| 200 |
+
table_img = PIL.Image.open(table_img_path)
|
| 201 |
+
tables[str(table_index)] = {
|
| 202 |
+
'caption': caption,
|
| 203 |
+
'table_path': assests_table_path,
|
| 204 |
+
# 'assests_table_path': assests_table_path,
|
| 205 |
+
'width': table_img.width,
|
| 206 |
+
'height': table_img.height,
|
| 207 |
+
'figure_size': table_img.width * table_img.height,
|
| 208 |
+
'figure_aspect': table_img.width / table_img.height,
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
table_index += 1
|
| 212 |
+
|
| 213 |
+
images = {}
|
| 214 |
+
image_index = 1
|
| 215 |
+
for image in conv_res.document.pictures:
|
| 216 |
+
caption = image.caption_text(conv_res.document)
|
| 217 |
+
if len(caption) > 0:
|
| 218 |
+
image_img_path = f'generated_project_pages/images_and_tables/{args.poster_name}/{args.poster_name}-picture-{image_index}.png'
|
| 219 |
+
assests_image_path = f'assets/{args.poster_name}-picture-{image_index}.png'
|
| 220 |
+
image_img = PIL.Image.open(image_img_path)
|
| 221 |
+
images[str(image_index)] = {
|
| 222 |
+
'caption': caption,
|
| 223 |
+
'image_path': assests_image_path,
|
| 224 |
+
# 'assests_image_path': assests_image_path,
|
| 225 |
+
'width': image_img.width,
|
| 226 |
+
'height': image_img.height,
|
| 227 |
+
'figure_size': image_img.width * image_img.height,
|
| 228 |
+
'figure_aspect': image_img.width / image_img.height,
|
| 229 |
+
}
|
| 230 |
+
image_index += 1
|
| 231 |
+
|
| 232 |
+
json.dump(images, open(f'generated_project_pages/images_and_tables/{args.poster_name}_images.json', 'w'), indent=4)
|
| 233 |
+
json.dump(tables, open(f'generated_project_pages/images_and_tables/{args.poster_name}_tables.json', 'w'), indent=4)
|
| 234 |
+
|
| 235 |
+
return input_token, output_token, images, tables
|
| 236 |
+
|
| 237 |
+
if __name__ == '__main__':
|
| 238 |
+
parser = argparse.ArgumentParser()
|
| 239 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 240 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 241 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 242 |
+
parser.add_argument('--index', type=int, default=0)
|
| 243 |
+
args = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
agent_config = get_agent_config(args.model_name)
|
| 246 |
+
|
| 247 |
+
if args.poster_name is None:
|
| 248 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 249 |
+
|
| 250 |
+
# Parse raw content
|
| 251 |
+
input_token, output_token = parse_raw(args, agent_config)
|
| 252 |
+
|
| 253 |
+
# Generate images and tables
|
| 254 |
+
_, _ = gen_image_and_table(args)
|
| 255 |
+
|
| 256 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
ProjectPageAgent/template_analyzer.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Template analyzer for project page generation.
|
| 3 |
+
Analyzes existing project page templates to understand structure and style.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import json
|
| 8 |
+
import re
|
| 9 |
+
from bs4 import BeautifulSoup
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import yaml
|
| 12 |
+
from jinja2 import Environment, StrictUndefined
|
| 13 |
+
|
| 14 |
+
class ProjectPageTemplateAnalyzer:
|
| 15 |
+
"""Analyzes project page templates to extract structure and styling patterns."""
|
| 16 |
+
|
| 17 |
+
def __init__(self, template_dir="project_templates"):
|
| 18 |
+
self.template_dir = Path(template_dir)
|
| 19 |
+
self.template_dir.mkdir(exist_ok=True)
|
| 20 |
+
self.templates = {}
|
| 21 |
+
self.common_patterns = {}
|
| 22 |
+
|
| 23 |
+
def analyze_html_template(self, html_file_path):
|
| 24 |
+
"""
|
| 25 |
+
Analyze an HTML template file to extract structure and styling.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
html_file_path: Path to the HTML template file
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
dict: Analysis results including structure, styling, and patterns
|
| 32 |
+
"""
|
| 33 |
+
try:
|
| 34 |
+
with open(html_file_path, 'r', encoding='utf-8') as f:
|
| 35 |
+
html_content = f.read()
|
| 36 |
+
|
| 37 |
+
soup = BeautifulSoup(html_content, 'html.parser')
|
| 38 |
+
|
| 39 |
+
analysis = {
|
| 40 |
+
'file_path': html_file_path,
|
| 41 |
+
'structure': self._extract_structure(soup),
|
| 42 |
+
'styling': self._extract_styling(soup),
|
| 43 |
+
'sections': self._extract_sections(soup),
|
| 44 |
+
'components': self._extract_components(soup),
|
| 45 |
+
'meta_info': self._extract_meta_info(soup)
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
return analysis
|
| 49 |
+
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Error analyzing template {html_file_path}: {e}")
|
| 52 |
+
return None
|
| 53 |
+
|
| 54 |
+
def _extract_structure(self, soup):
|
| 55 |
+
"""Extract the overall structure of the HTML document."""
|
| 56 |
+
structure = {
|
| 57 |
+
'doctype': soup.find('!DOCTYPE') is not None,
|
| 58 |
+
'html_lang': soup.html.get('lang', 'en') if soup.html else 'en',
|
| 59 |
+
'head_sections': [],
|
| 60 |
+
'body_sections': [],
|
| 61 |
+
'main_content': None,
|
| 62 |
+
'navigation': None,
|
| 63 |
+
'footer': None
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
# Extract head sections
|
| 67 |
+
if soup.head:
|
| 68 |
+
for tag in soup.head.find_all(['meta', 'link', 'script', 'title']):
|
| 69 |
+
structure['head_sections'].append({
|
| 70 |
+
'tag': tag.name,
|
| 71 |
+
'attrs': dict(tag.attrs)
|
| 72 |
+
})
|
| 73 |
+
|
| 74 |
+
# Extract body structure
|
| 75 |
+
if soup.body:
|
| 76 |
+
for section in soup.body.find_all(['header', 'nav', 'main', 'section', 'article', 'aside', 'footer']):
|
| 77 |
+
structure['body_sections'].append({
|
| 78 |
+
'tag': section.name,
|
| 79 |
+
'id': section.get('id', ''),
|
| 80 |
+
'class': section.get('class', []),
|
| 81 |
+
'content_type': self._identify_content_type(section)
|
| 82 |
+
})
|
| 83 |
+
|
| 84 |
+
return structure
|
| 85 |
+
|
| 86 |
+
def _extract_styling(self, soup):
|
| 87 |
+
"""Extract CSS styling information."""
|
| 88 |
+
styling = {
|
| 89 |
+
'inline_styles': [],
|
| 90 |
+
'external_css': [],
|
| 91 |
+
'color_scheme': [],
|
| 92 |
+
'typography': {},
|
| 93 |
+
'layout': {}
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
# Extract inline styles
|
| 97 |
+
for tag in soup.find_all(style=True):
|
| 98 |
+
styling['inline_styles'].append({
|
| 99 |
+
'tag': tag.name,
|
| 100 |
+
'style': tag.get('style', '')
|
| 101 |
+
})
|
| 102 |
+
|
| 103 |
+
# Extract external CSS links
|
| 104 |
+
for link in soup.find_all('link', rel='stylesheet'):
|
| 105 |
+
styling['external_css'].append(link.get('href', ''))
|
| 106 |
+
|
| 107 |
+
# Extract color information
|
| 108 |
+
color_pattern = re.compile(r'#[0-9a-fA-F]{3,6}|rgb\([^)]+\)|rgba\([^)]+\)')
|
| 109 |
+
for tag in soup.find_all(style=True):
|
| 110 |
+
colors = color_pattern.findall(tag.get('style', ''))
|
| 111 |
+
styling['color_scheme'].extend(colors)
|
| 112 |
+
|
| 113 |
+
# Extract typography patterns
|
| 114 |
+
for tag in soup.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'p']):
|
| 115 |
+
font_size = re.search(r'font-size:\s*([^;]+)', tag.get('style', ''))
|
| 116 |
+
if font_size:
|
| 117 |
+
styling['typography'][tag.name] = font_size.group(1)
|
| 118 |
+
|
| 119 |
+
return styling
|
| 120 |
+
|
| 121 |
+
def _extract_sections(self, soup):
|
| 122 |
+
"""Extract content sections and their organization."""
|
| 123 |
+
sections = []
|
| 124 |
+
|
| 125 |
+
for section in soup.find_all(['section', 'article', 'div'], class_=True):
|
| 126 |
+
section_info = {
|
| 127 |
+
'tag': section.name,
|
| 128 |
+
'id': section.get('id', ''),
|
| 129 |
+
'classes': section.get('class', []),
|
| 130 |
+
'content': self._extract_section_content(section),
|
| 131 |
+
'images': self._extract_images(section),
|
| 132 |
+
'tables': self._extract_tables(section)
|
| 133 |
+
}
|
| 134 |
+
sections.append(section_info)
|
| 135 |
+
|
| 136 |
+
return sections
|
| 137 |
+
|
| 138 |
+
def _extract_components(self, soup):
|
| 139 |
+
"""Extract reusable components and their patterns."""
|
| 140 |
+
components = {
|
| 141 |
+
'navigation': self._extract_navigation(soup),
|
| 142 |
+
'hero_section': self._extract_hero_section(soup),
|
| 143 |
+
'content_blocks': self._extract_content_blocks(soup),
|
| 144 |
+
'image_galleries': self._extract_image_galleries(soup),
|
| 145 |
+
'contact_forms': self._extract_contact_forms(soup)
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
return components
|
| 149 |
+
|
| 150 |
+
def _extract_meta_info(self, soup):
|
| 151 |
+
"""Extract meta information and SEO elements."""
|
| 152 |
+
meta_info = {
|
| 153 |
+
'title': soup.title.string if soup.title else '',
|
| 154 |
+
'meta_tags': [],
|
| 155 |
+
'open_graph': {},
|
| 156 |
+
'twitter_cards': {}
|
| 157 |
+
}
|
| 158 |
+
|
| 159 |
+
for meta in soup.find_all('meta'):
|
| 160 |
+
meta_info['meta_tags'].append({
|
| 161 |
+
'name': meta.get('name', ''),
|
| 162 |
+
'content': meta.get('content', ''),
|
| 163 |
+
'property': meta.get('property', '')
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
# Extract Open Graph tags
|
| 167 |
+
if meta.get('property', '').startswith('og:'):
|
| 168 |
+
meta_info['open_graph'][meta.get('property')] = meta.get('content', '')
|
| 169 |
+
|
| 170 |
+
# Extract Twitter Card tags
|
| 171 |
+
if meta.get('name', '').startswith('twitter:'):
|
| 172 |
+
meta_info['twitter_cards'][meta.get('name')] = meta.get('content', '')
|
| 173 |
+
|
| 174 |
+
return meta_info
|
| 175 |
+
|
| 176 |
+
def _identify_content_type(self, element):
|
| 177 |
+
"""Identify the type of content in an element."""
|
| 178 |
+
text = element.get_text().lower()
|
| 179 |
+
|
| 180 |
+
if any(word in text for word in ['abstract', 'summary', 'overview']):
|
| 181 |
+
return 'abstract'
|
| 182 |
+
elif any(word in text for word in ['introduction', 'background']):
|
| 183 |
+
return 'introduction'
|
| 184 |
+
elif any(word in text for word in ['method', 'approach', 'methodology']):
|
| 185 |
+
return 'methodology'
|
| 186 |
+
elif any(word in text for word in ['result', 'experiment', 'evaluation']):
|
| 187 |
+
return 'results'
|
| 188 |
+
elif any(word in text for word in ['conclusion', 'discussion', 'future']):
|
| 189 |
+
return 'conclusion'
|
| 190 |
+
elif any(word in text for word in ['contact', 'author', 'team']):
|
| 191 |
+
return 'contact'
|
| 192 |
+
else:
|
| 193 |
+
return 'general'
|
| 194 |
+
|
| 195 |
+
def _extract_section_content(self, element):
|
| 196 |
+
"""Extract text content from a section."""
|
| 197 |
+
content = {
|
| 198 |
+
'headings': [],
|
| 199 |
+
'paragraphs': [],
|
| 200 |
+
'lists': [],
|
| 201 |
+
'code_blocks': []
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
for heading in element.find_all(['h1', 'h2', 'h3', 'h4', 'h5', 'h6']):
|
| 205 |
+
content['headings'].append({
|
| 206 |
+
'level': int(heading.name[1]),
|
| 207 |
+
'text': heading.get_text().strip()
|
| 208 |
+
})
|
| 209 |
+
|
| 210 |
+
for p in element.find_all('p'):
|
| 211 |
+
content['paragraphs'].append(p.get_text().strip())
|
| 212 |
+
|
| 213 |
+
for ul in element.find_all(['ul', 'ol']):
|
| 214 |
+
items = [li.get_text().strip() for li in ul.find_all('li')]
|
| 215 |
+
content['lists'].append({
|
| 216 |
+
'type': ul.name,
|
| 217 |
+
'items': items
|
| 218 |
+
})
|
| 219 |
+
|
| 220 |
+
for code in element.find_all(['code', 'pre']):
|
| 221 |
+
content['code_blocks'].append({
|
| 222 |
+
'type': code.name,
|
| 223 |
+
'content': code.get_text().strip()
|
| 224 |
+
})
|
| 225 |
+
|
| 226 |
+
return content
|
| 227 |
+
|
| 228 |
+
def _extract_images(self, element):
|
| 229 |
+
"""Extract image information from an element."""
|
| 230 |
+
images = []
|
| 231 |
+
for img in element.find_all('img'):
|
| 232 |
+
images.append({
|
| 233 |
+
'src': img.get('src', ''),
|
| 234 |
+
'alt': img.get('alt', ''),
|
| 235 |
+
'title': img.get('title', ''),
|
| 236 |
+
'class': img.get('class', [])
|
| 237 |
+
})
|
| 238 |
+
return images
|
| 239 |
+
|
| 240 |
+
def _extract_tables(self, element):
|
| 241 |
+
"""Extract table information from an element."""
|
| 242 |
+
tables = []
|
| 243 |
+
for table in element.find_all('table'):
|
| 244 |
+
table_info = {
|
| 245 |
+
'class': table.get('class', []),
|
| 246 |
+
'headers': [],
|
| 247 |
+
'rows': []
|
| 248 |
+
}
|
| 249 |
+
|
| 250 |
+
# Extract headers
|
| 251 |
+
for th in table.find_all('th'):
|
| 252 |
+
table_info['headers'].append(th.get_text().strip())
|
| 253 |
+
|
| 254 |
+
# Extract rows
|
| 255 |
+
for tr in table.find_all('tr'):
|
| 256 |
+
row = [td.get_text().strip() for td in tr.find_all('td')]
|
| 257 |
+
if row:
|
| 258 |
+
table_info['rows'].append(row)
|
| 259 |
+
|
| 260 |
+
tables.append(table_info)
|
| 261 |
+
|
| 262 |
+
return tables
|
| 263 |
+
|
| 264 |
+
def _extract_navigation(self, soup):
|
| 265 |
+
"""Extract navigation structure."""
|
| 266 |
+
nav = soup.find('nav')
|
| 267 |
+
if nav:
|
| 268 |
+
return {
|
| 269 |
+
'links': [a.get('href', '') for a in nav.find_all('a')],
|
| 270 |
+
'texts': [a.get_text().strip() for a in nav.find_all('a')],
|
| 271 |
+
'structure': self._extract_nav_structure(nav)
|
| 272 |
+
}
|
| 273 |
+
return None
|
| 274 |
+
|
| 275 |
+
def _extract_nav_structure(self, nav_element):
|
| 276 |
+
"""Extract the hierarchical structure of navigation."""
|
| 277 |
+
structure = []
|
| 278 |
+
for item in nav_element.find_all(['a', 'li'], recursive=False):
|
| 279 |
+
if item.name == 'a':
|
| 280 |
+
structure.append({
|
| 281 |
+
'type': 'link',
|
| 282 |
+
'text': item.get_text().strip(),
|
| 283 |
+
'href': item.get('href', '')
|
| 284 |
+
})
|
| 285 |
+
elif item.name == 'li':
|
| 286 |
+
sub_items = []
|
| 287 |
+
for sub_item in item.find_all('a'):
|
| 288 |
+
sub_items.append({
|
| 289 |
+
'text': sub_item.get_text().strip(),
|
| 290 |
+
'href': sub_item.get('href', '')
|
| 291 |
+
})
|
| 292 |
+
structure.append({
|
| 293 |
+
'type': 'group',
|
| 294 |
+
'items': sub_items
|
| 295 |
+
})
|
| 296 |
+
return structure
|
| 297 |
+
|
| 298 |
+
def _extract_hero_section(self, soup):
|
| 299 |
+
"""Extract hero section information."""
|
| 300 |
+
hero = soup.find(['header', 'section'], class_=re.compile(r'hero|banner|intro'))
|
| 301 |
+
if hero:
|
| 302 |
+
return {
|
| 303 |
+
'title': hero.find(['h1', 'h2']).get_text().strip() if hero.find(['h1', 'h2']) else '',
|
| 304 |
+
'subtitle': hero.find(['h2', 'h3', 'p']).get_text().strip() if hero.find(['h2', 'h3', 'p']) else '',
|
| 305 |
+
'background_image': hero.find('img').get('src', '') if hero.find('img') else '',
|
| 306 |
+
'cta_buttons': [a.get_text().strip() for a in hero.find_all('a', class_=re.compile(r'btn|button'))]
|
| 307 |
+
}
|
| 308 |
+
return None
|
| 309 |
+
|
| 310 |
+
def _extract_content_blocks(self, soup):
|
| 311 |
+
"""Extract content block patterns."""
|
| 312 |
+
blocks = []
|
| 313 |
+
for block in soup.find_all(['div', 'section'], class_=re.compile(r'content|block|section')):
|
| 314 |
+
blocks.append({
|
| 315 |
+
'classes': block.get('class', []),
|
| 316 |
+
'content_type': self._identify_content_type(block),
|
| 317 |
+
'has_images': bool(block.find('img')),
|
| 318 |
+
'has_tables': bool(block.find('table')),
|
| 319 |
+
'has_code': bool(block.find(['code', 'pre']))
|
| 320 |
+
})
|
| 321 |
+
return blocks
|
| 322 |
+
|
| 323 |
+
def _extract_image_galleries(self, soup):
|
| 324 |
+
"""Extract image gallery patterns."""
|
| 325 |
+
galleries = []
|
| 326 |
+
for gallery in soup.find_all(['div', 'section'], class_=re.compile(r'gallery|carousel|slider')):
|
| 327 |
+
images = gallery.find_all('img')
|
| 328 |
+
galleries.append({
|
| 329 |
+
'image_count': len(images),
|
| 330 |
+
'layout': 'grid' if 'grid' in str(gallery.get('class', [])) else 'carousel',
|
| 331 |
+
'images': [img.get('src', '') for img in images]
|
| 332 |
+
})
|
| 333 |
+
return galleries
|
| 334 |
+
|
| 335 |
+
def _extract_contact_forms(self, soup):
|
| 336 |
+
"""Extract contact form patterns."""
|
| 337 |
+
forms = []
|
| 338 |
+
for form in soup.find_all('form'):
|
| 339 |
+
form_info = {
|
| 340 |
+
'action': form.get('action', ''),
|
| 341 |
+
'method': form.get('method', 'get'),
|
| 342 |
+
'fields': []
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
for input_field in form.find_all(['input', 'textarea', 'select']):
|
| 346 |
+
form_info['fields'].append({
|
| 347 |
+
'type': input_field.get('type', input_field.name),
|
| 348 |
+
'name': input_field.get('name', ''),
|
| 349 |
+
'placeholder': input_field.get('placeholder', ''),
|
| 350 |
+
'required': input_field.get('required') is not None
|
| 351 |
+
})
|
| 352 |
+
|
| 353 |
+
forms.append(form_info)
|
| 354 |
+
|
| 355 |
+
return forms
|
| 356 |
+
|
| 357 |
+
def analyze_multiple_templates(self, template_files):
|
| 358 |
+
"""
|
| 359 |
+
Analyze multiple template files and find common patterns.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
template_files: List of template file paths
|
| 363 |
+
|
| 364 |
+
Returns:
|
| 365 |
+
dict: Analysis results with common patterns
|
| 366 |
+
"""
|
| 367 |
+
all_analyses = []
|
| 368 |
+
|
| 369 |
+
for template_file in template_files:
|
| 370 |
+
analysis = self.analyze_html_template(template_file)
|
| 371 |
+
if analysis:
|
| 372 |
+
all_analyses.append(analysis)
|
| 373 |
+
|
| 374 |
+
# Find common patterns
|
| 375 |
+
common_patterns = self._find_common_patterns(all_analyses)
|
| 376 |
+
|
| 377 |
+
return {
|
| 378 |
+
'individual_analyses': all_analyses,
|
| 379 |
+
'common_patterns': common_patterns
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
def _find_common_patterns(self, analyses):
|
| 383 |
+
"""Find common patterns across multiple template analyses."""
|
| 384 |
+
patterns = {
|
| 385 |
+
'common_sections': [],
|
| 386 |
+
'common_styles': [],
|
| 387 |
+
'common_components': [],
|
| 388 |
+
'color_schemes': [],
|
| 389 |
+
'layout_patterns': []
|
| 390 |
+
}
|
| 391 |
+
|
| 392 |
+
# Analyze common sections
|
| 393 |
+
all_sections = []
|
| 394 |
+
for analysis in analyses:
|
| 395 |
+
all_sections.extend(analysis['sections'])
|
| 396 |
+
|
| 397 |
+
section_types = {}
|
| 398 |
+
for section in all_sections:
|
| 399 |
+
content_type = section.get('content_type', 'unknown')
|
| 400 |
+
if content_type not in section_types:
|
| 401 |
+
section_types[content_type] = 0
|
| 402 |
+
section_types[content_type] += 1
|
| 403 |
+
|
| 404 |
+
patterns['common_sections'] = [
|
| 405 |
+
section_type for section_type, count in section_types.items()
|
| 406 |
+
if count > len(analyses) * 0.5 # Appears in more than 50% of templates
|
| 407 |
+
]
|
| 408 |
+
|
| 409 |
+
# Analyze common styles
|
| 410 |
+
all_colors = []
|
| 411 |
+
for analysis in analyses:
|
| 412 |
+
all_colors.extend(analysis['styling']['color_scheme'])
|
| 413 |
+
|
| 414 |
+
color_counts = {}
|
| 415 |
+
for color in all_colors:
|
| 416 |
+
if color not in color_counts:
|
| 417 |
+
color_counts[color] = 0
|
| 418 |
+
color_counts[color] += 1
|
| 419 |
+
|
| 420 |
+
patterns['color_schemes'] = [
|
| 421 |
+
color for color, count in color_counts.items()
|
| 422 |
+
if count > len(analyses) * 0.3 # Appears in more than 30% of templates
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
return patterns
|
| 426 |
+
|
| 427 |
+
def save_analysis(self, analysis, output_path):
|
| 428 |
+
"""Save analysis results to a JSON file."""
|
| 429 |
+
try:
|
| 430 |
+
with open(output_path, 'w') as f:
|
| 431 |
+
json.dump(analysis, f, indent=2)
|
| 432 |
+
print(f"Analysis saved to {output_path}")
|
| 433 |
+
return True
|
| 434 |
+
except Exception as e:
|
| 435 |
+
print(f"Error saving analysis: {e}")
|
| 436 |
+
return False
|
app.py
ADDED
|
@@ -0,0 +1,1671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import gradio as gr
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
import base64
|
| 7 |
+
import re
|
| 8 |
+
from threading import Thread
|
| 9 |
+
from http.server import HTTPServer, SimpleHTTPRequestHandler
|
| 10 |
+
import socket
|
| 11 |
+
from dotenv import load_dotenv
|
| 12 |
+
from ProjectPageAgent.parse_paper import parse_paper_for_project_page, save_parsed_content
|
| 13 |
+
from ProjectPageAgent.html_finder import HtmlFinder
|
| 14 |
+
from ProjectPageAgent.content_planner import ProjectPageContentPlanner
|
| 15 |
+
from ProjectPageAgent.html_generator import ProjectPageHTMLGenerator, to_url
|
| 16 |
+
from utils.wei_utils import get_agent_config
|
| 17 |
+
import os
|
| 18 |
+
import subprocess
|
| 19 |
+
|
| 20 |
+
from ProjectPageAgent.content_planner import filter_references
|
| 21 |
+
from utils.src.utils import run_sync_screenshots
|
| 22 |
+
from ProjectPageAgent.main_pipline import matching, copy_static_files
|
| 23 |
+
|
| 24 |
+
load_dotenv()
|
| 25 |
+
|
| 26 |
+
subprocess.run(["playwright", "install", "chromium"], check=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_agent_config_with_keys(model_type, openai_api_key="", gemini_api_key="",
|
| 30 |
+
qwen_api_key="", zhipuai_api_key="", openrouter_api_key=""):
|
| 31 |
+
"""
|
| 32 |
+
Get agent configuration with user-provided API keys.
|
| 33 |
+
Falls back to environment variables if user keys are not provided.
|
| 34 |
+
Note: This function sets environment variables but does NOT restore them.
|
| 35 |
+
The environment variables will remain set for the duration of the application.
|
| 36 |
+
"""
|
| 37 |
+
# Set environment variables with user-provided keys
|
| 38 |
+
api_keys = {
|
| 39 |
+
'OPENAI_API_KEY': openai_api_key,
|
| 40 |
+
'GEMINI_API_KEY': gemini_api_key,
|
| 41 |
+
'QWEN_API_KEY': qwen_api_key,
|
| 42 |
+
'ZHIPUAI_API_KEY': zhipuai_api_key,
|
| 43 |
+
'OPENROUTER_API_KEY': openrouter_api_key
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# Set new API keys in environment
|
| 47 |
+
for key, value in api_keys.items():
|
| 48 |
+
if value and value.strip():
|
| 49 |
+
os.environ[key] = value
|
| 50 |
+
|
| 51 |
+
# Get agent config with the new API keys
|
| 52 |
+
config = get_agent_config(model_type)
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
def validate_api_keys(model_name_t, model_name_v, openai_api_key, gemini_api_key,
|
| 56 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key):
|
| 57 |
+
"""
|
| 58 |
+
Validate that required API keys are provided for the selected models.
|
| 59 |
+
"""
|
| 60 |
+
errors = []
|
| 61 |
+
|
| 62 |
+
# Check text model requirements
|
| 63 |
+
if model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
|
| 64 |
+
if not openai_api_key or not openai_api_key.strip():
|
| 65 |
+
errors.append("OpenAI API key is required for GPT models")
|
| 66 |
+
elif model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 67 |
+
if not gemini_api_key or not gemini_api_key.strip():
|
| 68 |
+
errors.append("Gemini API key is required for Gemini models")
|
| 69 |
+
elif model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
|
| 70 |
+
if not qwen_api_key or not qwen_api_key.strip():
|
| 71 |
+
errors.append("Qwen API key is required for Qwen models")
|
| 72 |
+
elif model_name_t.startswith('openrouter_'):
|
| 73 |
+
if not openrouter_api_key or not openrouter_api_key.strip():
|
| 74 |
+
errors.append("OpenRouter API key is required for OpenRouter models")
|
| 75 |
+
|
| 76 |
+
# Check vision model requirements
|
| 77 |
+
if model_name_v in ['4o', '4o-mini']:
|
| 78 |
+
if not openai_api_key or not openai_api_key.strip():
|
| 79 |
+
errors.append("OpenAI API key is required for GPT vision models")
|
| 80 |
+
elif model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
|
| 81 |
+
if not gemini_api_key or not gemini_api_key.strip():
|
| 82 |
+
errors.append("Gemini API key is required for Gemini vision models")
|
| 83 |
+
elif model_name_v in ['qwen-vl-max', 'qwen-2.5-vl-72b']:
|
| 84 |
+
if not qwen_api_key or not qwen_api_key.strip():
|
| 85 |
+
errors.append("Qwen API key is required for Qwen vision models")
|
| 86 |
+
elif model_name_v.startswith('openrouter_'):
|
| 87 |
+
if not openrouter_api_key or not openrouter_api_key.strip():
|
| 88 |
+
errors.append("OpenRouter API key is required for OpenRouter vision models")
|
| 89 |
+
|
| 90 |
+
return errors
|
| 91 |
+
|
| 92 |
+
# Global Variables
|
| 93 |
+
current_html_dir = None
|
| 94 |
+
preview_server = None
|
| 95 |
+
preview_port = None
|
| 96 |
+
template_preview_servers = []
|
| 97 |
+
|
| 98 |
+
class CustomHTTPRequestHandler(SimpleHTTPRequestHandler):
|
| 99 |
+
def __init__(self, *args, **kwargs):
|
| 100 |
+
super().__init__(*args, directory=current_html_dir, **kwargs)
|
| 101 |
+
|
| 102 |
+
def log_message(self, format, *args):
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
def find_free_port(start_port=8000, max_attempts=100):
|
| 106 |
+
for port in range(start_port, start_port + max_attempts):
|
| 107 |
+
try:
|
| 108 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
| 109 |
+
s.bind(('', port))
|
| 110 |
+
return port
|
| 111 |
+
except OSError:
|
| 112 |
+
continue
|
| 113 |
+
raise RuntimeError(f"Could not find available port")
|
| 114 |
+
|
| 115 |
+
def start_preview_server(html_dir):
|
| 116 |
+
global current_html_dir, preview_server, preview_port
|
| 117 |
+
stop_preview_server()
|
| 118 |
+
current_html_dir = html_dir
|
| 119 |
+
preview_port = find_free_port()
|
| 120 |
+
preview_server = HTTPServer(('0.0.0.0', preview_port), CustomHTTPRequestHandler)
|
| 121 |
+
server_thread = Thread(target=preview_server.serve_forever, daemon=True)
|
| 122 |
+
server_thread.start()
|
| 123 |
+
return preview_port
|
| 124 |
+
|
| 125 |
+
def stop_preview_server():
|
| 126 |
+
global preview_server, preview_port
|
| 127 |
+
if preview_server:
|
| 128 |
+
preview_server.shutdown()
|
| 129 |
+
preview_server = None
|
| 130 |
+
preview_port = None
|
| 131 |
+
|
| 132 |
+
def start_ephemeral_server_for_dir(html_dir):
|
| 133 |
+
port = find_free_port()
|
| 134 |
+
class _TempHandler(SimpleHTTPRequestHandler):
|
| 135 |
+
def __init__(self, *args, **kwargs):
|
| 136 |
+
super().__init__(*args, directory=html_dir, **kwargs)
|
| 137 |
+
def log_message(self, format, *args):
|
| 138 |
+
pass
|
| 139 |
+
srv = HTTPServer(('0.0.0.0', port), _TempHandler)
|
| 140 |
+
t = Thread(target=srv.serve_forever, daemon=True)
|
| 141 |
+
t.start()
|
| 142 |
+
template_preview_servers.append((srv, port))
|
| 143 |
+
return port
|
| 144 |
+
|
| 145 |
+
def stop_all_template_preview_servers():
|
| 146 |
+
global template_preview_servers
|
| 147 |
+
for srv, _ in template_preview_servers:
|
| 148 |
+
try:
|
| 149 |
+
srv.shutdown()
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
template_preview_servers = []
|
| 153 |
+
|
| 154 |
+
class GenerationArgs:
|
| 155 |
+
def __init__(self, paper_path, model_name_t, model_name_v, template_root,
|
| 156 |
+
template_dir, template_file, output_dir, style_preference, tmp_dir,
|
| 157 |
+
full_content_check_times, background_color, has_navigation,
|
| 158 |
+
has_hero_section, title_color, page_density, image_layout,
|
| 159 |
+
html_check_times, resume, human_input):
|
| 160 |
+
self.paper_path = paper_path
|
| 161 |
+
self.model_name_t = model_name_t
|
| 162 |
+
self.model_name_v = model_name_v
|
| 163 |
+
self.template_root = template_root
|
| 164 |
+
self.template_dir = template_dir
|
| 165 |
+
self.template_file = template_file
|
| 166 |
+
self.output_dir = output_dir
|
| 167 |
+
self.style_preference = style_preference
|
| 168 |
+
self.tmp_dir = tmp_dir
|
| 169 |
+
self.full_content_check_times = full_content_check_times
|
| 170 |
+
self.background_color = background_color
|
| 171 |
+
self.has_navigation = has_navigation
|
| 172 |
+
self.has_hero_section = has_hero_section
|
| 173 |
+
self.title_color = title_color
|
| 174 |
+
self.page_density = page_density
|
| 175 |
+
self.image_layout = image_layout
|
| 176 |
+
self.html_check_times = html_check_times
|
| 177 |
+
self.resume = resume
|
| 178 |
+
self.human_input = human_input
|
| 179 |
+
self.paper_name = None
|
| 180 |
+
|
| 181 |
+
# ==================== Formatting Functions ====================
|
| 182 |
+
|
| 183 |
+
def format_section_to_markdown(section_data):
|
| 184 |
+
"""
|
| 185 |
+
Convert Section JSON to beautifully formatted Markdown
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
section_data: Section JSON data
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
str: Formatted Markdown string
|
| 192 |
+
"""
|
| 193 |
+
if not section_data:
|
| 194 |
+
return "No data available"
|
| 195 |
+
|
| 196 |
+
md_lines = []
|
| 197 |
+
|
| 198 |
+
# Title
|
| 199 |
+
md_lines.append("# 📄 Paper Page Structure Preview\n")
|
| 200 |
+
|
| 201 |
+
# Basic Information
|
| 202 |
+
if "title" in section_data:
|
| 203 |
+
md_lines.append(f"## 📌 Title\n**{section_data['title']}**\n")
|
| 204 |
+
|
| 205 |
+
if "authors" in section_data:
|
| 206 |
+
md_lines.append(f"## 👥 Authors\n{section_data['authors']}\n")
|
| 207 |
+
|
| 208 |
+
if "affiliation" in section_data:
|
| 209 |
+
md_lines.append(f"## 🏛️ Affiliation\n{section_data['affiliation']}\n")
|
| 210 |
+
|
| 211 |
+
# Other Sections
|
| 212 |
+
md_lines.append("## 📑 Page Sections\n")
|
| 213 |
+
|
| 214 |
+
section_count = 0
|
| 215 |
+
for key, value in section_data.items():
|
| 216 |
+
if key in ["title", "authors", "affiliation"]:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
section_count += 1
|
| 220 |
+
|
| 221 |
+
# Section Title
|
| 222 |
+
section_title = key.replace("_", " ").title()
|
| 223 |
+
md_lines.append(f"### {section_count}. {section_title}\n")
|
| 224 |
+
|
| 225 |
+
# Section Content
|
| 226 |
+
if isinstance(value, dict):
|
| 227 |
+
# If dictionary, process recursively
|
| 228 |
+
for sub_key, sub_value in value.items():
|
| 229 |
+
sub_title = sub_key.replace("_", " ").title()
|
| 230 |
+
md_lines.append(f"**{sub_title}**: {sub_value}\n")
|
| 231 |
+
elif isinstance(value, list):
|
| 232 |
+
# If list
|
| 233 |
+
for item in value:
|
| 234 |
+
if isinstance(item, str):
|
| 235 |
+
md_lines.append(f"- {item}\n")
|
| 236 |
+
elif isinstance(item, dict):
|
| 237 |
+
for k, v in item.items():
|
| 238 |
+
md_lines.append(f"- **{k}**: {v}\n")
|
| 239 |
+
else:
|
| 240 |
+
# Simple value
|
| 241 |
+
md_lines.append(f"{value}\n")
|
| 242 |
+
|
| 243 |
+
md_lines.append("") # Empty line
|
| 244 |
+
|
| 245 |
+
# Add Statistics
|
| 246 |
+
md_lines.append("---\n")
|
| 247 |
+
md_lines.append(f"**📊 Total {section_count} sections**\n")
|
| 248 |
+
|
| 249 |
+
return "\n".join(md_lines)
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def format_full_content_to_markdown(content_data, figures=None):
|
| 253 |
+
"""
|
| 254 |
+
Convert Full Content JSON to beautifully formatted Markdown
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
content_data: Full Content JSON data
|
| 258 |
+
figures: Images and tables data (optional)
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
str: Formatted Markdown string
|
| 262 |
+
"""
|
| 263 |
+
if not content_data:
|
| 264 |
+
return "No data available"
|
| 265 |
+
|
| 266 |
+
md_lines = []
|
| 267 |
+
|
| 268 |
+
# Title
|
| 269 |
+
md_lines.append("# 📄 Full Content Preview\n")
|
| 270 |
+
|
| 271 |
+
# Basic Information
|
| 272 |
+
if "title" in content_data:
|
| 273 |
+
md_lines.append(f"# {content_data['title']}\n")
|
| 274 |
+
|
| 275 |
+
if "authors" in content_data:
|
| 276 |
+
md_lines.append(f"**Authors**: {content_data['authors']}\n")
|
| 277 |
+
|
| 278 |
+
if "affiliation" in content_data:
|
| 279 |
+
md_lines.append(f"**Affiliation**: {content_data['affiliation']}\n")
|
| 280 |
+
|
| 281 |
+
md_lines.append("---\n")
|
| 282 |
+
|
| 283 |
+
# Process Each Section
|
| 284 |
+
section_count = 0
|
| 285 |
+
image_count = 0
|
| 286 |
+
table_count = 0
|
| 287 |
+
|
| 288 |
+
for key, value in content_data.items():
|
| 289 |
+
if key in ["title", "authors", "affiliation"]:
|
| 290 |
+
continue
|
| 291 |
+
|
| 292 |
+
section_count += 1
|
| 293 |
+
|
| 294 |
+
# Section Title
|
| 295 |
+
section_title = key.replace("_", " ").title()
|
| 296 |
+
md_lines.append(f"## {section_count}. {section_title}\n")
|
| 297 |
+
|
| 298 |
+
# Process Content
|
| 299 |
+
if isinstance(value, dict):
|
| 300 |
+
# Process dictionary type content
|
| 301 |
+
for sub_key, sub_value in value.items():
|
| 302 |
+
if sub_key.lower() in ['content', 'description', 'text']:
|
| 303 |
+
# Main text content
|
| 304 |
+
md_lines.append(f"{sub_value}\n")
|
| 305 |
+
elif sub_key.lower() in ['image', 'figure', 'img']:
|
| 306 |
+
# Image
|
| 307 |
+
image_count += 1
|
| 308 |
+
if isinstance(sub_value, dict):
|
| 309 |
+
caption = sub_value.get('caption', f'Figure {image_count}')
|
| 310 |
+
path = sub_value.get('path', '')
|
| 311 |
+
md_lines.append(f"\n**🖼️ {caption}**\n")
|
| 312 |
+
if path:
|
| 313 |
+
md_lines.append(f"*Image path: `{path}`*\n")
|
| 314 |
+
else:
|
| 315 |
+
md_lines.append(f"\n**🖼️ Figure {image_count}**: {sub_value}\n")
|
| 316 |
+
elif sub_key.lower() in ['table']:
|
| 317 |
+
# Table
|
| 318 |
+
table_count += 1
|
| 319 |
+
md_lines.append(f"\n**📊 Table {table_count}**\n")
|
| 320 |
+
if isinstance(sub_value, dict):
|
| 321 |
+
caption = sub_value.get('caption', f'Table {table_count}')
|
| 322 |
+
md_lines.append(f"*{caption}*\n")
|
| 323 |
+
else:
|
| 324 |
+
md_lines.append(f"{sub_value}\n")
|
| 325 |
+
elif sub_key.lower() in ['code']:
|
| 326 |
+
# Code block
|
| 327 |
+
md_lines.append(f"\n```\n{sub_value}\n```\n")
|
| 328 |
+
else:
|
| 329 |
+
# Other subtitles
|
| 330 |
+
sub_title = sub_key.replace("_", " ").title()
|
| 331 |
+
md_lines.append(f"\n### {sub_title}\n")
|
| 332 |
+
md_lines.append(f"{sub_value}\n")
|
| 333 |
+
|
| 334 |
+
elif isinstance(value, list):
|
| 335 |
+
# Process list type content
|
| 336 |
+
for idx, item in enumerate(value):
|
| 337 |
+
if isinstance(item, dict):
|
| 338 |
+
# Dictionary items in list
|
| 339 |
+
if 'title' in item or 'name' in item:
|
| 340 |
+
item_title = item.get('title', item.get('name', f'Item {idx+1}'))
|
| 341 |
+
md_lines.append(f"\n### {item_title}\n")
|
| 342 |
+
|
| 343 |
+
for k, v in item.items():
|
| 344 |
+
if k not in ['title', 'name']:
|
| 345 |
+
if k.lower() in ['content', 'description', 'text']:
|
| 346 |
+
md_lines.append(f"{v}\n")
|
| 347 |
+
elif k.lower() in ['image', 'figure']:
|
| 348 |
+
image_count += 1
|
| 349 |
+
md_lines.append(f"\n**🖼️ Figure {image_count}**: {v}\n")
|
| 350 |
+
elif k.lower() == 'table':
|
| 351 |
+
table_count += 1
|
| 352 |
+
md_lines.append(f"\n**📊 Table {table_count}**: {v}\n")
|
| 353 |
+
else:
|
| 354 |
+
k_title = k.replace("_", " ").title()
|
| 355 |
+
md_lines.append(f"**{k_title}**: {v}\n")
|
| 356 |
+
else:
|
| 357 |
+
# Simple list item
|
| 358 |
+
md_lines.append(f"- {item}\n")
|
| 359 |
+
|
| 360 |
+
else:
|
| 361 |
+
# Simple text value
|
| 362 |
+
md_lines.append(f"{value}\n")
|
| 363 |
+
|
| 364 |
+
md_lines.append("") # Empty line between sections
|
| 365 |
+
|
| 366 |
+
# Add Statistics
|
| 367 |
+
md_lines.append("\n---\n")
|
| 368 |
+
stats = []
|
| 369 |
+
stats.append(f"📊 **Statistics**")
|
| 370 |
+
stats.append(f"- Sections: {section_count}")
|
| 371 |
+
if image_count > 0:
|
| 372 |
+
stats.append(f"- Images: {image_count}")
|
| 373 |
+
if table_count > 0:
|
| 374 |
+
stats.append(f"- Tables: {table_count}")
|
| 375 |
+
|
| 376 |
+
# If figures data is provided, add more information
|
| 377 |
+
if figures:
|
| 378 |
+
if 'images' in figures and figures['images']:
|
| 379 |
+
stats.append(f"- Available images: {len(figures['images'])}")
|
| 380 |
+
if 'tables' in figures and figures['tables']:
|
| 381 |
+
stats.append(f"- Available tables: {len(figures['tables'])}")
|
| 382 |
+
|
| 383 |
+
md_lines.append("\n".join(stats))
|
| 384 |
+
md_lines.append("\n")
|
| 385 |
+
|
| 386 |
+
return "\n".join(md_lines)
|
| 387 |
+
|
| 388 |
+
# ==================== Global State Management ====================
|
| 389 |
+
|
| 390 |
+
class GenerationState:
|
| 391 |
+
def __init__(self):
|
| 392 |
+
self.reset()
|
| 393 |
+
|
| 394 |
+
def reset(self):
|
| 395 |
+
self.args = None
|
| 396 |
+
self.paper_content = None
|
| 397 |
+
self.figures = None
|
| 398 |
+
self.generated_section = None
|
| 399 |
+
self.text_page_content = None
|
| 400 |
+
self.generated_content = None
|
| 401 |
+
self.html_content = None
|
| 402 |
+
self.html_file_path = None
|
| 403 |
+
self.html_dir = None
|
| 404 |
+
self.planner = None
|
| 405 |
+
self.html_generator = None
|
| 406 |
+
self.agent_config_t = None
|
| 407 |
+
self.total_input_tokens_t = 0
|
| 408 |
+
self.total_output_tokens_t = 0
|
| 409 |
+
self.current_stage = "init"
|
| 410 |
+
self.preview_url = None
|
| 411 |
+
|
| 412 |
+
state = GenerationState()
|
| 413 |
+
|
| 414 |
+
def create_project_zip(project_dir, output_dir, paper_name):
|
| 415 |
+
"""
|
| 416 |
+
Create project archive
|
| 417 |
+
|
| 418 |
+
Args:
|
| 419 |
+
project_dir: Project directory path
|
| 420 |
+
output_dir: Output directory
|
| 421 |
+
paper_name: Paper name
|
| 422 |
+
|
| 423 |
+
Returns:
|
| 424 |
+
str: Archive path, None if failed
|
| 425 |
+
"""
|
| 426 |
+
import zipfile
|
| 427 |
+
|
| 428 |
+
zip_filename = f"{paper_name}_project_page.zip"
|
| 429 |
+
zip_path = os.path.join(output_dir, zip_filename)
|
| 430 |
+
|
| 431 |
+
print(f"Creating project archive: {zip_path}")
|
| 432 |
+
|
| 433 |
+
try:
|
| 434 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 435 |
+
# Traverse project directory, add all files
|
| 436 |
+
for root, dirs, files in os.walk(project_dir):
|
| 437 |
+
for file in files:
|
| 438 |
+
file_path = os.path.join(root, file)
|
| 439 |
+
# Calculate relative path
|
| 440 |
+
arcname = os.path.relpath(file_path, output_dir)
|
| 441 |
+
zipf.write(file_path, arcname)
|
| 442 |
+
|
| 443 |
+
print(f"Archive created successfully: {zip_path}")
|
| 444 |
+
|
| 445 |
+
# Get archive size
|
| 446 |
+
zip_size = os.path.getsize(zip_path)
|
| 447 |
+
zip_size_mb = zip_size / (1024 * 1024)
|
| 448 |
+
print(f"Archive size: {zip_size_mb:.2f} MB")
|
| 449 |
+
|
| 450 |
+
return zip_path
|
| 451 |
+
|
| 452 |
+
except Exception as e:
|
| 453 |
+
print(f"Archive creation failed: {e}")
|
| 454 |
+
return None
|
| 455 |
+
|
| 456 |
+
def start_generation(pdf_file, model_name_t, model_name_v, template_root,
|
| 457 |
+
template_dir, template_file, output_dir, style_preference,
|
| 458 |
+
tmp_dir, full_content_check_times, background_color,
|
| 459 |
+
has_navigation, has_hero_section, title_color, page_density,
|
| 460 |
+
image_layout, html_check_times, resume, human_input,
|
| 461 |
+
template_choice_value, openai_api_key, gemini_api_key,
|
| 462 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key):
|
| 463 |
+
"""Start generation process"""
|
| 464 |
+
if pdf_file is None:
|
| 465 |
+
return "❌ Please upload a PDF file", gr.update(visible=False), "", "", gr.update(), gr.update(), ""
|
| 466 |
+
|
| 467 |
+
# Validate API keys
|
| 468 |
+
validation_errors = validate_api_keys(
|
| 469 |
+
model_name_t, model_name_v, openai_api_key, gemini_api_key,
|
| 470 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key
|
| 471 |
+
)
|
| 472 |
+
|
| 473 |
+
if validation_errors:
|
| 474 |
+
error_msg = "❌ API Key Validation Failed:\n" + "\n".join(f"• {error}" for error in validation_errors)
|
| 475 |
+
return error_msg, gr.update(visible=False), "", "", gr.update(), gr.update(), ""
|
| 476 |
+
|
| 477 |
+
state.reset()
|
| 478 |
+
|
| 479 |
+
# Handle template selection
|
| 480 |
+
if not (template_dir and str(template_dir).strip()):
|
| 481 |
+
if not template_choice_value:
|
| 482 |
+
stop_all_template_preview_servers()
|
| 483 |
+
template_requirement = {
|
| 484 |
+
"background_color": background_color,
|
| 485 |
+
"has_hero_section": has_hero_section,
|
| 486 |
+
"Page density": page_density,
|
| 487 |
+
"image_layout": image_layout,
|
| 488 |
+
"has_navigation": has_navigation,
|
| 489 |
+
"title_color": title_color
|
| 490 |
+
}
|
| 491 |
+
try:
|
| 492 |
+
matched = matching(template_requirement)
|
| 493 |
+
except Exception as e:
|
| 494 |
+
return f"❌ Template recommendation failed: {e}", gr.update(visible=False), "", "", gr.update(choices=[], value=None), gr.update(visible=False, value=""), ""
|
| 495 |
+
|
| 496 |
+
html_finder_ = HtmlFinder()
|
| 497 |
+
with open('templates/template_link.json','r') as f:
|
| 498 |
+
template_link = json.load(f)
|
| 499 |
+
previews = []
|
| 500 |
+
for name in matched:
|
| 501 |
+
t_dir = os.path.join(template_root, name)
|
| 502 |
+
try:
|
| 503 |
+
html_path = html_finder_.find_html(t_dir)
|
| 504 |
+
if not os.path.exists(html_path):
|
| 505 |
+
continue
|
| 506 |
+
html_dir = os.path.dirname(os.path.abspath(html_path))
|
| 507 |
+
filename = os.path.basename(html_path)
|
| 508 |
+
port = start_ephemeral_server_for_dir(html_dir)
|
| 509 |
+
url = template_link[name]
|
| 510 |
+
previews.append((name, html_path, url))
|
| 511 |
+
except Exception:
|
| 512 |
+
continue
|
| 513 |
+
|
| 514 |
+
if not previews:
|
| 515 |
+
return "❌ No previewable templates found", gr.update(visible=False), "", "", gr.update(choices=[], value=None), gr.update(visible=False, value=""), ""
|
| 516 |
+
|
| 517 |
+
md_lines = ["### 🔍 Please select a template to preview before clicking **Start Generation**", ""]
|
| 518 |
+
for name, _, url in previews:
|
| 519 |
+
md_lines.append(f"- **{name}** → [{url}]({url})")
|
| 520 |
+
md = "\n".join(md_lines)
|
| 521 |
+
|
| 522 |
+
return "Recommended 3 templates, please select one to continue", gr.update(visible=False), "", "", gr.update(choices=[n for n, _, _ in previews], value=None), gr.update(visible=True, value=md), ""
|
| 523 |
+
|
| 524 |
+
template_dir = os.path.join(template_root, template_choice_value)
|
| 525 |
+
|
| 526 |
+
# Create arguments object
|
| 527 |
+
args = GenerationArgs(
|
| 528 |
+
paper_path=pdf_file.name,
|
| 529 |
+
model_name_t=model_name_t,
|
| 530 |
+
model_name_v=model_name_v,
|
| 531 |
+
template_root=template_root,
|
| 532 |
+
template_dir=template_dir,
|
| 533 |
+
template_file=template_file,
|
| 534 |
+
output_dir=output_dir,
|
| 535 |
+
style_preference=style_preference,
|
| 536 |
+
tmp_dir=tmp_dir,
|
| 537 |
+
full_content_check_times=full_content_check_times,
|
| 538 |
+
background_color=background_color,
|
| 539 |
+
has_navigation=has_navigation,
|
| 540 |
+
has_hero_section=has_hero_section,
|
| 541 |
+
title_color=title_color,
|
| 542 |
+
page_density=page_density,
|
| 543 |
+
image_layout=image_layout,
|
| 544 |
+
html_check_times=html_check_times,
|
| 545 |
+
resume=resume,
|
| 546 |
+
human_input=human_input
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
if not args.template_dir:
|
| 550 |
+
return "❌ Please select a template", gr.update(visible=False), "", "", gr.update(), gr.update(), ""
|
| 551 |
+
|
| 552 |
+
if not args.template_file:
|
| 553 |
+
html_finder_ = HtmlFinder()
|
| 554 |
+
args.template_file = html_finder_.find_html(args.template_dir)
|
| 555 |
+
|
| 556 |
+
paper_name = args.paper_path.split('/')[-1].replace('.pdf', '') if '/' in args.paper_path else args.paper_path.replace('.pdf', '')
|
| 557 |
+
args.paper_name = paper_name
|
| 558 |
+
|
| 559 |
+
os.makedirs(args.tmp_dir, exist_ok=True)
|
| 560 |
+
|
| 561 |
+
try:
|
| 562 |
+
# Initialization
|
| 563 |
+
agent_config_t = get_agent_config_with_keys(
|
| 564 |
+
args.model_name_t, openai_api_key, gemini_api_key,
|
| 565 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key
|
| 566 |
+
)
|
| 567 |
+
state.agent_config_t = agent_config_t
|
| 568 |
+
state.args = args
|
| 569 |
+
|
| 570 |
+
# Step 1: Parse PDF
|
| 571 |
+
print("="*50)
|
| 572 |
+
print("STEP 1: Parsing Research Paper")
|
| 573 |
+
print("="*50)
|
| 574 |
+
|
| 575 |
+
raw_content_path = f'project_contents/{args.paper_name}_raw_content.json'
|
| 576 |
+
if not os.path.exists(raw_content_path):
|
| 577 |
+
agent_config_v = get_agent_config_with_keys(
|
| 578 |
+
args.model_name_v, openai_api_key, gemini_api_key,
|
| 579 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key
|
| 580 |
+
)
|
| 581 |
+
input_token, output_token, raw_result, images, tables = parse_paper_for_project_page(args, agent_config_t)
|
| 582 |
+
state.total_input_tokens_t += input_token
|
| 583 |
+
state.total_output_tokens_t += output_token
|
| 584 |
+
raw_content_path, _ = save_parsed_content(args, raw_result, images, tables, input_token, output_token)
|
| 585 |
+
|
| 586 |
+
with open(raw_content_path, 'r') as f:
|
| 587 |
+
paper_content = json.load(f)
|
| 588 |
+
|
| 589 |
+
images = paper_content.get('images', [])
|
| 590 |
+
tables = paper_content.get('tables', [])
|
| 591 |
+
figures = {'images': images, 'tables': tables}
|
| 592 |
+
paper_content = paper_content.get('markdown_content', "")
|
| 593 |
+
|
| 594 |
+
state.paper_content = paper_content
|
| 595 |
+
state.figures = figures
|
| 596 |
+
|
| 597 |
+
# Step 2: Filter content
|
| 598 |
+
print("="*50)
|
| 599 |
+
print("STEP 2: Filtering Content")
|
| 600 |
+
print("="*50)
|
| 601 |
+
|
| 602 |
+
planner = ProjectPageContentPlanner(agent_config_t, args)
|
| 603 |
+
state.planner = planner
|
| 604 |
+
|
| 605 |
+
paper_content, figures, input_token, output_token = planner.filter_raw_content(paper_content, figures)
|
| 606 |
+
state.total_input_tokens_t += input_token
|
| 607 |
+
state.total_output_tokens_t += output_token
|
| 608 |
+
state.paper_content = paper_content
|
| 609 |
+
state.figures = figures
|
| 610 |
+
|
| 611 |
+
# Step 3: Generate Section
|
| 612 |
+
print("="*50)
|
| 613 |
+
print("STEP 3: Generating Sections")
|
| 614 |
+
print("="*50)
|
| 615 |
+
|
| 616 |
+
state.current_stage = "section"
|
| 617 |
+
|
| 618 |
+
generated_section, input_token, output_token = generate_section_initial()
|
| 619 |
+
state.total_input_tokens_t += input_token
|
| 620 |
+
state.total_output_tokens_t += output_token
|
| 621 |
+
|
| 622 |
+
# Use Markdown formatting
|
| 623 |
+
section_display_md = format_section_to_markdown(generated_section)
|
| 624 |
+
section_display_json = json.dumps(generated_section, indent=2, ensure_ascii=False)
|
| 625 |
+
|
| 626 |
+
return (
|
| 627 |
+
f"✅ Section generation completed, please review and provide feedback\n\nTokens: {input_token} → {output_token}",
|
| 628 |
+
gr.update(visible=True), # feedback_section
|
| 629 |
+
section_display_md, # Markdown format
|
| 630 |
+
section_display_json, # JSON format (hidden)
|
| 631 |
+
gr.update(),
|
| 632 |
+
gr.update(visible=False, value=""),
|
| 633 |
+
""
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
except Exception as e:
|
| 637 |
+
import traceback
|
| 638 |
+
error_msg = f"❌ Generation failed: {str(e)}\n{traceback.format_exc()}"
|
| 639 |
+
return error_msg, gr.update(visible=False), "", "", gr.update(), gr.update(), ""
|
| 640 |
+
|
| 641 |
+
def generate_section_initial():
|
| 642 |
+
"""Generate initial Section"""
|
| 643 |
+
import yaml
|
| 644 |
+
from jinja2 import Environment, StrictUndefined
|
| 645 |
+
from utils.wei_utils import account_token
|
| 646 |
+
from utils.src.utils import get_json_from_response
|
| 647 |
+
|
| 648 |
+
with open('utils/prompt_templates/page_templates/section_generation.yaml', 'r') as f:
|
| 649 |
+
planner_config = yaml.safe_load(f)
|
| 650 |
+
|
| 651 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 652 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 653 |
+
|
| 654 |
+
jinja_args = {
|
| 655 |
+
'paper_content': state.paper_content,
|
| 656 |
+
'json_format_example': json.dumps(state.paper_content, indent=2)
|
| 657 |
+
}
|
| 658 |
+
|
| 659 |
+
prompt = template.render(**jinja_args)
|
| 660 |
+
|
| 661 |
+
state.planner.planner_agent.reset()
|
| 662 |
+
response = state.planner.planner_agent.step(prompt)
|
| 663 |
+
input_token, output_token = account_token(response)
|
| 664 |
+
generated_section = get_json_from_response(response.msgs[0].content)
|
| 665 |
+
|
| 666 |
+
def create_dynamic_page_dict(sections):
|
| 667 |
+
poster_dict = {
|
| 668 |
+
"title": "Title of the paper",
|
| 669 |
+
"authors": "Authors of the paper",
|
| 670 |
+
"affiliation": "Affiliation of the authors",
|
| 671 |
+
}
|
| 672 |
+
poster_dict.update(sections)
|
| 673 |
+
return poster_dict
|
| 674 |
+
|
| 675 |
+
generated_section = create_dynamic_page_dict(generated_section)
|
| 676 |
+
state.generated_section = generated_section
|
| 677 |
+
|
| 678 |
+
generated_path = f'project_contents/{state.args.paper_name}_generated_section.json'
|
| 679 |
+
with open(generated_path, 'w') as f:
|
| 680 |
+
json.dump(generated_section, f, indent=4)
|
| 681 |
+
|
| 682 |
+
return generated_section, input_token, output_token
|
| 683 |
+
|
| 684 |
+
def submit_section_feedback(feedback_text):
|
| 685 |
+
"""Submit Section feedback"""
|
| 686 |
+
if not feedback_text or feedback_text.strip().lower() == 'yes':
|
| 687 |
+
# User satisfied, proceed to next stage
|
| 688 |
+
result = proceed_to_text_content()
|
| 689 |
+
status, fc_section_visible, fc_display_visible, fc_display_md, fc_display_json, fc_feedback_visible = result
|
| 690 |
+
return (
|
| 691 |
+
status,
|
| 692 |
+
"", # section_display_md clear
|
| 693 |
+
"", # section_display_json clear
|
| 694 |
+
"", # section_feedback_input clear
|
| 695 |
+
gr.update(visible=False), # feedback_section hide
|
| 696 |
+
fc_section_visible, # feedback_full_content show
|
| 697 |
+
fc_display_visible, # full_content_display_md show
|
| 698 |
+
fc_display_md, # full_content_display_md content
|
| 699 |
+
fc_display_json, # full_content_display_json content
|
| 700 |
+
fc_feedback_visible # full_content_feedback_input show
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
# User provides feedback, modify Section
|
| 704 |
+
from camel.messages import BaseMessage
|
| 705 |
+
from utils.wei_utils import account_token
|
| 706 |
+
from utils.src.utils import get_json_from_response
|
| 707 |
+
|
| 708 |
+
message = BaseMessage.make_assistant_message(
|
| 709 |
+
role_name='User',
|
| 710 |
+
content=f'human feedback: {feedback_text}\n\nPlease make modifications based on this feedback. Output format as specified above.'
|
| 711 |
+
)
|
| 712 |
+
response = state.planner.planner_agent.step(message)
|
| 713 |
+
input_token, output_token = account_token(response)
|
| 714 |
+
state.total_input_tokens_t += input_token
|
| 715 |
+
state.total_output_tokens_t += output_token
|
| 716 |
+
|
| 717 |
+
generated_section = get_json_from_response(response.msgs[0].content)
|
| 718 |
+
state.generated_section = generated_section
|
| 719 |
+
|
| 720 |
+
generated_path = f'project_contents/{state.args.paper_name}_generated_section.json'
|
| 721 |
+
with open(generated_path, 'w') as f:
|
| 722 |
+
json.dump(generated_section, f, indent=4)
|
| 723 |
+
|
| 724 |
+
# Use Markdown formatting
|
| 725 |
+
section_display_md = format_section_to_markdown(generated_section)
|
| 726 |
+
section_display_json = json.dumps(generated_section, indent=2, ensure_ascii=False)
|
| 727 |
+
|
| 728 |
+
return (
|
| 729 |
+
f"✅ Section updated, please continue reviewing\n\nTokens: {input_token} → {output_token}",
|
| 730 |
+
section_display_md, # Markdown format
|
| 731 |
+
section_display_json, # JSON format
|
| 732 |
+
"", # Clear input box
|
| 733 |
+
gr.update(visible=True), # feedback_section keep visible
|
| 734 |
+
gr.update(visible=False), # feedback_full_content keep hidden
|
| 735 |
+
gr.update(visible=False), # full_content_display_md keep hidden
|
| 736 |
+
"", # full_content_display_md content
|
| 737 |
+
"", # full_content_display_json content
|
| 738 |
+
gr.update(visible=False) # full_content_feedback_input keep hidden
|
| 739 |
+
)
|
| 740 |
+
|
| 741 |
+
def proceed_to_text_content():
|
| 742 |
+
"""Enter Text Content generation stage"""
|
| 743 |
+
print("="*50)
|
| 744 |
+
print("STEP 4: Generating Text Content")
|
| 745 |
+
print("="*50)
|
| 746 |
+
|
| 747 |
+
text_page_content, input_token, output_token = state.planner.text_content_generation(
|
| 748 |
+
state.paper_content, state.figures, state.generated_section
|
| 749 |
+
)
|
| 750 |
+
state.total_input_tokens_t += input_token
|
| 751 |
+
state.total_output_tokens_t += output_token
|
| 752 |
+
state.text_page_content = text_page_content
|
| 753 |
+
|
| 754 |
+
# Enter Full Content stage
|
| 755 |
+
return proceed_to_full_content()
|
| 756 |
+
|
| 757 |
+
def proceed_to_full_content():
|
| 758 |
+
"""Enter Full Content generation stage"""
|
| 759 |
+
print("="*50)
|
| 760 |
+
print("STEP 5: Generating Full Content")
|
| 761 |
+
print("="*50)
|
| 762 |
+
|
| 763 |
+
state.current_stage = "full_content"
|
| 764 |
+
|
| 765 |
+
generated_content, input_token, output_token = generate_full_content_initial()
|
| 766 |
+
state.total_input_tokens_t += input_token
|
| 767 |
+
state.total_output_tokens_t += output_token
|
| 768 |
+
|
| 769 |
+
# Use Markdown formatting
|
| 770 |
+
content_display_md = format_full_content_to_markdown(generated_content, state.figures)
|
| 771 |
+
content_display_json = json.dumps(generated_content, indent=2, ensure_ascii=False)
|
| 772 |
+
|
| 773 |
+
return (
|
| 774 |
+
f"✅ Full Content generation completed, please review and provide feedback\n\nTokens: {input_token} → {output_token}",
|
| 775 |
+
gr.update(visible=True), # feedback_full_content show
|
| 776 |
+
gr.update(visible=True), # full_content_display_md show
|
| 777 |
+
content_display_md, # Markdown format
|
| 778 |
+
content_display_json, # JSON format
|
| 779 |
+
gr.update(visible=True) # full_content_feedback_input show
|
| 780 |
+
)
|
| 781 |
+
|
| 782 |
+
def generate_full_content_initial():
|
| 783 |
+
"""Generate initial Full Content"""
|
| 784 |
+
import yaml
|
| 785 |
+
from jinja2 import Environment, StrictUndefined
|
| 786 |
+
from utils.wei_utils import account_token
|
| 787 |
+
from utils.src.utils import get_json_from_response
|
| 788 |
+
|
| 789 |
+
with open('utils/prompt_templates/page_templates/full_content_generation.yaml', 'r') as f:
|
| 790 |
+
planner_config = yaml.safe_load(f)
|
| 791 |
+
|
| 792 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 793 |
+
template = jinja_env.from_string(planner_config["template"])
|
| 794 |
+
|
| 795 |
+
jinja_args = {
|
| 796 |
+
'paper_content': state.paper_content,
|
| 797 |
+
'figures': json.dumps(state.figures, indent=2),
|
| 798 |
+
'project_page_content': json.dumps(state.text_page_content, indent=2)
|
| 799 |
+
}
|
| 800 |
+
|
| 801 |
+
prompt = template.render(**jinja_args)
|
| 802 |
+
|
| 803 |
+
state.planner.planner_agent.reset()
|
| 804 |
+
response = state.planner.planner_agent.step(prompt)
|
| 805 |
+
input_token, output_token = account_token(response)
|
| 806 |
+
generated_content = get_json_from_response(response.msgs[0].content)
|
| 807 |
+
|
| 808 |
+
state.generated_content = generated_content
|
| 809 |
+
|
| 810 |
+
first_path = f'project_contents/{state.args.paper_name}_generated_full_content.v0.json'
|
| 811 |
+
with open(first_path, 'w', encoding='utf-8') as f:
|
| 812 |
+
json.dump(generated_content, f, ensure_ascii=False, indent=2)
|
| 813 |
+
|
| 814 |
+
return generated_content, input_token, output_token
|
| 815 |
+
|
| 816 |
+
def submit_full_content_feedback(feedback_text):
|
| 817 |
+
"""Submit Full Content feedback"""
|
| 818 |
+
if not feedback_text or feedback_text.strip().lower() == 'yes':
|
| 819 |
+
# User satisfied, proceed to HTML generation
|
| 820 |
+
result = proceed_to_html_generation()
|
| 821 |
+
status, html_feedback_visible, preview_info, preview_url, open_btn_visible = result
|
| 822 |
+
return (
|
| 823 |
+
status,
|
| 824 |
+
"", # full_content_display_md clear
|
| 825 |
+
"", # full_content_display_json clear
|
| 826 |
+
"", # full_content_feedback_input clear
|
| 827 |
+
gr.update(visible=False), # feedback_full_content hide
|
| 828 |
+
html_feedback_visible, # feedback_html show
|
| 829 |
+
preview_info, # preview_info_display
|
| 830 |
+
preview_url, # preview_url_state
|
| 831 |
+
open_btn_visible # open_preview_btn show
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# User provides feedback
|
| 835 |
+
from camel.messages import BaseMessage
|
| 836 |
+
from utils.wei_utils import account_token
|
| 837 |
+
from utils.src.utils import get_json_from_response
|
| 838 |
+
|
| 839 |
+
message = BaseMessage.make_assistant_message(
|
| 840 |
+
role_name='User',
|
| 841 |
+
content=f'human feedback: {feedback_text}\n\nPlease make modifications based on this feedback. Output format as specified above.'
|
| 842 |
+
)
|
| 843 |
+
response = state.planner.planner_agent.step(message)
|
| 844 |
+
input_token, output_token = account_token(response)
|
| 845 |
+
state.total_input_tokens_t += input_token
|
| 846 |
+
state.total_output_tokens_t += output_token
|
| 847 |
+
|
| 848 |
+
generated_content = get_json_from_response(response.msgs[0].content)
|
| 849 |
+
state.generated_content = generated_content
|
| 850 |
+
|
| 851 |
+
final_path = f'project_contents/{state.args.paper_name}_generated_full_content.json'
|
| 852 |
+
with open(final_path, 'w', encoding='utf-8') as f:
|
| 853 |
+
json.dump(generated_content, f, ensure_ascii=False, indent=2)
|
| 854 |
+
|
| 855 |
+
# Use Markdown formatting
|
| 856 |
+
content_display_md = format_full_content_to_markdown(generated_content, state.figures)
|
| 857 |
+
content_display_json = json.dumps(generated_content, indent=2, ensure_ascii=False)
|
| 858 |
+
|
| 859 |
+
return (
|
| 860 |
+
f"✅ Full Content updated, please continue reviewing\n\nTokens: {input_token} → {output_token}",
|
| 861 |
+
content_display_md, # Markdown format
|
| 862 |
+
content_display_json, # JSON format
|
| 863 |
+
"", # Clear input box
|
| 864 |
+
gr.update(visible=True), # feedback_full_content keep visible
|
| 865 |
+
gr.update(visible=False), # feedback_html keep hidden
|
| 866 |
+
"", # preview_info_display
|
| 867 |
+
"", # preview_url_state
|
| 868 |
+
gr.update(visible=False) # open_preview_btn keep hidden
|
| 869 |
+
)
|
| 870 |
+
|
| 871 |
+
def proceed_to_html_generation():
|
| 872 |
+
"""Enter HTML generation stage"""
|
| 873 |
+
print("="*50)
|
| 874 |
+
print("STEP 6: Generating HTML")
|
| 875 |
+
print("="*50)
|
| 876 |
+
|
| 877 |
+
state.current_stage = "html"
|
| 878 |
+
|
| 879 |
+
# Copy static files
|
| 880 |
+
static_dir = copy_static_files(
|
| 881 |
+
state.args.template_file,
|
| 882 |
+
state.args.template_dir,
|
| 883 |
+
state.args.output_dir,
|
| 884 |
+
state.args.paper_name
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
# Generate HTML
|
| 888 |
+
html_relative_path = os.path.relpath(state.args.template_file, state.args.template_dir)
|
| 889 |
+
html_dir = '/'.join(html_relative_path.strip().split('/')[:-1])
|
| 890 |
+
state.html_dir = html_dir
|
| 891 |
+
|
| 892 |
+
html_generator = ProjectPageHTMLGenerator(state.agent_config_t, state.args)
|
| 893 |
+
state.html_generator = html_generator
|
| 894 |
+
|
| 895 |
+
with open(state.args.template_file, 'r', encoding='utf-8') as file:
|
| 896 |
+
html_template = file.read()
|
| 897 |
+
|
| 898 |
+
# Create assets directory
|
| 899 |
+
assets_dir = html_generator.create_assets_directory(state.args, html_dir, state.args.output_dir)
|
| 900 |
+
|
| 901 |
+
# Generate HTML
|
| 902 |
+
html_content, input_token, output_token = html_generator.generate_complete_html(
|
| 903 |
+
state.args, state.generated_content, html_dir, html_template
|
| 904 |
+
)
|
| 905 |
+
state.total_input_tokens_t += input_token
|
| 906 |
+
state.total_output_tokens_t += output_token
|
| 907 |
+
|
| 908 |
+
# Save HTML (before table modification)
|
| 909 |
+
html_dir_path = os.path.join(state.args.output_dir, state.args.paper_name, html_dir)
|
| 910 |
+
os.makedirs(html_dir_path, exist_ok=True)
|
| 911 |
+
|
| 912 |
+
html_file_path_no_modify = os.path.join(html_dir_path, 'index_no_modify_table.html')
|
| 913 |
+
with open(html_file_path_no_modify, 'w', encoding='utf-8') as file:
|
| 914 |
+
file.write(html_content)
|
| 915 |
+
|
| 916 |
+
# Generate screenshot (before table modification)
|
| 917 |
+
screenshot_path_no_modify = os.path.join(html_dir_path, 'page_final_no_modify_table.png')
|
| 918 |
+
run_sync_screenshots(to_url(html_file_path_no_modify), screenshot_path_no_modify)
|
| 919 |
+
|
| 920 |
+
# Modify tables
|
| 921 |
+
html_content, input_token, output_token = html_generator.modify_html_table(html_content, html_dir)
|
| 922 |
+
state.total_input_tokens_t += input_token
|
| 923 |
+
state.total_output_tokens_t += output_token
|
| 924 |
+
|
| 925 |
+
state.html_content = html_content
|
| 926 |
+
|
| 927 |
+
# Save HTML (after table modification)
|
| 928 |
+
html_file_path = os.path.join(html_dir_path, 'index.html')
|
| 929 |
+
with open(html_file_path, 'w', encoding='utf-8') as file:
|
| 930 |
+
file.write(html_content)
|
| 931 |
+
|
| 932 |
+
state.html_file_path = html_file_path
|
| 933 |
+
|
| 934 |
+
# Generate screenshot (after table modification)
|
| 935 |
+
run_sync_screenshots(
|
| 936 |
+
to_url(html_file_path),
|
| 937 |
+
os.path.join(html_dir_path, 'page_final.png')
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# Start preview server
|
| 941 |
+
html_full_dir = os.path.dirname(os.path.abspath(html_file_path))
|
| 942 |
+
port = start_preview_server(html_full_dir)
|
| 943 |
+
preview_url = f"http://localhost:{port}/index.html"
|
| 944 |
+
state.preview_url = preview_url
|
| 945 |
+
|
| 946 |
+
# Create preview info display
|
| 947 |
+
preview_info = f"""
|
| 948 |
+
### 🌐 HTML Generation Completed
|
| 949 |
+
|
| 950 |
+
**Preview URL**: {preview_url}
|
| 951 |
+
|
| 952 |
+
**Instructions**:
|
| 953 |
+
1. Click the **"🌐 Open Preview in New Tab"** button below to view the generated webpage
|
| 954 |
+
2. Carefully review the page in the new tab
|
| 955 |
+
3. If satisfied, enter **'yes'** in the feedback box and submit
|
| 956 |
+
4. If modifications are needed, provide detailed feedback and submit
|
| 957 |
+
|
| 958 |
+
**Token Usage**: {input_token} → {output_token}
|
| 959 |
+
"""
|
| 960 |
+
|
| 961 |
+
return (
|
| 962 |
+
f"✅ HTML generation completed\n\nTokens: {input_token} → {output_token}",
|
| 963 |
+
gr.update(visible=True), # feedback_html show
|
| 964 |
+
preview_info, # preview_info_display
|
| 965 |
+
preview_url, # preview_url_state
|
| 966 |
+
gr.update(visible=True) # open_preview_btn show
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
def submit_html_feedback(feedback_text):
|
| 970 |
+
"""Submit HTML feedback"""
|
| 971 |
+
if not feedback_text or feedback_text.strip().lower() == 'yes':
|
| 972 |
+
# User satisfied, complete generation
|
| 973 |
+
result = finalize_generation()
|
| 974 |
+
status, html_file = result
|
| 975 |
+
return (
|
| 976 |
+
status,
|
| 977 |
+
"", # preview_info_display clear
|
| 978 |
+
"", # html_feedback_input clear
|
| 979 |
+
gr.update(visible=False), # feedback_html hide
|
| 980 |
+
gr.update(visible=False), # open_preview_btn hide
|
| 981 |
+
html_file # html_file_output
|
| 982 |
+
)
|
| 983 |
+
|
| 984 |
+
# User provides feedback
|
| 985 |
+
html_content, input_token, output_token = state.html_generator.modify_html_from_human_feedback(
|
| 986 |
+
state.html_content, feedback_text
|
| 987 |
+
)
|
| 988 |
+
state.total_input_tokens_t += input_token
|
| 989 |
+
state.total_output_tokens_t += output_token
|
| 990 |
+
state.html_content = html_content
|
| 991 |
+
|
| 992 |
+
# Save updated HTML
|
| 993 |
+
html_dir_path = os.path.dirname(state.html_file_path)
|
| 994 |
+
|
| 995 |
+
# Save as temporary version (for possible feedback iteration)
|
| 996 |
+
import time
|
| 997 |
+
timestamp = int(time.time())
|
| 998 |
+
html_file_feedback = os.path.join(html_dir_path, f'index_feedback_{timestamp}.html')
|
| 999 |
+
with open(html_file_feedback, 'w', encoding='utf-8') as file:
|
| 1000 |
+
file.write(html_content)
|
| 1001 |
+
|
| 1002 |
+
# Also update main file
|
| 1003 |
+
with open(state.html_file_path, 'w', encoding='utf-8') as file:
|
| 1004 |
+
file.write(html_content)
|
| 1005 |
+
|
| 1006 |
+
# Regenerate screenshot
|
| 1007 |
+
screenshot_path = os.path.join(html_dir_path, 'page_final.png')
|
| 1008 |
+
try:
|
| 1009 |
+
run_sync_screenshots(to_url(state.html_file_path), screenshot_path)
|
| 1010 |
+
except Exception as e:
|
| 1011 |
+
print(f"Screenshot generation failed: {e}")
|
| 1012 |
+
|
| 1013 |
+
# Update preview info
|
| 1014 |
+
preview_info = f"""
|
| 1015 |
+
### 🌐 HTML Updated
|
| 1016 |
+
|
| 1017 |
+
**Preview URL**: {state.preview_url}
|
| 1018 |
+
|
| 1019 |
+
**Instructions**:
|
| 1020 |
+
1. Click the **"🌐 Open Preview in New Tab"** button below to view the updated webpage
|
| 1021 |
+
2. **Refresh the browser** to see the latest version
|
| 1022 |
+
3. If satisfied, enter **'yes'** in the feedback box and submit
|
| 1023 |
+
4. If further modifications are needed, continue providing feedback
|
| 1024 |
+
|
| 1025 |
+
**Token Usage**: {input_token} → {output_token}
|
| 1026 |
+
"""
|
| 1027 |
+
|
| 1028 |
+
return (
|
| 1029 |
+
f"✅ HTML updated, please refresh the preview page\n\nTokens: {input_token} → {output_token}",
|
| 1030 |
+
preview_info, # preview_info_display
|
| 1031 |
+
"", # Clear input box
|
| 1032 |
+
gr.update(visible=True), # feedback_html keep visible
|
| 1033 |
+
gr.update(visible=True), # open_preview_btn keep visible
|
| 1034 |
+
None # html_file_output no download yet
|
| 1035 |
+
)
|
| 1036 |
+
|
| 1037 |
+
def finalize_generation():
|
| 1038 |
+
"""Complete generation and save final results"""
|
| 1039 |
+
import time
|
| 1040 |
+
|
| 1041 |
+
# Ensure final HTML is saved
|
| 1042 |
+
html_dir_path = os.path.dirname(state.html_file_path)
|
| 1043 |
+
|
| 1044 |
+
# Save final version
|
| 1045 |
+
final_html_path = os.path.join(html_dir_path, 'index_final.html')
|
| 1046 |
+
with open(final_html_path, 'w', encoding='utf-8') as file:
|
| 1047 |
+
file.write(state.html_content)
|
| 1048 |
+
|
| 1049 |
+
# Also update main file
|
| 1050 |
+
with open(state.html_file_path, 'w', encoding='utf-8') as file:
|
| 1051 |
+
file.write(state.html_content)
|
| 1052 |
+
|
| 1053 |
+
# Save metadata
|
| 1054 |
+
metadata = state.html_generator.generate_metadata(state.generated_content, state.args)
|
| 1055 |
+
metadata_path = state.html_generator.save_metadata(metadata, state.args, state.args.output_dir)
|
| 1056 |
+
|
| 1057 |
+
# Create README file
|
| 1058 |
+
readme_path = os.path.join(state.args.output_dir, state.args.paper_name, 'README.md')
|
| 1059 |
+
readme_content = f"""# {state.args.paper_name} - Project Page
|
| 1060 |
+
|
| 1061 |
+
## 📄 Project Information
|
| 1062 |
+
|
| 1063 |
+
- **Paper Name**: {state.args.paper_name}
|
| 1064 |
+
- **Generation Time**: {time.strftime('%Y-%m-%d %H:%M:%S')}
|
| 1065 |
+
- **Text Model**: {state.args.model_name_t}
|
| 1066 |
+
- **Vision Model**: {state.args.model_name_v}
|
| 1067 |
+
|
| 1068 |
+
## 🚀 Usage
|
| 1069 |
+
|
| 1070 |
+
1. Extract this archive to any directory
|
| 1071 |
+
2. Open `index.html` to view the project page
|
| 1072 |
+
3. All resources (CSS, images, etc.) are included
|
| 1073 |
+
|
| 1074 |
+
## 📁 File Structure
|
| 1075 |
+
|
| 1076 |
+
- `index.html` - Main page file
|
| 1077 |
+
- `index_final.html` - Final confirmed version
|
| 1078 |
+
- `assets/` - Image and table resources
|
| 1079 |
+
- `css/` or `styles/` - Style files
|
| 1080 |
+
- `js/` or `scripts/` - JavaScript files
|
| 1081 |
+
- `metadata.json` - Page metadata
|
| 1082 |
+
- `generation_log.json` - Generation log
|
| 1083 |
+
|
| 1084 |
+
## 💡 Tips
|
| 1085 |
+
|
| 1086 |
+
- Recommended browsers: Chrome, Firefox, Safari, Edge
|
| 1087 |
+
- For web deployment, simply upload the entire folder
|
| 1088 |
+
- Feel free to modify HTML and CSS for customization
|
| 1089 |
+
|
| 1090 |
+
---
|
| 1091 |
+
Generated by Paper2ProjectPage
|
| 1092 |
+
"""
|
| 1093 |
+
|
| 1094 |
+
with open(readme_path, 'w', encoding='utf-8') as f:
|
| 1095 |
+
f.write(readme_content)
|
| 1096 |
+
|
| 1097 |
+
# Save generation log
|
| 1098 |
+
log_data = {
|
| 1099 |
+
'paper_name': state.args.paper_name,
|
| 1100 |
+
'paper_path': state.args.paper_path,
|
| 1101 |
+
'models': {
|
| 1102 |
+
'text_model': state.args.model_name_t,
|
| 1103 |
+
'vision_model': state.args.model_name_v
|
| 1104 |
+
},
|
| 1105 |
+
'token_usage': {
|
| 1106 |
+
'text_input_tokens': state.total_input_tokens_t,
|
| 1107 |
+
'text_output_tokens': state.total_output_tokens_t
|
| 1108 |
+
},
|
| 1109 |
+
'output_files': {
|
| 1110 |
+
'html_file': state.html_file_path,
|
| 1111 |
+
'final_html_file': final_html_path,
|
| 1112 |
+
'metadata_file': metadata_path,
|
| 1113 |
+
'readme_file': readme_path
|
| 1114 |
+
},
|
| 1115 |
+
'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
|
| 1116 |
+
}
|
| 1117 |
+
|
| 1118 |
+
log_path = f"{state.args.output_dir}/{state.args.paper_name}/generation_log.json"
|
| 1119 |
+
with open(log_path, 'w') as f:
|
| 1120 |
+
json.dump(log_data, f, indent=4, ensure_ascii=False)
|
| 1121 |
+
|
| 1122 |
+
# Create project archive
|
| 1123 |
+
project_dir = os.path.join(state.args.output_dir, state.args.paper_name)
|
| 1124 |
+
zip_path = create_project_zip(project_dir, state.args.output_dir, state.args.paper_name)
|
| 1125 |
+
|
| 1126 |
+
if zip_path and os.path.exists(zip_path):
|
| 1127 |
+
# Get archive size
|
| 1128 |
+
zip_size = os.path.getsize(zip_path)
|
| 1129 |
+
zip_size_mb = zip_size / (1024 * 1024)
|
| 1130 |
+
zip_filename = os.path.basename(zip_path)
|
| 1131 |
+
|
| 1132 |
+
success_msg = f"""
|
| 1133 |
+
✅ Project page generation completed!
|
| 1134 |
+
|
| 1135 |
+
📁 Output directory: {state.args.output_dir}/{state.args.paper_name}
|
| 1136 |
+
🌐 HTML file: {state.html_file_path}
|
| 1137 |
+
🌐 Final version: {final_html_path}
|
| 1138 |
+
📋 Metadata: {metadata_path}
|
| 1139 |
+
📖 README: {readme_path}
|
| 1140 |
+
📊 Log file: {log_path}
|
| 1141 |
+
📦 Archive: {zip_filename} ({zip_size_mb:.2f} MB)
|
| 1142 |
+
🔢 Total token usage: {state.total_input_tokens_t} → {state.total_output_tokens_t}
|
| 1143 |
+
|
| 1144 |
+
🎉 All feedback completed, page successfully generated!
|
| 1145 |
+
Click the button below to download the complete project archive (including HTML, CSS, images, README, and all resources).
|
| 1146 |
+
"""
|
| 1147 |
+
|
| 1148 |
+
return (
|
| 1149 |
+
success_msg,
|
| 1150 |
+
zip_path # Return archive for download
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
else:
|
| 1154 |
+
error_msg = f"""
|
| 1155 |
+
⚠️ Project page generated, but archive creation failed!
|
| 1156 |
+
|
| 1157 |
+
📁 Output directory: {state.args.output_dir}/{state.args.paper_name}
|
| 1158 |
+
🌐 HTML file: {state.html_file_path}
|
| 1159 |
+
📋 Metadata: {metadata_path}
|
| 1160 |
+
|
| 1161 |
+
You can manually retrieve all files from the output directory {project_dir}.
|
| 1162 |
+
"""
|
| 1163 |
+
return (
|
| 1164 |
+
error_msg,
|
| 1165 |
+
state.html_file_path # Return HTML file
|
| 1166 |
+
)
|
| 1167 |
+
|
| 1168 |
+
# ==================== Gradio Interface ====================
|
| 1169 |
+
|
| 1170 |
+
# Custom CSS for better English font rendering
|
| 1171 |
+
custom_css = """
|
| 1172 |
+
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&family=JetBrains+Mono:wght@400;500&display=swap');
|
| 1173 |
+
|
| 1174 |
+
* {
|
| 1175 |
+
font-family: 'Inter', -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif !important;
|
| 1176 |
+
}
|
| 1177 |
+
|
| 1178 |
+
code, pre, .code {
|
| 1179 |
+
font-family: 'JetBrains Mono', 'Courier New', Consolas, Monaco, monospace !important;
|
| 1180 |
+
}
|
| 1181 |
+
|
| 1182 |
+
h1, h2, h3, h4, h5, h6 {
|
| 1183 |
+
font-weight: 600 !important;
|
| 1184 |
+
letter-spacing: -0.02em !important;
|
| 1185 |
+
}
|
| 1186 |
+
|
| 1187 |
+
.markdown-text {
|
| 1188 |
+
line-height: 1.7 !important;
|
| 1189 |
+
font-size: 15px !important;
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
.gr-button {
|
| 1193 |
+
font-weight: 500 !important;
|
| 1194 |
+
letter-spacing: 0.01em !important;
|
| 1195 |
+
}
|
| 1196 |
+
|
| 1197 |
+
.gr-input, .gr-textarea {
|
| 1198 |
+
font-size: 14px !important;
|
| 1199 |
+
line-height: 1.6 !important;
|
| 1200 |
+
}
|
| 1201 |
+
|
| 1202 |
+
.gr-box {
|
| 1203 |
+
border-radius: 8px !important;
|
| 1204 |
+
}
|
| 1205 |
+
|
| 1206 |
+
/* Better spacing for English content */
|
| 1207 |
+
.gr-markdown p {
|
| 1208 |
+
margin-bottom: 0.8em !important;
|
| 1209 |
+
}
|
| 1210 |
+
|
| 1211 |
+
.gr-markdown ul, .gr-markdown ol {
|
| 1212 |
+
margin-left: 1.2em !important;
|
| 1213 |
+
}
|
| 1214 |
+
|
| 1215 |
+
.gr-markdown li {
|
| 1216 |
+
margin-bottom: 0.4em !important;
|
| 1217 |
+
}
|
| 1218 |
+
"""
|
| 1219 |
+
|
| 1220 |
+
with gr.Blocks(title="Paper2ProjectPage Generator", theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 1221 |
+
|
| 1222 |
+
gr.Markdown("""
|
| 1223 |
+
# 📄 AutoPage Generator with Interactive Feedback
|
| 1224 |
+
|
| 1225 |
+
Upload your research paper PDF and generate beautiful project pages through multi-round interactive feedback
|
| 1226 |
+
""")
|
| 1227 |
+
|
| 1228 |
+
with gr.Row():
|
| 1229 |
+
with gr.Column(scale=1):
|
| 1230 |
+
# PDF Upload
|
| 1231 |
+
pdf_input = gr.File(
|
| 1232 |
+
label="📎 Upload PDF Paper",
|
| 1233 |
+
file_types=[".pdf"],
|
| 1234 |
+
type="filepath"
|
| 1235 |
+
)
|
| 1236 |
+
|
| 1237 |
+
gr.Markdown("### 🔑 API Keys Configuration")
|
| 1238 |
+
gr.Markdown("""
|
| 1239 |
+
**⚠️ Security Notice**: Your API keys are only stored in memory during the session and are never saved to disk.
|
| 1240 |
+
|
| 1241 |
+
**📋 How to get API keys:**
|
| 1242 |
+
- **OpenAI**: Get your API key from [OpenAI Platform](https://platform.openai.com/api-keys)
|
| 1243 |
+
- **Gemini**: Get your API key from [Google AI Studio](https://aistudio.google.com/app/apikey)
|
| 1244 |
+
- **Qwen**: Get your API key from [DashScope](https://dashscope.console.aliyun.com/apiKey)
|
| 1245 |
+
- **ZhipuAI**: Get your API key from [ZhipuAI Console](https://open.bigmodel.cn/usercenter/apikeys)
|
| 1246 |
+
- **OpenRouter**: Get your API key from [OpenRouter](https://openrouter.ai/keys)
|
| 1247 |
+
|
| 1248 |
+
**🚀 For HuggingFace Spaces**: You can also set these as environment variables in your Space settings.
|
| 1249 |
+
""")
|
| 1250 |
+
|
| 1251 |
+
with gr.Row():
|
| 1252 |
+
openai_api_key = gr.Textbox(
|
| 1253 |
+
label="OpenAI API Key",
|
| 1254 |
+
value=os.getenv("OPENAI_API_KEY", ""),
|
| 1255 |
+
type="password",
|
| 1256 |
+
placeholder="sk-...",
|
| 1257 |
+
info="Required for GPT models"
|
| 1258 |
+
)
|
| 1259 |
+
gemini_api_key = gr.Textbox(
|
| 1260 |
+
label="Gemini API Key",
|
| 1261 |
+
value=os.getenv("GEMINI_API_KEY", ""),
|
| 1262 |
+
type="password",
|
| 1263 |
+
placeholder="AI...",
|
| 1264 |
+
info="Required for Gemini models"
|
| 1265 |
+
)
|
| 1266 |
+
|
| 1267 |
+
with gr.Row():
|
| 1268 |
+
qwen_api_key = gr.Textbox(
|
| 1269 |
+
label="Qwen API Key",
|
| 1270 |
+
value=os.getenv("QWEN_API_KEY", ""),
|
| 1271 |
+
type="password",
|
| 1272 |
+
placeholder="sk-...",
|
| 1273 |
+
info="Required for Qwen models"
|
| 1274 |
+
)
|
| 1275 |
+
zhipuai_api_key = gr.Textbox(
|
| 1276 |
+
label="ZhipuAI API Key",
|
| 1277 |
+
value=os.getenv("ZHIPUAI_API_KEY", ""),
|
| 1278 |
+
type="password",
|
| 1279 |
+
placeholder="...",
|
| 1280 |
+
info="Required for GLM models"
|
| 1281 |
+
)
|
| 1282 |
+
|
| 1283 |
+
openrouter_api_key = gr.Textbox(
|
| 1284 |
+
label="OpenRouter API Key",
|
| 1285 |
+
value=os.getenv("OPENROUTER_API_KEY", ""),
|
| 1286 |
+
type="password",
|
| 1287 |
+
placeholder="sk-or-...",
|
| 1288 |
+
info="Required for OpenRouter models"
|
| 1289 |
+
)
|
| 1290 |
+
|
| 1291 |
+
gr.Markdown("### 🤖 Model Configuration")
|
| 1292 |
+
|
| 1293 |
+
# Text Model Options
|
| 1294 |
+
text_model_options = [
|
| 1295 |
+
("GPT-4o", "4o"),
|
| 1296 |
+
("GPT-4o Mini", "4o-mini"),
|
| 1297 |
+
("GPT-4.1", "gpt-4.1"),
|
| 1298 |
+
("GPT-4.1 Mini", "gpt-4.1-mini"),
|
| 1299 |
+
("O1", "o1"),
|
| 1300 |
+
("O3", "o3"),
|
| 1301 |
+
("O3 Mini", "o3-mini"),
|
| 1302 |
+
("Gemini 2.5 Pro", "gemini"),
|
| 1303 |
+
("Gemini 2.5 Pro (Alt)", "gemini-2.5-pro"),
|
| 1304 |
+
("Gemini 2.5 Flash", "gemini-2.5-flash"),
|
| 1305 |
+
("Qwen", "qwen"),
|
| 1306 |
+
("Qwen Plus", "qwen-plus"),
|
| 1307 |
+
("Qwen Max", "qwen-max"),
|
| 1308 |
+
("Qwen Long", "qwen-long"),
|
| 1309 |
+
("OpenRouter Qwen Plus", "openrouter_qwen-plus"),
|
| 1310 |
+
("OpenRouter GPT-4o Mini", "openrouter_gpt-4o-mini"),
|
| 1311 |
+
("OpenRouter Gemini 2.5 Flash", "openrouter_gemini-2.5-flash"),
|
| 1312 |
+
("OpenRouter O3", "openrouter_openai/o3"),
|
| 1313 |
+
("OpenRouter Claude Sonnet 4.5", "openrouter_claude-sonnet-4.5"),
|
| 1314 |
+
]
|
| 1315 |
+
|
| 1316 |
+
# Vision Model Options
|
| 1317 |
+
vision_model_options = [
|
| 1318 |
+
("GPT-4o", "4o"),
|
| 1319 |
+
("GPT-4o Mini", "4o-mini"),
|
| 1320 |
+
("Gemini 2.5 Pro", "gemini"),
|
| 1321 |
+
("Gemini 2.5 Pro (Alt)", "gemini-2.5-pro"),
|
| 1322 |
+
("Gemini 2.5 Flash", "gemini-2.5-flash"),
|
| 1323 |
+
("Qwen VL Max", "qwen-vl-max"),
|
| 1324 |
+
("Qwen 2.5 VL 72B", "qwen-2.5-vl-72b"),
|
| 1325 |
+
("OpenRouter Qwen VL 72B", "openrouter_qwen_vl_72b"),
|
| 1326 |
+
("OpenRouter Qwen VL 7B", "openrouter_qwen_vl_7b"),
|
| 1327 |
+
("OpenRouter Qwen VL Max", "openrouter_qwen-vl-max"),
|
| 1328 |
+
("OpenRouter Gemini 2.5 Flash", "openrouter_gemini-2.5-flash"),
|
| 1329 |
+
]
|
| 1330 |
+
|
| 1331 |
+
with gr.Row():
|
| 1332 |
+
model_name_t = gr.Dropdown(
|
| 1333 |
+
label="Text Model",
|
| 1334 |
+
choices=text_model_options,
|
| 1335 |
+
value="gemini",
|
| 1336 |
+
info="Select model for text processing"
|
| 1337 |
+
)
|
| 1338 |
+
model_name_v = gr.Dropdown(
|
| 1339 |
+
label="Vision Model",
|
| 1340 |
+
choices=vision_model_options,
|
| 1341 |
+
value="gemini",
|
| 1342 |
+
info="Select model for vision processing"
|
| 1343 |
+
)
|
| 1344 |
+
|
| 1345 |
+
gr.Markdown("### 📁 Path Configuration")
|
| 1346 |
+
template_root = gr.Textbox(
|
| 1347 |
+
label="Template Root",
|
| 1348 |
+
value="templates",
|
| 1349 |
+
info="Root directory for templates"
|
| 1350 |
+
)
|
| 1351 |
+
template_dir = gr.Textbox(
|
| 1352 |
+
label="Template Directory",
|
| 1353 |
+
value="",
|
| 1354 |
+
info="Selected template directory (optional)"
|
| 1355 |
+
)
|
| 1356 |
+
template_file = gr.Textbox(
|
| 1357 |
+
label="Template File",
|
| 1358 |
+
value="",
|
| 1359 |
+
info="Specific template file path (optional)"
|
| 1360 |
+
)
|
| 1361 |
+
template_choice = gr.Radio(
|
| 1362 |
+
label="Recommended Templates",
|
| 1363 |
+
choices=[],
|
| 1364 |
+
value=None,
|
| 1365 |
+
info="Select from recommended templates",
|
| 1366 |
+
visible=True
|
| 1367 |
+
)
|
| 1368 |
+
output_dir = gr.Textbox(
|
| 1369 |
+
label="Output Directory",
|
| 1370 |
+
value="generated_project_pages",
|
| 1371 |
+
info="Directory for output files"
|
| 1372 |
+
)
|
| 1373 |
+
style_preference = gr.Textbox(
|
| 1374 |
+
label="Style Preference JSON",
|
| 1375 |
+
value="",
|
| 1376 |
+
info="Style preference JSON file path (optional)"
|
| 1377 |
+
)
|
| 1378 |
+
tmp_dir = gr.Textbox(
|
| 1379 |
+
label="Temporary Directory",
|
| 1380 |
+
value="tmp",
|
| 1381 |
+
info="Directory for temporary files"
|
| 1382 |
+
)
|
| 1383 |
+
|
| 1384 |
+
template_preview_links = gr.Markdown(
|
| 1385 |
+
label="Template Preview Links",
|
| 1386 |
+
value="",
|
| 1387 |
+
visible=False
|
| 1388 |
+
)
|
| 1389 |
+
|
| 1390 |
+
# ===== Hidden parameters with default values =====
|
| 1391 |
+
resume = gr.Radio(
|
| 1392 |
+
label="Resume From Step",
|
| 1393 |
+
choices=['parse_pdf', 'generate_content','full_content_check', 'generate_html', 'html_check','modify_table','html_feedback'],
|
| 1394 |
+
value='parse_pdf',
|
| 1395 |
+
visible=False
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
human_input = gr.Radio(
|
| 1399 |
+
label="Enable Human Feedback",
|
| 1400 |
+
choices=[0, 1],
|
| 1401 |
+
value=1,
|
| 1402 |
+
visible=False
|
| 1403 |
+
)
|
| 1404 |
+
|
| 1405 |
+
with gr.Column(scale=1):
|
| 1406 |
+
gr.Markdown("### 🎨 Style Configuration")
|
| 1407 |
+
|
| 1408 |
+
background_color = gr.Radio(
|
| 1409 |
+
label="Background Color",
|
| 1410 |
+
choices=["light", "dark"],
|
| 1411 |
+
value="light",
|
| 1412 |
+
info="Background color theme"
|
| 1413 |
+
)
|
| 1414 |
+
|
| 1415 |
+
has_navigation = gr.Radio(
|
| 1416 |
+
label="Has Navigation",
|
| 1417 |
+
choices=["yes", "no"],
|
| 1418 |
+
value="yes",
|
| 1419 |
+
info="Include navigation bar"
|
| 1420 |
+
)
|
| 1421 |
+
|
| 1422 |
+
has_hero_section = gr.Radio(
|
| 1423 |
+
label="Has Hero Section",
|
| 1424 |
+
choices=["yes", "no"],
|
| 1425 |
+
value="yes",
|
| 1426 |
+
info="Include hero/header section"
|
| 1427 |
+
)
|
| 1428 |
+
|
| 1429 |
+
title_color = gr.Radio(
|
| 1430 |
+
label="Title Color",
|
| 1431 |
+
choices=["pure", "colorful"],
|
| 1432 |
+
value="pure",
|
| 1433 |
+
info="Title color style"
|
| 1434 |
+
)
|
| 1435 |
+
|
| 1436 |
+
page_density = gr.Radio(
|
| 1437 |
+
label="Page Density",
|
| 1438 |
+
choices=["spacious", "compact"],
|
| 1439 |
+
value="spacious",
|
| 1440 |
+
info="Page spacing density"
|
| 1441 |
+
)
|
| 1442 |
+
|
| 1443 |
+
image_layout = gr.Radio(
|
| 1444 |
+
label="Image Layout",
|
| 1445 |
+
choices=["rotation", "parallelism"],
|
| 1446 |
+
value="parallelism",
|
| 1447 |
+
info="Image layout style"
|
| 1448 |
+
)
|
| 1449 |
+
|
| 1450 |
+
gr.Markdown("### ⚙️ Advanced Options")
|
| 1451 |
+
|
| 1452 |
+
full_content_check_times = gr.Number(
|
| 1453 |
+
label="Full Content Check Times",
|
| 1454 |
+
value=1,
|
| 1455 |
+
precision=0,
|
| 1456 |
+
info="Number of full content validation checks"
|
| 1457 |
+
)
|
| 1458 |
+
|
| 1459 |
+
html_check_times = gr.Number(
|
| 1460 |
+
label="HTML Check Times",
|
| 1461 |
+
value=1,
|
| 1462 |
+
precision=0,
|
| 1463 |
+
info="Number of HTML validation checks"
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
# Start Generation Button
|
| 1467 |
+
start_btn = gr.Button("🚀 Start Generation", variant="primary", size="lg")
|
| 1468 |
+
|
| 1469 |
+
# Status Output
|
| 1470 |
+
status_output = gr.Textbox(
|
| 1471 |
+
label="📊 Generation Status",
|
| 1472 |
+
lines=5,
|
| 1473 |
+
interactive=False
|
| 1474 |
+
)
|
| 1475 |
+
|
| 1476 |
+
# Section Feedback Area
|
| 1477 |
+
with gr.Group(visible=False) as feedback_section:
|
| 1478 |
+
gr.Markdown("### 📝 Section Generation Results")
|
| 1479 |
+
gr.Markdown("Please review the generated section structure. If satisfied, enter **'yes'**, otherwise provide modification feedback:")
|
| 1480 |
+
|
| 1481 |
+
with gr.Tabs():
|
| 1482 |
+
with gr.Tab("📖 Preview (Markdown)"):
|
| 1483 |
+
section_display_md = gr.Markdown(
|
| 1484 |
+
label="Section Preview",
|
| 1485 |
+
value=""
|
| 1486 |
+
)
|
| 1487 |
+
with gr.Tab("📋 Raw Data (JSON)"):
|
| 1488 |
+
section_display_json = gr.Code(
|
| 1489 |
+
label="Section JSON",
|
| 1490 |
+
language="json",
|
| 1491 |
+
value="",
|
| 1492 |
+
lines=15
|
| 1493 |
+
)
|
| 1494 |
+
|
| 1495 |
+
section_feedback_input = gr.TextArea(
|
| 1496 |
+
label="Your Feedback",
|
| 1497 |
+
placeholder="Enter 'yes' to continue, or provide modification feedback...",
|
| 1498 |
+
lines=3
|
| 1499 |
+
)
|
| 1500 |
+
section_submit_btn = gr.Button("Submit Feedback", variant="primary")
|
| 1501 |
+
|
| 1502 |
+
# Full Content Feedback Area
|
| 1503 |
+
with gr.Group(visible=False) as feedback_full_content:
|
| 1504 |
+
gr.Markdown("### 📄 Full Content Generation Results")
|
| 1505 |
+
gr.Markdown("Please review the generated full content. If satisfied, enter **'yes'**, otherwise provide modification feedback:")
|
| 1506 |
+
|
| 1507 |
+
with gr.Tabs():
|
| 1508 |
+
with gr.Tab("📖 Preview (Markdown)"):
|
| 1509 |
+
full_content_display_md = gr.Markdown(
|
| 1510 |
+
label="Full Content Preview",
|
| 1511 |
+
value=""
|
| 1512 |
+
)
|
| 1513 |
+
with gr.Tab("📋 Raw Data (JSON)"):
|
| 1514 |
+
full_content_display_json = gr.Code(
|
| 1515 |
+
label="Full Content JSON",
|
| 1516 |
+
language="json",
|
| 1517 |
+
value="",
|
| 1518 |
+
lines=15
|
| 1519 |
+
)
|
| 1520 |
+
|
| 1521 |
+
full_content_feedback_input = gr.TextArea(
|
| 1522 |
+
label="Your Feedback",
|
| 1523 |
+
placeholder="Enter 'yes' to continue, or provide modification feedback...",
|
| 1524 |
+
lines=3
|
| 1525 |
+
)
|
| 1526 |
+
full_content_submit_btn = gr.Button("Submit Feedback", variant="primary")
|
| 1527 |
+
|
| 1528 |
+
# HTML Feedback Area
|
| 1529 |
+
with gr.Group(visible=False) as feedback_html:
|
| 1530 |
+
gr.Markdown("### 🌐 HTML Generation Results")
|
| 1531 |
+
|
| 1532 |
+
# Preview Info Display
|
| 1533 |
+
preview_info_display = gr.Markdown(
|
| 1534 |
+
value="",
|
| 1535 |
+
label="Preview Information"
|
| 1536 |
+
)
|
| 1537 |
+
|
| 1538 |
+
# Preview URL (hidden state for JS)
|
| 1539 |
+
preview_url_state = gr.Textbox(visible=False)
|
| 1540 |
+
|
| 1541 |
+
# Open Preview in New Tab Button
|
| 1542 |
+
open_preview_btn = gr.Button(
|
| 1543 |
+
"🌐 Open Preview in New Tab",
|
| 1544 |
+
variant="secondary",
|
| 1545 |
+
size="lg",
|
| 1546 |
+
visible=False
|
| 1547 |
+
)
|
| 1548 |
+
|
| 1549 |
+
gr.Markdown("---")
|
| 1550 |
+
|
| 1551 |
+
# Feedback Input Area
|
| 1552 |
+
html_feedback_input = gr.TextArea(
|
| 1553 |
+
label="Your Feedback",
|
| 1554 |
+
placeholder="Enter 'yes' to finalize, or provide modification feedback...",
|
| 1555 |
+
lines=3
|
| 1556 |
+
)
|
| 1557 |
+
html_submit_btn = gr.Button("Submit Feedback", variant="primary")
|
| 1558 |
+
|
| 1559 |
+
# Final Output
|
| 1560 |
+
html_file_output = gr.File(
|
| 1561 |
+
label="📥 Download Project Archive",
|
| 1562 |
+
interactive=False
|
| 1563 |
+
)
|
| 1564 |
+
|
| 1565 |
+
gr.Markdown("""
|
| 1566 |
+
---
|
| 1567 |
+
### 💡 User Guide
|
| 1568 |
+
|
| 1569 |
+
1. **Upload PDF**: Select your research paper PDF file
|
| 1570 |
+
2. **Configure Parameters**: Adjust model, path, and style settings as needed
|
| 1571 |
+
3. **Start Generation**: Click the "Start Generation" button
|
| 1572 |
+
4. **Three-Stage Feedback**:
|
| 1573 |
+
- 📝 **Section Feedback**: Review the generated page structure (Markdown preview + JSON data), provide feedback or enter 'yes' to continue
|
| 1574 |
+
- 📄 **Full Content Feedback**: Review the generated complete content (Markdown preview + JSON data), provide feedback or enter 'yes' to continue
|
| 1575 |
+
- 🌐 **HTML Feedback**: View the generated webpage in a new tab, provide feedback or enter 'yes' to finalize
|
| 1576 |
+
5. **Download Results**: Download the complete project archive after completion
|
| 1577 |
+
|
| 1578 |
+
⚠️ **Tips**:
|
| 1579 |
+
- Each stage supports multiple rounds of feedback until you're satisfied
|
| 1580 |
+
- Section and Full Content stages offer **Markdown preview** and **JSON raw data** viewing options
|
| 1581 |
+
- Markdown preview is more visually appealing, JSON data shows complete structure
|
| 1582 |
+
- HTML stage requires clicking "Open Preview in New Tab" to view the full page in browser
|
| 1583 |
+
- Enter 'yes' to indicate satisfaction and proceed to the next stage
|
| 1584 |
+
- The final ZIP download includes the complete project folder with all resources
|
| 1585 |
+
""")
|
| 1586 |
+
|
| 1587 |
+
# Bind Events
|
| 1588 |
+
start_btn.click(
|
| 1589 |
+
fn=start_generation,
|
| 1590 |
+
inputs=[
|
| 1591 |
+
pdf_input, model_name_t, model_name_v, template_root,
|
| 1592 |
+
template_dir, template_file, output_dir, style_preference,
|
| 1593 |
+
tmp_dir, full_content_check_times, background_color,
|
| 1594 |
+
has_navigation, has_hero_section, title_color, page_density,
|
| 1595 |
+
image_layout, html_check_times, resume, human_input,
|
| 1596 |
+
template_choice, openai_api_key, gemini_api_key,
|
| 1597 |
+
qwen_api_key, zhipuai_api_key, openrouter_api_key
|
| 1598 |
+
],
|
| 1599 |
+
outputs=[
|
| 1600 |
+
status_output,
|
| 1601 |
+
feedback_section,
|
| 1602 |
+
section_display_md,
|
| 1603 |
+
section_display_json,
|
| 1604 |
+
template_choice,
|
| 1605 |
+
template_preview_links,
|
| 1606 |
+
section_feedback_input
|
| 1607 |
+
]
|
| 1608 |
+
)
|
| 1609 |
+
|
| 1610 |
+
section_submit_btn.click(
|
| 1611 |
+
fn=submit_section_feedback,
|
| 1612 |
+
inputs=[section_feedback_input],
|
| 1613 |
+
outputs=[
|
| 1614 |
+
status_output,
|
| 1615 |
+
section_display_md,
|
| 1616 |
+
section_display_json,
|
| 1617 |
+
section_feedback_input,
|
| 1618 |
+
feedback_section,
|
| 1619 |
+
feedback_full_content,
|
| 1620 |
+
full_content_display_md,
|
| 1621 |
+
full_content_display_md,
|
| 1622 |
+
full_content_display_json,
|
| 1623 |
+
full_content_feedback_input
|
| 1624 |
+
]
|
| 1625 |
+
)
|
| 1626 |
+
|
| 1627 |
+
full_content_submit_btn.click(
|
| 1628 |
+
fn=submit_full_content_feedback,
|
| 1629 |
+
inputs=[full_content_feedback_input],
|
| 1630 |
+
outputs=[
|
| 1631 |
+
status_output,
|
| 1632 |
+
full_content_display_md,
|
| 1633 |
+
full_content_display_json,
|
| 1634 |
+
full_content_feedback_input,
|
| 1635 |
+
feedback_full_content,
|
| 1636 |
+
feedback_html,
|
| 1637 |
+
preview_info_display,
|
| 1638 |
+
preview_url_state,
|
| 1639 |
+
open_preview_btn
|
| 1640 |
+
]
|
| 1641 |
+
)
|
| 1642 |
+
|
| 1643 |
+
html_submit_btn.click(
|
| 1644 |
+
fn=submit_html_feedback,
|
| 1645 |
+
inputs=[html_feedback_input],
|
| 1646 |
+
outputs=[
|
| 1647 |
+
status_output,
|
| 1648 |
+
preview_info_display,
|
| 1649 |
+
html_feedback_input,
|
| 1650 |
+
feedback_html,
|
| 1651 |
+
open_preview_btn,
|
| 1652 |
+
html_file_output
|
| 1653 |
+
]
|
| 1654 |
+
)
|
| 1655 |
+
|
| 1656 |
+
# Open Preview Button - Use JavaScript to open in new tab
|
| 1657 |
+
open_preview_btn.click(
|
| 1658 |
+
fn=None,
|
| 1659 |
+
inputs=[preview_url_state],
|
| 1660 |
+
outputs=None,
|
| 1661 |
+
js="(url) => window.open(url, '_blank')"
|
| 1662 |
+
)
|
| 1663 |
+
|
| 1664 |
+
# Launch Application
|
| 1665 |
+
if __name__ == "__main__":
|
| 1666 |
+
demo.launch(
|
| 1667 |
+
server_name="0.0.0.0",
|
| 1668 |
+
server_port=7860,
|
| 1669 |
+
share=False,
|
| 1670 |
+
show_error=True
|
| 1671 |
+
)
|
camel/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from camel.logger import disable_logging, enable_logging, set_log_level
|
| 16 |
+
|
| 17 |
+
__version__ = '0.2.19'
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'__version__',
|
| 21 |
+
'camel',
|
| 22 |
+
'disable_logging',
|
| 23 |
+
'enable_logging',
|
| 24 |
+
'set_log_level',
|
| 25 |
+
]
|
camel/agents/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .base import BaseAgent
|
| 15 |
+
from .chat_agent import ChatAgent
|
| 16 |
+
from .critic_agent import CriticAgent
|
| 17 |
+
from .embodied_agent import EmbodiedAgent
|
| 18 |
+
from .knowledge_graph_agent import KnowledgeGraphAgent
|
| 19 |
+
from .role_assignment_agent import RoleAssignmentAgent
|
| 20 |
+
from .search_agent import SearchAgent
|
| 21 |
+
from .task_agent import (
|
| 22 |
+
TaskCreationAgent,
|
| 23 |
+
TaskPlannerAgent,
|
| 24 |
+
TaskPrioritizationAgent,
|
| 25 |
+
TaskSpecifyAgent,
|
| 26 |
+
)
|
| 27 |
+
from .tool_agents.base import BaseToolAgent
|
| 28 |
+
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
'BaseAgent',
|
| 32 |
+
'ChatAgent',
|
| 33 |
+
'TaskSpecifyAgent',
|
| 34 |
+
'TaskPlannerAgent',
|
| 35 |
+
'TaskCreationAgent',
|
| 36 |
+
'TaskPrioritizationAgent',
|
| 37 |
+
'CriticAgent',
|
| 38 |
+
'BaseToolAgent',
|
| 39 |
+
'HuggingFaceToolAgent',
|
| 40 |
+
'EmbodiedAgent',
|
| 41 |
+
'RoleAssignmentAgent',
|
| 42 |
+
'SearchAgent',
|
| 43 |
+
'KnowledgeGraphAgent',
|
| 44 |
+
]
|
camel/agents/base.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseAgent(ABC):
|
| 19 |
+
r"""An abstract base class for all CAMEL agents."""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
| 23 |
+
r"""Resets the agent to its initial state."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def step(self, *args: Any, **kwargs: Any) -> Any:
|
| 28 |
+
r"""Performs a single step of the agent."""
|
| 29 |
+
pass
|
camel/agents/chat_agent.py
ADDED
|
@@ -0,0 +1,1539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import re
|
| 19 |
+
import uuid
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from typing import (
|
| 22 |
+
TYPE_CHECKING,
|
| 23 |
+
Any,
|
| 24 |
+
Callable,
|
| 25 |
+
Dict,
|
| 26 |
+
List,
|
| 27 |
+
Optional,
|
| 28 |
+
Tuple,
|
| 29 |
+
Type,
|
| 30 |
+
Union,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from openai.types.chat import ChatCompletionMessageToolCall
|
| 34 |
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
| 35 |
+
from pydantic import BaseModel, ValidationError
|
| 36 |
+
|
| 37 |
+
from camel.agents.base import BaseAgent
|
| 38 |
+
from camel.memories import (
|
| 39 |
+
AgentMemory,
|
| 40 |
+
ChatHistoryMemory,
|
| 41 |
+
MemoryRecord,
|
| 42 |
+
ScoreBasedContextCreator,
|
| 43 |
+
)
|
| 44 |
+
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
|
| 45 |
+
from camel.models import (
|
| 46 |
+
BaseModelBackend,
|
| 47 |
+
ModelFactory,
|
| 48 |
+
ModelManager,
|
| 49 |
+
ModelProcessingError,
|
| 50 |
+
)
|
| 51 |
+
from camel.responses import ChatAgentResponse
|
| 52 |
+
from camel.types import (
|
| 53 |
+
ChatCompletion,
|
| 54 |
+
ChatCompletionChunk,
|
| 55 |
+
ModelPlatformType,
|
| 56 |
+
ModelType,
|
| 57 |
+
OpenAIBackendRole,
|
| 58 |
+
RoleType,
|
| 59 |
+
)
|
| 60 |
+
from camel.utils import (
|
| 61 |
+
func_string_to_callable,
|
| 62 |
+
generate_prompt_for_structured_output,
|
| 63 |
+
get_model_encoding,
|
| 64 |
+
get_pydantic_object_schema,
|
| 65 |
+
json_to_function_code,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if TYPE_CHECKING:
|
| 69 |
+
from openai import Stream
|
| 70 |
+
|
| 71 |
+
from camel.terminators import ResponseTerminator
|
| 72 |
+
from camel.toolkits import FunctionTool
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
logger = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
# AgentOps decorator setting
|
| 78 |
+
try:
|
| 79 |
+
import os
|
| 80 |
+
|
| 81 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 82 |
+
from agentops import track_agent
|
| 83 |
+
else:
|
| 84 |
+
raise ImportError
|
| 85 |
+
except (ImportError, AttributeError):
|
| 86 |
+
from camel.utils import track_agent
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class FunctionCallingRecord(BaseModel):
|
| 90 |
+
r"""Historical records of functions called in the conversation.
|
| 91 |
+
|
| 92 |
+
Attributes:
|
| 93 |
+
func_name (str): The name of the function being called.
|
| 94 |
+
args (Dict[str, Any]): The dictionary of arguments passed to
|
| 95 |
+
the function.
|
| 96 |
+
result (Any): The execution result of calling this function.
|
| 97 |
+
tool_call_id (str): The ID of the tool call, if available.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
func_name: str
|
| 101 |
+
args: Dict[str, Any]
|
| 102 |
+
result: Any
|
| 103 |
+
tool_call_id: str
|
| 104 |
+
|
| 105 |
+
def __str__(self) -> str:
|
| 106 |
+
r"""Overridden version of the string function.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
str: Modified string to represent the function calling.
|
| 110 |
+
"""
|
| 111 |
+
return (
|
| 112 |
+
f"Function Execution: {self.func_name}\n"
|
| 113 |
+
f"\tArgs: {self.args}\n"
|
| 114 |
+
f"\tResult: {self.result}\n"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def as_dict(self) -> dict[str, Any]:
|
| 118 |
+
r"""Returns the function calling record as a dictionary.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
dict[str, Any]: The function calling record as a dictionary.
|
| 122 |
+
"""
|
| 123 |
+
return self.model_dump()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@track_agent(name="ChatAgent")
|
| 127 |
+
class ChatAgent(BaseAgent):
|
| 128 |
+
r"""Class for managing conversations of CAMEL Chat Agents.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
system_message (Union[BaseMessage, str], optional): The system message
|
| 132 |
+
for the chat agent.
|
| 133 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 134 |
+
generating responses. (default: :obj:`ModelPlatformType.DEFAULT`
|
| 135 |
+
with `ModelType.DEFAULT`)
|
| 136 |
+
memory (AgentMemory, optional): The agent memory for managing chat
|
| 137 |
+
messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
|
| 138 |
+
(default: :obj:`None`)
|
| 139 |
+
message_window_size (int, optional): The maximum number of previous
|
| 140 |
+
messages to include in the context window. If `None`, no windowing
|
| 141 |
+
is performed. (default: :obj:`None`)
|
| 142 |
+
token_limit (int, optional): The maximum number of tokens in a context.
|
| 143 |
+
The context will be automatically pruned to fulfill the limitation.
|
| 144 |
+
If `None`, it will be set according to the backend model.
|
| 145 |
+
(default: :obj:`None`)
|
| 146 |
+
output_language (str, optional): The language to be output by the
|
| 147 |
+
agent. (default: :obj:`None`)
|
| 148 |
+
tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
|
| 149 |
+
of available :obj:`FunctionTool` or :obj:`Callable`. (default:
|
| 150 |
+
:obj:`None`)
|
| 151 |
+
external_tools (Optional[List[Union[FunctionTool, Callable]]],
|
| 152 |
+
optional): List of external tools (:obj:`FunctionTool` or or
|
| 153 |
+
:obj:`Callable`) bind to one chat agent. When these tools are
|
| 154 |
+
called, the agent will directly return the request instead of
|
| 155 |
+
processing it. (default: :obj:`None`)
|
| 156 |
+
response_terminators (List[ResponseTerminator], optional): List of
|
| 157 |
+
:obj:`ResponseTerminator` bind to one chat agent.
|
| 158 |
+
(default: :obj:`None`)
|
| 159 |
+
scheduling_strategy (str): name of function that defines how to select
|
| 160 |
+
the next model in ModelManager. (default: :str:`round_robin`)
|
| 161 |
+
single_iteration (bool): Whether to let the agent perform only one
|
| 162 |
+
model calling at each step. (default: :obj:`False`)
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
system_message: Optional[Union[BaseMessage, str]] = None,
|
| 168 |
+
model: Optional[
|
| 169 |
+
Union[BaseModelBackend, List[BaseModelBackend]]
|
| 170 |
+
] = None,
|
| 171 |
+
memory: Optional[AgentMemory] = None,
|
| 172 |
+
message_window_size: Optional[int] = None,
|
| 173 |
+
token_limit: Optional[int] = None,
|
| 174 |
+
output_language: Optional[str] = None,
|
| 175 |
+
tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
| 176 |
+
external_tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
| 177 |
+
response_terminators: Optional[List[ResponseTerminator]] = None,
|
| 178 |
+
scheduling_strategy: str = "round_robin",
|
| 179 |
+
single_iteration: bool = False,
|
| 180 |
+
) -> None:
|
| 181 |
+
# Initialize the system message, converting string to BaseMessage if needed
|
| 182 |
+
if isinstance(system_message, str):
|
| 183 |
+
system_message = BaseMessage.make_assistant_message(
|
| 184 |
+
role_name='Assistant', content=system_message
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.orig_sys_message: Optional[BaseMessage] = system_message
|
| 188 |
+
self._system_message: Optional[BaseMessage] = system_message
|
| 189 |
+
self.role_name: str = (
|
| 190 |
+
getattr(system_message, 'role_name', None) or "assistant"
|
| 191 |
+
)
|
| 192 |
+
self.role_type: RoleType = (
|
| 193 |
+
getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
|
| 194 |
+
)
|
| 195 |
+
self.model_backend = ModelManager(
|
| 196 |
+
model
|
| 197 |
+
if model is not None
|
| 198 |
+
else ModelFactory.create(
|
| 199 |
+
model_platform=ModelPlatformType.DEFAULT,
|
| 200 |
+
model_type=ModelType.DEFAULT,
|
| 201 |
+
),
|
| 202 |
+
scheduling_strategy=scheduling_strategy,
|
| 203 |
+
)
|
| 204 |
+
self.model_type = self.model_backend.model_type
|
| 205 |
+
|
| 206 |
+
# Initialize tools
|
| 207 |
+
self.tools: List[FunctionTool] = (
|
| 208 |
+
self._initialize_tools(tools) if tools else []
|
| 209 |
+
)
|
| 210 |
+
self.external_tools: List[FunctionTool] = (
|
| 211 |
+
self._initialize_tools(external_tools) if external_tools else []
|
| 212 |
+
)
|
| 213 |
+
self.external_tool_names: List[str] = [
|
| 214 |
+
tool.get_function_name() for tool in self.external_tools
|
| 215 |
+
]
|
| 216 |
+
self.all_tools = self.tools + self.external_tools or []
|
| 217 |
+
|
| 218 |
+
# Create tool dictionaries and configure backend tools if necessary
|
| 219 |
+
self.tool_dict = {
|
| 220 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# If the user set tools from `ChatAgent`, it will override the
|
| 224 |
+
# configured tools in `BaseModelBackend`.
|
| 225 |
+
if self.all_tools:
|
| 226 |
+
logger.warning(
|
| 227 |
+
"Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
|
| 228 |
+
)
|
| 229 |
+
tool_schema_list = [
|
| 230 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 231 |
+
]
|
| 232 |
+
self.model_backend.model_config_dict['tools'] = tool_schema_list
|
| 233 |
+
|
| 234 |
+
self.model_token_limit = token_limit or self.model_backend.token_limit
|
| 235 |
+
context_creator = ScoreBasedContextCreator(
|
| 236 |
+
self.model_backend.token_counter,
|
| 237 |
+
self.model_token_limit,
|
| 238 |
+
)
|
| 239 |
+
self.memory: AgentMemory = memory or ChatHistoryMemory(
|
| 240 |
+
context_creator, window_size=message_window_size
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.output_language: Optional[str] = output_language
|
| 244 |
+
if self.output_language is not None:
|
| 245 |
+
self.set_output_language(self.output_language)
|
| 246 |
+
|
| 247 |
+
self.terminated: bool = False
|
| 248 |
+
self.response_terminators = response_terminators or []
|
| 249 |
+
self.init_messages()
|
| 250 |
+
self.tool_prompt_added = False
|
| 251 |
+
self.single_iteration = single_iteration
|
| 252 |
+
|
| 253 |
+
def _initialize_tools(
|
| 254 |
+
self, tools: List[Union[FunctionTool, Callable]]
|
| 255 |
+
) -> List[FunctionTool]:
|
| 256 |
+
r"""Helper method to initialize tools as FunctionTool instances."""
|
| 257 |
+
from camel.toolkits import FunctionTool
|
| 258 |
+
|
| 259 |
+
func_tools = []
|
| 260 |
+
for tool in tools:
|
| 261 |
+
if not isinstance(tool, FunctionTool):
|
| 262 |
+
tool = FunctionTool(tool)
|
| 263 |
+
func_tools.append(tool)
|
| 264 |
+
return func_tools
|
| 265 |
+
|
| 266 |
+
def add_tool(
|
| 267 |
+
self, tool: Union[FunctionTool, Callable], is_external: bool = False
|
| 268 |
+
) -> None:
|
| 269 |
+
r"""Add a tool to the agent, specifying if it's an external tool."""
|
| 270 |
+
# Initialize the tool
|
| 271 |
+
initialized_tool = self._initialize_tools([tool])
|
| 272 |
+
|
| 273 |
+
# Update tools or external tools based on is_external flag
|
| 274 |
+
if is_external:
|
| 275 |
+
self.external_tools = self.external_tools + initialized_tool
|
| 276 |
+
self.external_tool_names.extend(
|
| 277 |
+
tool.get_function_name() for tool in initialized_tool
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
self.tools = self.tools + initialized_tool
|
| 281 |
+
|
| 282 |
+
# Rebuild all_tools, and tool_dict
|
| 283 |
+
self.all_tools = self.tools + self.external_tools
|
| 284 |
+
self.tool_dict = {
|
| 285 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
tool_schema_list = [
|
| 289 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 290 |
+
]
|
| 291 |
+
self.model_backend.model_config_dict['tools'] = tool_schema_list
|
| 292 |
+
|
| 293 |
+
def remove_tool(self, tool_name: str, is_external: bool = False) -> bool:
|
| 294 |
+
r"""Remove a tool by name, specifying if it's an external tool."""
|
| 295 |
+
tool_list = self.external_tools if is_external else self.tools
|
| 296 |
+
if not tool_list:
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
for tool in tool_list:
|
| 300 |
+
if tool.get_function_name() == tool_name:
|
| 301 |
+
tool_list.remove(tool)
|
| 302 |
+
if is_external:
|
| 303 |
+
self.external_tool_names.remove(tool_name)
|
| 304 |
+
# Reinitialize the tool dictionary
|
| 305 |
+
self.all_tools = (self.tools or []) + (
|
| 306 |
+
self.external_tools or []
|
| 307 |
+
)
|
| 308 |
+
self.tool_dict = {
|
| 309 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 310 |
+
}
|
| 311 |
+
tool_schema_list = [
|
| 312 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 313 |
+
]
|
| 314 |
+
self.model_backend.model_config_dict['tools'] = (
|
| 315 |
+
tool_schema_list
|
| 316 |
+
)
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
def list_tools(self) -> dict:
|
| 321 |
+
r"""List all tools, separated into normal and external tools."""
|
| 322 |
+
normal_tools = [
|
| 323 |
+
tool.get_function_name() for tool in (self.tools or [])
|
| 324 |
+
]
|
| 325 |
+
external_tools = [
|
| 326 |
+
tool.get_function_name() for tool in (self.external_tools or [])
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
return {"normal_tools": normal_tools, "external_tools": external_tools}
|
| 330 |
+
|
| 331 |
+
# ruff: noqa: E501
|
| 332 |
+
def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
|
| 333 |
+
r"""Generates a tool prompt based on the provided tool schema list.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
tool_schema_list (List[Dict]): A list of dictionaries, each
|
| 337 |
+
containing a tool schema.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
str: A string representing the tool prompt.
|
| 341 |
+
"""
|
| 342 |
+
tool_prompts = []
|
| 343 |
+
|
| 344 |
+
for tool in tool_schema_list:
|
| 345 |
+
tool_info = tool['function']
|
| 346 |
+
tool_name = tool_info['name']
|
| 347 |
+
tool_description = tool_info['description']
|
| 348 |
+
tool_json = json.dumps(tool_info, indent=4)
|
| 349 |
+
|
| 350 |
+
prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
|
| 351 |
+
tool_prompts.append(prompt)
|
| 352 |
+
|
| 353 |
+
tool_prompt_str = "\n".join(tool_prompts)
|
| 354 |
+
|
| 355 |
+
final_prompt = f"""
|
| 356 |
+
You have access to the following functions:
|
| 357 |
+
|
| 358 |
+
{tool_prompt_str}
|
| 359 |
+
|
| 360 |
+
If you choose to call a function ONLY reply in the following format with no
|
| 361 |
+
prefix or suffix:
|
| 362 |
+
|
| 363 |
+
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
| 364 |
+
|
| 365 |
+
Reminder:
|
| 366 |
+
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
| 367 |
+
- Required parameters MUST be specified
|
| 368 |
+
- Only call one function at a time
|
| 369 |
+
- Put the entire function call reply on one line
|
| 370 |
+
- If there is no function call available, answer the question like normal
|
| 371 |
+
with your current knowledge and do not tell the user about function calls
|
| 372 |
+
"""
|
| 373 |
+
return final_prompt
|
| 374 |
+
|
| 375 |
+
def _parse_tool_response(self, response: str):
|
| 376 |
+
r"""Parses the tool response to extract the function name and
|
| 377 |
+
arguments.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
response (str): The response from the model containing the
|
| 381 |
+
function call.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Optional[Dict[str, Any]]: The parsed function name and arguments
|
| 385 |
+
if found, otherwise :obj:`None`.
|
| 386 |
+
"""
|
| 387 |
+
function_regex = r"<function=(\w+)>(.*?)</function>"
|
| 388 |
+
match = re.search(function_regex, response)
|
| 389 |
+
|
| 390 |
+
if match:
|
| 391 |
+
function_name, args_string = match.groups()
|
| 392 |
+
try:
|
| 393 |
+
args = json.loads(args_string)
|
| 394 |
+
return {"function": function_name, "arguments": args}
|
| 395 |
+
except json.JSONDecodeError as error:
|
| 396 |
+
logger.error(f"Error parsing function arguments: {error}")
|
| 397 |
+
return None
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
def reset(self):
|
| 401 |
+
r"""Resets the :obj:`ChatAgent` to its initial state."""
|
| 402 |
+
self.terminated = False
|
| 403 |
+
self.init_messages()
|
| 404 |
+
for terminator in self.response_terminators:
|
| 405 |
+
terminator.reset()
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def system_message(self) -> Optional[BaseMessage]:
|
| 409 |
+
r"""The getter method for the property :obj:`system_message`.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Optional[BaseMessage]: The system message of this agent if set,
|
| 413 |
+
else :obj:`None`.
|
| 414 |
+
"""
|
| 415 |
+
return self._system_message
|
| 416 |
+
|
| 417 |
+
@system_message.setter
|
| 418 |
+
def system_message(self, message: BaseMessage) -> None:
|
| 419 |
+
r"""The setter method for the property :obj:`system_message`.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
message (BaseMessage): The message to be set as the
|
| 423 |
+
new system message of this agent.
|
| 424 |
+
"""
|
| 425 |
+
self._system_message = message
|
| 426 |
+
|
| 427 |
+
def is_tools_added(self) -> bool:
|
| 428 |
+
r"""Whether tool calling is enabled for this agent.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
bool: Whether tool calling is enabled for this agent, determined
|
| 432 |
+
by whether the dictionary of tools is empty.
|
| 433 |
+
"""
|
| 434 |
+
return len(self.tool_dict) > 0
|
| 435 |
+
|
| 436 |
+
def update_memory(
|
| 437 |
+
self, message: BaseMessage, role: OpenAIBackendRole
|
| 438 |
+
) -> None:
|
| 439 |
+
r"""Updates the agent memory with a new message.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
message (BaseMessage): The new message to add to the stored
|
| 443 |
+
messages.
|
| 444 |
+
role (OpenAIBackendRole): The backend role type.
|
| 445 |
+
"""
|
| 446 |
+
self.memory.write_record(
|
| 447 |
+
MemoryRecord(message=message, role_at_backend=role)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def set_output_language(self, output_language: str) -> BaseMessage:
|
| 451 |
+
r"""Sets the output language for the system message. This method
|
| 452 |
+
updates the output language for the system message. The output
|
| 453 |
+
language determines the language in which the output text should be
|
| 454 |
+
generated.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
output_language (str): The desired output language.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
BaseMessage: The updated system message object.
|
| 461 |
+
"""
|
| 462 |
+
self.output_language = output_language
|
| 463 |
+
language_prompt = (
|
| 464 |
+
"\nRegardless of the input language, "
|
| 465 |
+
f"you must output text in {output_language}."
|
| 466 |
+
)
|
| 467 |
+
if self.orig_sys_message is not None:
|
| 468 |
+
content = self.orig_sys_message.content + language_prompt
|
| 469 |
+
self._system_message = self.orig_sys_message.create_new_instance(
|
| 470 |
+
content
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
self._system_message = BaseMessage.make_assistant_message(
|
| 474 |
+
role_name="Assistant",
|
| 475 |
+
content=language_prompt,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
system_record = MemoryRecord(
|
| 479 |
+
message=self._system_message,
|
| 480 |
+
role_at_backend=OpenAIBackendRole.SYSTEM,
|
| 481 |
+
)
|
| 482 |
+
self.memory.clear()
|
| 483 |
+
self.memory.write_record(system_record)
|
| 484 |
+
return self._system_message
|
| 485 |
+
|
| 486 |
+
def get_info(
|
| 487 |
+
self,
|
| 488 |
+
session_id: Optional[str],
|
| 489 |
+
usage: Optional[Dict[str, int]],
|
| 490 |
+
termination_reasons: List[str],
|
| 491 |
+
num_tokens: int,
|
| 492 |
+
tool_calls: List[FunctionCallingRecord],
|
| 493 |
+
external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
|
| 494 |
+
) -> Dict[str, Any]:
|
| 495 |
+
r"""Returns a dictionary containing information about the chat session.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
session_id (str, optional): The ID of the chat session.
|
| 499 |
+
usage (Dict[str, int], optional): Information about the usage of
|
| 500 |
+
the LLM.
|
| 501 |
+
termination_reasons (List[str]): The reasons for the termination
|
| 502 |
+
of the chat session.
|
| 503 |
+
num_tokens (int): The number of tokens used in the chat session.
|
| 504 |
+
tool_calls (List[FunctionCallingRecord]): The list of function
|
| 505 |
+
calling records, containing the information of called tools.
|
| 506 |
+
external_tool_request
|
| 507 |
+
(Optional[ChatCompletionMessageToolCall], optional):
|
| 508 |
+
The tool calling request of external tools from the model.
|
| 509 |
+
These requests are directly returned to the user instead of
|
| 510 |
+
being processed by the agent automatically.
|
| 511 |
+
(default: :obj:`None`)
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
Dict[str, Any]: The chat session information.
|
| 515 |
+
"""
|
| 516 |
+
return {
|
| 517 |
+
"id": session_id,
|
| 518 |
+
"usage": usage,
|
| 519 |
+
"termination_reasons": termination_reasons,
|
| 520 |
+
"num_tokens": num_tokens,
|
| 521 |
+
"tool_calls": tool_calls,
|
| 522 |
+
"external_tool_request": external_tool_request,
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
def init_messages(self) -> None:
|
| 526 |
+
r"""Initializes the stored messages list with the current system
|
| 527 |
+
message.
|
| 528 |
+
"""
|
| 529 |
+
if self._system_message is not None:
|
| 530 |
+
system_record = MemoryRecord(
|
| 531 |
+
message=self._system_message,
|
| 532 |
+
role_at_backend=OpenAIBackendRole.SYSTEM,
|
| 533 |
+
)
|
| 534 |
+
self.memory.clear()
|
| 535 |
+
self.memory.write_record(system_record)
|
| 536 |
+
else:
|
| 537 |
+
self.memory.clear()
|
| 538 |
+
|
| 539 |
+
def record_message(self, message: BaseMessage) -> None:
|
| 540 |
+
r"""Records the externally provided message into the agent memory as if
|
| 541 |
+
it were an answer of the :obj:`ChatAgent` from the backend. Currently,
|
| 542 |
+
the choice of the critic is submitted with this method.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
message (BaseMessage): An external message to be recorded in the
|
| 546 |
+
memory.
|
| 547 |
+
"""
|
| 548 |
+
self.update_memory(message, OpenAIBackendRole.ASSISTANT)
|
| 549 |
+
|
| 550 |
+
def step(
|
| 551 |
+
self,
|
| 552 |
+
input_message: Union[BaseMessage, str],
|
| 553 |
+
response_format: Optional[Type[BaseModel]] = None,
|
| 554 |
+
) -> ChatAgentResponse:
|
| 555 |
+
r"""Executes a single step in the chat session, generating a response
|
| 556 |
+
to the input message.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
input_message (Union[BaseMessage, str]): The input message for the
|
| 560 |
+
agent. If provided as a BaseMessage, the `role` is adjusted to
|
| 561 |
+
`user` to indicate an external message.
|
| 562 |
+
response_format (Optional[Type[BaseModel]], optional): A Pydantic
|
| 563 |
+
model defining the expected structure of the response. Used to
|
| 564 |
+
generate a structured response if provided. (default:
|
| 565 |
+
:obj:`None`)
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
ChatAgentResponse: Contains output messages, a termination status
|
| 569 |
+
flag, and session information.
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
if (
|
| 573 |
+
self.model_backend.model_config_dict.get("response_format")
|
| 574 |
+
and response_format
|
| 575 |
+
):
|
| 576 |
+
raise ValueError(
|
| 577 |
+
"The `response_format` parameter cannot be set both in "
|
| 578 |
+
"the model configuration and in the ChatAgent step."
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
self.original_model_dict = self.model_backend.model_config_dict
|
| 582 |
+
model_response_format_modified = False
|
| 583 |
+
if (
|
| 584 |
+
response_format
|
| 585 |
+
and self.model_type.support_native_structured_output
|
| 586 |
+
):
|
| 587 |
+
self.model_backend.model_config_dict = (
|
| 588 |
+
self.original_model_dict.copy()
|
| 589 |
+
)
|
| 590 |
+
self.model_backend.model_config_dict["response_format"] = (
|
| 591 |
+
response_format
|
| 592 |
+
)
|
| 593 |
+
model_response_format_modified = True
|
| 594 |
+
|
| 595 |
+
# Convert input message to BaseMessage if necessary
|
| 596 |
+
if isinstance(input_message, str):
|
| 597 |
+
input_message = BaseMessage.make_user_message(
|
| 598 |
+
role_name='User', content=input_message
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Handle tool prompt injection if needed
|
| 602 |
+
if (
|
| 603 |
+
self.is_tools_added()
|
| 604 |
+
and not self.model_type.support_native_tool_calling
|
| 605 |
+
and not self.tool_prompt_added
|
| 606 |
+
):
|
| 607 |
+
self._inject_tool_prompt()
|
| 608 |
+
|
| 609 |
+
# Add user input to memory
|
| 610 |
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
return self._handle_step(response_format, self.single_iteration)
|
| 614 |
+
finally:
|
| 615 |
+
if model_response_format_modified:
|
| 616 |
+
# Reset model config back to original state
|
| 617 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 618 |
+
|
| 619 |
+
def _inject_tool_prompt(self) -> None:
|
| 620 |
+
r"""Generate and add the tool prompt to memory."""
|
| 621 |
+
tool_prompt = self._generate_tool_prompt(
|
| 622 |
+
self.model_backend.model_config_dict["tools"]
|
| 623 |
+
)
|
| 624 |
+
tool_msg = BaseMessage.make_assistant_message(
|
| 625 |
+
role_name="Assistant", content=tool_prompt
|
| 626 |
+
)
|
| 627 |
+
self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
|
| 628 |
+
self.tool_prompt_added = True
|
| 629 |
+
|
| 630 |
+
def _handle_step(
|
| 631 |
+
self,
|
| 632 |
+
response_format: Optional[Type[BaseModel]],
|
| 633 |
+
single_step: bool,
|
| 634 |
+
) -> ChatAgentResponse:
|
| 635 |
+
r"""Handles a single or multi-step interaction."""
|
| 636 |
+
|
| 637 |
+
if (
|
| 638 |
+
self.model_backend.model_config_dict.get("tool_choice")
|
| 639 |
+
== "required"
|
| 640 |
+
and not single_step
|
| 641 |
+
):
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"`tool_choice` cannot be set to `required` for multi-step"
|
| 644 |
+
" mode. To proceed, set `single_iteration` to `True`."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
# Record function calls made during the session
|
| 648 |
+
tool_call_records: List[FunctionCallingRecord] = []
|
| 649 |
+
|
| 650 |
+
external_tool_request = None
|
| 651 |
+
|
| 652 |
+
while True:
|
| 653 |
+
try:
|
| 654 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 655 |
+
except RuntimeError as e:
|
| 656 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 657 |
+
return self._step_token_exceed(
|
| 658 |
+
e.args[1], tool_call_records, "max_tokens_exceeded"
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Prompt engineering approach for structured output for non-native tool calling models
|
| 662 |
+
inject_prompt_for_structured_output = (
|
| 663 |
+
response_format
|
| 664 |
+
and not self.model_type.support_native_structured_output
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
if inject_prompt_for_structured_output:
|
| 668 |
+
# update last openai message
|
| 669 |
+
usr_msg = openai_messages.pop()
|
| 670 |
+
usr_msg["content"] = generate_prompt_for_structured_output(
|
| 671 |
+
response_format,
|
| 672 |
+
usr_msg["content"], # type: ignore [arg-type]
|
| 673 |
+
)
|
| 674 |
+
openai_messages.append(usr_msg)
|
| 675 |
+
|
| 676 |
+
# Process model response
|
| 677 |
+
(
|
| 678 |
+
response,
|
| 679 |
+
output_messages,
|
| 680 |
+
finish_reasons,
|
| 681 |
+
usage_dict,
|
| 682 |
+
response_id,
|
| 683 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 684 |
+
|
| 685 |
+
# Try to parse structured output to return a Pydantic object
|
| 686 |
+
if inject_prompt_for_structured_output and isinstance(
|
| 687 |
+
response, ChatCompletion
|
| 688 |
+
):
|
| 689 |
+
content = response.choices[0].message.content
|
| 690 |
+
try:
|
| 691 |
+
json_content = json.loads(str(content))
|
| 692 |
+
output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
|
| 693 |
+
except json.JSONDecodeError as e:
|
| 694 |
+
logger.error(
|
| 695 |
+
f"Failed in parsing the output into JSON: {e}"
|
| 696 |
+
)
|
| 697 |
+
output_messages[0].parsed = None
|
| 698 |
+
except ValidationError as e:
|
| 699 |
+
logger.warning(
|
| 700 |
+
"Successfully generating JSON response, "
|
| 701 |
+
"but failed in parsing it into Pydantic object :"
|
| 702 |
+
f"{e}, return the JSON response in parsed field"
|
| 703 |
+
)
|
| 704 |
+
output_messages[0].parsed = json_content
|
| 705 |
+
|
| 706 |
+
# Finalize on standard response in multi-step mode
|
| 707 |
+
if self._is_standard_response(response):
|
| 708 |
+
break
|
| 709 |
+
|
| 710 |
+
# Handle tool requests
|
| 711 |
+
tool_request = self._extract_tool_call(response)
|
| 712 |
+
if isinstance(response, ChatCompletion) and tool_request:
|
| 713 |
+
response.choices[0].message.tool_calls = [tool_request]
|
| 714 |
+
tool_call_records.append(
|
| 715 |
+
self._step_tool_call_and_update(response)
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if tool_request.function.name in self.external_tool_names:
|
| 719 |
+
external_tool_request = tool_request
|
| 720 |
+
info = self._step_get_info(
|
| 721 |
+
output_messages,
|
| 722 |
+
finish_reasons,
|
| 723 |
+
usage_dict,
|
| 724 |
+
response_id,
|
| 725 |
+
tool_call_records,
|
| 726 |
+
num_tokens,
|
| 727 |
+
tool_request,
|
| 728 |
+
)
|
| 729 |
+
self._log_final_output(output_messages)
|
| 730 |
+
self.model_backend.model_config_dict = (
|
| 731 |
+
self.original_model_dict
|
| 732 |
+
)
|
| 733 |
+
return ChatAgentResponse(
|
| 734 |
+
msgs=output_messages,
|
| 735 |
+
terminated=self.terminated,
|
| 736 |
+
info=info,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Single-step mode ends after one iteration
|
| 740 |
+
if single_step:
|
| 741 |
+
break
|
| 742 |
+
|
| 743 |
+
# Optional structured output via function calling
|
| 744 |
+
if (
|
| 745 |
+
response_format
|
| 746 |
+
and not inject_prompt_for_structured_output
|
| 747 |
+
and self.model_type
|
| 748 |
+
not in {
|
| 749 |
+
"gpt-4o",
|
| 750 |
+
"gpt-4o-mini",
|
| 751 |
+
}
|
| 752 |
+
):
|
| 753 |
+
(
|
| 754 |
+
output_messages,
|
| 755 |
+
finish_reasons,
|
| 756 |
+
usage_dict,
|
| 757 |
+
response_id,
|
| 758 |
+
tool_call,
|
| 759 |
+
num_tokens,
|
| 760 |
+
) = self._structure_output_with_function(response_format)
|
| 761 |
+
tool_call_records.append(tool_call)
|
| 762 |
+
|
| 763 |
+
# Final info and response
|
| 764 |
+
info = self._step_get_info(
|
| 765 |
+
output_messages,
|
| 766 |
+
finish_reasons,
|
| 767 |
+
usage_dict,
|
| 768 |
+
response_id,
|
| 769 |
+
tool_call_records,
|
| 770 |
+
num_tokens,
|
| 771 |
+
external_tool_request,
|
| 772 |
+
)
|
| 773 |
+
self._log_final_output(output_messages)
|
| 774 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 775 |
+
return ChatAgentResponse(
|
| 776 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def _extract_tool_call(
|
| 780 |
+
self, response: Any
|
| 781 |
+
) -> Optional[ChatCompletionMessageToolCall]:
|
| 782 |
+
r"""Extract the tool call from the model response, if present.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
response (Any): The model's response object.
|
| 786 |
+
|
| 787 |
+
Returns:
|
| 788 |
+
Optional[ChatCompletionMessageToolCall]: The parsed tool call if
|
| 789 |
+
present, otherwise None.
|
| 790 |
+
"""
|
| 791 |
+
# Check if the response contains tool calls
|
| 792 |
+
if (
|
| 793 |
+
self.is_tools_added()
|
| 794 |
+
and not self.model_type.support_native_tool_calling
|
| 795 |
+
and "</function>" in response.choices[0].message.content
|
| 796 |
+
):
|
| 797 |
+
parsed_content = self._parse_tool_response(
|
| 798 |
+
response.choices[0].message.content
|
| 799 |
+
)
|
| 800 |
+
if parsed_content:
|
| 801 |
+
return ChatCompletionMessageToolCall(
|
| 802 |
+
id=str(uuid.uuid4()),
|
| 803 |
+
function=Function(
|
| 804 |
+
arguments=str(parsed_content["arguments"]).replace(
|
| 805 |
+
"'", '"'
|
| 806 |
+
),
|
| 807 |
+
name=str(parsed_content["function"]),
|
| 808 |
+
),
|
| 809 |
+
type="function",
|
| 810 |
+
)
|
| 811 |
+
elif (
|
| 812 |
+
self.is_tools_added()
|
| 813 |
+
and self.model_type.support_native_tool_calling
|
| 814 |
+
and response.choices[0].message.tool_calls
|
| 815 |
+
):
|
| 816 |
+
return response.choices[0].message.tool_calls[0]
|
| 817 |
+
|
| 818 |
+
# No tool call found
|
| 819 |
+
return None
|
| 820 |
+
|
| 821 |
+
def _is_standard_response(self, response: Any) -> bool:
|
| 822 |
+
r"""Determine if the provided response is a standard reply without
|
| 823 |
+
tool calls.
|
| 824 |
+
|
| 825 |
+
Args:
|
| 826 |
+
response (Any): The response object to evaluate.
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
bool: `True` if the response is a standard reply, `False`
|
| 830 |
+
otherwise.
|
| 831 |
+
"""
|
| 832 |
+
if not self.is_tools_added():
|
| 833 |
+
return True
|
| 834 |
+
|
| 835 |
+
if not isinstance(response, ChatCompletion):
|
| 836 |
+
return True
|
| 837 |
+
|
| 838 |
+
if self.model_type.support_native_tool_calling:
|
| 839 |
+
return not response.choices[0].message.tool_calls
|
| 840 |
+
|
| 841 |
+
return "</function>" not in str(
|
| 842 |
+
response.choices[0].message.content or ""
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
|
| 846 |
+
r"""Log final messages or warnings about multiple responses."""
|
| 847 |
+
if len(output_messages) == 1:
|
| 848 |
+
self.record_message(output_messages[0])
|
| 849 |
+
else:
|
| 850 |
+
logger.warning(
|
| 851 |
+
"Multiple messages returned in `step()`. Record "
|
| 852 |
+
"selected message manually using `record_message()`."
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
async def step_async(
|
| 856 |
+
self,
|
| 857 |
+
input_message: Union[BaseMessage, str],
|
| 858 |
+
response_format: Optional[Type[BaseModel]] = None,
|
| 859 |
+
) -> ChatAgentResponse:
|
| 860 |
+
r"""Performs a single step in the chat session by generating a response
|
| 861 |
+
to the input message. This agent step can call async function calls.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
input_message (Union[BaseMessage, str]): The input message to the
|
| 865 |
+
agent. For BaseMessage input, its `role` field that specifies
|
| 866 |
+
the role at backend may be either `user` or `assistant` but it
|
| 867 |
+
will be set to `user` anyway since for the self agent any
|
| 868 |
+
incoming message is external. For str input, the `role_name`
|
| 869 |
+
would be `User`.
|
| 870 |
+
response_format (Optional[Type[BaseModel]], optional): A pydantic
|
| 871 |
+
model class that includes value types and field descriptions
|
| 872 |
+
used to generate a structured response by LLM. This schema
|
| 873 |
+
helps in defining the expected output format. (default:
|
| 874 |
+
:obj:`None`)
|
| 875 |
+
|
| 876 |
+
Returns:
|
| 877 |
+
ChatAgentResponse: A struct containing the output messages,
|
| 878 |
+
a boolean indicating whether the chat session has terminated,
|
| 879 |
+
and information about the chat session.
|
| 880 |
+
"""
|
| 881 |
+
if isinstance(input_message, str):
|
| 882 |
+
input_message = BaseMessage.make_user_message(
|
| 883 |
+
role_name='User', content=input_message
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
| 887 |
+
|
| 888 |
+
tool_call_records: List[FunctionCallingRecord] = []
|
| 889 |
+
while True:
|
| 890 |
+
try:
|
| 891 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 892 |
+
except RuntimeError as e:
|
| 893 |
+
return self._step_token_exceed(
|
| 894 |
+
e.args[1], tool_call_records, "max_tokens_exceeded"
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
(
|
| 898 |
+
response,
|
| 899 |
+
output_messages,
|
| 900 |
+
finish_reasons,
|
| 901 |
+
usage_dict,
|
| 902 |
+
response_id,
|
| 903 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 904 |
+
|
| 905 |
+
if (
|
| 906 |
+
not self.is_tools_added()
|
| 907 |
+
or not isinstance(response, ChatCompletion)
|
| 908 |
+
or not response.choices[0].message.tool_calls
|
| 909 |
+
):
|
| 910 |
+
break
|
| 911 |
+
|
| 912 |
+
# Check for external tool call
|
| 913 |
+
external_tool_request = response.choices[0].message.tool_calls[0]
|
| 914 |
+
if external_tool_request.function.name in self.external_tool_names:
|
| 915 |
+
# if model calls an external tool, directly return the request
|
| 916 |
+
info = self._step_get_info(
|
| 917 |
+
output_messages,
|
| 918 |
+
finish_reasons,
|
| 919 |
+
usage_dict,
|
| 920 |
+
response_id,
|
| 921 |
+
tool_call_records,
|
| 922 |
+
num_tokens,
|
| 923 |
+
external_tool_request,
|
| 924 |
+
)
|
| 925 |
+
return ChatAgentResponse(
|
| 926 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
# Normal function calling
|
| 930 |
+
tool_call_records.append(
|
| 931 |
+
await self._step_tool_call_and_update_async(response)
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
if (
|
| 935 |
+
response_format is not None
|
| 936 |
+
and self.model_type.support_native_tool_calling
|
| 937 |
+
):
|
| 938 |
+
(
|
| 939 |
+
output_messages,
|
| 940 |
+
finish_reasons,
|
| 941 |
+
usage_dict,
|
| 942 |
+
response_id,
|
| 943 |
+
tool_call_record,
|
| 944 |
+
num_tokens,
|
| 945 |
+
) = self._structure_output_with_function(response_format)
|
| 946 |
+
tool_call_records.append(tool_call_record)
|
| 947 |
+
|
| 948 |
+
info = self._step_get_info(
|
| 949 |
+
output_messages,
|
| 950 |
+
finish_reasons,
|
| 951 |
+
usage_dict,
|
| 952 |
+
response_id,
|
| 953 |
+
tool_call_records,
|
| 954 |
+
num_tokens,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
if len(output_messages) == 1:
|
| 958 |
+
# Auto record if the output result is a single message
|
| 959 |
+
self.record_message(output_messages[0])
|
| 960 |
+
else:
|
| 961 |
+
logger.warning(
|
| 962 |
+
"Multiple messages returned in `step()`, message won't be "
|
| 963 |
+
"recorded automatically. Please call `record_message()` to "
|
| 964 |
+
"record the selected message manually."
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
return ChatAgentResponse(
|
| 968 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
def _step_tool_call_and_update(
|
| 972 |
+
self, response: ChatCompletion
|
| 973 |
+
) -> FunctionCallingRecord:
|
| 974 |
+
r"""Processes a function call within the chat completion response,
|
| 975 |
+
records the function call in the provided list of tool calls and
|
| 976 |
+
updates the memory of the current agent.
|
| 977 |
+
|
| 978 |
+
Args:
|
| 979 |
+
response (ChatCompletion): The response object from the chat
|
| 980 |
+
completion.
|
| 981 |
+
|
| 982 |
+
Returns:
|
| 983 |
+
FunctionCallingRecord: The record of calling the function.
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
# Perform function calling
|
| 987 |
+
func_assistant_msg, func_result_msg, tool_call_record = (
|
| 988 |
+
self._step_tool_call(response)
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
# Update the messages
|
| 992 |
+
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
| 993 |
+
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
| 994 |
+
|
| 995 |
+
return tool_call_record
|
| 996 |
+
|
| 997 |
+
async def _step_tool_call_and_update_async(
|
| 998 |
+
self, response: ChatCompletion
|
| 999 |
+
) -> FunctionCallingRecord:
|
| 1000 |
+
(
|
| 1001 |
+
func_assistant_msg,
|
| 1002 |
+
func_result_msg,
|
| 1003 |
+
func_record,
|
| 1004 |
+
) = await self.step_tool_call_async(response)
|
| 1005 |
+
|
| 1006 |
+
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
| 1007 |
+
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
| 1008 |
+
|
| 1009 |
+
return func_record
|
| 1010 |
+
|
| 1011 |
+
def _structure_output_with_function(
|
| 1012 |
+
self, response_format: Type[BaseModel]
|
| 1013 |
+
) -> Tuple[
|
| 1014 |
+
List[BaseMessage],
|
| 1015 |
+
List[str],
|
| 1016 |
+
Dict[str, int],
|
| 1017 |
+
str,
|
| 1018 |
+
FunctionCallingRecord,
|
| 1019 |
+
int,
|
| 1020 |
+
]:
|
| 1021 |
+
r"""Internal function of structuring the output of the agent based on
|
| 1022 |
+
the given output schema.
|
| 1023 |
+
|
| 1024 |
+
Args:
|
| 1025 |
+
response_format (Type[BaseModel]): The output schema to use for
|
| 1026 |
+
structuring the output.
|
| 1027 |
+
|
| 1028 |
+
Returns:
|
| 1029 |
+
Tuple[List[BaseMessage], List[str], Dict[str, int], str,
|
| 1030 |
+
FunctionCallingRecord, int]:
|
| 1031 |
+
A tuple containing the output messages, finish reasons, usage
|
| 1032 |
+
dictionary, response ID, function calling record, and number of
|
| 1033 |
+
tokens.
|
| 1034 |
+
"""
|
| 1035 |
+
from camel.toolkits import FunctionTool
|
| 1036 |
+
|
| 1037 |
+
schema_json = get_pydantic_object_schema(response_format)
|
| 1038 |
+
func_str = json_to_function_code(schema_json)
|
| 1039 |
+
func_callable = func_string_to_callable(func_str)
|
| 1040 |
+
func = FunctionTool(func_callable)
|
| 1041 |
+
|
| 1042 |
+
original_model_dict = self.model_backend.model_config_dict
|
| 1043 |
+
|
| 1044 |
+
# Replace the original tools with the structuring function
|
| 1045 |
+
self.tool_dict = {func.get_function_name(): func}
|
| 1046 |
+
self.model_backend.model_config_dict = original_model_dict.copy()
|
| 1047 |
+
self.model_backend.model_config_dict["tools"] = [
|
| 1048 |
+
func.get_openai_tool_schema()
|
| 1049 |
+
]
|
| 1050 |
+
self.model_backend.model_config_dict["tool_choice"] = "required"
|
| 1051 |
+
|
| 1052 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 1053 |
+
(
|
| 1054 |
+
response,
|
| 1055 |
+
output_messages,
|
| 1056 |
+
finish_reasons,
|
| 1057 |
+
usage_dict,
|
| 1058 |
+
response_id,
|
| 1059 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 1060 |
+
|
| 1061 |
+
if isinstance(response, ChatCompletion):
|
| 1062 |
+
tool_call_record = self._step_tool_call_and_update(response)
|
| 1063 |
+
else:
|
| 1064 |
+
raise ValueError(
|
| 1065 |
+
"Structured output is not supported for stream responses."
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
for base_message_item in output_messages:
|
| 1069 |
+
base_message_item.content = json.dumps(tool_call_record.result)
|
| 1070 |
+
|
| 1071 |
+
# Recover the original tools
|
| 1072 |
+
self.model_backend.model_config_dict = original_model_dict
|
| 1073 |
+
|
| 1074 |
+
return (
|
| 1075 |
+
output_messages,
|
| 1076 |
+
finish_reasons,
|
| 1077 |
+
usage_dict,
|
| 1078 |
+
response_id,
|
| 1079 |
+
tool_call_record,
|
| 1080 |
+
num_tokens,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
def _step_model_response(
|
| 1084 |
+
self,
|
| 1085 |
+
openai_messages: List[OpenAIMessage],
|
| 1086 |
+
num_tokens: int,
|
| 1087 |
+
) -> tuple[
|
| 1088 |
+
Union[ChatCompletion, Stream],
|
| 1089 |
+
List[BaseMessage],
|
| 1090 |
+
List[str],
|
| 1091 |
+
Dict[str, int],
|
| 1092 |
+
str,
|
| 1093 |
+
]:
|
| 1094 |
+
r"""Internal function for agent step model response."""
|
| 1095 |
+
|
| 1096 |
+
response = None
|
| 1097 |
+
# Obtain the model's response
|
| 1098 |
+
for _ in range(len(self.model_backend.models)):
|
| 1099 |
+
try:
|
| 1100 |
+
response = self.model_backend.run(openai_messages)
|
| 1101 |
+
break
|
| 1102 |
+
except Exception as exc:
|
| 1103 |
+
logger.error(
|
| 1104 |
+
f"An error occurred while running model "
|
| 1105 |
+
f"{self.model_backend.model_type}, "
|
| 1106 |
+
f"index: {self.model_backend.current_model_index}",
|
| 1107 |
+
exc_info=exc,
|
| 1108 |
+
)
|
| 1109 |
+
continue
|
| 1110 |
+
if not response:
|
| 1111 |
+
raise ModelProcessingError(
|
| 1112 |
+
"Unable to process messages: none of the provided models "
|
| 1113 |
+
"run succesfully."
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
logger.info(
|
| 1117 |
+
f"Model {self.model_backend.model_type}, "
|
| 1118 |
+
f"index {self.model_backend.current_model_index}, "
|
| 1119 |
+
f"processed these messages: {openai_messages}"
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
if isinstance(response, ChatCompletion):
|
| 1123 |
+
output_messages, finish_reasons, usage_dict, response_id = (
|
| 1124 |
+
self.handle_batch_response(response)
|
| 1125 |
+
)
|
| 1126 |
+
else:
|
| 1127 |
+
output_messages, finish_reasons, usage_dict, response_id = (
|
| 1128 |
+
self.handle_stream_response(response, num_tokens)
|
| 1129 |
+
)
|
| 1130 |
+
return (
|
| 1131 |
+
response,
|
| 1132 |
+
output_messages,
|
| 1133 |
+
finish_reasons,
|
| 1134 |
+
usage_dict,
|
| 1135 |
+
response_id,
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
def _step_get_info(
|
| 1139 |
+
self,
|
| 1140 |
+
output_messages: List[BaseMessage],
|
| 1141 |
+
finish_reasons: List[str],
|
| 1142 |
+
usage_dict: Dict[str, int],
|
| 1143 |
+
response_id: str,
|
| 1144 |
+
tool_calls: List[FunctionCallingRecord],
|
| 1145 |
+
num_tokens: int,
|
| 1146 |
+
external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
|
| 1147 |
+
) -> Dict[str, Any]:
|
| 1148 |
+
r"""Process the output of a chat step and gather information about the
|
| 1149 |
+
step.
|
| 1150 |
+
|
| 1151 |
+
This method checks for termination conditions, updates the agent's
|
| 1152 |
+
state, and collects information about the chat step, including tool
|
| 1153 |
+
calls and termination reasons.
|
| 1154 |
+
|
| 1155 |
+
Args:
|
| 1156 |
+
output_messages (List[BaseMessage]): The messages generated in
|
| 1157 |
+
this step.
|
| 1158 |
+
finish_reasons (List[str]): The reasons for finishing the
|
| 1159 |
+
generation for each message.
|
| 1160 |
+
usage_dict (Dict[str, int]): Dictionary containing token usage
|
| 1161 |
+
information.
|
| 1162 |
+
response_id (str): The ID of the response from the model.
|
| 1163 |
+
tool_calls (List[FunctionCallingRecord]): Records of function calls
|
| 1164 |
+
made during this step.
|
| 1165 |
+
num_tokens (int): The number of tokens used in this step.
|
| 1166 |
+
external_tool_request (Optional[ChatCompletionMessageToolCall]):
|
| 1167 |
+
Any external tool request made during this step.
|
| 1168 |
+
(default: :obj:`None`)
|
| 1169 |
+
|
| 1170 |
+
Returns:
|
| 1171 |
+
Dict[str, Any]: A dictionary containing information about the chat
|
| 1172 |
+
step, including termination status, reasons, and tool call
|
| 1173 |
+
information.
|
| 1174 |
+
|
| 1175 |
+
Note:
|
| 1176 |
+
This method iterates over all response terminators and checks if
|
| 1177 |
+
any of them signal termination. If a terminator signals
|
| 1178 |
+
termination, the agent's state is updated accordingly, and the
|
| 1179 |
+
termination reason is recorded.
|
| 1180 |
+
"""
|
| 1181 |
+
termination = [
|
| 1182 |
+
terminator.is_terminated(output_messages)
|
| 1183 |
+
for terminator in self.response_terminators
|
| 1184 |
+
]
|
| 1185 |
+
# Terminate the agent if any of the terminator terminates
|
| 1186 |
+
self.terminated, termination_reason = next(
|
| 1187 |
+
(
|
| 1188 |
+
(terminated, termination_reason)
|
| 1189 |
+
for terminated, termination_reason in termination
|
| 1190 |
+
if terminated
|
| 1191 |
+
),
|
| 1192 |
+
(False, None),
|
| 1193 |
+
)
|
| 1194 |
+
# For now only retain the first termination reason
|
| 1195 |
+
if self.terminated and termination_reason is not None:
|
| 1196 |
+
finish_reasons = [termination_reason] * len(finish_reasons)
|
| 1197 |
+
|
| 1198 |
+
info = self.get_info(
|
| 1199 |
+
response_id,
|
| 1200 |
+
usage_dict,
|
| 1201 |
+
finish_reasons,
|
| 1202 |
+
num_tokens,
|
| 1203 |
+
tool_calls,
|
| 1204 |
+
external_tool_request,
|
| 1205 |
+
)
|
| 1206 |
+
return info
|
| 1207 |
+
|
| 1208 |
+
def handle_batch_response(
|
| 1209 |
+
self, response: ChatCompletion
|
| 1210 |
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
| 1211 |
+
r"""Process a batch response from the model and extract the necessary
|
| 1212 |
+
information.
|
| 1213 |
+
|
| 1214 |
+
Args:
|
| 1215 |
+
response (dict): Model response.
|
| 1216 |
+
|
| 1217 |
+
Returns:
|
| 1218 |
+
tuple: A tuple of list of output `ChatMessage`, list of
|
| 1219 |
+
finish reasons, usage dictionary, and response id.
|
| 1220 |
+
"""
|
| 1221 |
+
output_messages: List[BaseMessage] = []
|
| 1222 |
+
for choice in response.choices:
|
| 1223 |
+
chat_message = BaseMessage(
|
| 1224 |
+
role_name=self.role_name,
|
| 1225 |
+
role_type=self.role_type,
|
| 1226 |
+
meta_dict=dict(),
|
| 1227 |
+
content=choice.message.content or "",
|
| 1228 |
+
parsed=getattr(choice.message, 'parsed', None),
|
| 1229 |
+
)
|
| 1230 |
+
# Process log probabilities and append to the message meta information
|
| 1231 |
+
if choice.logprobs is not None:
|
| 1232 |
+
tokens_logprobs = choice.logprobs.content
|
| 1233 |
+
|
| 1234 |
+
if tokens_logprobs is not None:
|
| 1235 |
+
# Extract and structure logprob information
|
| 1236 |
+
logprobs_info = [
|
| 1237 |
+
{
|
| 1238 |
+
"token": token_logprob.token,
|
| 1239 |
+
"logprob": token_logprob.logprob,
|
| 1240 |
+
"top_logprobs": [
|
| 1241 |
+
(top_logprob.token, top_logprob.logprob)
|
| 1242 |
+
for top_logprob in token_logprob.top_logprobs
|
| 1243 |
+
],
|
| 1244 |
+
}
|
| 1245 |
+
for token_logprob in tokens_logprobs
|
| 1246 |
+
]
|
| 1247 |
+
# Ensure meta_dict exists before adding logprobs info
|
| 1248 |
+
if chat_message.meta_dict is None:
|
| 1249 |
+
chat_message.meta_dict = {}
|
| 1250 |
+
chat_message.meta_dict["logprobs_info"] = logprobs_info
|
| 1251 |
+
# Append the processed chat message to output
|
| 1252 |
+
output_messages.append(chat_message)
|
| 1253 |
+
|
| 1254 |
+
finish_reasons = [
|
| 1255 |
+
str(choice.finish_reason) for choice in response.choices
|
| 1256 |
+
]
|
| 1257 |
+
usage = (
|
| 1258 |
+
self._safe_model_dump(response.usage)
|
| 1259 |
+
if response.usage is not None
|
| 1260 |
+
else {}
|
| 1261 |
+
)
|
| 1262 |
+
return (
|
| 1263 |
+
output_messages,
|
| 1264 |
+
finish_reasons,
|
| 1265 |
+
usage,
|
| 1266 |
+
response.id,
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
def _safe_model_dump(self, obj) -> dict:
|
| 1270 |
+
r"""Safely dump a Pydantic model to a dictionary.
|
| 1271 |
+
|
| 1272 |
+
This method attempts to use the `model_dump` method if available,
|
| 1273 |
+
otherwise it falls back to the `dict` method.
|
| 1274 |
+
|
| 1275 |
+
Args:
|
| 1276 |
+
obj: The Pydantic model instance to be dumped.
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
dict: A dictionary representation of the Pydantic model.
|
| 1280 |
+
"""
|
| 1281 |
+
# Check if the `model_dump` method exists (Pydantic v2)
|
| 1282 |
+
if hasattr(obj, 'model_dump'):
|
| 1283 |
+
return obj.model_dump()
|
| 1284 |
+
# Fallback to `dict()` method (Pydantic v1)
|
| 1285 |
+
elif hasattr(obj, 'dict'):
|
| 1286 |
+
return obj.dict()
|
| 1287 |
+
else:
|
| 1288 |
+
raise TypeError("The object is not a Pydantic model")
|
| 1289 |
+
|
| 1290 |
+
def handle_stream_response(
|
| 1291 |
+
self,
|
| 1292 |
+
response: Stream[ChatCompletionChunk],
|
| 1293 |
+
prompt_tokens: int,
|
| 1294 |
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
| 1295 |
+
r"""Process a stream response from the model and extract the necessary
|
| 1296 |
+
information.
|
| 1297 |
+
|
| 1298 |
+
Args:
|
| 1299 |
+
response (dict): Model response.
|
| 1300 |
+
prompt_tokens (int): Number of input prompt tokens.
|
| 1301 |
+
|
| 1302 |
+
Returns:
|
| 1303 |
+
tuple: A tuple of list of output `ChatMessage`, list of
|
| 1304 |
+
finish reasons, usage dictionary, and response id.
|
| 1305 |
+
"""
|
| 1306 |
+
content_dict: defaultdict = defaultdict(lambda: "")
|
| 1307 |
+
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
| 1308 |
+
output_messages: List[BaseMessage] = []
|
| 1309 |
+
response_id: str = ""
|
| 1310 |
+
# All choices in one response share one role
|
| 1311 |
+
for chunk in response:
|
| 1312 |
+
response_id = chunk.id
|
| 1313 |
+
for choice in chunk.choices:
|
| 1314 |
+
index = choice.index
|
| 1315 |
+
delta = choice.delta
|
| 1316 |
+
if delta.content is not None:
|
| 1317 |
+
# When response has not been stopped
|
| 1318 |
+
# Notice that only the first chunk_dict has the "role"
|
| 1319 |
+
content_dict[index] += delta.content
|
| 1320 |
+
if choice.finish_reason:
|
| 1321 |
+
finish_reasons_dict[index] = choice.finish_reason
|
| 1322 |
+
chat_message = BaseMessage(
|
| 1323 |
+
role_name=self.role_name,
|
| 1324 |
+
role_type=self.role_type,
|
| 1325 |
+
meta_dict=dict(),
|
| 1326 |
+
content=content_dict[index],
|
| 1327 |
+
)
|
| 1328 |
+
output_messages.append(chat_message)
|
| 1329 |
+
finish_reasons = [
|
| 1330 |
+
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
| 1331 |
+
]
|
| 1332 |
+
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
| 1333 |
+
return output_messages, finish_reasons, usage_dict, response_id
|
| 1334 |
+
|
| 1335 |
+
def _step_token_exceed(
|
| 1336 |
+
self,
|
| 1337 |
+
num_tokens: int,
|
| 1338 |
+
tool_calls: List[FunctionCallingRecord],
|
| 1339 |
+
termination_reason: str,
|
| 1340 |
+
) -> ChatAgentResponse:
|
| 1341 |
+
r"""Return trivial response containing number of tokens and information
|
| 1342 |
+
of called functions when the number of tokens exceeds.
|
| 1343 |
+
|
| 1344 |
+
Args:
|
| 1345 |
+
num_tokens (int): Number of tokens in the messages.
|
| 1346 |
+
tool_calls (List[FunctionCallingRecord]): List of information
|
| 1347 |
+
objects of functions called in the current step.
|
| 1348 |
+
termination_reason (str): String of termination reason.
|
| 1349 |
+
|
| 1350 |
+
Returns:
|
| 1351 |
+
ChatAgentResponse: The struct containing trivial outputs and
|
| 1352 |
+
information about token number and called functions.
|
| 1353 |
+
"""
|
| 1354 |
+
self.terminated = True
|
| 1355 |
+
output_messages: List[BaseMessage] = []
|
| 1356 |
+
|
| 1357 |
+
info = self.get_info(
|
| 1358 |
+
None,
|
| 1359 |
+
None,
|
| 1360 |
+
[termination_reason],
|
| 1361 |
+
num_tokens,
|
| 1362 |
+
tool_calls,
|
| 1363 |
+
)
|
| 1364 |
+
|
| 1365 |
+
return ChatAgentResponse(
|
| 1366 |
+
msgs=output_messages,
|
| 1367 |
+
terminated=self.terminated,
|
| 1368 |
+
info=info,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
def _step_tool_call(
|
| 1372 |
+
self,
|
| 1373 |
+
response: ChatCompletion,
|
| 1374 |
+
) -> Tuple[
|
| 1375 |
+
FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
|
| 1376 |
+
]:
|
| 1377 |
+
r"""Execute the function with arguments following the model's response.
|
| 1378 |
+
|
| 1379 |
+
Args:
|
| 1380 |
+
response (Dict[str, Any]): The response obtained by calling the
|
| 1381 |
+
model.
|
| 1382 |
+
|
| 1383 |
+
Returns:
|
| 1384 |
+
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
|
| 1385 |
+
one about the arguments and the other about the execution
|
| 1386 |
+
result, and a struct for logging information about this
|
| 1387 |
+
function call.
|
| 1388 |
+
"""
|
| 1389 |
+
choice = response.choices[0]
|
| 1390 |
+
if choice.message.tool_calls is None:
|
| 1391 |
+
raise RuntimeError("Tool call is None")
|
| 1392 |
+
func_name = choice.message.tool_calls[0].function.name
|
| 1393 |
+
|
| 1394 |
+
arguments_str = choice.message.tool_calls[0].function.arguments
|
| 1395 |
+
args = self._safe_json_loads(arguments_str)
|
| 1396 |
+
|
| 1397 |
+
tool = self.tool_dict[func_name]
|
| 1398 |
+
result = tool(**args)
|
| 1399 |
+
tool_call_id = choice.message.tool_calls[0].id
|
| 1400 |
+
|
| 1401 |
+
assist_msg = FunctionCallingMessage(
|
| 1402 |
+
role_name=self.role_name,
|
| 1403 |
+
role_type=self.role_type,
|
| 1404 |
+
meta_dict=None,
|
| 1405 |
+
content="",
|
| 1406 |
+
func_name=func_name,
|
| 1407 |
+
args=args,
|
| 1408 |
+
tool_call_id=tool_call_id,
|
| 1409 |
+
)
|
| 1410 |
+
func_msg = FunctionCallingMessage(
|
| 1411 |
+
role_name=self.role_name,
|
| 1412 |
+
role_type=self.role_type,
|
| 1413 |
+
meta_dict=None,
|
| 1414 |
+
content="",
|
| 1415 |
+
func_name=func_name,
|
| 1416 |
+
result=result,
|
| 1417 |
+
tool_call_id=tool_call_id,
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
# Record information about this function call
|
| 1421 |
+
func_record = FunctionCallingRecord(
|
| 1422 |
+
func_name=func_name,
|
| 1423 |
+
args=args,
|
| 1424 |
+
result=result,
|
| 1425 |
+
tool_call_id=tool_call_id,
|
| 1426 |
+
)
|
| 1427 |
+
return assist_msg, func_msg, func_record
|
| 1428 |
+
|
| 1429 |
+
def _safe_json_loads(self, arguments_str):
|
| 1430 |
+
# Replace Python types with their JSON equivalents
|
| 1431 |
+
arguments_str = arguments_str.replace("None", "null")
|
| 1432 |
+
arguments_str = arguments_str.replace("True", "true")
|
| 1433 |
+
arguments_str = arguments_str.replace("False", "false")
|
| 1434 |
+
|
| 1435 |
+
# Attempt to parse the corrected string
|
| 1436 |
+
try:
|
| 1437 |
+
return json.loads(arguments_str)
|
| 1438 |
+
except json.JSONDecodeError as e:
|
| 1439 |
+
raise ValueError(f"Invalid JSON format: {e}")
|
| 1440 |
+
|
| 1441 |
+
async def step_tool_call_async(
|
| 1442 |
+
self,
|
| 1443 |
+
response: ChatCompletion,
|
| 1444 |
+
) -> Tuple[
|
| 1445 |
+
FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
|
| 1446 |
+
]:
|
| 1447 |
+
r"""Execute the async function with arguments following the model's
|
| 1448 |
+
response.
|
| 1449 |
+
|
| 1450 |
+
Args:
|
| 1451 |
+
response (Dict[str, Any]): The response obtained by calling the
|
| 1452 |
+
model.
|
| 1453 |
+
|
| 1454 |
+
Returns:
|
| 1455 |
+
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
|
| 1456 |
+
one about the arguments and the other about the execution
|
| 1457 |
+
result, and a struct for logging information about this
|
| 1458 |
+
function call.
|
| 1459 |
+
"""
|
| 1460 |
+
# Note that when function calling is enabled, `n` is set to 1.
|
| 1461 |
+
choice = response.choices[0]
|
| 1462 |
+
if choice.message.tool_calls is None:
|
| 1463 |
+
raise RuntimeError("Tool call is None")
|
| 1464 |
+
func_name = choice.message.tool_calls[0].function.name
|
| 1465 |
+
|
| 1466 |
+
args = json.loads(choice.message.tool_calls[0].function.arguments)
|
| 1467 |
+
tool = self.tool_dict[func_name]
|
| 1468 |
+
result = await tool(**args)
|
| 1469 |
+
tool_call_id = choice.message.tool_calls[0].id
|
| 1470 |
+
|
| 1471 |
+
assist_msg = FunctionCallingMessage(
|
| 1472 |
+
role_name=self.role_name,
|
| 1473 |
+
role_type=self.role_type,
|
| 1474 |
+
meta_dict=None,
|
| 1475 |
+
content="",
|
| 1476 |
+
func_name=func_name,
|
| 1477 |
+
args=args,
|
| 1478 |
+
tool_call_id=tool_call_id,
|
| 1479 |
+
)
|
| 1480 |
+
func_msg = FunctionCallingMessage(
|
| 1481 |
+
role_name=self.role_name,
|
| 1482 |
+
role_type=self.role_type,
|
| 1483 |
+
meta_dict=None,
|
| 1484 |
+
content="",
|
| 1485 |
+
func_name=func_name,
|
| 1486 |
+
result=result,
|
| 1487 |
+
tool_call_id=tool_call_id,
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
# Record information about this function call
|
| 1491 |
+
func_record = FunctionCallingRecord(
|
| 1492 |
+
func_name=func_name,
|
| 1493 |
+
args=args,
|
| 1494 |
+
result=result,
|
| 1495 |
+
tool_call_id=tool_call_id,
|
| 1496 |
+
)
|
| 1497 |
+
return assist_msg, func_msg, func_record
|
| 1498 |
+
|
| 1499 |
+
def get_usage_dict(
|
| 1500 |
+
self, output_messages: List[BaseMessage], prompt_tokens: int
|
| 1501 |
+
) -> Dict[str, int]:
|
| 1502 |
+
r"""Get usage dictionary when using the stream mode.
|
| 1503 |
+
|
| 1504 |
+
Args:
|
| 1505 |
+
output_messages (list): List of output messages.
|
| 1506 |
+
prompt_tokens (int): Number of input prompt tokens.
|
| 1507 |
+
|
| 1508 |
+
Returns:
|
| 1509 |
+
dict: Usage dictionary.
|
| 1510 |
+
"""
|
| 1511 |
+
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
|
| 1512 |
+
completion_tokens = 0
|
| 1513 |
+
for message in output_messages:
|
| 1514 |
+
completion_tokens += len(encoding.encode(message.content))
|
| 1515 |
+
usage_dict = dict(
|
| 1516 |
+
completion_tokens=completion_tokens,
|
| 1517 |
+
prompt_tokens=prompt_tokens,
|
| 1518 |
+
total_tokens=completion_tokens + prompt_tokens,
|
| 1519 |
+
)
|
| 1520 |
+
return usage_dict
|
| 1521 |
+
|
| 1522 |
+
def add_model_scheduling_strategy(self, name: str, strategy_fn: Callable):
|
| 1523 |
+
r"""Add a scheduling strategy method provided by user to ModelManger.
|
| 1524 |
+
|
| 1525 |
+
Args:
|
| 1526 |
+
name (str): The name of the strategy.
|
| 1527 |
+
strategy_fn (Callable): The scheduling strategy function.
|
| 1528 |
+
"""
|
| 1529 |
+
self.model_backend.add_strategy(name, strategy_fn)
|
| 1530 |
+
|
| 1531 |
+
def __repr__(self) -> str:
|
| 1532 |
+
r"""Returns a string representation of the :obj:`ChatAgent`.
|
| 1533 |
+
|
| 1534 |
+
Returns:
|
| 1535 |
+
str: The string representation of the :obj:`ChatAgent`.
|
| 1536 |
+
"""
|
| 1537 |
+
return (
|
| 1538 |
+
f"ChatAgent({self.role_name}, {self.role_type}, {self.model_type})"
|
| 1539 |
+
)
|
camel/agents/critic_agent.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import random
|
| 15 |
+
import warnings
|
| 16 |
+
from typing import Any, Dict, Optional, Sequence
|
| 17 |
+
|
| 18 |
+
from colorama import Fore
|
| 19 |
+
|
| 20 |
+
from camel.agents.chat_agent import ChatAgent
|
| 21 |
+
from camel.memories import AgentMemory
|
| 22 |
+
from camel.messages import BaseMessage
|
| 23 |
+
from camel.models import BaseModelBackend
|
| 24 |
+
from camel.responses import ChatAgentResponse
|
| 25 |
+
from camel.utils import get_first_int, print_text_animated
|
| 26 |
+
|
| 27 |
+
# AgentOps decorator setting
|
| 28 |
+
try:
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 32 |
+
from agentops import track_agent
|
| 33 |
+
else:
|
| 34 |
+
raise ImportError
|
| 35 |
+
except (ImportError, AttributeError):
|
| 36 |
+
from camel.utils import track_agent
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@track_agent(name="CriticAgent")
|
| 40 |
+
class CriticAgent(ChatAgent):
|
| 41 |
+
r"""A class for the critic agent that assists in selecting an option.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
system_message (BaseMessage): The system message for the critic
|
| 45 |
+
agent.
|
| 46 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 47 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 48 |
+
`GPT_4O_MINI`)
|
| 49 |
+
message_window_size (int, optional): The maximum number of previous
|
| 50 |
+
messages to include in the context window. If `None`, no windowing
|
| 51 |
+
is performed. (default: :obj:`6`)
|
| 52 |
+
retry_attempts (int, optional): The number of retry attempts if the
|
| 53 |
+
critic fails to return a valid option. (default: :obj:`2`)
|
| 54 |
+
verbose (bool, optional): Whether to print the critic's messages.
|
| 55 |
+
logger_color (Any): The color of the menu options displayed to the
|
| 56 |
+
user. (default: :obj:`Fore.MAGENTA`)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
system_message: BaseMessage,
|
| 62 |
+
model: Optional[BaseModelBackend] = None,
|
| 63 |
+
memory: Optional[AgentMemory] = None,
|
| 64 |
+
message_window_size: int = 6,
|
| 65 |
+
retry_attempts: int = 2,
|
| 66 |
+
verbose: bool = False,
|
| 67 |
+
logger_color: Any = Fore.MAGENTA,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__(
|
| 70 |
+
system_message,
|
| 71 |
+
model=model,
|
| 72 |
+
memory=memory,
|
| 73 |
+
message_window_size=message_window_size,
|
| 74 |
+
)
|
| 75 |
+
self.options_dict: Dict[str, str] = dict()
|
| 76 |
+
self.retry_attempts = retry_attempts
|
| 77 |
+
self.verbose = verbose
|
| 78 |
+
self.logger_color = logger_color
|
| 79 |
+
|
| 80 |
+
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
| 81 |
+
r"""Flattens the options to the critic.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
str: A string containing the flattened options to the critic.
|
| 88 |
+
"""
|
| 89 |
+
options = [message.content for message in messages]
|
| 90 |
+
flatten_options = (
|
| 91 |
+
f"> Proposals from "
|
| 92 |
+
f"{messages[0].role_name} ({messages[0].role_type}). "
|
| 93 |
+
"Please choose an option:\n"
|
| 94 |
+
)
|
| 95 |
+
for index, option in enumerate(options):
|
| 96 |
+
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
| 97 |
+
self.options_dict[str(index + 1)] = option
|
| 98 |
+
format = (
|
| 99 |
+
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
| 100 |
+
"and then your explanation and comparison: "
|
| 101 |
+
)
|
| 102 |
+
return flatten_options + format
|
| 103 |
+
|
| 104 |
+
def get_option(self, input_message: BaseMessage) -> str:
|
| 105 |
+
r"""Gets the option selected by the critic.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
input_message (BaseMessage): A `BaseMessage` object representing
|
| 109 |
+
the input message.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
str: The option selected by the critic.
|
| 113 |
+
"""
|
| 114 |
+
# TODO: Add support for editing options by the critic.
|
| 115 |
+
msg_content = input_message.content
|
| 116 |
+
i = 0
|
| 117 |
+
while i < self.retry_attempts:
|
| 118 |
+
critic_response = self.step(input_message)
|
| 119 |
+
|
| 120 |
+
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
| 121 |
+
raise RuntimeError("Got None critic messages.")
|
| 122 |
+
if critic_response.terminated:
|
| 123 |
+
raise RuntimeError("Critic step failed.")
|
| 124 |
+
|
| 125 |
+
critic_msg = critic_response.msg
|
| 126 |
+
if self.verbose:
|
| 127 |
+
print_text_animated(
|
| 128 |
+
self.logger_color + "\n> Critic response: "
|
| 129 |
+
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
| 130 |
+
)
|
| 131 |
+
choice = self.parse_critic(critic_msg)
|
| 132 |
+
|
| 133 |
+
if choice in self.options_dict:
|
| 134 |
+
return self.options_dict[choice]
|
| 135 |
+
else:
|
| 136 |
+
input_message = BaseMessage(
|
| 137 |
+
role_name=input_message.role_name,
|
| 138 |
+
role_type=input_message.role_type,
|
| 139 |
+
meta_dict=input_message.meta_dict,
|
| 140 |
+
content="> Invalid choice. Please choose again.\n"
|
| 141 |
+
+ msg_content,
|
| 142 |
+
)
|
| 143 |
+
i += 1
|
| 144 |
+
warnings.warn(
|
| 145 |
+
"Critic failed to get a valid option. "
|
| 146 |
+
f"After {self.retry_attempts} attempts. "
|
| 147 |
+
"Returning a random option."
|
| 148 |
+
)
|
| 149 |
+
return random.choice(list(self.options_dict.values()))
|
| 150 |
+
|
| 151 |
+
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
| 152 |
+
r"""Parses the critic's message and extracts the choice.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
| 156 |
+
critic's response.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Optional[str]: The critic's choice as a string, or None if the
|
| 160 |
+
message could not be parsed.
|
| 161 |
+
"""
|
| 162 |
+
choice = str(get_first_int(critic_msg.content))
|
| 163 |
+
return choice
|
| 164 |
+
|
| 165 |
+
def reduce_step(
|
| 166 |
+
self,
|
| 167 |
+
input_messages: Sequence[BaseMessage],
|
| 168 |
+
) -> ChatAgentResponse:
|
| 169 |
+
r"""Performs one step of the conversation by flattening options to the
|
| 170 |
+
critic, getting the option, and parsing the choice.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
| 174 |
+
objects.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
| 178 |
+
critic's choice.
|
| 179 |
+
"""
|
| 180 |
+
meta_chat_message = BaseMessage(
|
| 181 |
+
role_name=input_messages[0].role_name,
|
| 182 |
+
role_type=input_messages[0].role_type,
|
| 183 |
+
meta_dict=input_messages[0].meta_dict,
|
| 184 |
+
content="",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
flatten_options = self.flatten_options(input_messages)
|
| 188 |
+
if self.verbose:
|
| 189 |
+
print_text_animated(
|
| 190 |
+
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
| 191 |
+
)
|
| 192 |
+
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
| 193 |
+
|
| 194 |
+
option = self.get_option(input_msg)
|
| 195 |
+
output_msg = meta_chat_message.create_new_instance(option)
|
| 196 |
+
|
| 197 |
+
# TODO: The return `info` can be improved.
|
| 198 |
+
return ChatAgentResponse(
|
| 199 |
+
msgs=[output_msg],
|
| 200 |
+
terminated=False,
|
| 201 |
+
info={},
|
| 202 |
+
)
|
camel/agents/deductive_reasoner_agent.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
from camel.agents.chat_agent import ChatAgent
|
| 18 |
+
from camel.logger import get_logger
|
| 19 |
+
from camel.messages import BaseMessage
|
| 20 |
+
from camel.models import BaseModelBackend
|
| 21 |
+
from camel.prompts import TextPrompt
|
| 22 |
+
from camel.types import RoleType
|
| 23 |
+
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
# AgentOps decorator setting
|
| 27 |
+
try:
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 31 |
+
from agentops import track_agent
|
| 32 |
+
else:
|
| 33 |
+
raise ImportError
|
| 34 |
+
except (ImportError, AttributeError):
|
| 35 |
+
from camel.utils import track_agent
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@track_agent(name="DeductiveReasonerAgent")
|
| 39 |
+
class DeductiveReasonerAgent(ChatAgent):
|
| 40 |
+
r"""An agent responsible for deductive reasoning. Model of deductive
|
| 41 |
+
reasoning:
|
| 42 |
+
- L: A ⊕ C -> q * B
|
| 43 |
+
- A represents the known starting state.
|
| 44 |
+
- B represents the known target state.
|
| 45 |
+
- C represents the conditions required to transition from A to B.
|
| 46 |
+
- Q represents the quality or effectiveness of the transition from
|
| 47 |
+
A to B.
|
| 48 |
+
- L represents the path or process from A to B.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 52 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 53 |
+
`GPT_4O_MINI`)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
model: Optional[BaseModelBackend] = None,
|
| 59 |
+
) -> None:
|
| 60 |
+
system_message = BaseMessage(
|
| 61 |
+
role_name="Insight Agent",
|
| 62 |
+
role_type=RoleType.ASSISTANT,
|
| 63 |
+
meta_dict=None,
|
| 64 |
+
content="You assign roles based on tasks.",
|
| 65 |
+
)
|
| 66 |
+
super().__init__(system_message, model=model)
|
| 67 |
+
|
| 68 |
+
def deduce_conditions_and_quality(
|
| 69 |
+
self,
|
| 70 |
+
starting_state: str,
|
| 71 |
+
target_state: str,
|
| 72 |
+
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
| 73 |
+
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
| 74 |
+
r"""Derives the conditions and quality from the starting state and the
|
| 75 |
+
target state based on the model of the deductive reasoning and the
|
| 76 |
+
knowledge base. It can optionally consider the roles involved in the
|
| 77 |
+
scenario, which allows tailoring the output more closely to the AI
|
| 78 |
+
agent's environment.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
starting_state (str): The initial or starting state from which
|
| 82 |
+
conditions are deduced.
|
| 83 |
+
target_state (str): The target state of the task.
|
| 84 |
+
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
| 85 |
+
descriptions of the roles. (default: :obj:`None`)
|
| 86 |
+
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
| 87 |
+
dictionary describing the roles involved in the scenario. This
|
| 88 |
+
is optional and can be used to provide a context for the
|
| 89 |
+
CAMEL's role-playing, enabling the generation of more relevant
|
| 90 |
+
and tailored conditions and quality assessments. This could be
|
| 91 |
+
generated using a `RoleAssignmentAgent()` or defined manually
|
| 92 |
+
by the user.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
| 96 |
+
extracted data from the message. The dictionary contains three
|
| 97 |
+
keys:
|
| 98 |
+
- 'conditions': A list where each key is a condition ID and
|
| 99 |
+
each value is the corresponding condition text.
|
| 100 |
+
- 'labels': A list of label strings extracted from the message.
|
| 101 |
+
- 'quality': A string of quality assessment strings extracted
|
| 102 |
+
from the message.
|
| 103 |
+
"""
|
| 104 |
+
self.reset()
|
| 105 |
+
|
| 106 |
+
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
| 107 |
+
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
| 108 |
+
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
| 109 |
+
CONTENT to help you complete the TASK.
|
| 110 |
+
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
| 111 |
+
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
| 112 |
+
|
| 113 |
+
===== MODELING OF DEDUCTIVE REASONING =====
|
| 114 |
+
You are tasked with understanding a mathematical model based on the components
|
| 115 |
+
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
| 116 |
+
- $A$ represents the known starting state.
|
| 117 |
+
- $B$ represents the known target state.
|
| 118 |
+
- $C$ represents the conditions required to transition from $A$ to $B$.
|
| 119 |
+
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
| 120 |
+
$B$.
|
| 121 |
+
- $L$ represents the path or process from $A$ to $B$.
|
| 122 |
+
|
| 123 |
+
===== THOUGHT OF DEDUCTIVE REASONING =====
|
| 124 |
+
1. Define the Parameters of A and B:
|
| 125 |
+
- Characterization: Before delving into transitions, thoroughly understand
|
| 126 |
+
the nature and boundaries of both $A$ and $B$. This includes the type,
|
| 127 |
+
properties, constraints, and possible interactions between the two.
|
| 128 |
+
- Contrast and Compare: Highlight the similarities and differences between
|
| 129 |
+
$A$ and $B$. This comparative analysis will give an insight into what
|
| 130 |
+
needs changing and what remains constant.
|
| 131 |
+
2. Historical & Empirical Analysis:
|
| 132 |
+
- Previous Transitions according to the Knowledge Base of GPT: (if
|
| 133 |
+
applicable) Extract conditions and patterns from the historical instances
|
| 134 |
+
where a similar transition from a state comparable to $A$ moved towards
|
| 135 |
+
$B$.
|
| 136 |
+
- Scientific Principles: (if applicable) Consider the underlying
|
| 137 |
+
scientific principles governing or related to the states and their
|
| 138 |
+
transition. For example, if $A$ and $B$ are physical states, laws of
|
| 139 |
+
physics might apply.
|
| 140 |
+
3. Logical Deduction of Conditions ($C$):
|
| 141 |
+
- Direct Path Analysis: What are the immediate and direct conditions
|
| 142 |
+
required to move from $A$ to $B$?
|
| 143 |
+
- Intermediate States: Are there states between $A$ and $B$ that must be
|
| 144 |
+
traversed or can be used to make the transition smoother or more
|
| 145 |
+
efficient? If yes, what is the content?
|
| 146 |
+
- Constraints & Limitations: Identify potential barriers or restrictions
|
| 147 |
+
in moving from $A$ to $B$. These can be external (e.g., environmental
|
| 148 |
+
factors) or internal (properties of $A$ or $B$).
|
| 149 |
+
- Resource and Information Analysis: What resources and information are
|
| 150 |
+
required for the transition? This could be time, entity, factor, code
|
| 151 |
+
language, software platform, unknowns, etc.
|
| 152 |
+
- External Influences: Consider socio-economic, political, or
|
| 153 |
+
environmental factors (if applicable) that could influence the transition
|
| 154 |
+
conditions.
|
| 155 |
+
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
| 156 |
+
no matter how unconventional they might seem. Utilize analogies,
|
| 157 |
+
metaphors, or brainstorming techniques to envision possible conditions or
|
| 158 |
+
paths from $A$ to $B$.
|
| 159 |
+
- The conditions $C$ should be multiple but in one sentence. And each
|
| 160 |
+
condition should be concerned with one aspect/entity.
|
| 161 |
+
4. Entity/Label Recognition of Conditions ($C$):
|
| 162 |
+
- Identify and categorize entities of Conditions ($C$) such as the names,
|
| 163 |
+
locations, dates, specific technical terms or contextual parameters that
|
| 164 |
+
might be associated with events, innovations post-2022.
|
| 165 |
+
- The output of the entities/labels will be used as tags or labels for
|
| 166 |
+
semantic similarity searches. The entities/labels may be the words, or
|
| 167 |
+
phrases, each of them should contain valuable, high information entropy
|
| 168 |
+
information, and should be independent.
|
| 169 |
+
- Ensure that the identified entities are formatted in a manner suitable
|
| 170 |
+
for database indexing and retrieval. Organize the entities into
|
| 171 |
+
categories, and combine the category with its instance into a continuous
|
| 172 |
+
phrase, without using colons or other separators.
|
| 173 |
+
- Format these entities for database indexing: output the category rather
|
| 174 |
+
than its instance/content into a continuous phrase. For example, instead
|
| 175 |
+
of "Jan. 02", identify it as "Event time".
|
| 176 |
+
5. Quality Assessment ($Q$):
|
| 177 |
+
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
| 178 |
+
measures the resources used versus the desired outcome?
|
| 179 |
+
- Effectiveness: Did the transition achieve the desired outcome or was the
|
| 180 |
+
target state achieved as intended?
|
| 181 |
+
- Safety & Risks: Assess any risks associated with the transition and the
|
| 182 |
+
measures to mitigate them.
|
| 183 |
+
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
| 184 |
+
and adjust the quality of transition, making it more adaptive.
|
| 185 |
+
6. Iterative Evaluation:
|
| 186 |
+
- Test & Refine: Based on the initially deduced conditions and assessed
|
| 187 |
+
quality, iterate the process to refine and optimize the transition. This
|
| 188 |
+
might involve tweaking conditions, employing different paths, or changing
|
| 189 |
+
resources.
|
| 190 |
+
- Feedback Integration: Use feedback to make improvements and increase the
|
| 191 |
+
quality of the transition.
|
| 192 |
+
7. Real-world scenarios often present challenges that may not be captured by
|
| 193 |
+
models and frameworks. While using the model, maintain an adaptive mindset:
|
| 194 |
+
- Scenario Exploration: Continuously imagine various possible scenarios,
|
| 195 |
+
both positive and negative, to prepare for unexpected events.
|
| 196 |
+
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
| 197 |
+
process ($L$) if unforeseen challenges arise.
|
| 198 |
+
- Feedback Integration: Rapidly integrate feedback from actual
|
| 199 |
+
implementations to adjust the model's application, ensuring relevancy and
|
| 200 |
+
effectiveness.
|
| 201 |
+
|
| 202 |
+
===== TASK =====
|
| 203 |
+
Given the starting state $A$ and the target state $B$, assuming that a path
|
| 204 |
+
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
| 205 |
+
necessary conditions $C$ and the quality $Q$ of the transition?
|
| 206 |
+
|
| 207 |
+
===== STARTING STATE $A$ =====
|
| 208 |
+
{starting_state}
|
| 209 |
+
|
| 210 |
+
===== TARGET STATE $B$ =====
|
| 211 |
+
{target_state}
|
| 212 |
+
|
| 213 |
+
{role_with_description_prompt}
|
| 214 |
+
===== ANSWER TEMPLATE =====
|
| 215 |
+
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
| 216 |
+
- Historical & Empirical Analysis:\n<BLANK>/None
|
| 217 |
+
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
| 218 |
+
condition <NUM>:
|
| 219 |
+
<BLANK>.
|
| 220 |
+
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
| 221 |
+
square brackets)
|
| 222 |
+
- Quality Assessment ($Q$) (do not use symbols):
|
| 223 |
+
<BLANK>.
|
| 224 |
+
- Iterative Evaluation:\n<BLANK>/None"""
|
| 225 |
+
|
| 226 |
+
if role_descriptions_dict is not None:
|
| 227 |
+
role_names = role_descriptions_dict.keys()
|
| 228 |
+
role_with_description_prompt = (
|
| 229 |
+
"===== ROLES WITH DESCRIPTIONS =====\n"
|
| 230 |
+
+ "\n".join(
|
| 231 |
+
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
| 232 |
+
for role_name in role_names
|
| 233 |
+
)
|
| 234 |
+
+ "\n\n"
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
role_with_description_prompt = ""
|
| 238 |
+
deduce_prompt = TextPrompt(deduce_prompt)
|
| 239 |
+
|
| 240 |
+
deduce = deduce_prompt.format(
|
| 241 |
+
starting_state=starting_state,
|
| 242 |
+
target_state=target_state,
|
| 243 |
+
role_with_description_prompt=role_with_description_prompt,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
| 247 |
+
role_name="Deductive Reasoner", content=deduce
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response = self.step(
|
| 251 |
+
input_message=conditions_and_quality_generation_msg
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if response.terminated:
|
| 255 |
+
raise RuntimeError(
|
| 256 |
+
"Deduction failed. Error:\n" + f"{response.info}"
|
| 257 |
+
)
|
| 258 |
+
msg: BaseMessage = response.msg
|
| 259 |
+
logger.info(f"Message content:\n{msg.content}")
|
| 260 |
+
|
| 261 |
+
# Extract the conditions from the message
|
| 262 |
+
conditions_dict = {
|
| 263 |
+
f"condition {i}": cdt.replace("<", "")
|
| 264 |
+
.replace(">", "")
|
| 265 |
+
.strip()
|
| 266 |
+
.strip('\n')
|
| 267 |
+
for i, cdt in re.findall(
|
| 268 |
+
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
| 269 |
+
msg.content,
|
| 270 |
+
re.DOTALL,
|
| 271 |
+
)
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
# Extract the labels from the message
|
| 275 |
+
labels = [
|
| 276 |
+
label.strip().strip('\n').strip("\"'")
|
| 277 |
+
for label in re.findall(
|
| 278 |
+
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
| 279 |
+
msg.content,
|
| 280 |
+
re.DOTALL,
|
| 281 |
+
)[0].split(",")
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# Extract the quality from the message
|
| 285 |
+
quality = next(
|
| 286 |
+
q.strip().strip('\n')
|
| 287 |
+
for q in re.findall(
|
| 288 |
+
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
| 289 |
+
r"\n(.+?)- Iterative",
|
| 290 |
+
msg.content,
|
| 291 |
+
re.DOTALL,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Convert them into JSON format
|
| 296 |
+
conditions_and_quality_json: Dict[
|
| 297 |
+
str, Union[List[str], Dict[str, str]]
|
| 298 |
+
] = {}
|
| 299 |
+
conditions_and_quality_json["conditions"] = conditions_dict
|
| 300 |
+
conditions_and_quality_json["labels"] = labels
|
| 301 |
+
conditions_and_quality_json["evaluate_quality"] = quality
|
| 302 |
+
|
| 303 |
+
return conditions_and_quality_json
|
camel/agents/embodied_agent.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, List, Optional
|
| 15 |
+
|
| 16 |
+
from colorama import Fore
|
| 17 |
+
|
| 18 |
+
from camel.agents.chat_agent import ChatAgent
|
| 19 |
+
from camel.agents.tool_agents.base import BaseToolAgent
|
| 20 |
+
from camel.interpreters import (
|
| 21 |
+
BaseInterpreter,
|
| 22 |
+
InternalPythonInterpreter,
|
| 23 |
+
SubprocessInterpreter,
|
| 24 |
+
)
|
| 25 |
+
from camel.messages import BaseMessage
|
| 26 |
+
from camel.models import BaseModelBackend
|
| 27 |
+
from camel.responses import ChatAgentResponse
|
| 28 |
+
from camel.utils import print_text_animated
|
| 29 |
+
|
| 30 |
+
# AgentOps decorator setting
|
| 31 |
+
try:
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 35 |
+
from agentops import track_agent
|
| 36 |
+
else:
|
| 37 |
+
raise ImportError
|
| 38 |
+
except (ImportError, AttributeError):
|
| 39 |
+
from camel.utils import track_agent
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@track_agent(name="EmbodiedAgent")
|
| 43 |
+
class EmbodiedAgent(ChatAgent):
|
| 44 |
+
r"""Class for managing conversations of CAMEL Embodied Agents.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
system_message (BaseMessage): The system message for the chat agent.
|
| 48 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 49 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 50 |
+
`GPT_4O_MINI`)
|
| 51 |
+
message_window_size (int, optional): The maximum number of previous
|
| 52 |
+
messages to include in the context window. If `None`, no windowing
|
| 53 |
+
is performed. (default: :obj:`None`)
|
| 54 |
+
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
| 55 |
+
the embodied agent. (default: :obj:`None`)
|
| 56 |
+
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
| 57 |
+
execute codes. If `code_interpreter` and `tool_agent` are both
|
| 58 |
+
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
| 59 |
+
is `None` and `tool_agents` is not `None`, default to
|
| 60 |
+
`InternalPythonInterpreter`. (default: :obj:`None`)
|
| 61 |
+
verbose (bool, optional): Whether to print the critic's messages.
|
| 62 |
+
logger_color (Any): The color of the logger displayed to the user.
|
| 63 |
+
(default: :obj:`Fore.MAGENTA`)
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
system_message: BaseMessage,
|
| 69 |
+
model: Optional[BaseModelBackend] = None,
|
| 70 |
+
message_window_size: Optional[int] = None,
|
| 71 |
+
tool_agents: Optional[List[BaseToolAgent]] = None,
|
| 72 |
+
code_interpreter: Optional[BaseInterpreter] = None,
|
| 73 |
+
verbose: bool = False,
|
| 74 |
+
logger_color: Any = Fore.MAGENTA,
|
| 75 |
+
) -> None:
|
| 76 |
+
self.tool_agents = tool_agents
|
| 77 |
+
self.code_interpreter: BaseInterpreter
|
| 78 |
+
if code_interpreter is not None:
|
| 79 |
+
self.code_interpreter = code_interpreter
|
| 80 |
+
elif self.tool_agents:
|
| 81 |
+
self.code_interpreter = InternalPythonInterpreter()
|
| 82 |
+
else:
|
| 83 |
+
self.code_interpreter = SubprocessInterpreter()
|
| 84 |
+
|
| 85 |
+
if self.tool_agents:
|
| 86 |
+
system_message = self._set_tool_agents(system_message)
|
| 87 |
+
self.verbose = verbose
|
| 88 |
+
self.logger_color = logger_color
|
| 89 |
+
super().__init__(
|
| 90 |
+
system_message=system_message,
|
| 91 |
+
model=model,
|
| 92 |
+
message_window_size=message_window_size,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
| 96 |
+
action_space_prompt = self._get_tool_agents_prompt()
|
| 97 |
+
result_message = system_message.create_new_instance(
|
| 98 |
+
content=system_message.content.format(
|
| 99 |
+
action_space=action_space_prompt
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
if self.tool_agents is not None:
|
| 103 |
+
self.code_interpreter.update_action_space(
|
| 104 |
+
{tool.name: tool for tool in self.tool_agents}
|
| 105 |
+
)
|
| 106 |
+
return result_message
|
| 107 |
+
|
| 108 |
+
def _get_tool_agents_prompt(self) -> str:
|
| 109 |
+
r"""Returns the action space prompt.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
str: The action space prompt.
|
| 113 |
+
"""
|
| 114 |
+
if self.tool_agents is not None:
|
| 115 |
+
return "\n".join(
|
| 116 |
+
[
|
| 117 |
+
f"*** {tool.name} ***:\n {tool.description}"
|
| 118 |
+
for tool in self.tool_agents
|
| 119 |
+
]
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
return ""
|
| 123 |
+
|
| 124 |
+
def get_tool_agent_names(self) -> List[str]:
|
| 125 |
+
r"""Returns the names of tool agents.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List[str]: The names of tool agents.
|
| 129 |
+
"""
|
| 130 |
+
if self.tool_agents is not None:
|
| 131 |
+
return [tool.name for tool in self.tool_agents]
|
| 132 |
+
else:
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
# ruff: noqa: E501
|
| 136 |
+
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
| 137 |
+
r"""Performs a step in the conversation.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
input_message (BaseMessage): The input message.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
ChatAgentResponse: A struct containing the output messages,
|
| 144 |
+
a boolean indicating whether the chat session has terminated,
|
| 145 |
+
and information about the chat session.
|
| 146 |
+
"""
|
| 147 |
+
response = super().step(input_message)
|
| 148 |
+
|
| 149 |
+
if response.msgs is None or len(response.msgs) == 0:
|
| 150 |
+
raise RuntimeError("Got None output messages.")
|
| 151 |
+
if response.terminated:
|
| 152 |
+
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
| 153 |
+
|
| 154 |
+
# NOTE: Only single output messages are supported
|
| 155 |
+
explanations, codes = response.msg.extract_text_and_code_prompts()
|
| 156 |
+
|
| 157 |
+
if self.verbose:
|
| 158 |
+
for explanation, code in zip(explanations, codes):
|
| 159 |
+
print_text_animated(
|
| 160 |
+
self.logger_color + f"> Explanation:\n{explanation}"
|
| 161 |
+
)
|
| 162 |
+
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
| 163 |
+
|
| 164 |
+
if len(explanations) > len(codes):
|
| 165 |
+
print_text_animated(
|
| 166 |
+
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
content = response.msg.content
|
| 170 |
+
|
| 171 |
+
if codes is not None:
|
| 172 |
+
try:
|
| 173 |
+
content = "\n> Executed Results:\n"
|
| 174 |
+
for block_idx, code in enumerate(codes):
|
| 175 |
+
executed_output = self.code_interpreter.run(
|
| 176 |
+
code, code.code_type
|
| 177 |
+
)
|
| 178 |
+
content += (
|
| 179 |
+
f"Executing code block {block_idx}: {{\n"
|
| 180 |
+
+ executed_output
|
| 181 |
+
+ "}\n"
|
| 182 |
+
)
|
| 183 |
+
except InterruptedError as e:
|
| 184 |
+
content = (
|
| 185 |
+
f"\n> Running code fail: {e}\n"
|
| 186 |
+
"Please regenerate the code."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# TODO: Handle errors
|
| 190 |
+
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
| 191 |
+
message = BaseMessage(
|
| 192 |
+
input_message.role_name,
|
| 193 |
+
input_message.role_type,
|
| 194 |
+
input_message.meta_dict,
|
| 195 |
+
content,
|
| 196 |
+
)
|
| 197 |
+
return ChatAgentResponse(
|
| 198 |
+
msgs=[message],
|
| 199 |
+
terminated=response.terminated,
|
| 200 |
+
info=response.info,
|
| 201 |
+
)
|
camel/agents/knowledge_graph_agent.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from unstructured.documents.elements import Element
|
| 18 |
+
|
| 19 |
+
from camel.agents import ChatAgent
|
| 20 |
+
from camel.messages import BaseMessage
|
| 21 |
+
from camel.models import BaseModelBackend
|
| 22 |
+
from camel.prompts import TextPrompt
|
| 23 |
+
from camel.storages.graph_storages.graph_element import (
|
| 24 |
+
GraphElement,
|
| 25 |
+
Node,
|
| 26 |
+
Relationship,
|
| 27 |
+
)
|
| 28 |
+
from camel.types import RoleType
|
| 29 |
+
|
| 30 |
+
# AgentOps decorator setting
|
| 31 |
+
try:
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 35 |
+
from agentops import track_agent
|
| 36 |
+
else:
|
| 37 |
+
raise ImportError
|
| 38 |
+
except (ImportError, AttributeError):
|
| 39 |
+
from camel.utils import track_agent
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
text_prompt = """
|
| 43 |
+
You are tasked with extracting nodes and relationships from given content and
|
| 44 |
+
structures them into Node and Relationship objects. Here's the outline of what
|
| 45 |
+
you needs to do:
|
| 46 |
+
|
| 47 |
+
Content Extraction:
|
| 48 |
+
You should be able to process input content and identify entities mentioned
|
| 49 |
+
within it.
|
| 50 |
+
Entities can be any noun phrases or concepts that represent distinct entities
|
| 51 |
+
in the context of the given content.
|
| 52 |
+
|
| 53 |
+
Node Extraction:
|
| 54 |
+
For each identified entity, you should create a Node object.
|
| 55 |
+
Each Node object should have a unique identifier (id) and a type (type).
|
| 56 |
+
Additional properties associated with the node can also be extracted and
|
| 57 |
+
stored.
|
| 58 |
+
|
| 59 |
+
Relationship Extraction:
|
| 60 |
+
You should identify relationships between entities mentioned in the content.
|
| 61 |
+
For each relationship, create a Relationship object.
|
| 62 |
+
A Relationship object should have a subject (subj) and an object (obj) which
|
| 63 |
+
are Node objects representing the entities involved in the relationship.
|
| 64 |
+
Each relationship should also have a type (type), and additional properties if
|
| 65 |
+
applicable.
|
| 66 |
+
|
| 67 |
+
Output Formatting:
|
| 68 |
+
The extracted nodes and relationships should be formatted as instances of the
|
| 69 |
+
provided Node and Relationship classes.
|
| 70 |
+
Ensure that the extracted data adheres to the structure defined by the classes.
|
| 71 |
+
Output the structured data in a format that can be easily validated against
|
| 72 |
+
the provided code.
|
| 73 |
+
|
| 74 |
+
Instructions for you:
|
| 75 |
+
Read the provided content thoroughly.
|
| 76 |
+
Identify distinct entities mentioned in the content and categorize them as
|
| 77 |
+
nodes.
|
| 78 |
+
Determine relationships between these entities and represent them as directed
|
| 79 |
+
relationships.
|
| 80 |
+
Provide the extracted nodes and relationships in the specified format below.
|
| 81 |
+
Example for you:
|
| 82 |
+
|
| 83 |
+
Example Content:
|
| 84 |
+
"John works at XYZ Corporation. He is a software engineer. The company is
|
| 85 |
+
located in New York City."
|
| 86 |
+
|
| 87 |
+
Expected Output:
|
| 88 |
+
|
| 89 |
+
Nodes:
|
| 90 |
+
|
| 91 |
+
Node(id='John', type='Person')
|
| 92 |
+
Node(id='XYZ Corporation', type='Organization')
|
| 93 |
+
Node(id='New York City', type='Location')
|
| 94 |
+
|
| 95 |
+
Relationships:
|
| 96 |
+
|
| 97 |
+
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
| 98 |
+
Corporation', type='Organization'), type='WorksAt')
|
| 99 |
+
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
| 100 |
+
type='Location'), type='ResidesIn')
|
| 101 |
+
|
| 102 |
+
===== TASK =====
|
| 103 |
+
Please extracts nodes and relationships from given content and structures them
|
| 104 |
+
into Node and Relationship objects.
|
| 105 |
+
|
| 106 |
+
{task}
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@track_agent(name="KnowledgeGraphAgent")
|
| 111 |
+
class KnowledgeGraphAgent(ChatAgent):
|
| 112 |
+
r"""An agent that can extract node and relationship information for
|
| 113 |
+
different entities from given `Element` content.
|
| 114 |
+
|
| 115 |
+
Attributes:
|
| 116 |
+
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
| 117 |
+
relationship information for different entities.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
model: Optional[BaseModelBackend] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
r"""Initialize the `KnowledgeGraphAgent`.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 128 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 129 |
+
`GPT_4O_MINI`)
|
| 130 |
+
"""
|
| 131 |
+
system_message = BaseMessage(
|
| 132 |
+
role_name="Graphify",
|
| 133 |
+
role_type=RoleType.ASSISTANT,
|
| 134 |
+
meta_dict=None,
|
| 135 |
+
content="Your mission is to transform unstructured content "
|
| 136 |
+
"into structured graph data. Extract nodes and relationships with "
|
| 137 |
+
"precision, and let the connections unfold. Your graphs will "
|
| 138 |
+
"illuminate the hidden connections within the chaos of "
|
| 139 |
+
"information.",
|
| 140 |
+
)
|
| 141 |
+
super().__init__(system_message, model=model)
|
| 142 |
+
|
| 143 |
+
def run(
|
| 144 |
+
self,
|
| 145 |
+
element: "Element",
|
| 146 |
+
parse_graph_elements: bool = False,
|
| 147 |
+
) -> Union[str, GraphElement]:
|
| 148 |
+
r"""Run the agent to extract node and relationship information.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
element (Element): The input element.
|
| 152 |
+
parse_graph_elements (bool, optional): Whether to parse into
|
| 153 |
+
`GraphElement`. Defaults to `False`.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Union[str, GraphElement]: The extracted node and relationship
|
| 157 |
+
information. If `parse_graph_elements` is `True` then return
|
| 158 |
+
`GraphElement`, else return `str`.
|
| 159 |
+
"""
|
| 160 |
+
self.reset()
|
| 161 |
+
self.element = element
|
| 162 |
+
|
| 163 |
+
knowledge_graph_prompt = TextPrompt(text_prompt)
|
| 164 |
+
knowledge_graph_generation = knowledge_graph_prompt.format(
|
| 165 |
+
task=str(element)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
knowledge_graph_generation_msg = BaseMessage.make_user_message(
|
| 169 |
+
role_name="Graphify", content=knowledge_graph_generation
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
response = self.step(input_message=knowledge_graph_generation_msg)
|
| 173 |
+
|
| 174 |
+
content = response.msg.content
|
| 175 |
+
|
| 176 |
+
if parse_graph_elements:
|
| 177 |
+
content = self._parse_graph_elements(content)
|
| 178 |
+
|
| 179 |
+
return content
|
| 180 |
+
|
| 181 |
+
def _validate_node(self, node: Node) -> bool:
|
| 182 |
+
r"""Validate if the object is a valid Node.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
node (Node): Object to be validated.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
bool: True if the object is a valid Node, False otherwise.
|
| 189 |
+
"""
|
| 190 |
+
return (
|
| 191 |
+
isinstance(node, Node)
|
| 192 |
+
and isinstance(node.id, (str, int))
|
| 193 |
+
and isinstance(node.type, str)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def _validate_relationship(self, relationship: Relationship) -> bool:
|
| 197 |
+
r"""Validate if the object is a valid Relationship.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
relationship (Relationship): Object to be validated.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
bool: True if the object is a valid Relationship, False otherwise.
|
| 204 |
+
"""
|
| 205 |
+
return (
|
| 206 |
+
isinstance(relationship, Relationship)
|
| 207 |
+
and self._validate_node(relationship.subj)
|
| 208 |
+
and self._validate_node(relationship.obj)
|
| 209 |
+
and isinstance(relationship.type, str)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
| 213 |
+
r"""Parses graph elements from given content.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
input_string (str): The input content.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
GraphElement: The parsed graph elements.
|
| 220 |
+
"""
|
| 221 |
+
import re
|
| 222 |
+
|
| 223 |
+
# Regular expressions to extract nodes and relationships
|
| 224 |
+
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
| 225 |
+
rel_pattern = (
|
| 226 |
+
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
| 227 |
+
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
nodes = {}
|
| 231 |
+
relationships = []
|
| 232 |
+
|
| 233 |
+
# Extract nodes
|
| 234 |
+
for match in re.finditer(node_pattern, input_string):
|
| 235 |
+
id, type = match.groups()
|
| 236 |
+
properties = {'source': 'agent_created'}
|
| 237 |
+
if id not in nodes:
|
| 238 |
+
node = Node(id=id, type=type, properties=properties)
|
| 239 |
+
if self._validate_node(node):
|
| 240 |
+
nodes[id] = node
|
| 241 |
+
|
| 242 |
+
# Extract relationships
|
| 243 |
+
for match in re.finditer(rel_pattern, input_string):
|
| 244 |
+
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
|
| 245 |
+
properties = {'source': 'agent_created'}
|
| 246 |
+
if subj_id in nodes and obj_id in nodes:
|
| 247 |
+
subj = nodes[subj_id]
|
| 248 |
+
obj = nodes[obj_id]
|
| 249 |
+
relationship = Relationship(
|
| 250 |
+
subj=subj, obj=obj, type=rel_type, properties=properties
|
| 251 |
+
)
|
| 252 |
+
if self._validate_relationship(relationship):
|
| 253 |
+
relationships.append(relationship)
|
| 254 |
+
|
| 255 |
+
return GraphElement(
|
| 256 |
+
nodes=list(nodes.values()),
|
| 257 |
+
relationships=relationships,
|
| 258 |
+
source=self.element,
|
| 259 |
+
)
|
camel/agents/multi_hop_generator_agent.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import textwrap
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from pydantic import ConfigDict
|
| 19 |
+
|
| 20 |
+
from camel.agents.programmed_agent_instruction import (
|
| 21 |
+
ProgrammableChatAgent,
|
| 22 |
+
ProgrammedAgentInstructionResult,
|
| 23 |
+
programmable_capability,
|
| 24 |
+
)
|
| 25 |
+
from camel.datagen.source2synth.models import (
|
| 26 |
+
ContextPrompt,
|
| 27 |
+
MultiHopQA,
|
| 28 |
+
)
|
| 29 |
+
from camel.messages import BaseMessage
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MultiHopGeneratorAgent(ProgrammableChatAgent):
|
| 33 |
+
r"""An agent specialized in generating multi-hop question-answer pairs.
|
| 34 |
+
|
| 35 |
+
This agent is designed to create complex questions that require multiple
|
| 36 |
+
steps of reasoning to answer. It analyzes context to identify related
|
| 37 |
+
facts and generates questions that require connecting these facts
|
| 38 |
+
logically.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
model_config (ConfigDict): Configuration for model behavior.
|
| 42 |
+
system_message (BaseMessage): System message defining agent's role and
|
| 43 |
+
instructions.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 47 |
+
|
| 48 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 49 |
+
r"""Initialize the MultiHopGeneratorAgent.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
**kwargs (Any): Additional keyword arguments to pass to parent
|
| 53 |
+
class.
|
| 54 |
+
"""
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
|
| 57 |
+
system_text: str = textwrap.dedent(
|
| 58 |
+
"""\
|
| 59 |
+
You are an expert at generating
|
| 60 |
+
multi-hop question-answer pairs.
|
| 61 |
+
For each context, you should:
|
| 62 |
+
1. Identify multiple related facts or pieces of information
|
| 63 |
+
2. Create questions that require reasoning across these multiple pieces
|
| 64 |
+
3. Ensure the reasoning chain is clear and logical
|
| 65 |
+
4. Generate questions that require at least 2-3 steps of reasoning
|
| 66 |
+
5. Include the reasoning steps in the answer
|
| 67 |
+
|
| 68 |
+
Give your response with this information:
|
| 69 |
+
Question: [Complex question requiring multiple reasoning steps]
|
| 70 |
+
Reasoning Steps:
|
| 71 |
+
1. [First reasoning step]
|
| 72 |
+
2. [Second reasoning step]
|
| 73 |
+
3. [Final reasoning step]
|
| 74 |
+
Answer: [Final answer]
|
| 75 |
+
Supporting Facts: [List of relevant text segments used]
|
| 76 |
+
""" # noqa: E501
|
| 77 |
+
)
|
| 78 |
+
self.system_message = BaseMessage.make_assistant_message(
|
| 79 |
+
role_name='Assistant', content=system_text
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@programmable_capability
|
| 83 |
+
def generate_multi_hop_qa(
|
| 84 |
+
self, context: str
|
| 85 |
+
) -> ProgrammedAgentInstructionResult[MultiHopQA]:
|
| 86 |
+
r"""Generate a multi-hop question-answer pair from given context.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
context (str): The input text context to generate QA from.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
|
| 93 |
+
generated question, reasoning steps, answer, and supporting
|
| 94 |
+
facts.
|
| 95 |
+
|
| 96 |
+
Raises:
|
| 97 |
+
RuntimeError: If the agent fails to generate a response.
|
| 98 |
+
"""
|
| 99 |
+
context_prompt = ContextPrompt(
|
| 100 |
+
main_context=context, related_contexts=None
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
user_message = BaseMessage.make_user_message(
|
| 104 |
+
content=context_prompt.model_dump_json(), role_name="User"
|
| 105 |
+
)
|
| 106 |
+
response = self.step(
|
| 107 |
+
input_message=user_message, response_format=MultiHopQA
|
| 108 |
+
)
|
| 109 |
+
value = MultiHopQA.model_validate_json(response.msgs[0].content)
|
| 110 |
+
|
| 111 |
+
if response.msgs:
|
| 112 |
+
return ProgrammedAgentInstructionResult(
|
| 113 |
+
user_message=user_message,
|
| 114 |
+
agent_message=response.msgs[0],
|
| 115 |
+
value=value,
|
| 116 |
+
)
|
| 117 |
+
raise RuntimeError("No response from agent")
|
camel/agents/programmed_agent_instruction.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import abc
|
| 15 |
+
import threading
|
| 16 |
+
from enum import Enum
|
| 17 |
+
from functools import wraps
|
| 18 |
+
from typing import Any, Callable, Generic, Optional, TypeVar
|
| 19 |
+
|
| 20 |
+
from pydantic import BaseModel, ConfigDict
|
| 21 |
+
|
| 22 |
+
from camel.agents import ChatAgent
|
| 23 |
+
from camel.messages import BaseMessage
|
| 24 |
+
|
| 25 |
+
T = TypeVar('T')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ProgrammableAgentRequirement(Enum):
|
| 29 |
+
r"""Requirements for programmable agent state.
|
| 30 |
+
|
| 31 |
+
Defines the possible requirements that can be used to repair the state
|
| 32 |
+
of a programmable agent.
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
|
| 36 |
+
conversation was not from the user.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
|
| 43 |
+
r"""Result of a programmable agent instruction execution.
|
| 44 |
+
|
| 45 |
+
Contains the messages exchanged during execution and the computed value.
|
| 46 |
+
The value type is specified by the generic type parameter T.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
user_message (BaseMessage): The message sent by the user.
|
| 50 |
+
agent_message (BaseMessage): The message sent by the agent.
|
| 51 |
+
value (T): The computed result value of type T.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
user_message: BaseMessage
|
| 55 |
+
agent_message: BaseMessage
|
| 56 |
+
value: T
|
| 57 |
+
|
| 58 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AbstractProgrammableAgent(abc.ABC):
|
| 62 |
+
r"""Abstract class for a programmable agent.
|
| 63 |
+
|
| 64 |
+
A programmable agent is an agent that can be programmed to perform a
|
| 65 |
+
specific function or task. This class defines the interface for a
|
| 66 |
+
programmable agent.
|
| 67 |
+
|
| 68 |
+
These methods should be implemented in order to ensure the agent supports
|
| 69 |
+
the necessary guarantees to enable a programming interface while
|
| 70 |
+
maintaining compatibility in a multi-agent system.
|
| 71 |
+
|
| 72 |
+
A programmable agent is responsible for providing and maintaining a
|
| 73 |
+
programming interface for its functionality.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
@abc.abstractmethod
|
| 77 |
+
def run_atomic(
|
| 78 |
+
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
| 79 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 80 |
+
r"""Run an atomic operation on the agent.
|
| 81 |
+
|
| 82 |
+
An atomic operation is an operation that is guaranteed to
|
| 83 |
+
be executed without interruption by any other operation.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
| 87 |
+
operation to execute atomically.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
RuntimeError: If an operation is already in progress.
|
| 94 |
+
"""
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
| 99 |
+
r"""Repair the state of the agent.
|
| 100 |
+
|
| 101 |
+
Agents may have other non-atomic interfaces, such as a user interface,
|
| 102 |
+
or chat between other agents. This method should restore the agent to
|
| 103 |
+
a state where it can perform operations according to the specified
|
| 104 |
+
requirement.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
requirement (ProgrammableAgentRequirement): The requirement to
|
| 108 |
+
repair the state for.
|
| 109 |
+
"""
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def programmable_capability(
|
| 114 |
+
func: Callable[..., ProgrammedAgentInstructionResult[T]],
|
| 115 |
+
) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
|
| 116 |
+
r"""Decorator for programmable agent capabilities.
|
| 117 |
+
|
| 118 |
+
This decorator ensures that the decorated method is executed atomically
|
| 119 |
+
and maintains the agent's state guarantees.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
|
| 123 |
+
to decorate.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
|
| 127 |
+
method that ensures atomic execution.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
@wraps(func)
|
| 131 |
+
def wrapper(
|
| 132 |
+
self, *args: Any, **kwargs: Any
|
| 133 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 134 |
+
return self.run_atomic(lambda: func(self, *args, **kwargs))
|
| 135 |
+
|
| 136 |
+
return wrapper
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
|
| 140 |
+
r"""A chat agent that can be programmed to perform specific tasks.
|
| 141 |
+
|
| 142 |
+
Provides a default implementation of atomic execution using threading locks
|
| 143 |
+
and basic state tracking for message roles. Implementing classes need to
|
| 144 |
+
provide specific repair logic for their use cases.
|
| 145 |
+
|
| 146 |
+
Attributes:
|
| 147 |
+
_operation_lock (threading.Lock): Lock for ensuring atomic operations.
|
| 148 |
+
_last_message_role (Optional[str]): Role of the last message in the
|
| 149 |
+
conversation.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 153 |
+
r"""Initialize the ProgrammableChatAgent.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
**kwargs (Any): Additional keyword arguments to pass to parent
|
| 157 |
+
class.
|
| 158 |
+
"""
|
| 159 |
+
super().__init__(**kwargs)
|
| 160 |
+
self._operation_lock = threading.Lock()
|
| 161 |
+
self._last_message_role: Optional[str] = None
|
| 162 |
+
|
| 163 |
+
def run_atomic(
|
| 164 |
+
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
| 165 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 166 |
+
r"""Run an atomic operation on the agent.
|
| 167 |
+
|
| 168 |
+
Ensures thread-safe execution of the callback function by using a lock.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
| 172 |
+
operation to execute atomically.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
| 176 |
+
|
| 177 |
+
Raises:
|
| 178 |
+
RuntimeError: If an operation is already in progress.
|
| 179 |
+
"""
|
| 180 |
+
if not self._operation_lock.acquire(blocking=False):
|
| 181 |
+
raise RuntimeError("Operation already in progress")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
result = callback()
|
| 185 |
+
self._last_message_role = result.agent_message.role_name
|
| 186 |
+
return result
|
| 187 |
+
finally:
|
| 188 |
+
self._operation_lock.release()
|
| 189 |
+
|
| 190 |
+
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
| 191 |
+
r"""Repair the state of the agent.
|
| 192 |
+
|
| 193 |
+
Implements basic state repair for message role requirements.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
requirement (ProgrammableAgentRequirement): The requirement to
|
| 197 |
+
repair the state for.
|
| 198 |
+
"""
|
| 199 |
+
if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
|
| 200 |
+
if self._last_message_role == "user":
|
| 201 |
+
raise NotImplementedError(
|
| 202 |
+
"Must implement repair for LAST_MESSAGE_NOT_USER"
|
| 203 |
+
)
|
camel/agents/role_assignment_agent.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, Optional, Union
|
| 16 |
+
|
| 17 |
+
from camel.agents.chat_agent import ChatAgent
|
| 18 |
+
from camel.messages import BaseMessage
|
| 19 |
+
from camel.models import BaseModelBackend
|
| 20 |
+
from camel.prompts import TextPrompt
|
| 21 |
+
from camel.types import RoleType
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="RoleAssignmentAgent")
|
| 36 |
+
class RoleAssignmentAgent(ChatAgent):
|
| 37 |
+
r"""An agent that generates role names based on the task prompt.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 41 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 42 |
+
`GPT_4O_MINI`)
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
| 46 |
+
role names.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
model: Optional[BaseModelBackend] = None,
|
| 52 |
+
) -> None:
|
| 53 |
+
system_message = BaseMessage(
|
| 54 |
+
role_name="Role Assigner",
|
| 55 |
+
role_type=RoleType.ASSISTANT,
|
| 56 |
+
meta_dict=None,
|
| 57 |
+
content="You assign roles based on tasks.",
|
| 58 |
+
)
|
| 59 |
+
super().__init__(system_message, model=model)
|
| 60 |
+
|
| 61 |
+
def run(
|
| 62 |
+
self,
|
| 63 |
+
task_prompt: Union[str, TextPrompt],
|
| 64 |
+
num_roles: int = 2,
|
| 65 |
+
) -> Dict[str, str]:
|
| 66 |
+
r"""Generate role names based on the input task prompt.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
task_prompt (Union[str, TextPrompt]): The prompt
|
| 70 |
+
for the task based on which the roles are to be generated.
|
| 71 |
+
num_roles (int, optional): The number of roles to generate.
|
| 72 |
+
(default: :obj:`2`)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, str]: A dictionary mapping role names to their
|
| 76 |
+
descriptions.
|
| 77 |
+
"""
|
| 78 |
+
self.reset()
|
| 79 |
+
|
| 80 |
+
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
| 81 |
+
f"Domain expert {i + 1}: <BLANK>\n"
|
| 82 |
+
f"Associated competencies, characteristics, duties "
|
| 83 |
+
f"and workflows: <BLANK>. End."
|
| 84 |
+
for i in range(num_roles or 0)
|
| 85 |
+
)
|
| 86 |
+
role_assignment_generation_prompt = TextPrompt(
|
| 87 |
+
"You are a role assignment agent, and you're in charge of "
|
| 88 |
+
+ "recruiting {num_roles} experts for the following task."
|
| 89 |
+
+ "\n==== TASK =====\n {task}\n\n"
|
| 90 |
+
+ "Identify the domain experts you'd recruit and detail their "
|
| 91 |
+
+ "associated competencies, characteristics, duties and workflows "
|
| 92 |
+
+ "to complete the task.\n "
|
| 93 |
+
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
| 94 |
+
+ "ONLY answer the BLANKs.\n"
|
| 95 |
+
+ expert_prompt
|
| 96 |
+
)
|
| 97 |
+
role_assignment_generation = role_assignment_generation_prompt.format(
|
| 98 |
+
num_roles=num_roles, task=task_prompt
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
role_assignment_generation_msg = BaseMessage.make_user_message(
|
| 102 |
+
role_name="Role Assigner", content=role_assignment_generation
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
response = self.step(input_message=role_assignment_generation_msg)
|
| 106 |
+
|
| 107 |
+
msg = response.msg # type: BaseMessage
|
| 108 |
+
terminated = response.terminated
|
| 109 |
+
|
| 110 |
+
# Distribute the output completions into role names and descriptions
|
| 111 |
+
role_names = [
|
| 112 |
+
desc.replace("<|", "").replace("|>", "")
|
| 113 |
+
for desc in re.findall(
|
| 114 |
+
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
| 115 |
+
msg.content,
|
| 116 |
+
re.DOTALL,
|
| 117 |
+
)
|
| 118 |
+
]
|
| 119 |
+
role_descriptions = [
|
| 120 |
+
desc.replace("<|", "").replace("|>", "")
|
| 121 |
+
for desc in re.findall(
|
| 122 |
+
r"Associated competencies, characteristics, "
|
| 123 |
+
r"duties and workflows: (.+?) End.",
|
| 124 |
+
msg.content,
|
| 125 |
+
re.DOTALL,
|
| 126 |
+
)
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
| 130 |
+
raise RuntimeError(
|
| 131 |
+
"Got None or insufficient information of roles."
|
| 132 |
+
)
|
| 133 |
+
if terminated:
|
| 134 |
+
raise RuntimeError("Role assignment failed.")
|
| 135 |
+
|
| 136 |
+
role_descriptions_dict = {
|
| 137 |
+
role_name: description
|
| 138 |
+
for role_name, description in zip(role_names, role_descriptions)
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
return role_descriptions_dict
|
camel/agents/search_agent.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from camel.agents.chat_agent import ChatAgent
|
| 17 |
+
from camel.messages import BaseMessage
|
| 18 |
+
from camel.models import BaseModelBackend
|
| 19 |
+
from camel.prompts import TextPrompt
|
| 20 |
+
from camel.types import RoleType
|
| 21 |
+
from camel.utils import create_chunks
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="SearchAgent")
|
| 36 |
+
class SearchAgent(ChatAgent):
|
| 37 |
+
r"""An agent that summarizes text based on a query and evaluates the
|
| 38 |
+
relevance of an answer.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 42 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 43 |
+
`GPT_4O_MINI`)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
model: Optional[BaseModelBackend] = None,
|
| 49 |
+
) -> None:
|
| 50 |
+
system_message = BaseMessage(
|
| 51 |
+
role_name="Assistant",
|
| 52 |
+
role_type=RoleType.ASSISTANT,
|
| 53 |
+
meta_dict=None,
|
| 54 |
+
content="You are a helpful assistant.",
|
| 55 |
+
)
|
| 56 |
+
super().__init__(system_message, model=model)
|
| 57 |
+
|
| 58 |
+
def summarize_text(self, text: str, query: str) -> str:
|
| 59 |
+
r"""Summarize the information from the text, base on the query.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
text (str): Text to summarize.
|
| 63 |
+
query (str): What information you want.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
str: Strings with information.
|
| 67 |
+
"""
|
| 68 |
+
self.reset()
|
| 69 |
+
|
| 70 |
+
summary_prompt = TextPrompt(
|
| 71 |
+
'''Gather information from this text that relative to the
|
| 72 |
+
question, but do not directly answer the question.\nquestion:
|
| 73 |
+
{query}\ntext '''
|
| 74 |
+
)
|
| 75 |
+
summary_prompt = summary_prompt.format(query=query)
|
| 76 |
+
# Max length of each chunk
|
| 77 |
+
max_len = 3000
|
| 78 |
+
results = ""
|
| 79 |
+
chunks = create_chunks(text, max_len)
|
| 80 |
+
# Summarize
|
| 81 |
+
for i, chunk in enumerate(chunks, start=1):
|
| 82 |
+
prompt = summary_prompt + str(i) + ": " + chunk
|
| 83 |
+
user_msg = BaseMessage.make_user_message(
|
| 84 |
+
role_name="User",
|
| 85 |
+
content=prompt,
|
| 86 |
+
)
|
| 87 |
+
result = self.step(user_msg).msg.content
|
| 88 |
+
results += result + "\n"
|
| 89 |
+
|
| 90 |
+
# Final summarization
|
| 91 |
+
final_prompt = TextPrompt(
|
| 92 |
+
'''Here are some summarized texts which split from one text. Using
|
| 93 |
+
the information to answer the question. If can't find the answer,
|
| 94 |
+
you must answer "I can not find the answer to the query" and
|
| 95 |
+
explain why.\n Query:\n{query}.\n\nText:\n'''
|
| 96 |
+
)
|
| 97 |
+
final_prompt = final_prompt.format(query=query)
|
| 98 |
+
prompt = final_prompt + results
|
| 99 |
+
|
| 100 |
+
user_msg = BaseMessage.make_user_message(
|
| 101 |
+
role_name="User",
|
| 102 |
+
content=prompt,
|
| 103 |
+
)
|
| 104 |
+
response = self.step(user_msg).msg.content
|
| 105 |
+
|
| 106 |
+
return response
|
| 107 |
+
|
| 108 |
+
def continue_search(self, query: str, answer: str) -> bool:
|
| 109 |
+
r"""Ask whether to continue search or not based on the provided answer.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
query (str): The question.
|
| 113 |
+
answer (str): The answer to the question.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
bool: `True` if the user want to continue search, `False`
|
| 117 |
+
otherwise.
|
| 118 |
+
"""
|
| 119 |
+
prompt = TextPrompt(
|
| 120 |
+
"Do you think the ANSWER can answer the QUERY? "
|
| 121 |
+
"Use only 'yes' or 'no' to answer.\n"
|
| 122 |
+
"===== QUERY =====\n{query}\n\n"
|
| 123 |
+
"===== ANSWER =====\n{answer}"
|
| 124 |
+
)
|
| 125 |
+
prompt = prompt.format(query=query, answer=answer)
|
| 126 |
+
user_msg = BaseMessage.make_user_message(
|
| 127 |
+
role_name="User",
|
| 128 |
+
content=prompt,
|
| 129 |
+
)
|
| 130 |
+
response = self.step(user_msg).msg.content
|
| 131 |
+
if "yes" in str(response).lower():
|
| 132 |
+
return False
|
| 133 |
+
return True
|
camel/agents/task_agent.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, Dict, List, Optional, Union
|
| 15 |
+
|
| 16 |
+
from camel.agents.chat_agent import ChatAgent
|
| 17 |
+
from camel.messages import BaseMessage
|
| 18 |
+
from camel.models import BaseModelBackend
|
| 19 |
+
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
| 20 |
+
from camel.types import RoleType, TaskType
|
| 21 |
+
from camel.utils import get_task_list
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="TaskSpecifyAgent")
|
| 36 |
+
class TaskSpecifyAgent(ChatAgent):
|
| 37 |
+
r"""An agent that specifies a given task prompt by prompting the user to
|
| 38 |
+
provide more details.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
| 42 |
+
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 46 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 47 |
+
`GPT_4O_MINI`)
|
| 48 |
+
task_type (TaskType, optional): The type of task for which to generate
|
| 49 |
+
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
| 50 |
+
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
| 51 |
+
specifying the task. (default: :obj:`None`)
|
| 52 |
+
word_limit (int, optional): The word limit for the task prompt.
|
| 53 |
+
(default: :obj:`50`)
|
| 54 |
+
output_language (str, optional): The language to be output by the
|
| 55 |
+
agent. (default: :obj:`None`)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
DEFAULT_WORD_LIMIT = 50
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
model: Optional[BaseModelBackend] = None,
|
| 63 |
+
task_type: TaskType = TaskType.AI_SOCIETY,
|
| 64 |
+
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
| 65 |
+
word_limit: int = DEFAULT_WORD_LIMIT,
|
| 66 |
+
output_language: Optional[str] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
self.task_specify_prompt: Union[str, TextPrompt]
|
| 69 |
+
if task_specify_prompt is None:
|
| 70 |
+
task_specify_prompt_template = (
|
| 71 |
+
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.task_specify_prompt = task_specify_prompt_template.format(
|
| 75 |
+
word_limit=word_limit
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
| 79 |
+
|
| 80 |
+
system_message = BaseMessage(
|
| 81 |
+
role_name="Task Specifier",
|
| 82 |
+
role_type=RoleType.ASSISTANT,
|
| 83 |
+
meta_dict=None,
|
| 84 |
+
content="You can make a task more specific.",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
super().__init__(
|
| 88 |
+
system_message,
|
| 89 |
+
model=model,
|
| 90 |
+
output_language=output_language,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def run(
|
| 94 |
+
self,
|
| 95 |
+
task_prompt: Union[str, TextPrompt],
|
| 96 |
+
meta_dict: Optional[Dict[str, Any]] = None,
|
| 97 |
+
) -> TextPrompt:
|
| 98 |
+
r"""Specify the given task prompt by providing more details.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
task_prompt (Union[str, TextPrompt]): The original task
|
| 102 |
+
prompt.
|
| 103 |
+
meta_dict (Dict[str, Any], optional): A dictionary containing
|
| 104 |
+
additional information to include in the prompt.
|
| 105 |
+
(default: :obj:`None`)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
TextPrompt: The specified task prompt.
|
| 109 |
+
"""
|
| 110 |
+
self.reset()
|
| 111 |
+
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
| 112 |
+
|
| 113 |
+
if meta_dict is not None:
|
| 114 |
+
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
| 115 |
+
task_msg = BaseMessage.make_user_message(
|
| 116 |
+
role_name="Task Specifier", content=task_specify_prompt
|
| 117 |
+
)
|
| 118 |
+
specifier_response = self.step(task_msg)
|
| 119 |
+
|
| 120 |
+
if specifier_response.terminated:
|
| 121 |
+
raise RuntimeError("Task specification failed.")
|
| 122 |
+
if len(specifier_response.msgs) == 0:
|
| 123 |
+
raise RuntimeError("Got no specification message.")
|
| 124 |
+
|
| 125 |
+
specified_task_msg = specifier_response.msgs[0]
|
| 126 |
+
|
| 127 |
+
return TextPrompt(specified_task_msg.content)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@track_agent(name="TaskPlannerAgent")
|
| 131 |
+
class TaskPlannerAgent(ChatAgent):
|
| 132 |
+
r"""An agent that helps divide a task into subtasks based on the input
|
| 133 |
+
task prompt.
|
| 134 |
+
|
| 135 |
+
Attributes:
|
| 136 |
+
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
| 137 |
+
the task into subtasks.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 141 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 142 |
+
`GPT_4O_MINI`)
|
| 143 |
+
output_language (str, optional): The language to be output by the
|
| 144 |
+
agent. (default: :obj:`None`)
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
model: Optional[BaseModelBackend] = None,
|
| 150 |
+
output_language: Optional[str] = None,
|
| 151 |
+
) -> None:
|
| 152 |
+
self.task_planner_prompt = TextPrompt(
|
| 153 |
+
"Divide this task into subtasks: {task}. Be concise."
|
| 154 |
+
)
|
| 155 |
+
system_message = BaseMessage(
|
| 156 |
+
role_name="Task Planner",
|
| 157 |
+
role_type=RoleType.ASSISTANT,
|
| 158 |
+
meta_dict=None,
|
| 159 |
+
content="You are a helpful task planner.",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
super().__init__(
|
| 163 |
+
system_message,
|
| 164 |
+
model=model,
|
| 165 |
+
output_language=output_language,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def run(
|
| 169 |
+
self,
|
| 170 |
+
task_prompt: Union[str, TextPrompt],
|
| 171 |
+
) -> TextPrompt:
|
| 172 |
+
r"""Generate subtasks based on the input task prompt.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
| 176 |
+
be divided into subtasks.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
TextPrompt: A prompt for the subtasks generated by the agent.
|
| 180 |
+
"""
|
| 181 |
+
# TODO: Maybe include roles information.
|
| 182 |
+
self.reset()
|
| 183 |
+
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
| 184 |
+
|
| 185 |
+
task_msg = BaseMessage.make_user_message(
|
| 186 |
+
role_name="Task Planner", content=task_planner_prompt
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
task_response = self.step(task_msg)
|
| 190 |
+
|
| 191 |
+
if task_response.terminated:
|
| 192 |
+
raise RuntimeError("Task planning failed.")
|
| 193 |
+
if len(task_response.msgs) == 0:
|
| 194 |
+
raise RuntimeError("Got no task planning message.")
|
| 195 |
+
|
| 196 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 197 |
+
return TextPrompt(sub_tasks_msg.content)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@track_agent(name="TaskCreationAgent")
|
| 201 |
+
class TaskCreationAgent(ChatAgent):
|
| 202 |
+
r"""An agent that helps create new tasks based on the objective
|
| 203 |
+
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
| 204 |
+
it's still a task planner, but it has more context information
|
| 205 |
+
like last task and incomplete task list. Modified from
|
| 206 |
+
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
| 207 |
+
|
| 208 |
+
Attributes:
|
| 209 |
+
task_creation_prompt (TextPrompt): A prompt for the agent to
|
| 210 |
+
create new tasks.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
role_name (str): The role name of the Agent to create the task.
|
| 214 |
+
objective (Union[str, TextPrompt]): The objective of the Agent to
|
| 215 |
+
perform the task.
|
| 216 |
+
model (BaseModelBackend, optional): The LLM backend to use for
|
| 217 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 218 |
+
`GPT_4O_MINI`)
|
| 219 |
+
output_language (str, optional): The language to be output by the
|
| 220 |
+
agent. (default: :obj:`None`)
|
| 221 |
+
message_window_size (int, optional): The maximum number of previous
|
| 222 |
+
messages to include in the context window. If `None`, no windowing
|
| 223 |
+
is performed. (default: :obj:`None`)
|
| 224 |
+
max_task_num (int, optional): The maximum number of planned
|
| 225 |
+
tasks in one round. (default: :obj:3)
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
role_name: str,
|
| 231 |
+
objective: Union[str, TextPrompt],
|
| 232 |
+
model: Optional[BaseModelBackend] = None,
|
| 233 |
+
output_language: Optional[str] = None,
|
| 234 |
+
message_window_size: Optional[int] = None,
|
| 235 |
+
max_task_num: Optional[int] = 3,
|
| 236 |
+
) -> None:
|
| 237 |
+
task_creation_prompt = TextPrompt(
|
| 238 |
+
"""Create new a task with the following objective: {objective}.
|
| 239 |
+
Never forget you are a Task Creator of {role_name}.
|
| 240 |
+
You must instruct me based on my expertise and your needs to solve the task.
|
| 241 |
+
You should consider past solved tasks and in-progress tasks: {task_list}.
|
| 242 |
+
The new created tasks must not overlap with these past tasks.
|
| 243 |
+
The result must be a numbered list in the format:
|
| 244 |
+
|
| 245 |
+
#. First Task
|
| 246 |
+
#. Second Task
|
| 247 |
+
#. Third Task
|
| 248 |
+
|
| 249 |
+
You can only give me up to {max_task_num} tasks at a time. \
|
| 250 |
+
Each task should be concise, concrete and doable for a {role_name}.
|
| 251 |
+
You should make task plan and not ask me questions.
|
| 252 |
+
If you think no new tasks are needed right now, write "No tasks to add."
|
| 253 |
+
Now start to give me new tasks one by one. No more than three tasks.
|
| 254 |
+
Be concrete.
|
| 255 |
+
"""
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.task_creation_prompt = task_creation_prompt.format(
|
| 259 |
+
objective=objective, role_name=role_name, max_task_num=max_task_num
|
| 260 |
+
)
|
| 261 |
+
self.objective = objective
|
| 262 |
+
|
| 263 |
+
system_message = BaseMessage(
|
| 264 |
+
role_name="Task Creator",
|
| 265 |
+
role_type=RoleType.ASSISTANT,
|
| 266 |
+
meta_dict=None,
|
| 267 |
+
content="You are a helpful task creator.",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
super().__init__(
|
| 271 |
+
system_message,
|
| 272 |
+
model=model,
|
| 273 |
+
output_language=output_language,
|
| 274 |
+
message_window_size=message_window_size,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def run(
|
| 278 |
+
self,
|
| 279 |
+
task_list: List[str],
|
| 280 |
+
) -> List[str]:
|
| 281 |
+
r"""Generate subtasks based on the previous task results and
|
| 282 |
+
incomplete task list.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
task_list (List[str]): The completed or in-progress
|
| 286 |
+
tasks which should not overlap with new created tasks.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
List[str]: The new task list generated by the Agent.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
if len(task_list) > 0:
|
| 293 |
+
task_creation_prompt = self.task_creation_prompt.format(
|
| 294 |
+
task_list=task_list
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
task_creation_prompt = self.task_creation_prompt.format(
|
| 298 |
+
task_list=""
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
task_msg = BaseMessage.make_user_message(
|
| 302 |
+
role_name="Task Creator", content=task_creation_prompt
|
| 303 |
+
)
|
| 304 |
+
task_response = self.step(task_msg)
|
| 305 |
+
|
| 306 |
+
if task_response.terminated:
|
| 307 |
+
raise RuntimeError("Task creation failed.")
|
| 308 |
+
if len(task_response.msgs) == 0:
|
| 309 |
+
raise RuntimeError("Got no task creation message.")
|
| 310 |
+
|
| 311 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 312 |
+
return get_task_list(sub_tasks_msg.content)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@track_agent(name="TaskPrioritizationAgent")
|
| 316 |
+
class TaskPrioritizationAgent(ChatAgent):
|
| 317 |
+
r"""An agent that helps re-prioritize the task list and
|
| 318 |
+
returns numbered prioritized list. Modified from
|
| 319 |
+
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
| 320 |
+
|
| 321 |
+
Attributes:
|
| 322 |
+
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
| 323 |
+
prioritize tasks.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
objective (Union[str, TextPrompt]): The objective of the Agent to
|
| 327 |
+
perform the task.
|
| 328 |
+
model (BaseModelBackend, optional): The LLM backend to use for
|
| 329 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 330 |
+
`GPT_4O_MINI`)
|
| 331 |
+
output_language (str, optional): The language to be output by the
|
| 332 |
+
agent. (default: :obj:`None`)
|
| 333 |
+
message_window_size (int, optional): The maximum number of previous
|
| 334 |
+
messages to include in the context window. If `None`, no windowing
|
| 335 |
+
is performed. (default: :obj:`None`)
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
objective: Union[str, TextPrompt],
|
| 341 |
+
model: Optional[BaseModelBackend] = None,
|
| 342 |
+
output_language: Optional[str] = None,
|
| 343 |
+
message_window_size: Optional[int] = None,
|
| 344 |
+
) -> None:
|
| 345 |
+
task_prioritization_prompt = TextPrompt(
|
| 346 |
+
"""Prioritize the following tasks : {task_list}.
|
| 347 |
+
Consider the ultimate objective of you: {objective}.
|
| 348 |
+
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
| 349 |
+
tasks are those that act as pre-requisites or are more essential for meeting \
|
| 350 |
+
the objective. Return one task per line in your response.
|
| 351 |
+
Do not remove or modify any tasks.
|
| 352 |
+
The result must be a numbered list in the format:
|
| 353 |
+
|
| 354 |
+
#. First task
|
| 355 |
+
#. Second task
|
| 356 |
+
|
| 357 |
+
The entries must be consecutively numbered, starting with 1.
|
| 358 |
+
The number of each entry must be followed by a period.
|
| 359 |
+
Do not include any headers before your ranked list or follow your list \
|
| 360 |
+
with any other output."""
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
| 364 |
+
objective=objective
|
| 365 |
+
)
|
| 366 |
+
self.objective = objective
|
| 367 |
+
|
| 368 |
+
system_message = BaseMessage(
|
| 369 |
+
role_name="Task Prioritizer",
|
| 370 |
+
role_type=RoleType.ASSISTANT,
|
| 371 |
+
meta_dict=None,
|
| 372 |
+
content="You are a helpful task prioritizer.",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
super().__init__(
|
| 376 |
+
system_message,
|
| 377 |
+
model=model,
|
| 378 |
+
output_language=output_language,
|
| 379 |
+
message_window_size=message_window_size,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def run(
|
| 383 |
+
self,
|
| 384 |
+
task_list: List[str],
|
| 385 |
+
) -> List[str]:
|
| 386 |
+
r"""Prioritize the task list given the agent objective.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
task_list (List[str]): The unprioritized tasks of agent.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List[str]: The new prioritized task list generated by the Agent.
|
| 393 |
+
"""
|
| 394 |
+
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
| 395 |
+
task_list=task_list
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
task_msg = BaseMessage.make_user_message(
|
| 399 |
+
role_name="Task Prioritizer", content=task_prioritization_prompt
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
task_response = self.step(task_msg)
|
| 403 |
+
|
| 404 |
+
if task_response.terminated:
|
| 405 |
+
raise RuntimeError("Task prioritization failed.")
|
| 406 |
+
if len(task_response.msgs) == 0:
|
| 407 |
+
raise RuntimeError("Got no task prioritization message.")
|
| 408 |
+
|
| 409 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 410 |
+
return get_task_list(sub_tasks_msg.content)
|
camel/agents/tool_agents/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .base import BaseToolAgent
|
| 15 |
+
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'BaseToolAgent',
|
| 19 |
+
'HuggingFaceToolAgent',
|
| 20 |
+
]
|
camel/agents/tool_agents/base.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from camel.agents import BaseAgent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseToolAgent(BaseAgent):
|
| 18 |
+
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
| 19 |
+
description.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
name (str): The name of the tool agent.
|
| 23 |
+
description (str): The description of the tool agent.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, name: str, description: str) -> None:
|
| 27 |
+
self.name = name
|
| 28 |
+
self.description = description
|
| 29 |
+
|
| 30 |
+
def reset(self) -> None:
|
| 31 |
+
r"""Resets the agent to its initial state."""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
def step(self) -> None:
|
| 35 |
+
r"""Performs a single step of the agent."""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def __str__(self) -> str:
|
| 39 |
+
return f"{self.name}: {self.description}"
|
camel/agents/tool_agents/hugging_face_tool_agent.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from camel.agents.tool_agents.base import BaseToolAgent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# flake8: noqa :E501
|
| 20 |
+
class HuggingFaceToolAgent(BaseToolAgent):
|
| 21 |
+
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
| 22 |
+
around agents from the `transformers` library. For more information
|
| 23 |
+
about the available models, please see the `transformers` documentation
|
| 24 |
+
at https://huggingface.co/docs/transformers/transformers_agents.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
name (str): The name of the agent.
|
| 28 |
+
*args (Any): Additional positional arguments to pass to the underlying
|
| 29 |
+
Agent class.
|
| 30 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 31 |
+
remotely. (default: :obj:`True`)
|
| 32 |
+
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
| 33 |
+
Agent class.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
name: str,
|
| 39 |
+
*args: Any,
|
| 40 |
+
remote: bool = True,
|
| 41 |
+
**kwargs: Any,
|
| 42 |
+
) -> None:
|
| 43 |
+
try:
|
| 44 |
+
# TODO: Support other tool agents
|
| 45 |
+
import transformers
|
| 46 |
+
from packaging import version
|
| 47 |
+
|
| 48 |
+
if version.parse(transformers.__version__) < version.parse(
|
| 49 |
+
"4.31.0"
|
| 50 |
+
):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"The version of \"transformers\" package should >= 4.31.0"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
from transformers.tools import OpenAiAgent
|
| 56 |
+
from transformers.tools.agent_types import AgentImage
|
| 57 |
+
except (ImportError, ValueError):
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"Could not import transformers tool agents. "
|
| 60 |
+
"Please setup the environment with "
|
| 61 |
+
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
| 62 |
+
)
|
| 63 |
+
self.agent_image_type = AgentImage
|
| 64 |
+
self.agent = OpenAiAgent(*args, **kwargs)
|
| 65 |
+
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
| 66 |
+
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
| 67 |
+
- Text question answering: given a long text and a question, answer the question in the text
|
| 68 |
+
- Unconditional image captioning: Caption the image!
|
| 69 |
+
- Image question answering: given an image, answer a question on this image
|
| 70 |
+
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
| 71 |
+
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
| 72 |
+
- Text to speech: convert text to speech
|
| 73 |
+
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
| 74 |
+
- Text summarization: summarize a long text in one or a few sentences
|
| 75 |
+
- Translation: translate the text into a given language
|
| 76 |
+
- Text downloading: to download a text from a web URL
|
| 77 |
+
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
| 78 |
+
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
| 79 |
+
- Text to video: generate a small video according to a prompt
|
| 80 |
+
|
| 81 |
+
Here are some python code examples of what you can do with this agent:
|
| 82 |
+
|
| 83 |
+
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
| 84 |
+
```
|
| 85 |
+
# Text to image
|
| 86 |
+
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
| 87 |
+
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
| 88 |
+
|
| 89 |
+
# Text to image -> Image transformation
|
| 90 |
+
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
| 91 |
+
sea_add_island_image.save("./sea_add_island_image.png")
|
| 92 |
+
|
| 93 |
+
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
| 94 |
+
# you can do so by specifying variables that you would like the agent to use. For example,
|
| 95 |
+
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
| 96 |
+
picture = {name}.step("Generate a picture of rivers and lakes.")
|
| 97 |
+
picture.save("./picture.png")
|
| 98 |
+
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
| 99 |
+
updated_picture.save("./updated_picture.png")
|
| 100 |
+
|
| 101 |
+
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
| 102 |
+
capybara_sea_image.save("./capybara_sea_image.png")
|
| 103 |
+
|
| 104 |
+
# Document question answering
|
| 105 |
+
answer = {name}.step(
|
| 106 |
+
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
| 107 |
+
document=document,
|
| 108 |
+
)
|
| 109 |
+
print(answer)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Text to image
|
| 113 |
+
boat_image = {name}.step("Generate an image of a boat in the water")
|
| 114 |
+
boat_image.save("./boat_image.png")
|
| 115 |
+
|
| 116 |
+
# Unconditional image captioning
|
| 117 |
+
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
| 118 |
+
print(boat_image_caption)
|
| 119 |
+
|
| 120 |
+
# Text to image -> Unconditional image captioning -> Text to speech
|
| 121 |
+
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
| 122 |
+
|
| 123 |
+
# Text downloading
|
| 124 |
+
document = {name}.step("Download the text from http://hf.co")
|
| 125 |
+
print(document)
|
| 126 |
+
|
| 127 |
+
# Text summarization
|
| 128 |
+
summary = {name}.step("Summarize the following text: `document`", document=document)
|
| 129 |
+
print(summary)
|
| 130 |
+
|
| 131 |
+
# Text downloading -> Text summarization -> Text to speech
|
| 132 |
+
audio = {name}.step("Read out loud the summary of http://hf.co")
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
| 136 |
+
```
|
| 137 |
+
# Clean the chat history
|
| 138 |
+
{name}.reset()
|
| 139 |
+
|
| 140 |
+
# Text to image
|
| 141 |
+
capybara_image = {name}.chat("Show me an an image of a capybara")
|
| 142 |
+
capybara_image.save("./capybara_image.png")
|
| 143 |
+
|
| 144 |
+
# Image transformation
|
| 145 |
+
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
| 146 |
+
transformed_capybara_image.save("./transformed_capybara_image.png")
|
| 147 |
+
|
| 148 |
+
# Image segmentation
|
| 149 |
+
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
| 150 |
+
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
| 151 |
+
```
|
| 152 |
+
"""
|
| 153 |
+
super(HuggingFaceToolAgent, self).__init__(name, description)
|
| 154 |
+
self.remote = remote
|
| 155 |
+
|
| 156 |
+
def reset(self) -> None:
|
| 157 |
+
r"""Resets the chat history of the agent."""
|
| 158 |
+
self.agent.prepare_for_new_chat()
|
| 159 |
+
|
| 160 |
+
def step(
|
| 161 |
+
self,
|
| 162 |
+
*args: Any,
|
| 163 |
+
remote: Optional[bool] = None,
|
| 164 |
+
**kwargs: Any,
|
| 165 |
+
) -> Any:
|
| 166 |
+
r"""Runs the agent in single execution mode.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
*args (Any): Positional arguments to pass to the agent.
|
| 170 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 171 |
+
remotely. Overrides the default setting. (default: :obj:`None`)
|
| 172 |
+
**kwargs (Any): Keyword arguments to pass to the agent.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
str: The response from the agent.
|
| 176 |
+
"""
|
| 177 |
+
if remote is None:
|
| 178 |
+
remote = self.remote
|
| 179 |
+
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
| 180 |
+
if isinstance(agent_output, self.agent_image_type):
|
| 181 |
+
agent_output = agent_output.to_raw()
|
| 182 |
+
return agent_output
|
| 183 |
+
|
| 184 |
+
def chat(
|
| 185 |
+
self,
|
| 186 |
+
*args: Any,
|
| 187 |
+
remote: Optional[bool] = None,
|
| 188 |
+
**kwargs: Any,
|
| 189 |
+
) -> Any:
|
| 190 |
+
r"""Runs the agent in a chat conversation mode.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
*args (Any): Positional arguments to pass to the agent.
|
| 194 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 195 |
+
remotely. Overrides the default setting. (default: :obj:`None`)
|
| 196 |
+
**kwargs (Any): Keyword arguments to pass to the agent.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
str: The response from the agent.
|
| 200 |
+
"""
|
| 201 |
+
if remote is None:
|
| 202 |
+
remote = self.remote
|
| 203 |
+
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
| 204 |
+
if isinstance(agent_output, self.agent_image_type):
|
| 205 |
+
agent_output = agent_output.to_raw()
|
| 206 |
+
return agent_output
|
camel/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from .apibank import APIBankBenchmark
|
| 16 |
+
from .apibench import APIBenchBenchmark
|
| 17 |
+
from .base import BaseBenchmark
|
| 18 |
+
from .gaia import DefaultGAIARetriever, GAIABenchmark
|
| 19 |
+
from .nexus import NexusBenchmark
|
| 20 |
+
from .ragbench import RAGBenchBenchmark
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"BaseBenchmark",
|
| 24 |
+
"GAIABenchmark",
|
| 25 |
+
"DefaultGAIARetriever",
|
| 26 |
+
"NexusBenchmark",
|
| 27 |
+
"APIBenchBenchmark",
|
| 28 |
+
"APIBankBenchmark",
|
| 29 |
+
"RAGBenchBenchmark",
|
| 30 |
+
]
|
camel/benchmarks/apibank.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from rouge import Rouge
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
from camel.agents import ChatAgent
|
| 29 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 30 |
+
from camel.messages import BaseMessage
|
| 31 |
+
from camel.utils import download_github_subdirectory
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# Add current folder to sys.path to enable relative import
|
| 36 |
+
current_folder = os.getcwd()
|
| 37 |
+
if current_folder not in sys.path:
|
| 38 |
+
sys.path.append(current_folder)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def process_messages(
|
| 42 |
+
chat_history: List[Dict[str, Any]],
|
| 43 |
+
prompt: str,
|
| 44 |
+
) -> List[Dict[str, str]]:
|
| 45 |
+
"""
|
| 46 |
+
Processes chat history into a structured format for further use.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
chat_history (List[Dict[str, Any]):
|
| 50 |
+
A list of dictionaries representing the chat history.
|
| 51 |
+
prompt (str): A propmt to be set as the system message.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List[Dict[str, str]]: A list of dictionaries representing
|
| 55 |
+
the processed messages, where each dictionary has:
|
| 56 |
+
- 'role': The role of the message ('system', 'user', or 'assistant').
|
| 57 |
+
- 'content': The content of the message, including formatted
|
| 58 |
+
API responses when applicable.
|
| 59 |
+
"""
|
| 60 |
+
messages = [{'role': 'system', 'content': prompt}]
|
| 61 |
+
for item in chat_history:
|
| 62 |
+
role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
|
| 63 |
+
chat_role = role_map.get(
|
| 64 |
+
item['role'], 'unknown'
|
| 65 |
+
) # default role to 'unknown'
|
| 66 |
+
if item['role'] == 'API':
|
| 67 |
+
chat_content = '[{}({})] Response: {}'.format(
|
| 68 |
+
item['api_name'],
|
| 69 |
+
', '.join(
|
| 70 |
+
[
|
| 71 |
+
'{}=\'{}\''.format(k, v)
|
| 72 |
+
for k, v in item['param_dict'].items()
|
| 73 |
+
]
|
| 74 |
+
),
|
| 75 |
+
str(item['result']['output']),
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
chat_content = item['text']
|
| 79 |
+
messages.append({'role': chat_role, 'content': chat_content})
|
| 80 |
+
return messages
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class APIBankBenchmark(BaseBenchmark):
|
| 84 |
+
r"""API-Bank Benchmark adapted from `API-Bank:
|
| 85 |
+
A Comprehensive Benchmark for Tool-Augmented LLMs`
|
| 86 |
+
<https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
save_to (str): The file to save the results.
|
| 90 |
+
processes (int, optional): The number of processes to use.
|
| 91 |
+
(default: :obj:`1`)
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
save_to: str,
|
| 97 |
+
processes: int = 1,
|
| 98 |
+
):
|
| 99 |
+
r"""Initialize the APIBank benchmark.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
save_to (str): The file to save the results.
|
| 103 |
+
processes (int, optional): The number of processes to use for
|
| 104 |
+
parallel processing. (default: :obj:`1`)
|
| 105 |
+
"""
|
| 106 |
+
# Predefine data_dir for better import management
|
| 107 |
+
super().__init__("apibank", "api_bank", save_to, processes)
|
| 108 |
+
self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
|
| 109 |
+
|
| 110 |
+
def download(self):
|
| 111 |
+
r"""Download APIBank dataset and code from Github."""
|
| 112 |
+
|
| 113 |
+
repo = "AlibabaResearch/DAMO-ConvAI"
|
| 114 |
+
subdir = "api-bank"
|
| 115 |
+
data_dir = self.data_dir
|
| 116 |
+
|
| 117 |
+
download_github_subdirectory(repo, subdir, data_dir)
|
| 118 |
+
|
| 119 |
+
sys.path.insert(0, self.data_dir)
|
| 120 |
+
logger.info("Download completed.")
|
| 121 |
+
|
| 122 |
+
def load(self, level: str, force_download: bool = False): # type: ignore[override]
|
| 123 |
+
r"""Load the APIBank Benchmark dataset.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
level (str): Level to run benchmark on.
|
| 127 |
+
force_download (bool, optional): Whether to
|
| 128 |
+
force download the data.
|
| 129 |
+
"""
|
| 130 |
+
if force_download:
|
| 131 |
+
logger.info("Force downloading data.")
|
| 132 |
+
self.download()
|
| 133 |
+
|
| 134 |
+
if level == "level-1":
|
| 135 |
+
file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
|
| 136 |
+
elif level == 'level-2':
|
| 137 |
+
file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
|
| 138 |
+
jsonl_files = [
|
| 139 |
+
f for f in os.listdir(file_path) if f.endswith('.jsonl')
|
| 140 |
+
]
|
| 141 |
+
for file in tqdm(jsonl_files, desc="Processing files"):
|
| 142 |
+
history = []
|
| 143 |
+
with open(file_path / file, 'r') as f:
|
| 144 |
+
for line in f:
|
| 145 |
+
history.append(json.loads(line))
|
| 146 |
+
samples = APIBankSample.from_chat_history(history)
|
| 147 |
+
self._data[file.rsplit('.', 1)[0]] = samples
|
| 148 |
+
|
| 149 |
+
# Change import to relative import in the downloaded python files
|
| 150 |
+
def process_files(folder_path, replacements):
|
| 151 |
+
r"""Replace absolute imports in downloaded files with
|
| 152 |
+
relative import."""
|
| 153 |
+
for file in os.listdir(folder_path):
|
| 154 |
+
if file.endswith(".py"):
|
| 155 |
+
file_path = os.path.join(folder_path, file)
|
| 156 |
+
try:
|
| 157 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 158 |
+
content = file.read()
|
| 159 |
+
|
| 160 |
+
original_content = content
|
| 161 |
+
|
| 162 |
+
for pattern, replacement in replacements:
|
| 163 |
+
content = re.sub(pattern, replacement, content)
|
| 164 |
+
|
| 165 |
+
if content != original_content:
|
| 166 |
+
with open(
|
| 167 |
+
file_path, "w", encoding="utf-8"
|
| 168 |
+
) as file:
|
| 169 |
+
file.write(content)
|
| 170 |
+
logger.info(f"Updated file: {file_path}")
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.info(f"Error processing file {file_path}: {e}")
|
| 174 |
+
|
| 175 |
+
api_bank_folder = "api_bank"
|
| 176 |
+
apis_folder = os.path.join(api_bank_folder, "apis")
|
| 177 |
+
|
| 178 |
+
apis_replacements = [
|
| 179 |
+
(r"from apis.api", "from .api"),
|
| 180 |
+
(r"from apis import", "from .api import"),
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
api_bank_replacements = [
|
| 184 |
+
(r"from apis", "from .apis"),
|
| 185 |
+
(r"from api_call_extraction", "from .api_call_extraction"),
|
| 186 |
+
(r"f'{basename}", r"f'api_bank.{basename}"),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
process_files(apis_folder, apis_replacements)
|
| 190 |
+
process_files(api_bank_folder, api_bank_replacements)
|
| 191 |
+
|
| 192 |
+
def run( # type: ignore[override, return]
|
| 193 |
+
self,
|
| 194 |
+
agent: ChatAgent,
|
| 195 |
+
level: Literal["level-1", "level-2"],
|
| 196 |
+
api_test_enabled=True,
|
| 197 |
+
randomize: bool = False,
|
| 198 |
+
subset: Optional[int] = None,
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
r"""Run the benchmark.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
agent (ChatAgent): The agent to run the
|
| 204 |
+
benchmark.
|
| 205 |
+
level (Literal['level-1', 'level-2']):
|
| 206 |
+
The level to run the benchmark on.
|
| 207 |
+
randomize (bool, optional): Whether to
|
| 208 |
+
randomize the data.
|
| 209 |
+
api_test_enabled (bool): Whether to test
|
| 210 |
+
API calling (`True`) or response (`False`)
|
| 211 |
+
(default: :obj:`False`)
|
| 212 |
+
subset (Optional[int], optional):
|
| 213 |
+
The subset of data to run.
|
| 214 |
+
(default: :obj:`None`)
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Dict[str, Any]: The results of the benchmark.
|
| 218 |
+
"""
|
| 219 |
+
logger.info(f"Running APIBench benchmark on {level}.")
|
| 220 |
+
self.load(level)
|
| 221 |
+
datas = self._data
|
| 222 |
+
|
| 223 |
+
# Shuffle and subset data if necessary
|
| 224 |
+
if randomize:
|
| 225 |
+
randomized_items = list(datas.items())
|
| 226 |
+
random.shuffle(randomized_items)
|
| 227 |
+
datas = dict(randomized_items)
|
| 228 |
+
if subset:
|
| 229 |
+
datas = dict(list(datas.items())[:subset])
|
| 230 |
+
|
| 231 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 232 |
+
|
| 233 |
+
# Initialize results storage
|
| 234 |
+
self._results = []
|
| 235 |
+
|
| 236 |
+
# The following code are adapted from the evaluator
|
| 237 |
+
# from the original repo:
|
| 238 |
+
tool_search_enabled = level == "level-2"
|
| 239 |
+
dialog_test_enabled = not api_test_enabled
|
| 240 |
+
total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
|
| 241 |
+
|
| 242 |
+
with open(self.save_to, "w") as f:
|
| 243 |
+
for test in tqdm(datas, desc="Running"):
|
| 244 |
+
samples = self._data[test]
|
| 245 |
+
evaluator = Evaluator(samples) # type: ignore[arg-type]
|
| 246 |
+
|
| 247 |
+
for sample_id in evaluator.get_all_sample_ids():
|
| 248 |
+
# Process sample and generate response
|
| 249 |
+
sample = evaluator.dataset[sample_id]
|
| 250 |
+
|
| 251 |
+
if (
|
| 252 |
+
sample.ground_truth['role'] == 'API'
|
| 253 |
+
and api_test_enabled
|
| 254 |
+
):
|
| 255 |
+
if tool_search_enabled:
|
| 256 |
+
_, chat_history = evaluator.get_model_input(
|
| 257 |
+
sample_id
|
| 258 |
+
)
|
| 259 |
+
api_descriptions = evaluator.get_api_description(
|
| 260 |
+
'ToolSearcher'
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
api_descriptions, chat_history = (
|
| 264 |
+
evaluator.get_model_input(sample_id)
|
| 265 |
+
)
|
| 266 |
+
messages = process_messages(
|
| 267 |
+
chat_history, API_CALL_PROMPT + api_descriptions
|
| 268 |
+
)
|
| 269 |
+
model_output = agent_call(messages, agent)
|
| 270 |
+
api_call = get_api_call(model_output)
|
| 271 |
+
|
| 272 |
+
# Evaluate API call
|
| 273 |
+
if api_call:
|
| 274 |
+
try:
|
| 275 |
+
correct, model_output_result = (
|
| 276 |
+
evaluator.evaluate(sample_id, api_call)
|
| 277 |
+
)
|
| 278 |
+
except AssertionError as e:
|
| 279 |
+
if 'The API name is not correct.' not in str(
|
| 280 |
+
e
|
| 281 |
+
):
|
| 282 |
+
raise e
|
| 283 |
+
logging.info('AssertionError: {}'.format(e))
|
| 284 |
+
correct = False
|
| 285 |
+
else:
|
| 286 |
+
model_output_result = 'No API call found'
|
| 287 |
+
correct = False
|
| 288 |
+
if correct:
|
| 289 |
+
correct_api_calls += 1
|
| 290 |
+
logging.info(
|
| 291 |
+
'Correct API call: {} Ground truth: {}'.format(
|
| 292 |
+
api_call, sample.ground_truth
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
logging.info(
|
| 297 |
+
'Incorrect model output: {} Result: {} \
|
| 298 |
+
Ground truth: {} File: {} Sample ID: {} \
|
| 299 |
+
Messages: {}'.format(
|
| 300 |
+
model_output.replace('\n', ' '),
|
| 301 |
+
model_output_result,
|
| 302 |
+
sample.ground_truth,
|
| 303 |
+
test,
|
| 304 |
+
sample_id,
|
| 305 |
+
messages[1:],
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
total_api_calls += 1
|
| 309 |
+
self._results.append(
|
| 310 |
+
{
|
| 311 |
+
'Role': 'API',
|
| 312 |
+
'Model_output': model_output,
|
| 313 |
+
'Model_output_result': model_output_result,
|
| 314 |
+
'Ground_truth': sample.ground_truth,
|
| 315 |
+
'Test': test,
|
| 316 |
+
'Correct': correct,
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 320 |
+
|
| 321 |
+
elif (
|
| 322 |
+
sample.ground_truth['role'] == 'AI'
|
| 323 |
+
and dialog_test_enabled
|
| 324 |
+
):
|
| 325 |
+
# Process sample and generate response
|
| 326 |
+
api_descriptions, chat_history = (
|
| 327 |
+
evaluator.get_model_input(sample_id)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
messages = process_messages(
|
| 331 |
+
chat_history, RESPONSE_PROMPT + api_descriptions
|
| 332 |
+
)
|
| 333 |
+
model_output = agent_call(messages, agent)
|
| 334 |
+
|
| 335 |
+
# Evaluate model response
|
| 336 |
+
if model_output:
|
| 337 |
+
score = evaluator.evaluate(sample_id, model_output)
|
| 338 |
+
else:
|
| 339 |
+
score = 0
|
| 340 |
+
rougel_scores.append(score)
|
| 341 |
+
if score < 0.2:
|
| 342 |
+
logging.info(
|
| 343 |
+
'Low score: {} Score: {} Ground truth: {} \
|
| 344 |
+
Test: {} Sample ID: {} \
|
| 345 |
+
Messages: {}'.format(
|
| 346 |
+
model_output.replace('\n', ' '),
|
| 347 |
+
score,
|
| 348 |
+
sample.ground_truth,
|
| 349 |
+
test,
|
| 350 |
+
sample_id,
|
| 351 |
+
messages[1:],
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self._results.append(
|
| 356 |
+
{
|
| 357 |
+
'Role': 'AI',
|
| 358 |
+
'Model_output': model_output,
|
| 359 |
+
'Score': score,
|
| 360 |
+
'Ground_truth': sample.ground_truth,
|
| 361 |
+
'Test': test,
|
| 362 |
+
}
|
| 363 |
+
)
|
| 364 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 365 |
+
|
| 366 |
+
f.flush()
|
| 367 |
+
|
| 368 |
+
if api_test_enabled:
|
| 369 |
+
return {
|
| 370 |
+
'total': total_api_calls,
|
| 371 |
+
'correct': correct_api_calls,
|
| 372 |
+
"accuracy": correct_api_calls / total_api_calls
|
| 373 |
+
if total_api_calls
|
| 374 |
+
else 0,
|
| 375 |
+
}
|
| 376 |
+
elif dialog_test_enabled:
|
| 377 |
+
return {'Dialog_score': np.mean(rougel_scores)}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# The following code are migrated from the original repo:
|
| 381 |
+
# https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
|
| 382 |
+
def agent_call(messages: List[Dict], agent: ChatAgent):
|
| 383 |
+
r"""Add messages to agent memory and get response."""
|
| 384 |
+
for i, msg in enumerate(messages):
|
| 385 |
+
if msg['role'] == 'user':
|
| 386 |
+
message = BaseMessage.make_user_message(
|
| 387 |
+
role_name="CAMEL User", content=msg['content']
|
| 388 |
+
)
|
| 389 |
+
elif msg['role'] == 'assistant':
|
| 390 |
+
message = BaseMessage.make_assistant_message(
|
| 391 |
+
role_name="CAMEL Assistant", content=msg['content']
|
| 392 |
+
)
|
| 393 |
+
elif msg['role'] == 'system':
|
| 394 |
+
message = BaseMessage.make_assistant_message(
|
| 395 |
+
role_name="System", content=msg['content']
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
raise ValueError(f"Unrecognized role: {msg['role']}")
|
| 399 |
+
|
| 400 |
+
if i == len(messages) - 1:
|
| 401 |
+
break
|
| 402 |
+
agent.record_message(message)
|
| 403 |
+
|
| 404 |
+
response = agent.step(message)
|
| 405 |
+
model_output = response.msgs[0].content
|
| 406 |
+
agent.reset()
|
| 407 |
+
return model_output
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def calculate_rouge_l_score(reference, hypothesis):
|
| 411 |
+
r"""Calculate rouge l score between hypothesis and reference."""
|
| 412 |
+
rouge = Rouge()
|
| 413 |
+
scores = rouge.get_scores(hypothesis, reference)
|
| 414 |
+
rouge_l_score = scores[0]['rouge-l']['f']
|
| 415 |
+
return rouge_l_score
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def get_api_call(model_output):
|
| 419 |
+
r"""Parse api call from model output."""
|
| 420 |
+
api_call_pattern = r"\[(\w+)\((.*)\)\]"
|
| 421 |
+
api_call_pattern = re.compile(api_call_pattern)
|
| 422 |
+
match = api_call_pattern.search(model_output)
|
| 423 |
+
if match:
|
| 424 |
+
return match.group(0)
|
| 425 |
+
else:
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class APIBankSample:
|
| 430 |
+
r"""APIBank sample used to load the datasets."""
|
| 431 |
+
|
| 432 |
+
def __init__(self, chat_history, apis, ground_truth):
|
| 433 |
+
self.chat_history = chat_history
|
| 434 |
+
self.apis = apis
|
| 435 |
+
self.ground_truth = ground_truth
|
| 436 |
+
|
| 437 |
+
def __repr__(self):
|
| 438 |
+
return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
|
| 439 |
+
self.chat_history, self.apis, self.ground_truth
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
@classmethod
|
| 443 |
+
def from_chat_history(cls, chat_history):
|
| 444 |
+
apis = set()
|
| 445 |
+
api_positions = []
|
| 446 |
+
for i, item in enumerate(chat_history):
|
| 447 |
+
if item['role'] == 'API':
|
| 448 |
+
apis.add(item['api_name'])
|
| 449 |
+
api_positions.append(i)
|
| 450 |
+
|
| 451 |
+
samples = []
|
| 452 |
+
for i in api_positions:
|
| 453 |
+
sample = cls(chat_history[:i], apis, chat_history[i])
|
| 454 |
+
samples.append(sample)
|
| 455 |
+
sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
|
| 456 |
+
samples.append(sample)
|
| 457 |
+
|
| 458 |
+
return samples
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class Evaluator:
|
| 462 |
+
r"""Evaluator for APIBank benchmark."""
|
| 463 |
+
|
| 464 |
+
def __init__(self, samples: List[APIBankSample]):
|
| 465 |
+
# Place holder for import as the import
|
| 466 |
+
# only works after the files have been downloaded
|
| 467 |
+
try:
|
| 468 |
+
from api_bank.tool_manager import ( # type: ignore[import-not-found]
|
| 469 |
+
ToolManager,
|
| 470 |
+
)
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.info(f"{e}, Module will be imported after download.")
|
| 473 |
+
self.dataset = samples
|
| 474 |
+
self.sample_ids = list(range(len(self.dataset)))
|
| 475 |
+
os.chdir("api_bank")
|
| 476 |
+
self.tool_manager = ToolManager("apis")
|
| 477 |
+
os.chdir("..")
|
| 478 |
+
|
| 479 |
+
def get_all_sample_ids(self):
|
| 480 |
+
return self.sample_ids
|
| 481 |
+
|
| 482 |
+
def get_api_description(self, api_name):
|
| 483 |
+
return self.tool_manager.get_api_description(api_name)
|
| 484 |
+
|
| 485 |
+
def get_model_input(self, sample_id: int):
|
| 486 |
+
sample = self.dataset[sample_id]
|
| 487 |
+
apis = sample.apis
|
| 488 |
+
chat_history = sample.chat_history
|
| 489 |
+
api_descriptions = []
|
| 490 |
+
for api_name in apis:
|
| 491 |
+
api_descriptions.append(
|
| 492 |
+
self.tool_manager.get_api_description(api_name)
|
| 493 |
+
)
|
| 494 |
+
api_description = '\n'.join(api_descriptions)
|
| 495 |
+
return api_description, chat_history
|
| 496 |
+
|
| 497 |
+
def evaluate(self, sample_id, model_output):
|
| 498 |
+
try:
|
| 499 |
+
from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
|
| 500 |
+
parse_api_call,
|
| 501 |
+
)
|
| 502 |
+
except Exception as e:
|
| 503 |
+
logger.info(f"{e}, Module will be imported after download.")
|
| 504 |
+
sample = self.dataset[sample_id]
|
| 505 |
+
ground_truth = sample.ground_truth
|
| 506 |
+
if ground_truth['role'] == 'API':
|
| 507 |
+
api_name, param_dict = parse_api_call(model_output)
|
| 508 |
+
if api_name != ground_truth['api_name']:
|
| 509 |
+
return False, 'API Name Mismatch: {} vs {}'.format(
|
| 510 |
+
api_name, ground_truth['api_name']
|
| 511 |
+
)
|
| 512 |
+
try:
|
| 513 |
+
result = self.tool_manager.api_call(api_name, **param_dict)
|
| 514 |
+
except Exception as e:
|
| 515 |
+
return False, str(e)
|
| 516 |
+
api = self.tool_manager.init_tool(api_name)
|
| 517 |
+
try:
|
| 518 |
+
correct = api.check_api_call_correctness(
|
| 519 |
+
result, ground_truth['result']
|
| 520 |
+
)
|
| 521 |
+
except KeyError:
|
| 522 |
+
correct = False
|
| 523 |
+
result = 'KeyError' + str(result)
|
| 524 |
+
return correct, result
|
| 525 |
+
elif ground_truth['role'] == 'AI':
|
| 526 |
+
score = calculate_rouge_l_score(ground_truth['text'], model_output)
|
| 527 |
+
return round(score, 4)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
API_CALL_PROMPT = '''
|
| 531 |
+
Based on the given API description and the existing \
|
| 532 |
+
conversation history 1..t, please generate the API request \
|
| 533 |
+
that the AI should call in step t+1 and output it in the \
|
| 534 |
+
format of [ApiName(key1='value1', key2='value2', ...)], \
|
| 535 |
+
replace the ApiName with the actual API name, and \
|
| 536 |
+
replace the key and value with the actual parameters. \
|
| 537 |
+
Your output should start with a square bracket "[" \
|
| 538 |
+
and end with a square bracket "]". Do not output any \
|
| 539 |
+
other explanation or prompt or the result of the API call in your output.
|
| 540 |
+
This year is 2023.
|
| 541 |
+
Input:
|
| 542 |
+
User: [User's utterence]
|
| 543 |
+
AI: [AI's utterence]
|
| 544 |
+
|
| 545 |
+
Expected output:
|
| 546 |
+
[ApiName(key1='value1', key2='value2', ...)]
|
| 547 |
+
|
| 548 |
+
API descriptions:
|
| 549 |
+
'''
|
| 550 |
+
|
| 551 |
+
RESPONSE_PROMPT = '''
|
| 552 |
+
Based on the given API description and the existing \
|
| 553 |
+
conversation history 1..t, please generate the next \
|
| 554 |
+
dialog that the AI should response after the API call t.
|
| 555 |
+
This year is 2023.
|
| 556 |
+
Input:
|
| 557 |
+
User: [User's utterence]
|
| 558 |
+
AI: [AI's utterence]
|
| 559 |
+
[ApiName(key1='value1', key2='value2', …)]
|
| 560 |
+
|
| 561 |
+
Expected output:
|
| 562 |
+
AI: [AI's utterence]
|
| 563 |
+
|
| 564 |
+
API descriptions:
|
| 565 |
+
'''
|
camel/benchmarks/apibench.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import random
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Dict, Literal, Optional
|
| 20 |
+
|
| 21 |
+
import tree_sitter_python as tspython
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from tree_sitter import Language, Parser
|
| 24 |
+
|
| 25 |
+
from camel.agents import ChatAgent
|
| 26 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 27 |
+
from camel.messages import BaseMessage
|
| 28 |
+
from camel.utils import download_github_subdirectory
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Mapping of dataset names to file names
|
| 34 |
+
# 'Oracle' retriver used here which means all the full
|
| 35 |
+
# API documentation will be included in the prompt
|
| 36 |
+
dataset_mapping = {
|
| 37 |
+
"huggingface": {
|
| 38 |
+
"api": "huggingface_api.jsonl",
|
| 39 |
+
"eval": "huggingface_eval.json",
|
| 40 |
+
"train": "huggingface_train.json",
|
| 41 |
+
"questions": "questions_huggingface_oracle.jsonl",
|
| 42 |
+
},
|
| 43 |
+
"tensorflowhub": {
|
| 44 |
+
"api": "tensorflowhub_api.jsonl",
|
| 45 |
+
"eval": "tensorflow_eval.json",
|
| 46 |
+
"train": "tensorflow_train.json",
|
| 47 |
+
"questions": "questions_tensorflowhub_oracle.jsonl",
|
| 48 |
+
},
|
| 49 |
+
"torchhub": {
|
| 50 |
+
"api": "torchhub_api.jsonl",
|
| 51 |
+
"eval": "torchhub_eval.json",
|
| 52 |
+
"train": "torchhub_train.json",
|
| 53 |
+
"questions": "questions_torchhub_oracle.jsonl",
|
| 54 |
+
},
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# This function is migrated from the original repo:
|
| 59 |
+
# https://github.com/ShishirPatil/gorilla
|
| 60 |
+
def encode_question(question: str, dataset_name: str) -> str:
|
| 61 |
+
r"""Encode multiple prompt instructions into a single string."""
|
| 62 |
+
|
| 63 |
+
if dataset_name == "torchhub":
|
| 64 |
+
domains = "1. $DOMAIN is inferred from the task description and \
|
| 65 |
+
should include one of {Classification, Semantic Segmentation, \
|
| 66 |
+
Object Detection, Audio Separation, Video Classification, \
|
| 67 |
+
Text-to-Speech}."
|
| 68 |
+
elif dataset_name == "huggingface":
|
| 69 |
+
domains = "1. $DOMAIN should include one of {Multimodal Feature \
|
| 70 |
+
Extraction, Multimodal Text-to-Image, Multimodal \
|
| 71 |
+
Image-to-Text, Multimodal Text-to-Video, \
|
| 72 |
+
Multimodal Visual Question Answering, Multimodal Document \
|
| 73 |
+
Question Answer, Multimodal Graph Machine Learning, \
|
| 74 |
+
Computer Vision Depth Estimation, Computer Vision Image \
|
| 75 |
+
Classification, Computer Vision Object Detection, \
|
| 76 |
+
Computer Vision Image Segmentation, Computer Vision \
|
| 77 |
+
Image-to-Image, Computer Vision Unconditional \
|
| 78 |
+
Image Generation, Computer Vision Video Classification, \
|
| 79 |
+
Computer Vision Zero-Shor Image Classification, \
|
| 80 |
+
Natural Language Processing Text Classification, \
|
| 81 |
+
Natural Language Processing Token Classification, \
|
| 82 |
+
Natural Language Processing Table Question Answering, \
|
| 83 |
+
Natural Language Processing Question Answering, \
|
| 84 |
+
Natural Language Processing, Zero-Shot Classification \
|
| 85 |
+
Natural Language Processing Translation, Natural Language \
|
| 86 |
+
Processing Summarization, Natural Language Processing \
|
| 87 |
+
Conversational, Natural Language Processing Text \
|
| 88 |
+
Generation, Natural Language Processing Fill-Mask, \
|
| 89 |
+
Natural Language Processing Text2Text Generation, \
|
| 90 |
+
Natural Language Processing Sentence Similarity, \
|
| 91 |
+
Audio Text-to-Speech, Audio Automatic Speech Recognition, \
|
| 92 |
+
Audio Audio-to-Audio, Audio Audio Classification, \
|
| 93 |
+
Audio Voice Activity Detection, Tabular Tabular \
|
| 94 |
+
Classification, Tabular Tabular Regression, \
|
| 95 |
+
Reinforcement Learning Reinforcement Learning, \
|
| 96 |
+
Reinforcement Learning Robotics }"
|
| 97 |
+
elif dataset_name == "tensorflowhub":
|
| 98 |
+
domains = "1. $DOMAIN is inferred from the task description \
|
| 99 |
+
and should include one of {text-sequence-alignment, \
|
| 100 |
+
text-embedding, text-language-model, text-preprocessing, \
|
| 101 |
+
text-classification, text-generation, text-question-answering, \
|
| 102 |
+
text-retrieval-question-answering, text-segmentation, \
|
| 103 |
+
text-to-mel, image-classification, image-feature-vector, \
|
| 104 |
+
image-object-detection, image-segmentation, \
|
| 105 |
+
image-generator, image-pose-detection, image-rnn-agent, \
|
| 106 |
+
image-augmentation, image-classifier, image-style-transfer, \
|
| 107 |
+
image-aesthetic-quality, image-depth-estimation, \
|
| 108 |
+
image-super-resolution, image-deblurring, image-extrapolation, \
|
| 109 |
+
image-text-recognition, image-dehazing, image-deraining, \
|
| 110 |
+
image-enhancemenmt, image-classification-logits, \
|
| 111 |
+
image-frame-interpolation, image-text-detection, image-denoising, \
|
| 112 |
+
image-others, video-classification, video-feature-extraction, \
|
| 113 |
+
video-generation, video-audio-text, video-text, \
|
| 114 |
+
audio-embedding, audio-event-classification, audio-command-detection, \
|
| 115 |
+
audio-paralinguists-classification, audio-speech-to-text, \
|
| 116 |
+
audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
|
| 117 |
+
else:
|
| 118 |
+
logger.info("Error: API name is not supported.")
|
| 119 |
+
|
| 120 |
+
prompt = (
|
| 121 |
+
question
|
| 122 |
+
+ "\nWrite a python program in 1 to 2 lines to call API in "
|
| 123 |
+
+ dataset_name
|
| 124 |
+
+ ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
|
| 125 |
+
<<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
|
| 126 |
+
<<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
|
| 127 |
+
Here are the requirements:\n"
|
| 128 |
+
+ domains
|
| 129 |
+
+ "\n2. The $API_CALL should have only 1 line of code \
|
| 130 |
+
that calls api.\n 3. The $API_PROVIDER should be the \
|
| 131 |
+
programming framework used.\n4. $EXPLANATION should be \
|
| 132 |
+
a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
|
| 133 |
+
Do not repeat the format in your answer."
|
| 134 |
+
)
|
| 135 |
+
return prompt
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class APIBenchBenchmark(BaseBenchmark):
|
| 139 |
+
r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
|
| 140 |
+
Connected with Massive APIs`
|
| 141 |
+
<https://huggingface.co/datasets/gorilla-llm/APIBench>.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data_dir (str): The directory to save the data.
|
| 145 |
+
save_to (str): The file to save the results.
|
| 146 |
+
processes (int, optional): The number of processes to use.
|
| 147 |
+
(default: :obj:`1`)
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
# TODO: Integrate retriever (pending)
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
data_dir: str,
|
| 155 |
+
save_to: str,
|
| 156 |
+
processes: int = 1,
|
| 157 |
+
):
|
| 158 |
+
r"""Initialize the APIBench benchmark.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
data_dir (str): The directory to save the data.
|
| 162 |
+
save_to (str): The file to save the results.
|
| 163 |
+
processes (int, optional): The number of processes to use for
|
| 164 |
+
parallel processing. (default: :obj:`1`)
|
| 165 |
+
"""
|
| 166 |
+
super().__init__("apibench", data_dir, save_to, processes)
|
| 167 |
+
|
| 168 |
+
def download(self):
|
| 169 |
+
r"""Download the APIBench dataset."""
|
| 170 |
+
from huggingface_hub import snapshot_download
|
| 171 |
+
|
| 172 |
+
snapshot_download(
|
| 173 |
+
repo_id="gorilla-llm/APIBench",
|
| 174 |
+
repo_type="dataset",
|
| 175 |
+
local_dir=self.data_dir,
|
| 176 |
+
local_dir_use_symlinks=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
repo = "ShishirPatil/gorilla"
|
| 180 |
+
subdir = "/gorilla/eval/eval-data/questions"
|
| 181 |
+
data_dir = self.data_dir
|
| 182 |
+
|
| 183 |
+
download_github_subdirectory(repo, subdir, data_dir)
|
| 184 |
+
|
| 185 |
+
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
| 186 |
+
r"""Load the APIBench Benchmark dataset.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
dataset_name (str): Name of the specific dataset to be loaded.
|
| 190 |
+
force_download (bool, optional): Whether to force
|
| 191 |
+
download the data. (default: :obj:`False`)
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
if force_download:
|
| 195 |
+
logger.info("Force downloading data.")
|
| 196 |
+
self.download()
|
| 197 |
+
|
| 198 |
+
def load_json_lines(file_path: Path):
|
| 199 |
+
r"""Helper function to load JSON lines from a file."""
|
| 200 |
+
try:
|
| 201 |
+
with open(file_path, "r") as f:
|
| 202 |
+
return [json.loads(line) for line in f]
|
| 203 |
+
except FileNotFoundError:
|
| 204 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 205 |
+
except json.JSONDecodeError as e:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Error decoding JSON in file {file_path}: {e}"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
dataset_path = self.data_dir / dataset_name
|
| 211 |
+
if not dataset_path.exists():
|
| 212 |
+
raise FileNotFoundError(
|
| 213 |
+
f"Dataset directory does not exist: {dataset_path}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
for label in ['api', 'eval', 'questions']:
|
| 217 |
+
file_name = dataset_mapping[dataset_name][label]
|
| 218 |
+
file_path = (
|
| 219 |
+
dataset_path / file_name
|
| 220 |
+
if label == 'questions'
|
| 221 |
+
else self.data_dir / file_name
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Load data based on label type
|
| 225 |
+
if label in ['api', 'questions', 'eval']:
|
| 226 |
+
data = load_json_lines(file_path)
|
| 227 |
+
|
| 228 |
+
if label == 'eval':
|
| 229 |
+
# Extract 'api_data' specifically for eval label
|
| 230 |
+
data = [item['api_data'] for item in data]
|
| 231 |
+
|
| 232 |
+
self._data[label] = data
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Unknown label: {label}")
|
| 235 |
+
|
| 236 |
+
ast_database = []
|
| 237 |
+
for data in self._data['api']:
|
| 238 |
+
ast_tree = ast_parse(data['api_call'])
|
| 239 |
+
ast_database.append(ast_tree)
|
| 240 |
+
self._data['ast'] = ast_database
|
| 241 |
+
|
| 242 |
+
def run( # type: ignore[override]
|
| 243 |
+
self,
|
| 244 |
+
agent: ChatAgent,
|
| 245 |
+
dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
|
| 246 |
+
randomize: bool = False,
|
| 247 |
+
subset: Optional[int] = None,
|
| 248 |
+
) -> Dict[str, Any]:
|
| 249 |
+
r"""Run the benchmark.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
agent (ChatAgent): The agent to run the
|
| 253 |
+
benchmark.
|
| 254 |
+
dataset_name (Literal["huggingface",
|
| 255 |
+
"tensorflowhub", "torchhub"]):
|
| 256 |
+
The dataset to run the benchmark.
|
| 257 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 258 |
+
(default: :obj:`False`)
|
| 259 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 260 |
+
(default: :obj:`None`)
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
if dataset_name not in dataset_mapping:
|
| 264 |
+
raise ValueError(f"Invalid value for dataset: {dataset_name}.")
|
| 265 |
+
|
| 266 |
+
logger.info(f"Running APIBench benchmark on {dataset_name}.")
|
| 267 |
+
self.load(dataset_name)
|
| 268 |
+
datas = self._data['questions']
|
| 269 |
+
|
| 270 |
+
# Shuffle and subset data if necessary
|
| 271 |
+
if randomize:
|
| 272 |
+
random.shuffle(datas)
|
| 273 |
+
if subset:
|
| 274 |
+
datas = datas[:subset]
|
| 275 |
+
|
| 276 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 277 |
+
|
| 278 |
+
# Initialize results storage
|
| 279 |
+
self._results = []
|
| 280 |
+
|
| 281 |
+
with open(self.save_to, "w") as f:
|
| 282 |
+
for question in tqdm(datas, desc="Running"):
|
| 283 |
+
prompt = encode_question(question["text"], dataset_name)
|
| 284 |
+
msg = BaseMessage.make_user_message(
|
| 285 |
+
role_name="User", content=prompt
|
| 286 |
+
)
|
| 287 |
+
try:
|
| 288 |
+
# Generate response
|
| 289 |
+
responses = agent.step(msg)
|
| 290 |
+
response = responses.msgs[0].content
|
| 291 |
+
api_database = self._data['api']
|
| 292 |
+
qa_pairs = self._data['eval']
|
| 293 |
+
ast_database = self._data['ast']
|
| 294 |
+
question_id = question['question_id']
|
| 295 |
+
|
| 296 |
+
# Evaluate response
|
| 297 |
+
error, correct, hallucination = evaluate_response(
|
| 298 |
+
response,
|
| 299 |
+
question_id,
|
| 300 |
+
dataset_name,
|
| 301 |
+
api_database,
|
| 302 |
+
qa_pairs,
|
| 303 |
+
ast_database,
|
| 304 |
+
)
|
| 305 |
+
self._results.append(
|
| 306 |
+
{
|
| 307 |
+
"question": question,
|
| 308 |
+
"agent_response": response,
|
| 309 |
+
"correct": correct,
|
| 310 |
+
"hallucination": hallucination,
|
| 311 |
+
"error": str(error) if error else None,
|
| 312 |
+
}
|
| 313 |
+
)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.warning(
|
| 316 |
+
f"Error in processing task: {question}: {e}"
|
| 317 |
+
)
|
| 318 |
+
self._results.append(
|
| 319 |
+
{
|
| 320 |
+
"question": question,
|
| 321 |
+
"agent_response": None,
|
| 322 |
+
"correct": False,
|
| 323 |
+
"hallucination": False,
|
| 324 |
+
"error": str(e),
|
| 325 |
+
}
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
agent.reset()
|
| 329 |
+
|
| 330 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 331 |
+
f.flush()
|
| 332 |
+
|
| 333 |
+
total = len(self._results)
|
| 334 |
+
correct = sum(r["correct"] for r in self.results)
|
| 335 |
+
hallucination = sum(r["hallucination"] for r in self.results)
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"total": total,
|
| 339 |
+
"correct": correct,
|
| 340 |
+
"hallucination": hallucination,
|
| 341 |
+
"accuracy": correct / total if total else "N/A",
|
| 342 |
+
"hallucination rate": hallucination / total if total else "N/A",
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# This code is modified from the
|
| 347 |
+
# evaluators in the original repo
|
| 348 |
+
# https://github.com/ShishirPatil/gorilla
|
| 349 |
+
# Get all the subtrees given a root_node
|
| 350 |
+
def get_all_sub_trees(root_node):
|
| 351 |
+
node_stack = []
|
| 352 |
+
sub_tree_sexp_list = []
|
| 353 |
+
depth = 1
|
| 354 |
+
# text = root_node.text
|
| 355 |
+
node_stack.append([root_node, depth])
|
| 356 |
+
while len(node_stack) != 0:
|
| 357 |
+
cur_node, cur_depth = node_stack.pop()
|
| 358 |
+
if cur_node.child_count > 0:
|
| 359 |
+
sub_tree_sexp_list.append(
|
| 360 |
+
[
|
| 361 |
+
str(cur_node),
|
| 362 |
+
cur_depth,
|
| 363 |
+
cur_node,
|
| 364 |
+
cur_node.children[0].text,
|
| 365 |
+
]
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
sub_tree_sexp_list.append(
|
| 369 |
+
[str(cur_node), cur_depth, cur_node, None]
|
| 370 |
+
)
|
| 371 |
+
for child_node in cur_node.children:
|
| 372 |
+
if len(child_node.children) != 0:
|
| 373 |
+
depth = cur_depth + 1
|
| 374 |
+
node_stack.append([child_node, depth])
|
| 375 |
+
return sub_tree_sexp_list
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Parse the program into AST trees
|
| 379 |
+
def ast_parse(candidate):
|
| 380 |
+
PY_LANGUAGE = Language(tspython.language())
|
| 381 |
+
parser = Parser(PY_LANGUAGE)
|
| 382 |
+
|
| 383 |
+
candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
|
| 384 |
+
return candidate_tree
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Get all the arguments in the ast tree
|
| 388 |
+
def get_args(node, dataset_name):
|
| 389 |
+
if node.child_count == 0:
|
| 390 |
+
return []
|
| 391 |
+
args_list = []
|
| 392 |
+
if dataset_name == "huggingface":
|
| 393 |
+
for child in node.children[0].children[0].children[1].children:
|
| 394 |
+
if "=" in child.text.decode():
|
| 395 |
+
args_list.append(child.children[2].text)
|
| 396 |
+
elif (
|
| 397 |
+
child.text.decode() != "("
|
| 398 |
+
and child.text.decode() != ")"
|
| 399 |
+
and child.text.decode() != ","
|
| 400 |
+
):
|
| 401 |
+
args_list.append(child.text)
|
| 402 |
+
elif dataset_name == "tensorflowhub":
|
| 403 |
+
for child in node.children[0].children[0].children[1].children:
|
| 404 |
+
if (
|
| 405 |
+
'model=' in child.text.decode()
|
| 406 |
+
or 'model =' in child.text.decode()
|
| 407 |
+
):
|
| 408 |
+
args_list.append(child.children[2].text)
|
| 409 |
+
elif (
|
| 410 |
+
child.text.decode() != "("
|
| 411 |
+
and child.text.decode() != ")"
|
| 412 |
+
and child.text.decode() != ","
|
| 413 |
+
):
|
| 414 |
+
args_list.append(child.text)
|
| 415 |
+
elif dataset_name == "torchhub":
|
| 416 |
+
for child in node.children[0].children[0].children[1].children:
|
| 417 |
+
if (
|
| 418 |
+
"repo_or_dir" in child.text.decode()
|
| 419 |
+
or "model" in child.text.decode()
|
| 420 |
+
):
|
| 421 |
+
args_list.append(child.children[2].text)
|
| 422 |
+
return args_list
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# Check if there is an api match
|
| 426 |
+
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
|
| 427 |
+
for idx, base_tree in enumerate(base_tree_list):
|
| 428 |
+
if base_tree.children[0].children[0].child_count == 0:
|
| 429 |
+
continue
|
| 430 |
+
api_name = base_tree.children[0].children[0].children[0].text
|
| 431 |
+
for candidate_tree in candidate_subtree_list:
|
| 432 |
+
if candidate_tree[3] == api_name:
|
| 433 |
+
break
|
| 434 |
+
# Now we have a sub-tree
|
| 435 |
+
candidate_tree = candidate_tree[2]
|
| 436 |
+
args_list = get_args(base_tree, dataset_name)
|
| 437 |
+
if len(args_list) == 0:
|
| 438 |
+
continue
|
| 439 |
+
ast_match = True
|
| 440 |
+
for arg in args_list:
|
| 441 |
+
if (
|
| 442 |
+
arg.decode().lstrip("'").rstrip("'")
|
| 443 |
+
not in candidate_tree.text.decode()
|
| 444 |
+
):
|
| 445 |
+
ast_match = False
|
| 446 |
+
break
|
| 447 |
+
if ast_match:
|
| 448 |
+
return idx
|
| 449 |
+
return -1
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def evaluate_response(
|
| 453 |
+
response, question_id, dataset_name, api_database, qa_pairs, ast_database
|
| 454 |
+
):
|
| 455 |
+
try:
|
| 456 |
+
# Index the "api_call" domain
|
| 457 |
+
output = response.split("api_call")
|
| 458 |
+
if len(output) == 1:
|
| 459 |
+
api_call = output[0]
|
| 460 |
+
else:
|
| 461 |
+
# Parse the output
|
| 462 |
+
output = output[1].split("api_provider")[0]
|
| 463 |
+
if ":" not in output:
|
| 464 |
+
start = 0
|
| 465 |
+
else:
|
| 466 |
+
start = output.index(":")
|
| 467 |
+
if ")" not in output:
|
| 468 |
+
end = -2
|
| 469 |
+
else:
|
| 470 |
+
end = output.rindex(")")
|
| 471 |
+
api_call = output[start + 2 : end + 1]
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
ast_tree = ast_parse(api_call)
|
| 475 |
+
except Exception as parse_error:
|
| 476 |
+
print(f"Error parsing api_call: {api_call}, error: {parse_error}")
|
| 477 |
+
return parse_error, False, False
|
| 478 |
+
# Search for a subtree
|
| 479 |
+
ast_subtree_list = get_all_sub_trees(ast_tree)
|
| 480 |
+
# Check which ast tree is matching
|
| 481 |
+
database_index = ast_check(
|
| 482 |
+
ast_subtree_list, ast_database, dataset_name
|
| 483 |
+
)
|
| 484 |
+
# We cannot index this ast in our database
|
| 485 |
+
if database_index == -1:
|
| 486 |
+
halluncination = True
|
| 487 |
+
correct = False
|
| 488 |
+
# We index our reference api_call
|
| 489 |
+
ref_api_call = api_database[database_index]
|
| 490 |
+
# Check for functionality
|
| 491 |
+
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
| 492 |
+
correct = True
|
| 493 |
+
halluncination = False
|
| 494 |
+
else:
|
| 495 |
+
return None, False, False
|
| 496 |
+
except Exception as e:
|
| 497 |
+
print(f'Error parsing response: {response}, error: {e}')
|
| 498 |
+
return e, False, False
|
| 499 |
+
|
| 500 |
+
return None, correct, halluncination
|
camel/benchmarks/base.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 19 |
+
|
| 20 |
+
from camel.agents import ChatAgent
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BaseBenchmark(ABC):
|
| 26 |
+
r"""Base class for benchmarks.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
name (str): Name of the benchmark.
|
| 30 |
+
data_dir (str): Path to the data directory.
|
| 31 |
+
save_to (str): Path to save the results.
|
| 32 |
+
processes (int): Number of processes to use for parallel
|
| 33 |
+
processing. :(default: :obj:`1`)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
| 38 |
+
):
|
| 39 |
+
r"""Initialize the benchmark.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
name (str): Name of the benchmark.
|
| 43 |
+
data_dir (str): Path to the data directory.
|
| 44 |
+
save_to (str): Path to save the results.
|
| 45 |
+
processes (int): Number of processes to use for parallel
|
| 46 |
+
processing. :(default: :obj:`1`)
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
self.name = name
|
| 50 |
+
self.data_dir = Path(data_dir)
|
| 51 |
+
self.processes = processes
|
| 52 |
+
self.save_to = save_to
|
| 53 |
+
if not self.data_dir.exists():
|
| 54 |
+
logger.info(
|
| 55 |
+
f"Data directory {data_dir} does not exist. Creating it."
|
| 56 |
+
)
|
| 57 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
if not self.data_dir.is_dir():
|
| 59 |
+
raise NotADirectoryError(
|
| 60 |
+
f"Data directory {data_dir} is not a directory"
|
| 61 |
+
)
|
| 62 |
+
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
| 63 |
+
self._results: List[Dict[str, Any]] = []
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def download(self) -> "BaseBenchmark":
|
| 67 |
+
r"""Download the benchmark data.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
BaseBenchmark: The benchmark instance.
|
| 71 |
+
"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
| 76 |
+
r"""Load the benchmark data.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
force_download (bool): Whether to force download the data.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
BaseBenchmark: The benchmark instance.
|
| 83 |
+
"""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def train(self) -> List[Dict[str, Any]]:
|
| 88 |
+
r"""Get the training data.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
List[Dict[str, Any]]: The training data.
|
| 92 |
+
"""
|
| 93 |
+
if not self._data:
|
| 94 |
+
logger.info("Data not loaded. Loading data.")
|
| 95 |
+
self.load()
|
| 96 |
+
return self._data["train"]
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def valid(self) -> List[Dict[str, Any]]:
|
| 100 |
+
r"""Get the validation data.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List[Dict[str, Any]]: The validation data.
|
| 104 |
+
"""
|
| 105 |
+
if not self._data:
|
| 106 |
+
logger.info("Data not loaded. Loading data.")
|
| 107 |
+
self.load()
|
| 108 |
+
return self._data["valid"]
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def test(self) -> List[Dict[str, Any]]:
|
| 112 |
+
r"""Get the test data.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List[Dict[str, Any]]: The test data.
|
| 116 |
+
"""
|
| 117 |
+
if not self._data:
|
| 118 |
+
logger.info("Data not loaded. Loading data.")
|
| 119 |
+
self.load()
|
| 120 |
+
return self._data["test"]
|
| 121 |
+
|
| 122 |
+
@abstractmethod
|
| 123 |
+
def run(
|
| 124 |
+
self,
|
| 125 |
+
agent: ChatAgent,
|
| 126 |
+
on: Literal["train", "valid", "test"],
|
| 127 |
+
randomize: bool = False,
|
| 128 |
+
subset: Optional[int] = None,
|
| 129 |
+
*args,
|
| 130 |
+
**kwargs,
|
| 131 |
+
) -> "BaseBenchmark":
|
| 132 |
+
r"""Run the benchmark.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
agent (ChatAgent): The chat agent.
|
| 136 |
+
on (str): The data split to run the benchmark on.
|
| 137 |
+
randomize (bool): Whether to randomize the data.
|
| 138 |
+
subset (int): The subset of the data to run the benchmark on.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
BaseBenchmark: The benchmark instance.
|
| 142 |
+
"""
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def results(self) -> List[Dict[str, Any]]:
|
| 147 |
+
r"""Get the results.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
List[Dict[str, Any]]: The results.
|
| 151 |
+
"""
|
| 152 |
+
return self._results
|
camel/benchmarks/gaia.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import string
|
| 21 |
+
import uuid
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
| 24 |
+
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from camel.agents import ChatAgent
|
| 28 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 29 |
+
from camel.messages import BaseMessage
|
| 30 |
+
from camel.retrievers.auto_retriever import AutoRetriever
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RetrieverProtocol(Protocol):
|
| 36 |
+
r"""Protocol for the retriever class. Any retriever class implementing
|
| 37 |
+
this protocol can be used in the benchmark class.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def retrieve(
|
| 41 |
+
self, query: str, contents: List[str], **kwargs: Dict[str, Any]
|
| 42 |
+
) -> Dict[str, Any]:
|
| 43 |
+
r"""Retrieve the relevant content for the query.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
query (str): The query to retrieve the content for.
|
| 47 |
+
contents (List[str]): The list of contents to search in.
|
| 48 |
+
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dict[str, Any]: The relevant content for the query.
|
| 52 |
+
"""
|
| 53 |
+
...
|
| 54 |
+
|
| 55 |
+
def reset(self, **kwargs) -> bool:
|
| 56 |
+
r"""Reset the retriever.
|
| 57 |
+
Some benchmarks may require resetting the retriever
|
| 58 |
+
after each query.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
**kwargs: Additional keyword arguments.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
bool: True if the reset was successful, False otherwise.
|
| 65 |
+
"""
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DefaultGAIARetriever(AutoRetriever):
|
| 70 |
+
r"""Default retriever for the GAIA benchmark.
|
| 71 |
+
This retriever uses AutoRetriever in camel to retrieve the content based on
|
| 72 |
+
the query.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def retrieve(
|
| 76 |
+
self, query: str, contents: List[str], **kwargs: Any
|
| 77 |
+
) -> Dict[str, Any]:
|
| 78 |
+
r"""Retrieve the content based on the query.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
query (str): The query to search for.
|
| 82 |
+
contents (List[str]): The list of contents to search from.
|
| 83 |
+
**kwargs (Any): The keyword arguments to pass to the
|
| 84 |
+
retriever.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Dict[str, Any]: The retrieved content.
|
| 88 |
+
"""
|
| 89 |
+
return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
|
| 90 |
+
|
| 91 |
+
def reset(self, **kwargs: Any) -> bool:
|
| 92 |
+
r"""Reset the retriever.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
**kwargs (Any): The keyword arguments to pass to the
|
| 96 |
+
retriever.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
bool: Whether the reset was successful.
|
| 100 |
+
"""
|
| 101 |
+
path = Path(self.vector_storage_local_path or os.getcwd())
|
| 102 |
+
task_id = str(kwargs.get("task_id", uuid.uuid4()))
|
| 103 |
+
retriever_dir = path / task_id
|
| 104 |
+
if not retriever_dir.exists():
|
| 105 |
+
try:
|
| 106 |
+
retriever_dir.mkdir(parents=True)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(
|
| 109 |
+
"Error in creating directory: " + f"{retriever_dir}: {e!s}"
|
| 110 |
+
)
|
| 111 |
+
return False
|
| 112 |
+
self.vector_storage_local_path = str(retriever_dir)
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GAIABenchmark(BaseBenchmark):
|
| 117 |
+
r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
|
| 118 |
+
Assistants"
|
| 119 |
+
<https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
data_dir (str): The directory to save the data.
|
| 123 |
+
save_to (str): The file to save the results.
|
| 124 |
+
retriever (Optional[RetrieverProtocol]): The retriever to use.
|
| 125 |
+
(default: :obj:`None`)
|
| 126 |
+
processes (int, optional): The number of processes to use.
|
| 127 |
+
(default: :obj:`1`)
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
data_dir: str,
|
| 133 |
+
save_to: str,
|
| 134 |
+
retriever: Optional[RetrieverProtocol] = None,
|
| 135 |
+
processes: int = 1,
|
| 136 |
+
):
|
| 137 |
+
r"""Initialize the GAIA benchmark.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
data_dir (str): The directory to save the data.
|
| 141 |
+
save_to (str): The file to save the results.
|
| 142 |
+
retriever (Optional[RetrieverProtocol], optional): The retriever to
|
| 143 |
+
use. (default: :obj:`None`)
|
| 144 |
+
processes (int, optional): The number of processes to use for
|
| 145 |
+
parallel processing. (default: :obj:`1`)
|
| 146 |
+
"""
|
| 147 |
+
super().__init__("gaia", data_dir, save_to, processes)
|
| 148 |
+
self.retriever = retriever or DefaultGAIARetriever()
|
| 149 |
+
|
| 150 |
+
def download(self):
|
| 151 |
+
r"""Download the GAIA dataset."""
|
| 152 |
+
from huggingface_hub import snapshot_download
|
| 153 |
+
|
| 154 |
+
snapshot_download(
|
| 155 |
+
repo_id="gaia-benchmark/GAIA",
|
| 156 |
+
repo_type="dataset",
|
| 157 |
+
local_dir=self.data_dir,
|
| 158 |
+
local_dir_use_symlinks=True,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def load(self, force_download=False):
|
| 162 |
+
r"""Load the GAIA dataset.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
force_download (bool, optional): Whether to
|
| 166 |
+
force download the data.
|
| 167 |
+
"""
|
| 168 |
+
if force_download:
|
| 169 |
+
logger.info("Force downloading data.")
|
| 170 |
+
self.download()
|
| 171 |
+
|
| 172 |
+
# Define validation and test directories
|
| 173 |
+
valid_dir = self.data_dir / "2023/validation"
|
| 174 |
+
test_dir = self.data_dir / "2023/test"
|
| 175 |
+
|
| 176 |
+
# Check if directories exist; if not, download the data
|
| 177 |
+
if not valid_dir.is_dir() or not test_dir.is_dir():
|
| 178 |
+
logger.info("Data not found. Downloading data.")
|
| 179 |
+
self.download()
|
| 180 |
+
|
| 181 |
+
# Load metadata for both validation and test datasets
|
| 182 |
+
for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
|
| 183 |
+
self._data[label] = []
|
| 184 |
+
with open(path / "metadata.jsonl", "r") as f:
|
| 185 |
+
lines = f.readlines()
|
| 186 |
+
for line in lines:
|
| 187 |
+
data = json.loads(line)
|
| 188 |
+
if data["task_id"] == "0-0-0-0-0":
|
| 189 |
+
continue
|
| 190 |
+
if data["file_name"]:
|
| 191 |
+
data["file_name"] = path / data["file_name"]
|
| 192 |
+
self._data[label].append(data)
|
| 193 |
+
return self
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def train(self):
|
| 197 |
+
r"""Get the training set."""
|
| 198 |
+
raise NotImplementedError("GAIA does not have a training set.")
|
| 199 |
+
|
| 200 |
+
def run( # type: ignore[override]
|
| 201 |
+
self,
|
| 202 |
+
agent: ChatAgent,
|
| 203 |
+
on: Literal["train", "valid", "test"],
|
| 204 |
+
level: Union[int, List[int], Literal["all"]],
|
| 205 |
+
randomize: bool = False,
|
| 206 |
+
subset: Optional[int] = None,
|
| 207 |
+
) -> Dict[str, Any]:
|
| 208 |
+
r"""Run the benchmark.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
agent (ChatAgent): The agent to run the benchmark.
|
| 212 |
+
on (Literal["valid", "test"]): The set to run the benchmark.
|
| 213 |
+
level (Union[int, List[int], Literal["all"]]): The level to run
|
| 214 |
+
the benchmark.
|
| 215 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 216 |
+
(default: :obj:`False`)
|
| 217 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 218 |
+
(default: :obj:`None`)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Dict[str, Any]: The results of the benchmark.
|
| 222 |
+
"""
|
| 223 |
+
# Validate inputs
|
| 224 |
+
if on not in ["valid", "test"]:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
levels = (
|
| 230 |
+
[1, 2, 3]
|
| 231 |
+
if level == "all"
|
| 232 |
+
else [level]
|
| 233 |
+
if isinstance(level, int)
|
| 234 |
+
else level
|
| 235 |
+
)
|
| 236 |
+
if not all(
|
| 237 |
+
isinstance(level, int) and level in [1, 2, 3] for level in levels
|
| 238 |
+
):
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Invalid value for `level`: {level}, expected 1, 2, 3 "
|
| 241 |
+
"or 'all'."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
logger.info(f"Running benchmark on {on} set at levels {levels}.")
|
| 245 |
+
datas = [data for data in self._data[on] if data["Level"] in levels]
|
| 246 |
+
|
| 247 |
+
# Shuffle and subset data if necessary
|
| 248 |
+
if randomize:
|
| 249 |
+
random.shuffle(datas)
|
| 250 |
+
if subset:
|
| 251 |
+
datas = datas[:subset]
|
| 252 |
+
|
| 253 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 254 |
+
|
| 255 |
+
# Initialize results storage
|
| 256 |
+
self._results = []
|
| 257 |
+
|
| 258 |
+
# Process tasks
|
| 259 |
+
with open(self.save_to, "w") as f:
|
| 260 |
+
for task in tqdm(datas, desc="Running"):
|
| 261 |
+
if not self._prepare_task(task):
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
result = agent.step(self._create_user_message(task))
|
| 266 |
+
self._process_result(agent, task, result, f)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
self._handle_error(task, e, f)
|
| 269 |
+
finally:
|
| 270 |
+
agent.reset()
|
| 271 |
+
|
| 272 |
+
return self._generate_summary()
|
| 273 |
+
|
| 274 |
+
def _prepare_task(self, task: Dict[str, Any]) -> bool:
|
| 275 |
+
r"""Prepare the task by validating and enriching its data."""
|
| 276 |
+
if task["file_name"]:
|
| 277 |
+
file_path = Path(task["file_name"])
|
| 278 |
+
if not file_path.exists():
|
| 279 |
+
logger.info(
|
| 280 |
+
f"Skipping task because file not found: {file_path}"
|
| 281 |
+
)
|
| 282 |
+
return False
|
| 283 |
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
| 284 |
+
if not self.retriever.reset(task_id=task["task_id"]):
|
| 285 |
+
return False
|
| 286 |
+
retrieved_info = self.retriever.retrieve(
|
| 287 |
+
query=task["Question"], contents=[task["file_name"]]
|
| 288 |
+
)
|
| 289 |
+
retrieved_content = [
|
| 290 |
+
item["text"]
|
| 291 |
+
for item in retrieved_info.get("Retrieved Context", [])
|
| 292 |
+
]
|
| 293 |
+
if retrieved_content:
|
| 294 |
+
task["Question"] += "\n" + "\n".join(retrieved_content)
|
| 295 |
+
else:
|
| 296 |
+
logger.info(
|
| 297 |
+
f"Skipping task due to unsupported file "
|
| 298 |
+
f"format: {file_path.suffix}"
|
| 299 |
+
)
|
| 300 |
+
return False
|
| 301 |
+
return True
|
| 302 |
+
|
| 303 |
+
def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
|
| 304 |
+
r"""Create a user message from a task."""
|
| 305 |
+
return BaseMessage.make_user_message(
|
| 306 |
+
role_name="User",
|
| 307 |
+
content=task["Question"],
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def _process_result(
|
| 311 |
+
self,
|
| 312 |
+
agent: ChatAgent,
|
| 313 |
+
task: Dict[str, Any],
|
| 314 |
+
result: Any,
|
| 315 |
+
file_obj: Any,
|
| 316 |
+
) -> None:
|
| 317 |
+
r"""Process and store the result of a task."""
|
| 318 |
+
model_answer = self.get_final_answer(result.msgs[0].content)
|
| 319 |
+
final_answer = task["Final answer"]
|
| 320 |
+
score = self.question_scorer(model_answer, final_answer)
|
| 321 |
+
tool_calls = result.info.get("tool_calls", [])
|
| 322 |
+
|
| 323 |
+
result_data = {
|
| 324 |
+
"task_id": task["task_id"],
|
| 325 |
+
"question": task["Question"],
|
| 326 |
+
"level": task["Level"],
|
| 327 |
+
"model_answer": model_answer,
|
| 328 |
+
"ground_truth": final_answer,
|
| 329 |
+
"tool_calls": [tool.model_dump() for tool in tool_calls],
|
| 330 |
+
"error": None,
|
| 331 |
+
"score": int(score),
|
| 332 |
+
"history": agent.memory.get_context(),
|
| 333 |
+
}
|
| 334 |
+
self._results.append(result_data)
|
| 335 |
+
file_obj.write(json.dumps(result_data, indent=2) + "\n")
|
| 336 |
+
file_obj.flush()
|
| 337 |
+
|
| 338 |
+
def _handle_error(
|
| 339 |
+
self, task: Dict[str, Any], error: Exception, file_obj: Any
|
| 340 |
+
) -> None:
|
| 341 |
+
r"""Handle errors encountered during task processing."""
|
| 342 |
+
logger.warning(f"Error processing task {task['task_id']}: {error}")
|
| 343 |
+
error_data = {
|
| 344 |
+
"task_id": task["task_id"],
|
| 345 |
+
"question": task["Question"],
|
| 346 |
+
"level": task["Level"],
|
| 347 |
+
"model_answer": "ERROR",
|
| 348 |
+
"ground_truth": task["Final answer"],
|
| 349 |
+
"tool_calls": [],
|
| 350 |
+
"error": str(error),
|
| 351 |
+
"score": 0,
|
| 352 |
+
}
|
| 353 |
+
self._results.append(error_data)
|
| 354 |
+
file_obj.write(json.dumps(error_data, indent=2) + "\n")
|
| 355 |
+
file_obj.flush()
|
| 356 |
+
|
| 357 |
+
def _generate_summary(self) -> Dict[str, Any]:
|
| 358 |
+
r"""Generate and return a summary of the benchmark results."""
|
| 359 |
+
return {
|
| 360 |
+
"total": len(self._results),
|
| 361 |
+
"correct": sum(result["score"] for result in self._results),
|
| 362 |
+
"results": self._results,
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
|
| 366 |
+
r"""Scorer for the GAIA benchmark.
|
| 367 |
+
https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
|
| 368 |
+
scorer.py
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
model_answer (str): The model answer.
|
| 372 |
+
ground_truth (str): The ground truth answer.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
bool: The score of the model
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def is_float(element: Any) -> bool:
|
| 379 |
+
try:
|
| 380 |
+
float(element)
|
| 381 |
+
return True
|
| 382 |
+
except ValueError:
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
if is_float(ground_truth):
|
| 386 |
+
logger.info(f"Evaluating {model_answer} as a number.")
|
| 387 |
+
normalized_answer = self.normalize_number_str(model_answer)
|
| 388 |
+
return normalized_answer == float(ground_truth)
|
| 389 |
+
|
| 390 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
| 391 |
+
logger.info(
|
| 392 |
+
f"Evaluating {model_answer} as a comma separated list."
|
| 393 |
+
)
|
| 394 |
+
gt_elems = self.split_string(ground_truth)
|
| 395 |
+
ma_elems = self.split_string(model_answer)
|
| 396 |
+
|
| 397 |
+
if len(gt_elems) != len(ma_elems):
|
| 398 |
+
logger.warning(
|
| 399 |
+
"Answer lists have different lengths, returning False.",
|
| 400 |
+
UserWarning,
|
| 401 |
+
)
|
| 402 |
+
return False
|
| 403 |
+
|
| 404 |
+
comparisons = []
|
| 405 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
| 406 |
+
if is_float(gt_elem):
|
| 407 |
+
normalized_ma_elem = self.normalize_number_str(ma_elem)
|
| 408 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
| 409 |
+
else:
|
| 410 |
+
ma_elem = self.normalize_str(ma_elem, remove_punct=False)
|
| 411 |
+
gt_elem = self.normalize_str(gt_elem, remove_punct=False)
|
| 412 |
+
comparisons.append(ma_elem == gt_elem)
|
| 413 |
+
return all(comparisons)
|
| 414 |
+
else:
|
| 415 |
+
logger.info(f"Evaluating {model_answer} as a string.")
|
| 416 |
+
ma_elem = self.normalize_str(model_answer)
|
| 417 |
+
gt_elem = self.normalize_str(ground_truth)
|
| 418 |
+
return ma_elem == gt_elem
|
| 419 |
+
|
| 420 |
+
def normalize_number_str(self, number_str: str) -> float:
|
| 421 |
+
for char in ["$", "%", ","]:
|
| 422 |
+
number_str = number_str.replace(char, "")
|
| 423 |
+
try:
|
| 424 |
+
return float(number_str)
|
| 425 |
+
except ValueError:
|
| 426 |
+
logger.error(
|
| 427 |
+
f"String {number_str} cannot be normalized to number str."
|
| 428 |
+
)
|
| 429 |
+
return float("inf")
|
| 430 |
+
|
| 431 |
+
def split_string(
|
| 432 |
+
self, s: str, char_list: Optional[List[str]] = None
|
| 433 |
+
) -> list[str]:
|
| 434 |
+
r"""Split a string based on a list of characters.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
s (str): The string to split.
|
| 438 |
+
char_list (Optional[List[str]], optional): T
|
| 439 |
+
he list of characters to split on.
|
| 440 |
+
(default: :obj:`None`)
|
| 441 |
+
"""
|
| 442 |
+
if char_list is None:
|
| 443 |
+
char_list = [",", ";"]
|
| 444 |
+
pattern = f"[{''.join(char_list)}]"
|
| 445 |
+
return re.split(pattern, s)
|
| 446 |
+
|
| 447 |
+
def normalize_str(self, input_str, remove_punct=True) -> str:
|
| 448 |
+
r"""Normalize a string.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
input_str: The input string to normalize.
|
| 452 |
+
remove_punct: Whether to remove punctuation.
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
str: The normalized string.
|
| 456 |
+
"""
|
| 457 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
| 458 |
+
if remove_punct:
|
| 459 |
+
translator = str.maketrans("", "", string.punctuation)
|
| 460 |
+
return no_spaces.lower().translate(translator)
|
| 461 |
+
else:
|
| 462 |
+
return no_spaces.lower()
|
| 463 |
+
|
| 464 |
+
def get_final_answer(self, content: str) -> str:
|
| 465 |
+
r"""Get the final answer from the content.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
content (str): The content to extract the final answer from.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
str: The final answer.
|
| 472 |
+
"""
|
| 473 |
+
final_answer_index = content.find("FINAL ANSWER")
|
| 474 |
+
if final_answer_index == -1:
|
| 475 |
+
return "FINAL ANSWER not found"
|
| 476 |
+
start_index = final_answer_index + len("FINAL ANSWER: ")
|
| 477 |
+
final_answer_content = content[start_index:].strip()
|
| 478 |
+
return final_answer_content
|
camel/benchmarks/nexus.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import ast
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import random
|
| 20 |
+
import textwrap
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import pandas as pd
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from camel.agents import ChatAgent
|
| 30 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 31 |
+
from camel.messages import BaseMessage
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Define the data class
|
| 37 |
+
@dataclass
|
| 38 |
+
class NexusSample:
|
| 39 |
+
r"""Nexus benchmark dataset sample."""
|
| 40 |
+
|
| 41 |
+
input: str
|
| 42 |
+
output: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class NexusTool:
|
| 47 |
+
r"""Nexus benchmark tool"""
|
| 48 |
+
|
| 49 |
+
function_calls: str
|
| 50 |
+
descriptions: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
dataset_mapping = {
|
| 54 |
+
"NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
|
| 55 |
+
"VirusTotal": "Nexusflow/VirusTotalBenchmark",
|
| 56 |
+
"PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
|
| 57 |
+
"ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
|
| 58 |
+
"OTX": "Nexusflow/OTXAPIBenchmark",
|
| 59 |
+
"VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
|
| 60 |
+
"VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
|
| 61 |
+
"NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
TOOL_CALLING_PROMPT = """
|
| 65 |
+
You are given multiple functions and a user query.
|
| 66 |
+
|
| 67 |
+
Please proceed with generating a function call for the function \
|
| 68 |
+
with the proper arguments that best answers the given prompt.
|
| 69 |
+
|
| 70 |
+
Respond with nothing but the function call ONLY, such that I can \
|
| 71 |
+
directly execute your function call without any post processing \
|
| 72 |
+
necessary from my end. Do not use variables.
|
| 73 |
+
If there are more than two function calls, separate them with a semicolon (;).
|
| 74 |
+
|
| 75 |
+
{tools}
|
| 76 |
+
|
| 77 |
+
Question: {input}
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class NexusBenchmark(BaseBenchmark):
|
| 82 |
+
r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
|
| 83 |
+
Function Calling Benchmark`
|
| 84 |
+
<https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data_dir (str): The directory to save the data.
|
| 88 |
+
save_to (str): The file to save the results.
|
| 89 |
+
processes (int, optional): The number of processes to use.
|
| 90 |
+
(default: :obj:`1`)
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
data_dir: str,
|
| 96 |
+
save_to: str,
|
| 97 |
+
processes: int = 1,
|
| 98 |
+
):
|
| 99 |
+
r"""Initialize the Nexus Function Calling benchmark.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
data_dir (str): The directory to save the data.
|
| 103 |
+
save_to (str): The file to save the results.
|
| 104 |
+
processes (int, optional): The number of processes to use for
|
| 105 |
+
parallel processing. (default: :obj:`1`)
|
| 106 |
+
"""
|
| 107 |
+
super().__init__("nexus", data_dir, save_to, processes)
|
| 108 |
+
self._data: List[NexusSample] = [] # type: ignore[assignment]
|
| 109 |
+
|
| 110 |
+
def download(self):
|
| 111 |
+
r"""Download the Nexus Functional Calling Benchmark dataset."""
|
| 112 |
+
from huggingface_hub import snapshot_download
|
| 113 |
+
|
| 114 |
+
for dataset_name, repo_id in dataset_mapping.items():
|
| 115 |
+
local_dir = self.data_dir / dataset_name
|
| 116 |
+
snapshot_download(
|
| 117 |
+
repo_id=repo_id,
|
| 118 |
+
repo_type="dataset",
|
| 119 |
+
local_dir=local_dir,
|
| 120 |
+
local_dir_use_symlinks=True,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
| 124 |
+
r"""Load the Nexus Benchmark dataset.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
dataset_name (str): Name of the specific dataset to be loaded.
|
| 128 |
+
force_download (bool): Whether to force download the data.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def _load_csv_data(dataset_dir: Path) -> List:
|
| 132 |
+
r"""Load datasets from CSV files."""
|
| 133 |
+
dataset = []
|
| 134 |
+
for file_name in os.listdir(dataset_dir):
|
| 135 |
+
file_path = dataset_dir / file_name
|
| 136 |
+
if file_name.endswith(".csv"):
|
| 137 |
+
data = pd.read_csv(file_path)
|
| 138 |
+
for _, sample in data.iterrows():
|
| 139 |
+
dataset.append(
|
| 140 |
+
NexusSample(
|
| 141 |
+
sample["Input"], "".join(sample["Output"])
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
logger.warning(f"Skipping unsupported file: {file_name}")
|
| 147 |
+
return dataset
|
| 148 |
+
|
| 149 |
+
def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
|
| 150 |
+
r"""Load datasets from Parquet files."""
|
| 151 |
+
dataset = []
|
| 152 |
+
if not data_dir.exists():
|
| 153 |
+
raise FileNotFoundError(
|
| 154 |
+
f"Data directory '{data_dir}' does not exist."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
for file_name in os.listdir(data_dir):
|
| 158 |
+
file_path = data_dir / file_name
|
| 159 |
+
if file_name.endswith(".parquet"):
|
| 160 |
+
data = pd.read_parquet(file_path)
|
| 161 |
+
dataset.extend(_process_parquet_data(data, dataset_name))
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
logger.warning(f"Skipping unsupported file: {file_name}")
|
| 165 |
+
|
| 166 |
+
return dataset
|
| 167 |
+
|
| 168 |
+
def _process_parquet_data(
|
| 169 |
+
data: pd.DataFrame, dataset_name: str
|
| 170 |
+
) -> List:
|
| 171 |
+
r"""Process data from Parquet files based on dataset name."""
|
| 172 |
+
dataset: List = []
|
| 173 |
+
dataset_handlers = {
|
| 174 |
+
"NVDLibrary": _process_nvdlibrary,
|
| 175 |
+
"VirusTotal": _process_simple,
|
| 176 |
+
"PlacesAPI": _process_simple,
|
| 177 |
+
"ClimateAPI": _process_simple,
|
| 178 |
+
"OTX": _process_simple,
|
| 179 |
+
"VirusTotal-NestedCalls": _process_nested_calls,
|
| 180 |
+
"VirusTotal-ParallelCalls": _process_parallel_calls,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if dataset_name not in dataset_handlers:
|
| 184 |
+
logger.warning(
|
| 185 |
+
f"No specific handler for dataset: {dataset_name}"
|
| 186 |
+
)
|
| 187 |
+
return dataset
|
| 188 |
+
|
| 189 |
+
handler = dataset_handlers[dataset_name]
|
| 190 |
+
for _, sample in data.iterrows():
|
| 191 |
+
processed_sample = handler(sample)
|
| 192 |
+
if processed_sample:
|
| 193 |
+
dataset.append(processed_sample)
|
| 194 |
+
return dataset
|
| 195 |
+
|
| 196 |
+
def _process_nvdlibrary(sample) -> NexusSample:
|
| 197 |
+
r"""Process samples for the NVDLibrary dataset."""
|
| 198 |
+
return NexusSample(
|
| 199 |
+
sample["Input"], sample["Output"].replace("r = nvdlib.", "")
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def _process_simple(sample) -> NexusSample:
|
| 203 |
+
r"""Process samples for simple datasets (e.g., VirusTotal)."""
|
| 204 |
+
return NexusSample(sample["Input"], sample["Output"])
|
| 205 |
+
|
| 206 |
+
def _process_nested_calls(sample) -> Union[NexusSample, None]:
|
| 207 |
+
r"""Process samples for VirusTotal-NestedCalls dataset."""
|
| 208 |
+
if len(sample["fncall"]) == 1:
|
| 209 |
+
return NexusSample(
|
| 210 |
+
sample["generated_question"], "".join(sample["fncall"])
|
| 211 |
+
)
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def _process_parallel_calls(sample) -> Union[NexusSample, None]:
|
| 215 |
+
r"""Process samples for VirusTotal-ParallelCalls dataset."""
|
| 216 |
+
if len(sample["fncall"]) > 1:
|
| 217 |
+
return NexusSample(
|
| 218 |
+
sample["generated_question"], "; ".join(sample["fncall"])
|
| 219 |
+
)
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if force_download:
|
| 223 |
+
logger.info("Force downloading data.")
|
| 224 |
+
self.download()
|
| 225 |
+
|
| 226 |
+
# Validate dataset name
|
| 227 |
+
if dataset_name not in dataset_mapping:
|
| 228 |
+
available_datasets = list(dataset_mapping.keys())
|
| 229 |
+
raise ValueError(
|
| 230 |
+
f"Dataset '{dataset_name}' is not recognized. "
|
| 231 |
+
f"Available datasets: {available_datasets}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Get the dataset directory
|
| 235 |
+
dataset_dir = self.data_dir / dataset_name
|
| 236 |
+
if not dataset_dir.exists():
|
| 237 |
+
raise FileNotFoundError(
|
| 238 |
+
f"The dataset directory for '{dataset_name}' \
|
| 239 |
+
does not exist at {dataset_dir}. "
|
| 240 |
+
"Please download it first."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Load the dataset
|
| 244 |
+
if dataset_name == "NVDLibrary-NestedCalls":
|
| 245 |
+
self._data = _load_csv_data(dataset_dir)
|
| 246 |
+
else:
|
| 247 |
+
self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def train(self):
|
| 251 |
+
r"""Get the training set."""
|
| 252 |
+
raise NotImplementedError(
|
| 253 |
+
"Nexus Functional Calling has only a single 'train' set."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def run( # type: ignore[override, return]
|
| 257 |
+
self,
|
| 258 |
+
agent: ChatAgent,
|
| 259 |
+
task: Literal[
|
| 260 |
+
"NVDLibrary",
|
| 261 |
+
"VirusTotal",
|
| 262 |
+
"OTX",
|
| 263 |
+
"PlacesAPI",
|
| 264 |
+
"ClimateAPI",
|
| 265 |
+
"VirusTotal-ParallelCalls",
|
| 266 |
+
"VirusTotal-NestedCalls",
|
| 267 |
+
"NVDLibrary-NestedCalls",
|
| 268 |
+
],
|
| 269 |
+
randomize: bool = False,
|
| 270 |
+
subset: Optional[int] = None,
|
| 271 |
+
) -> Dict[str, Any]:
|
| 272 |
+
r"""Run the benchmark.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
agent (ChatAgent): The agent to run the benchmark.
|
| 276 |
+
task (Literal["NVDLibrary", "VirusTotal", "OTX",
|
| 277 |
+
"PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
|
| 278 |
+
"VirusTotal-NestedCalls",
|
| 279 |
+
"NVDLibrary-NestedCalls"]): The task to run the benchmark.
|
| 280 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 281 |
+
(default: :obj:`False`)
|
| 282 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 283 |
+
(default: :obj:`None`)
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Dict[str, Any]: The results of the benchmark.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
if task not in dataset_mapping:
|
| 290 |
+
raise ValueError(f"Invalid value for dataset: {task}.")
|
| 291 |
+
|
| 292 |
+
logger.info(f"Running Nexus Function Calling benchmark on {task}.")
|
| 293 |
+
self.load(task)
|
| 294 |
+
datas = self._data
|
| 295 |
+
|
| 296 |
+
# Shuffle and subset data if necessary
|
| 297 |
+
if randomize:
|
| 298 |
+
random.shuffle(datas)
|
| 299 |
+
if subset:
|
| 300 |
+
datas = datas[:subset]
|
| 301 |
+
|
| 302 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 303 |
+
|
| 304 |
+
# Initialize results storage
|
| 305 |
+
self._results = []
|
| 306 |
+
|
| 307 |
+
# Process samples
|
| 308 |
+
tools = construct_tool_descriptions(task)
|
| 309 |
+
with open(self.save_to, "w") as f:
|
| 310 |
+
for sample in tqdm(datas, desc="Running"):
|
| 311 |
+
prompt = construct_prompt(input=sample.input, tools=tools)
|
| 312 |
+
msg = BaseMessage.make_user_message(
|
| 313 |
+
role_name="User", content=prompt
|
| 314 |
+
)
|
| 315 |
+
ground_truth_call = sample.output
|
| 316 |
+
try:
|
| 317 |
+
# Generate response
|
| 318 |
+
response = agent.step(msg)
|
| 319 |
+
agent_call = response.msgs[0].content
|
| 320 |
+
|
| 321 |
+
# Evaluate response
|
| 322 |
+
if agent_call:
|
| 323 |
+
result = compare_function_calls(
|
| 324 |
+
agent_call=agent_call,
|
| 325 |
+
ground_truth_call=ground_truth_call,
|
| 326 |
+
)
|
| 327 |
+
self._results.append(
|
| 328 |
+
{
|
| 329 |
+
"input": sample.input,
|
| 330 |
+
"agent_call": agent_call,
|
| 331 |
+
"ground_truth_call": ground_truth_call,
|
| 332 |
+
"result": result,
|
| 333 |
+
"error": None,
|
| 334 |
+
}
|
| 335 |
+
)
|
| 336 |
+
except Exception as e:
|
| 337 |
+
logger.warning(f"Error in processing task: {sample.input}")
|
| 338 |
+
self._results.append(
|
| 339 |
+
{
|
| 340 |
+
"input": sample.input,
|
| 341 |
+
"agent_call": None,
|
| 342 |
+
"ground_truth_call": ground_truth_call,
|
| 343 |
+
"result": 0,
|
| 344 |
+
"error": str(e),
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
agent.reset()
|
| 349 |
+
|
| 350 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 351 |
+
f.flush()
|
| 352 |
+
|
| 353 |
+
total = len(self._results)
|
| 354 |
+
correct = sum(r["result"] for r in self._results)
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"total": total,
|
| 358 |
+
"correct": correct,
|
| 359 |
+
"accuracy": correct / total,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# Utility functions
|
| 364 |
+
def construct_tool_descriptions(dataset_name: str) -> str:
|
| 365 |
+
r"""Construct tool descriptions from function definitions and
|
| 366 |
+
descriptions."""
|
| 367 |
+
tool_dataset_mapping = {
|
| 368 |
+
"NVDLibrary": "CVECPE",
|
| 369 |
+
"VirusTotal": "VirusTotal",
|
| 370 |
+
"PlacesAPI": "Places",
|
| 371 |
+
"ClimateAPI": "Climate",
|
| 372 |
+
"OTX": "OTX",
|
| 373 |
+
"VirusTotal-NestedCalls": "VT_Multi (Nested)",
|
| 374 |
+
"VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
|
| 375 |
+
"NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if dataset_name not in tool_dataset_mapping:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Dataset '{dataset_name}' is not recognized. "
|
| 381 |
+
f"Available datasets: {list(dataset_mapping.keys())}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Load the dataset based on the dataset name
|
| 385 |
+
dataset = load_dataset(
|
| 386 |
+
"Nexusflow/Function_Call_Definitions",
|
| 387 |
+
name=tool_dataset_mapping[dataset_name],
|
| 388 |
+
)["train"]
|
| 389 |
+
|
| 390 |
+
# Construct tool descriptions
|
| 391 |
+
tools = [
|
| 392 |
+
NexusTool(tool["function_calls"], tool["descriptions"])
|
| 393 |
+
for tool in dataset
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
# Generate the tool prompt
|
| 397 |
+
tool_prompt = "".join(
|
| 398 |
+
f"Function:\ndef {tool.function_calls}:\n"
|
| 399 |
+
+ "\"\"\"\n"
|
| 400 |
+
+ f"{tool.descriptions}\n"
|
| 401 |
+
+ "\"\"\"\n"
|
| 402 |
+
for tool in tools
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return tool_prompt
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def construct_prompt(input: str, tools: str) -> str:
|
| 409 |
+
r"Construct prompt from tools and input."
|
| 410 |
+
return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# Functions for function call evaluation
|
| 414 |
+
def parse_function_call(
|
| 415 |
+
call: str,
|
| 416 |
+
) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
|
| 417 |
+
r"""Parse a function call string to extract the function name,
|
| 418 |
+
positional arguments, and keyword arguments, including
|
| 419 |
+
nested function calls.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
tuple: (function_name (str), positional_args (list),
|
| 426 |
+
keyword_args (dict)) or (None, None, None).
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
def preprocess_input(call: str) -> str:
|
| 430 |
+
r"""Remove formatting like code blocks and whitespace."""
|
| 431 |
+
if call.strip().startswith("```python"):
|
| 432 |
+
call = call.strip().removeprefix("```python").removesuffix("```")
|
| 433 |
+
return textwrap.dedent(call).strip()
|
| 434 |
+
|
| 435 |
+
def evaluate_arg(arg):
|
| 436 |
+
r"""Recursively evaluate arguments, including nested calls."""
|
| 437 |
+
if isinstance(arg, ast.Call):
|
| 438 |
+
# Recursively parse nested calls
|
| 439 |
+
func_name, args, kwargs = parse_function_call(ast.unparse(arg))
|
| 440 |
+
return func_name, args, kwargs
|
| 441 |
+
elif isinstance(
|
| 442 |
+
arg, ast.Constant
|
| 443 |
+
): # Handle literals like numbers, strings, etc.
|
| 444 |
+
return arg.value
|
| 445 |
+
elif isinstance(arg, ast.List): # Handle list literals
|
| 446 |
+
return [evaluate_arg(el) for el in arg.elts]
|
| 447 |
+
elif isinstance(arg, ast.Dict): # Handle dictionary literals
|
| 448 |
+
return {
|
| 449 |
+
evaluate_arg(k): evaluate_arg(v)
|
| 450 |
+
for k, v in zip(arg.keys, arg.values)
|
| 451 |
+
}
|
| 452 |
+
elif isinstance(arg, ast.Tuple): # Handle tuple literals
|
| 453 |
+
return tuple(evaluate_arg(el) for el in arg.elts)
|
| 454 |
+
else:
|
| 455 |
+
return ast.literal_eval(arg) # Safely evaluate other types
|
| 456 |
+
|
| 457 |
+
call = preprocess_input(call)
|
| 458 |
+
parsed_calls = []
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
# Parse the string into an AST
|
| 462 |
+
parsed_calls = call.split(";")
|
| 463 |
+
for single_call in parsed_calls:
|
| 464 |
+
tree = ast.parse(single_call, mode='eval')
|
| 465 |
+
|
| 466 |
+
# Ensure it's a function call
|
| 467 |
+
if isinstance(tree.body, ast.Call):
|
| 468 |
+
# Extract function name
|
| 469 |
+
if isinstance(
|
| 470 |
+
tree.body.func, ast.Name
|
| 471 |
+
): # Simple function call
|
| 472 |
+
func_name = tree.body.func.id
|
| 473 |
+
elif isinstance(
|
| 474 |
+
tree.body.func, ast.Attribute
|
| 475 |
+
): # Attribute function call
|
| 476 |
+
func_name = (
|
| 477 |
+
f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
raise ValueError(f"Unsupported function call: {call}")
|
| 481 |
+
|
| 482 |
+
# Extract positional arguments
|
| 483 |
+
args = [evaluate_arg(arg) for arg in tree.body.args]
|
| 484 |
+
|
| 485 |
+
# Extract keyword arguments
|
| 486 |
+
kwargs: Dict[str, Any] = {
|
| 487 |
+
kw.arg: evaluate_arg(kw.value)
|
| 488 |
+
for kw in tree.body.keywords
|
| 489 |
+
if kw.arg is not None
|
| 490 |
+
}
|
| 491 |
+
logger.info("Valid call.")
|
| 492 |
+
return func_name, args, kwargs
|
| 493 |
+
else:
|
| 494 |
+
raise ValueError(f"Not a valid function call: {call}")
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logger.info(f"Error parsing call: {call}, {e}")
|
| 497 |
+
return None, None, None
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
|
| 501 |
+
r"""Compare the function name and arguments of
|
| 502 |
+
agent_call and ground_truth_call.
|
| 503 |
+
Args:
|
| 504 |
+
agent_call (str): Function call by agent.
|
| 505 |
+
ground_truth_call (str): Ground truth function call.
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
- `True` if the function names and arguments match.
|
| 509 |
+
- `False` otherwise.
|
| 510 |
+
"""
|
| 511 |
+
# Parse both calls
|
| 512 |
+
agent_parsed = parse_function_call(agent_call)
|
| 513 |
+
gt_parsed = parse_function_call(ground_truth_call)
|
| 514 |
+
|
| 515 |
+
if agent_parsed and gt_parsed:
|
| 516 |
+
return agent_parsed == gt_parsed
|
| 517 |
+
else:
|
| 518 |
+
return False
|
camel/benchmarks/ragbench.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from datasets import Dataset, load_dataset
|
| 19 |
+
|
| 20 |
+
from camel.agents import ChatAgent
|
| 21 |
+
from camel.benchmarks import BaseBenchmark
|
| 22 |
+
from camel.logger import get_logger
|
| 23 |
+
from camel.retrievers import AutoRetriever
|
| 24 |
+
|
| 25 |
+
logger = get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RagasFields:
|
| 29 |
+
r"""Constants for RAGAS evaluation field names."""
|
| 30 |
+
|
| 31 |
+
INPUT_CONTEXT = "contexts"
|
| 32 |
+
INPUT_QUESTION = "question"
|
| 33 |
+
INPUT_ANSWER = "answer"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def annotate_dataset(
|
| 37 |
+
dataset: Dataset,
|
| 38 |
+
context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
|
| 39 |
+
answer_call: Optional[Callable[[Dict[str, Any]], str]],
|
| 40 |
+
) -> Dataset:
|
| 41 |
+
r"""Annotate the dataset by adding context and answers using the provided
|
| 42 |
+
functions.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dataset (Dataset): The input dataset to annotate.
|
| 46 |
+
context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
|
| 47 |
+
Function to generate context for each example.
|
| 48 |
+
answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
|
| 49 |
+
generate answer for each example.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dataset: The annotated dataset with added contexts and/or answers.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 56 |
+
if context_call:
|
| 57 |
+
example["contexts"] = context_call(example)
|
| 58 |
+
if answer_call:
|
| 59 |
+
example["answer"] = answer_call(example)
|
| 60 |
+
return example
|
| 61 |
+
|
| 62 |
+
return dataset.map(process_example)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def rmse(
|
| 66 |
+
input_trues: Sequence[float],
|
| 67 |
+
input_preds: Sequence[float],
|
| 68 |
+
) -> Optional[float]:
|
| 69 |
+
r"""Calculate Root Mean Squared Error (RMSE).
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
input_trues (Sequence[float]): Ground truth values.
|
| 73 |
+
input_preds (Sequence[float]): Predicted values.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Optional[float]: RMSE value, or None if inputs have different lengths.
|
| 77 |
+
"""
|
| 78 |
+
if len(input_trues) != len(input_preds):
|
| 79 |
+
logger.warning("Input lengths mismatch in RMSE calculation")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
trues = np.array(input_trues)
|
| 83 |
+
preds = np.array(input_preds, dtype=float)
|
| 84 |
+
|
| 85 |
+
# Ignore NaN values in predictions
|
| 86 |
+
eval_idx = ~np.isnan(preds)
|
| 87 |
+
if not np.any(eval_idx):
|
| 88 |
+
logger.warning("No valid predictions for RMSE calculation")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
trues = trues[eval_idx]
|
| 92 |
+
preds = preds[eval_idx]
|
| 93 |
+
|
| 94 |
+
return float(np.sqrt(np.mean((preds - trues) ** 2)))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
|
| 98 |
+
r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
trues (Sequence[bool]): Ground truth binary values.
|
| 102 |
+
preds (Sequence[float]): Predicted probability values.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
float: AUROC score.
|
| 106 |
+
"""
|
| 107 |
+
from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
|
| 108 |
+
|
| 109 |
+
eval_idx = ~np.isnan(preds)
|
| 110 |
+
if not np.any(eval_idx):
|
| 111 |
+
logger.warning("No valid predictions for AUROC calculation")
|
| 112 |
+
return 0.5 # Return random classifier score
|
| 113 |
+
|
| 114 |
+
return float(
|
| 115 |
+
roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def ragas_calculate_metrics(
|
| 120 |
+
dataset: Dataset,
|
| 121 |
+
pred_context_relevance_field: Optional[str],
|
| 122 |
+
pred_faithfulness_field: Optional[str],
|
| 123 |
+
metrics_to_evaluate: Optional[List[str]] = None,
|
| 124 |
+
ground_truth_context_relevance_field: str = "relevance_score",
|
| 125 |
+
ground_truth_faithfulness_field: str = "adherence_score",
|
| 126 |
+
) -> Dict[str, Optional[float]]:
|
| 127 |
+
r"""Calculate RAGAS evaluation metrics.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
dataset (Dataset): The dataset containing predictions and ground truth.
|
| 131 |
+
pred_context_relevance_field (Optional[str]): Field name for predicted
|
| 132 |
+
context relevance.
|
| 133 |
+
pred_faithfulness_field (Optional[str]): Field name for predicted
|
| 134 |
+
faithfulness.
|
| 135 |
+
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
| 136 |
+
ground_truth_context_relevance_field (str): Field name for ground truth
|
| 137 |
+
relevance.
|
| 138 |
+
ground_truth_faithfulness_field (str): Field name for ground truth
|
| 139 |
+
adherence.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Dict[str, Optional[float]]: Dictionary of calculated metrics.
|
| 143 |
+
"""
|
| 144 |
+
metrics_to_evaluate = metrics_to_evaluate or [
|
| 145 |
+
"context_relevancy",
|
| 146 |
+
"faithfulness",
|
| 147 |
+
]
|
| 148 |
+
calculated_metrics: Dict[str, Optional[float]] = {}
|
| 149 |
+
|
| 150 |
+
if (
|
| 151 |
+
"context_relevancy" in metrics_to_evaluate
|
| 152 |
+
and pred_context_relevance_field
|
| 153 |
+
):
|
| 154 |
+
trues_relevance = dataset[ground_truth_context_relevance_field]
|
| 155 |
+
preds_relevance = dataset[pred_context_relevance_field]
|
| 156 |
+
calculated_metrics["relevance_rmse"] = rmse(
|
| 157 |
+
trues_relevance, preds_relevance
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
|
| 161 |
+
trues_hallucination = ~np.array(
|
| 162 |
+
dataset[ground_truth_faithfulness_field]
|
| 163 |
+
)
|
| 164 |
+
preds_hallucination = 1 - np.array(
|
| 165 |
+
dataset[pred_faithfulness_field], dtype=float
|
| 166 |
+
)
|
| 167 |
+
calculated_metrics["hallucination_auroc"] = auroc(
|
| 168 |
+
trues_hallucination.tolist(), preds_hallucination.tolist()
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return calculated_metrics
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def ragas_evaluate_dataset(
|
| 175 |
+
dataset: Dataset,
|
| 176 |
+
contexts_field_name: Optional[str],
|
| 177 |
+
answer_field_name: Optional[str],
|
| 178 |
+
metrics_to_evaluate: Optional[List[str]] = None,
|
| 179 |
+
) -> Dataset:
|
| 180 |
+
r"""Evaluate the dataset using RAGAS metrics.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
dataset (Dataset): Input dataset to evaluate.
|
| 184 |
+
contexts_field_name (Optional[str]): Field name containing contexts.
|
| 185 |
+
answer_field_name (Optional[str]): Field name containing answers.
|
| 186 |
+
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Dataset: Dataset with added evaluation metrics.
|
| 190 |
+
"""
|
| 191 |
+
from ragas import evaluate
|
| 192 |
+
from ragas.metrics import ( # type: ignore[import-untyped]
|
| 193 |
+
context_relevancy,
|
| 194 |
+
faithfulness,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
metrics_to_evaluate = metrics_to_evaluate or [
|
| 198 |
+
"context_relevancy",
|
| 199 |
+
"faithfulness",
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
# Rename fields if necessary
|
| 203 |
+
if (
|
| 204 |
+
contexts_field_name
|
| 205 |
+
and contexts_field_name != RagasFields.INPUT_CONTEXT
|
| 206 |
+
):
|
| 207 |
+
dataset = dataset.rename_column(
|
| 208 |
+
contexts_field_name, RagasFields.INPUT_CONTEXT
|
| 209 |
+
)
|
| 210 |
+
if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
|
| 211 |
+
dataset = dataset.rename_column(
|
| 212 |
+
answer_field_name, RagasFields.INPUT_ANSWER
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
metrics = []
|
| 216 |
+
if "context_relevancy" in metrics_to_evaluate:
|
| 217 |
+
metrics.append(context_relevancy)
|
| 218 |
+
if "faithfulness" in metrics_to_evaluate:
|
| 219 |
+
metrics.append(faithfulness)
|
| 220 |
+
|
| 221 |
+
ragas_result = evaluate(dataset, metrics=metrics)
|
| 222 |
+
return Dataset.from_pandas(ragas_result.to_pandas())
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class RAGBenchBenchmark(BaseBenchmark):
|
| 226 |
+
r"""RAGBench Benchmark for evaluating RAG performance.
|
| 227 |
+
|
| 228 |
+
This benchmark uses the rungalileo/ragbench dataset to evaluate
|
| 229 |
+
retrieval-augmented generation (RAG) systems. It measures context
|
| 230 |
+
relevancy and faithfulness metrics as described in
|
| 231 |
+
https://arxiv.org/abs/2407.11005.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
processes (int, optional): Number of processes for parallel processing.
|
| 235 |
+
subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
|
| 236 |
+
split (str, optional): Dataset split to use (e.g., "test").
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
processes: int = 1,
|
| 242 |
+
subset: Literal[
|
| 243 |
+
"covidqa",
|
| 244 |
+
"cuad",
|
| 245 |
+
"delucionqa",
|
| 246 |
+
"emanual",
|
| 247 |
+
"expertqa",
|
| 248 |
+
"finqa",
|
| 249 |
+
"hagrid",
|
| 250 |
+
"hotpotqa",
|
| 251 |
+
"msmarco",
|
| 252 |
+
"pubmedqa",
|
| 253 |
+
"tatqa",
|
| 254 |
+
"techqa",
|
| 255 |
+
] = "hotpotqa",
|
| 256 |
+
split: Literal["train", "test", "validation"] = "test",
|
| 257 |
+
) -> None:
|
| 258 |
+
super().__init__("ragbench", "rag_bench", "", processes)
|
| 259 |
+
self.subset = subset
|
| 260 |
+
self.split = split
|
| 261 |
+
self.dataset: Optional[Dataset] = None
|
| 262 |
+
|
| 263 |
+
def download(self):
|
| 264 |
+
r"""Download the RAGBench dataset."""
|
| 265 |
+
try:
|
| 266 |
+
self.dataset = load_dataset(
|
| 267 |
+
"rungalileo/ragbench", self.subset, split=self.split
|
| 268 |
+
)
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Failed to download dataset: {e}")
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
def load(self, force_download: bool = False):
|
| 274 |
+
r"""Load the RAGBench dataset.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
force_download (bool, optional): Whether to force download the
|
| 278 |
+
data.
|
| 279 |
+
"""
|
| 280 |
+
if force_download or self.dataset is None:
|
| 281 |
+
logger.info(
|
| 282 |
+
"%s dataset",
|
| 283 |
+
"Force downloading" if force_download else "Loading",
|
| 284 |
+
)
|
| 285 |
+
self.download()
|
| 286 |
+
|
| 287 |
+
def run( # type: ignore[override, return]
|
| 288 |
+
self,
|
| 289 |
+
agent: ChatAgent,
|
| 290 |
+
auto_retriever: AutoRetriever,
|
| 291 |
+
) -> Dict[str, Optional[float]]:
|
| 292 |
+
r"""Run the benchmark evaluation.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
agent (ChatAgent): Chat agent for generating answers.
|
| 296 |
+
auto_retriever (AutoRetriever): Retriever for finding relevant
|
| 297 |
+
contexts.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Dict[str, Optional[float]]: Dictionary of evaluation metrics.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def context_call(example):
|
| 304 |
+
retrieved_info = auto_retriever.run_vector_retriever(
|
| 305 |
+
query=example['question'],
|
| 306 |
+
contents=example['documents'],
|
| 307 |
+
top_k=1,
|
| 308 |
+
return_detailed_info=True,
|
| 309 |
+
similarity_threshold=0.5,
|
| 310 |
+
)
|
| 311 |
+
return [c['text'] for c in retrieved_info['Retrieved Context']]
|
| 312 |
+
|
| 313 |
+
def answer_call(example: Dict[str, Any]) -> str:
|
| 314 |
+
user_msg = str(example)
|
| 315 |
+
assistant_response = agent.step(user_msg)
|
| 316 |
+
return assistant_response.msg.content
|
| 317 |
+
|
| 318 |
+
# Annotate the dataset
|
| 319 |
+
annotated_ds = annotate_dataset(
|
| 320 |
+
self.dataset, context_call, answer_call
|
| 321 |
+
)
|
| 322 |
+
evaluated_ds = ragas_evaluate_dataset(
|
| 323 |
+
annotated_ds,
|
| 324 |
+
contexts_field_name="contexts",
|
| 325 |
+
answer_field_name="answer",
|
| 326 |
+
metrics_to_evaluate=["context_relevancy", "faithfulness"],
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return ragas_calculate_metrics(
|
| 330 |
+
evaluated_ds,
|
| 331 |
+
pred_context_relevance_field="context_relevancy",
|
| 332 |
+
pred_faithfulness_field="faithfulness",
|
| 333 |
+
)
|
camel/bots/__init__.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .discord import DiscordApp
|
| 15 |
+
from .slack.models import (
|
| 16 |
+
SlackAppMentionEventBody,
|
| 17 |
+
SlackAppMentionEventProfile,
|
| 18 |
+
SlackAuthProfile,
|
| 19 |
+
SlackEventBody,
|
| 20 |
+
SlackEventProfile,
|
| 21 |
+
)
|
| 22 |
+
from .slack.slack_app import SlackApp
|
| 23 |
+
from .telegram_bot import TelegramBot
|
| 24 |
+
|
| 25 |
+
__all__ = [
|
| 26 |
+
'DiscordApp',
|
| 27 |
+
'SlackApp',
|
| 28 |
+
'SlackAppMentionEventBody',
|
| 29 |
+
'SlackAppMentionEventProfile',
|
| 30 |
+
'SlackAuthProfile',
|
| 31 |
+
'SlackEventBody',
|
| 32 |
+
'SlackEventProfile',
|
| 33 |
+
'TelegramBot',
|
| 34 |
+
]
|
camel/bots/discord/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .discord_app import DiscordApp
|
| 15 |
+
from .discord_installation import DiscordInstallation
|
| 16 |
+
from .discord_store import (
|
| 17 |
+
DiscordBaseInstallationStore,
|
| 18 |
+
DiscordSQLiteInstallationStore,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
"DiscordApp",
|
| 23 |
+
"DiscordInstallation",
|
| 24 |
+
"DiscordSQLiteInstallationStore",
|
| 25 |
+
"DiscordBaseInstallationStore",
|
| 26 |
+
]
|
camel/bots/discord/discord_app.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import os
|
| 15 |
+
from datetime import datetime, timedelta
|
| 16 |
+
from typing import TYPE_CHECKING, List, Optional
|
| 17 |
+
|
| 18 |
+
import discord
|
| 19 |
+
import httpx
|
| 20 |
+
from fastapi import FastAPI
|
| 21 |
+
|
| 22 |
+
from camel.bots.discord.discord_installation import DiscordInstallation
|
| 23 |
+
from camel.logger import get_logger
|
| 24 |
+
from camel.utils import api_keys_required, dependencies_required
|
| 25 |
+
|
| 26 |
+
from .discord_store import DiscordBaseInstallationStore
|
| 27 |
+
|
| 28 |
+
if TYPE_CHECKING:
|
| 29 |
+
from discord import Message
|
| 30 |
+
|
| 31 |
+
logger = get_logger(__name__)
|
| 32 |
+
|
| 33 |
+
TOKEN_URL = "https://discord.com/api/oauth2/token"
|
| 34 |
+
USER_URL = "https://discord.com/api/users/@me"
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class DiscordApp:
|
| 38 |
+
r"""A class representing a Discord app that uses the `discord.py` library
|
| 39 |
+
to interact with Discord servers.
|
| 40 |
+
|
| 41 |
+
This bot can respond to messages in specific channels and only reacts to
|
| 42 |
+
messages that mention the bot.
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
channel_ids (Optional[List[int]]): A list of allowed channel IDs. If
|
| 46 |
+
provided, the bot will only respond to messages in these channels.
|
| 47 |
+
token (Optional[str]): The Discord bot token used for authentication.
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
@dependencies_required('discord')
|
| 51 |
+
@api_keys_required(
|
| 52 |
+
[
|
| 53 |
+
("token", "DISCORD_BOT_TOKEN"),
|
| 54 |
+
]
|
| 55 |
+
)
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
channel_ids: Optional[List[int]] = None,
|
| 59 |
+
token: Optional[str] = None,
|
| 60 |
+
client_id: Optional[str] = None,
|
| 61 |
+
client_secret: Optional[str] = None,
|
| 62 |
+
redirect_uri: Optional[str] = None,
|
| 63 |
+
installation_store: Optional[DiscordBaseInstallationStore] = None,
|
| 64 |
+
intents: Optional[discord.Intents] = None,
|
| 65 |
+
) -> None:
|
| 66 |
+
r"""Initialize the DiscordApp instance by setting up the Discord client
|
| 67 |
+
and event handlers.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
channel_ids (Optional[List[int]]): A list of allowed channel IDs.
|
| 71 |
+
The bot will only respond to messages in these channels if
|
| 72 |
+
provided. (default: :obj:`None`)
|
| 73 |
+
token (Optional[str]): The Discord bot token for authentication.
|
| 74 |
+
If not provided, the token will be retrieved from the
|
| 75 |
+
environment variable `DISCORD_TOKEN`. (default: :obj:`None`)
|
| 76 |
+
client_id (str, optional): The client ID for Discord OAuth.
|
| 77 |
+
(default: :obj:`None`)
|
| 78 |
+
client_secret (Optional[str]): The client secret for Discord OAuth.
|
| 79 |
+
(default: :obj:`None`)
|
| 80 |
+
redirect_uri (str): The redirect URI for OAuth callbacks.
|
| 81 |
+
(default: :obj:`None`)
|
| 82 |
+
installation_store (DiscordAsyncInstallationStore): The database
|
| 83 |
+
stores all information of all installations.
|
| 84 |
+
(default: :obj:`None`)
|
| 85 |
+
intents (discord.Intents): The Discord intents of this app.
|
| 86 |
+
(default: :obj:`None`)
|
| 87 |
+
|
| 88 |
+
Raises:
|
| 89 |
+
ValueError: If the `DISCORD_BOT_TOKEN` is not found in environment
|
| 90 |
+
variables.
|
| 91 |
+
"""
|
| 92 |
+
self.token = token or os.getenv("DISCORD_BOT_TOKEN")
|
| 93 |
+
self.channel_ids = channel_ids
|
| 94 |
+
self.installation_store = installation_store
|
| 95 |
+
|
| 96 |
+
if not intents:
|
| 97 |
+
intents = discord.Intents.all()
|
| 98 |
+
intents.message_content = True
|
| 99 |
+
intents.guilds = True
|
| 100 |
+
|
| 101 |
+
self._client = discord.Client(intents=intents)
|
| 102 |
+
|
| 103 |
+
# Register event handlers
|
| 104 |
+
self._client.event(self.on_ready)
|
| 105 |
+
self._client.event(self.on_message)
|
| 106 |
+
|
| 107 |
+
# OAuth flow
|
| 108 |
+
self.client_id = client_id or os.getenv("DISCORD_CLIENT_ID")
|
| 109 |
+
self.client_secret = client_secret or os.getenv(
|
| 110 |
+
"DISCORD_CLIENT_SECRET"
|
| 111 |
+
)
|
| 112 |
+
self.redirect_uri = redirect_uri
|
| 113 |
+
|
| 114 |
+
self.oauth_flow = bool(
|
| 115 |
+
self.client_id
|
| 116 |
+
and self.client_secret
|
| 117 |
+
and self.redirect_uri
|
| 118 |
+
and self.installation_store
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.app = FastAPI()
|
| 122 |
+
|
| 123 |
+
async def start(self):
|
| 124 |
+
r"""Asynchronously start the Discord bot using its token.
|
| 125 |
+
|
| 126 |
+
This method starts the bot and logs into Discord asynchronously using
|
| 127 |
+
the provided token. It should be awaited when used in an async
|
| 128 |
+
environment.
|
| 129 |
+
"""
|
| 130 |
+
await self._client.start(self.token)
|
| 131 |
+
|
| 132 |
+
def run(self) -> None:
|
| 133 |
+
r"""Start the Discord bot using its token.
|
| 134 |
+
|
| 135 |
+
This method starts the bot and logs into Discord synchronously using
|
| 136 |
+
the provided token. It blocks execution and keeps the bot running.
|
| 137 |
+
"""
|
| 138 |
+
self._client.run(self.token) # type: ignore[arg-type]
|
| 139 |
+
|
| 140 |
+
async def exchange_code_for_token_response(
|
| 141 |
+
self, code: str
|
| 142 |
+
) -> Optional[str]:
|
| 143 |
+
r"""Exchange the authorization code for an access token.
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
code (str): The authorization code received from Discord after
|
| 147 |
+
user authorization.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Optional[str]: The access token if successful, otherwise None.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
ValueError: If OAuth configuration is incomplete or invalid.
|
| 154 |
+
httpx.RequestError: If there is a network issue during the request.
|
| 155 |
+
"""
|
| 156 |
+
if not self.oauth_flow:
|
| 157 |
+
logger.warning(
|
| 158 |
+
"OAuth is not enabled. Missing client_id, "
|
| 159 |
+
"client_secret, or redirect_uri."
|
| 160 |
+
)
|
| 161 |
+
return None
|
| 162 |
+
data = {
|
| 163 |
+
"client_id": self.client_id,
|
| 164 |
+
"client_secret": self.client_secret,
|
| 165 |
+
"grant_type": "authorization_code",
|
| 166 |
+
"code": code,
|
| 167 |
+
"redirect_uri": self.redirect_uri,
|
| 168 |
+
}
|
| 169 |
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
| 170 |
+
try:
|
| 171 |
+
async with httpx.AsyncClient() as client:
|
| 172 |
+
response = await client.post(
|
| 173 |
+
TOKEN_URL, data=data, headers=headers
|
| 174 |
+
)
|
| 175 |
+
if response.status_code != 200:
|
| 176 |
+
logger.error(f"Failed to exchange code: {response.text}")
|
| 177 |
+
return None
|
| 178 |
+
response_data = response.json()
|
| 179 |
+
|
| 180 |
+
return response_data
|
| 181 |
+
except (httpx.RequestError, ValueError) as e:
|
| 182 |
+
logger.error(f"Error during token fetch: {e}")
|
| 183 |
+
return None
|
| 184 |
+
|
| 185 |
+
async def get_user_info(self, access_token: str) -> Optional[dict]:
|
| 186 |
+
r"""Retrieve user information using the access token.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
access_token (str): The access token received from Discord.
|
| 190 |
+
|
| 191 |
+
Returns:
|
| 192 |
+
dict: The user information retrieved from Discord.
|
| 193 |
+
"""
|
| 194 |
+
if not self.oauth_flow:
|
| 195 |
+
logger.warning(
|
| 196 |
+
"OAuth is not enabled. Missing client_id, "
|
| 197 |
+
"client_secret, or redirect_uri."
|
| 198 |
+
)
|
| 199 |
+
return None
|
| 200 |
+
headers = {"Authorization": f"Bearer {access_token}"}
|
| 201 |
+
async with httpx.AsyncClient() as client:
|
| 202 |
+
user_response = await client.get(USER_URL, headers=headers)
|
| 203 |
+
return user_response.json()
|
| 204 |
+
|
| 205 |
+
async def refresh_access_token(self, refresh_token: str) -> Optional[str]:
|
| 206 |
+
r"""Refresh the access token using a refresh token.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
refresh_token (str): The refresh token issued by Discord that
|
| 210 |
+
can be used to obtain a new access token.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
Optional[str]: The new access token if successful, otherwise None.
|
| 214 |
+
"""
|
| 215 |
+
if not self.oauth_flow:
|
| 216 |
+
logger.warning(
|
| 217 |
+
"OAuth is not enabled. Missing client_id, "
|
| 218 |
+
"client_secret, or redirect_uri."
|
| 219 |
+
)
|
| 220 |
+
return None
|
| 221 |
+
data = {
|
| 222 |
+
"client_id": self.client_id,
|
| 223 |
+
"client_secret": self.client_secret,
|
| 224 |
+
"grant_type": "refresh_token",
|
| 225 |
+
"refresh_token": refresh_token,
|
| 226 |
+
"redirect_uri": self.redirect_uri,
|
| 227 |
+
}
|
| 228 |
+
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
| 229 |
+
async with httpx.AsyncClient() as client:
|
| 230 |
+
response = await client.post(TOKEN_URL, data=data, headers=headers)
|
| 231 |
+
if response.status_code != 200:
|
| 232 |
+
logger.error(f"Failed to refresh token: {response.text}")
|
| 233 |
+
return None
|
| 234 |
+
response_data = response.json()
|
| 235 |
+
return response_data.get("access_token")
|
| 236 |
+
|
| 237 |
+
async def get_valid_access_token(self, guild_id: str) -> Optional[str]:
|
| 238 |
+
r"""Retrieve a valid access token for the specified guild.
|
| 239 |
+
|
| 240 |
+
This method attempts to retrieve an access token for a specific guild.
|
| 241 |
+
If the current access token is expired, it will refresh the token using
|
| 242 |
+
the refresh token.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
guild_id (str): The ID of the guild to retrieve the access
|
| 246 |
+
token for.
|
| 247 |
+
|
| 248 |
+
Returns:
|
| 249 |
+
Optional[str]: The valid access token if successful,
|
| 250 |
+
otherwise None.
|
| 251 |
+
"""
|
| 252 |
+
if not self.oauth_flow:
|
| 253 |
+
logger.warning(
|
| 254 |
+
"OAuth is not enabled. Missing client_id, "
|
| 255 |
+
"client_secret, or redirect_uri."
|
| 256 |
+
)
|
| 257 |
+
return None
|
| 258 |
+
assert self.installation_store is not None
|
| 259 |
+
installation = await self.installation_store.find_by_guild(
|
| 260 |
+
guild_id=guild_id
|
| 261 |
+
)
|
| 262 |
+
if not installation:
|
| 263 |
+
logger.error(f"No installation found for guild: {guild_id}")
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
if (
|
| 267 |
+
installation.token_expires_at
|
| 268 |
+
and datetime.now() >= installation.token_expires_at
|
| 269 |
+
):
|
| 270 |
+
logger.info(
|
| 271 |
+
f"Access token expired for guild: {guild_id}, "
|
| 272 |
+
f"refreshing token..."
|
| 273 |
+
)
|
| 274 |
+
new_access_token = await self.refresh_access_token(
|
| 275 |
+
installation.refresh_token
|
| 276 |
+
)
|
| 277 |
+
if new_access_token:
|
| 278 |
+
installation.access_token = new_access_token
|
| 279 |
+
installation.token_expires_at = datetime.now() + timedelta(
|
| 280 |
+
seconds=3600
|
| 281 |
+
)
|
| 282 |
+
await self.installation_store.save(installation)
|
| 283 |
+
return new_access_token
|
| 284 |
+
else:
|
| 285 |
+
logger.error(
|
| 286 |
+
f"Failed to refresh access token for guild: {guild_id}"
|
| 287 |
+
)
|
| 288 |
+
return None
|
| 289 |
+
|
| 290 |
+
return installation.access_token
|
| 291 |
+
|
| 292 |
+
async def save_installation(
|
| 293 |
+
self,
|
| 294 |
+
guild_id: str,
|
| 295 |
+
access_token: str,
|
| 296 |
+
refresh_token: str,
|
| 297 |
+
expires_in: int,
|
| 298 |
+
):
|
| 299 |
+
r"""Save the installation information for a given guild.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
guild_id (str): The ID of the guild where the bot is installed.
|
| 303 |
+
access_token (str): The access token for the guild.
|
| 304 |
+
refresh_token (str): The refresh token for the guild.
|
| 305 |
+
expires_in: (int): The expiration time of the
|
| 306 |
+
access token.
|
| 307 |
+
"""
|
| 308 |
+
if not self.oauth_flow:
|
| 309 |
+
logger.warning(
|
| 310 |
+
"OAuth is not enabled. Missing client_id, "
|
| 311 |
+
"client_secret, or redirect_uri."
|
| 312 |
+
)
|
| 313 |
+
return None
|
| 314 |
+
assert self.installation_store is not None
|
| 315 |
+
expires_at = datetime.now() + timedelta(seconds=expires_in)
|
| 316 |
+
installation = DiscordInstallation(
|
| 317 |
+
guild_id=guild_id,
|
| 318 |
+
access_token=access_token,
|
| 319 |
+
refresh_token=refresh_token,
|
| 320 |
+
installed_at=datetime.now(),
|
| 321 |
+
token_expires_at=expires_at,
|
| 322 |
+
)
|
| 323 |
+
await self.installation_store.save(installation)
|
| 324 |
+
logger.info(f"Installation saved for guild: {guild_id}")
|
| 325 |
+
|
| 326 |
+
async def remove_installation(self, guild: discord.Guild):
|
| 327 |
+
r"""Remove the installation for a given guild.
|
| 328 |
+
|
| 329 |
+
Args:
|
| 330 |
+
guild (discord.Guild): The guild from which the bot is
|
| 331 |
+
being removed.
|
| 332 |
+
"""
|
| 333 |
+
if not self.oauth_flow:
|
| 334 |
+
logger.warning(
|
| 335 |
+
"OAuth is not enabled. Missing client_id, "
|
| 336 |
+
"client_secret, or redirect_uri."
|
| 337 |
+
)
|
| 338 |
+
return None
|
| 339 |
+
assert self.installation_store is not None
|
| 340 |
+
await self.installation_store.delete(guild_id=str(guild.id))
|
| 341 |
+
print(f"Bot removed from guild: {guild.id}")
|
| 342 |
+
|
| 343 |
+
async def on_ready(self) -> None:
|
| 344 |
+
r"""Event handler that is called when the bot has successfully
|
| 345 |
+
connected to the Discord server.
|
| 346 |
+
|
| 347 |
+
When the bot is ready and logged into Discord, it prints a message
|
| 348 |
+
displaying the bot's username.
|
| 349 |
+
"""
|
| 350 |
+
logger.info(f'We have logged in as {self._client.user}')
|
| 351 |
+
|
| 352 |
+
async def on_message(self, message: 'Message') -> None:
|
| 353 |
+
r"""Event handler for processing incoming messages.
|
| 354 |
+
|
| 355 |
+
This method is called whenever a new message is received by the bot. It
|
| 356 |
+
will ignore messages sent by the bot itself, only respond to messages
|
| 357 |
+
in allowed channels (if specified), and only to messages that mention
|
| 358 |
+
the bot.
|
| 359 |
+
|
| 360 |
+
Args:
|
| 361 |
+
message (discord.Message): The message object received from
|
| 362 |
+
Discord.
|
| 363 |
+
"""
|
| 364 |
+
# If the message author is the bot itself,
|
| 365 |
+
# do not respond to this message
|
| 366 |
+
if message.author == self._client.user:
|
| 367 |
+
return
|
| 368 |
+
|
| 369 |
+
# If allowed channel IDs are provided,
|
| 370 |
+
# only respond to messages in those channels
|
| 371 |
+
if self.channel_ids and message.channel.id not in self.channel_ids:
|
| 372 |
+
return
|
| 373 |
+
|
| 374 |
+
# Only respond to messages that mention the bot
|
| 375 |
+
if not self._client.user or not self._client.user.mentioned_in(
|
| 376 |
+
message
|
| 377 |
+
):
|
| 378 |
+
return
|
| 379 |
+
|
| 380 |
+
logger.info(f"Received message: {message.content}")
|
| 381 |
+
|
| 382 |
+
@property
|
| 383 |
+
def client(self):
|
| 384 |
+
return self._client
|
camel/bots/discord/discord_installation.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from datetime import datetime
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DiscordInstallation:
|
| 19 |
+
r"""Represents an installation of a Discord application in a
|
| 20 |
+
specific guild (server).
|
| 21 |
+
|
| 22 |
+
Attributes:
|
| 23 |
+
guild_id (str): The unique identifier for the Discord guild (server)
|
| 24 |
+
where the application is installed.
|
| 25 |
+
access_token (str): The access token used to authenticate API requests
|
| 26 |
+
for the installed application.
|
| 27 |
+
refresh_token (str): The token used to refresh the access token when
|
| 28 |
+
it expires.
|
| 29 |
+
installed_at (datetime): The timestamp indicating when the application
|
| 30 |
+
was installed in the guild.
|
| 31 |
+
token_expires_at (Optional[datetime]): The optional timestamp
|
| 32 |
+
indicating when the access token will expire. Defaults to None
|
| 33 |
+
if the token does not have an expiration time.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
guild_id: str,
|
| 39 |
+
access_token: str,
|
| 40 |
+
refresh_token: str,
|
| 41 |
+
installed_at: datetime,
|
| 42 |
+
token_expires_at: Optional[datetime] = None,
|
| 43 |
+
):
|
| 44 |
+
r"""Initialize the DiscordInstallation.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
guild_id (str): The unique identifier for the Discord guild
|
| 48 |
+
(server) where the application is installed.
|
| 49 |
+
access_token (str): The access token used to authenticate API
|
| 50 |
+
requests for the installed application.
|
| 51 |
+
refresh_token (str): The token used to refresh the access token
|
| 52 |
+
when it expires.
|
| 53 |
+
installed_at (datetime): The timestamp indicating when the
|
| 54 |
+
application was installed in the guild.
|
| 55 |
+
token_expires_at (Optional[datetime]): The optional timestamp
|
| 56 |
+
indicating when the access token will expire. Defaults to None
|
| 57 |
+
if the token does not have an expiration time.
|
| 58 |
+
(default: :obj:`None`)
|
| 59 |
+
"""
|
| 60 |
+
self.guild_id = guild_id
|
| 61 |
+
self.access_token = access_token
|
| 62 |
+
self.refresh_token = refresh_token
|
| 63 |
+
self.installed_at = installed_at
|
| 64 |
+
self.token_expires_at = token_expires_at
|
camel/bots/discord/discord_store.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
|
| 17 |
+
from .discord_installation import DiscordInstallation
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class DiscordBaseInstallationStore:
|
| 21 |
+
r"""Abstract base class for managing Discord installations.
|
| 22 |
+
|
| 23 |
+
This class defines the interface for database operations related to storing
|
| 24 |
+
and retrieving Discord installation data. Subclasses must implement these
|
| 25 |
+
methods to handle database-specific logic.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
async def init(self):
|
| 29 |
+
r"""Initializes the database connection or structure."""
|
| 30 |
+
pass
|
| 31 |
+
|
| 32 |
+
async def save(self, installation: DiscordInstallation):
|
| 33 |
+
r"""Saves or updates a Discord installation record."""
|
| 34 |
+
pass
|
| 35 |
+
|
| 36 |
+
async def find_by_guild(
|
| 37 |
+
self, guild_id: str
|
| 38 |
+
) -> Optional[DiscordInstallation]:
|
| 39 |
+
r"""Finds an installation record by guild ID."""
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
async def delete(self, guild_id: str):
|
| 43 |
+
r"""Deletes an installation record by guild ID."""
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class DiscordSQLiteInstallationStore(DiscordBaseInstallationStore):
|
| 48 |
+
r"""SQLite-based implementation for managing Discord installations.
|
| 49 |
+
|
| 50 |
+
This class provides methods for initializing the database, saving,
|
| 51 |
+
retrieving, and deleting installation records using SQLite.
|
| 52 |
+
|
| 53 |
+
Attributes:
|
| 54 |
+
database (str): Path to the SQLite database file.
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
def __init__(self, database: str):
|
| 58 |
+
r"""Initializes the SQLite installation store.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
database (str): Path to the SQLite database file.
|
| 62 |
+
"""
|
| 63 |
+
self.database = database
|
| 64 |
+
|
| 65 |
+
async def init(self):
|
| 66 |
+
r"""Initializes the database by creating the required table if it
|
| 67 |
+
does not exist."""
|
| 68 |
+
import aiosqlite
|
| 69 |
+
|
| 70 |
+
async with aiosqlite.connect(self.database) as db:
|
| 71 |
+
await db.execute(
|
| 72 |
+
"""
|
| 73 |
+
CREATE TABLE IF NOT EXISTS discord_installations (
|
| 74 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 75 |
+
guild_id TEXT NOT NULL UNIQUE,
|
| 76 |
+
access_token TEXT NOT NULL,
|
| 77 |
+
refresh_token TEXT NOT NULL,
|
| 78 |
+
installed_at DATETIME NOT NULL,
|
| 79 |
+
token_expires_at DATETIME
|
| 80 |
+
);
|
| 81 |
+
"""
|
| 82 |
+
)
|
| 83 |
+
await db.commit()
|
| 84 |
+
|
| 85 |
+
async def save(self, installation: DiscordInstallation):
|
| 86 |
+
r"""Saves a new installation record or updates an existing one.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
installation (DiscordInstallation): The installation data to save.
|
| 90 |
+
"""
|
| 91 |
+
import aiosqlite
|
| 92 |
+
|
| 93 |
+
async with aiosqlite.connect(self.database) as db:
|
| 94 |
+
await db.execute(
|
| 95 |
+
"""
|
| 96 |
+
INSERT INTO discord_installations (
|
| 97 |
+
guild_id, access_token, refresh_token,
|
| 98 |
+
installed_at, token_expires_at
|
| 99 |
+
) VALUES (?, ?, ?, ?, ?)
|
| 100 |
+
ON CONFLICT(guild_id) DO UPDATE SET
|
| 101 |
+
access_token = excluded.access_token,
|
| 102 |
+
refresh_token = excluded.refresh_token,
|
| 103 |
+
token_expires_at = excluded.token_expires_at;
|
| 104 |
+
""",
|
| 105 |
+
[
|
| 106 |
+
installation.guild_id,
|
| 107 |
+
installation.access_token,
|
| 108 |
+
installation.refresh_token,
|
| 109 |
+
installation.installed_at,
|
| 110 |
+
installation.token_expires_at,
|
| 111 |
+
],
|
| 112 |
+
)
|
| 113 |
+
await db.commit()
|
| 114 |
+
|
| 115 |
+
async def find_by_guild(
|
| 116 |
+
self, guild_id: str
|
| 117 |
+
) -> Optional[DiscordInstallation]:
|
| 118 |
+
r"""Finds an installation record by guild ID.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
guild_id (str): The guild ID to search for.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
Optional[DiscordInstallation]: The installation record if found,
|
| 125 |
+
otherwise None.
|
| 126 |
+
"""
|
| 127 |
+
import aiosqlite
|
| 128 |
+
|
| 129 |
+
async with aiosqlite.connect(self.database) as db:
|
| 130 |
+
async with db.execute(
|
| 131 |
+
"SELECT guild_id, access_token, refresh_token, "
|
| 132 |
+
"installed_at, token_expires_at FROM discord_installations "
|
| 133 |
+
"WHERE guild_id = ?",
|
| 134 |
+
[guild_id],
|
| 135 |
+
) as cursor:
|
| 136 |
+
row = await cursor.fetchone()
|
| 137 |
+
if row:
|
| 138 |
+
return DiscordInstallation(
|
| 139 |
+
guild_id=row[0],
|
| 140 |
+
access_token=row[1],
|
| 141 |
+
refresh_token=row[2],
|
| 142 |
+
installed_at=row[3],
|
| 143 |
+
token_expires_at=row[4],
|
| 144 |
+
)
|
| 145 |
+
return None
|
| 146 |
+
|
| 147 |
+
async def delete(self, guild_id: str):
|
| 148 |
+
r"""Deletes an installation record by guild ID.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
guild_id (str): The guild ID of the record to delete.
|
| 152 |
+
"""
|
| 153 |
+
import aiosqlite
|
| 154 |
+
|
| 155 |
+
async with aiosqlite.connect(self.database) as db:
|
| 156 |
+
await db.execute(
|
| 157 |
+
"DELETE FROM discord_installations WHERE guild_id = ?",
|
| 158 |
+
[guild_id],
|
| 159 |
+
)
|
| 160 |
+
await db.commit()
|
camel/bots/slack/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .models import (
|
| 15 |
+
SlackAppMentionEventBody,
|
| 16 |
+
SlackAppMentionEventProfile,
|
| 17 |
+
SlackAuthProfile,
|
| 18 |
+
SlackEventBody,
|
| 19 |
+
SlackEventProfile,
|
| 20 |
+
)
|
| 21 |
+
from .slack_app import SlackApp
|
| 22 |
+
|
| 23 |
+
__all__ = [
|
| 24 |
+
'SlackApp',
|
| 25 |
+
'SlackAppMentionEventBody',
|
| 26 |
+
'SlackAppMentionEventProfile',
|
| 27 |
+
'SlackAuthProfile',
|
| 28 |
+
'SlackEventBody',
|
| 29 |
+
'SlackEventProfile',
|
| 30 |
+
]
|
camel/bots/slack/models.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from pydantic import BaseModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SlackAuthProfile(BaseModel):
|
| 20 |
+
r"""Represents the authorization profile within a Slack event.
|
| 21 |
+
|
| 22 |
+
Events will contain a single, compact authorizations field that shows one
|
| 23 |
+
installation of your app that the event is visible to.
|
| 24 |
+
In other words, lists of authorizations will be truncated to one element.
|
| 25 |
+
|
| 26 |
+
If there's more than one installing party that your app is keeping track
|
| 27 |
+
of, it's best not to rely on the single party listed in authorizations to
|
| 28 |
+
be any particular one.
|
| 29 |
+
|
| 30 |
+
To get a full list of who can see events, call the apps.event.
|
| 31 |
+
authorizations.list method after obtaining an app-level token. Read more on
|
| 32 |
+
the changes here; they have taken effect for existing apps as of
|
| 33 |
+
February 24, 2021.
|
| 34 |
+
|
| 35 |
+
References:
|
| 36 |
+
|
| 37 |
+
- https://api.slack.com/apis/events-api#authorizations
|
| 38 |
+
- https://api.slack.com/changelog/2020-09-15-events-api-truncate-authed-users#no_context
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
enterprise_id: Optional[str] = None
|
| 42 |
+
"""The ID of the enterprise associated with the authorization."""
|
| 43 |
+
|
| 44 |
+
team_id: str
|
| 45 |
+
"""The ID of the team associated with the authorization."""
|
| 46 |
+
|
| 47 |
+
user_id: str
|
| 48 |
+
"""The ID of the user associated with the authorization."""
|
| 49 |
+
|
| 50 |
+
is_bot: bool
|
| 51 |
+
"""Whether the authorized user is a bot."""
|
| 52 |
+
|
| 53 |
+
is_enterprise_install: bool
|
| 54 |
+
"""Whether the authorization is for an enterprise installation."""
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class SlackEventProfile(BaseModel):
|
| 58 |
+
r"""Represents the detailed profile of a Slack event, including user,
|
| 59 |
+
message, and context data.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
user: str
|
| 63 |
+
"""The ID of the user associated with the event."""
|
| 64 |
+
|
| 65 |
+
type: str
|
| 66 |
+
"""The type of the event (e.g., 'message')."""
|
| 67 |
+
|
| 68 |
+
ts: str
|
| 69 |
+
"""A timestamp representing when the event was triggered."""
|
| 70 |
+
|
| 71 |
+
thread_ts: Optional[str] = None
|
| 72 |
+
"""The timestamp of the parent message in a thread."""
|
| 73 |
+
|
| 74 |
+
client_msg_id: str
|
| 75 |
+
"""A unique ID generated by the client for the message (if available)."""
|
| 76 |
+
|
| 77 |
+
text: str
|
| 78 |
+
"""The message content text."""
|
| 79 |
+
|
| 80 |
+
team: str
|
| 81 |
+
"""The ID of the team that the event is associated with."""
|
| 82 |
+
|
| 83 |
+
blocks: list
|
| 84 |
+
"""The list of message blocks, providing structured information."""
|
| 85 |
+
|
| 86 |
+
channel: str
|
| 87 |
+
"""The ID of the Slack channel where the event happened."""
|
| 88 |
+
|
| 89 |
+
event_ts: str
|
| 90 |
+
"""The event-specific timestamp when it occurred."""
|
| 91 |
+
|
| 92 |
+
channel_type: Optional[str]
|
| 93 |
+
"""The type of Slack channel (e.g., 'channel', 'im')."""
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class SlackEventBody(BaseModel):
|
| 97 |
+
r"""Represents the entire body of a Slack event, including the event
|
| 98 |
+
profile, authorization, and context.
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
token: str
|
| 102 |
+
"""The token to verify the source of the event."""
|
| 103 |
+
|
| 104 |
+
team_id: str
|
| 105 |
+
"""The ID of the team where the event is happening."""
|
| 106 |
+
|
| 107 |
+
context_team_id: Optional[str]
|
| 108 |
+
"""The team ID for the shared channel context, if applicable."""
|
| 109 |
+
|
| 110 |
+
context_enterprise_id: Optional[str] = None
|
| 111 |
+
"""The enterprise ID for the shared channel context, if applicable."""
|
| 112 |
+
|
| 113 |
+
api_app_id: str
|
| 114 |
+
"""The unique identifier for the Slack app that received the event."""
|
| 115 |
+
|
| 116 |
+
event: SlackEventProfile
|
| 117 |
+
"""A detailed profile of the event"""
|
| 118 |
+
|
| 119 |
+
type: str
|
| 120 |
+
"""The overall type of event received (e.g., 'event_callback')."""
|
| 121 |
+
|
| 122 |
+
event_id: str
|
| 123 |
+
"""A unique identifier assigned to this event by Slack."""
|
| 124 |
+
|
| 125 |
+
event_time: int
|
| 126 |
+
"""The timestamp (in seconds) representing when the event was triggered."""
|
| 127 |
+
|
| 128 |
+
authorizations: Optional[list[SlackAuthProfile]] = None
|
| 129 |
+
"""An optional list of authorizations that describe which installation can
|
| 130 |
+
see the event."""
|
| 131 |
+
|
| 132 |
+
is_ext_shared_channel: bool
|
| 133 |
+
"""Indicates if the event is part of a shared channel between different
|
| 134 |
+
organizations."""
|
| 135 |
+
|
| 136 |
+
event_context: str
|
| 137 |
+
"""A unique string representing the context of the event."""
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SlackAppMentionEventProfile(SlackEventProfile):
|
| 141 |
+
r"""Represents the detailed profile of a Slack event where the app was
|
| 142 |
+
mentioned in a message.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
channel_type: Optional[str] = None
|
| 146 |
+
"""The type of Slack channel. it's None for app mentions."""
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class SlackAppMentionEventBody(SlackEventBody):
|
| 150 |
+
r"""Represents the entire body of a Slack event where the app was mentioned
|
| 151 |
+
in a message.
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
context_team_id: Optional[str] = None
|
| 155 |
+
"""A detailed profile of the event. it's None for app mentions."""
|
| 156 |
+
|
| 157 |
+
event: SlackAppMentionEventProfile
|
| 158 |
+
"""A detailed profile of the event"""
|
camel/bots/slack/slack_app.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import logging
|
| 15 |
+
import os
|
| 16 |
+
from typing import TYPE_CHECKING, Any, Dict, Optional
|
| 17 |
+
|
| 18 |
+
from slack_sdk.oauth.installation_store.async_installation_store import (
|
| 19 |
+
AsyncInstallationStore,
|
| 20 |
+
)
|
| 21 |
+
from starlette import requests, responses
|
| 22 |
+
|
| 23 |
+
from camel.bots.slack.models import (
|
| 24 |
+
SlackAppMentionEventBody,
|
| 25 |
+
SlackAppMentionEventProfile,
|
| 26 |
+
SlackEventBody,
|
| 27 |
+
SlackEventProfile,
|
| 28 |
+
)
|
| 29 |
+
from camel.utils import dependencies_required
|
| 30 |
+
|
| 31 |
+
if TYPE_CHECKING:
|
| 32 |
+
from slack_bolt.context.async_context import AsyncBoltContext
|
| 33 |
+
from slack_bolt.context.say.async_say import AsyncSay
|
| 34 |
+
from slack_sdk.web.async_client import AsyncWebClient
|
| 35 |
+
|
| 36 |
+
logging.basicConfig(level=logging.INFO)
|
| 37 |
+
logger = logging.getLogger(__name__)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SlackApp:
|
| 41 |
+
r"""Represents a Slack app that is powered by a Slack Bolt `AsyncApp`.
|
| 42 |
+
|
| 43 |
+
This class is responsible for initializing and managing the Slack
|
| 44 |
+
application by setting up event handlers, running the app server, and
|
| 45 |
+
handling events such as messages and mentions from Slack.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
token (Optional[str]): Slack API token for authentication.
|
| 49 |
+
scopes (Optional[str]): Slack app scopes for permissions.
|
| 50 |
+
signing_secret (Optional[str]): Signing secret for verifying Slack
|
| 51 |
+
requests.
|
| 52 |
+
client_id (Optional[str]): Slack app client ID.
|
| 53 |
+
client_secret (Optional[str]): Slack app client secret.
|
| 54 |
+
redirect_uri_path (str): The URI path for OAuth redirect, defaults to
|
| 55 |
+
"/slack/oauth_redirect".
|
| 56 |
+
installation_store (Optional[AsyncInstallationStore]): The installation
|
| 57 |
+
store for handling OAuth installations.
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
@dependencies_required('slack_bolt')
|
| 61 |
+
def __init__(
|
| 62 |
+
self,
|
| 63 |
+
token: Optional[str] = None,
|
| 64 |
+
scopes: Optional[str] = None,
|
| 65 |
+
signing_secret: Optional[str] = None,
|
| 66 |
+
client_id: Optional[str] = None,
|
| 67 |
+
client_secret: Optional[str] = None,
|
| 68 |
+
redirect_uri_path: str = "/slack/oauth_redirect",
|
| 69 |
+
installation_store: Optional[AsyncInstallationStore] = None,
|
| 70 |
+
) -> None:
|
| 71 |
+
r"""Initializes the SlackApp instance by setting up the Slack Bolt app
|
| 72 |
+
and configuring event handlers and OAuth settings.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
token (Optional[str]): The Slack API token.
|
| 76 |
+
scopes (Optional[str]): The scopes for Slack app permissions.
|
| 77 |
+
signing_secret (Optional[str]): The signing secret for verifying
|
| 78 |
+
requests.
|
| 79 |
+
client_id (Optional[str]): The Slack app client ID.
|
| 80 |
+
client_secret (Optional[str]): The Slack app client secret.
|
| 81 |
+
redirect_uri_path (str): The URI path for handling OAuth redirects
|
| 82 |
+
(default is "/slack/oauth_redirect").
|
| 83 |
+
installation_store (Optional[AsyncInstallationStore]): An optional
|
| 84 |
+
installation store for OAuth installations.
|
| 85 |
+
"""
|
| 86 |
+
from slack_bolt.adapter.starlette.async_handler import (
|
| 87 |
+
AsyncSlackRequestHandler,
|
| 88 |
+
)
|
| 89 |
+
from slack_bolt.app.async_app import AsyncApp
|
| 90 |
+
from slack_bolt.oauth.async_oauth_settings import AsyncOAuthSettings
|
| 91 |
+
|
| 92 |
+
self.token: Optional[str] = token or os.getenv("SLACK_TOKEN")
|
| 93 |
+
self.scopes: Optional[str] = scopes or os.getenv("SLACK_SCOPES")
|
| 94 |
+
self.signing_secret: Optional[str] = signing_secret or os.getenv(
|
| 95 |
+
"SLACK_SIGNING_SECRET"
|
| 96 |
+
)
|
| 97 |
+
self.client_id: Optional[str] = client_id or os.getenv(
|
| 98 |
+
"SLACK_CLIENT_ID"
|
| 99 |
+
)
|
| 100 |
+
self.client_secret: Optional[str] = client_secret or os.getenv(
|
| 101 |
+
"SLACK_CLIENT_SECRET"
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
if not all([self.token, self.scopes, self.signing_secret]):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
"`SLACK_TOKEN`, `SLACK_SCOPES`, and `SLACK_SIGNING_SECRET` "
|
| 107 |
+
"environment variables must be set. Get it here: "
|
| 108 |
+
"`https://api.slack.com/apps`."
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
# Setup OAuth settings if client ID and secret are provided
|
| 112 |
+
if self.client_id and self.client_secret:
|
| 113 |
+
self._app = AsyncApp(
|
| 114 |
+
oauth_settings=AsyncOAuthSettings(
|
| 115 |
+
client_id=self.client_id,
|
| 116 |
+
client_secret=self.client_secret,
|
| 117 |
+
scopes=self.scopes,
|
| 118 |
+
redirect_uri_path=redirect_uri_path,
|
| 119 |
+
),
|
| 120 |
+
logger=logger,
|
| 121 |
+
signing_secret=self.signing_secret,
|
| 122 |
+
installation_store=installation_store,
|
| 123 |
+
token=self.token,
|
| 124 |
+
)
|
| 125 |
+
else:
|
| 126 |
+
# Initialize Slack Bolt AsyncApp with settings
|
| 127 |
+
self._app = AsyncApp(
|
| 128 |
+
logger=logger,
|
| 129 |
+
signing_secret=self.signing_secret,
|
| 130 |
+
installation_store=installation_store,
|
| 131 |
+
token=self.token,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
self._handler = AsyncSlackRequestHandler(self._app)
|
| 135 |
+
self.setup_handlers()
|
| 136 |
+
|
| 137 |
+
def setup_handlers(self) -> None:
|
| 138 |
+
r"""Sets up the event handlers for Slack events, such as `app_mention`
|
| 139 |
+
and `message`.
|
| 140 |
+
|
| 141 |
+
This method registers the `app_mention` and `on_message` event handlers
|
| 142 |
+
with the Slack Bolt app to respond to Slack events.
|
| 143 |
+
"""
|
| 144 |
+
self._app.event("app_mention")(self.app_mention)
|
| 145 |
+
self._app.event("message")(self.on_message)
|
| 146 |
+
|
| 147 |
+
def run(
|
| 148 |
+
self,
|
| 149 |
+
port: int = 3000,
|
| 150 |
+
path: str = "/slack/events",
|
| 151 |
+
host: Optional[str] = None,
|
| 152 |
+
) -> None:
|
| 153 |
+
r"""Starts the Slack Bolt app server to listen for incoming Slack
|
| 154 |
+
events.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
port (int): The port on which the server should run (default is
|
| 158 |
+
3000).
|
| 159 |
+
path (str): The endpoint path for receiving Slack events (default
|
| 160 |
+
is "/slack/events").
|
| 161 |
+
host (Optional[str]): The hostname to bind the server (default is
|
| 162 |
+
None).
|
| 163 |
+
"""
|
| 164 |
+
self._app.start(port=port, path=path, host=host)
|
| 165 |
+
|
| 166 |
+
async def handle_request(
|
| 167 |
+
self, request: requests.Request
|
| 168 |
+
) -> responses.Response:
|
| 169 |
+
r"""Handles incoming requests from Slack through the request handler.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
request (Request): A Starlette request object representing the
|
| 173 |
+
incoming request.
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
The response generated by the Slack Bolt handler.
|
| 177 |
+
"""
|
| 178 |
+
return await self._handler.handle(request)
|
| 179 |
+
|
| 180 |
+
async def app_mention(
|
| 181 |
+
self,
|
| 182 |
+
context: "AsyncBoltContext",
|
| 183 |
+
client: "AsyncWebClient",
|
| 184 |
+
event: Dict[str, Any],
|
| 185 |
+
body: Dict[str, Any],
|
| 186 |
+
say: "AsyncSay",
|
| 187 |
+
) -> None:
|
| 188 |
+
r"""Event handler for `app_mention` events.
|
| 189 |
+
|
| 190 |
+
This method is triggered when someone mentions the app in Slack.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
context (AsyncBoltContext): The Slack Bolt context for the event.
|
| 194 |
+
client (AsyncWebClient): The Slack Web API client.
|
| 195 |
+
event (Dict[str, Any]): The event data for the app mention.
|
| 196 |
+
body (Dict[str, Any]): The full request body from Slack.
|
| 197 |
+
say (AsyncSay): A function to send a response back to the channel.
|
| 198 |
+
"""
|
| 199 |
+
event_profile = SlackAppMentionEventProfile(**event)
|
| 200 |
+
event_body = SlackAppMentionEventBody(**body)
|
| 201 |
+
|
| 202 |
+
logger.info(f"app_mention, context: {context}")
|
| 203 |
+
logger.info(f"app_mention, client: {client}")
|
| 204 |
+
logger.info(f"app_mention, event_profile: {event_profile}")
|
| 205 |
+
logger.info(f"app_mention, event_body: {event_body}")
|
| 206 |
+
logger.info(f"app_mention, say: {say}")
|
| 207 |
+
|
| 208 |
+
async def on_message(
|
| 209 |
+
self,
|
| 210 |
+
context: "AsyncBoltContext",
|
| 211 |
+
client: "AsyncWebClient",
|
| 212 |
+
event: Dict[str, Any],
|
| 213 |
+
body: Dict[str, Any],
|
| 214 |
+
say: "AsyncSay",
|
| 215 |
+
) -> None:
|
| 216 |
+
r"""Event handler for `message` events.
|
| 217 |
+
|
| 218 |
+
This method is triggered when the app receives a message in Slack.
|
| 219 |
+
|
| 220 |
+
Args:
|
| 221 |
+
context (AsyncBoltContext): The Slack Bolt context for the event.
|
| 222 |
+
client (AsyncWebClient): The Slack Web API client.
|
| 223 |
+
event (Dict[str, Any]): The event data for the message.
|
| 224 |
+
body (Dict[str, Any]): The full request body from Slack.
|
| 225 |
+
say (AsyncSay): A function to send a response back to the channel.
|
| 226 |
+
"""
|
| 227 |
+
await context.ack()
|
| 228 |
+
|
| 229 |
+
event_profile = SlackEventProfile(**event)
|
| 230 |
+
event_body = SlackEventBody(**body)
|
| 231 |
+
|
| 232 |
+
logger.info(f"on_message, context: {context}")
|
| 233 |
+
logger.info(f"on_message, client: {client}")
|
| 234 |
+
logger.info(f"on_message, event_profile: {event_profile}")
|
| 235 |
+
logger.info(f"on_message, event_body: {event_body}")
|
| 236 |
+
logger.info(f"on_message, say: {say}")
|
| 237 |
+
|
| 238 |
+
logger.info(f"Received message: {event_profile.text}")
|
| 239 |
+
|
| 240 |
+
def mention_me(
|
| 241 |
+
self, context: "AsyncBoltContext", body: SlackEventBody
|
| 242 |
+
) -> bool:
|
| 243 |
+
r"""Check if the bot is mentioned in the message.
|
| 244 |
+
|
| 245 |
+
Args:
|
| 246 |
+
context (AsyncBoltContext): The Slack Bolt context for the event.
|
| 247 |
+
body (SlackEventBody): The body of the Slack event.
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
bool: True if the bot is mentioned in the message, False otherwise.
|
| 251 |
+
"""
|
| 252 |
+
message = body.event.text
|
| 253 |
+
bot_user_id = context.bot_user_id
|
| 254 |
+
mention = f"<@{bot_user_id}>"
|
| 255 |
+
return mention in message
|
camel/bots/telegram_bot.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import os
|
| 15 |
+
from typing import TYPE_CHECKING, Optional
|
| 16 |
+
|
| 17 |
+
from camel.agents import ChatAgent
|
| 18 |
+
from camel.messages import BaseMessage
|
| 19 |
+
from camel.utils import dependencies_required
|
| 20 |
+
|
| 21 |
+
# Conditionally import telebot types only for type checking
|
| 22 |
+
if TYPE_CHECKING:
|
| 23 |
+
from telebot.types import ( # type: ignore[import-untyped]
|
| 24 |
+
Message,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class TelegramBot:
|
| 29 |
+
r"""Represents a Telegram bot that is powered by an agent.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
chat_agent (ChatAgent): Chat agent that will power the bot.
|
| 33 |
+
telegram_token (str, optional): The bot token.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
@dependencies_required('telebot')
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
chat_agent: ChatAgent,
|
| 40 |
+
telegram_token: Optional[str] = None,
|
| 41 |
+
) -> None:
|
| 42 |
+
self.chat_agent = chat_agent
|
| 43 |
+
|
| 44 |
+
if not telegram_token:
|
| 45 |
+
self.token = os.getenv('TELEGRAM_TOKEN')
|
| 46 |
+
if not self.token:
|
| 47 |
+
raise ValueError(
|
| 48 |
+
"`TELEGRAM_TOKEN` not found in environment variables. "
|
| 49 |
+
"Get it from t.me/BotFather."
|
| 50 |
+
)
|
| 51 |
+
else:
|
| 52 |
+
self.token = telegram_token
|
| 53 |
+
|
| 54 |
+
import telebot # type: ignore[import-untyped]
|
| 55 |
+
|
| 56 |
+
self.bot = telebot.TeleBot(token=self.token)
|
| 57 |
+
|
| 58 |
+
# Register the message handler within the constructor
|
| 59 |
+
self.bot.message_handler(func=lambda message: True)(self.on_message)
|
| 60 |
+
|
| 61 |
+
def run(self) -> None:
|
| 62 |
+
r"""Start the Telegram bot."""
|
| 63 |
+
print("Telegram bot is running...")
|
| 64 |
+
self.bot.infinity_polling()
|
| 65 |
+
|
| 66 |
+
def on_message(self, message: 'Message') -> None:
|
| 67 |
+
r"""Handles incoming messages from the user.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
message (types.Message): The incoming message object.
|
| 71 |
+
"""
|
| 72 |
+
self.chat_agent.reset()
|
| 73 |
+
|
| 74 |
+
if not message.text:
|
| 75 |
+
return
|
| 76 |
+
|
| 77 |
+
user_msg = BaseMessage.make_user_message(
|
| 78 |
+
role_name="User", content=message.text
|
| 79 |
+
)
|
| 80 |
+
assistant_response = self.chat_agent.step(user_msg)
|
| 81 |
+
|
| 82 |
+
self.bot.reply_to(message, assistant_response.msg.content)
|
camel/configs/__init__.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .anthropic_config import ANTHROPIC_API_PARAMS, AnthropicConfig
|
| 15 |
+
from .base_config import BaseConfig
|
| 16 |
+
from .cohere_config import COHERE_API_PARAMS, CohereConfig
|
| 17 |
+
from .deepseek_config import DEEPSEEK_API_PARAMS, DeepSeekConfig
|
| 18 |
+
from .gemini_config import Gemini_API_PARAMS, GeminiConfig
|
| 19 |
+
from .groq_config import GROQ_API_PARAMS, GroqConfig
|
| 20 |
+
from .internlm_config import INTERNLM_API_PARAMS, InternLMConfig
|
| 21 |
+
from .litellm_config import LITELLM_API_PARAMS, LiteLLMConfig
|
| 22 |
+
from .mistral_config import MISTRAL_API_PARAMS, MistralConfig
|
| 23 |
+
from .nvidia_config import NVIDIA_API_PARAMS, NvidiaConfig
|
| 24 |
+
from .ollama_config import OLLAMA_API_PARAMS, OllamaConfig
|
| 25 |
+
from .openai_config import OPENAI_API_PARAMS, ChatGPTConfig
|
| 26 |
+
from .qwen_config import QWEN_API_PARAMS, QwenConfig
|
| 27 |
+
from .reka_config import REKA_API_PARAMS, RekaConfig
|
| 28 |
+
from .openrouter_config import OPENROUTER_API_PARAMS, OpenRouterConfig
|
| 29 |
+
from .samba_config import (
|
| 30 |
+
SAMBA_CLOUD_API_PARAMS,
|
| 31 |
+
SAMBA_VERSE_API_PARAMS,
|
| 32 |
+
SambaCloudAPIConfig,
|
| 33 |
+
SambaVerseAPIConfig,
|
| 34 |
+
)
|
| 35 |
+
from .sglang_config import SGLANG_API_PARAMS, SGLangConfig
|
| 36 |
+
from .togetherai_config import TOGETHERAI_API_PARAMS, TogetherAIConfig
|
| 37 |
+
from .vllm_config import VLLM_API_PARAMS, VLLMConfig
|
| 38 |
+
from .yi_config import YI_API_PARAMS, YiConfig
|
| 39 |
+
from .zhipuai_config import ZHIPUAI_API_PARAMS, ZhipuAIConfig
|
| 40 |
+
|
| 41 |
+
__all__ = [
|
| 42 |
+
'BaseConfig',
|
| 43 |
+
'ChatGPTConfig',
|
| 44 |
+
'OPENAI_API_PARAMS',
|
| 45 |
+
'AnthropicConfig',
|
| 46 |
+
'ANTHROPIC_API_PARAMS',
|
| 47 |
+
'GROQ_API_PARAMS',
|
| 48 |
+
'GroqConfig',
|
| 49 |
+
'LiteLLMConfig',
|
| 50 |
+
'LITELLM_API_PARAMS',
|
| 51 |
+
'NvidiaConfig',
|
| 52 |
+
'NVIDIA_API_PARAMS',
|
| 53 |
+
'OllamaConfig',
|
| 54 |
+
'OLLAMA_API_PARAMS',
|
| 55 |
+
'ZhipuAIConfig',
|
| 56 |
+
'ZHIPUAI_API_PARAMS',
|
| 57 |
+
'GeminiConfig',
|
| 58 |
+
'Gemini_API_PARAMS',
|
| 59 |
+
'VLLMConfig',
|
| 60 |
+
'VLLM_API_PARAMS',
|
| 61 |
+
'SGLangConfig',
|
| 62 |
+
'SGLANG_API_PARAMS',
|
| 63 |
+
'MistralConfig',
|
| 64 |
+
'MISTRAL_API_PARAMS',
|
| 65 |
+
'RekaConfig',
|
| 66 |
+
'REKA_API_PARAMS',
|
| 67 |
+
'SambaVerseAPIConfig',
|
| 68 |
+
'SAMBA_VERSE_API_PARAMS',
|
| 69 |
+
'SambaCloudAPIConfig',
|
| 70 |
+
'SAMBA_CLOUD_API_PARAMS',
|
| 71 |
+
'TogetherAIConfig',
|
| 72 |
+
'TOGETHERAI_API_PARAMS',
|
| 73 |
+
'CohereConfig',
|
| 74 |
+
'COHERE_API_PARAMS',
|
| 75 |
+
'YiConfig',
|
| 76 |
+
'YI_API_PARAMS',
|
| 77 |
+
'QwenConfig',
|
| 78 |
+
'QWEN_API_PARAMS',
|
| 79 |
+
'DeepSeekConfig',
|
| 80 |
+
'DEEPSEEK_API_PARAMS',
|
| 81 |
+
'InternLMConfig',
|
| 82 |
+
'INTERNLM_API_PARAMS',
|
| 83 |
+
'OPENROUTER_API_PARAMS',
|
| 84 |
+
'OpenRouterConfig',
|
| 85 |
+
]
|
camel/configs/anthropic_config.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Any, ClassVar, List, Union
|
| 17 |
+
|
| 18 |
+
from camel.configs.base_config import BaseConfig
|
| 19 |
+
from camel.types import NotGiven
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class AnthropicConfig(BaseConfig):
|
| 23 |
+
r"""Defines the parameters for generating chat completions using the
|
| 24 |
+
Anthropic API.
|
| 25 |
+
|
| 26 |
+
See: https://docs.anthropic.com/claude/reference/complete_post
|
| 27 |
+
Args:
|
| 28 |
+
max_tokens (int, optional): The maximum number of tokens to
|
| 29 |
+
generate before stopping. Note that Anthropic models may stop
|
| 30 |
+
before reaching this maximum. This parameter only specifies the
|
| 31 |
+
absolute maximum number of tokens to generate.
|
| 32 |
+
(default: :obj:`8192`)
|
| 33 |
+
stop_sequences (List[str], optional): Sequences that will cause the
|
| 34 |
+
model to stop generating completion text. Anthropic models stop
|
| 35 |
+
on "\n\nHuman:", and may include additional built-in stop sequences
|
| 36 |
+
in the future. By providing the stop_sequences parameter, you may
|
| 37 |
+
include additional strings that will cause the model to stop
|
| 38 |
+
generating. (default: :obj:`[]`)
|
| 39 |
+
temperature (float, optional): Amount of randomness injected into the
|
| 40 |
+
response. Defaults to 1. Ranges from 0 to 1. Use temp closer to 0
|
| 41 |
+
for analytical / multiple choice, and closer to 1 for creative
|
| 42 |
+
and generative tasks. (default: :obj:`1`)
|
| 43 |
+
top_p (float, optional): Use nucleus sampling. In nucleus sampling, we
|
| 44 |
+
compute the cumulative distribution over all the options for each
|
| 45 |
+
subsequent token in decreasing probability order and cut it off
|
| 46 |
+
once it reaches a particular probability specified by `top_p`.
|
| 47 |
+
You should either alter `temperature` or `top_p`,
|
| 48 |
+
but not both. (default: :obj:`0.7`)
|
| 49 |
+
top_k (int, optional): Only sample from the top K options for each
|
| 50 |
+
subsequent token. Used to remove "long tail" low probability
|
| 51 |
+
responses. (default: :obj:`5`)
|
| 52 |
+
metadata: An object describing metadata about the request.
|
| 53 |
+
stream (bool, optional): Whether to incrementally stream the response
|
| 54 |
+
using server-sent events. (default: :obj:`False`)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
max_tokens: int = 8192
|
| 58 |
+
stop_sequences: ClassVar[Union[List[str], NotGiven]] = []
|
| 59 |
+
temperature: float = 1
|
| 60 |
+
top_p: Union[float, NotGiven] = 0.7
|
| 61 |
+
top_k: Union[int, NotGiven] = 5
|
| 62 |
+
stream: bool = False
|
| 63 |
+
|
| 64 |
+
def as_dict(self) -> dict[str, Any]:
|
| 65 |
+
config_dict = super().as_dict()
|
| 66 |
+
if "tools" in config_dict:
|
| 67 |
+
del config_dict["tools"] # TODO: Support tool calling.
|
| 68 |
+
return config_dict
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
ANTHROPIC_API_PARAMS = {param for param in AnthropicConfig.model_fields.keys()}
|
camel/configs/base_config.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from abc import ABC
|
| 17 |
+
from typing import Any, List, Optional
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel, ConfigDict, field_validator
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class BaseConfig(ABC, BaseModel):
|
| 23 |
+
r"""Base configuration class for all models.
|
| 24 |
+
|
| 25 |
+
This class provides a common interface for all models, ensuring that all
|
| 26 |
+
models have a consistent set of attributes and methods.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
model_config = ConfigDict(
|
| 30 |
+
arbitrary_types_allowed=True,
|
| 31 |
+
extra="forbid",
|
| 32 |
+
frozen=True,
|
| 33 |
+
# UserWarning: conflict with protected namespace "model_"
|
| 34 |
+
protected_namespaces=(),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
tools: Optional[List[Any]] = None
|
| 38 |
+
"""A list of tools the model may
|
| 39 |
+
call. Currently, only functions are supported as a tool. Use this
|
| 40 |
+
to provide a list of functions the model may generate JSON inputs
|
| 41 |
+
for. A max of 128 functions are supported.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
@field_validator("tools", mode="before")
|
| 45 |
+
@classmethod
|
| 46 |
+
def fields_type_checking(cls, tools):
|
| 47 |
+
r"""Validate the type of tools in the configuration.
|
| 48 |
+
|
| 49 |
+
This method ensures that the tools provided in the configuration are
|
| 50 |
+
instances of `FunctionTool`. If any tool is not an instance of
|
| 51 |
+
`FunctionTool`, it raises a ValueError.
|
| 52 |
+
"""
|
| 53 |
+
if tools is not None:
|
| 54 |
+
from camel.toolkits import FunctionTool
|
| 55 |
+
|
| 56 |
+
for tool in tools:
|
| 57 |
+
if not isinstance(tool, FunctionTool):
|
| 58 |
+
raise ValueError(
|
| 59 |
+
f"The tool {tool} should "
|
| 60 |
+
"be an instance of `FunctionTool`."
|
| 61 |
+
)
|
| 62 |
+
return tools
|
| 63 |
+
|
| 64 |
+
def as_dict(self) -> dict[str, Any]:
|
| 65 |
+
r"""Convert the current configuration to a dictionary.
|
| 66 |
+
|
| 67 |
+
This method converts the current configuration object to a dictionary
|
| 68 |
+
representation, which can be used for serialization or other purposes.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
dict[str, Any]: A dictionary representation of the current
|
| 72 |
+
configuration.
|
| 73 |
+
"""
|
| 74 |
+
config_dict = self.model_dump()
|
| 75 |
+
|
| 76 |
+
tools_schema = None
|
| 77 |
+
if self.tools:
|
| 78 |
+
from camel.toolkits import FunctionTool
|
| 79 |
+
|
| 80 |
+
tools_schema = []
|
| 81 |
+
for tool in self.tools:
|
| 82 |
+
if not isinstance(tool, FunctionTool):
|
| 83 |
+
raise ValueError(
|
| 84 |
+
f"The tool {tool} should "
|
| 85 |
+
"be an instance of `FunctionTool`."
|
| 86 |
+
)
|
| 87 |
+
tools_schema.append(tool.get_openai_tool_schema())
|
| 88 |
+
config_dict["tools"] = tools_schema
|
| 89 |
+
return config_dict
|
camel/configs/cohere_config.py
ADDED
|
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import List, Optional
|
| 17 |
+
|
| 18 |
+
from camel.configs.base_config import BaseConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CohereConfig(BaseConfig):
|
| 22 |
+
r"""Defines the parameters for generating chat completions using the
|
| 23 |
+
Cohere API.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
temperature (float, optional): Sampling temperature to use, between
|
| 27 |
+
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
| 28 |
+
while lower values make it more focused and deterministic.
|
| 29 |
+
(default: :obj:`0.3`)
|
| 30 |
+
documents (list, optional): A list of relevant documents that the
|
| 31 |
+
model can cite to generate a more accurate reply. Each document is
|
| 32 |
+
either a string or document object with content and metadata.
|
| 33 |
+
(default: :obj:`None`)
|
| 34 |
+
max_tokens (int, optional): The maximum number of tokens the model
|
| 35 |
+
will generate as part of the response. (default: :obj:`None`)
|
| 36 |
+
stop_sequences (List(str), optional): A list of up to 5 strings that
|
| 37 |
+
the model will use to stop generation. If the model generates a
|
| 38 |
+
string that matches any of the strings in the list, it will stop
|
| 39 |
+
generating tokens and return the generated text up to that point
|
| 40 |
+
not including the stop sequence. (default: :obj:`None`)
|
| 41 |
+
seed (int, optional): If specified, the backend will make a best
|
| 42 |
+
effort to sample tokens deterministically, such that repeated
|
| 43 |
+
requests with the same seed and parameters should return the same
|
| 44 |
+
result. However, determinism cannot be totally guaranteed.
|
| 45 |
+
(default: :obj:`None`)
|
| 46 |
+
frequency_penalty (float, optional): Min value of `0.0`, max value of
|
| 47 |
+
`1.0`. Used to reduce repetitiveness of generated tokens. The
|
| 48 |
+
higher the value, the stronger a penalty is applied to previously
|
| 49 |
+
present tokens, proportional to how many times they have already
|
| 50 |
+
appeared in the prompt or prior generation. (default: :obj:`0.0`)
|
| 51 |
+
presence_penalty (float, optional): Min value of `0.0`, max value of
|
| 52 |
+
`1.0`. Used to reduce repetitiveness of generated tokens. Similar
|
| 53 |
+
to `frequency_penalty`, except that this penalty is applied
|
| 54 |
+
equally to all tokens that have already appeared, regardless of
|
| 55 |
+
their exact frequencies. (default: :obj:`0.0`)
|
| 56 |
+
k (int, optional): Ensures only the top k most likely tokens are
|
| 57 |
+
considered for generation at each step. Min value of `0`, max
|
| 58 |
+
value of `500`. (default: :obj:`0`)
|
| 59 |
+
p (float, optional): Ensures that only the most likely tokens, with
|
| 60 |
+
total probability mass of `p`, are considered for generation at
|
| 61 |
+
each step. If both k and p are enabled, `p` acts after `k`. Min
|
| 62 |
+
value of `0.01`, max value of `0.99`. (default: :obj:`0.75`)
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
temperature: Optional[float] = 0.2
|
| 66 |
+
documents: Optional[list] = None
|
| 67 |
+
max_tokens: Optional[int] = None
|
| 68 |
+
stop_sequences: Optional[List[str]] = None
|
| 69 |
+
seed: Optional[int] = None
|
| 70 |
+
frequency_penalty: Optional[float] = 0.0
|
| 71 |
+
presence_penalty: Optional[float] = 0.0
|
| 72 |
+
k: Optional[int] = 0
|
| 73 |
+
p: Optional[float] = 0.75
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
COHERE_API_PARAMS = {param for param in CohereConfig().model_fields.keys()}
|
camel/configs/deepseek_config.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Any, Optional, Sequence, Type, Union
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
|
| 21 |
+
from camel.configs.base_config import BaseConfig
|
| 22 |
+
from camel.types import NOT_GIVEN, NotGiven
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class DeepSeekConfig(BaseConfig):
|
| 26 |
+
r"""Defines the parameters for generating chat completions using the
|
| 27 |
+
DeepSeek API.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
temperature (float, optional): Sampling temperature to use, between
|
| 31 |
+
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
| 32 |
+
while lower values make it more focused and deterministic.
|
| 33 |
+
(default: :obj:`1.0`)
|
| 34 |
+
top_p (float, optional): Controls the diversity and focus of the
|
| 35 |
+
generated results. Higher values make the output more diverse,
|
| 36 |
+
while lower values make it more focused. (default: :obj:`1.0`)
|
| 37 |
+
response_format (object, optional): Specifies the format of the
|
| 38 |
+
returned content. The available values are `{"type": "text"}` or
|
| 39 |
+
`{"type": "json_object"}`. Setting it to `{"type": "json_object"}`
|
| 40 |
+
will output a standard JSON string.
|
| 41 |
+
(default: :obj:`{"type": "text"}`)
|
| 42 |
+
stream (bool, optional): If set, partial message deltas will be sent.
|
| 43 |
+
Tokens will be sent as data-only server-sent events (SSE) as
|
| 44 |
+
they become available, with the stream terminated by a
|
| 45 |
+
data: [DONE] message. (default: :obj:`False`)
|
| 46 |
+
stop (Union[str, list[str]], optional): Up to 16 sequences where
|
| 47 |
+
the API will stop generating further tokens. (default: :obj:`None`)
|
| 48 |
+
max_tokens (int, optional): The maximum number of tokens that can
|
| 49 |
+
be generated in the chat completion. The total length of input
|
| 50 |
+
tokens and generated tokens is limited by the model's context
|
| 51 |
+
length. (default: :obj:`None`)
|
| 52 |
+
presence_penalty (float, optional): Number between -2.0 and 2.0.
|
| 53 |
+
Positive values penalize new tokens based on whether they
|
| 54 |
+
appear in the text so far, increasing the model's likelihood
|
| 55 |
+
to talk about new topics. (default: :obj:`0.0`)
|
| 56 |
+
frequency_penalty (float, optional): Number between -2.0 and 2.0.
|
| 57 |
+
Positive values penalize new tokens based on their existing
|
| 58 |
+
frequency in the text so far, decreasing the model's likelihood
|
| 59 |
+
to repeat the same line verbatim. (default: :obj:`0`)
|
| 60 |
+
tools (list[FunctionTool], optional): A list of tools the model may
|
| 61 |
+
call. Currently, only functions are supported as a tool. Use
|
| 62 |
+
this to provide a list of functions the model may generate JSON
|
| 63 |
+
inputs for. A max of 128 functions are supported.
|
| 64 |
+
(default: :obj:`None`)
|
| 65 |
+
tool_choice (Union[dict[str, str], str], optional): Controls which
|
| 66 |
+
(if any) tool is called by the model. "none" means the model
|
| 67 |
+
will not call any tool and instead generates a message. "auto"
|
| 68 |
+
means the model can pick between generating a message or calling
|
| 69 |
+
one or more tools. "required" means the model must call one or
|
| 70 |
+
more tools. Specifying a particular tool via
|
| 71 |
+
{"type": "function", "function": {"name": "my_function"}} forces
|
| 72 |
+
the model to call that tool. "none" is the default when no tools
|
| 73 |
+
are present. "auto" is the default if tools are present.
|
| 74 |
+
(default: :obj:`"auto"`)
|
| 75 |
+
logprobs (bool, optional): Whether to return log probabilities of
|
| 76 |
+
the output tokens or not. If true, returns the log probabilities
|
| 77 |
+
of each output token returned in the content of message.
|
| 78 |
+
(default: :obj:`False`)
|
| 79 |
+
top_logprobs (int, optional): An integer between 0 and 20 specifying
|
| 80 |
+
the number of most likely tokens to return at each token
|
| 81 |
+
position, each with an associated log probability. logprobs
|
| 82 |
+
must be set to true if this parameter is used.
|
| 83 |
+
(default: :obj:`None`)
|
| 84 |
+
include_usage (bool, optional): When streaming, specifies whether to
|
| 85 |
+
include usage information in `stream_options`. (default:
|
| 86 |
+
:obj:`True`)
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
temperature: float = 1.0 # deepseek default: 1.0
|
| 90 |
+
top_p: float = 1.0
|
| 91 |
+
stream: bool = False
|
| 92 |
+
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
| 93 |
+
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
| 94 |
+
presence_penalty: float = 0.0
|
| 95 |
+
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
| 96 |
+
frequency_penalty: float = 0.0
|
| 97 |
+
tool_choice: Optional[Union[dict[str, str], str]] = None
|
| 98 |
+
logprobs: bool = False
|
| 99 |
+
top_logprobs: Optional[int] = None
|
| 100 |
+
|
| 101 |
+
def __init__(self, include_usage: bool = True, **kwargs):
|
| 102 |
+
super().__init__(**kwargs)
|
| 103 |
+
# Only set stream_options when stream is True
|
| 104 |
+
# Otherwise, it will raise error when calling the API
|
| 105 |
+
if self.stream:
|
| 106 |
+
self.stream_options = {"include_usage": include_usage}
|
| 107 |
+
|
| 108 |
+
def as_dict(self) -> dict[str, Any]:
|
| 109 |
+
r"""Convert the current configuration to a dictionary.
|
| 110 |
+
|
| 111 |
+
This method converts the current configuration object to a dictionary
|
| 112 |
+
representation, which can be used for serialization or other purposes.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
dict[str, Any]: A dictionary representation of the current
|
| 116 |
+
configuration.
|
| 117 |
+
"""
|
| 118 |
+
config_dict = self.model_dump()
|
| 119 |
+
if self.tools:
|
| 120 |
+
from camel.toolkits import FunctionTool
|
| 121 |
+
|
| 122 |
+
tools_schema = []
|
| 123 |
+
for tool in self.tools:
|
| 124 |
+
if not isinstance(tool, FunctionTool):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"The tool {tool} should "
|
| 127 |
+
"be an instance of `FunctionTool`."
|
| 128 |
+
)
|
| 129 |
+
tools_schema.append(tool.get_openai_tool_schema())
|
| 130 |
+
config_dict["tools"] = NOT_GIVEN
|
| 131 |
+
return config_dict
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
DEEPSEEK_API_PARAMS = {param for param in DeepSeekConfig.model_fields.keys()}
|
camel/configs/gemini_config.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Any, Optional, Sequence, Type, Union
|
| 18 |
+
|
| 19 |
+
from pydantic import BaseModel
|
| 20 |
+
|
| 21 |
+
from camel.configs.base_config import BaseConfig
|
| 22 |
+
from camel.types import NOT_GIVEN, NotGiven
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GeminiConfig(BaseConfig):
|
| 26 |
+
r"""Defines the parameters for generating chat completions using the
|
| 27 |
+
Gemini API.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
temperature (float, optional): Sampling temperature to use, between
|
| 31 |
+
:obj:`0` and :obj:`2`. Higher values make the output more random,
|
| 32 |
+
while lower values make it more focused and deterministic.
|
| 33 |
+
(default: :obj:`0.2`)
|
| 34 |
+
top_p (float, optional): An alternative to sampling with temperature,
|
| 35 |
+
called nucleus sampling, where the model considers the results of
|
| 36 |
+
the tokens with top_p probability mass. So :obj:`0.1` means only
|
| 37 |
+
the tokens comprising the top 10% probability mass are considered.
|
| 38 |
+
(default: :obj:`1.0`)
|
| 39 |
+
n (int, optional): How many chat completion choices to generate for
|
| 40 |
+
each input message. (default: :obj:`1`)
|
| 41 |
+
response_format (object, optional): An object specifying the format
|
| 42 |
+
that the model must output. Compatible with GPT-4 Turbo and all
|
| 43 |
+
GPT-3.5 Turbo models newer than gpt-3.5-turbo-1106. Setting to
|
| 44 |
+
{"type": "json_object"} enables JSON mode, which guarantees the
|
| 45 |
+
message the model generates is valid JSON. Important: when using
|
| 46 |
+
JSON mode, you must also instruct the model to produce JSON
|
| 47 |
+
yourself via a system or user message. Without this, the model
|
| 48 |
+
may generate an unending stream of whitespace until the generation
|
| 49 |
+
reaches the token limit, resulting in a long-running and seemingly
|
| 50 |
+
"stuck" request. Also note that the message content may be
|
| 51 |
+
partially cut off if finish_reason="length", which indicates the
|
| 52 |
+
generation exceeded max_tokens or the conversation exceeded the
|
| 53 |
+
max context length.
|
| 54 |
+
stream (bool, optional): If True, partial message deltas will be sent
|
| 55 |
+
as data-only server-sent events as they become available.
|
| 56 |
+
(default: :obj:`False`)
|
| 57 |
+
stop (str or list, optional): Up to :obj:`4` sequences where the API
|
| 58 |
+
will stop generating further tokens. (default: :obj:`None`)
|
| 59 |
+
max_tokens (int, optional): The maximum number of tokens to generate
|
| 60 |
+
in the chat completion. The total length of input tokens and
|
| 61 |
+
generated tokens is limited by the model's context length.
|
| 62 |
+
(default: :obj:`None`)
|
| 63 |
+
tools (list[FunctionTool], optional): A list of tools the model may
|
| 64 |
+
call. Currently, only functions are supported as a tool. Use this
|
| 65 |
+
to provide a list of functions the model may generate JSON inputs
|
| 66 |
+
for. A max of 128 functions are supported.
|
| 67 |
+
tool_choice (Union[dict[str, str], str], optional): Controls which (if
|
| 68 |
+
any) tool is called by the model. :obj:`"none"` means the model
|
| 69 |
+
will not call any tool and instead generates a message.
|
| 70 |
+
:obj:`"auto"` means the model can pick between generating a
|
| 71 |
+
message or calling one or more tools. :obj:`"required"` means the
|
| 72 |
+
model must call one or more tools. Specifying a particular tool
|
| 73 |
+
via {"type": "function", "function": {"name": "my_function"}}
|
| 74 |
+
forces the model to call that tool. :obj:`"none"` is the default
|
| 75 |
+
when no tools are present. :obj:`"auto"` is the default if tools
|
| 76 |
+
are present.
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
temperature: float = 0.2 # openai default: 1.0
|
| 80 |
+
top_p: float = 1.0
|
| 81 |
+
n: int = 1
|
| 82 |
+
stream: bool = False
|
| 83 |
+
stop: Union[str, Sequence[str], NotGiven] = NOT_GIVEN
|
| 84 |
+
max_tokens: Union[int, NotGiven] = NOT_GIVEN
|
| 85 |
+
response_format: Union[Type[BaseModel], dict, NotGiven] = NOT_GIVEN
|
| 86 |
+
tool_choice: Optional[Union[dict[str, str], str, NotGiven]] = NOT_GIVEN
|
| 87 |
+
|
| 88 |
+
def as_dict(self) -> dict[str, Any]:
|
| 89 |
+
r"""Convert the current configuration to a dictionary.
|
| 90 |
+
|
| 91 |
+
This method converts the current configuration object to a dictionary
|
| 92 |
+
representation, which can be used for serialization or other purposes.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
dict[str, Any]: A dictionary representation of the current
|
| 96 |
+
configuration.
|
| 97 |
+
"""
|
| 98 |
+
config_dict = self.model_dump()
|
| 99 |
+
if self.tools:
|
| 100 |
+
from camel.toolkits import FunctionTool
|
| 101 |
+
|
| 102 |
+
tools_schema = []
|
| 103 |
+
for tool in self.tools:
|
| 104 |
+
if not isinstance(tool, FunctionTool):
|
| 105 |
+
raise ValueError(
|
| 106 |
+
f"The tool {tool} should "
|
| 107 |
+
"be an instance of `FunctionTool`."
|
| 108 |
+
)
|
| 109 |
+
tools_schema.append(tool.get_openai_tool_schema())
|
| 110 |
+
config_dict["tools"] = NOT_GIVEN
|
| 111 |
+
return config_dict
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
Gemini_API_PARAMS = {param for param in GeminiConfig.model_fields.keys()}
|