ZaynZhu
commited on
Commit
·
7c08dc3
0
Parent(s):
Clean version without large assets
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +40 -0
- .gitignore +18 -0
- Paper2Poster/.gitignore +17 -0
- Paper2Poster/LICENSE +21 -0
- Paper2Poster/Paper2Poster-eval/create_paper_questions.py +40 -0
- Paper2Poster/Paper2Poster-eval/eval_poster_pipeline.py +479 -0
- Paper2Poster/Paper2Poster-eval/eval_qa_fix.py +114 -0
- Paper2Poster/PosterAgent/LLM_direct_generate.py +103 -0
- Paper2Poster/PosterAgent/LLM_direct_generate_beamer.py +189 -0
- Paper2Poster/PosterAgent/__init__.py +16 -0
- Paper2Poster/PosterAgent/apply_theme.py +281 -0
- Paper2Poster/PosterAgent/beamer_pipeline.py +182 -0
- Paper2Poster/PosterAgent/create_dataset.py +69 -0
- Paper2Poster/PosterAgent/deoverflow.py +234 -0
- Paper2Poster/PosterAgent/deoverflow_parallel.py +485 -0
- Paper2Poster/PosterAgent/fill_and_style.py +215 -0
- Paper2Poster/PosterAgent/gen_beamer_code.py +299 -0
- Paper2Poster/PosterAgent/gen_outline_layout.py +851 -0
- Paper2Poster/PosterAgent/gen_outline_layout_parallel.py +949 -0
- Paper2Poster/PosterAgent/gen_poster_content.py +529 -0
- Paper2Poster/PosterAgent/gen_pptx_code.py +249 -0
- Paper2Poster/PosterAgent/new_pipeline.py +547 -0
- Paper2Poster/PosterAgent/parse_raw.py +237 -0
- Paper2Poster/PosterAgent/poster_gen_pipeline.py +101 -0
- Paper2Poster/PosterAgent/tree_split_layout.py +750 -0
- Paper2Poster/README.md +215 -0
- Paper2Poster/__init__.py +3 -0
- Paper2Poster/camel/__init__.py +25 -0
- Paper2Poster/camel/agents/__init__.py +44 -0
- Paper2Poster/camel/agents/base.py +29 -0
- Paper2Poster/camel/agents/chat_agent.py +1539 -0
- Paper2Poster/camel/agents/critic_agent.py +202 -0
- Paper2Poster/camel/agents/deductive_reasoner_agent.py +303 -0
- Paper2Poster/camel/agents/embodied_agent.py +201 -0
- Paper2Poster/camel/agents/knowledge_graph_agent.py +259 -0
- Paper2Poster/camel/agents/multi_hop_generator_agent.py +117 -0
- Paper2Poster/camel/agents/programmed_agent_instruction.py +203 -0
- Paper2Poster/camel/agents/role_assignment_agent.py +141 -0
- Paper2Poster/camel/agents/search_agent.py +133 -0
- Paper2Poster/camel/agents/task_agent.py +410 -0
- Paper2Poster/camel/agents/tool_agents/__init__.py +20 -0
- Paper2Poster/camel/agents/tool_agents/base.py +39 -0
- Paper2Poster/camel/agents/tool_agents/hugging_face_tool_agent.py +206 -0
- Paper2Poster/camel/benchmarks/__init__.py +30 -0
- Paper2Poster/camel/benchmarks/apibank.py +565 -0
- Paper2Poster/camel/benchmarks/apibench.py +500 -0
- Paper2Poster/camel/benchmarks/base.py +152 -0
- Paper2Poster/camel/benchmarks/gaia.py +478 -0
- Paper2Poster/camel/benchmarks/nexus.py +518 -0
- Paper2Poster/camel/benchmarks/ragbench.py +333 -0
.gitattributes
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
*.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.wav filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
input/
|
| 2 |
+
output/Paper2Poster/assets/
|
| 3 |
+
Paper2Video/assets/
|
| 4 |
+
posterbuilder/latex_proj/figures/
|
| 5 |
+
*.png
|
| 6 |
+
*.pdf
|
| 7 |
+
*.jpg
|
| 8 |
+
*.wav
|
| 9 |
+
*.mp4
|
| 10 |
+
__pycache__/
|
| 11 |
+
*.png
|
| 12 |
+
*.jpg
|
| 13 |
+
*.pdf
|
| 14 |
+
*.wav
|
| 15 |
+
*.mp4
|
| 16 |
+
Paper2Poster/assets/
|
| 17 |
+
Paper2Video/assets/
|
| 18 |
+
posterbuilder/latex_proj/figures/
|
Paper2Poster/.gitignore
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.vscode/
|
| 3 |
+
ablations/
|
| 4 |
+
**/__pycache__/
|
| 5 |
+
*_generated_posters/
|
| 6 |
+
*_images_and_tables/
|
| 7 |
+
contents/
|
| 8 |
+
tmp/
|
| 9 |
+
tree_splits/
|
| 10 |
+
eval_results/
|
| 11 |
+
Paper2Poster-data/
|
| 12 |
+
Example/
|
| 13 |
+
*.ipynb
|
| 14 |
+
eval_time_detail_parallel/
|
| 15 |
+
*.sh
|
| 16 |
+
.claude
|
| 17 |
+
CLAUDE.md
|
Paper2Poster/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2025 Paper2Poster
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
Paper2Poster/Paper2Poster-eval/create_paper_questions.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.poster_eval_utils import *
|
| 2 |
+
import argparse
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
|
| 6 |
+
if __name__ == '__main__':
|
| 7 |
+
parser = argparse.ArgumentParser()
|
| 8 |
+
parser.add_argument('--paper_folder', type=str, default=None)
|
| 9 |
+
parser.add_argument('--model_name', type=str, default='o3')
|
| 10 |
+
args = parser.parse_args()
|
| 11 |
+
|
| 12 |
+
paper_text = get_poster_text(os.path.join(args.paper_folder, 'paper.pdf'))
|
| 13 |
+
|
| 14 |
+
if args.model_name == '4o':
|
| 15 |
+
model_type = ModelType.GPT_4O
|
| 16 |
+
elif args.model_name == 'o3':
|
| 17 |
+
model_type = ModelType.O3
|
| 18 |
+
detail_qa = get_questions(paper_text, 'detail', model_type)
|
| 19 |
+
understanding_qa = get_questions(paper_text, 'understanding', model_type)
|
| 20 |
+
|
| 21 |
+
detail_q, detail_a, detail_aspects = get_answers_and_remove_answers(detail_qa)
|
| 22 |
+
understanding_q, understanding_a, understanding_aspects = get_answers_and_remove_answers(understanding_qa)
|
| 23 |
+
|
| 24 |
+
final_qa = {}
|
| 25 |
+
detail_qa = {
|
| 26 |
+
'questions': detail_q,
|
| 27 |
+
'answers': detail_a,
|
| 28 |
+
'aspects': detail_aspects,
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
understanding_qa = {
|
| 32 |
+
'questions': understanding_q,
|
| 33 |
+
'answers': understanding_a,
|
| 34 |
+
'aspects': understanding_aspects,
|
| 35 |
+
}
|
| 36 |
+
final_qa['detail'] = detail_qa
|
| 37 |
+
final_qa['understanding'] = understanding_qa
|
| 38 |
+
|
| 39 |
+
with open(os.path.join(args.paper_folder, f'{args.model_name}_qa.json'), 'w') as f:
|
| 40 |
+
json.dump(final_qa, f, indent=4)
|
Paper2Poster/Paper2Poster-eval/eval_poster_pipeline.py
ADDED
|
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from utils.poster_eval_utils import *
|
| 2 |
+
import json
|
| 3 |
+
from utils.wei_utils import get_agent_config
|
| 4 |
+
import argparse
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
import tempfile
|
| 7 |
+
import shutil
|
| 8 |
+
import os
|
| 9 |
+
import glob
|
| 10 |
+
import re
|
| 11 |
+
|
| 12 |
+
load_dotenv()
|
| 13 |
+
|
| 14 |
+
def run_qa_and_update_results(
|
| 15 |
+
args,
|
| 16 |
+
raw_folder,
|
| 17 |
+
gen_poster_path,
|
| 18 |
+
save_path,
|
| 19 |
+
single_model_name=None,
|
| 20 |
+
del_model_name=None,
|
| 21 |
+
):
|
| 22 |
+
"""
|
| 23 |
+
If single_model_name is provided, run QA for that one model only,
|
| 24 |
+
but update an existing JSON file (which already contains the other
|
| 25 |
+
models' results) and re-compute the overall averages.
|
| 26 |
+
|
| 27 |
+
If single_model_name is None, run QA for all models in all_model_names
|
| 28 |
+
and write a new JSON file.
|
| 29 |
+
|
| 30 |
+
:param raw_folder: Path to folder with 'o3_qa.json'.
|
| 31 |
+
:param gen_poster_path: Path to the generated poster image.
|
| 32 |
+
:param save_path: Directory where overall_qa_result.json is saved or should be written.
|
| 33 |
+
:param all_model_names: List of model names (e.g. ['vllm_qwen_vl', '4o', 'o3']).
|
| 34 |
+
:param single_model_name: Optional single model name.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
# Load the QA data (questions, answers, aspects)
|
| 38 |
+
qa_dict = json.load(open(os.path.join(raw_folder, 'o3_qa.json'), 'r'))
|
| 39 |
+
detail_qa = qa_dict['detail']
|
| 40 |
+
understanding_qa = qa_dict['understanding']
|
| 41 |
+
|
| 42 |
+
# Option A: Single model case
|
| 43 |
+
if single_model_name is not None:
|
| 44 |
+
qa_input_token, qa_output_token = 0, 0
|
| 45 |
+
# Load the existing JSON with all previously computed results
|
| 46 |
+
existing_path = os.path.join(save_path, "overall_qa_result.json")
|
| 47 |
+
with open(existing_path, 'r') as f:
|
| 48 |
+
overall_qa_result = json.load(f)
|
| 49 |
+
|
| 50 |
+
if del_model_name is not None:
|
| 51 |
+
# Remove the specified model from the existing results
|
| 52 |
+
if del_model_name in overall_qa_result['qa_result']:
|
| 53 |
+
del overall_qa_result['qa_result'][del_model_name]
|
| 54 |
+
print(f"Removed model {del_model_name} from existing results.")
|
| 55 |
+
|
| 56 |
+
if single_model_name in overall_qa_result['qa_result']:
|
| 57 |
+
print(f"Model {single_model_name} already evaluated. Skipping.")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
# Evaluate QA for the single_model_name
|
| 61 |
+
print(f"Running QA for single model: {single_model_name}")
|
| 62 |
+
agent_config = get_agent_config(single_model_name)
|
| 63 |
+
|
| 64 |
+
if args.poster_method == 'paper':
|
| 65 |
+
poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), format='jpg')
|
| 66 |
+
else:
|
| 67 |
+
poster_images = [Image.open(gen_poster_path)]
|
| 68 |
+
|
| 69 |
+
poster_images = [ensure_under_limit_pil(image) for image in poster_images]
|
| 70 |
+
|
| 71 |
+
detail_accuracy, detail_aspect_accuracy, detail_agent_answers, input_token, output_token = eval_qa_get_answer(
|
| 72 |
+
poster_input=poster_images,
|
| 73 |
+
questions=detail_qa['questions'],
|
| 74 |
+
answers=detail_qa['answers'],
|
| 75 |
+
aspects=detail_qa['aspects'],
|
| 76 |
+
input_type='image',
|
| 77 |
+
agent_config=agent_config
|
| 78 |
+
)
|
| 79 |
+
qa_input_token += input_token
|
| 80 |
+
qa_output_token += output_token
|
| 81 |
+
print('Detail QA accuracy:', detail_accuracy)
|
| 82 |
+
|
| 83 |
+
understanding_accuracy, understanding_aspect_accuracy, understanding_agent_answers, input_token, output_token = eval_qa_get_answer(
|
| 84 |
+
poster_input=poster_images,
|
| 85 |
+
questions=understanding_qa['questions'],
|
| 86 |
+
answers=understanding_qa['answers'],
|
| 87 |
+
aspects=understanding_qa['aspects'],
|
| 88 |
+
input_type='image',
|
| 89 |
+
agent_config=agent_config
|
| 90 |
+
)
|
| 91 |
+
qa_input_token += input_token
|
| 92 |
+
qa_output_token += output_token
|
| 93 |
+
print('Understanding QA accuracy:', understanding_accuracy)
|
| 94 |
+
|
| 95 |
+
# Update QA result for this one model
|
| 96 |
+
# overall_qa_result["qa_result"] is assumed to already have the others
|
| 97 |
+
overall_qa_result['qa_result'][single_model_name] = {
|
| 98 |
+
'detail_accuracy': detail_accuracy,
|
| 99 |
+
'detail_aspect_accuracy': detail_aspect_accuracy,
|
| 100 |
+
'detail_agent_answers': detail_agent_answers,
|
| 101 |
+
'understanding_accuracy': understanding_accuracy,
|
| 102 |
+
'understanding_aspect_accuracy': understanding_aspect_accuracy,
|
| 103 |
+
'understanding_agent_answers': understanding_agent_answers
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
# Now re-compute the averages across all models present in the JSON
|
| 107 |
+
# Grab all model entries from overall_qa_result['qa_result']
|
| 108 |
+
all_models_in_file = list(overall_qa_result['qa_result'].keys())
|
| 109 |
+
detail_accs = []
|
| 110 |
+
understanding_accs = []
|
| 111 |
+
for m in all_models_in_file:
|
| 112 |
+
detail_accs.append(overall_qa_result['qa_result'][m]['detail_accuracy'])
|
| 113 |
+
understanding_accs.append(overall_qa_result['qa_result'][m]['understanding_accuracy'])
|
| 114 |
+
|
| 115 |
+
avg_detail_accuracy = float(np.mean(detail_accs)) if detail_accs else 0.0
|
| 116 |
+
avg_understanding_accuracy = float(np.mean(understanding_accs)) if understanding_accs else 0.0
|
| 117 |
+
|
| 118 |
+
overall_qa_result['avg_detail_accuracy'] = avg_detail_accuracy
|
| 119 |
+
overall_qa_result['avg_understanding_accuracy'] = avg_understanding_accuracy
|
| 120 |
+
|
| 121 |
+
# Finally, overwrite the same JSON file with the updated results
|
| 122 |
+
with open(existing_path, 'w') as f:
|
| 123 |
+
json.dump(overall_qa_result, f, indent=4)
|
| 124 |
+
|
| 125 |
+
print(f'Input tokens: {qa_input_token}')
|
| 126 |
+
print(f'Output tokens: {qa_output_token}')
|
| 127 |
+
|
| 128 |
+
print('Updated overall_qa_result.json with single-model results.')
|
| 129 |
+
print('New average detail accuracy:', avg_detail_accuracy)
|
| 130 |
+
print('New average understanding accuracy:', avg_understanding_accuracy)
|
| 131 |
+
|
| 132 |
+
if __name__ == '__main__':
|
| 133 |
+
parser = argparse.ArgumentParser()
|
| 134 |
+
parser.add_argument('--paper_name', type=str)
|
| 135 |
+
parser.add_argument('--base_dir', type=str, default='Paper2Poster-data')
|
| 136 |
+
parser.add_argument('--poster_method', type=str)
|
| 137 |
+
parser.add_argument('--poster_image_name', type=str, default='poster.png', choices=['poster.png'])
|
| 138 |
+
parser.add_argument('--metric', type=str, choices=['stats', 'qa', 'judge', 'word_count', 'token_count', 'figure_count', 'aesthetic_judge'], default='stats')
|
| 139 |
+
parser.add_argument('--fix', type=str, default=None)
|
| 140 |
+
parser.add_argument('--del_model_name', type=str, default=None)
|
| 141 |
+
|
| 142 |
+
args = parser.parse_args()
|
| 143 |
+
|
| 144 |
+
raw_poster_path = f'{args.base_dir}/{args.paper_name}/poster.png'
|
| 145 |
+
raw_folder = f'{args.base_dir}/{args.paper_name}'
|
| 146 |
+
|
| 147 |
+
gen_poster_path = f'{args.poster_method}/{args.base_dir}/{args.paper_name}/{args.poster_image_name}'
|
| 148 |
+
gen_folder = f'{args.poster_method}/{args.base_dir}/{args.paper_name}'
|
| 149 |
+
|
| 150 |
+
save_path = f'eval_results/{args.paper_name}/{args.poster_method}'
|
| 151 |
+
os.makedirs(save_path, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
if args.poster_method == 'paper':
|
| 154 |
+
if args.metric == 'qa' and args.fix is not None:
|
| 155 |
+
overall_qa_result = json.load(open(f'{save_path}/overall_qa_result.json', 'r'))
|
| 156 |
+
if args.fix in overall_qa_result['qa_result']:
|
| 157 |
+
print(f"Model {args.fix} already evaluated. Skipping.")
|
| 158 |
+
exit(0)
|
| 159 |
+
# create a temp folder to store the paper
|
| 160 |
+
# 1) Create a unique temp folder
|
| 161 |
+
temp_dir = tempfile.mkdtemp(prefix="eval_temp", suffix="_data")
|
| 162 |
+
|
| 163 |
+
# 2) Build your source directory path, replacing spaces
|
| 164 |
+
paper_slug = args.paper_name.replace(' ', '_')
|
| 165 |
+
source_dir = os.path.join('<4o_vllm_qwen>_images_and_tables', paper_slug)
|
| 166 |
+
|
| 167 |
+
# 3) Sequentially copy files named "<paper_slug>-<index>.png"
|
| 168 |
+
index = 1
|
| 169 |
+
while True:
|
| 170 |
+
filename = f"{paper_slug}-{index}.png"
|
| 171 |
+
src_path = os.path.join(source_dir, filename)
|
| 172 |
+
if not os.path.isfile(src_path):
|
| 173 |
+
# stop once the next index is missing
|
| 174 |
+
break
|
| 175 |
+
shutil.copy2(src_path, os.path.join(temp_dir, filename))
|
| 176 |
+
index += 1
|
| 177 |
+
if index > 20 and args.metric != 'word_count' and args.metric != 'token_count':
|
| 178 |
+
break
|
| 179 |
+
|
| 180 |
+
gen_folder = temp_dir
|
| 181 |
+
gen_poster_path = f'{args.base_dir}/{args.paper_name}/paper.pdf'
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
print('Evaluating poster:', args.paper_name)
|
| 185 |
+
|
| 186 |
+
if args.metric == 'stats':
|
| 187 |
+
stats_file = os.path.join(save_path, 'stats_result.json')
|
| 188 |
+
|
| 189 |
+
# 1) load existing results if there are any
|
| 190 |
+
if os.path.exists(stats_file):
|
| 191 |
+
with open(stats_file, 'r') as f:
|
| 192 |
+
stats_result = json.load(f)
|
| 193 |
+
print(f"Loaded existing stats from {stats_file}")
|
| 194 |
+
else:
|
| 195 |
+
stats_result = {}
|
| 196 |
+
|
| 197 |
+
# 2) CLIP similarity
|
| 198 |
+
if 'CLIP_similarity' not in stats_result:
|
| 199 |
+
_, cos_sim = compare_folders_with_clip(raw_folder, gen_folder)
|
| 200 |
+
stats_result['CLIP_similarity'] = cos_sim
|
| 201 |
+
print(f'CLIP similarity: {cos_sim}')
|
| 202 |
+
else:
|
| 203 |
+
print(f"Skipping CLIP similarity (already {stats_result['CLIP_similarity']})")
|
| 204 |
+
|
| 205 |
+
# 3) we only need to regenerate markdown+images if any of the text/image metrics is missing
|
| 206 |
+
need_eval = any(k not in stats_result for k in ('textual_ppl', 'mixtual_ppl', 'visual_relevance', 'visual_ppl'))
|
| 207 |
+
if need_eval:
|
| 208 |
+
images, poster_text, raw_markdown, new_markdown = gen_eval_markdown(
|
| 209 |
+
args.paper_name,
|
| 210 |
+
args.poster_method,
|
| 211 |
+
gen_poster_path
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# textual PPL
|
| 215 |
+
if 'textual_ppl' not in stats_result:
|
| 216 |
+
textual_ppl = get_ppl(poster_text)
|
| 217 |
+
stats_result['textual_ppl'] = textual_ppl
|
| 218 |
+
print(f'Textual PPL: {textual_ppl}')
|
| 219 |
+
else:
|
| 220 |
+
print(f"Skipping textual PPL (already {stats_result['textual_ppl']})")
|
| 221 |
+
|
| 222 |
+
# mixtual PPL
|
| 223 |
+
if 'mixtual_ppl' not in stats_result:
|
| 224 |
+
mixtual_ppl = get_ppl(new_markdown)
|
| 225 |
+
stats_result['mixtual_ppl'] = mixtual_ppl
|
| 226 |
+
print(f'Mixtual PPL: {mixtual_ppl}')
|
| 227 |
+
else:
|
| 228 |
+
print(f"Skipping mixtual PPL (already {stats_result['mixtual_ppl']})")
|
| 229 |
+
|
| 230 |
+
# visual relevance
|
| 231 |
+
if 'visual_relevance' not in stats_result:
|
| 232 |
+
if images:
|
| 233 |
+
sims = [
|
| 234 |
+
compute_cosine_similarity(v['image_clip_embedding'],
|
| 235 |
+
v['section_text_clip_embedding'])
|
| 236 |
+
for v in images.values()
|
| 237 |
+
]
|
| 238 |
+
avg_sim = float(np.mean(sims))
|
| 239 |
+
stats_result['visual_relevance'] = avg_sim
|
| 240 |
+
print(f'Average cosine similarity: {avg_sim}')
|
| 241 |
+
else:
|
| 242 |
+
stats_result['visual_relevance'] = 0.0
|
| 243 |
+
print('No images found in the poster. Set visual_relevance to 0.')
|
| 244 |
+
else:
|
| 245 |
+
print(f"Skipping visual relevance (already {stats_result['visual_relevance']})")
|
| 246 |
+
|
| 247 |
+
if 'visual_ppl' not in stats_result or math.isnan(stats_result['visual_ppl']):
|
| 248 |
+
visual_ppls = []
|
| 249 |
+
for relative_path, v in images.items():
|
| 250 |
+
image_path = os.path.join('eval_poster_markdown', args.paper_name, args.poster_method, relative_path)
|
| 251 |
+
image = Image.open(image_path)
|
| 252 |
+
visual_ppl = get_visual_ppl(image, poster_text)
|
| 253 |
+
visual_ppls.append(visual_ppl)
|
| 254 |
+
avg_visual_ppl = float(np.mean(visual_ppls))
|
| 255 |
+
stats_result['visual_ppl'] = avg_visual_ppl
|
| 256 |
+
print(f'Average visual PPL: {avg_visual_ppl}')
|
| 257 |
+
else:
|
| 258 |
+
print("All textual and visual metrics already computed; skipping gen_eval_markdown.")
|
| 259 |
+
|
| 260 |
+
if 'interleaved_ppl' not in stats_result:
|
| 261 |
+
interleaved_ppl = compute_interleaved_ppl(args.paper_name, args.poster_method)
|
| 262 |
+
stats_result['interleaved_ppl'] = interleaved_ppl
|
| 263 |
+
print(f'Interleaved PPL: {interleaved_ppl}')
|
| 264 |
+
else:
|
| 265 |
+
print(f"Skipping interleaved PPL (already {stats_result['interleaved_ppl']})")
|
| 266 |
+
|
| 267 |
+
if 'poster_image_ppl' not in stats_result:
|
| 268 |
+
if args.poster_method == 'paper':
|
| 269 |
+
poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), format='jpg')
|
| 270 |
+
else:
|
| 271 |
+
poster_images = [Image.open(gen_poster_path)]
|
| 272 |
+
poster_image_ppl = compute_poster_image_ppl(poster_images)
|
| 273 |
+
stats_result['poster_image_ppl'] = poster_image_ppl
|
| 274 |
+
print(f'Poster image PPL: {poster_image_ppl}')
|
| 275 |
+
else:
|
| 276 |
+
print(f"Skipping poster image PPL (already {stats_result['poster_image_ppl']})")
|
| 277 |
+
|
| 278 |
+
# 4) write back updated file
|
| 279 |
+
with open(stats_file, 'w') as f:
|
| 280 |
+
json.dump(stats_result, f, indent=4)
|
| 281 |
+
print(f"Updated stats written to {stats_file}")
|
| 282 |
+
elif args.metric == 'figure_count':
|
| 283 |
+
save_file_path = os.path.join(save_path, 'figure_count.json')
|
| 284 |
+
if os.path.exists(save_file_path):
|
| 285 |
+
print(f"Figure count already exists at {save_file_path}. Skipping.")
|
| 286 |
+
else:
|
| 287 |
+
figure_count = gen_eval_markdown(
|
| 288 |
+
args.paper_name,
|
| 289 |
+
args.poster_method,
|
| 290 |
+
gen_poster_path,
|
| 291 |
+
figure_count_only=True
|
| 292 |
+
)
|
| 293 |
+
with open(save_file_path, 'w') as f:
|
| 294 |
+
json.dump({'figure_count': figure_count}, f, indent=4)
|
| 295 |
+
print(f"Figure count saved to {save_file_path}")
|
| 296 |
+
elif args.metric == 'qa':
|
| 297 |
+
if args.fix is not None:
|
| 298 |
+
run_qa_and_update_results(
|
| 299 |
+
args,
|
| 300 |
+
raw_folder,
|
| 301 |
+
gen_poster_path,
|
| 302 |
+
save_path,
|
| 303 |
+
single_model_name=args.fix,
|
| 304 |
+
del_model_name=args.del_model_name
|
| 305 |
+
)
|
| 306 |
+
else:
|
| 307 |
+
overall_qa_result = {}
|
| 308 |
+
qa_result = {}
|
| 309 |
+
qa_dict = json.load(open(os.path.join(raw_folder, 'o3_qa.json'), 'r'))
|
| 310 |
+
detail_qa = qa_dict['detail']
|
| 311 |
+
understanding_qa = qa_dict['understanding']
|
| 312 |
+
model_names = [
|
| 313 |
+
'4o',
|
| 314 |
+
'o3',
|
| 315 |
+
'4o-mini'
|
| 316 |
+
]
|
| 317 |
+
if args.poster_method == 'paper':
|
| 318 |
+
poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
|
| 319 |
+
else:
|
| 320 |
+
poster_images = [Image.open(gen_poster_path)]
|
| 321 |
+
|
| 322 |
+
poster_images = [ensure_under_limit_pil(image) for image in poster_images]
|
| 323 |
+
|
| 324 |
+
for model_name in model_names:
|
| 325 |
+
qa_input_token, qa_output_token = 0, 0
|
| 326 |
+
print('QA model:', model_name)
|
| 327 |
+
agent_config = get_agent_config(model_name)
|
| 328 |
+
detail_accuracy, detail_aspect_accuracy, detail_agent_answers, input_token, output_token = eval_qa_get_answer(
|
| 329 |
+
poster_input=poster_images,
|
| 330 |
+
questions=detail_qa['questions'],
|
| 331 |
+
answers=detail_qa['answers'],
|
| 332 |
+
aspects=detail_qa['aspects'],
|
| 333 |
+
input_type='image',
|
| 334 |
+
agent_config=agent_config
|
| 335 |
+
)
|
| 336 |
+
print(f'{model_name} Detail QA accuracy:', detail_accuracy)
|
| 337 |
+
qa_input_token += input_token
|
| 338 |
+
qa_output_token += output_token
|
| 339 |
+
|
| 340 |
+
understanding_accuracy, understanding_aspect_accuracy, understanding_agent_answers, input_token, output_token = eval_qa_get_answer(
|
| 341 |
+
poster_input=poster_images,
|
| 342 |
+
questions=understanding_qa['questions'],
|
| 343 |
+
answers=understanding_qa['answers'],
|
| 344 |
+
aspects=understanding_qa['aspects'],
|
| 345 |
+
input_type='image',
|
| 346 |
+
agent_config=agent_config
|
| 347 |
+
)
|
| 348 |
+
print(f'{model_name} Understanding QA accuracy:', understanding_accuracy)
|
| 349 |
+
qa_input_token += input_token
|
| 350 |
+
qa_output_token += output_token
|
| 351 |
+
|
| 352 |
+
qa_result[model_name] = {
|
| 353 |
+
'detail_accuracy': detail_accuracy,
|
| 354 |
+
'detail_aspect_accuracy': detail_aspect_accuracy,
|
| 355 |
+
'detail_agent_answers': detail_agent_answers,
|
| 356 |
+
'understanding_accuracy': understanding_accuracy,
|
| 357 |
+
'understanding_aspect_accuracy': understanding_aspect_accuracy,
|
| 358 |
+
'understanding_agent_answers': understanding_agent_answers
|
| 359 |
+
}
|
| 360 |
+
|
| 361 |
+
print(f'{model_name} Input tokens:', qa_input_token)
|
| 362 |
+
print(f'{model_name} Output tokens:', qa_output_token)
|
| 363 |
+
|
| 364 |
+
# average the results
|
| 365 |
+
avg_detail_accuracy = np.mean([qa_result[model_name]['detail_accuracy'] for model_name in model_names])
|
| 366 |
+
avg_understanding_accuracy = np.mean([qa_result[model_name]['understanding_accuracy'] for model_name in model_names])
|
| 367 |
+
|
| 368 |
+
print('Average detail accuracy:', avg_detail_accuracy)
|
| 369 |
+
print('Average understanding accuracy:', avg_understanding_accuracy)
|
| 370 |
+
|
| 371 |
+
overall_qa_result['avg_detail_accuracy'] = avg_detail_accuracy
|
| 372 |
+
overall_qa_result['avg_understanding_accuracy'] = avg_understanding_accuracy
|
| 373 |
+
overall_qa_result['qa_result'] = qa_result
|
| 374 |
+
|
| 375 |
+
with open(f'{save_path}/overall_qa_result.json', 'w') as f:
|
| 376 |
+
json.dump(overall_qa_result, f, indent=4)
|
| 377 |
+
|
| 378 |
+
elif args.metric == 'word_count':
|
| 379 |
+
if args.poster_method == 'paper':
|
| 380 |
+
# loop through all images in the folder
|
| 381 |
+
image_paths = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), return_path=True)
|
| 382 |
+
word_count = 0
|
| 383 |
+
for image_path in image_paths:
|
| 384 |
+
# count words in each image
|
| 385 |
+
word_count += count_words_in_image(image_path)
|
| 386 |
+
else:
|
| 387 |
+
word_count = count_words_in_image(gen_poster_path)
|
| 388 |
+
# save to json
|
| 389 |
+
with open(f'{save_path}/word_count.json', 'w') as f:
|
| 390 |
+
json.dump({'word_count': word_count}, f, indent=4)
|
| 391 |
+
|
| 392 |
+
elif args.metric == 'token_count':
|
| 393 |
+
if args.poster_method == 'paper':
|
| 394 |
+
# loop through all images in the folder
|
| 395 |
+
image_paths = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'), return_path=True)
|
| 396 |
+
token_count = 0
|
| 397 |
+
for image_path in image_paths:
|
| 398 |
+
# count tokens in each image
|
| 399 |
+
token_count += count_tokens_in_image(image_path)
|
| 400 |
+
else:
|
| 401 |
+
token_count = count_tokens_in_image(gen_poster_path)
|
| 402 |
+
# save to json
|
| 403 |
+
with open(f'{save_path}/token_count.json', 'w') as f:
|
| 404 |
+
json.dump({'token_count': token_count}, f, indent=4)
|
| 405 |
+
elif args.metric == 'judge':
|
| 406 |
+
agent_config = get_agent_config('4o')
|
| 407 |
+
|
| 408 |
+
if args.poster_method == 'paper':
|
| 409 |
+
poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
|
| 410 |
+
else:
|
| 411 |
+
poster_images = [Image.open(gen_poster_path)]
|
| 412 |
+
|
| 413 |
+
results = eval_vlm_as_judge(
|
| 414 |
+
poster_image_list=poster_images,
|
| 415 |
+
agent_config=agent_config,
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
aesthetic_aspects = [
|
| 419 |
+
'aesthetic_element',
|
| 420 |
+
'aesthetic_engagement',
|
| 421 |
+
'aesthetic_layout'
|
| 422 |
+
]
|
| 423 |
+
|
| 424 |
+
information_aspects = [
|
| 425 |
+
'information_low_level',
|
| 426 |
+
'information_logic',
|
| 427 |
+
'information_content',
|
| 428 |
+
]
|
| 429 |
+
|
| 430 |
+
# compute average scores for all, for aesthetic, and for information
|
| 431 |
+
overall_average = np.mean([results[aspect]['score'] for aspect in results])
|
| 432 |
+
aesthetic_average = np.mean([results[aspect]['score'] for aspect in results if aspect in aesthetic_aspects])
|
| 433 |
+
information_average = np.mean([results[aspect]['score'] for aspect in results if aspect in information_aspects])
|
| 434 |
+
|
| 435 |
+
judge_result = {
|
| 436 |
+
'overall_average': overall_average,
|
| 437 |
+
'aesthetic_average': aesthetic_average,
|
| 438 |
+
'information_average': information_average,
|
| 439 |
+
'results': results
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
# save to json
|
| 443 |
+
with open(f'{save_path}/judge_result.json', 'w') as f:
|
| 444 |
+
json.dump(judge_result, f, indent=4)
|
| 445 |
+
elif args.metric == 'aesthetic_judge':
|
| 446 |
+
agent_config = get_agent_config('4o')
|
| 447 |
+
|
| 448 |
+
if args.poster_method == 'paper':
|
| 449 |
+
poster_images = open_folder_images(gen_folder, args.paper_name.replace(' ', '_'))
|
| 450 |
+
else:
|
| 451 |
+
poster_images = [Image.open(gen_poster_path)]
|
| 452 |
+
|
| 453 |
+
results = eval_vlm_as_judge(
|
| 454 |
+
poster_image_list=poster_images,
|
| 455 |
+
agent_config=agent_config,
|
| 456 |
+
aspect='aesthetic'
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
aesthetic_aspects = [
|
| 460 |
+
'aesthetic_element',
|
| 461 |
+
'aesthetic_engagement',
|
| 462 |
+
'aesthetic_layout'
|
| 463 |
+
]
|
| 464 |
+
|
| 465 |
+
aesthetic_average = np.mean([results[aspect]['score'] for aspect in results if aspect in aesthetic_aspects])
|
| 466 |
+
|
| 467 |
+
judge_result = {
|
| 468 |
+
'aesthetic_average': aesthetic_average,
|
| 469 |
+
'results': results
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
# save to json
|
| 473 |
+
with open(f'{save_path}/aesthetic_judge_result.json', 'w') as f:
|
| 474 |
+
json.dump(judge_result, f, indent=4)
|
| 475 |
+
|
| 476 |
+
if args.poster_method == 'paper':
|
| 477 |
+
# remove the temp folder
|
| 478 |
+
shutil.rmtree(temp_dir)
|
| 479 |
+
print(f"Removed temporary folder {temp_dir}")
|
Paper2Poster/Paper2Poster-eval/eval_qa_fix.py
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Run eval_poster_pipeline.py for every sub-folder in poster_sum_100,
|
| 4 |
+
using up to 10 threads. poster_method and fix are now taken from
|
| 5 |
+
command-line arguments.
|
| 6 |
+
|
| 7 |
+
Example:
|
| 8 |
+
python run_eval_threads.py \
|
| 9 |
+
--poster_method poster_sum_50 \
|
| 10 |
+
--fix llama-3-70b-vl
|
| 11 |
+
|
| 12 |
+
"""
|
| 13 |
+
import argparse
|
| 14 |
+
import concurrent.futures as cf
|
| 15 |
+
import pathlib
|
| 16 |
+
import signal
|
| 17 |
+
import subprocess
|
| 18 |
+
import sys
|
| 19 |
+
|
| 20 |
+
BASE_DIR = pathlib.Path("poster_sum_100") # directory holding the papers # number of worker threads
|
| 21 |
+
|
| 22 |
+
# ── Argument parsing ───────────────────────────────────────────────────────────
|
| 23 |
+
parser = argparse.ArgumentParser(
|
| 24 |
+
description="Run eval_poster_pipeline.py concurrently on all papers."
|
| 25 |
+
)
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--poster_method",
|
| 28 |
+
default="poster_sum_100",
|
| 29 |
+
help="Name of the poster-generation method to evaluate (default: %(default)s)",
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument(
|
| 32 |
+
"--fix",
|
| 33 |
+
default="qwen-2.5-vl-72b",
|
| 34 |
+
help="Value to pass to --fix in eval_poster_pipeline.py (default: %(default)s)",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
parser.add_argument(
|
| 38 |
+
'--max_workers',
|
| 39 |
+
type=int,
|
| 40 |
+
default=1,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
parser.add_argument('--del_model_name', type=str)
|
| 44 |
+
args = parser.parse_args()
|
| 45 |
+
# ───────────────────────────────────────────────────────────────────────────────
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
MAX_WORKERS = args.max_workers
|
| 49 |
+
|
| 50 |
+
def run_pipeline(subfolder: str, poster_method: str, fix: str) -> None:
|
| 51 |
+
"""Invoke eval_poster_pipeline.py for a single paper."""
|
| 52 |
+
cmd = [
|
| 53 |
+
"python",
|
| 54 |
+
"eval_poster_pipeline.py",
|
| 55 |
+
"--paper_name",
|
| 56 |
+
subfolder,
|
| 57 |
+
"--poster_method",
|
| 58 |
+
poster_method,
|
| 59 |
+
"--poster_image_name",
|
| 60 |
+
"poster.png",
|
| 61 |
+
"--metric",
|
| 62 |
+
"qa",
|
| 63 |
+
"--fix",
|
| 64 |
+
fix,
|
| 65 |
+
]
|
| 66 |
+
if args.del_model_name:
|
| 67 |
+
cmd += ["--del_model_name", args.del_model_name]
|
| 68 |
+
subprocess.run(cmd, check=True)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
MAX_RETRIES = 50
|
| 72 |
+
|
| 73 |
+
def run_with_retries(folder: str, poster_method, fix) -> None:
|
| 74 |
+
"""
|
| 75 |
+
Tries to run_pipeline up to MAX_RETRIES times before giving up.
|
| 76 |
+
"""
|
| 77 |
+
for attempt in range(1, MAX_RETRIES + 1):
|
| 78 |
+
try:
|
| 79 |
+
run_pipeline(folder, poster_method, fix)
|
| 80 |
+
return
|
| 81 |
+
except Exception as e:
|
| 82 |
+
if attempt < MAX_RETRIES:
|
| 83 |
+
print(f"⚠️ {folder}: attempt {attempt} failed ({e!r}), retrying…")
|
| 84 |
+
else:
|
| 85 |
+
# Last attempt also failed, re-raise so the pool will catch it
|
| 86 |
+
raise
|
| 87 |
+
|
| 88 |
+
def main() -> None:
|
| 89 |
+
folders = sorted(p.name for p in BASE_DIR.iterdir() if p.is_dir())
|
| 90 |
+
|
| 91 |
+
with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as pool:
|
| 92 |
+
futures = {
|
| 93 |
+
pool.submit(run_with_retries, f, args.poster_method, args.fix): f
|
| 94 |
+
for f in folders
|
| 95 |
+
}
|
| 96 |
+
for fut in cf.as_completed(futures):
|
| 97 |
+
paper = futures[fut]
|
| 98 |
+
try:
|
| 99 |
+
fut.result()
|
| 100 |
+
print(f"✓ {paper} done")
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"✗ {paper} failed after {MAX_RETRIES} attempts: {e}", file=sys.stderr)
|
| 103 |
+
# ── Graceful shutdown on Ctrl-C / SIGTERM ──────────────────────────────────────
|
| 104 |
+
def _handle_signal(signum, frame):
|
| 105 |
+
print("\nReceived signal, shutting down…", file=sys.stderr)
|
| 106 |
+
sys.exit(1)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
signal.signal(signal.SIGINT, _handle_signal)
|
| 110 |
+
signal.signal(signal.SIGTERM, _handle_signal)
|
| 111 |
+
|
| 112 |
+
# ── Entry point ────────────────────────────────────────────────────────────────
|
| 113 |
+
if __name__ == "__main__":
|
| 114 |
+
main()
|
Paper2Poster/PosterAgent/LLM_direct_generate.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import get_json_from_response
|
| 3 |
+
|
| 4 |
+
from camel.models import ModelFactory
|
| 5 |
+
from camel.agents import ChatAgent
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from utils.wei_utils import account_token, get_agent_config, html_to_png
|
| 9 |
+
|
| 10 |
+
from utils.pptx_utils import *
|
| 11 |
+
from utils.critic_utils import *
|
| 12 |
+
import yaml
|
| 13 |
+
import time
|
| 14 |
+
from jinja2 import Environment, StrictUndefined
|
| 15 |
+
from utils.poster_eval_utils import get_poster_text
|
| 16 |
+
import argparse
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
if __name__ == '__main__':
|
| 23 |
+
parser = argparse.ArgumentParser()
|
| 24 |
+
parser.add_argument('--paper_path', type=str)
|
| 25 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 26 |
+
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
# get current directory
|
| 30 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 31 |
+
|
| 32 |
+
meta_dir = args.paper_path.replace('paper.pdf', 'meta.json')
|
| 33 |
+
meta = json.load(open(meta_dir, 'r'))
|
| 34 |
+
poster_width = meta['width']
|
| 35 |
+
poster_height = meta['height']
|
| 36 |
+
|
| 37 |
+
output_dir = f"{args.model_name}_HTML/{args.paper_path.replace('paper.pdf', '')}"
|
| 38 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 39 |
+
|
| 40 |
+
total_input_token = 0
|
| 41 |
+
total_output_token = 0
|
| 42 |
+
|
| 43 |
+
start_time = time.time()
|
| 44 |
+
model_config = get_agent_config(args.model_name)
|
| 45 |
+
model = ModelFactory.create(
|
| 46 |
+
model_platform=model_config['model_platform'],
|
| 47 |
+
model_type=model_config['model_type'],
|
| 48 |
+
model_config_dict=model_config['model_config'],
|
| 49 |
+
)
|
| 50 |
+
paper_text = get_poster_text(args.paper_path)
|
| 51 |
+
|
| 52 |
+
actor_agent_name = 'LLM_gen_HTML'
|
| 53 |
+
|
| 54 |
+
with open(f'prompt_templates/{actor_agent_name}.yaml', "r") as f:
|
| 55 |
+
content_config = yaml.safe_load(f)
|
| 56 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 57 |
+
template = jinja_env.from_string(content_config["template"])
|
| 58 |
+
|
| 59 |
+
actor_sys_msg = content_config['system_prompt']
|
| 60 |
+
actor_agent = ChatAgent(
|
| 61 |
+
system_message=actor_sys_msg,
|
| 62 |
+
model=model,
|
| 63 |
+
message_window_size=None
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
jinja_args = {
|
| 67 |
+
'document_markdown': paper_text,
|
| 68 |
+
'poster_width': poster_width,
|
| 69 |
+
'poster_height': poster_height,
|
| 70 |
+
}
|
| 71 |
+
prompt = template.render(**jinja_args)
|
| 72 |
+
|
| 73 |
+
actor_agent.reset()
|
| 74 |
+
response = actor_agent.step(prompt)
|
| 75 |
+
input_token, output_token = account_token(response)
|
| 76 |
+
total_input_token += input_token
|
| 77 |
+
total_output_token += output_token
|
| 78 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 79 |
+
html_str = result_json['HTML']
|
| 80 |
+
|
| 81 |
+
# write to poster.html
|
| 82 |
+
with open(f'{output_dir}/poster.html', 'w') as f:
|
| 83 |
+
f.write(html_str)
|
| 84 |
+
|
| 85 |
+
html_to_png(
|
| 86 |
+
os.path.join(current_dir, output_dir, 'poster.html'),
|
| 87 |
+
poster_width,
|
| 88 |
+
poster_height,
|
| 89 |
+
os.path.join(current_dir, output_dir, 'poster.png')
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
end_time = time.time()
|
| 94 |
+
elapsed_time = end_time - start_time
|
| 95 |
+
|
| 96 |
+
log = {
|
| 97 |
+
'input_token': total_input_token,
|
| 98 |
+
'output_token': total_output_token,
|
| 99 |
+
'time_taken': elapsed_time
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
with open(f'{output_dir}/log.json', 'w') as f:
|
| 103 |
+
json.dump(log, f, indent=4)
|
Paper2Poster/PosterAgent/LLM_direct_generate_beamer.py
ADDED
|
@@ -0,0 +1,189 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import time
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
from jinja2 import Environment, StrictUndefined
|
| 6 |
+
|
| 7 |
+
from utils.src.utils import get_json_from_response, account_token, html_to_png
|
| 8 |
+
from utils.config_utils import load_poster_yaml_config
|
| 9 |
+
|
| 10 |
+
from camel.models import ModelFactory
|
| 11 |
+
from camel.agents import ChatAgent
|
| 12 |
+
from camel.configs import ChatGPTConfig
|
| 13 |
+
from camel.types import ModelPlatformType, ModelType
|
| 14 |
+
|
| 15 |
+
load_dotenv()
|
| 16 |
+
|
| 17 |
+
def gen_beamer_poster_direct(
|
| 18 |
+
paper_text: str,
|
| 19 |
+
poster_width_cm: float = 120,
|
| 20 |
+
poster_height_cm: float = 90,
|
| 21 |
+
beamer_theme: str = "default",
|
| 22 |
+
output_dir: str = "output",
|
| 23 |
+
model_name: str = "4o"
|
| 24 |
+
):
|
| 25 |
+
"""
|
| 26 |
+
Generate Beamer poster directly from paper text using LLM.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
paper_text: Extracted text from the paper
|
| 30 |
+
poster_width_cm: Poster width in centimeters
|
| 31 |
+
poster_height_cm: Poster height in centimeters
|
| 32 |
+
beamer_theme: Beamer theme name
|
| 33 |
+
output_dir: Output directory
|
| 34 |
+
model_name: Model name for generation
|
| 35 |
+
"""
|
| 36 |
+
start_time = time.time()
|
| 37 |
+
total_input_token, total_output_token = 0, 0
|
| 38 |
+
|
| 39 |
+
# Load configuration
|
| 40 |
+
config_path = "utils/prompt_templates/LLM_gen_Beamer.yaml"
|
| 41 |
+
with open(config_path, "r") as f:
|
| 42 |
+
config = yaml.safe_load(f)
|
| 43 |
+
|
| 44 |
+
# Create model and agent
|
| 45 |
+
actor_model = ModelFactory.create(
|
| 46 |
+
model_platform=ModelPlatformType.OPENAI,
|
| 47 |
+
model_type=ModelType.GPT_4O,
|
| 48 |
+
model_config_dict=ChatGPTConfig().as_dict(),
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
actor_agent = ChatAgent(
|
| 52 |
+
system_message=config['system_prompt'],
|
| 53 |
+
model=actor_model,
|
| 54 |
+
message_window_size=None
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Prepare template arguments
|
| 58 |
+
jinja_args = {
|
| 59 |
+
'document_markdown': paper_text,
|
| 60 |
+
'poster_width_cm': poster_width_cm,
|
| 61 |
+
'poster_height_cm': poster_height_cm,
|
| 62 |
+
'beamer_theme': beamer_theme,
|
| 63 |
+
'aspect_ratio': "169",
|
| 64 |
+
'title_color': "[47, 85, 151]",
|
| 65 |
+
'text_color': "[0, 0, 0]"
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Render template
|
| 69 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 70 |
+
template = jinja_env.from_string(config["template"])
|
| 71 |
+
prompt = template.render(**jinja_args)
|
| 72 |
+
|
| 73 |
+
# Generate Beamer code
|
| 74 |
+
actor_agent.reset()
|
| 75 |
+
response = actor_agent.step(prompt)
|
| 76 |
+
input_token, output_token = account_token(response)
|
| 77 |
+
total_input_token += input_token
|
| 78 |
+
total_output_token += output_token
|
| 79 |
+
|
| 80 |
+
# Extract LaTeX code
|
| 81 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 82 |
+
latex_str = result_json['LATEX']
|
| 83 |
+
|
| 84 |
+
# Save LaTeX file
|
| 85 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 86 |
+
tex_path = os.path.join(output_dir, 'poster.tex')
|
| 87 |
+
with open(tex_path, 'w', encoding='utf-8') as f:
|
| 88 |
+
f.write(latex_str)
|
| 89 |
+
|
| 90 |
+
# Compile to PDF
|
| 91 |
+
print("Compiling LaTeX to PDF...")
|
| 92 |
+
success = compile_beamer_to_pdf(tex_path, output_dir)
|
| 93 |
+
|
| 94 |
+
if success:
|
| 95 |
+
print(f"✅ Beamer poster generated successfully: {tex_path}")
|
| 96 |
+
else:
|
| 97 |
+
print("❌ Failed to compile LaTeX to PDF")
|
| 98 |
+
|
| 99 |
+
# Save log
|
| 100 |
+
end_time = time.time()
|
| 101 |
+
elapsed_time = end_time - start_time
|
| 102 |
+
|
| 103 |
+
log = {
|
| 104 |
+
'input_token': total_input_token,
|
| 105 |
+
'output_token': total_output_token,
|
| 106 |
+
'time_taken': elapsed_time,
|
| 107 |
+
'output_format': 'beamer',
|
| 108 |
+
'beamer_theme': beamer_theme
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
with open(os.path.join(output_dir, 'log.json'), 'w') as f:
|
| 112 |
+
json.dump(log, f, indent=4)
|
| 113 |
+
|
| 114 |
+
return tex_path, success
|
| 115 |
+
|
| 116 |
+
def compile_beamer_to_pdf(tex_path: str, output_dir: str = "."):
|
| 117 |
+
"""
|
| 118 |
+
Compile Beamer .tex file to PDF using pdflatex.
|
| 119 |
+
|
| 120 |
+
Args:
|
| 121 |
+
tex_path: Path to .tex file
|
| 122 |
+
output_dir: Output directory for PDF
|
| 123 |
+
"""
|
| 124 |
+
import subprocess
|
| 125 |
+
|
| 126 |
+
try:
|
| 127 |
+
# Run pdflatex twice for proper cross-references
|
| 128 |
+
result1 = subprocess.run(
|
| 129 |
+
['pdflatex', '-output-directory', output_dir, tex_path],
|
| 130 |
+
capture_output=True,
|
| 131 |
+
text=True,
|
| 132 |
+
timeout=60
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
result2 = subprocess.run(
|
| 136 |
+
['pdflatex', '-output-directory', output_dir, tex_path],
|
| 137 |
+
capture_output=True,
|
| 138 |
+
text=True,
|
| 139 |
+
timeout=60
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
if result1.returncode == 0 and result2.returncode == 0:
|
| 143 |
+
print(f"Successfully compiled {tex_path} to PDF")
|
| 144 |
+
return True
|
| 145 |
+
else:
|
| 146 |
+
print(f"Error compiling {tex_path}:")
|
| 147 |
+
print(result1.stderr)
|
| 148 |
+
print(result2.stderr)
|
| 149 |
+
return False
|
| 150 |
+
|
| 151 |
+
except subprocess.TimeoutExpired:
|
| 152 |
+
print(f"Timeout while compiling {tex_path}")
|
| 153 |
+
return False
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Error compiling {tex_path}: {e}")
|
| 156 |
+
return False
|
| 157 |
+
|
| 158 |
+
if __name__ == "__main__":
|
| 159 |
+
import argparse
|
| 160 |
+
|
| 161 |
+
parser = argparse.ArgumentParser(description='Generate Beamer poster directly from paper')
|
| 162 |
+
parser.add_argument('--paper_path', required=True, help='Path to paper PDF')
|
| 163 |
+
parser.add_argument('--output_dir', default='beamer_output', help='Output directory')
|
| 164 |
+
parser.add_argument('--poster_width_cm', type=float, default=120, help='Poster width in cm')
|
| 165 |
+
parser.add_argument('--poster_height_cm', type=float, default=90, help='Poster height in cm')
|
| 166 |
+
parser.add_argument('--beamer_theme', default='default', help='Beamer theme')
|
| 167 |
+
parser.add_argument('--model_name', default='4o', help='Model name')
|
| 168 |
+
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
|
| 171 |
+
# Extract text from paper (you'll need to implement this)
|
| 172 |
+
# For now, using placeholder text
|
| 173 |
+
paper_text = "This is placeholder text. In practice, you would extract text from the PDF."
|
| 174 |
+
|
| 175 |
+
# Generate Beamer poster
|
| 176 |
+
tex_path, success = gen_beamer_poster_direct(
|
| 177 |
+
paper_text=paper_text,
|
| 178 |
+
poster_width_cm=args.poster_width_cm,
|
| 179 |
+
poster_height_cm=args.poster_height_cm,
|
| 180 |
+
beamer_theme=args.beamer_theme,
|
| 181 |
+
output_dir=args.output_dir,
|
| 182 |
+
model_name=args.model_name
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
if success:
|
| 186 |
+
print(f"Beamer poster generated at: {tex_path}")
|
| 187 |
+
else:
|
| 188 |
+
print("Failed to generate Beamer poster")
|
| 189 |
+
|
Paper2Poster/PosterAgent/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from . import (
|
| 2 |
+
apply_theme,
|
| 3 |
+
create_dataset,
|
| 4 |
+
deoverflow,
|
| 5 |
+
deoverflow_parallel,
|
| 6 |
+
fill_and_style,
|
| 7 |
+
gen_outline_layout_parallel,
|
| 8 |
+
gen_outline_layout,
|
| 9 |
+
gen_poster_content,
|
| 10 |
+
gen_pptx_code,
|
| 11 |
+
LLM_direct_generate,
|
| 12 |
+
new_pipeline,
|
| 13 |
+
parse_raw,
|
| 14 |
+
poster_gen_pipeline,
|
| 15 |
+
tree_split_layout
|
| 16 |
+
)
|
Paper2Poster/PosterAgent/apply_theme.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 3 |
+
import json
|
| 4 |
+
import shutil
|
| 5 |
+
|
| 6 |
+
from camel.models import ModelFactory
|
| 7 |
+
from camel.agents import ChatAgent
|
| 8 |
+
|
| 9 |
+
from utils.wei_utils import *
|
| 10 |
+
|
| 11 |
+
from camel.messages import BaseMessage
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import pickle as pkl
|
| 14 |
+
from utils.pptx_utils import *
|
| 15 |
+
from utils.critic_utils import *
|
| 16 |
+
import yaml
|
| 17 |
+
from jinja2 import Environment, StrictUndefined
|
| 18 |
+
from pdf2image import convert_from_path
|
| 19 |
+
import argparse
|
| 20 |
+
|
| 21 |
+
load_dotenv()
|
| 22 |
+
|
| 23 |
+
def poster_apply_theme(args, actor_config, critic_config):
|
| 24 |
+
total_input_token, total_output_token = 0, 0
|
| 25 |
+
extract_input_token, extract_output_token = 0, 0
|
| 26 |
+
gen_input_token, gen_output_token = 0, 0
|
| 27 |
+
non_overlap_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'rb'))
|
| 28 |
+
non_overlap_code = non_overlap_ckpt['final_code_by_section']
|
| 29 |
+
sections = list(non_overlap_code.keys())
|
| 30 |
+
sections = [s for s in sections if s != 'meta']
|
| 31 |
+
template_img = convert_from_path(args.template_path)[0]
|
| 32 |
+
image_bytes = io.BytesIO()
|
| 33 |
+
template_img.save(image_bytes, format="PNG")
|
| 34 |
+
image_bytes.seek(0)
|
| 35 |
+
|
| 36 |
+
# Reload the image from memory as a standard PIL.Image.Image
|
| 37 |
+
template_img = Image.open(image_bytes)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
title_actor_agent_name = 'theme_agent_title'
|
| 41 |
+
with open(f"prompt_templates/{title_actor_agent_name}.yaml", "r") as f:
|
| 42 |
+
title_theme_actor_config = yaml.safe_load(f)
|
| 43 |
+
|
| 44 |
+
section_actor_agent_name = 'theme_agent_section'
|
| 45 |
+
with open(f"prompt_templates/{section_actor_agent_name}.yaml", "r") as f:
|
| 46 |
+
section_theme_actor_config = yaml.safe_load(f)
|
| 47 |
+
|
| 48 |
+
title_actor_model = ModelFactory.create(
|
| 49 |
+
model_platform=actor_config['model_platform'],
|
| 50 |
+
model_type=actor_config['model_type'],
|
| 51 |
+
model_config_dict=actor_config['model_config'], # [Optional] the config for model
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
title_actor_sys_msg = title_theme_actor_config['system_prompt']
|
| 55 |
+
|
| 56 |
+
title_actor_agent = ChatAgent(
|
| 57 |
+
system_message=title_actor_sys_msg,
|
| 58 |
+
model=title_actor_model,
|
| 59 |
+
message_window_size=10, # [Optional] the length for chat memory
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
section_actor_model = ModelFactory.create(
|
| 63 |
+
model_platform=actor_config['model_platform'],
|
| 64 |
+
model_type=actor_config['model_type'],
|
| 65 |
+
model_config_dict=actor_config['model_config'], # [Optional] the config for model
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
section_actor_sys_msg = section_theme_actor_config['system_prompt']
|
| 69 |
+
|
| 70 |
+
section_actor_agent = ChatAgent(
|
| 71 |
+
system_message=section_actor_sys_msg,
|
| 72 |
+
model=section_actor_model,
|
| 73 |
+
message_window_size=10, # [Optional] the length for chat memory
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
critic_model = ModelFactory.create(
|
| 77 |
+
model_platform=critic_config['model_platform'],
|
| 78 |
+
model_type=critic_config['model_type'],
|
| 79 |
+
model_config_dict=critic_config['model_config'],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
critic_sys_msg = 'You are a helpful assistant.'
|
| 83 |
+
|
| 84 |
+
critic_agent = ChatAgent(
|
| 85 |
+
system_message=critic_sys_msg,
|
| 86 |
+
model=critic_model,
|
| 87 |
+
message_window_size=None,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
theme_aspects = {
|
| 91 |
+
'background': ['background'],
|
| 92 |
+
'title': ['title_author', 'title_author_border'],
|
| 93 |
+
'section': ['section_body', 'section_title', 'section_border']
|
| 94 |
+
}
|
| 95 |
+
|
| 96 |
+
theme_styles = {}
|
| 97 |
+
for aspect in theme_aspects.keys():
|
| 98 |
+
theme_styles[aspect] = {}
|
| 99 |
+
|
| 100 |
+
for aspect, prompt_types in theme_aspects.items():
|
| 101 |
+
for prompt_type in prompt_types:
|
| 102 |
+
print(f'Getting style for {prompt_type}')
|
| 103 |
+
with open(f"prompt_templates/theme_templates/theme_{prompt_type}.txt", "r") as f:
|
| 104 |
+
prompt = f.read()
|
| 105 |
+
msg = BaseMessage.make_user_message(
|
| 106 |
+
role_name="User",
|
| 107 |
+
content=prompt,
|
| 108 |
+
image_list=[template_img],
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
critic_agent.reset()
|
| 112 |
+
response = critic_agent.step(msg)
|
| 113 |
+
input_token, output_token = account_token(response)
|
| 114 |
+
total_input_token += input_token
|
| 115 |
+
total_output_token += output_token
|
| 116 |
+
extract_input_token += input_token
|
| 117 |
+
extract_output_token += output_token
|
| 118 |
+
theme_style = get_json_from_response(response.msgs[0].content)
|
| 119 |
+
theme_styles[aspect][prompt_type] = theme_style
|
| 120 |
+
|
| 121 |
+
if 'fontStyle' in theme_styles['section']['section_body']:
|
| 122 |
+
del theme_styles['section']['section_body']['fontStyle']
|
| 123 |
+
|
| 124 |
+
outline_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
|
| 125 |
+
outline = json.load(open(outline_path, 'r'))
|
| 126 |
+
outline_skeleton = {}
|
| 127 |
+
for key, val in outline.items():
|
| 128 |
+
if key == 'meta':
|
| 129 |
+
continue
|
| 130 |
+
if not 'subsections' in val:
|
| 131 |
+
outline_skeleton[key] = {
|
| 132 |
+
'section': key
|
| 133 |
+
}
|
| 134 |
+
else:
|
| 135 |
+
for subsection_name, subsection_dict in val['subsections'].items():
|
| 136 |
+
outline_skeleton[subsection_dict['name']] = {
|
| 137 |
+
'section': key
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
for key in outline_skeleton.keys():
|
| 141 |
+
if 'title' in key.lower() or 'author' in key.lower():
|
| 142 |
+
outline_skeleton[key]['style'] = theme_styles['section']['section_title']
|
| 143 |
+
else:
|
| 144 |
+
outline_skeleton[key]['style'] = theme_styles['section']['section_body']
|
| 145 |
+
|
| 146 |
+
outline_skeleton_list = []
|
| 147 |
+
for section in sections[1:]:
|
| 148 |
+
# append all subsections whose section key is the current section
|
| 149 |
+
for key, val in outline_skeleton.items():
|
| 150 |
+
if val['section'] == section:
|
| 151 |
+
outline_skeleton_list.append({key: val})
|
| 152 |
+
|
| 153 |
+
theme_logs = {}
|
| 154 |
+
theme_code = {}
|
| 155 |
+
concatenated_code = {}
|
| 156 |
+
|
| 157 |
+
# Title
|
| 158 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 159 |
+
|
| 160 |
+
title_actor_template = jinja_env.from_string(title_theme_actor_config["template"])
|
| 161 |
+
|
| 162 |
+
# Title section
|
| 163 |
+
print(f'Processing section {sections[0]}')
|
| 164 |
+
curr_title_code = non_overlap_code[sections[0]]
|
| 165 |
+
for style in ['background', 'title']:
|
| 166 |
+
for sub_style in theme_styles[style].keys():
|
| 167 |
+
print(f' Applying theme for {sub_style}')
|
| 168 |
+
jinja_args = {
|
| 169 |
+
'style_json': {sub_style: theme_styles[style][sub_style]},
|
| 170 |
+
'function_docs': documentation,
|
| 171 |
+
'existing_code': curr_title_code
|
| 172 |
+
}
|
| 173 |
+
actor_prompt = title_actor_template.render(**jinja_args)
|
| 174 |
+
log = apply_theme(title_actor_agent, actor_prompt, args.max_retry, existing_code='')
|
| 175 |
+
if log[-1]['error'] is not None:
|
| 176 |
+
raise Exception(log[-1]['error'])
|
| 177 |
+
|
| 178 |
+
input_token, output_token = log[-1]['cumulative_tokens']
|
| 179 |
+
total_input_token += input_token
|
| 180 |
+
total_output_token += output_token
|
| 181 |
+
gen_input_token += input_token
|
| 182 |
+
gen_output_token += output_token
|
| 183 |
+
|
| 184 |
+
shutil.copy('poster.pptx', f'tmp/theme_poster_<{sections[0]}>_<{style}>_<{sub_style}>.pptx')
|
| 185 |
+
|
| 186 |
+
if not style in theme_logs:
|
| 187 |
+
theme_logs[style] = {}
|
| 188 |
+
|
| 189 |
+
theme_logs[style][sub_style] = log
|
| 190 |
+
curr_title_code = log[-1]['code']
|
| 191 |
+
|
| 192 |
+
theme_code[sections[0]] = curr_title_code
|
| 193 |
+
concatenated_code[sections[0]] = log[-1]['concatenated_code']
|
| 194 |
+
|
| 195 |
+
# Remaining sections
|
| 196 |
+
|
| 197 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 198 |
+
|
| 199 |
+
section_actor_template = jinja_env.from_string(section_theme_actor_config["template"])
|
| 200 |
+
|
| 201 |
+
prev_section = None
|
| 202 |
+
for style_dict in outline_skeleton_list:
|
| 203 |
+
curr_subsection = list(style_dict.keys())[0]
|
| 204 |
+
curr_section = style_dict[curr_subsection]['section']
|
| 205 |
+
section_index = sections.index(curr_section)
|
| 206 |
+
print(f'Processing section {curr_section}')
|
| 207 |
+
if prev_section != curr_section:
|
| 208 |
+
prev_section = curr_section
|
| 209 |
+
curr_section_code = non_overlap_code[curr_section]
|
| 210 |
+
print(f' Applying theme for {curr_subsection}')
|
| 211 |
+
jinja_args = {
|
| 212 |
+
'style_json': json.dumps({curr_subsection: style_dict[curr_subsection]['style']}, indent=4),
|
| 213 |
+
'function_docs': documentation,
|
| 214 |
+
'existing_code': curr_section_code
|
| 215 |
+
}
|
| 216 |
+
actor_prompt = section_actor_template.render(**jinja_args)
|
| 217 |
+
existing_code = concatenated_code[sections[section_index - 1]]
|
| 218 |
+
log = apply_theme(section_actor_agent, actor_prompt, args.max_retry, existing_code=existing_code)
|
| 219 |
+
if log[-1]['error'] is not None:
|
| 220 |
+
raise Exception(log[-1]['error'])
|
| 221 |
+
|
| 222 |
+
input_token, output_token = log[-1]['cumulative_tokens']
|
| 223 |
+
total_input_token += input_token
|
| 224 |
+
total_output_token += output_token
|
| 225 |
+
gen_input_token += input_token
|
| 226 |
+
gen_output_token += output_token
|
| 227 |
+
|
| 228 |
+
shutil.copy('poster.pptx', f'tmp/theme_poster_<{curr_section}>_<{curr_subsection}>.pptx')
|
| 229 |
+
|
| 230 |
+
if not style in theme_logs:
|
| 231 |
+
theme_logs[style] = {}
|
| 232 |
+
|
| 233 |
+
theme_logs[style][sub_style] = log
|
| 234 |
+
curr_section_code = log[-1]['code']
|
| 235 |
+
|
| 236 |
+
theme_code[curr_section] = curr_section_code
|
| 237 |
+
concatenated_code[curr_section] = log[-1]['concatenated_code']
|
| 238 |
+
|
| 239 |
+
ppt_to_images(f'poster.pptx', 'tmp/theme_preview')
|
| 240 |
+
|
| 241 |
+
result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
|
| 242 |
+
shutil.copy('poster.pptx', f'{result_dir}/theme_poster.pptx')
|
| 243 |
+
ppt_to_images(f'poster.pptx', f'{result_dir}/theme_poster_preview')
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
ckpt = {
|
| 247 |
+
'theme_styles': theme_styles,
|
| 248 |
+
'theme_logs': theme_logs,
|
| 249 |
+
'theme_code': theme_code,
|
| 250 |
+
'concatenated_code': concatenated_code,
|
| 251 |
+
'total_input_token': total_input_token,
|
| 252 |
+
'total_output_token': total_output_token,
|
| 253 |
+
'extract_input_token': extract_input_token,
|
| 254 |
+
'extract_output_token': extract_output_token,
|
| 255 |
+
'gen_input_token': gen_input_token,
|
| 256 |
+
'gen_output_token': gen_output_token
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_theme_ckpt.pkl', 'wb'))
|
| 260 |
+
|
| 261 |
+
return total_input_token, total_output_token
|
| 262 |
+
|
| 263 |
+
if __name__ == '__main__':
|
| 264 |
+
parser = argparse.ArgumentParser()
|
| 265 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 266 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 267 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 268 |
+
parser.add_argument('--index', type=int, default=0)
|
| 269 |
+
parser.add_argument('--template_path', type=str)
|
| 270 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 271 |
+
args = parser.parse_args()
|
| 272 |
+
|
| 273 |
+
actor_config = get_agent_config(args.model_name)
|
| 274 |
+
critic_config = get_agent_config(args.model_name)
|
| 275 |
+
|
| 276 |
+
if args.poster_name is None:
|
| 277 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 278 |
+
|
| 279 |
+
input_token, output_token = poster_apply_theme(args, actor_config, critic_config)
|
| 280 |
+
|
| 281 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/beamer_pipeline.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
from typing import Dict, Any, List
|
| 5 |
+
|
| 6 |
+
# Import existing modules
|
| 7 |
+
from PosterAgent.gen_beamer_code import (
|
| 8 |
+
generate_beamer_poster_code,
|
| 9 |
+
convert_pptx_layout_to_beamer,
|
| 10 |
+
save_beamer_code,
|
| 11 |
+
compile_beamer_to_pdf
|
| 12 |
+
)
|
| 13 |
+
from PosterAgent.gen_pptx_code import generate_poster_code
|
| 14 |
+
from utils.wei_utils import run_code
|
| 15 |
+
from utils.theme_utils import get_default_theme, create_theme_with_alignment
|
| 16 |
+
|
| 17 |
+
def generate_beamer_poster(
|
| 18 |
+
panel_arrangement_inches: List[Dict[str, Any]],
|
| 19 |
+
text_arrangement_inches: List[Dict[str, Any]],
|
| 20 |
+
figure_arrangement_inches: List[Dict[str, Any]],
|
| 21 |
+
bullet_content: List[Dict[str, Any]],
|
| 22 |
+
poster_info: Dict[str, str],
|
| 23 |
+
args,
|
| 24 |
+
width_cm: float = 120,
|
| 25 |
+
height_cm: float = 90,
|
| 26 |
+
theme: str = "default"
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Generate Beamer poster instead of PowerPoint.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
panel_arrangement_inches: Panel layout data
|
| 33 |
+
text_arrangement_inches: Text layout data
|
| 34 |
+
figure_arrangement_inches: Figure layout data
|
| 35 |
+
bullet_content: Content for text boxes
|
| 36 |
+
poster_info: Poster metadata (title, author, institute)
|
| 37 |
+
args: Command line arguments
|
| 38 |
+
width_cm: Poster width in centimeters
|
| 39 |
+
height_cm: Poster height in centimeters
|
| 40 |
+
theme: Beamer theme name
|
| 41 |
+
"""
|
| 42 |
+
print("\n🎯 Generating Beamer poster code...", flush=True)
|
| 43 |
+
|
| 44 |
+
# Convert layout data to Beamer format
|
| 45 |
+
beamer_data = convert_pptx_layout_to_beamer({
|
| 46 |
+
'text_arrangement': text_arrangement_inches,
|
| 47 |
+
'figure_arrangement': figure_arrangement_inches
|
| 48 |
+
})
|
| 49 |
+
|
| 50 |
+
# Update poster info
|
| 51 |
+
beamer_data['poster_info'].update(poster_info)
|
| 52 |
+
|
| 53 |
+
# Generate Beamer code
|
| 54 |
+
beamer_code = generate_beamer_poster_code(
|
| 55 |
+
sections=beamer_data['sections'],
|
| 56 |
+
figures=beamer_data['figures'],
|
| 57 |
+
poster_info=beamer_data['poster_info'],
|
| 58 |
+
width_cm=width_cm,
|
| 59 |
+
height_cm=height_cm,
|
| 60 |
+
theme=theme
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Save Beamer code
|
| 64 |
+
tex_path = f'{args.tmp_dir}/poster.tex'
|
| 65 |
+
save_beamer_code(beamer_code, tex_path)
|
| 66 |
+
|
| 67 |
+
# Compile to PDF
|
| 68 |
+
print("\n📄 Compiling Beamer to PDF...", flush=True)
|
| 69 |
+
success = compile_beamer_to_pdf(tex_path, args.tmp_dir)
|
| 70 |
+
|
| 71 |
+
if not success:
|
| 72 |
+
raise RuntimeError('Error in compiling Beamer to PDF')
|
| 73 |
+
|
| 74 |
+
print(f"✅ Beamer poster generated successfully: {tex_path}")
|
| 75 |
+
return tex_path
|
| 76 |
+
|
| 77 |
+
def modify_new_pipeline_for_beamer(args):
|
| 78 |
+
"""
|
| 79 |
+
Modified version of new_pipeline.py to support Beamer output.
|
| 80 |
+
This function replaces the PowerPoint generation part with Beamer generation.
|
| 81 |
+
"""
|
| 82 |
+
# Import the original pipeline components
|
| 83 |
+
from PosterAgent.new_pipeline import (
|
| 84 |
+
parse_paper_content,
|
| 85 |
+
gen_outline_layout_parallel,
|
| 86 |
+
gen_poster_content,
|
| 87 |
+
deoverflow_parallel,
|
| 88 |
+
apply_theme
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# ... (keep all the existing pipeline steps until poster generation)
|
| 92 |
+
|
| 93 |
+
# At the poster generation step, replace PowerPoint with Beamer:
|
| 94 |
+
|
| 95 |
+
# === Beamer Poster Generation ===
|
| 96 |
+
print("\n🎯 Generating Beamer poster...", flush=True)
|
| 97 |
+
|
| 98 |
+
# Extract poster information from content
|
| 99 |
+
poster_info = {
|
| 100 |
+
'title': 'Research Poster Title', # Extract from paper content
|
| 101 |
+
'author': 'Author Name', # Extract from paper content
|
| 102 |
+
'institute': 'Institute Name' # Extract from paper content
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
# Convert inches to centimeters (1 inch = 2.54 cm)
|
| 106 |
+
width_cm = args.poster_width_inches * 2.54
|
| 107 |
+
height_cm = args.poster_height_inches * 2.54
|
| 108 |
+
|
| 109 |
+
# Generate Beamer poster
|
| 110 |
+
tex_path = generate_beamer_poster(
|
| 111 |
+
panel_arrangement_inches=panel_arrangement_inches,
|
| 112 |
+
text_arrangement_inches=text_arrangement_inches,
|
| 113 |
+
figure_arrangement_inches=figure_arrangement_inches,
|
| 114 |
+
bullet_content=bullet_content,
|
| 115 |
+
poster_info=poster_info,
|
| 116 |
+
args=args,
|
| 117 |
+
width_cm=width_cm,
|
| 118 |
+
height_cm=height_cm,
|
| 119 |
+
theme=getattr(args, 'beamer_theme', 'default')
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# Copy output to final directory
|
| 123 |
+
output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_posters/{args.poster_path.replace("paper.pdf", "")}'
|
| 124 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 125 |
+
|
| 126 |
+
# Copy generated files
|
| 127 |
+
import shutil
|
| 128 |
+
shutil.copy(tex_path, f'{output_dir}/poster.tex')
|
| 129 |
+
shutil.copy(f'{args.tmp_dir}/poster.pdf', f'{output_dir}/poster.pdf')
|
| 130 |
+
|
| 131 |
+
print(f"✅ Beamer poster saved to: {output_dir}")
|
| 132 |
+
return output_dir
|
| 133 |
+
|
| 134 |
+
def add_beamer_arguments(parser):
|
| 135 |
+
"""Add Beamer-specific command line arguments."""
|
| 136 |
+
parser.add_argument(
|
| 137 |
+
'--output_format',
|
| 138 |
+
choices=['pptx', 'beamer'],
|
| 139 |
+
default='pptx',
|
| 140 |
+
help='Output format: pptx (PowerPoint) or beamer (LaTeX)'
|
| 141 |
+
)
|
| 142 |
+
parser.add_argument(
|
| 143 |
+
'--beamer_theme',
|
| 144 |
+
default='default',
|
| 145 |
+
help='Beamer theme name (default, Madrid, Warsaw, etc.)'
|
| 146 |
+
)
|
| 147 |
+
parser.add_argument(
|
| 148 |
+
'--beamer_width_cm',
|
| 149 |
+
type=float,
|
| 150 |
+
default=120,
|
| 151 |
+
help='Beamer poster width in centimeters'
|
| 152 |
+
)
|
| 153 |
+
parser.add_argument(
|
| 154 |
+
'--beamer_height_cm',
|
| 155 |
+
type=float,
|
| 156 |
+
default=90,
|
| 157 |
+
help='Beamer poster height in centimeters'
|
| 158 |
+
)
|
| 159 |
+
return parser
|
| 160 |
+
|
| 161 |
+
# Example integration with existing pipeline
|
| 162 |
+
def integrate_beamer_with_existing_pipeline():
|
| 163 |
+
"""
|
| 164 |
+
Example of how to integrate Beamer generation with the existing pipeline.
|
| 165 |
+
"""
|
| 166 |
+
# This would be added to the main pipeline function
|
| 167 |
+
pass
|
| 168 |
+
|
| 169 |
+
if __name__ == "__main__":
|
| 170 |
+
parser = argparse.ArgumentParser(description='Generate Beamer poster from paper')
|
| 171 |
+
parser = add_beamer_arguments(parser)
|
| 172 |
+
|
| 173 |
+
# Add other existing arguments...
|
| 174 |
+
|
| 175 |
+
args = parser.parse_args()
|
| 176 |
+
|
| 177 |
+
if args.output_format == 'beamer':
|
| 178 |
+
modify_new_pipeline_for_beamer(args)
|
| 179 |
+
else:
|
| 180 |
+
# Use original PowerPoint pipeline
|
| 181 |
+
pass
|
| 182 |
+
|
Paper2Poster/PosterAgent/create_dataset.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datasets import load_dataset
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
def generate_meta_json(base_dir='Paper2Poster-data'):
|
| 9 |
+
# Loop over each item in the specified base directory
|
| 10 |
+
for folder_name in os.listdir(base_dir):
|
| 11 |
+
subfolder_path = os.path.join(base_dir, folder_name)
|
| 12 |
+
|
| 13 |
+
# Ensure the item is a directory
|
| 14 |
+
if os.path.isdir(subfolder_path):
|
| 15 |
+
poster_path = os.path.join(subfolder_path, 'poster.png')
|
| 16 |
+
|
| 17 |
+
# Check if the poster.png exists in the subfolder
|
| 18 |
+
if os.path.exists(poster_path):
|
| 19 |
+
try:
|
| 20 |
+
# Open the image and get size (width, height)
|
| 21 |
+
with Image.open(poster_path) as img:
|
| 22 |
+
width, height = img.size
|
| 23 |
+
|
| 24 |
+
# Prepare metadata dictionary
|
| 25 |
+
metadata = {
|
| 26 |
+
'width': width,
|
| 27 |
+
'height': height
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
# Write metadata to meta.json in the same subfolder
|
| 31 |
+
meta_json_path = os.path.join(subfolder_path, 'meta.json')
|
| 32 |
+
with open(meta_json_path, 'w') as json_file:
|
| 33 |
+
json.dump(metadata, json_file)
|
| 34 |
+
|
| 35 |
+
print(f"Metadata for '{folder_name}' saved successfully.")
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Error processing image in folder '{folder_name}': {e}")
|
| 38 |
+
else:
|
| 39 |
+
print(f"No poster.png found in folder '{folder_name}'.")
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
dataset = load_dataset("Paper2Poster/Paper2Poster", split="train")
|
| 43 |
+
os.makedirs('Paper2Poster-data', exist_ok=True)
|
| 44 |
+
for data in dataset:
|
| 45 |
+
paper_title = data['title']
|
| 46 |
+
paper_url = data['paper_url']
|
| 47 |
+
poster_url = data['image_url']
|
| 48 |
+
qa = data['qa']
|
| 49 |
+
|
| 50 |
+
os.makedirs(f'Paper2Poster-data/{paper_title}', exist_ok=True)
|
| 51 |
+
|
| 52 |
+
paper_output_path = os.path.join('Paper2Poster-data', paper_title, 'paper.pdf')
|
| 53 |
+
poster_output_path = os.path.join('Paper2Poster-data', paper_title, 'poster.png')
|
| 54 |
+
qa_path = os.path.join('Paper2Poster-data', paper_title, 'o3_qa.json')
|
| 55 |
+
|
| 56 |
+
qa_dict = json.loads(qa)
|
| 57 |
+
with open(qa_path, 'w') as f:
|
| 58 |
+
json.dump(qa_dict, f, indent=4)
|
| 59 |
+
print(f"Saved QA for {paper_title} into {qa_path}")
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
subprocess.run(['wget', paper_url, '-O', paper_output_path], check=True)
|
| 63 |
+
subprocess.run(['wget', poster_url, '-O', poster_output_path], check=True)
|
| 64 |
+
print(f"Downloaded {poster_url} into {poster_output_path}")
|
| 65 |
+
print(f"Downloaded {paper_url} into {paper_output_path}")
|
| 66 |
+
except subprocess.CalledProcessError as e:
|
| 67 |
+
print(f"Error downloading {paper_url} or {poster_url}: {e}")
|
| 68 |
+
|
| 69 |
+
generate_meta_json('Paper2Poster-data')
|
Paper2Poster/PosterAgent/deoverflow.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from camel.models import ModelFactory
|
| 6 |
+
from camel.agents import ChatAgent
|
| 7 |
+
|
| 8 |
+
from utils.wei_utils import *
|
| 9 |
+
|
| 10 |
+
from camel.messages import BaseMessage
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import pickle as pkl
|
| 13 |
+
from utils.pptx_utils import *
|
| 14 |
+
from utils.critic_utils import *
|
| 15 |
+
import yaml
|
| 16 |
+
import argparse
|
| 17 |
+
import shutil
|
| 18 |
+
from jinja2 import Environment, StrictUndefined
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
MAX_ATTEMPTS = 5
|
| 23 |
+
|
| 24 |
+
def deoverflow(args, actor_config, critic_config):
|
| 25 |
+
total_input_token, total_output_token = 0, 0
|
| 26 |
+
style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb'))
|
| 27 |
+
logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
|
| 28 |
+
|
| 29 |
+
style_logs = style_ckpt['style_logs']
|
| 30 |
+
sections = list(style_logs.keys())
|
| 31 |
+
sections = [s for s in sections if s != 'meta']
|
| 32 |
+
|
| 33 |
+
slide_width = style_ckpt['outline']['meta']['width']
|
| 34 |
+
slide_height = style_ckpt['outline']['meta']['height']
|
| 35 |
+
|
| 36 |
+
content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
|
| 37 |
+
outline = logs_ckpt['outline']
|
| 38 |
+
|
| 39 |
+
name_to_hierarchy = get_hierarchy(outline, 1)
|
| 40 |
+
|
| 41 |
+
critic_agent_name = 'critic_overlap_agent'
|
| 42 |
+
with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f:
|
| 43 |
+
deoverflow_critic_config = yaml.safe_load(f)
|
| 44 |
+
|
| 45 |
+
actor_agent_name = 'actor_editor_agent'
|
| 46 |
+
|
| 47 |
+
with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f:
|
| 48 |
+
deoverflow_actor_config = yaml.safe_load(f)
|
| 49 |
+
|
| 50 |
+
actor_model = ModelFactory.create(
|
| 51 |
+
model_platform=actor_config['model_platform'],
|
| 52 |
+
model_type=actor_config['model_type'],
|
| 53 |
+
model_config_dict=actor_config['model_config'],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
actor_sys_msg = deoverflow_actor_config['system_prompt']
|
| 57 |
+
|
| 58 |
+
actor_agent = ChatAgent(
|
| 59 |
+
system_message=actor_sys_msg,
|
| 60 |
+
model=actor_model,
|
| 61 |
+
message_window_size=10,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
critic_model = ModelFactory.create(
|
| 65 |
+
model_platform=critic_config['model_platform'],
|
| 66 |
+
model_type=critic_config['model_type'],
|
| 67 |
+
model_config_dict=critic_config['model_config'],
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
critic_sys_msg = deoverflow_critic_config['system_prompt']
|
| 71 |
+
|
| 72 |
+
critic_agent = ChatAgent(
|
| 73 |
+
system_message=critic_sys_msg,
|
| 74 |
+
model=critic_model,
|
| 75 |
+
message_window_size=None,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 79 |
+
|
| 80 |
+
actor_template = jinja_env.from_string(deoverflow_actor_config["template"])
|
| 81 |
+
critic_template = jinja_env.from_string(deoverflow_critic_config["template"])
|
| 82 |
+
|
| 83 |
+
critic_logs = {}
|
| 84 |
+
actor_logs = {}
|
| 85 |
+
img_logs = {}
|
| 86 |
+
|
| 87 |
+
# Load neg and pos examples
|
| 88 |
+
neg_img = Image.open('overflow_example/neg.jpg')
|
| 89 |
+
pos_img = Image.open('overflow_example/pos.jpg')
|
| 90 |
+
|
| 91 |
+
for section_index in range(len(sections)):
|
| 92 |
+
section_name = sections[section_index]
|
| 93 |
+
section_code = style_logs[section_name][-1]['code']
|
| 94 |
+
|
| 95 |
+
if 'subsections' in content[section_name]:
|
| 96 |
+
subsections = list(content[section_name]['subsections'].keys())
|
| 97 |
+
else:
|
| 98 |
+
subsections = [section_name]
|
| 99 |
+
|
| 100 |
+
log = []
|
| 101 |
+
|
| 102 |
+
for leaf_section in subsections:
|
| 103 |
+
if leaf_section in outline:
|
| 104 |
+
leaf_name = outline[leaf_section]['name']
|
| 105 |
+
else:
|
| 106 |
+
leaf_name = outline[section_name]['subsections'][leaf_section]['name']
|
| 107 |
+
num_rounds = 0
|
| 108 |
+
while True:
|
| 109 |
+
print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}")
|
| 110 |
+
num_rounds += 1
|
| 111 |
+
if num_rounds > MAX_ATTEMPTS:
|
| 112 |
+
break
|
| 113 |
+
poster = create_poster(slide_width, slide_height)
|
| 114 |
+
add_blank_slide(poster)
|
| 115 |
+
save_presentation(poster, file_name='poster.pptx')
|
| 116 |
+
curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section(
|
| 117 |
+
leaf_section,
|
| 118 |
+
section_name,
|
| 119 |
+
name_to_hierarchy,
|
| 120 |
+
leaf_name,
|
| 121 |
+
section_code
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if not leaf_section in img_logs:
|
| 125 |
+
img_logs[leaf_section] = []
|
| 126 |
+
img_logs[leaf_section].append(zoomed_in_img)
|
| 127 |
+
|
| 128 |
+
jinja_args = {
|
| 129 |
+
'content_json': content[leaf_section] if leaf_section in content else content[section_name]['subsections'][leaf_section],
|
| 130 |
+
'existing_code': section_code,
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
critic_prompt = critic_template.render(**jinja_args)
|
| 134 |
+
|
| 135 |
+
critic_msg = BaseMessage.make_user_message(
|
| 136 |
+
role_name="User",
|
| 137 |
+
content=critic_prompt,
|
| 138 |
+
image_list=[neg_img, pos_img, zoomed_in_img],
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
critic_agent.reset()
|
| 142 |
+
response = critic_agent.step(critic_msg)
|
| 143 |
+
resp = response.msgs[0].content
|
| 144 |
+
input_token, output_token = account_token(response)
|
| 145 |
+
total_input_token += input_token
|
| 146 |
+
total_output_token += output_token
|
| 147 |
+
if not leaf_section in critic_logs:
|
| 148 |
+
critic_logs[leaf_section] = []
|
| 149 |
+
|
| 150 |
+
critic_logs[leaf_section].append(response)
|
| 151 |
+
|
| 152 |
+
if type(resp) == str:
|
| 153 |
+
if resp in ['NO', 'NO.', '"NO"', "'NO'"]:
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
feedback = get_json_from_response(resp)
|
| 157 |
+
print(feedback)
|
| 158 |
+
jinja_args = {
|
| 159 |
+
'content_json': content[leaf_section] if leaf_section in content else content[section_name]['subsections'][leaf_section],
|
| 160 |
+
'function_docs': documentation,
|
| 161 |
+
'existing_code': section_code,
|
| 162 |
+
'suggestion_json': feedback,
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
actor_prompt = actor_template.render(**jinja_args)
|
| 166 |
+
|
| 167 |
+
log = edit_code(actor_agent, actor_prompt, 3, existing_code='')
|
| 168 |
+
if log[-1]['error'] is not None:
|
| 169 |
+
raise Exception(log[-1]['error'])
|
| 170 |
+
|
| 171 |
+
input_token = log[-1]['cumulative_tokens'][0]
|
| 172 |
+
output_token = log[-1]['cumulative_tokens'][1]
|
| 173 |
+
total_input_token += input_token
|
| 174 |
+
total_output_token += output_token
|
| 175 |
+
|
| 176 |
+
section_code = log[-1]['code']
|
| 177 |
+
|
| 178 |
+
if not leaf_section in actor_logs:
|
| 179 |
+
actor_logs[leaf_section] = []
|
| 180 |
+
|
| 181 |
+
actor_logs[leaf_section].append(log)
|
| 182 |
+
if len(log) > 0:
|
| 183 |
+
style_logs[section_name].append(log[-1])
|
| 184 |
+
|
| 185 |
+
final_code = ''
|
| 186 |
+
for section in sections:
|
| 187 |
+
final_code += style_logs[section][-1]['code'] + '\n'
|
| 188 |
+
|
| 189 |
+
run_code_with_utils(final_code, utils_functions)
|
| 190 |
+
ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview')
|
| 191 |
+
|
| 192 |
+
result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
|
| 193 |
+
if not os.path.exists(result_dir):
|
| 194 |
+
os.makedirs(result_dir)
|
| 195 |
+
shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx')
|
| 196 |
+
ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview')
|
| 197 |
+
|
| 198 |
+
final_code_by_section = {}
|
| 199 |
+
for section in sections:
|
| 200 |
+
final_code_by_section[section] = style_logs[section][-1]['code']
|
| 201 |
+
|
| 202 |
+
non_overlap_ckpt = {
|
| 203 |
+
'critic_logs': critic_logs,
|
| 204 |
+
'actor_logs': actor_logs,
|
| 205 |
+
'img_logs': img_logs,
|
| 206 |
+
'name_to_hierarchy': name_to_hierarchy,
|
| 207 |
+
'final_code': final_code,
|
| 208 |
+
'final_code_by_section': final_code_by_section,
|
| 209 |
+
'total_input_token': total_input_token,
|
| 210 |
+
'total_output_token': total_output_token
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb'))
|
| 214 |
+
|
| 215 |
+
return total_input_token, total_output_token
|
| 216 |
+
|
| 217 |
+
if __name__ == '__main__':
|
| 218 |
+
parser = argparse.ArgumentParser()
|
| 219 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 220 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 221 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 222 |
+
parser.add_argument('--index', type=int, default=0)
|
| 223 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
|
| 226 |
+
actor_config = get_agent_config(args.model_name)
|
| 227 |
+
critic_config = get_agent_config(args.model_name)
|
| 228 |
+
|
| 229 |
+
if args.poster_name is None:
|
| 230 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 231 |
+
|
| 232 |
+
input_token, output_token = deoverflow(args, actor_config, critic_config)
|
| 233 |
+
|
| 234 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/deoverflow_parallel.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
from camel.models import ModelFactory
|
| 6 |
+
from camel.agents import ChatAgent
|
| 7 |
+
|
| 8 |
+
from utils.wei_utils import *
|
| 9 |
+
|
| 10 |
+
from camel.messages import BaseMessage
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import pickle as pkl
|
| 13 |
+
from utils.pptx_utils import *
|
| 14 |
+
from utils.critic_utils import *
|
| 15 |
+
import yaml
|
| 16 |
+
import argparse
|
| 17 |
+
import shutil
|
| 18 |
+
from jinja2 import Environment, StrictUndefined
|
| 19 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 20 |
+
import copy
|
| 21 |
+
|
| 22 |
+
load_dotenv()
|
| 23 |
+
|
| 24 |
+
MAX_ATTEMPTS = 5
|
| 25 |
+
|
| 26 |
+
def process_leaf_section(
|
| 27 |
+
leaf_section,
|
| 28 |
+
section_name,
|
| 29 |
+
outline,
|
| 30 |
+
content,
|
| 31 |
+
style_logs,
|
| 32 |
+
critic_logs,
|
| 33 |
+
actor_logs,
|
| 34 |
+
img_logs,
|
| 35 |
+
slide_width,
|
| 36 |
+
slide_height,
|
| 37 |
+
name_to_hierarchy,
|
| 38 |
+
critic_template,
|
| 39 |
+
actor_template,
|
| 40 |
+
critic_agent,
|
| 41 |
+
actor_agent,
|
| 42 |
+
neg_img,
|
| 43 |
+
pos_img,
|
| 44 |
+
MAX_ATTEMPTS,
|
| 45 |
+
documentation,
|
| 46 |
+
total_input_token,
|
| 47 |
+
total_output_token,
|
| 48 |
+
):
|
| 49 |
+
"""
|
| 50 |
+
Handles the logic for a single leaf_section within a section_name.
|
| 51 |
+
|
| 52 |
+
Returns a dictionary of updated logs and tokens.
|
| 53 |
+
"""
|
| 54 |
+
section_code = style_logs[section_name][-1]['code'] # current code for this section
|
| 55 |
+
log = []
|
| 56 |
+
leaf_name = None
|
| 57 |
+
if leaf_section in outline:
|
| 58 |
+
leaf_name = outline[leaf_section]['name']
|
| 59 |
+
else:
|
| 60 |
+
leaf_name = outline[section_name]['subsections'][leaf_section]['name']
|
| 61 |
+
|
| 62 |
+
num_rounds = 0
|
| 63 |
+
while True:
|
| 64 |
+
print(f"Section: {section_name}, Leaf Section: {leaf_section}, Round: {num_rounds}")
|
| 65 |
+
num_rounds += 1
|
| 66 |
+
if num_rounds > MAX_ATTEMPTS:
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
poster = create_poster(slide_width, slide_height)
|
| 70 |
+
add_blank_slide(poster)
|
| 71 |
+
empty_poster_path = f'tmp/empty_poster_{section_name}_{leaf_section}.pptx'
|
| 72 |
+
save_presentation(poster, file_name=empty_poster_path)
|
| 73 |
+
|
| 74 |
+
curr_location, zoomed_in_img, zoomed_in_img_path = get_snapshot_from_section(
|
| 75 |
+
leaf_section,
|
| 76 |
+
section_name,
|
| 77 |
+
name_to_hierarchy,
|
| 78 |
+
leaf_name,
|
| 79 |
+
section_code,
|
| 80 |
+
empty_poster_path
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
if leaf_section not in img_logs:
|
| 84 |
+
img_logs[leaf_section] = []
|
| 85 |
+
img_logs[leaf_section].append(zoomed_in_img)
|
| 86 |
+
|
| 87 |
+
jinja_args = {
|
| 88 |
+
'content_json': content[leaf_section] if leaf_section in content
|
| 89 |
+
else content[section_name]['subsections'][leaf_section],
|
| 90 |
+
'existing_code': section_code,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
critic_prompt = critic_template.render(**jinja_args)
|
| 94 |
+
|
| 95 |
+
critic_msg = BaseMessage.make_user_message(
|
| 96 |
+
role_name="User",
|
| 97 |
+
content=critic_prompt,
|
| 98 |
+
image_list=[neg_img, pos_img, zoomed_in_img],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
critic_agent.reset()
|
| 102 |
+
response = critic_agent.step(critic_msg)
|
| 103 |
+
resp = response.msgs[0].content
|
| 104 |
+
|
| 105 |
+
# Track tokens
|
| 106 |
+
input_token, output_token = account_token(response)
|
| 107 |
+
total_input_token += input_token
|
| 108 |
+
total_output_token += output_token
|
| 109 |
+
|
| 110 |
+
if leaf_section not in critic_logs:
|
| 111 |
+
critic_logs[leaf_section] = []
|
| 112 |
+
critic_logs[leaf_section].append(response)
|
| 113 |
+
|
| 114 |
+
# Stop condition
|
| 115 |
+
if isinstance(resp, str):
|
| 116 |
+
if resp in ['NO', 'NO.', '"NO"', "'NO'"]:
|
| 117 |
+
break
|
| 118 |
+
|
| 119 |
+
feedback = get_json_from_response(resp)
|
| 120 |
+
print(feedback)
|
| 121 |
+
|
| 122 |
+
jinja_args = {
|
| 123 |
+
'content_json': content[leaf_section] if leaf_section in content
|
| 124 |
+
else content[section_name]['subsections'][leaf_section],
|
| 125 |
+
'function_docs': documentation,
|
| 126 |
+
'existing_code': section_code,
|
| 127 |
+
'suggestion_json': feedback,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
actor_prompt = actor_template.render(**jinja_args)
|
| 131 |
+
|
| 132 |
+
leaf_log = edit_code(actor_agent, actor_prompt, 3, existing_code='')
|
| 133 |
+
if leaf_log[-1]['error'] is not None:
|
| 134 |
+
raise Exception(leaf_log[-1]['error'])
|
| 135 |
+
|
| 136 |
+
# Track tokens
|
| 137 |
+
in_tok = leaf_log[-1]['cumulative_tokens'][0]
|
| 138 |
+
out_tok = leaf_log[-1]['cumulative_tokens'][1]
|
| 139 |
+
total_input_token += in_tok
|
| 140 |
+
total_output_token += out_tok
|
| 141 |
+
|
| 142 |
+
section_code = leaf_log[-1]['code']
|
| 143 |
+
|
| 144 |
+
if leaf_section not in actor_logs:
|
| 145 |
+
actor_logs[leaf_section] = []
|
| 146 |
+
actor_logs[leaf_section].append(leaf_log)
|
| 147 |
+
|
| 148 |
+
log.extend(leaf_log)
|
| 149 |
+
|
| 150 |
+
return {
|
| 151 |
+
"section_code": section_code,
|
| 152 |
+
"log": log,
|
| 153 |
+
"img_logs": img_logs,
|
| 154 |
+
"critic_logs": critic_logs,
|
| 155 |
+
"actor_logs": actor_logs,
|
| 156 |
+
"total_input_token": total_input_token,
|
| 157 |
+
"total_output_token": total_output_token,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def process_section(
|
| 162 |
+
section_name,
|
| 163 |
+
content,
|
| 164 |
+
outline,
|
| 165 |
+
sections,
|
| 166 |
+
style_logs,
|
| 167 |
+
critic_logs,
|
| 168 |
+
actor_logs,
|
| 169 |
+
img_logs,
|
| 170 |
+
slide_width,
|
| 171 |
+
slide_height,
|
| 172 |
+
name_to_hierarchy,
|
| 173 |
+
critic_template,
|
| 174 |
+
actor_template,
|
| 175 |
+
critic_agent,
|
| 176 |
+
actor_agent,
|
| 177 |
+
neg_img,
|
| 178 |
+
pos_img,
|
| 179 |
+
MAX_ATTEMPTS,
|
| 180 |
+
documentation,
|
| 181 |
+
total_input_token,
|
| 182 |
+
total_output_token,
|
| 183 |
+
):
|
| 184 |
+
"""
|
| 185 |
+
Handles processing of a single section and its subsections (leaf sections).
|
| 186 |
+
Returns updated logs and token counters for this section.
|
| 187 |
+
"""
|
| 188 |
+
results_per_leaf = []
|
| 189 |
+
|
| 190 |
+
# Grab the current code for this section
|
| 191 |
+
section_code = style_logs[section_name][-1]['code']
|
| 192 |
+
|
| 193 |
+
# Determine which leaf sections to process
|
| 194 |
+
if 'subsections' in content[section_name]:
|
| 195 |
+
subsections = list(content[section_name]['subsections'].keys())
|
| 196 |
+
else:
|
| 197 |
+
subsections = [section_name]
|
| 198 |
+
|
| 199 |
+
all_logs_for_section = []
|
| 200 |
+
|
| 201 |
+
for leaf_section in subsections:
|
| 202 |
+
# Process this leaf section
|
| 203 |
+
leaf_result = process_leaf_section(
|
| 204 |
+
leaf_section,
|
| 205 |
+
section_name,
|
| 206 |
+
outline,
|
| 207 |
+
content,
|
| 208 |
+
style_logs,
|
| 209 |
+
critic_logs,
|
| 210 |
+
actor_logs,
|
| 211 |
+
img_logs,
|
| 212 |
+
slide_width,
|
| 213 |
+
slide_height,
|
| 214 |
+
name_to_hierarchy,
|
| 215 |
+
critic_template,
|
| 216 |
+
actor_template,
|
| 217 |
+
critic_agent,
|
| 218 |
+
actor_agent,
|
| 219 |
+
neg_img,
|
| 220 |
+
pos_img,
|
| 221 |
+
MAX_ATTEMPTS,
|
| 222 |
+
documentation,
|
| 223 |
+
total_input_token,
|
| 224 |
+
total_output_token,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Update logs/tokens
|
| 228 |
+
section_code = leaf_result["section_code"]
|
| 229 |
+
all_logs_for_section.extend(leaf_result["log"])
|
| 230 |
+
img_logs = leaf_result["img_logs"]
|
| 231 |
+
critic_logs = leaf_result["critic_logs"]
|
| 232 |
+
actor_logs = leaf_result["actor_logs"]
|
| 233 |
+
total_input_token = leaf_result["total_input_token"]
|
| 234 |
+
total_output_token = leaf_result["total_output_token"]
|
| 235 |
+
|
| 236 |
+
# If we have any logs from the last leaf in this section, append them
|
| 237 |
+
if all_logs_for_section:
|
| 238 |
+
style_logs[section_name].append(all_logs_for_section[-1])
|
| 239 |
+
|
| 240 |
+
# Return updated state for merging back in the main thread
|
| 241 |
+
return {
|
| 242 |
+
"section_name": section_name,
|
| 243 |
+
"style_logs": style_logs,
|
| 244 |
+
"critic_logs": critic_logs,
|
| 245 |
+
"actor_logs": actor_logs,
|
| 246 |
+
"img_logs": img_logs,
|
| 247 |
+
"total_input_token": total_input_token,
|
| 248 |
+
"total_output_token": total_output_token
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
def parallel_by_sections(
|
| 252 |
+
sections,
|
| 253 |
+
content,
|
| 254 |
+
outline,
|
| 255 |
+
style_logs,
|
| 256 |
+
critic_logs,
|
| 257 |
+
actor_logs,
|
| 258 |
+
img_logs,
|
| 259 |
+
slide_width,
|
| 260 |
+
slide_height,
|
| 261 |
+
name_to_hierarchy,
|
| 262 |
+
critic_template,
|
| 263 |
+
actor_template,
|
| 264 |
+
critic_agent,
|
| 265 |
+
actor_agent,
|
| 266 |
+
neg_img,
|
| 267 |
+
pos_img,
|
| 268 |
+
MAX_ATTEMPTS,
|
| 269 |
+
documentation,
|
| 270 |
+
total_input_token,
|
| 271 |
+
total_output_token,
|
| 272 |
+
max_workers=4
|
| 273 |
+
):
|
| 274 |
+
"""
|
| 275 |
+
Main entry point to parallelize processing across sections.
|
| 276 |
+
|
| 277 |
+
Returns the merged logs and token counters after processing all sections in parallel.
|
| 278 |
+
"""
|
| 279 |
+
# Because we’ll be modifying dictionaries (like style_logs, etc.),
|
| 280 |
+
# it can be safer to create a copy for the workers, then merge results
|
| 281 |
+
# after. (Below is a simple approach—depending on your scale, consider
|
| 282 |
+
# explicit concurrency controls or a database-backed approach.)
|
| 283 |
+
|
| 284 |
+
# Summaries from each future
|
| 285 |
+
results = []
|
| 286 |
+
|
| 287 |
+
# We’ll store fresh copies for each section to avoid concurrency collisions
|
| 288 |
+
# on dictionary updates. If the data is large, you might want a more
|
| 289 |
+
# sophisticated synchronization or partition approach rather than naive copies.
|
| 290 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 291 |
+
futures = []
|
| 292 |
+
|
| 293 |
+
for section_name in sections:
|
| 294 |
+
# Make shallow copies or deep copies of logs
|
| 295 |
+
_style_logs = copy.deepcopy(style_logs)
|
| 296 |
+
_critic_logs = copy.deepcopy(critic_logs)
|
| 297 |
+
_actor_logs = copy.deepcopy(actor_logs)
|
| 298 |
+
_img_logs = copy.deepcopy(img_logs)
|
| 299 |
+
|
| 300 |
+
futures.append(executor.submit(
|
| 301 |
+
process_section,
|
| 302 |
+
section_name,
|
| 303 |
+
content,
|
| 304 |
+
outline,
|
| 305 |
+
sections,
|
| 306 |
+
_style_logs,
|
| 307 |
+
_critic_logs,
|
| 308 |
+
_actor_logs,
|
| 309 |
+
_img_logs,
|
| 310 |
+
slide_width,
|
| 311 |
+
slide_height,
|
| 312 |
+
name_to_hierarchy,
|
| 313 |
+
critic_template,
|
| 314 |
+
actor_template,
|
| 315 |
+
critic_agent,
|
| 316 |
+
actor_agent,
|
| 317 |
+
neg_img,
|
| 318 |
+
pos_img,
|
| 319 |
+
MAX_ATTEMPTS,
|
| 320 |
+
documentation,
|
| 321 |
+
total_input_token,
|
| 322 |
+
total_output_token
|
| 323 |
+
))
|
| 324 |
+
|
| 325 |
+
for future in futures:
|
| 326 |
+
results.append(future.result())
|
| 327 |
+
|
| 328 |
+
# The code below merges the results. The method of merging depends on how
|
| 329 |
+
# you prefer to aggregate. For a minimal approach, we’ll pick the logs from
|
| 330 |
+
# each section, then overwrite or update them in the main dictionaries.
|
| 331 |
+
|
| 332 |
+
for res in results:
|
| 333 |
+
sec_name = res["section_name"]
|
| 334 |
+
# Overwrite or merge logs as needed
|
| 335 |
+
style_logs[sec_name] = res["style_logs"][sec_name]
|
| 336 |
+
critic_logs.update(res["critic_logs"])
|
| 337 |
+
actor_logs.update(res["actor_logs"])
|
| 338 |
+
img_logs.update(res["img_logs"])
|
| 339 |
+
total_input_token = res["total_input_token"]
|
| 340 |
+
total_output_token = res["total_output_token"]
|
| 341 |
+
|
| 342 |
+
return style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def deoverflow(args, actor_config, critic_config):
|
| 346 |
+
total_input_token, total_output_token = 0, 0
|
| 347 |
+
style_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'rb'))
|
| 348 |
+
logs_ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
|
| 349 |
+
|
| 350 |
+
style_logs = style_ckpt['style_logs']
|
| 351 |
+
sections = list(style_logs.keys())
|
| 352 |
+
sections = [s for s in sections if s != 'meta']
|
| 353 |
+
|
| 354 |
+
slide_width = style_ckpt['outline']['meta']['width']
|
| 355 |
+
slide_height = style_ckpt['outline']['meta']['height']
|
| 356 |
+
|
| 357 |
+
content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
|
| 358 |
+
outline = logs_ckpt['outline']
|
| 359 |
+
|
| 360 |
+
name_to_hierarchy = get_hierarchy(outline, 1)
|
| 361 |
+
|
| 362 |
+
critic_agent_name = 'critic_overlap_agent'
|
| 363 |
+
with open(f"prompt_templates/{critic_agent_name}.yaml", "r") as f:
|
| 364 |
+
deoverflow_critic_config = yaml.safe_load(f)
|
| 365 |
+
|
| 366 |
+
actor_agent_name = 'actor_editor_agent'
|
| 367 |
+
|
| 368 |
+
with open(f"prompt_templates/{actor_agent_name}.yaml", "r") as f:
|
| 369 |
+
deoverflow_actor_config = yaml.safe_load(f)
|
| 370 |
+
|
| 371 |
+
actor_model = ModelFactory.create(
|
| 372 |
+
model_platform=actor_config['model_platform'],
|
| 373 |
+
model_type=actor_config['model_type'],
|
| 374 |
+
model_config_dict=actor_config['model_config'],
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
actor_sys_msg = deoverflow_actor_config['system_prompt']
|
| 378 |
+
|
| 379 |
+
actor_agent = ChatAgent(
|
| 380 |
+
system_message=actor_sys_msg,
|
| 381 |
+
model=actor_model,
|
| 382 |
+
message_window_size=10,
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
critic_model = ModelFactory.create(
|
| 386 |
+
model_platform=critic_config['model_platform'],
|
| 387 |
+
model_type=critic_config['model_type'],
|
| 388 |
+
model_config_dict=critic_config['model_config'],
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
critic_sys_msg = deoverflow_critic_config['system_prompt']
|
| 392 |
+
|
| 393 |
+
critic_agent = ChatAgent(
|
| 394 |
+
system_message=critic_sys_msg,
|
| 395 |
+
model=critic_model,
|
| 396 |
+
message_window_size=None,
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 400 |
+
|
| 401 |
+
actor_template = jinja_env.from_string(deoverflow_actor_config["template"])
|
| 402 |
+
critic_template = jinja_env.from_string(deoverflow_critic_config["template"])
|
| 403 |
+
|
| 404 |
+
critic_logs = {}
|
| 405 |
+
actor_logs = {}
|
| 406 |
+
img_logs = {}
|
| 407 |
+
|
| 408 |
+
# Load neg and pos examples
|
| 409 |
+
neg_img = Image.open('overflow_example/neg.jpg')
|
| 410 |
+
pos_img = Image.open('overflow_example/pos.jpg')
|
| 411 |
+
|
| 412 |
+
style_logs, critic_logs, actor_logs, img_logs, total_input_token, total_output_token = parallel_by_sections(
|
| 413 |
+
sections=sections,
|
| 414 |
+
content=content,
|
| 415 |
+
outline=outline,
|
| 416 |
+
style_logs=style_logs,
|
| 417 |
+
critic_logs=critic_logs,
|
| 418 |
+
actor_logs=actor_logs,
|
| 419 |
+
img_logs=img_logs,
|
| 420 |
+
slide_width=slide_width,
|
| 421 |
+
slide_height=slide_height,
|
| 422 |
+
name_to_hierarchy=name_to_hierarchy,
|
| 423 |
+
critic_template=critic_template,
|
| 424 |
+
actor_template=actor_template,
|
| 425 |
+
critic_agent=critic_agent,
|
| 426 |
+
actor_agent=actor_agent,
|
| 427 |
+
neg_img=neg_img,
|
| 428 |
+
pos_img=pos_img,
|
| 429 |
+
MAX_ATTEMPTS=MAX_ATTEMPTS,
|
| 430 |
+
documentation=documentation,
|
| 431 |
+
total_input_token=total_input_token,
|
| 432 |
+
total_output_token=total_output_token,
|
| 433 |
+
max_workers=100, # or however many worker threads you want
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
final_code = ''
|
| 437 |
+
for section in sections:
|
| 438 |
+
final_code += style_logs[section][-1]['code'] + '\n'
|
| 439 |
+
|
| 440 |
+
run_code_with_utils(final_code, utils_functions)
|
| 441 |
+
ppt_to_images(f'poster.pptx', 'tmp/non_overlap_preview')
|
| 442 |
+
|
| 443 |
+
result_dir = f'results/{args.poster_name}/{args.model_name}/{args.index}'
|
| 444 |
+
if not os.path.exists(result_dir):
|
| 445 |
+
os.makedirs(result_dir)
|
| 446 |
+
shutil.copy('poster.pptx', f'{result_dir}/non_overlap_poster.pptx')
|
| 447 |
+
ppt_to_images(f'poster.pptx', f'{result_dir}/non_overlap_poster_preview')
|
| 448 |
+
|
| 449 |
+
final_code_by_section = {}
|
| 450 |
+
for section in sections:
|
| 451 |
+
final_code_by_section[section] = style_logs[section][-1]['code']
|
| 452 |
+
|
| 453 |
+
non_overlap_ckpt = {
|
| 454 |
+
'critic_logs': critic_logs,
|
| 455 |
+
'actor_logs': actor_logs,
|
| 456 |
+
'img_logs': img_logs,
|
| 457 |
+
'name_to_hierarchy': name_to_hierarchy,
|
| 458 |
+
'final_code': final_code,
|
| 459 |
+
'final_code_by_section': final_code_by_section,
|
| 460 |
+
'total_input_token': total_input_token,
|
| 461 |
+
'total_output_token': total_output_token
|
| 462 |
+
}
|
| 463 |
+
|
| 464 |
+
pkl.dump(non_overlap_ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_non_overlap_ckpt_{args.index}.pkl', 'wb'))
|
| 465 |
+
|
| 466 |
+
return total_input_token, total_output_token
|
| 467 |
+
|
| 468 |
+
if __name__ == '__main__':
|
| 469 |
+
parser = argparse.ArgumentParser()
|
| 470 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 471 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 472 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 473 |
+
parser.add_argument('--index', type=int, default=0)
|
| 474 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 475 |
+
args = parser.parse_args()
|
| 476 |
+
|
| 477 |
+
actor_config = get_agent_config(args.model_name)
|
| 478 |
+
critic_config = get_agent_config(args.model_name)
|
| 479 |
+
|
| 480 |
+
if args.poster_name is None:
|
| 481 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 482 |
+
|
| 483 |
+
input_token, output_token = deoverflow(args, actor_config, critic_config)
|
| 484 |
+
|
| 485 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/fill_and_style.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 4 |
+
import json
|
| 5 |
+
import pptx
|
| 6 |
+
|
| 7 |
+
from camel.models import ModelFactory
|
| 8 |
+
from camel.types import ModelPlatformType, ModelType
|
| 9 |
+
from camel.configs import ChatGPTConfig, QwenConfig
|
| 10 |
+
from camel.agents import ChatAgent
|
| 11 |
+
|
| 12 |
+
from utils.wei_utils import fill_content
|
| 13 |
+
|
| 14 |
+
from camel.messages import BaseMessage
|
| 15 |
+
from PIL import Image
|
| 16 |
+
import pickle as pkl
|
| 17 |
+
from utils.pptx_utils import *
|
| 18 |
+
from utils.critic_utils import *
|
| 19 |
+
from utils.wei_utils import *
|
| 20 |
+
import importlib
|
| 21 |
+
import yaml
|
| 22 |
+
import os
|
| 23 |
+
import shutil
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from jinja2 import Environment, StrictUndefined, Template
|
| 26 |
+
import argparse
|
| 27 |
+
|
| 28 |
+
load_dotenv()
|
| 29 |
+
|
| 30 |
+
def fill_poster_content(args, actor_config):
|
| 31 |
+
total_input_token, total_output_token = 0, 0
|
| 32 |
+
poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
|
| 33 |
+
agent_name = 'content_filler_agent'
|
| 34 |
+
|
| 35 |
+
with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
|
| 36 |
+
fill_config = yaml.safe_load(f)
|
| 37 |
+
|
| 38 |
+
actor_model = ModelFactory.create(
|
| 39 |
+
model_platform=actor_config['model_platform'],
|
| 40 |
+
model_type=actor_config['model_type'],
|
| 41 |
+
model_config_dict=actor_config['model_config'],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
actor_sys_msg = fill_config['system_prompt']
|
| 45 |
+
|
| 46 |
+
actor_agent = ChatAgent(
|
| 47 |
+
system_message=actor_sys_msg,
|
| 48 |
+
model=actor_model,
|
| 49 |
+
message_window_size=10,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb'))
|
| 53 |
+
logs = ckpt['logs']
|
| 54 |
+
outline = ckpt['outline']
|
| 55 |
+
|
| 56 |
+
sections = list(outline.keys())
|
| 57 |
+
sections = [s for s in sections if s != 'meta']
|
| 58 |
+
|
| 59 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 60 |
+
|
| 61 |
+
template = jinja_env.from_string(fill_config["template"])
|
| 62 |
+
content_logs = {}
|
| 63 |
+
|
| 64 |
+
for section_index in range(len(sections)):
|
| 65 |
+
section_name = sections[section_index]
|
| 66 |
+
section_code = logs[section_name][-1]['code']
|
| 67 |
+
|
| 68 |
+
print(f'Filling content for {section_name}')
|
| 69 |
+
|
| 70 |
+
jinja_args = {
|
| 71 |
+
'content_json': poster_content[section_name],
|
| 72 |
+
'function_docs': documentation,
|
| 73 |
+
'existing_code': section_code
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
prompt = template.render(**jinja_args)
|
| 77 |
+
if section_index == 0:
|
| 78 |
+
existing_code = ''
|
| 79 |
+
else:
|
| 80 |
+
existing_code = content_logs[sections[section_index - 1]][-1]['concatenated_code']
|
| 81 |
+
content_logs[section_name] = fill_content(
|
| 82 |
+
actor_agent,
|
| 83 |
+
prompt,
|
| 84 |
+
3,
|
| 85 |
+
existing_code
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
shutil.copy('poster.pptx', f'tmp/content_poster_<{section_name}>.pptx')
|
| 89 |
+
|
| 90 |
+
if content_logs[section_name][-1]['error'] is not None:
|
| 91 |
+
raise Exception(f'Error in filling content for {section_name}: {content_logs[section_name][-1]["error"]}')
|
| 92 |
+
|
| 93 |
+
total_input_token += content_logs[section_name][-1]['cumulative_tokens'][0]
|
| 94 |
+
total_output_token += content_logs[section_name][-1]['cumulative_tokens'][1]
|
| 95 |
+
|
| 96 |
+
ppt_to_images(f'tmp/content_poster_<{sections[-1]}>.pptx', 'tmp/content_preview')
|
| 97 |
+
|
| 98 |
+
ckpt = {
|
| 99 |
+
'logs': logs,
|
| 100 |
+
'content_logs': content_logs,
|
| 101 |
+
'outline': outline,
|
| 102 |
+
'total_input_token': total_input_token,
|
| 103 |
+
'total_output_token': total_output_token
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'wb'))
|
| 107 |
+
|
| 108 |
+
return total_input_token, total_output_token
|
| 109 |
+
|
| 110 |
+
def stylize_poster(args, actor_config):
|
| 111 |
+
total_input_token, total_output_token = 0, 0
|
| 112 |
+
poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r'))
|
| 113 |
+
agent_name = 'style_agent'
|
| 114 |
+
|
| 115 |
+
with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
|
| 116 |
+
style_config = yaml.safe_load(f)
|
| 117 |
+
|
| 118 |
+
actor_model = ModelFactory.create(
|
| 119 |
+
model_platform=actor_config['model_platform'],
|
| 120 |
+
model_type=actor_config['model_type'],
|
| 121 |
+
model_config_dict=actor_config['model_config'],
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
actor_sys_msg = style_config['system_prompt']
|
| 125 |
+
|
| 126 |
+
actor_agent = ChatAgent(
|
| 127 |
+
system_message=actor_sys_msg,
|
| 128 |
+
model=actor_model,
|
| 129 |
+
message_window_size=10,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'rb'))
|
| 133 |
+
content_logs = ckpt['content_logs']
|
| 134 |
+
outline = ckpt['outline']
|
| 135 |
+
|
| 136 |
+
sections = list(outline.keys())
|
| 137 |
+
sections = [s for s in sections if s != 'meta']
|
| 138 |
+
|
| 139 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 140 |
+
|
| 141 |
+
template = jinja_env.from_string(style_config["template"])
|
| 142 |
+
style_logs = {}
|
| 143 |
+
|
| 144 |
+
for section_index in range(len(sections)):
|
| 145 |
+
section_name = sections[section_index]
|
| 146 |
+
section_outline = json.dumps(outline[section_name])
|
| 147 |
+
section_code = content_logs[section_name][-1]['code']
|
| 148 |
+
|
| 149 |
+
print(f'Stylizing for {section_name}')
|
| 150 |
+
|
| 151 |
+
img_ratio_json = get_img_ratio_in_section(poster_content[section_name])
|
| 152 |
+
|
| 153 |
+
jinja_args = {
|
| 154 |
+
'content_json': poster_content[section_name],
|
| 155 |
+
'function_docs': documentation,
|
| 156 |
+
'existing_code': section_code,
|
| 157 |
+
'image_ratio': img_ratio_json,
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
prompt = template.render(**jinja_args)
|
| 161 |
+
if section_index == 0:
|
| 162 |
+
existing_code = ''
|
| 163 |
+
else:
|
| 164 |
+
existing_code = style_logs[sections[section_index - 1]][-1]['concatenated_code']
|
| 165 |
+
style_logs[section_name] = stylize(
|
| 166 |
+
actor_agent,
|
| 167 |
+
prompt,
|
| 168 |
+
args.max_retry,
|
| 169 |
+
existing_code
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
shutil.copy('poster.pptx', f'tmp/style_poster_<{section_name}>.pptx')
|
| 173 |
+
|
| 174 |
+
if style_logs[section_name][-1]['error'] is not None:
|
| 175 |
+
raise Exception(f'Error in stylizing for {section_name}')
|
| 176 |
+
|
| 177 |
+
total_input_token += style_logs[section_name][-1]['cumulative_tokens'][0]
|
| 178 |
+
total_output_token += style_logs[section_name][-1]['cumulative_tokens'][1]
|
| 179 |
+
|
| 180 |
+
ppt_to_images(f'tmp/style_poster_<{sections[-1]}>.pptx', 'tmp/style_preview')
|
| 181 |
+
ckpt = {
|
| 182 |
+
'logs': ckpt['logs'],
|
| 183 |
+
'content_logs': content_logs,
|
| 184 |
+
'style_logs': style_logs,
|
| 185 |
+
'outline': outline,
|
| 186 |
+
'total_input_token': total_input_token,
|
| 187 |
+
'total_output_token': total_output_token
|
| 188 |
+
}
|
| 189 |
+
|
| 190 |
+
with open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'wb') as f:
|
| 191 |
+
pkl.dump(ckpt, f)
|
| 192 |
+
|
| 193 |
+
return total_input_token, total_output_token
|
| 194 |
+
|
| 195 |
+
if __name__ == '__main__':
|
| 196 |
+
parser = argparse.ArgumentParser()
|
| 197 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 198 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 199 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 200 |
+
parser.add_argument('--index', type=int, default=0)
|
| 201 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
actor_config = get_agent_config(args.model_name)
|
| 205 |
+
|
| 206 |
+
if args.poster_name is None:
|
| 207 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 208 |
+
|
| 209 |
+
fill_total_input_token, fill_total_output_token = fill_poster_content(args, actor_config)
|
| 210 |
+
style_total_input_token, style_total_output_token = stylize_poster(args, actor_config)
|
| 211 |
+
|
| 212 |
+
total_input_token = fill_total_input_token + style_total_input_token
|
| 213 |
+
total_output_token = fill_total_output_token + style_total_output_token
|
| 214 |
+
|
| 215 |
+
print(f'Token consumption: {total_input_token} -> {total_output_token}')
|
Paper2Poster/PosterAgent/gen_beamer_code.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Dict, Any
|
| 5 |
+
|
| 6 |
+
def sanitize_for_latex(name):
|
| 7 |
+
"""Convert any character that is not alphanumeric into underscore for LaTeX compatibility."""
|
| 8 |
+
return re.sub(r'[^0-9a-zA-Z_]+', '_', name)
|
| 9 |
+
|
| 10 |
+
def initialize_beamer_document(width_cm=120, height_cm=90, theme="default"):
|
| 11 |
+
"""
|
| 12 |
+
Initialize a Beamer document with specified dimensions and theme.
|
| 13 |
+
|
| 14 |
+
Args:
|
| 15 |
+
width_cm: Width in centimeters (default 120cm for poster)
|
| 16 |
+
height_cm: Height in centimeters (default 90cm for poster)
|
| 17 |
+
theme: Beamer theme name (default, Madrid, Warsaw, etc.)
|
| 18 |
+
"""
|
| 19 |
+
code = f'''\\documentclass[aspectratio=169]{{beamer}}
|
| 20 |
+
\\usepackage[utf8]{{inputenc}}
|
| 21 |
+
\\usepackage[T1]{{fontenc}}
|
| 22 |
+
\\usepackage{{graphicx}}
|
| 23 |
+
\\usepackage{{tikz}}
|
| 24 |
+
\\usepackage{{xcolor}}
|
| 25 |
+
\\usepackage{{geometry}}
|
| 26 |
+
\\usepackage{{multicol}}
|
| 27 |
+
\\usepackage{{array}}
|
| 28 |
+
\\usepackage{{booktabs}}
|
| 29 |
+
\\usepackage{{adjustbox}}
|
| 30 |
+
|
| 31 |
+
% Set page dimensions for poster
|
| 32 |
+
\\geometry{{paperwidth={width_cm}cm, paperheight={height_cm}cm, margin=1cm}}
|
| 33 |
+
|
| 34 |
+
% Beamer theme
|
| 35 |
+
\\usetheme{{{theme}}}
|
| 36 |
+
\\usecolortheme{{default}}
|
| 37 |
+
|
| 38 |
+
% Custom colors
|
| 39 |
+
\\definecolor{{titlecolor}}{{RGB}}{{47, 85, 151}}
|
| 40 |
+
\\definecolor{{textcolor}}{{RGB}}{{0, 0, 0}}
|
| 41 |
+
\\definecolor{{bgcolor}}{{RGB}}{{255, 255, 255}}
|
| 42 |
+
|
| 43 |
+
% Remove navigation symbols
|
| 44 |
+
\\setbeamertemplate{{navigation symbols}}{{}}
|
| 45 |
+
|
| 46 |
+
% Custom title page
|
| 47 |
+
\\setbeamertemplate{{title page}}{{
|
| 48 |
+
\\begin{{center}}
|
| 49 |
+
\\vspace{{1cm}}
|
| 50 |
+
{{\\color{{titlecolor}}\\Huge\\textbf{{\\inserttitle}}}}
|
| 51 |
+
\\vspace{{0.5cm}}
|
| 52 |
+
\\Large{{\\insertauthor}}
|
| 53 |
+
\\vspace{{0.3cm}}
|
| 54 |
+
\\normalsize{{\\insertinstitute}}
|
| 55 |
+
\\end{{center}}
|
| 56 |
+
}}
|
| 57 |
+
|
| 58 |
+
% Custom frame title
|
| 59 |
+
\\setbeamertemplate{{frametitle}}{{
|
| 60 |
+
\\vspace{{0.5cm}}
|
| 61 |
+
\\begin{{flushleft}}
|
| 62 |
+
{{\\color{{titlecolor}}\\Large\\textbf{{\\insertframetitle}}}}
|
| 63 |
+
\\end{{flushleft}}
|
| 64 |
+
\\vspace{{0.3cm}}
|
| 65 |
+
}}
|
| 66 |
+
|
| 67 |
+
\\begin{{document}}
|
| 68 |
+
|
| 69 |
+
% Title frame
|
| 70 |
+
\\title{{POSTER_TITLE_PLACEHOLDER}}
|
| 71 |
+
\\author{{POSTER_AUTHOR_PLACEHOLDER}}
|
| 72 |
+
\\institute{{POSTER_INSTITUTE_PLACEHOLDER}}
|
| 73 |
+
\\date{{\\today}}
|
| 74 |
+
|
| 75 |
+
\\begin{{frame}}[plain]
|
| 76 |
+
\\titlepage
|
| 77 |
+
\\end{{frame}}
|
| 78 |
+
|
| 79 |
+
'''
|
| 80 |
+
return code
|
| 81 |
+
|
| 82 |
+
def generate_beamer_section_code(section_data: Dict[str, Any], section_index: int):
|
| 83 |
+
"""
|
| 84 |
+
兼容 Paper2Poster bullet JSON:
|
| 85 |
+
- section_data 包含 title_blocks / textbox1_blocks / textbox2_blocks
|
| 86 |
+
- 每个 *_blocks 是 list[ {bullet: bool, runs: [{text: str, ...}], ...} ]
|
| 87 |
+
"""
|
| 88 |
+
def blocks_to_lines(blocks):
|
| 89 |
+
"""把 blocks 转成 list[str],并标注是否 bullet"""
|
| 90 |
+
lines = []
|
| 91 |
+
for blk in blocks or []:
|
| 92 |
+
text = " ".join([r.get("text","") for r in blk.get("runs", [])]).strip()
|
| 93 |
+
if not text:
|
| 94 |
+
continue
|
| 95 |
+
lines.append({
|
| 96 |
+
"text": text,
|
| 97 |
+
"bullet": bool(blk.get("bullet", False))
|
| 98 |
+
})
|
| 99 |
+
return lines
|
| 100 |
+
|
| 101 |
+
# Frame title 优先用 title_blocks 的文本,否则用 title_str,否则 Untitled
|
| 102 |
+
if isinstance(section_data.get("title_blocks"), list) and section_data["title_blocks"]:
|
| 103 |
+
frame_title = " ".join([r.get("text","") for r in section_data["title_blocks"][0].get("runs", [])]).strip()
|
| 104 |
+
else:
|
| 105 |
+
frame_title = section_data.get("title_str") or "Untitled"
|
| 106 |
+
|
| 107 |
+
frame_title = frame_title.replace("{","\\{").replace("}","\\}") # 简单转义以防标题含花括号
|
| 108 |
+
|
| 109 |
+
code = f"\n% ===== Section {section_index} =====\n"
|
| 110 |
+
code += f"\\begin{{frame}}[t]{{{frame_title}}}\n"
|
| 111 |
+
code += " \\vspace{-0.5cm}\n"
|
| 112 |
+
|
| 113 |
+
for key in ["textbox1_blocks", "textbox2_blocks"]:
|
| 114 |
+
lines = blocks_to_lines(section_data.get(key, []))
|
| 115 |
+
if not lines:
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# 如果全是 bullet,就合并成一个 itemize;否则分别处理
|
| 119 |
+
if all(l["bullet"] for l in lines):
|
| 120 |
+
code += " \\begin{itemize}\n"
|
| 121 |
+
for l in lines:
|
| 122 |
+
code += f" \\item {l['text']}\n"
|
| 123 |
+
code += " \\end{itemize}\n"
|
| 124 |
+
else:
|
| 125 |
+
for l in lines:
|
| 126 |
+
if l["bullet"]:
|
| 127 |
+
code += f" \\begin{{itemize}}\\item {l['text']}\\end{{itemize}}\n"
|
| 128 |
+
else:
|
| 129 |
+
code += f" {l['text']}\\\\\n"
|
| 130 |
+
|
| 131 |
+
code += "\\end{frame}\n\n"
|
| 132 |
+
return code
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def generate_beamer_figure_code(figure_data: Dict[str, Any], figure_index: int):
|
| 137 |
+
"""
|
| 138 |
+
Generate Beamer code for including figures.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
figure_data: Dictionary containing figure information
|
| 142 |
+
figure_index: Index of the figure
|
| 143 |
+
"""
|
| 144 |
+
figure_name = sanitize_for_latex(figure_data.get('figure_name', f'figure_{figure_index}'))
|
| 145 |
+
figure_path = figure_data.get('figure_path', '')
|
| 146 |
+
|
| 147 |
+
# Convert inches to centimeters (1 inch = 2.54 cm)
|
| 148 |
+
width_cm = figure_data.get('width', 10) * 2.54
|
| 149 |
+
height_cm = figure_data.get('height', 8) * 2.54
|
| 150 |
+
|
| 151 |
+
code = f'''
|
| 152 |
+
% Figure: {figure_name}
|
| 153 |
+
\\begin{{frame}}[t]{{{figure_data.get('title', 'Figure')}}}
|
| 154 |
+
\\vspace{{-0.5cm}}
|
| 155 |
+
\\begin{{center}}
|
| 156 |
+
\\includegraphics[width={width_cm:.2f}cm, height={height_cm:.2f}cm]{{{figure_path}}}
|
| 157 |
+
\\end{{center}}
|
| 158 |
+
\\vspace{{0.3cm}}
|
| 159 |
+
\\begin{{center}}
|
| 160 |
+
\\small{{\\textbf{{{figure_data.get('caption', 'Figure Caption')}}}}}
|
| 161 |
+
\\end{{center}}
|
| 162 |
+
\\end{{frame}}
|
| 163 |
+
|
| 164 |
+
'''
|
| 165 |
+
return code
|
| 166 |
+
|
| 167 |
+
def generate_beamer_poster_code(
|
| 168 |
+
sections: List[Dict[str, Any]],
|
| 169 |
+
figures: List[Dict[str, Any]],
|
| 170 |
+
poster_info: Dict[str, str],
|
| 171 |
+
width_cm: float = 120,
|
| 172 |
+
height_cm: float = 90,
|
| 173 |
+
theme: str = "default",
|
| 174 |
+
output_path: str = "poster.tex"
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Generate complete Beamer poster code.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
sections: List of section dictionaries
|
| 181 |
+
figures: List of figure dictionaries
|
| 182 |
+
poster_info: Dictionary with title, author, institute
|
| 183 |
+
width_cm: Poster width in centimeters
|
| 184 |
+
height_cm: Poster height in centimeters
|
| 185 |
+
theme: Beamer theme name
|
| 186 |
+
output_path: Output .tex file path
|
| 187 |
+
"""
|
| 188 |
+
code = initialize_beamer_document(width_cm, height_cm, theme)
|
| 189 |
+
|
| 190 |
+
# Replace placeholders with actual content
|
| 191 |
+
code = code.replace('POSTER_TITLE_PLACEHOLDER', poster_info.get('title', 'Poster Title'))
|
| 192 |
+
code = code.replace('POSTER_AUTHOR_PLACEHOLDER', poster_info.get('author', 'Author Name'))
|
| 193 |
+
code = code.replace('POSTER_INSTITUTE_PLACEHOLDER', poster_info.get('institute', 'Institute Name'))
|
| 194 |
+
|
| 195 |
+
# Add sections
|
| 196 |
+
for i, section in enumerate(sections):
|
| 197 |
+
code += generate_beamer_section_code(section, i)
|
| 198 |
+
|
| 199 |
+
# Add figures
|
| 200 |
+
for i, figure in enumerate(figures):
|
| 201 |
+
code += generate_beamer_figure_code(figure, i)
|
| 202 |
+
|
| 203 |
+
# Close document
|
| 204 |
+
code += '''
|
| 205 |
+
\\end{document}
|
| 206 |
+
'''
|
| 207 |
+
|
| 208 |
+
return code
|
| 209 |
+
|
| 210 |
+
def save_beamer_code(code: str, output_path: str):
|
| 211 |
+
"""Save Beamer code to file."""
|
| 212 |
+
with open(output_path, 'w', encoding='utf-8') as f:
|
| 213 |
+
f.write(code)
|
| 214 |
+
|
| 215 |
+
def compile_beamer_to_pdf(tex_path: str, output_dir: str = "."):
|
| 216 |
+
"""
|
| 217 |
+
Compile Beamer .tex file to PDF using pdflatex.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
tex_path: Path to .tex file
|
| 221 |
+
output_dir: Output directory for PDF
|
| 222 |
+
"""
|
| 223 |
+
import subprocess
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
# Run pdflatex twice for proper cross-references
|
| 227 |
+
result1 = subprocess.run(
|
| 228 |
+
['pdflatex', '-output-directory', output_dir, tex_path],
|
| 229 |
+
capture_output=True,
|
| 230 |
+
text=True,
|
| 231 |
+
timeout=60
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
result2 = subprocess.run(
|
| 235 |
+
['pdflatex', '-output-directory', output_dir, tex_path],
|
| 236 |
+
capture_output=True,
|
| 237 |
+
text=True,
|
| 238 |
+
timeout=60
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
if result1.returncode == 0 and result2.returncode == 0:
|
| 242 |
+
print(f"Successfully compiled {tex_path} to PDF")
|
| 243 |
+
return True
|
| 244 |
+
else:
|
| 245 |
+
print(f"Error compiling {tex_path}:")
|
| 246 |
+
print(result1.stderr)
|
| 247 |
+
print(result2.stderr)
|
| 248 |
+
return False
|
| 249 |
+
|
| 250 |
+
except subprocess.TimeoutExpired:
|
| 251 |
+
print(f"Timeout while compiling {tex_path}")
|
| 252 |
+
return False
|
| 253 |
+
except Exception as e:
|
| 254 |
+
print(f"Error compiling {tex_path}: {e}")
|
| 255 |
+
return False
|
| 256 |
+
|
| 257 |
+
# Example usage functions
|
| 258 |
+
def convert_pptx_layout_to_beamer(pptx_layout_data: Dict[str, Any]) -> Dict[str, Any]:
|
| 259 |
+
"""
|
| 260 |
+
Convert PowerPoint layout data to Beamer-compatible format.
|
| 261 |
+
|
| 262 |
+
Args:
|
| 263 |
+
pptx_layout_data: Layout data from PowerPoint generation
|
| 264 |
+
"""
|
| 265 |
+
beamer_data = {
|
| 266 |
+
'sections': [],
|
| 267 |
+
'figures': [],
|
| 268 |
+
'poster_info': {
|
| 269 |
+
'title': 'Default Title',
|
| 270 |
+
'author': 'Default Author',
|
| 271 |
+
'institute': 'Default Institute'
|
| 272 |
+
}
|
| 273 |
+
}
|
| 274 |
+
|
| 275 |
+
# Convert text arrangements to sections
|
| 276 |
+
if 'text_arrangement' in pptx_layout_data:
|
| 277 |
+
for i, text_item in enumerate(pptx_layout_data['text_arrangement']):
|
| 278 |
+
section = {
|
| 279 |
+
'section_name': text_item.get('textbox_name', f'section_{i}'),
|
| 280 |
+
'title': text_item.get('title', f'Section {i+1}'),
|
| 281 |
+
'content': text_item.get('content', 'Content placeholder')
|
| 282 |
+
}
|
| 283 |
+
beamer_data['sections'].append(section)
|
| 284 |
+
|
| 285 |
+
# Convert figure arrangements to figures
|
| 286 |
+
if 'figure_arrangement' in pptx_layout_data:
|
| 287 |
+
for i, figure_item in enumerate(pptx_layout_data['figure_arrangement']):
|
| 288 |
+
figure = {
|
| 289 |
+
'figure_name': figure_item.get('figure_name', f'figure_{i}'),
|
| 290 |
+
'figure_path': figure_item.get('figure_path', ''),
|
| 291 |
+
'width': figure_item.get('width', 10),
|
| 292 |
+
'height': figure_item.get('height', 8),
|
| 293 |
+
'title': figure_item.get('title', f'Figure {i+1}'),
|
| 294 |
+
'caption': figure_item.get('caption', 'Figure caption')
|
| 295 |
+
}
|
| 296 |
+
beamer_data['figures'].append(figure)
|
| 297 |
+
|
| 298 |
+
return beamer_data
|
| 299 |
+
|
Paper2Poster/PosterAgent/gen_outline_layout.py
ADDED
|
@@ -0,0 +1,851 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import yaml
|
| 6 |
+
from jinja2 import Environment, StrictUndefined
|
| 7 |
+
|
| 8 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 9 |
+
|
| 10 |
+
from camel.models import ModelFactory
|
| 11 |
+
from camel.agents import ChatAgent
|
| 12 |
+
from camel.messages import BaseMessage
|
| 13 |
+
|
| 14 |
+
from utils.pptx_utils import *
|
| 15 |
+
from utils.wei_utils import *
|
| 16 |
+
|
| 17 |
+
import pickle as pkl
|
| 18 |
+
import argparse
|
| 19 |
+
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
IMAGE_SCALE_RATIO_MIN = 50
|
| 23 |
+
IMAGE_SCALE_RATIO_MAX = 40
|
| 24 |
+
TABLE_SCALE_RATIO_MIN = 100
|
| 25 |
+
TABLE_SCALE_RATIO_MAX = 80
|
| 26 |
+
|
| 27 |
+
def compute_tp(raw_content_json):
|
| 28 |
+
total_length = 0
|
| 29 |
+
for section in raw_content_json['sections']:
|
| 30 |
+
total_length += len(section['content'])
|
| 31 |
+
|
| 32 |
+
for i in range(len(raw_content_json['sections'])):
|
| 33 |
+
raw_content_json['sections'][i]['tp'] = len(raw_content_json['sections'][i]['content']) / total_length
|
| 34 |
+
raw_content_json['sections'][i]['text_len'] = len(raw_content_json['sections'][i]['content'])
|
| 35 |
+
|
| 36 |
+
def compute_gp(table_info, image_info):
|
| 37 |
+
total_area = 0
|
| 38 |
+
for k, v in table_info.items():
|
| 39 |
+
total_area += v['figure_size']
|
| 40 |
+
|
| 41 |
+
for k, v in image_info.items():
|
| 42 |
+
total_area += v['figure_size']
|
| 43 |
+
|
| 44 |
+
for k, v in table_info.items():
|
| 45 |
+
v['gp'] = v['figure_size'] / total_area
|
| 46 |
+
|
| 47 |
+
for k, v in image_info.items():
|
| 48 |
+
v['gp'] = v['figure_size'] / total_area
|
| 49 |
+
|
| 50 |
+
def get_outline_location(outline, subsection=False):
|
| 51 |
+
outline_location = {}
|
| 52 |
+
for k, v in outline.items():
|
| 53 |
+
if k == 'meta':
|
| 54 |
+
continue
|
| 55 |
+
outline_location[k] = {
|
| 56 |
+
'location': v['location'],
|
| 57 |
+
}
|
| 58 |
+
if subsection:
|
| 59 |
+
if 'subsections' in v:
|
| 60 |
+
outline_location[k]['subsections'] = get_outline_location(v['subsections'])
|
| 61 |
+
return outline_location
|
| 62 |
+
|
| 63 |
+
def apply_outline_location(outline, location, subsection=False):
|
| 64 |
+
new_outline = {}
|
| 65 |
+
for k, v in outline.items():
|
| 66 |
+
if k == 'meta':
|
| 67 |
+
new_outline[k] = v
|
| 68 |
+
continue
|
| 69 |
+
new_outline[k] = copy.deepcopy(v)
|
| 70 |
+
new_outline[k]['location'] = location[k]['location']
|
| 71 |
+
if subsection:
|
| 72 |
+
if 'subsections' in v:
|
| 73 |
+
new_outline[k]['subsections'] = apply_outline_location(v['subsections'], location[k]['subsections'])
|
| 74 |
+
|
| 75 |
+
return new_outline
|
| 76 |
+
|
| 77 |
+
def fill_location(outline, section_name, location_dict):
|
| 78 |
+
new_outline = copy.deepcopy(outline)
|
| 79 |
+
if 'subsections' not in new_outline[section_name]:
|
| 80 |
+
return new_outline
|
| 81 |
+
for k, v in new_outline[section_name]['subsections'].items():
|
| 82 |
+
v['location'] = location_dict[k]['location']
|
| 83 |
+
return new_outline
|
| 84 |
+
|
| 85 |
+
def recover_name_and_location(outline_no_name, outline):
|
| 86 |
+
new_outline = copy.deepcopy(outline_no_name)
|
| 87 |
+
for k, v in outline_no_name.items():
|
| 88 |
+
if k == 'meta':
|
| 89 |
+
continue
|
| 90 |
+
new_outline[k]['name'] = outline[k]['name']
|
| 91 |
+
if type(new_outline[k]['location']) == list:
|
| 92 |
+
new_outline[k]['location'] = {
|
| 93 |
+
'left': v['location'][0],
|
| 94 |
+
'top': v['location'][1],
|
| 95 |
+
'width': v['location'][2],
|
| 96 |
+
'height': v['location'][3]
|
| 97 |
+
}
|
| 98 |
+
if 'subsections' in v:
|
| 99 |
+
for k_sub, v_sub in v['subsections'].items():
|
| 100 |
+
new_outline[k]['subsections'][k_sub]['name'] = outline[k]['subsections'][k_sub]['name']
|
| 101 |
+
if type(new_outline[k]['subsections'][k_sub]['location']) == list:
|
| 102 |
+
new_outline[k]['subsections'][k_sub]['location'] = {
|
| 103 |
+
'left': v_sub['location'][0],
|
| 104 |
+
'top': v_sub['location'][1],
|
| 105 |
+
'width': v_sub['location'][2],
|
| 106 |
+
'height': v_sub['location'][3]
|
| 107 |
+
}
|
| 108 |
+
return new_outline
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def validate_and_adjust_subsections(section_bbox, subsection_bboxes):
|
| 112 |
+
"""
|
| 113 |
+
Validate that the given subsections collectively occupy the entire section.
|
| 114 |
+
If not, return an adjusted version that fixes the layout.
|
| 115 |
+
|
| 116 |
+
We assume all subsections are intended to be stacked vertically with no gaps,
|
| 117 |
+
spanning the full width of the section.
|
| 118 |
+
|
| 119 |
+
:param section_bbox: dict with keys ["left", "top", "width", "height"]
|
| 120 |
+
:param subsection_bboxes: dict of subsection_name -> bounding_box (each also
|
| 121 |
+
with keys ["left", "top", "width", "height"])
|
| 122 |
+
:return: (is_valid, revised_subsections)
|
| 123 |
+
where is_valid is True/False,
|
| 124 |
+
and revised_subsections is either the same as subsection_bboxes if valid,
|
| 125 |
+
or a new dict of adjusted bounding boxes if invalid.
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
# Helper functions
|
| 129 |
+
def _right(bbox):
|
| 130 |
+
return bbox["left"] + bbox["width"]
|
| 131 |
+
|
| 132 |
+
def _bottom(bbox):
|
| 133 |
+
return bbox["top"] + bbox["height"]
|
| 134 |
+
|
| 135 |
+
section_left = section_bbox["left"]
|
| 136 |
+
section_top = section_bbox["top"]
|
| 137 |
+
section_right = section_left + section_bbox["width"]
|
| 138 |
+
section_bottom = section_top + section_bbox["height"]
|
| 139 |
+
|
| 140 |
+
# Convert dictionary to a list of (subsection_name, bbox) pairs
|
| 141 |
+
items = list(subsection_bboxes.items())
|
| 142 |
+
if not items:
|
| 143 |
+
# No subsections is definitely not valid if we want to fill the section
|
| 144 |
+
return False, None
|
| 145 |
+
|
| 146 |
+
# Sort subsections by their 'top' coordinate
|
| 147 |
+
items_sorted = sorted(items, key=lambda x: x[1]["top"])
|
| 148 |
+
|
| 149 |
+
# ---------------------------
|
| 150 |
+
# Step 1: Validate
|
| 151 |
+
# ---------------------------
|
| 152 |
+
# We'll check:
|
| 153 |
+
# 1. left/right boundaries match the section for each subsection
|
| 154 |
+
# 2. The first subsection's top == section_top
|
| 155 |
+
# 3. The last subsection's bottom == section_bottom
|
| 156 |
+
# 4. Each pair of consecutive subsections lines up exactly
|
| 157 |
+
# (previous bottom == current top) with no gap or overlap.
|
| 158 |
+
|
| 159 |
+
is_valid = True
|
| 160 |
+
|
| 161 |
+
# Check left/right for each
|
| 162 |
+
for name, bbox in items_sorted:
|
| 163 |
+
if bbox["left"] != section_left or _right(bbox) != section_right:
|
| 164 |
+
is_valid = False
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
# Check alignment for the first and last
|
| 168 |
+
if is_valid:
|
| 169 |
+
first_sub_name, first_sub_bbox = items_sorted[0]
|
| 170 |
+
if first_sub_bbox["top"] != section_top:
|
| 171 |
+
is_valid = False
|
| 172 |
+
|
| 173 |
+
if is_valid:
|
| 174 |
+
last_sub_name, last_sub_bbox = items_sorted[-1]
|
| 175 |
+
if _bottom(last_sub_bbox) != section_bottom:
|
| 176 |
+
is_valid = False
|
| 177 |
+
|
| 178 |
+
# Check consecutive alignment
|
| 179 |
+
if is_valid:
|
| 180 |
+
for i in range(len(items_sorted) - 1):
|
| 181 |
+
_, current_bbox = items_sorted[i]
|
| 182 |
+
_, next_bbox = items_sorted[i + 1]
|
| 183 |
+
if _bottom(current_bbox) != next_bbox["top"]:
|
| 184 |
+
is_valid = False
|
| 185 |
+
break
|
| 186 |
+
|
| 187 |
+
# If everything passed, we return
|
| 188 |
+
if is_valid:
|
| 189 |
+
return True, subsection_bboxes
|
| 190 |
+
|
| 191 |
+
# ---------------------------
|
| 192 |
+
# Step 2: Revise
|
| 193 |
+
# ---------------------------
|
| 194 |
+
# We will adjust all subsection bboxes so that they occupy
|
| 195 |
+
# the entire section exactly, preserving each original bbox's
|
| 196 |
+
# height *ratio* if possible.
|
| 197 |
+
|
| 198 |
+
# 2a. Compute total original height (in the order of sorted items)
|
| 199 |
+
original_heights = [bbox["height"] for _, bbox in items_sorted]
|
| 200 |
+
total_original_height = sum(original_heights)
|
| 201 |
+
|
| 202 |
+
# Avoid divide-by-zero if somehow there's a 0 height
|
| 203 |
+
if total_original_height <= 0:
|
| 204 |
+
# Fallback: split the section equally among subsections
|
| 205 |
+
# to avoid zero or negative heights
|
| 206 |
+
chunk_height = section_bbox["height"] / len(items_sorted)
|
| 207 |
+
scale_heights = [chunk_height] * len(items_sorted)
|
| 208 |
+
else:
|
| 209 |
+
# Scale each original height by the ratio of
|
| 210 |
+
# (section total height / sum of original heights)
|
| 211 |
+
scale = section_bbox["height"] / total_original_height
|
| 212 |
+
scale_heights = [h * scale for h in original_heights]
|
| 213 |
+
|
| 214 |
+
# 2b. Assign bounding boxes top->bottom, ensuring no gap
|
| 215 |
+
revised = {}
|
| 216 |
+
current_top = section_top
|
| 217 |
+
for i, (name, original_bbox) in enumerate(items_sorted):
|
| 218 |
+
revised_height = scale_heights[i]
|
| 219 |
+
# If there's floating error, we can clamp in the last iteration
|
| 220 |
+
# so that the bottom exactly matches section_bottom.
|
| 221 |
+
# But for simplicity, we'll keep it straightforward unless needed.
|
| 222 |
+
|
| 223 |
+
revised[name] = {
|
| 224 |
+
"left": section_left,
|
| 225 |
+
"top": current_top,
|
| 226 |
+
"width": section_bbox["width"],
|
| 227 |
+
"height": revised_height
|
| 228 |
+
}
|
| 229 |
+
# Update current_top for next subsection
|
| 230 |
+
current_top += revised_height
|
| 231 |
+
|
| 232 |
+
# Due to potential float rounding, we can enforce the last subsection
|
| 233 |
+
# to exactly end at section_bottom:
|
| 234 |
+
last_name = items_sorted[-1][0]
|
| 235 |
+
# Recompute the actual bottom after the above assignment
|
| 236 |
+
new_bottom = revised[last_name]["top"] + revised[last_name]["height"]
|
| 237 |
+
diff = new_bottom - section_bottom
|
| 238 |
+
if abs(diff) > 1e-9:
|
| 239 |
+
# Adjust the last subsection's height
|
| 240 |
+
revised[last_name]["height"] -= diff
|
| 241 |
+
|
| 242 |
+
# Return the revised dictionary
|
| 243 |
+
return False, revised
|
| 244 |
+
|
| 245 |
+
def filter_image_table(args, filter_config):
|
| 246 |
+
images = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images.json', 'r'))
|
| 247 |
+
tables = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables.json', 'r'))
|
| 248 |
+
doc_json = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
|
| 249 |
+
agent_filter = 'image_table_filter_agent'
|
| 250 |
+
with open(f"utils/prompt_templates/{agent_filter}.yaml", "r", encoding="utf-8") as f:
|
| 251 |
+
config_filter = yaml.safe_load(f)
|
| 252 |
+
|
| 253 |
+
image_information = {}
|
| 254 |
+
for k, v in images.items():
|
| 255 |
+
image_information[k] = copy.deepcopy(v)
|
| 256 |
+
image_information[k]['min_width'] = v['width'] // IMAGE_SCALE_RATIO_MIN
|
| 257 |
+
image_information[k]['min_height'] = v['height'] // IMAGE_SCALE_RATIO_MIN
|
| 258 |
+
image_information[k]['max_width'] = v['width'] // IMAGE_SCALE_RATIO_MAX
|
| 259 |
+
image_information[k]['max_height'] = v['height'] // IMAGE_SCALE_RATIO_MAX
|
| 260 |
+
|
| 261 |
+
table_information = {}
|
| 262 |
+
for k, v in tables.items():
|
| 263 |
+
table_information[k] = copy.deepcopy(v)
|
| 264 |
+
table_information[k]['min_width'] = v['width'] // TABLE_SCALE_RATIO_MIN
|
| 265 |
+
table_information[k]['min_height'] = v['height'] // TABLE_SCALE_RATIO_MIN
|
| 266 |
+
table_information[k]['max_width'] = v['width'] // TABLE_SCALE_RATIO_MAX
|
| 267 |
+
table_information[k]['max_height'] = v['height'] // TABLE_SCALE_RATIO_MAX
|
| 268 |
+
|
| 269 |
+
filter_actor_sys_msg = config_filter['system_prompt']
|
| 270 |
+
|
| 271 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 272 |
+
filter_model = ModelFactory.create(
|
| 273 |
+
model_platform=filter_config['model_platform'],
|
| 274 |
+
model_type=filter_config['model_type'],
|
| 275 |
+
model_config_dict=filter_config['model_config'],
|
| 276 |
+
url=filter_config['url'],
|
| 277 |
+
)
|
| 278 |
+
else:
|
| 279 |
+
filter_model = ModelFactory.create(
|
| 280 |
+
model_platform=filter_config['model_platform'],
|
| 281 |
+
model_type=filter_config['model_type'],
|
| 282 |
+
model_config_dict=filter_config['model_config'],
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
filter_actor_agent = ChatAgent(
|
| 286 |
+
system_message=filter_actor_sys_msg,
|
| 287 |
+
model=filter_model,
|
| 288 |
+
message_window_size=10,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
filter_jinja_args = {
|
| 292 |
+
'json_content': doc_json,
|
| 293 |
+
'table_information': json.dumps(table_information, indent=4),
|
| 294 |
+
'image_information': json.dumps(image_information, indent=4),
|
| 295 |
+
}
|
| 296 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 297 |
+
filter_prompt = jinja_env.from_string(config_filter["template"])
|
| 298 |
+
filter_actor_agent.reset()
|
| 299 |
+
response = filter_actor_agent.step(filter_prompt.render(**filter_jinja_args))
|
| 300 |
+
input_token, output_token = account_token(response)
|
| 301 |
+
response_json = get_json_from_response(response.msgs[0].content)
|
| 302 |
+
table_information = response_json['table_information']
|
| 303 |
+
image_information = response_json['image_information']
|
| 304 |
+
json.dump(images, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images_filtered.json', 'w'), indent=4)
|
| 305 |
+
json.dump(tables, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables_filtered.json', 'w'), indent=4)
|
| 306 |
+
|
| 307 |
+
return input_token, output_token
|
| 308 |
+
|
| 309 |
+
def gen_outline_layout_v2(args, actor_config):
|
| 310 |
+
total_input_token, total_output_token = 0, 0
|
| 311 |
+
agent_name = 'poster_planner_new_v2'
|
| 312 |
+
doc_json = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
|
| 313 |
+
filtered_table_information = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
|
| 314 |
+
filtered_image_information = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
|
| 315 |
+
|
| 316 |
+
filtered_table_information_captions = {}
|
| 317 |
+
filtered_image_information_captions = {}
|
| 318 |
+
|
| 319 |
+
for k, v in filtered_table_information.items():
|
| 320 |
+
filtered_table_information_captions[k] = {
|
| 321 |
+
v['caption']
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
for k, v in filtered_image_information.items():
|
| 325 |
+
filtered_image_information_captions[k] = {
|
| 326 |
+
v['caption']
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
with open(f"utils/prompt_templates/{agent_name}.yaml", "r", encoding="utf-8") as f:
|
| 330 |
+
planner_config = yaml.safe_load(f)
|
| 331 |
+
|
| 332 |
+
compute_tp(doc_json)
|
| 333 |
+
|
| 334 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 335 |
+
outline_template = jinja_env.from_string(planner_config["template"])
|
| 336 |
+
planner_jinja_args = {
|
| 337 |
+
'json_content': doc_json,
|
| 338 |
+
'table_information': filtered_table_information_captions,
|
| 339 |
+
'image_information': filtered_image_information_captions,
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 343 |
+
planner_model = ModelFactory.create(
|
| 344 |
+
model_platform=actor_config['model_platform'],
|
| 345 |
+
model_type=actor_config['model_type'],
|
| 346 |
+
model_config_dict=actor_config['model_config'],
|
| 347 |
+
url=actor_config['url'],
|
| 348 |
+
)
|
| 349 |
+
else:
|
| 350 |
+
planner_model = ModelFactory.create(
|
| 351 |
+
model_platform=actor_config['model_platform'],
|
| 352 |
+
model_type=actor_config['model_type'],
|
| 353 |
+
model_config_dict=actor_config['model_config'],
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
planner_agent = ChatAgent(
|
| 358 |
+
system_message=planner_config['system_prompt'],
|
| 359 |
+
model=planner_model,
|
| 360 |
+
message_window_size=10,
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
print(f'Generating outline...')
|
| 364 |
+
planner_prompt = outline_template.render(**planner_jinja_args)
|
| 365 |
+
planner_agent.reset()
|
| 366 |
+
response = planner_agent.step(planner_prompt)
|
| 367 |
+
input_token, output_token = account_token(response)
|
| 368 |
+
total_input_token += input_token
|
| 369 |
+
total_output_token += output_token
|
| 370 |
+
|
| 371 |
+
figure_arrangement = get_json_from_response(response.msgs[0].content)
|
| 372 |
+
|
| 373 |
+
print(f'Figure arrangement: {json.dumps(figure_arrangement, indent=4)}')
|
| 374 |
+
|
| 375 |
+
arranged_images = {}
|
| 376 |
+
arranged_tables = {}
|
| 377 |
+
assigned_images = set()
|
| 378 |
+
assigned_tables = set()
|
| 379 |
+
|
| 380 |
+
for section_name, figure in figure_arrangement.items():
|
| 381 |
+
if 'image' in figure:
|
| 382 |
+
image_id = str(figure['image'])
|
| 383 |
+
if image_id in assigned_images:
|
| 384 |
+
continue
|
| 385 |
+
if image_id in filtered_image_information:
|
| 386 |
+
arranged_images[image_id] = filtered_image_information[image_id]
|
| 387 |
+
assigned_images.add(image_id)
|
| 388 |
+
if 'table' in figure:
|
| 389 |
+
table_id = str(figure['table'])
|
| 390 |
+
if table_id in assigned_tables:
|
| 391 |
+
continue
|
| 392 |
+
if table_id in filtered_table_information:
|
| 393 |
+
arranged_tables[table_id] = filtered_table_information[table_id]
|
| 394 |
+
assigned_tables.add(table_id)
|
| 395 |
+
|
| 396 |
+
compute_gp(arranged_tables, arranged_images)
|
| 397 |
+
|
| 398 |
+
# Obtain panel input
|
| 399 |
+
paper_panels = []
|
| 400 |
+
for i in range(len(doc_json['sections'])):
|
| 401 |
+
section = doc_json['sections'][i]
|
| 402 |
+
panel = {}
|
| 403 |
+
panel['panel_id'] = i
|
| 404 |
+
panel['section_name'] = section['title']
|
| 405 |
+
panel['tp'] = section['tp']
|
| 406 |
+
panel['text_len'] = section['text_len']
|
| 407 |
+
panel['gp'] = 0
|
| 408 |
+
panel['figure_size'] = 0
|
| 409 |
+
panel['figure_aspect'] = 1
|
| 410 |
+
if section['title'] in figure_arrangement:
|
| 411 |
+
curr_arrangement = figure_arrangement[section['title']]
|
| 412 |
+
if 'table' in curr_arrangement:
|
| 413 |
+
table_id = str(curr_arrangement['table'])
|
| 414 |
+
if table_id in arranged_tables:
|
| 415 |
+
panel['gp'] = arranged_tables[table_id]['gp']
|
| 416 |
+
panel['figure_size'] = arranged_tables[table_id]['figure_size']
|
| 417 |
+
panel['figure_aspect'] = arranged_tables[table_id]['figure_aspect']
|
| 418 |
+
elif 'image' in curr_arrangement:
|
| 419 |
+
image_id = str(curr_arrangement['image'])
|
| 420 |
+
if image_id in arranged_images:
|
| 421 |
+
panel['gp'] = arranged_images[image_id]['gp']
|
| 422 |
+
panel['figure_size'] = arranged_images[image_id]['figure_size']
|
| 423 |
+
panel['figure_aspect'] = arranged_images[image_id]['figure_aspect']
|
| 424 |
+
|
| 425 |
+
paper_panels.append(panel)
|
| 426 |
+
|
| 427 |
+
return total_input_token, total_output_token, paper_panels, figure_arrangement
|
| 428 |
+
|
| 429 |
+
def gen_outline_layout(args, actor_config, critic_config):
|
| 430 |
+
poster_log_path = f'log/{args.model_name}_{args.poster_name}_poster_{args.index}'
|
| 431 |
+
if not os.path.exists(poster_log_path):
|
| 432 |
+
os.mkdir(poster_log_path)
|
| 433 |
+
total_input_token, total_output_token = 0, 0
|
| 434 |
+
consumption_log = {
|
| 435 |
+
'outline': [],
|
| 436 |
+
'h1_actor': [],
|
| 437 |
+
'h2_actor': [],
|
| 438 |
+
'h1_critic': [],
|
| 439 |
+
'gen_layout': []
|
| 440 |
+
}
|
| 441 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 442 |
+
outline_file_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
|
| 443 |
+
agent_name = 'poster_planner_new'
|
| 444 |
+
agent_init_name = 'layout_agent_init'
|
| 445 |
+
agent_new_section_name = 'layout_agent_new_section'
|
| 446 |
+
h1_critic_name = 'critic_layout_hierarchy_1'
|
| 447 |
+
h2_actor_name = 'actor_layout_hierarchy_2'
|
| 448 |
+
|
| 449 |
+
doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
|
| 450 |
+
filtered_table_information = json.load(open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
|
| 451 |
+
filtered_image_information = json.load(open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
|
| 452 |
+
|
| 453 |
+
with open(f"utils/prompt_templates/{agent_name}.yaml", "r", encoding="utf-8") as f:
|
| 454 |
+
planner_config = yaml.safe_load(f)
|
| 455 |
+
|
| 456 |
+
with open(f"utils/prompt_templates/{agent_init_name}.yaml", "r", encoding="utf-8") as f:
|
| 457 |
+
config_init = yaml.safe_load(f)
|
| 458 |
+
|
| 459 |
+
with open(f"utils/prompt_templates/{agent_new_section_name}.yaml", "r", encoding="utf-8") as f:
|
| 460 |
+
config_new_section = yaml.safe_load(f)
|
| 461 |
+
|
| 462 |
+
with open(f"utils/prompt_templates/{h1_critic_name}.yaml", "r", encoding="utf-8") as f:
|
| 463 |
+
config_h1_critic = yaml.safe_load(f)
|
| 464 |
+
|
| 465 |
+
with open(f"utils/prompt_templates/{h2_actor_name}.yaml", "r", encoding="utf-8") as f:
|
| 466 |
+
config_h2_actor = yaml.safe_load(f)
|
| 467 |
+
|
| 468 |
+
planner_model = ModelFactory.create(
|
| 469 |
+
model_platform=actor_config['model_platform'],
|
| 470 |
+
model_type=actor_config['model_type'],
|
| 471 |
+
model_config_dict=actor_config['model_config'],
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
planner_agent = ChatAgent(
|
| 475 |
+
system_message=planner_config['system_prompt'],
|
| 476 |
+
model=planner_model,
|
| 477 |
+
message_window_size=10,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
outline_template = jinja_env.from_string(planner_config["template"])
|
| 481 |
+
|
| 482 |
+
planner_jinja_args = {
|
| 483 |
+
'json_content': doc_json,
|
| 484 |
+
'table_information': filtered_table_information,
|
| 485 |
+
'image_information': filtered_image_information,
|
| 486 |
+
}
|
| 487 |
+
|
| 488 |
+
actor_model = ModelFactory.create(
|
| 489 |
+
model_platform=actor_config['model_platform'],
|
| 490 |
+
model_type=actor_config['model_type'],
|
| 491 |
+
model_config_dict=actor_config['model_config'],
|
| 492 |
+
)
|
| 493 |
+
|
| 494 |
+
init_actor_sys_msg = config_init['system_prompt']
|
| 495 |
+
|
| 496 |
+
init_actor_agent = ChatAgent(
|
| 497 |
+
system_message=init_actor_sys_msg,
|
| 498 |
+
model=actor_model,
|
| 499 |
+
message_window_size=10,
|
| 500 |
+
)
|
| 501 |
+
|
| 502 |
+
new_section_actor_sys_msg = config_new_section['system_prompt']
|
| 503 |
+
new_section_actor_agent = ChatAgent(
|
| 504 |
+
system_message=new_section_actor_sys_msg,
|
| 505 |
+
model=actor_model,
|
| 506 |
+
message_window_size=10,
|
| 507 |
+
)
|
| 508 |
+
|
| 509 |
+
h1_critic_model = ModelFactory.create(
|
| 510 |
+
model_platform=critic_config['model_platform'],
|
| 511 |
+
model_type=critic_config['model_type'],
|
| 512 |
+
model_config_dict=critic_config['model_config'],
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
h1_critic_sys_msg = config_h1_critic['system_prompt']
|
| 516 |
+
|
| 517 |
+
h1_critic_agent = ChatAgent(
|
| 518 |
+
system_message=h1_critic_sys_msg,
|
| 519 |
+
model=h1_critic_model,
|
| 520 |
+
message_window_size=None,
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
h1_pos_example = Image.open('assets/h1_example/h1_pos.jpg')
|
| 524 |
+
h1_neg_example = Image.open('assets/h1_example/h1_neg.jpg')
|
| 525 |
+
|
| 526 |
+
h2_actor_model = ModelFactory.create(
|
| 527 |
+
model_platform=actor_config['model_platform'],
|
| 528 |
+
model_type=actor_config['model_type'],
|
| 529 |
+
model_config_dict=actor_config['model_config'],
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
h2_actor_sys_msg = config_h2_actor['system_prompt']
|
| 533 |
+
|
| 534 |
+
h2_actor_agent = ChatAgent(
|
| 535 |
+
system_message=h2_actor_sys_msg,
|
| 536 |
+
model=h2_actor_model,
|
| 537 |
+
message_window_size=10,
|
| 538 |
+
)
|
| 539 |
+
|
| 540 |
+
attempt = 0
|
| 541 |
+
while True:
|
| 542 |
+
print(f'Generating outline attempt {attempt}...')
|
| 543 |
+
planner_prompt = outline_template.render(**planner_jinja_args)
|
| 544 |
+
planner_agent.reset()
|
| 545 |
+
response = planner_agent.step(planner_prompt)
|
| 546 |
+
input_token, output_token = account_token(response)
|
| 547 |
+
consumption_log['outline'].append((input_token, output_token))
|
| 548 |
+
total_input_token += input_token
|
| 549 |
+
total_output_token += output_token
|
| 550 |
+
|
| 551 |
+
outline = get_json_from_response(response.msgs[0].content)
|
| 552 |
+
name_to_hierarchy = get_hierarchy(outline)
|
| 553 |
+
|
| 554 |
+
sections = list(outline.keys())
|
| 555 |
+
sections = [x for x in sections if x != 'meta']
|
| 556 |
+
init_template = jinja_env.from_string(config_init["template"])
|
| 557 |
+
new_section_template = jinja_env.from_string(config_new_section["template"])
|
| 558 |
+
h1_critic_template = jinja_env.from_string(config_h1_critic["template"])
|
| 559 |
+
init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
|
| 560 |
+
|
| 561 |
+
new_outline = outline
|
| 562 |
+
|
| 563 |
+
init_jinja_args = {
|
| 564 |
+
'json_outline': init_outline,
|
| 565 |
+
'function_docs': documentation
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
init_prompt = init_template.render(**init_jinja_args)
|
| 569 |
+
|
| 570 |
+
# hierarchy 1 only
|
| 571 |
+
outline_location = get_outline_location(outline, subsection=False)
|
| 572 |
+
logs = {}
|
| 573 |
+
curr_section = sections[0]
|
| 574 |
+
|
| 575 |
+
layout_cumulative_input_token = 0
|
| 576 |
+
layout_cumulative_output_token = 0
|
| 577 |
+
|
| 578 |
+
print('Generating h1 layout...\n')
|
| 579 |
+
print(f'Generating h1 layout for section {curr_section}...')
|
| 580 |
+
logs[curr_section] = gen_layout(
|
| 581 |
+
init_actor_agent,
|
| 582 |
+
init_prompt,
|
| 583 |
+
args.max_retry,
|
| 584 |
+
name_to_hierarchy,
|
| 585 |
+
visual_identifier=curr_section
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
if logs[curr_section][-1]['error'] is not None:
|
| 589 |
+
raise ValueError(f'Failed to generate layout for section {curr_section}.')
|
| 590 |
+
|
| 591 |
+
layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
|
| 592 |
+
layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
|
| 593 |
+
|
| 594 |
+
for section_index in range(1, len(sections)):
|
| 595 |
+
curr_section = sections[section_index]
|
| 596 |
+
print(f'generating h1 layout for section {curr_section}...')
|
| 597 |
+
new_section_outline = {curr_section: new_outline[curr_section]}
|
| 598 |
+
new_section_jinja_args = {
|
| 599 |
+
'json_outline': new_section_outline,
|
| 600 |
+
'function_docs': documentation
|
| 601 |
+
}
|
| 602 |
+
new_section_prompt = new_section_template.render(**new_section_jinja_args)
|
| 603 |
+
|
| 604 |
+
logs[curr_section] = gen_layout(
|
| 605 |
+
new_section_actor_agent,
|
| 606 |
+
new_section_prompt,
|
| 607 |
+
args.max_retry,
|
| 608 |
+
name_to_hierarchy,
|
| 609 |
+
visual_identifier=curr_section,
|
| 610 |
+
existing_code = logs[sections[section_index - 1]][-1]['concatenated_code']
|
| 611 |
+
)
|
| 612 |
+
if logs[curr_section][-1]['error'] is not None:
|
| 613 |
+
raise ValueError(f'Failed to generate layout for section {curr_section}.')
|
| 614 |
+
|
| 615 |
+
layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
|
| 616 |
+
layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
|
| 617 |
+
|
| 618 |
+
consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
|
| 619 |
+
total_input_token += layout_cumulative_input_token
|
| 620 |
+
total_output_token += layout_cumulative_output_token
|
| 621 |
+
|
| 622 |
+
h1_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1.pptx'
|
| 623 |
+
h2_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2.pptx'
|
| 624 |
+
|
| 625 |
+
h1_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1_filled.pptx'
|
| 626 |
+
h2_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2_filled.pptx'
|
| 627 |
+
|
| 628 |
+
ppt_to_images(h1_path, 'tmp/layout_h1')
|
| 629 |
+
ppt_to_images(h2_path, 'tmp/layout_h2')
|
| 630 |
+
ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
|
| 631 |
+
ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
|
| 632 |
+
|
| 633 |
+
h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
|
| 634 |
+
h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
|
| 635 |
+
h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
|
| 636 |
+
h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
|
| 637 |
+
|
| 638 |
+
h1_critic_msg = BaseMessage.make_user_message(
|
| 639 |
+
role_name='User',
|
| 640 |
+
content=h1_critic_template.render(),
|
| 641 |
+
image_list=[h1_neg_example, h1_pos_example, h1_filled_img]
|
| 642 |
+
)
|
| 643 |
+
|
| 644 |
+
outline_bbox_dict = {}
|
| 645 |
+
for k, v in outline_location.items():
|
| 646 |
+
outline_bbox_dict[k] = v['location']
|
| 647 |
+
|
| 648 |
+
bbox_check_result = check_bounding_boxes(
|
| 649 |
+
outline_bbox_dict,
|
| 650 |
+
new_outline['meta']['width'],
|
| 651 |
+
new_outline['meta']['height']
|
| 652 |
+
)
|
| 653 |
+
|
| 654 |
+
if len(bbox_check_result) != 0:
|
| 655 |
+
print(bbox_check_result)
|
| 656 |
+
attempt += 1
|
| 657 |
+
continue
|
| 658 |
+
|
| 659 |
+
h1_critic_agent.reset()
|
| 660 |
+
response = h1_critic_agent.step(h1_critic_msg)
|
| 661 |
+
input_token, output_token = account_token(response)
|
| 662 |
+
consumption_log['h1_critic'].append((input_token, output_token))
|
| 663 |
+
total_input_token += input_token
|
| 664 |
+
total_output_token += output_token
|
| 665 |
+
if response.msgs[0].content == 'T':
|
| 666 |
+
print('Blank area detected.')
|
| 667 |
+
attempt += 1
|
| 668 |
+
continue
|
| 669 |
+
|
| 670 |
+
break
|
| 671 |
+
|
| 672 |
+
outline_bbox_dict = {}
|
| 673 |
+
for k, v in outline_location.items():
|
| 674 |
+
outline_bbox_dict[k] = v['location']
|
| 675 |
+
|
| 676 |
+
# Generate subsection locations
|
| 677 |
+
outline_no_sub_locations = copy.deepcopy(new_outline)
|
| 678 |
+
if 'meta' in outline_no_sub_locations:
|
| 679 |
+
outline_no_sub_locations.pop('meta')
|
| 680 |
+
|
| 681 |
+
for k, v in outline_no_sub_locations.items():
|
| 682 |
+
if 'subsections' in v:
|
| 683 |
+
subsections = v['subsections']
|
| 684 |
+
for k_sub, v_sub in subsections.items():
|
| 685 |
+
del v_sub['location']
|
| 686 |
+
del v_sub['name']
|
| 687 |
+
|
| 688 |
+
h2_actor_template = jinja_env.from_string(config_h2_actor["template"])
|
| 689 |
+
|
| 690 |
+
h2_cumulative_input_token = 0
|
| 691 |
+
h2_cumulative_output_token = 0
|
| 692 |
+
|
| 693 |
+
for section in sections:
|
| 694 |
+
while True:
|
| 695 |
+
print(f'generating h2 for section {section}...')
|
| 696 |
+
section_outline = {section: outline_no_sub_locations[section]}
|
| 697 |
+
section_jinja_args = {
|
| 698 |
+
'section_outline': json.dumps(section_outline, indent=4),
|
| 699 |
+
}
|
| 700 |
+
|
| 701 |
+
section_prompt = h2_actor_template.render(**section_jinja_args)
|
| 702 |
+
|
| 703 |
+
h2_actor_agent.reset()
|
| 704 |
+
response = h2_actor_agent.step(section_prompt)
|
| 705 |
+
input_token, output_token = account_token(response)
|
| 706 |
+
h2_cumulative_input_token += input_token
|
| 707 |
+
h2_cumulative_output_token += output_token
|
| 708 |
+
subsection_location = get_json_from_response(response.msgs[0].content)
|
| 709 |
+
|
| 710 |
+
sec_bbox = outline_no_sub_locations[section]['location']
|
| 711 |
+
subsection_location_dict = {}
|
| 712 |
+
for k, v in subsection_location.items():
|
| 713 |
+
subsection_location_dict[k] = {
|
| 714 |
+
'left': v['location'][0],
|
| 715 |
+
'top': v['location'][1],
|
| 716 |
+
'width': v['location'][2],
|
| 717 |
+
'height': v['location'][3]
|
| 718 |
+
}
|
| 719 |
+
|
| 720 |
+
is_valid, revised = validate_and_adjust_subsections(sec_bbox, subsection_location_dict)
|
| 721 |
+
if not is_valid:
|
| 722 |
+
is_valid, revised = validate_and_adjust_subsections(sec_bbox, revised)
|
| 723 |
+
assert is_valid, "Failed to adjust subsections to fit section"
|
| 724 |
+
outline_no_sub_locations = fill_location(outline_no_sub_locations, section, revised)
|
| 725 |
+
else:
|
| 726 |
+
outline_no_sub_locations = fill_location(outline_no_sub_locations, section, subsection_location)
|
| 727 |
+
break
|
| 728 |
+
|
| 729 |
+
consumption_log['h2_actor'].append((h2_cumulative_input_token, h2_cumulative_output_token))
|
| 730 |
+
total_input_token += h2_cumulative_input_token
|
| 731 |
+
total_output_token += h2_cumulative_output_token
|
| 732 |
+
|
| 733 |
+
outline_no_sub_locations['meta'] = outline['meta']
|
| 734 |
+
outline_no_sub_locations_with_name = recover_name_and_location(outline_no_sub_locations, new_outline)
|
| 735 |
+
new_outline = outline_no_sub_locations_with_name
|
| 736 |
+
|
| 737 |
+
### Outline finalized, actually generate layout
|
| 738 |
+
|
| 739 |
+
logs = {}
|
| 740 |
+
|
| 741 |
+
gen_layout_cumulative_input_token = 0
|
| 742 |
+
gen_layout_cumulative_output_token = 0
|
| 743 |
+
curr_section = sections[0]
|
| 744 |
+
|
| 745 |
+
init_outline = {'meta': new_outline['meta'], sections[0]: new_outline[sections[0]]}
|
| 746 |
+
|
| 747 |
+
init_jinja_args = {
|
| 748 |
+
'json_outline': init_outline,
|
| 749 |
+
'function_docs': documentation
|
| 750 |
+
}
|
| 751 |
+
|
| 752 |
+
init_prompt = init_template.render(**init_jinja_args)
|
| 753 |
+
logs[curr_section] = gen_layout(
|
| 754 |
+
init_actor_agent,
|
| 755 |
+
init_prompt,
|
| 756 |
+
args.max_retry,
|
| 757 |
+
name_to_hierarchy,
|
| 758 |
+
visual_identifier=curr_section
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
if logs[curr_section][-1]['error'] is not None:
|
| 762 |
+
raise ValueError(f'Failed to generate layout for section {curr_section}.')
|
| 763 |
+
|
| 764 |
+
gen_layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
|
| 765 |
+
gen_layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
|
| 766 |
+
|
| 767 |
+
for section_index in range(1, len(sections)):
|
| 768 |
+
curr_section = sections[section_index]
|
| 769 |
+
print(f'generating section {curr_section}...')
|
| 770 |
+
new_section_outline = {curr_section: new_outline[curr_section]}
|
| 771 |
+
new_section_jinja_args = {
|
| 772 |
+
'json_outline': new_section_outline,
|
| 773 |
+
'function_docs': documentation
|
| 774 |
+
}
|
| 775 |
+
new_section_prompt = new_section_template.render(**new_section_jinja_args)
|
| 776 |
+
|
| 777 |
+
logs[curr_section] = gen_layout(
|
| 778 |
+
new_section_actor_agent,
|
| 779 |
+
new_section_prompt,
|
| 780 |
+
args.max_retry,
|
| 781 |
+
name_to_hierarchy,
|
| 782 |
+
visual_identifier=curr_section,
|
| 783 |
+
existing_code = logs[sections[section_index - 1]][-1]['concatenated_code']
|
| 784 |
+
)
|
| 785 |
+
if logs[curr_section][-1]['error'] is not None:
|
| 786 |
+
raise ValueError(f'Failed to generate layout for section {curr_section}.')
|
| 787 |
+
|
| 788 |
+
gen_layout_cumulative_input_token += logs[curr_section][-1]['cumulative_tokens'][0]
|
| 789 |
+
gen_layout_cumulative_output_token += logs[curr_section][-1]['cumulative_tokens'][1]
|
| 790 |
+
|
| 791 |
+
consumption_log['gen_layout'].append((gen_layout_cumulative_input_token, gen_layout_cumulative_output_token))
|
| 792 |
+
total_input_token += gen_layout_cumulative_input_token
|
| 793 |
+
total_output_token += gen_layout_cumulative_output_token
|
| 794 |
+
|
| 795 |
+
h1_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1.pptx'
|
| 796 |
+
h2_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2.pptx'
|
| 797 |
+
|
| 798 |
+
h1_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_1_filled.pptx'
|
| 799 |
+
h2_filled_path = f'tmp/poster_<{sections[-1]}>_hierarchy_2_filled.pptx'
|
| 800 |
+
|
| 801 |
+
ppt_to_images(h1_path, f'{poster_log_path}/layout_h1')
|
| 802 |
+
ppt_to_images(h2_path, f'{poster_log_path}/layout_h2')
|
| 803 |
+
ppt_to_images(h1_filled_path, f'{poster_log_path}/layout_h1_filled')
|
| 804 |
+
ppt_to_images(h2_filled_path, f'{poster_log_path}/layout_h2_filled')
|
| 805 |
+
|
| 806 |
+
h1_img = Image.open(f'{poster_log_path}/layout_h1/slide_0001.jpg')
|
| 807 |
+
h2_img = Image.open(f'{poster_log_path}/layout_h2/slide_0001.jpg')
|
| 808 |
+
h1_filled_img = Image.open(f'{poster_log_path}/layout_h1_filled/slide_0001.jpg')
|
| 809 |
+
h2_filled_img = Image.open(f'{poster_log_path}/layout_h2_filled/slide_0001.jpg')
|
| 810 |
+
|
| 811 |
+
ckpt = {
|
| 812 |
+
'logs': logs,
|
| 813 |
+
'outline': new_outline,
|
| 814 |
+
'name_to_hierarchy': name_to_hierarchy,
|
| 815 |
+
'consumption_log': consumption_log,
|
| 816 |
+
'total_input_token': total_input_token,
|
| 817 |
+
'total_output_token': total_output_token,
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
with open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'wb') as f:
|
| 821 |
+
pkl.dump(ckpt, f)
|
| 822 |
+
|
| 823 |
+
json.dump(
|
| 824 |
+
new_outline,
|
| 825 |
+
open(outline_file_path, "w"),
|
| 826 |
+
ensure_ascii=False,
|
| 827 |
+
indent=4,
|
| 828 |
+
)
|
| 829 |
+
|
| 830 |
+
return total_input_token, total_output_token
|
| 831 |
+
|
| 832 |
+
if __name__ == '__main__':
|
| 833 |
+
parser = argparse.ArgumentParser()
|
| 834 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 835 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 836 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 837 |
+
parser.add_argument('--index', type=int, default=0)
|
| 838 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 839 |
+
args = parser.parse_args()
|
| 840 |
+
|
| 841 |
+
actor_config = get_agent_config(args.model_name)
|
| 842 |
+
critic_config = get_agent_config(args.model_name)
|
| 843 |
+
|
| 844 |
+
if args.poster_name is None:
|
| 845 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 846 |
+
|
| 847 |
+
input_token, output_token = filter_image_table(args, actor_config)
|
| 848 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
| 849 |
+
|
| 850 |
+
input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
|
| 851 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/gen_outline_layout_parallel.py
ADDED
|
@@ -0,0 +1,949 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import copy
|
| 5 |
+
import yaml
|
| 6 |
+
import logging
|
| 7 |
+
import time
|
| 8 |
+
from jinja2 import Environment, StrictUndefined
|
| 9 |
+
|
| 10 |
+
from utils.src.utils import ppt_to_images, get_json_from_response
|
| 11 |
+
|
| 12 |
+
from camel.models import ModelFactory
|
| 13 |
+
from camel.agents import ChatAgent
|
| 14 |
+
from camel.messages import BaseMessage
|
| 15 |
+
|
| 16 |
+
from utils.pptx_utils import *
|
| 17 |
+
from utils.wei_utils import *
|
| 18 |
+
|
| 19 |
+
import pickle as pkl
|
| 20 |
+
import argparse
|
| 21 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 22 |
+
from concurrent.futures import ProcessPoolExecutor, as_completed
|
| 23 |
+
import concurrent.futures
|
| 24 |
+
import sys
|
| 25 |
+
|
| 26 |
+
load_dotenv()
|
| 27 |
+
|
| 28 |
+
logging.basicConfig(
|
| 29 |
+
level=logging.DEBUG,
|
| 30 |
+
format='%(threadName)s: %(message)s',
|
| 31 |
+
stream=sys.stdout
|
| 32 |
+
)
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
IMAGE_SCALE_RATIO_MIN = 50
|
| 36 |
+
IMAGE_SCALE_RATIO_MAX = 40
|
| 37 |
+
TABLE_SCALE_RATIO_MIN = 100
|
| 38 |
+
TABLE_SCALE_RATIO_MAX = 80
|
| 39 |
+
|
| 40 |
+
def layout_process_section_wrapped(
|
| 41 |
+
sections,
|
| 42 |
+
new_outline,
|
| 43 |
+
init_template,
|
| 44 |
+
new_section_template,
|
| 45 |
+
init_actor_sys_msg,
|
| 46 |
+
new_section_actor_sys_msg,
|
| 47 |
+
actor_config,
|
| 48 |
+
documentation,
|
| 49 |
+
max_retry,
|
| 50 |
+
slide_width,
|
| 51 |
+
slide_height
|
| 52 |
+
):
|
| 53 |
+
logs = {}
|
| 54 |
+
parallel_results = {}
|
| 55 |
+
total_input_token, total_output_token = 0, 0
|
| 56 |
+
|
| 57 |
+
# Switch from ThreadPoolExecutor to ProcessPoolExecutor
|
| 58 |
+
with ThreadPoolExecutor() as executor:
|
| 59 |
+
futures = []
|
| 60 |
+
|
| 61 |
+
for section_index in range(len(sections)):
|
| 62 |
+
if section_index == 0:
|
| 63 |
+
sys_msg = init_actor_sys_msg
|
| 64 |
+
prompt_template = init_template
|
| 65 |
+
else:
|
| 66 |
+
sys_msg = new_section_actor_sys_msg
|
| 67 |
+
prompt_template = new_section_template
|
| 68 |
+
|
| 69 |
+
actor_model = ModelFactory.create(
|
| 70 |
+
model_platform=actor_config['model_platform'],
|
| 71 |
+
model_type=actor_config['model_type'],
|
| 72 |
+
model_config_dict=actor_config['model_config'],
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
future = executor.submit(
|
| 76 |
+
layout_process_section,
|
| 77 |
+
section_index,
|
| 78 |
+
sections,
|
| 79 |
+
new_outline,
|
| 80 |
+
prompt_template,
|
| 81 |
+
documentation,
|
| 82 |
+
sys_msg,
|
| 83 |
+
actor_model,
|
| 84 |
+
10,
|
| 85 |
+
max_retry,
|
| 86 |
+
slide_width,
|
| 87 |
+
slide_height
|
| 88 |
+
)
|
| 89 |
+
futures.append(future)
|
| 90 |
+
|
| 91 |
+
# Collect results as processes complete
|
| 92 |
+
for future in as_completed(futures):
|
| 93 |
+
section_index, section_logs, in_toks, out_toks = future.result()
|
| 94 |
+
|
| 95 |
+
# Store logs by section index
|
| 96 |
+
parallel_results[section_index] = section_logs
|
| 97 |
+
|
| 98 |
+
# Update token counters
|
| 99 |
+
total_input_token += in_toks
|
| 100 |
+
total_output_token += out_toks
|
| 101 |
+
|
| 102 |
+
# Merge results back into `logs`
|
| 103 |
+
for section_index, section_logs in parallel_results.items():
|
| 104 |
+
curr_section = sections[section_index]
|
| 105 |
+
logs[curr_section] = section_logs
|
| 106 |
+
|
| 107 |
+
return logs, total_input_token, total_output_token
|
| 108 |
+
|
| 109 |
+
def create_agent_fn(sys_msg, agent_model, window_size=10):
|
| 110 |
+
agent = ChatAgent(
|
| 111 |
+
system_message=sys_msg,
|
| 112 |
+
model=agent_model,
|
| 113 |
+
message_window_size=window_size,
|
| 114 |
+
)
|
| 115 |
+
return agent
|
| 116 |
+
|
| 117 |
+
def layout_h2_process_section(
|
| 118 |
+
section,
|
| 119 |
+
outline_no_sub_locations,
|
| 120 |
+
h2_actor_template,
|
| 121 |
+
create_h2_actor_agent, # If you need a fresh agent for each thread
|
| 122 |
+
):
|
| 123 |
+
"""
|
| 124 |
+
Run the logic for a single section.
|
| 125 |
+
Returns a tuple containing:
|
| 126 |
+
- section name (or id),
|
| 127 |
+
- updated subsection-location dict,
|
| 128 |
+
- input token count,
|
| 129 |
+
- output token count
|
| 130 |
+
"""
|
| 131 |
+
print(f'Generating h2 for section {section}...', flush=True)
|
| 132 |
+
|
| 133 |
+
# 1) Create the prompt
|
| 134 |
+
section_outline = {section: outline_no_sub_locations[section]}
|
| 135 |
+
section_jinja_args = {
|
| 136 |
+
'section_outline': json.dumps(section_outline, indent=4),
|
| 137 |
+
}
|
| 138 |
+
section_prompt = h2_actor_template.render(**section_jinja_args)
|
| 139 |
+
|
| 140 |
+
# 2) Prepare a fresh agent or reuse existing (thread-safe?) agent
|
| 141 |
+
# If your h2_actor_agent is not thread-safe, instantiate a new one here:
|
| 142 |
+
h2_actor_agent = create_h2_actor_agent()
|
| 143 |
+
h2_actor_agent.reset()
|
| 144 |
+
|
| 145 |
+
# 3) Get response
|
| 146 |
+
response = h2_actor_agent.step(section_prompt)
|
| 147 |
+
input_token, output_token = account_token(response)
|
| 148 |
+
|
| 149 |
+
# 4) Parse JSON
|
| 150 |
+
subsection_location = get_json_from_response(response.msgs[0].content)
|
| 151 |
+
|
| 152 |
+
# 5) Create a dict from the sub-locations
|
| 153 |
+
sec_bbox = outline_no_sub_locations[section]['location']
|
| 154 |
+
subsection_location_dict = {}
|
| 155 |
+
for k, v in subsection_location.items():
|
| 156 |
+
subsection_location_dict[k] = {
|
| 157 |
+
'left': v['location'][0],
|
| 158 |
+
'top': v['location'][1],
|
| 159 |
+
'width': v['location'][2],
|
| 160 |
+
'height': v['location'][3]
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
# 6) Validate and possibly revise
|
| 164 |
+
is_valid, revised = validate_and_adjust_subsections(sec_bbox, subsection_location_dict)
|
| 165 |
+
if not is_valid:
|
| 166 |
+
# Try once more
|
| 167 |
+
is_valid, revised = validate_and_adjust_subsections(sec_bbox, revised)
|
| 168 |
+
assert is_valid, "Failed to adjust subsections to fit section"
|
| 169 |
+
final_sub_loc = revised
|
| 170 |
+
else:
|
| 171 |
+
final_sub_loc = subsection_location
|
| 172 |
+
|
| 173 |
+
# Return all data needed by the main thread
|
| 174 |
+
return section, final_sub_loc, input_token, output_token
|
| 175 |
+
|
| 176 |
+
def layout_process_section(
|
| 177 |
+
section_index,
|
| 178 |
+
sections,
|
| 179 |
+
new_outline,
|
| 180 |
+
new_section_template,
|
| 181 |
+
documentation,
|
| 182 |
+
sys_msg,
|
| 183 |
+
agent_model,
|
| 184 |
+
window_size,
|
| 185 |
+
max_retry,
|
| 186 |
+
slide_width,
|
| 187 |
+
slide_height
|
| 188 |
+
):
|
| 189 |
+
"""
|
| 190 |
+
Runs the 'gen_layout' logic for a single section_index.
|
| 191 |
+
|
| 192 |
+
Returns a tuple:
|
| 193 |
+
(section_index, updated_log, input_tokens, output_tokens)
|
| 194 |
+
"""
|
| 195 |
+
curr_section = sections[section_index]
|
| 196 |
+
print(f'Generating h1 layout for section {curr_section}...')
|
| 197 |
+
|
| 198 |
+
# Build outline JSON just for current section
|
| 199 |
+
new_section_outline = {curr_section: new_outline[curr_section]}
|
| 200 |
+
if section_index == 0:
|
| 201 |
+
new_section_outline = {'meta': new_outline['meta'], curr_section: new_outline[curr_section]}
|
| 202 |
+
new_section_jinja_args = {
|
| 203 |
+
'json_outline': new_section_outline,
|
| 204 |
+
'function_docs': documentation,
|
| 205 |
+
'file_name': f'poster_{section_index}.pptx'
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
# Render prompt
|
| 209 |
+
new_section_prompt = new_section_template.render(**new_section_jinja_args)
|
| 210 |
+
|
| 211 |
+
existing_code = '' # Or fetch from a stable location that is not dependent on real-time results
|
| 212 |
+
|
| 213 |
+
# Call gen_layout
|
| 214 |
+
section_logs = gen_layout_parallel(
|
| 215 |
+
create_agent_fn(
|
| 216 |
+
sys_msg,
|
| 217 |
+
agent_model,
|
| 218 |
+
window_size
|
| 219 |
+
),
|
| 220 |
+
new_section_prompt,
|
| 221 |
+
max_retry,
|
| 222 |
+
existing_code=existing_code,
|
| 223 |
+
slide_width=slide_width,
|
| 224 |
+
slide_height=slide_height,
|
| 225 |
+
tmp_name=section_index
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
if section_logs[-1]['error'] is not None:
|
| 229 |
+
print(f'Failed to generate layout for section {curr_section}.')
|
| 230 |
+
return None
|
| 231 |
+
|
| 232 |
+
in_toks, out_toks = section_logs[-1]['cumulative_tokens']
|
| 233 |
+
return (section_index, section_logs, in_toks, out_toks)
|
| 234 |
+
|
| 235 |
+
def get_outline_location(outline, subsection=False):
|
| 236 |
+
outline_location = {}
|
| 237 |
+
for k, v in outline.items():
|
| 238 |
+
if k == 'meta':
|
| 239 |
+
continue
|
| 240 |
+
outline_location[k] = {
|
| 241 |
+
'location': v['location'],
|
| 242 |
+
}
|
| 243 |
+
if subsection:
|
| 244 |
+
if 'subsections' in v:
|
| 245 |
+
outline_location[k]['subsections'] = get_outline_location(v['subsections'])
|
| 246 |
+
return outline_location
|
| 247 |
+
|
| 248 |
+
def apply_outline_location(outline, location, subsection=False):
|
| 249 |
+
new_outline = {}
|
| 250 |
+
for k, v in outline.items():
|
| 251 |
+
if k == 'meta':
|
| 252 |
+
new_outline[k] = v
|
| 253 |
+
continue
|
| 254 |
+
new_outline[k] = copy.deepcopy(v)
|
| 255 |
+
new_outline[k]['location'] = location[k]['location']
|
| 256 |
+
if subsection:
|
| 257 |
+
if 'subsections' in v:
|
| 258 |
+
new_outline[k]['subsections'] = apply_outline_location(v['subsections'], location[k]['subsections'])
|
| 259 |
+
|
| 260 |
+
return new_outline
|
| 261 |
+
|
| 262 |
+
def fill_location(outline, section_name, location_dict):
|
| 263 |
+
new_outline = copy.deepcopy(outline)
|
| 264 |
+
if 'subsections' not in new_outline[section_name]:
|
| 265 |
+
return new_outline
|
| 266 |
+
for k, v in new_outline[section_name]['subsections'].items():
|
| 267 |
+
v['location'] = location_dict[k]['location']
|
| 268 |
+
return new_outline
|
| 269 |
+
|
| 270 |
+
def recover_name_and_location(outline_no_name, outline):
|
| 271 |
+
new_outline = copy.deepcopy(outline_no_name)
|
| 272 |
+
for k, v in outline_no_name.items():
|
| 273 |
+
if k == 'meta':
|
| 274 |
+
continue
|
| 275 |
+
new_outline[k]['name'] = outline[k]['name']
|
| 276 |
+
if type(new_outline[k]['location']) == list:
|
| 277 |
+
new_outline[k]['location'] = {
|
| 278 |
+
'left': v['location'][0],
|
| 279 |
+
'top': v['location'][1],
|
| 280 |
+
'width': v['location'][2],
|
| 281 |
+
'height': v['location'][3]
|
| 282 |
+
}
|
| 283 |
+
if 'subsections' in v:
|
| 284 |
+
for k_sub, v_sub in v['subsections'].items():
|
| 285 |
+
new_outline[k]['subsections'][k_sub]['name'] = outline[k]['subsections'][k_sub]['name']
|
| 286 |
+
if type(new_outline[k]['subsections'][k_sub]['location']) == list:
|
| 287 |
+
new_outline[k]['subsections'][k_sub]['location'] = {
|
| 288 |
+
'left': v_sub['location'][0],
|
| 289 |
+
'top': v_sub['location'][1],
|
| 290 |
+
'width': v_sub['location'][2],
|
| 291 |
+
'height': v_sub['location'][3]
|
| 292 |
+
}
|
| 293 |
+
return new_outline
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def validate_and_adjust_subsections(section_bbox, subsection_bboxes):
|
| 297 |
+
"""
|
| 298 |
+
Validate that the given subsections collectively occupy the entire section.
|
| 299 |
+
If not, return an adjusted version that fixes the layout.
|
| 300 |
+
|
| 301 |
+
We assume all subsections are intended to be stacked vertically with no gaps,
|
| 302 |
+
spanning the full width of the section.
|
| 303 |
+
|
| 304 |
+
:param section_bbox: dict with keys ["left", "top", "width", "height"]
|
| 305 |
+
:param subsection_bboxes: dict of subsection_name -> bounding_box (each also
|
| 306 |
+
with keys ["left", "top", "width", "height"])
|
| 307 |
+
:return: (is_valid, revised_subsections)
|
| 308 |
+
where is_valid is True/False,
|
| 309 |
+
and revised_subsections is either the same as subsection_bboxes if valid,
|
| 310 |
+
or a new dict of adjusted bounding boxes if invalid.
|
| 311 |
+
"""
|
| 312 |
+
|
| 313 |
+
# Helper functions
|
| 314 |
+
def _right(bbox):
|
| 315 |
+
return bbox["left"] + bbox["width"]
|
| 316 |
+
|
| 317 |
+
def _bottom(bbox):
|
| 318 |
+
return bbox["top"] + bbox["height"]
|
| 319 |
+
|
| 320 |
+
section_left = section_bbox["left"]
|
| 321 |
+
section_top = section_bbox["top"]
|
| 322 |
+
section_right = section_left + section_bbox["width"]
|
| 323 |
+
section_bottom = section_top + section_bbox["height"]
|
| 324 |
+
|
| 325 |
+
# Convert dictionary to a list of (subsection_name, bbox) pairs
|
| 326 |
+
items = list(subsection_bboxes.items())
|
| 327 |
+
if not items:
|
| 328 |
+
# No subsections is definitely not valid if we want to fill the section
|
| 329 |
+
return False, None
|
| 330 |
+
|
| 331 |
+
# Sort subsections by their 'top' coordinate
|
| 332 |
+
items_sorted = sorted(items, key=lambda x: x[1]["top"])
|
| 333 |
+
|
| 334 |
+
# ---------------------------
|
| 335 |
+
# Step 1: Validate
|
| 336 |
+
# ---------------------------
|
| 337 |
+
# We'll check:
|
| 338 |
+
# 1. left/right boundaries match the section for each subsection
|
| 339 |
+
# 2. The first subsection's top == section_top
|
| 340 |
+
# 3. The last subsection's bottom == section_bottom
|
| 341 |
+
# 4. Each pair of consecutive subsections lines up exactly
|
| 342 |
+
# (previous bottom == current top) with no gap or overlap.
|
| 343 |
+
|
| 344 |
+
is_valid = True
|
| 345 |
+
|
| 346 |
+
# Check left/right for each
|
| 347 |
+
for name, bbox in items_sorted:
|
| 348 |
+
if bbox["left"] != section_left or _right(bbox) != section_right:
|
| 349 |
+
is_valid = False
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
# Check alignment for the first and last
|
| 353 |
+
if is_valid:
|
| 354 |
+
first_sub_name, first_sub_bbox = items_sorted[0]
|
| 355 |
+
if first_sub_bbox["top"] != section_top:
|
| 356 |
+
is_valid = False
|
| 357 |
+
|
| 358 |
+
if is_valid:
|
| 359 |
+
last_sub_name, last_sub_bbox = items_sorted[-1]
|
| 360 |
+
if _bottom(last_sub_bbox) != section_bottom:
|
| 361 |
+
is_valid = False
|
| 362 |
+
|
| 363 |
+
# Check consecutive alignment
|
| 364 |
+
if is_valid:
|
| 365 |
+
for i in range(len(items_sorted) - 1):
|
| 366 |
+
_, current_bbox = items_sorted[i]
|
| 367 |
+
_, next_bbox = items_sorted[i + 1]
|
| 368 |
+
if _bottom(current_bbox) != next_bbox["top"]:
|
| 369 |
+
is_valid = False
|
| 370 |
+
break
|
| 371 |
+
|
| 372 |
+
# If everything passed, we return
|
| 373 |
+
if is_valid:
|
| 374 |
+
return True, subsection_bboxes
|
| 375 |
+
|
| 376 |
+
# ---------------------------
|
| 377 |
+
# Step 2: Revise
|
| 378 |
+
# ---------------------------
|
| 379 |
+
# We will adjust all subsection bboxes so that they occupy
|
| 380 |
+
# the entire section exactly, preserving each original bbox's
|
| 381 |
+
# height *ratio* if possible.
|
| 382 |
+
|
| 383 |
+
# 2a. Compute total original height (in the order of sorted items)
|
| 384 |
+
original_heights = [bbox["height"] for _, bbox in items_sorted]
|
| 385 |
+
total_original_height = sum(original_heights)
|
| 386 |
+
|
| 387 |
+
# Avoid divide-by-zero if somehow there's a 0 height
|
| 388 |
+
if total_original_height <= 0:
|
| 389 |
+
# Fallback: split the section equally among subsections
|
| 390 |
+
# to avoid zero or negative heights
|
| 391 |
+
chunk_height = section_bbox["height"] / len(items_sorted)
|
| 392 |
+
scale_heights = [chunk_height] * len(items_sorted)
|
| 393 |
+
else:
|
| 394 |
+
# Scale each original height by the ratio of
|
| 395 |
+
# (section total height / sum of original heights)
|
| 396 |
+
scale = section_bbox["height"] / total_original_height
|
| 397 |
+
scale_heights = [h * scale for h in original_heights]
|
| 398 |
+
|
| 399 |
+
# 2b. Assign bounding boxes top->bottom, ensuring no gap
|
| 400 |
+
revised = {}
|
| 401 |
+
current_top = section_top
|
| 402 |
+
for i, (name, original_bbox) in enumerate(items_sorted):
|
| 403 |
+
revised_height = scale_heights[i]
|
| 404 |
+
# If there's floating error, we can clamp in the last iteration
|
| 405 |
+
# so that the bottom exactly matches section_bottom.
|
| 406 |
+
# But for simplicity, we'll keep it straightforward unless needed.
|
| 407 |
+
|
| 408 |
+
revised[name] = {
|
| 409 |
+
"left": section_left,
|
| 410 |
+
"top": current_top,
|
| 411 |
+
"width": section_bbox["width"],
|
| 412 |
+
"height": revised_height
|
| 413 |
+
}
|
| 414 |
+
# Update current_top for next subsection
|
| 415 |
+
current_top += revised_height
|
| 416 |
+
|
| 417 |
+
# Due to potential float rounding, we can enforce the last subsection
|
| 418 |
+
# to exactly end at section_bottom:
|
| 419 |
+
last_name = items_sorted[-1][0]
|
| 420 |
+
# Recompute the actual bottom after the above assignment
|
| 421 |
+
new_bottom = revised[last_name]["top"] + revised[last_name]["height"]
|
| 422 |
+
diff = new_bottom - section_bottom
|
| 423 |
+
if abs(diff) > 1e-9:
|
| 424 |
+
# Adjust the last subsection's height
|
| 425 |
+
revised[last_name]["height"] -= diff
|
| 426 |
+
|
| 427 |
+
# Return the revised dictionary
|
| 428 |
+
return False, revised
|
| 429 |
+
|
| 430 |
+
def filter_image_table(args, filter_config):
|
| 431 |
+
images = json.load(open(f'images_and_tables/{args.poster_name}_images.json', 'r'))
|
| 432 |
+
tables = json.load(open(f'images_and_tables/{args.poster_name}_tables.json', 'r'))
|
| 433 |
+
doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
|
| 434 |
+
agent_filter = 'image_table_filter_agent'
|
| 435 |
+
with open(f"prompt_templates/{agent_filter}.yaml", "r") as f:
|
| 436 |
+
config_filter = yaml.safe_load(f)
|
| 437 |
+
|
| 438 |
+
image_information = {}
|
| 439 |
+
for k, v in images.items():
|
| 440 |
+
image_information[k] = copy.deepcopy(v)
|
| 441 |
+
image_information[k]['min_width'] = v['width'] // IMAGE_SCALE_RATIO_MIN
|
| 442 |
+
image_information[k]['min_height'] = v['height'] // IMAGE_SCALE_RATIO_MIN
|
| 443 |
+
image_information[k]['max_width'] = v['width'] // IMAGE_SCALE_RATIO_MAX
|
| 444 |
+
image_information[k]['max_height'] = v['height'] // IMAGE_SCALE_RATIO_MAX
|
| 445 |
+
|
| 446 |
+
table_information = {}
|
| 447 |
+
for k, v in tables.items():
|
| 448 |
+
table_information[k] = copy.deepcopy(v)
|
| 449 |
+
table_information[k]['min_width'] = v['width'] // TABLE_SCALE_RATIO_MIN
|
| 450 |
+
table_information[k]['min_height'] = v['height'] // TABLE_SCALE_RATIO_MIN
|
| 451 |
+
table_information[k]['max_width'] = v['width'] // TABLE_SCALE_RATIO_MAX
|
| 452 |
+
table_information[k]['max_height'] = v['height'] // TABLE_SCALE_RATIO_MAX
|
| 453 |
+
|
| 454 |
+
filter_actor_sys_msg = config_filter['system_prompt']
|
| 455 |
+
|
| 456 |
+
filter_model = ModelFactory.create(
|
| 457 |
+
model_platform=filter_config['model_platform'],
|
| 458 |
+
model_type=filter_config['model_type'],
|
| 459 |
+
model_config_dict=filter_config['model_config'],
|
| 460 |
+
)
|
| 461 |
+
filter_actor_agent = ChatAgent(
|
| 462 |
+
system_message=filter_actor_sys_msg,
|
| 463 |
+
model=filter_model,
|
| 464 |
+
message_window_size=10, # [Optional] the length for chat memory
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
filter_jinja_args = {
|
| 468 |
+
'json_content': doc_json,
|
| 469 |
+
'table_information': table_information,
|
| 470 |
+
'image_information': image_information,
|
| 471 |
+
}
|
| 472 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 473 |
+
filter_prompt = jinja_env.from_string(config_filter["template"])
|
| 474 |
+
response = filter_actor_agent.step(filter_prompt.render(**filter_jinja_args))
|
| 475 |
+
input_token, output_token = account_token(response)
|
| 476 |
+
response_json = get_json_from_response(response.msgs[0].content)
|
| 477 |
+
table_information = response_json['table_information']
|
| 478 |
+
image_information = response_json['image_information']
|
| 479 |
+
json.dump(images, open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'w'), indent=4)
|
| 480 |
+
json.dump(tables, open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'w'), indent=4)
|
| 481 |
+
|
| 482 |
+
return input_token, output_token
|
| 483 |
+
|
| 484 |
+
def gen_outline_layout(args, actor_config, critic_config):
|
| 485 |
+
poster_log_path = f'log/{args.model_name}_{args.poster_name}_poster_{args.index}'
|
| 486 |
+
if not os.path.exists(poster_log_path):
|
| 487 |
+
os.mkdir(poster_log_path)
|
| 488 |
+
total_input_token, total_output_token = 0, 0
|
| 489 |
+
consumption_log = {
|
| 490 |
+
'outline': [],
|
| 491 |
+
'h1_actor': [],
|
| 492 |
+
'h2_actor': [],
|
| 493 |
+
'h1_critic': [],
|
| 494 |
+
'gen_layout': []
|
| 495 |
+
}
|
| 496 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 497 |
+
outline_file_path = f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json'
|
| 498 |
+
agent_name = 'poster_planner_new'
|
| 499 |
+
agent_init_name = 'layout_agent_init_parallel'
|
| 500 |
+
agent_new_section_name = 'layout_agent_new_section_parallel'
|
| 501 |
+
h1_critic_name = 'critic_layout_hierarchy_1'
|
| 502 |
+
h2_actor_name = 'actor_layout_hierarchy_2'
|
| 503 |
+
|
| 504 |
+
doc_json = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
|
| 505 |
+
filtered_table_information = json.load(open(f'images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
|
| 506 |
+
filtered_image_information = json.load(open(f'images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
|
| 507 |
+
|
| 508 |
+
with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
|
| 509 |
+
planner_config = yaml.safe_load(f)
|
| 510 |
+
|
| 511 |
+
with open(f"prompt_templates/{agent_init_name}.yaml", "r") as f:
|
| 512 |
+
config_init = yaml.safe_load(f)
|
| 513 |
+
|
| 514 |
+
with open(f"prompt_templates/{agent_new_section_name}.yaml", "r") as f:
|
| 515 |
+
config_new_section = yaml.safe_load(f)
|
| 516 |
+
|
| 517 |
+
with open(f"prompt_templates/{h1_critic_name}.yaml", "r") as f:
|
| 518 |
+
config_h1_critic = yaml.safe_load(f)
|
| 519 |
+
|
| 520 |
+
with open(f"prompt_templates/{h2_actor_name}.yaml", "r") as f:
|
| 521 |
+
config_h2_actor = yaml.safe_load(f)
|
| 522 |
+
|
| 523 |
+
planner_model = ModelFactory.create(
|
| 524 |
+
model_platform=actor_config['model_platform'],
|
| 525 |
+
model_type=actor_config['model_type'],
|
| 526 |
+
model_config_dict=actor_config['model_config'],
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
planner_agent = ChatAgent(
|
| 530 |
+
system_message=planner_config['system_prompt'],
|
| 531 |
+
model=planner_model,
|
| 532 |
+
message_window_size=10,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
outline_template = jinja_env.from_string(planner_config["template"])
|
| 536 |
+
|
| 537 |
+
planner_jinja_args = {
|
| 538 |
+
'json_content': doc_json,
|
| 539 |
+
'table_information': filtered_table_information,
|
| 540 |
+
'image_information': filtered_image_information,
|
| 541 |
+
}
|
| 542 |
+
|
| 543 |
+
actor_model = ModelFactory.create(
|
| 544 |
+
model_platform=actor_config['model_platform'],
|
| 545 |
+
model_type=actor_config['model_type'],
|
| 546 |
+
model_config_dict=actor_config['model_config'],
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
init_actor_sys_msg = config_init['system_prompt']
|
| 550 |
+
|
| 551 |
+
def create_init_actor_agent():
|
| 552 |
+
actor_model = ModelFactory.create(
|
| 553 |
+
model_platform=actor_config['model_platform'],
|
| 554 |
+
model_type=actor_config['model_type'],
|
| 555 |
+
model_config_dict=actor_config['model_config'],
|
| 556 |
+
)
|
| 557 |
+
init_actor_agent = ChatAgent(
|
| 558 |
+
system_message=init_actor_sys_msg,
|
| 559 |
+
model=actor_model,
|
| 560 |
+
message_window_size=10,
|
| 561 |
+
)
|
| 562 |
+
return init_actor_agent
|
| 563 |
+
|
| 564 |
+
new_section_actor_sys_msg = config_new_section['system_prompt']
|
| 565 |
+
|
| 566 |
+
def create_new_section_actor_agent():
|
| 567 |
+
actor_model = ModelFactory.create(
|
| 568 |
+
model_platform=actor_config['model_platform'],
|
| 569 |
+
model_type=actor_config['model_type'],
|
| 570 |
+
model_config_dict=actor_config['model_config'],
|
| 571 |
+
)
|
| 572 |
+
new_section_actor_agent = ChatAgent(
|
| 573 |
+
system_message=new_section_actor_sys_msg,
|
| 574 |
+
model=actor_model,
|
| 575 |
+
message_window_size=10,
|
| 576 |
+
)
|
| 577 |
+
return new_section_actor_agent
|
| 578 |
+
|
| 579 |
+
h1_critic_model = ModelFactory.create(
|
| 580 |
+
model_platform=critic_config['model_platform'],
|
| 581 |
+
model_type=critic_config['model_type'],
|
| 582 |
+
model_config_dict=critic_config['model_config'],
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
h1_critic_sys_msg = config_h1_critic['system_prompt']
|
| 586 |
+
|
| 587 |
+
h1_critic_agent = ChatAgent(
|
| 588 |
+
system_message=h1_critic_sys_msg,
|
| 589 |
+
model=h1_critic_model,
|
| 590 |
+
message_window_size=None,
|
| 591 |
+
)
|
| 592 |
+
|
| 593 |
+
h1_pos_example = Image.open('h1_example/h1_pos.jpg')
|
| 594 |
+
h1_neg_example = Image.open('h1_example/h1_neg.jpg')
|
| 595 |
+
|
| 596 |
+
h2_actor_model = ModelFactory.create(
|
| 597 |
+
model_platform=actor_config['model_platform'],
|
| 598 |
+
model_type=actor_config['model_type'],
|
| 599 |
+
model_config_dict=actor_config['model_config'],
|
| 600 |
+
)
|
| 601 |
+
|
| 602 |
+
h2_actor_sys_msg = config_h2_actor['system_prompt']
|
| 603 |
+
|
| 604 |
+
def create_h2_actor_agent():
|
| 605 |
+
h2_actor_model = ModelFactory.create(
|
| 606 |
+
model_platform=actor_config['model_platform'],
|
| 607 |
+
model_type=actor_config['model_type'],
|
| 608 |
+
model_config_dict=actor_config['model_config'],
|
| 609 |
+
)
|
| 610 |
+
h2_actor_agent = ChatAgent(
|
| 611 |
+
system_message=h2_actor_sys_msg,
|
| 612 |
+
model=h2_actor_model,
|
| 613 |
+
message_window_size=10,
|
| 614 |
+
)
|
| 615 |
+
return h2_actor_agent
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
init_template = jinja_env.from_string(config_init["template"])
|
| 619 |
+
new_section_template = jinja_env.from_string(config_new_section["template"])
|
| 620 |
+
h1_critic_template = jinja_env.from_string(config_h1_critic["template"])
|
| 621 |
+
|
| 622 |
+
attempt = 0
|
| 623 |
+
while True:
|
| 624 |
+
print(f'Generating outline attempt {attempt}...', flush=True)
|
| 625 |
+
planner_prompt = outline_template.render(**planner_jinja_args)
|
| 626 |
+
planner_agent.reset()
|
| 627 |
+
response = planner_agent.step(planner_prompt)
|
| 628 |
+
outline = get_json_from_response(response.msgs[0].content)
|
| 629 |
+
input_token, output_token = account_token(response)
|
| 630 |
+
sections = list(outline.keys())
|
| 631 |
+
sections = [x for x in sections if x != 'meta']
|
| 632 |
+
slide_width = outline['meta']['width']
|
| 633 |
+
slide_height = outline['meta']['height']
|
| 634 |
+
name_to_hierarchy = get_hierarchy(outline)
|
| 635 |
+
consumption_log['outline'].append((input_token, output_token))
|
| 636 |
+
total_input_token += input_token
|
| 637 |
+
total_output_token += output_token
|
| 638 |
+
init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
|
| 639 |
+
|
| 640 |
+
new_outline = outline
|
| 641 |
+
|
| 642 |
+
init_jinja_args = {
|
| 643 |
+
'json_outline': init_outline,
|
| 644 |
+
'function_docs': documentation
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
init_prompt = init_template.render(**init_jinja_args)
|
| 648 |
+
|
| 649 |
+
# hierarchy 1 only
|
| 650 |
+
outline_location = get_outline_location(outline, subsection=False)
|
| 651 |
+
|
| 652 |
+
logs, layout_cumulative_input_token, layout_cumulative_output_token = layout_process_section_wrapped(
|
| 653 |
+
sections,
|
| 654 |
+
new_outline,
|
| 655 |
+
init_template,
|
| 656 |
+
new_section_template,
|
| 657 |
+
init_actor_sys_msg,
|
| 658 |
+
new_section_actor_sys_msg,
|
| 659 |
+
actor_config,
|
| 660 |
+
documentation,
|
| 661 |
+
args.max_retry,
|
| 662 |
+
slide_width,
|
| 663 |
+
slide_height
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
concatenated_code = utils_functions
|
| 667 |
+
for section_index in range(len(sections)):
|
| 668 |
+
section = sections[section_index]
|
| 669 |
+
concatenated_code += '\n' + logs[section][-1]['code']
|
| 670 |
+
presentation_object_name = logs[section][-1]['output'].replace('\n', '')
|
| 671 |
+
concatenated_code += '\n' + f'save_presentation({presentation_object_name}, file_name="poster_{section_index + 1}.pptx")'
|
| 672 |
+
|
| 673 |
+
concatenated_code += f'''
|
| 674 |
+
name_to_hierarchy = {name_to_hierarchy}
|
| 675 |
+
identifier = "parallel"
|
| 676 |
+
poster_path = "poster_{section_index + 1}.pptx"
|
| 677 |
+
get_visual_cues(name_to_hierarchy, identifier, poster_path)
|
| 678 |
+
'''
|
| 679 |
+
output, error = run_code_with_utils(concatenated_code, utils_functions)
|
| 680 |
+
if error is not None:
|
| 681 |
+
print(error, flush=True)
|
| 682 |
+
attempt += 1
|
| 683 |
+
continue
|
| 684 |
+
|
| 685 |
+
consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
|
| 686 |
+
total_input_token += layout_cumulative_input_token
|
| 687 |
+
total_output_token += layout_cumulative_output_token
|
| 688 |
+
|
| 689 |
+
h1_path = f'tmp/poster_<parallel>_hierarchy_1.pptx'
|
| 690 |
+
h2_path = f'tmp/poster_<parallel>_hierarchy_2.pptx'
|
| 691 |
+
|
| 692 |
+
h1_filled_path = f'tmp/poster_<parallel>_hierarchy_1_filled.pptx'
|
| 693 |
+
h2_filled_path = f'tmp/poster_<parallel>_hierarchy_2_filled.pptx'
|
| 694 |
+
|
| 695 |
+
ppt_to_images(h1_path, 'tmp/layout_h1')
|
| 696 |
+
ppt_to_images(h2_path, 'tmp/layout_h2')
|
| 697 |
+
ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
|
| 698 |
+
ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
|
| 699 |
+
|
| 700 |
+
h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
|
| 701 |
+
h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
|
| 702 |
+
h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
|
| 703 |
+
h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
|
| 704 |
+
|
| 705 |
+
h1_critic_msg = BaseMessage.make_user_message(
|
| 706 |
+
role_name='User',
|
| 707 |
+
content=h1_critic_template.render(),
|
| 708 |
+
image_list=[h1_neg_example, h1_pos_example, h1_filled_img]
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
outline_bbox_dict = {}
|
| 712 |
+
for k, v in outline_location.items():
|
| 713 |
+
outline_bbox_dict[k] = v['location']
|
| 714 |
+
|
| 715 |
+
bbox_check_result = check_bounding_boxes(
|
| 716 |
+
outline_bbox_dict,
|
| 717 |
+
new_outline['meta']['width'],
|
| 718 |
+
new_outline['meta']['height']
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if len(bbox_check_result) != 0:
|
| 722 |
+
print(bbox_check_result, flush=True)
|
| 723 |
+
attempt += 1
|
| 724 |
+
continue
|
| 725 |
+
|
| 726 |
+
h1_critic_agent.reset()
|
| 727 |
+
response = h1_critic_agent.step(h1_critic_msg)
|
| 728 |
+
input_token, output_token = account_token(response)
|
| 729 |
+
consumption_log['h1_critic'].append((input_token, output_token))
|
| 730 |
+
total_input_token += input_token
|
| 731 |
+
total_output_token += output_token
|
| 732 |
+
if response.msgs[0].content == 'T':
|
| 733 |
+
print('Blank area detected.', flush=True)
|
| 734 |
+
attempt += 1
|
| 735 |
+
continue
|
| 736 |
+
|
| 737 |
+
print('Sucessfully generated outline.', flush=True)
|
| 738 |
+
|
| 739 |
+
break
|
| 740 |
+
|
| 741 |
+
outline_bbox_dict = {}
|
| 742 |
+
for k, v in outline_location.items():
|
| 743 |
+
outline_bbox_dict[k] = v['location']
|
| 744 |
+
|
| 745 |
+
# Generate subsection locations
|
| 746 |
+
outline_no_sub_locations = copy.deepcopy(new_outline)
|
| 747 |
+
if 'meta' in outline_no_sub_locations:
|
| 748 |
+
outline_no_sub_locations.pop('meta')
|
| 749 |
+
|
| 750 |
+
for k, v in outline_no_sub_locations.items():
|
| 751 |
+
if 'subsections' in v:
|
| 752 |
+
subsections = v['subsections']
|
| 753 |
+
for k_sub, v_sub in subsections.items():
|
| 754 |
+
del v_sub['location']
|
| 755 |
+
del v_sub['name']
|
| 756 |
+
|
| 757 |
+
h2_actor_template = jinja_env.from_string(config_h2_actor["template"])
|
| 758 |
+
|
| 759 |
+
h2_cumulative_input_token = 0
|
| 760 |
+
h2_cumulative_output_token = 0
|
| 761 |
+
|
| 762 |
+
updated_sections = []
|
| 763 |
+
|
| 764 |
+
with ThreadPoolExecutor() as executor:
|
| 765 |
+
# Kick off all tasks
|
| 766 |
+
future_to_section = {
|
| 767 |
+
executor.submit(
|
| 768 |
+
layout_h2_process_section,
|
| 769 |
+
section,
|
| 770 |
+
outline_no_sub_locations,
|
| 771 |
+
h2_actor_template,
|
| 772 |
+
create_h2_actor_agent # pass the factory function
|
| 773 |
+
): section
|
| 774 |
+
for section in sections
|
| 775 |
+
}
|
| 776 |
+
|
| 777 |
+
# Gather results as they complete
|
| 778 |
+
for future in concurrent.futures.as_completed(future_to_section):
|
| 779 |
+
section = future_to_section[future]
|
| 780 |
+
sec, final_sub_loc, in_toks, out_toks = future.result()
|
| 781 |
+
|
| 782 |
+
# Accumulate token usage
|
| 783 |
+
h2_cumulative_input_token += in_toks
|
| 784 |
+
h2_cumulative_output_token += out_toks
|
| 785 |
+
|
| 786 |
+
# Stash the final sub-loc for merging
|
| 787 |
+
updated_sections.append((sec, final_sub_loc))
|
| 788 |
+
|
| 789 |
+
# Now merge each updated subsection location back into outline_no_sub_locations
|
| 790 |
+
for (section, final_sub_loc) in updated_sections:
|
| 791 |
+
outline_no_sub_locations = fill_location(
|
| 792 |
+
outline_no_sub_locations,
|
| 793 |
+
section,
|
| 794 |
+
final_sub_loc
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
consumption_log['h2_actor'].append((h2_cumulative_input_token, h2_cumulative_output_token))
|
| 798 |
+
total_input_token += h2_cumulative_input_token
|
| 799 |
+
total_output_token += h2_cumulative_output_token
|
| 800 |
+
|
| 801 |
+
outline_no_sub_locations['meta'] = outline['meta']
|
| 802 |
+
outline_no_sub_locations_with_name = recover_name_and_location(outline_no_sub_locations, new_outline)
|
| 803 |
+
new_outline = outline_no_sub_locations_with_name
|
| 804 |
+
|
| 805 |
+
### Outline finalized, actually generate layout
|
| 806 |
+
|
| 807 |
+
logs = {}
|
| 808 |
+
|
| 809 |
+
gen_layout_cumulative_input_token = 0
|
| 810 |
+
gen_layout_cumulative_output_token = 0
|
| 811 |
+
init_outline = {'meta': outline['meta'], sections[0]: outline[sections[0]]}
|
| 812 |
+
|
| 813 |
+
new_outline = outline
|
| 814 |
+
|
| 815 |
+
init_jinja_args = {
|
| 816 |
+
'json_outline': init_outline,
|
| 817 |
+
'function_docs': documentation
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
outline_location = get_outline_location(outline, subsection=False)
|
| 821 |
+
logs = {}
|
| 822 |
+
|
| 823 |
+
# We'll store all updated logs here, keyed by section_index.
|
| 824 |
+
parallel_results = {}
|
| 825 |
+
|
| 826 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
| 827 |
+
futures = []
|
| 828 |
+
for section_index in range(len(sections)):
|
| 829 |
+
if section_index == 0:
|
| 830 |
+
create_agent_fn = create_init_actor_agent
|
| 831 |
+
prompt_template = init_template
|
| 832 |
+
else:
|
| 833 |
+
create_agent_fn = create_new_section_actor_agent
|
| 834 |
+
prompt_template = new_section_template
|
| 835 |
+
future = executor.submit(
|
| 836 |
+
layout_process_section,
|
| 837 |
+
section_index,
|
| 838 |
+
sections,
|
| 839 |
+
new_outline,
|
| 840 |
+
prompt_template,
|
| 841 |
+
documentation,
|
| 842 |
+
create_agent_fn,
|
| 843 |
+
args.max_retry,
|
| 844 |
+
name_to_hierarchy,
|
| 845 |
+
slide_width,
|
| 846 |
+
slide_height
|
| 847 |
+
)
|
| 848 |
+
futures.append(future)
|
| 849 |
+
|
| 850 |
+
# Collect the results as they come in
|
| 851 |
+
for future in concurrent.futures.as_completed(futures):
|
| 852 |
+
try:
|
| 853 |
+
section_index, section_logs, in_toks, out_toks = future.result()
|
| 854 |
+
|
| 855 |
+
# Store these logs in a dictionary keyed by the section index
|
| 856 |
+
parallel_results[section_index] = section_logs
|
| 857 |
+
|
| 858 |
+
# Update token counters
|
| 859 |
+
gen_layout_cumulative_input_token += in_toks
|
| 860 |
+
gen_layout_cumulative_output_token += out_toks
|
| 861 |
+
|
| 862 |
+
except Exception as exc:
|
| 863 |
+
print(f"[ERROR] A section failed: {exc}", flush=True)
|
| 864 |
+
# Possibly re-raise if you want to stop everything on error
|
| 865 |
+
# raise
|
| 866 |
+
|
| 867 |
+
# After all tasks complete, merge the results back into `logs`
|
| 868 |
+
for section_index, section_logs in parallel_results.items():
|
| 869 |
+
curr_section = sections[section_index]
|
| 870 |
+
logs[curr_section] = section_logs
|
| 871 |
+
|
| 872 |
+
concatenated_code = utils_functions
|
| 873 |
+
for section_index in range(len(sections)):
|
| 874 |
+
section = sections[section_index]
|
| 875 |
+
concatenated_code += '\n' + logs[section][-1]['code']
|
| 876 |
+
concatenated_code += '\n' + f'save_presentation(presentation, file_name="poster_{section_index + 1}.pptx")'
|
| 877 |
+
|
| 878 |
+
concatenated_code += f'''
|
| 879 |
+
name_to_hierarchy = {name_to_hierarchy}
|
| 880 |
+
identifier = "parallel"
|
| 881 |
+
poster_path = "poster_{section_index + 1}.pptx"
|
| 882 |
+
get_visual_cues(name_to_hierarchy, identifier, poster_path)
|
| 883 |
+
'''
|
| 884 |
+
output, error = run_code(concatenated_code)
|
| 885 |
+
if error is not None:
|
| 886 |
+
print(f'Failed to generate layout for section {curr_section}.')
|
| 887 |
+
|
| 888 |
+
consumption_log['h1_actor'].append((layout_cumulative_input_token, layout_cumulative_output_token))
|
| 889 |
+
total_input_token += gen_layout_cumulative_input_token
|
| 890 |
+
total_output_token += gen_layout_cumulative_output_token
|
| 891 |
+
|
| 892 |
+
h1_path = f'tmp/poster_<parallel>_hierarchy_1.pptx'
|
| 893 |
+
h2_path = f'tmp/poster_<parallel>_hierarchy_2.pptx'
|
| 894 |
+
|
| 895 |
+
h1_filled_path = f'tmp/poster_<parallel>_hierarchy_1_filled.pptx'
|
| 896 |
+
h2_filled_path = f'tmp/poster_<parallel>_hierarchy_2_filled.pptx'
|
| 897 |
+
|
| 898 |
+
ppt_to_images(h1_path, 'tmp/layout_h1')
|
| 899 |
+
ppt_to_images(h2_path, 'tmp/layout_h2')
|
| 900 |
+
ppt_to_images(h1_filled_path, 'tmp/layout_h1_filled')
|
| 901 |
+
ppt_to_images(h2_filled_path, 'tmp/layout_h2_filled')
|
| 902 |
+
|
| 903 |
+
h1_img = Image.open('tmp/layout_h1/slide_0001.jpg')
|
| 904 |
+
h2_img = Image.open('tmp/layout_h2/slide_0001.jpg')
|
| 905 |
+
h1_filled_img = Image.open('tmp/layout_h1_filled/slide_0001.jpg')
|
| 906 |
+
h2_filled_img = Image.open('tmp/layout_h2_filled/slide_0001.jpg')
|
| 907 |
+
|
| 908 |
+
|
| 909 |
+
ckpt = {
|
| 910 |
+
'logs': logs,
|
| 911 |
+
'outline': new_outline,
|
| 912 |
+
'name_to_hierarchy': name_to_hierarchy,
|
| 913 |
+
'consumption_log': consumption_log,
|
| 914 |
+
'total_input_token': total_input_token,
|
| 915 |
+
'total_output_token': total_output_token,
|
| 916 |
+
}
|
| 917 |
+
|
| 918 |
+
with open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'wb') as f:
|
| 919 |
+
pkl.dump(ckpt, f)
|
| 920 |
+
|
| 921 |
+
json.dump(
|
| 922 |
+
new_outline,
|
| 923 |
+
open(outline_file_path, "w"),
|
| 924 |
+
ensure_ascii=False,
|
| 925 |
+
indent=4,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
return total_input_token, total_output_token
|
| 929 |
+
|
| 930 |
+
if __name__ == '__main__':
|
| 931 |
+
parser = argparse.ArgumentParser()
|
| 932 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 933 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 934 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 935 |
+
parser.add_argument('--index', type=int, default=0)
|
| 936 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 937 |
+
args = parser.parse_args()
|
| 938 |
+
|
| 939 |
+
actor_config = get_agent_config(args.model_name)
|
| 940 |
+
critic_config = get_agent_config(args.model_name)
|
| 941 |
+
|
| 942 |
+
if args.poster_name is None:
|
| 943 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 944 |
+
|
| 945 |
+
input_token, output_token = filter_image_table(args, actor_config)
|
| 946 |
+
print(f'Token consumption: {input_token} -> {output_token}', flush=True)
|
| 947 |
+
|
| 948 |
+
input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
|
| 949 |
+
print(f'Token consumption: {input_token} -> {output_token}', flush=True)
|
Paper2Poster/PosterAgent/gen_poster_content.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tempfile
|
| 2 |
+
import shutil
|
| 3 |
+
from dotenv import load_dotenv
|
| 4 |
+
from utils.src.utils import get_json_from_response
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
from camel.models import ModelFactory
|
| 9 |
+
from PosterAgent.gen_pptx_code import generate_poster_code
|
| 10 |
+
from camel.agents import ChatAgent
|
| 11 |
+
from camel.messages import BaseMessage
|
| 12 |
+
from utils.src.utils import ppt_to_images
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
from utils.wei_utils import *
|
| 16 |
+
|
| 17 |
+
from utils.pptx_utils import *
|
| 18 |
+
from utils.critic_utils import *
|
| 19 |
+
import yaml
|
| 20 |
+
from jinja2 import Environment, StrictUndefined
|
| 21 |
+
import argparse
|
| 22 |
+
|
| 23 |
+
load_dotenv()
|
| 24 |
+
MAX_ATTEMPT = 10
|
| 25 |
+
|
| 26 |
+
def gen_content_process_section(
|
| 27 |
+
section_name,
|
| 28 |
+
outline,
|
| 29 |
+
raw_content,
|
| 30 |
+
raw_outline,
|
| 31 |
+
template,
|
| 32 |
+
create_actor_agent,
|
| 33 |
+
MAX_ATTEMPT
|
| 34 |
+
):
|
| 35 |
+
"""
|
| 36 |
+
Process a single section in its own thread or process.
|
| 37 |
+
Returns (section_name, result_json, total_input_token, total_output_token).
|
| 38 |
+
"""
|
| 39 |
+
# Create a fresh ActorAgent instance for each parallel call
|
| 40 |
+
actor_agent = create_actor_agent()
|
| 41 |
+
|
| 42 |
+
section_outline = ''
|
| 43 |
+
num_attempts = 0
|
| 44 |
+
total_input_token = 0
|
| 45 |
+
total_output_token = 0
|
| 46 |
+
result_json = None
|
| 47 |
+
|
| 48 |
+
while True:
|
| 49 |
+
print(f"[Thread] Generating content for section: {section_name}")
|
| 50 |
+
|
| 51 |
+
if len(section_outline) == 0:
|
| 52 |
+
# Initialize the section outline
|
| 53 |
+
section_outline = json.dumps(outline[section_name], indent=4)
|
| 54 |
+
|
| 55 |
+
# Render prompt using Jinja template
|
| 56 |
+
jinja_args = {
|
| 57 |
+
'json_outline': section_outline,
|
| 58 |
+
'json_content': raw_content,
|
| 59 |
+
}
|
| 60 |
+
prompt = template.render(**jinja_args)
|
| 61 |
+
|
| 62 |
+
# Step the actor_agent and track tokens
|
| 63 |
+
response = actor_agent.step(prompt)
|
| 64 |
+
input_token, output_token = account_token(response)
|
| 65 |
+
total_input_token += input_token
|
| 66 |
+
total_output_token += output_token
|
| 67 |
+
|
| 68 |
+
# Parse JSON and possibly adjust text length
|
| 69 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 70 |
+
new_section_outline, suggested = generate_length_suggestions(
|
| 71 |
+
result_json,
|
| 72 |
+
json.dumps(outline[section_name]),
|
| 73 |
+
raw_outline[section_name]
|
| 74 |
+
)
|
| 75 |
+
section_outline = json.dumps(new_section_outline, indent=4)
|
| 76 |
+
|
| 77 |
+
if not suggested:
|
| 78 |
+
# No more adjustments needed
|
| 79 |
+
break
|
| 80 |
+
|
| 81 |
+
print(f"[Thread] Adjusting text length for section: {section_name}...")
|
| 82 |
+
|
| 83 |
+
num_attempts += 1
|
| 84 |
+
if num_attempts >= MAX_ATTEMPT:
|
| 85 |
+
break
|
| 86 |
+
|
| 87 |
+
return section_name, result_json, total_input_token, total_output_token
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def gen_content_parallel_process_sections(
|
| 91 |
+
sections,
|
| 92 |
+
outline,
|
| 93 |
+
raw_content,
|
| 94 |
+
raw_outline,
|
| 95 |
+
template,
|
| 96 |
+
create_actor_agent,
|
| 97 |
+
MAX_ATTEMPT=3
|
| 98 |
+
):
|
| 99 |
+
"""
|
| 100 |
+
Parallelize the section processing using ThreadPoolExecutor.
|
| 101 |
+
"""
|
| 102 |
+
poster_content = {}
|
| 103 |
+
total_input_token = 0
|
| 104 |
+
total_output_token = 0
|
| 105 |
+
|
| 106 |
+
# Create a pool of worker threads (or processes)
|
| 107 |
+
with ThreadPoolExecutor() as executor:
|
| 108 |
+
futures = []
|
| 109 |
+
|
| 110 |
+
# Submit each section to be processed in parallel
|
| 111 |
+
for section_name in sections:
|
| 112 |
+
futures.append(
|
| 113 |
+
executor.submit(
|
| 114 |
+
gen_content_process_section,
|
| 115 |
+
section_name,
|
| 116 |
+
outline,
|
| 117 |
+
raw_content,
|
| 118 |
+
raw_outline,
|
| 119 |
+
template,
|
| 120 |
+
create_actor_agent,
|
| 121 |
+
MAX_ATTEMPT
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Collect results as they complete
|
| 126 |
+
for future in as_completed(futures):
|
| 127 |
+
section_name, result_json, sec_input_token, sec_output_token = future.result()
|
| 128 |
+
poster_content[section_name] = result_json
|
| 129 |
+
total_input_token += sec_input_token
|
| 130 |
+
total_output_token += sec_output_token
|
| 131 |
+
|
| 132 |
+
return poster_content, total_input_token, total_output_token
|
| 133 |
+
|
| 134 |
+
def render_textbox(text_arrangement, textbox_content, tmp_dir):
|
| 135 |
+
arrangement = copy.deepcopy(text_arrangement)
|
| 136 |
+
arrangement['x'] = 1
|
| 137 |
+
arrangement['y'] = 1
|
| 138 |
+
|
| 139 |
+
poster_code = generate_poster_code(
|
| 140 |
+
[],
|
| 141 |
+
[arrangement],
|
| 142 |
+
[],
|
| 143 |
+
presentation_object_name='poster_presentation',
|
| 144 |
+
slide_object_name='poster_slide',
|
| 145 |
+
utils_functions=utils_functions,
|
| 146 |
+
slide_width=text_arrangement['width'] + 3,
|
| 147 |
+
slide_height=text_arrangement['height'] + 3,
|
| 148 |
+
img_path='placeholder.jpg',
|
| 149 |
+
save_path=f'{tmp_dir}/poster.pptx',
|
| 150 |
+
visible=True,
|
| 151 |
+
content=textbox_content,
|
| 152 |
+
check_overflow=True,
|
| 153 |
+
tmp_dir=tmp_dir,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
output, err = run_code(poster_code)
|
| 157 |
+
ppt_to_images(f'{tmp_dir}/poster.pptx', tmp_dir, output_type='jpg')
|
| 158 |
+
img = Image.open(f'{tmp_dir}/poster.jpg')
|
| 159 |
+
|
| 160 |
+
return img
|
| 161 |
+
|
| 162 |
+
def gen_poster_title_content(args, actor_config):
|
| 163 |
+
total_input_token, total_output_token = 0, 0
|
| 164 |
+
raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
|
| 165 |
+
actor_agent_name = 'poster_title_agent'
|
| 166 |
+
|
| 167 |
+
title_string = raw_content['meta']
|
| 168 |
+
|
| 169 |
+
with open(f'utils/prompt_templates/{actor_agent_name}.yaml', "r") as f:
|
| 170 |
+
content_config = yaml.safe_load(f)
|
| 171 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 172 |
+
template = jinja_env.from_string(content_config["template"])
|
| 173 |
+
|
| 174 |
+
if args.model_name_t == 'vllm_qwen':
|
| 175 |
+
actor_model = ModelFactory.create(
|
| 176 |
+
model_platform=actor_config['model_platform'],
|
| 177 |
+
model_type=actor_config['model_type'],
|
| 178 |
+
model_config_dict=actor_config['model_config'],
|
| 179 |
+
url=actor_config['url'],
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
actor_model = ModelFactory.create(
|
| 183 |
+
model_platform=actor_config['model_platform'],
|
| 184 |
+
model_type=actor_config['model_type'],
|
| 185 |
+
model_config_dict=actor_config['model_config']
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
actor_sys_msg = content_config['system_prompt']
|
| 189 |
+
actor_agent = ChatAgent(
|
| 190 |
+
system_message=actor_sys_msg,
|
| 191 |
+
model=actor_model,
|
| 192 |
+
message_window_size=30
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
jinja_args = {
|
| 196 |
+
'title_string': title_string,
|
| 197 |
+
'title_font_size': getattr(args, 'poster_title_font_size', None) or getattr(args, 'title_font_size', None),
|
| 198 |
+
'author_font_size': getattr(args, 'poster_author_font_size', None) or getattr(args, 'author_font_size', None),
|
| 199 |
+
}
|
| 200 |
+
prompt = template.render(**jinja_args)
|
| 201 |
+
# Step the actor_agent and track tokens
|
| 202 |
+
actor_agent.reset()
|
| 203 |
+
response = actor_agent.step(prompt)
|
| 204 |
+
input_token, output_token = account_token(response)
|
| 205 |
+
total_input_token += input_token
|
| 206 |
+
total_output_token += output_token
|
| 207 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 208 |
+
|
| 209 |
+
return result_json, total_input_token, total_output_token
|
| 210 |
+
|
| 211 |
+
def gen_bullet_point_content(args, actor_config, critic_config, agent_modify=True, tmp_dir='tmp'):
|
| 212 |
+
import json, yaml, copy, threading
|
| 213 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 214 |
+
from PIL import Image
|
| 215 |
+
from jinja2 import Environment, StrictUndefined
|
| 216 |
+
|
| 217 |
+
# ----------------------- Load data & configs -----------------------
|
| 218 |
+
total_input_token_t = total_output_token_t = 0
|
| 219 |
+
total_input_token_v = total_output_token_v = 0
|
| 220 |
+
|
| 221 |
+
raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
|
| 222 |
+
with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'r') as f:
|
| 223 |
+
tree_split_results = json.load(f)
|
| 224 |
+
|
| 225 |
+
panels = tree_split_results['panels']
|
| 226 |
+
text_arrangement_list = tree_split_results['text_arrangement_inches']
|
| 227 |
+
|
| 228 |
+
actor_agent_name = 'bullet_point_agent'
|
| 229 |
+
if args.model_name_v == 'vllm_qwen_vl':
|
| 230 |
+
critic_agent_name = 'critic_overlap_agent_v3_short'
|
| 231 |
+
else:
|
| 232 |
+
critic_agent_name = 'critic_overlap_agent_v3'
|
| 233 |
+
|
| 234 |
+
with open(f"utils/prompt_templates/{actor_agent_name}.yaml", "r") as f:
|
| 235 |
+
content_config = yaml.safe_load(f)
|
| 236 |
+
with open(f"utils/prompt_templates/{critic_agent_name}.yaml", "r") as f:
|
| 237 |
+
critic_content_config = yaml.safe_load(f)
|
| 238 |
+
|
| 239 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 240 |
+
template = jinja_env.from_string(content_config["template"])
|
| 241 |
+
critic_template = jinja_env.from_string(critic_content_config["template"])
|
| 242 |
+
|
| 243 |
+
# Preload images once (each worker can reopen if needed, or just pass paths)
|
| 244 |
+
neg_img_path = 'assets/overflow_example_v2/neg.jpg'
|
| 245 |
+
pos_img_path = 'assets/overflow_example_v2/pos.jpg'
|
| 246 |
+
|
| 247 |
+
# Group text arrangements by panel_id for O(1) lookup in workers
|
| 248 |
+
from collections import defaultdict
|
| 249 |
+
textboxes_by_panel = defaultdict(list)
|
| 250 |
+
for ta in text_arrangement_list:
|
| 251 |
+
textboxes_by_panel[ta['panel_id']].append(ta)
|
| 252 |
+
# Ensure deterministic order inside each panel
|
| 253 |
+
for k in textboxes_by_panel:
|
| 254 |
+
textboxes_by_panel[k] = sorted(textboxes_by_panel[k], key=lambda x: x.get('textbox_id', 0))
|
| 255 |
+
|
| 256 |
+
# ----------------------- Worker (defined INSIDE main fn) -----------------------
|
| 257 |
+
def _process_section(i):
|
| 258 |
+
"""
|
| 259 |
+
Returns:
|
| 260 |
+
(i, result_json, t_in, t_out, v_in, v_out)
|
| 261 |
+
"""
|
| 262 |
+
local_t_in = local_t_out = 0
|
| 263 |
+
local_v_in = local_v_out = 0
|
| 264 |
+
|
| 265 |
+
arrangement = panels[i]
|
| 266 |
+
num_textboxes = 2 if arrangement.get('gp', 0) > 0 else 1
|
| 267 |
+
|
| 268 |
+
local_tmp_dir = tempfile.mkdtemp(prefix=f"sec_{i}_", dir=tmp_dir)
|
| 269 |
+
|
| 270 |
+
jinja_args = {
|
| 271 |
+
'summary_of_section': raw_content['sections'][i]['content'],
|
| 272 |
+
'number_of_textboxes': num_textboxes,
|
| 273 |
+
'section_title': raw_content['sections'][i]['title'],
|
| 274 |
+
'bullet_font_size': args.bullet_font_size,
|
| 275 |
+
'section_title_font_size': args.section_title_font_size,
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
target_textboxes = textboxes_by_panel[i][1:] # skip first (section title)
|
| 279 |
+
total_expected_length = sum(tb['num_chars'] for tb in target_textboxes)
|
| 280 |
+
|
| 281 |
+
# Create fresh models & agents per thread for safety
|
| 282 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 283 |
+
actor_model = ModelFactory.create(
|
| 284 |
+
model_platform=actor_config['model_platform'],
|
| 285 |
+
model_type=actor_config['model_type'],
|
| 286 |
+
model_config_dict=actor_config['model_config'],
|
| 287 |
+
url=actor_config['url'],
|
| 288 |
+
)
|
| 289 |
+
else:
|
| 290 |
+
actor_model = ModelFactory.create(
|
| 291 |
+
model_platform=actor_config['model_platform'],
|
| 292 |
+
model_type=actor_config['model_type'],
|
| 293 |
+
model_config_dict=actor_config['model_config']
|
| 294 |
+
)
|
| 295 |
+
if args.model_name_v.startswith('vllm_qwen'):
|
| 296 |
+
critic_model = ModelFactory.create(
|
| 297 |
+
model_platform=critic_config['model_platform'],
|
| 298 |
+
model_type=critic_config['model_type'],
|
| 299 |
+
model_config_dict=critic_config['model_config'],
|
| 300 |
+
url=critic_config['url'],
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
critic_model = ModelFactory.create(
|
| 304 |
+
model_platform=critic_config['model_platform'],
|
| 305 |
+
model_type=critic_config['model_type'],
|
| 306 |
+
model_config_dict=critic_config['model_config']
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
actor_agent = ChatAgent(system_message=content_config['system_prompt'], model=actor_model, message_window_size=30)
|
| 310 |
+
critic_agent = ChatAgent(system_message=critic_content_config['system_prompt'], model=critic_model, message_window_size=10)
|
| 311 |
+
|
| 312 |
+
prompt = template.render(**jinja_args)
|
| 313 |
+
actor_agent.reset()
|
| 314 |
+
response = actor_agent.step(prompt)
|
| 315 |
+
t_in, t_out = account_token(response)
|
| 316 |
+
local_t_in += t_in
|
| 317 |
+
local_t_out += t_out
|
| 318 |
+
|
| 319 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 320 |
+
|
| 321 |
+
max_attempts = 5
|
| 322 |
+
num_attempts = 0
|
| 323 |
+
old_result_json = copy.deepcopy(result_json)
|
| 324 |
+
|
| 325 |
+
# Length control loop
|
| 326 |
+
while args.estimate_chars:
|
| 327 |
+
num_attempts += 1
|
| 328 |
+
if num_attempts > max_attempts:
|
| 329 |
+
result_json = old_result_json
|
| 330 |
+
break
|
| 331 |
+
try:
|
| 332 |
+
total_bullet_length = 0
|
| 333 |
+
for j in range(num_textboxes):
|
| 334 |
+
bullet_content_key = f'textbox{j + 1}'
|
| 335 |
+
total_bullet_length += compute_bullet_length(result_json[bullet_content_key])
|
| 336 |
+
except Exception:
|
| 337 |
+
result_json = old_result_json
|
| 338 |
+
break
|
| 339 |
+
|
| 340 |
+
if total_bullet_length > total_expected_length:
|
| 341 |
+
percentage_to_shrink = int((total_bullet_length - total_expected_length) / total_bullet_length * 100)
|
| 342 |
+
percentage_to_shrink = min(90, percentage_to_shrink + 10)
|
| 343 |
+
old_result_json = copy.deepcopy(result_json)
|
| 344 |
+
response = actor_agent.step('Too long, please shorten the bullet points by ' + str(percentage_to_shrink) + '%.')
|
| 345 |
+
t_in, t_out = account_token(response)
|
| 346 |
+
local_t_in += t_in
|
| 347 |
+
local_t_out += t_out
|
| 348 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 349 |
+
else:
|
| 350 |
+
break
|
| 351 |
+
|
| 352 |
+
critic_prompt = critic_template.render()
|
| 353 |
+
bullet_contents = ['textbox1'] + (['textbox2'] if num_textboxes == 2 else [])
|
| 354 |
+
|
| 355 |
+
# Visual overflow/blank detection & correction
|
| 356 |
+
for j, text_arrangement in enumerate(target_textboxes[:num_textboxes]):
|
| 357 |
+
bullet_content = bullet_contents[j]
|
| 358 |
+
curr_round = 0
|
| 359 |
+
while True:
|
| 360 |
+
if args.ablation_no_commenter:
|
| 361 |
+
break
|
| 362 |
+
curr_round += 1
|
| 363 |
+
img = render_textbox(text_arrangement, result_json[bullet_content], local_tmp_dir)
|
| 364 |
+
if args.model_name_v.startswith('vllm_qwen') or args.ablation_no_example:
|
| 365 |
+
critic_msg = BaseMessage.make_user_message(
|
| 366 |
+
role_name="User",
|
| 367 |
+
content=critic_prompt,
|
| 368 |
+
image_list=[img],
|
| 369 |
+
)
|
| 370 |
+
else:
|
| 371 |
+
critic_msg = BaseMessage.make_user_message(
|
| 372 |
+
role_name="User",
|
| 373 |
+
content=critic_prompt,
|
| 374 |
+
image_list=[Image.open(neg_img_path), Image.open(pos_img_path), img],
|
| 375 |
+
)
|
| 376 |
+
|
| 377 |
+
critic_agent.reset()
|
| 378 |
+
response = critic_agent.step(critic_msg)
|
| 379 |
+
v_in, v_out = account_token(response)
|
| 380 |
+
local_v_in += v_in
|
| 381 |
+
local_v_out += v_out
|
| 382 |
+
|
| 383 |
+
decision = response.msgs[0].content.lower()
|
| 384 |
+
if decision in ['1', '1.', '"1"', "'1'"]:
|
| 385 |
+
if curr_round > 10:
|
| 386 |
+
print(f'Section {i}: Too many rounds of modification, breaking...')
|
| 387 |
+
break
|
| 388 |
+
if agent_modify:
|
| 389 |
+
print(f'Section {i}: Text overflow detected, modifying...')
|
| 390 |
+
modify_message = f'{bullet_content} is too long, please shorten that part, other content should stay the same. Return the entire modified JSON.'
|
| 391 |
+
response = actor_agent.step(modify_message)
|
| 392 |
+
t_in, t_out = account_token(response)
|
| 393 |
+
local_t_in += t_in
|
| 394 |
+
local_t_out += t_out
|
| 395 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 396 |
+
else:
|
| 397 |
+
# naive truncate
|
| 398 |
+
result_json[bullet_content] = result_json[bullet_content][:-1]
|
| 399 |
+
continue
|
| 400 |
+
elif decision in ['2', '2.', '"2"', "'2'"]:
|
| 401 |
+
if args.no_blank_detection:
|
| 402 |
+
print(f'Section {i}: No blank space detection, skipping...')
|
| 403 |
+
break
|
| 404 |
+
if curr_round > 10:
|
| 405 |
+
print(f'Section {i}: Too many rounds of modification, breaking...')
|
| 406 |
+
break
|
| 407 |
+
print(f'Section {i}: Too much blank space detected, modifying...')
|
| 408 |
+
modify_message = f'{bullet_content} is too short, please add one more bullet point, other content should stay the same. Return the entire modified JSON.'
|
| 409 |
+
response = actor_agent.step(modify_message)
|
| 410 |
+
t_in, t_out = account_token(response)
|
| 411 |
+
local_t_in += t_in
|
| 412 |
+
local_t_out += t_out
|
| 413 |
+
result_json = get_json_from_response(response.msgs[0].content)
|
| 414 |
+
else:
|
| 415 |
+
break
|
| 416 |
+
|
| 417 |
+
# Clean up temp dir
|
| 418 |
+
if local_tmp_dir:
|
| 419 |
+
try:
|
| 420 |
+
print(f'Section {i}: Cleaning up temp dir {local_tmp_dir}')
|
| 421 |
+
shutil.rmtree(local_tmp_dir)
|
| 422 |
+
except Exception as e:
|
| 423 |
+
print(f"Error cleaning up temp dir {local_tmp_dir}: {e}")
|
| 424 |
+
return i, result_json, local_t_in, local_t_out, local_v_in, local_v_out
|
| 425 |
+
|
| 426 |
+
# ----------------------- Parallel execution -----------------------
|
| 427 |
+
max_workers = getattr(args, 'max_workers', 4)
|
| 428 |
+
results = {}
|
| 429 |
+
lock = threading.Lock()
|
| 430 |
+
|
| 431 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 432 |
+
futures = {
|
| 433 |
+
ex.submit(_process_section, i): i
|
| 434 |
+
for i in range(1, len(raw_content['sections']))
|
| 435 |
+
}
|
| 436 |
+
for fut in as_completed(futures):
|
| 437 |
+
i, rjson, t_in, t_out, v_in, v_out = fut.result()
|
| 438 |
+
with lock:
|
| 439 |
+
results[i] = rjson
|
| 440 |
+
total_input_token_t += t_in
|
| 441 |
+
total_output_token_t += t_out
|
| 442 |
+
total_input_token_v += v_in
|
| 443 |
+
total_output_token_v += v_out
|
| 444 |
+
|
| 445 |
+
# ----------------------- Title generation (sequential) -----------------------
|
| 446 |
+
title_json, title_input_token, title_output_token = gen_poster_title_content(args, actor_config)
|
| 447 |
+
total_input_token_t += title_input_token
|
| 448 |
+
total_output_token_t += title_output_token
|
| 449 |
+
|
| 450 |
+
# ----------------------- Assemble & save -----------------------
|
| 451 |
+
bullet_point_content = [title_json]
|
| 452 |
+
for idx in range(1, len(raw_content['sections'])):
|
| 453 |
+
bullet_point_content.append(results[idx])
|
| 454 |
+
|
| 455 |
+
json.dump(
|
| 456 |
+
bullet_point_content,
|
| 457 |
+
open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'w'),
|
| 458 |
+
indent=2
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
return total_input_token_t, total_output_token_t, total_input_token_v, total_output_token_v
|
| 462 |
+
|
| 463 |
+
def gen_poster_content(args, actor_config):
|
| 464 |
+
total_input_token, total_output_token = 0, 0
|
| 465 |
+
raw_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
|
| 466 |
+
agent_name = 'poster_content_agent'
|
| 467 |
+
|
| 468 |
+
with open(f"utils/prompt_templates/{agent_name}.yaml", "r") as f:
|
| 469 |
+
content_config = yaml.safe_load(f)
|
| 470 |
+
|
| 471 |
+
actor_model = ModelFactory.create(
|
| 472 |
+
model_platform=actor_config['model_platform'],
|
| 473 |
+
model_type=actor_config['model_type'],
|
| 474 |
+
model_config_dict=actor_config['model_config']
|
| 475 |
+
)
|
| 476 |
+
|
| 477 |
+
actor_sys_msg = content_config['system_prompt']
|
| 478 |
+
|
| 479 |
+
def create_actor_agent():
|
| 480 |
+
actor_agent = ChatAgent(
|
| 481 |
+
system_message=actor_sys_msg,
|
| 482 |
+
model=actor_model,
|
| 483 |
+
message_window_size=10
|
| 484 |
+
)
|
| 485 |
+
return actor_agent
|
| 486 |
+
|
| 487 |
+
outline = json.load(open(f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json', 'r'))
|
| 488 |
+
raw_outline = json.loads(json.dumps(outline))
|
| 489 |
+
outline_estimate_num_chars(outline)
|
| 490 |
+
outline = remove_hierarchy_and_id(outline)
|
| 491 |
+
|
| 492 |
+
sections = list(outline.keys())
|
| 493 |
+
sections = [s for s in sections if s != 'meta']
|
| 494 |
+
|
| 495 |
+
jinja_env = Environment(undefined=StrictUndefined)
|
| 496 |
+
|
| 497 |
+
template = jinja_env.from_string(content_config["template"])
|
| 498 |
+
|
| 499 |
+
poster_content = {}
|
| 500 |
+
|
| 501 |
+
poster_content, total_input_token, total_output_token = gen_content_parallel_process_sections(
|
| 502 |
+
sections,
|
| 503 |
+
outline,
|
| 504 |
+
raw_content,
|
| 505 |
+
raw_outline,
|
| 506 |
+
template,
|
| 507 |
+
create_actor_agent,
|
| 508 |
+
MAX_ATTEMPT=5
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
json.dump(poster_content, open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'w'), indent=2)
|
| 512 |
+
return total_input_token, total_output_token
|
| 513 |
+
|
| 514 |
+
if __name__ == '__main__':
|
| 515 |
+
parser = argparse.ArgumentParser()
|
| 516 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 517 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 518 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 519 |
+
parser.add_argument('--index', type=int, default=0)
|
| 520 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 521 |
+
args = parser.parse_args()
|
| 522 |
+
|
| 523 |
+
actor_config = get_agent_config(args.model_name)
|
| 524 |
+
if args.poster_name is None:
|
| 525 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 526 |
+
|
| 527 |
+
input_token, output_token = gen_poster_content(args, actor_config)
|
| 528 |
+
|
| 529 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/gen_pptx_code.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
def sanitize_for_var(name):
|
| 5 |
+
# Convert any character that is not alphanumeric or underscore into underscore.
|
| 6 |
+
return re.sub(r'[^0-9a-zA-Z_]+', '_', name)
|
| 7 |
+
|
| 8 |
+
def initialize_poster_code(width, height, slide_object_name, presentation_object_name, utils_functions):
|
| 9 |
+
code = utils_functions
|
| 10 |
+
code += fr'''
|
| 11 |
+
# Poster: {presentation_object_name}
|
| 12 |
+
{presentation_object_name} = create_poster(width_inch={width}, height_inch={height})
|
| 13 |
+
|
| 14 |
+
# Slide: {slide_object_name}
|
| 15 |
+
{slide_object_name} = add_blank_slide({presentation_object_name})
|
| 16 |
+
'''
|
| 17 |
+
|
| 18 |
+
return code
|
| 19 |
+
|
| 20 |
+
def save_poster_code(output_file, utils_functions, presentation_object_name):
|
| 21 |
+
code = utils_functions
|
| 22 |
+
code = fr'''
|
| 23 |
+
# Save the presentation
|
| 24 |
+
save_presentation({presentation_object_name}, file_name="{output_file}")
|
| 25 |
+
'''
|
| 26 |
+
return code
|
| 27 |
+
|
| 28 |
+
def generate_panel_code(panel_dict, utils_functions, slide_object_name, visible=False, theme=None):
|
| 29 |
+
code = utils_functions
|
| 30 |
+
raw_name = panel_dict["panel_name"]
|
| 31 |
+
var_name = 'var_' + sanitize_for_var(raw_name)
|
| 32 |
+
|
| 33 |
+
code += fr'''
|
| 34 |
+
# Panel: {raw_name}
|
| 35 |
+
{var_name} = add_textbox(
|
| 36 |
+
{slide_object_name},
|
| 37 |
+
'{var_name}',
|
| 38 |
+
{panel_dict['x']},
|
| 39 |
+
{panel_dict['y']},
|
| 40 |
+
{panel_dict['width']},
|
| 41 |
+
{panel_dict['height']},
|
| 42 |
+
text="",
|
| 43 |
+
word_wrap=True,
|
| 44 |
+
font_size=40,
|
| 45 |
+
bold=False,
|
| 46 |
+
italic=False,
|
| 47 |
+
alignment="left",
|
| 48 |
+
fill_color=None,
|
| 49 |
+
font_name="Arial"
|
| 50 |
+
)
|
| 51 |
+
'''
|
| 52 |
+
|
| 53 |
+
if visible:
|
| 54 |
+
if theme is None:
|
| 55 |
+
code += fr'''
|
| 56 |
+
# Make border visible
|
| 57 |
+
style_shape_border({var_name}, color=(0, 0, 0), thickness=5, line_style="solid")
|
| 58 |
+
'''
|
| 59 |
+
else:
|
| 60 |
+
code += fr'''
|
| 61 |
+
# Make border visible
|
| 62 |
+
style_shape_border({var_name}, color={theme['color']}, thickness={theme['thickness']}, line_style="{theme['line_style']}")
|
| 63 |
+
'''
|
| 64 |
+
|
| 65 |
+
return code
|
| 66 |
+
|
| 67 |
+
def generate_textbox_code(
|
| 68 |
+
text_dict,
|
| 69 |
+
utils_functions,
|
| 70 |
+
slide_object_name,
|
| 71 |
+
visible=False,
|
| 72 |
+
content=None,
|
| 73 |
+
theme=None,
|
| 74 |
+
tmp_dir='tmp',
|
| 75 |
+
is_title=False,
|
| 76 |
+
):
|
| 77 |
+
code = utils_functions
|
| 78 |
+
raw_name = text_dict["textbox_name"]
|
| 79 |
+
var_name = sanitize_for_var(raw_name)
|
| 80 |
+
|
| 81 |
+
code += fr'''
|
| 82 |
+
# Textbox: {raw_name}
|
| 83 |
+
{var_name} = add_textbox(
|
| 84 |
+
{slide_object_name},
|
| 85 |
+
'{var_name}',
|
| 86 |
+
{text_dict['x']},
|
| 87 |
+
{text_dict['y']},
|
| 88 |
+
{text_dict['width']},
|
| 89 |
+
{text_dict['height']},
|
| 90 |
+
text="",
|
| 91 |
+
word_wrap=True,
|
| 92 |
+
font_size=40,
|
| 93 |
+
bold=False,
|
| 94 |
+
italic=False,
|
| 95 |
+
alignment="left",
|
| 96 |
+
fill_color=None,
|
| 97 |
+
font_name="Arial"
|
| 98 |
+
)
|
| 99 |
+
'''
|
| 100 |
+
if visible:
|
| 101 |
+
# Extract textbox_theme from full theme if needed
|
| 102 |
+
textbox_border_theme = None
|
| 103 |
+
if theme is not None and isinstance(theme, dict):
|
| 104 |
+
textbox_border_theme = theme.get('textbox_theme')
|
| 105 |
+
|
| 106 |
+
if textbox_border_theme is None:
|
| 107 |
+
code += fr'''
|
| 108 |
+
# Make border visible
|
| 109 |
+
style_shape_border({var_name}, color=(255, 0, 0), thickness=5, line_style="solid")
|
| 110 |
+
'''
|
| 111 |
+
else:
|
| 112 |
+
code += fr'''
|
| 113 |
+
# Make border visible
|
| 114 |
+
style_shape_border({var_name}, color={textbox_border_theme['color']}, thickness={textbox_border_theme['thickness']}, line_style="{textbox_border_theme['line_style']}")
|
| 115 |
+
'''
|
| 116 |
+
|
| 117 |
+
if content is not None:
|
| 118 |
+
tmp_name = f'{tmp_dir}/{var_name}_content.json'
|
| 119 |
+
json.dump(content, open(tmp_name, 'w'), indent=4)
|
| 120 |
+
|
| 121 |
+
# Determine vertical alignment
|
| 122 |
+
vertical_anchor = None
|
| 123 |
+
if is_title and theme is not None and 'section_title_vertical_align' in theme:
|
| 124 |
+
vertical_anchor = theme['section_title_vertical_align']
|
| 125 |
+
|
| 126 |
+
if vertical_anchor:
|
| 127 |
+
code += fr'''
|
| 128 |
+
fill_textframe({var_name}, json.load(open('{tmp_name}', 'r')), vertical_anchor="{vertical_anchor}")
|
| 129 |
+
'''
|
| 130 |
+
else:
|
| 131 |
+
code += fr'''
|
| 132 |
+
fill_textframe({var_name}, json.load(open('{tmp_name}', 'r')))
|
| 133 |
+
'''
|
| 134 |
+
|
| 135 |
+
return code
|
| 136 |
+
|
| 137 |
+
def generate_figure_code(figure_dict, utils_functions, slide_object_name, img_path, visible=False, theme=None):
|
| 138 |
+
code = utils_functions
|
| 139 |
+
raw_name = figure_dict["figure_name"]
|
| 140 |
+
var_name = sanitize_for_var(raw_name)
|
| 141 |
+
|
| 142 |
+
code += fr'''
|
| 143 |
+
# Figure: {raw_name}
|
| 144 |
+
{var_name} = add_image(
|
| 145 |
+
{slide_object_name},
|
| 146 |
+
'{var_name}',
|
| 147 |
+
{figure_dict['x']},
|
| 148 |
+
{figure_dict['y']},
|
| 149 |
+
{figure_dict['width']},
|
| 150 |
+
{figure_dict['height']},
|
| 151 |
+
image_path="{img_path}"
|
| 152 |
+
)
|
| 153 |
+
'''
|
| 154 |
+
|
| 155 |
+
if visible:
|
| 156 |
+
if theme is None:
|
| 157 |
+
code += fr'''
|
| 158 |
+
# Make border visible
|
| 159 |
+
style_shape_border({var_name}, color=(0, 0, 255), thickness=5, line_style="long_dash_dot")
|
| 160 |
+
'''
|
| 161 |
+
else:
|
| 162 |
+
code += fr'''
|
| 163 |
+
# Make border visible
|
| 164 |
+
style_shape_border({var_name}, color={theme['color']}, thickness={theme['thickness']}, line_style="{theme['line_style']}")
|
| 165 |
+
'''
|
| 166 |
+
|
| 167 |
+
return code
|
| 168 |
+
|
| 169 |
+
def generate_poster_code(
|
| 170 |
+
panel_arrangement_list,
|
| 171 |
+
text_arrangement_list,
|
| 172 |
+
figure_arrangement_list,
|
| 173 |
+
presentation_object_name,
|
| 174 |
+
slide_object_name,
|
| 175 |
+
utils_functions,
|
| 176 |
+
slide_width,
|
| 177 |
+
slide_height,
|
| 178 |
+
img_path,
|
| 179 |
+
save_path,
|
| 180 |
+
visible=False,
|
| 181 |
+
content=None,
|
| 182 |
+
check_overflow=False,
|
| 183 |
+
theme=None,
|
| 184 |
+
tmp_dir='tmp',
|
| 185 |
+
):
|
| 186 |
+
code = ''
|
| 187 |
+
code += initialize_poster_code(slide_width, slide_height, slide_object_name, presentation_object_name, utils_functions)
|
| 188 |
+
|
| 189 |
+
if theme is None:
|
| 190 |
+
panel_visible = visible
|
| 191 |
+
textbox_visible = visible
|
| 192 |
+
figure_visible = visible
|
| 193 |
+
|
| 194 |
+
panel_theme, textbox_theme, figure_theme = None, None, None
|
| 195 |
+
else:
|
| 196 |
+
panel_visible = theme['panel_visible']
|
| 197 |
+
textbox_visible = theme['textbox_visible']
|
| 198 |
+
figure_visible = theme['figure_visible']
|
| 199 |
+
panel_theme = theme['panel_theme']
|
| 200 |
+
textbox_theme = theme['textbox_theme']
|
| 201 |
+
figure_theme = theme['figure_theme']
|
| 202 |
+
|
| 203 |
+
for p in panel_arrangement_list:
|
| 204 |
+
code += generate_panel_code(p, '', slide_object_name, panel_visible, panel_theme)
|
| 205 |
+
|
| 206 |
+
if check_overflow:
|
| 207 |
+
t = text_arrangement_list[0]
|
| 208 |
+
# Pass full theme for consistency
|
| 209 |
+
code += generate_textbox_code(t, '', slide_object_name, textbox_visible, content, theme, tmp_dir, is_title=False)
|
| 210 |
+
else:
|
| 211 |
+
all_content = []
|
| 212 |
+
title_indices = set() # Track which indices are section titles
|
| 213 |
+
if content is not None:
|
| 214 |
+
idx = 0
|
| 215 |
+
for section_content in content:
|
| 216 |
+
if 'title' in section_content:
|
| 217 |
+
all_content.append(section_content['title'])
|
| 218 |
+
title_indices.add(idx) # Mark this index as a title
|
| 219 |
+
idx += 1
|
| 220 |
+
if len(section_content) == 2:
|
| 221 |
+
all_content.append(section_content['textbox1'])
|
| 222 |
+
idx += 1
|
| 223 |
+
elif len(section_content) == 3:
|
| 224 |
+
all_content.append(section_content['textbox1'])
|
| 225 |
+
all_content.append(section_content['textbox2'])
|
| 226 |
+
idx += 2
|
| 227 |
+
else:
|
| 228 |
+
raise ValueError(f"Unexpected content length: {len(section_content)}")
|
| 229 |
+
|
| 230 |
+
for i in range(len(text_arrangement_list)):
|
| 231 |
+
t = text_arrangement_list[i]
|
| 232 |
+
if content is not None:
|
| 233 |
+
textbox_content = all_content[i]
|
| 234 |
+
is_title = i in title_indices
|
| 235 |
+
else:
|
| 236 |
+
textbox_content = None
|
| 237 |
+
is_title = False
|
| 238 |
+
# Pass full theme (not textbox_theme) so vertical alignment config is available
|
| 239 |
+
code += generate_textbox_code(t, '', slide_object_name, textbox_visible, textbox_content, theme, tmp_dir, is_title=is_title)
|
| 240 |
+
|
| 241 |
+
for f in figure_arrangement_list:
|
| 242 |
+
if img_path is None:
|
| 243 |
+
code += generate_figure_code(f, '', slide_object_name, f['figure_path'], figure_visible, figure_theme)
|
| 244 |
+
else:
|
| 245 |
+
code += generate_figure_code(f, '', slide_object_name, img_path, figure_visible, figure_theme)
|
| 246 |
+
|
| 247 |
+
code += save_poster_code(save_path, '', presentation_object_name)
|
| 248 |
+
|
| 249 |
+
return code
|
Paper2Poster/PosterAgent/new_pipeline.py
ADDED
|
@@ -0,0 +1,547 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
print("Initializing...")
|
| 3 |
+
from PosterAgent.parse_raw import parse_raw, gen_image_and_table
|
| 4 |
+
from PosterAgent.gen_outline_layout import filter_image_table, gen_outline_layout_v2
|
| 5 |
+
from utils.wei_utils import get_agent_config, utils_functions, run_code, scale_to_target_area, char_capacity
|
| 6 |
+
from PosterAgent.tree_split_layout import main_train, main_inference, get_arrangments_in_inches, split_textbox, to_inches
|
| 7 |
+
# from PosterAgent.gen_pptx_code import generate_poster_code
|
| 8 |
+
# from utils.src.utils import ppt_to_images
|
| 9 |
+
# from PosterAgent.gen_poster_content import gen_bullet_point_content
|
| 10 |
+
from utils.ablation_utils import no_tree_get_layout
|
| 11 |
+
|
| 12 |
+
# Import refactored utilities
|
| 13 |
+
from utils.logo_utils import LogoManager, add_logos_to_poster_code
|
| 14 |
+
# from utils.config_utils import (
|
| 15 |
+
# load_poster_yaml_config, extract_font_sizes, extract_colors,
|
| 16 |
+
# extract_vertical_alignment, extract_section_title_symbol, normalize_config_values
|
| 17 |
+
# )
|
| 18 |
+
# from utils.style_utils import apply_all_styles
|
| 19 |
+
# from utils.theme_utils import get_default_theme, create_theme_with_alignment, resolve_colors
|
| 20 |
+
|
| 21 |
+
# from PosterAgent.gen_beamer_code import (
|
| 22 |
+
# generate_beamer_poster_code,
|
| 23 |
+
# save_beamer_code,
|
| 24 |
+
# compile_beamer_to_pdf,
|
| 25 |
+
# convert_pptx_layout_to_beamer
|
| 26 |
+
# )
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
|
| 31 |
+
import time
|
| 32 |
+
import shutil
|
| 33 |
+
|
| 34 |
+
units_per_inch = 25
|
| 35 |
+
|
| 36 |
+
if __name__ == '__main__':
|
| 37 |
+
|
| 38 |
+
parser = argparse.ArgumentParser(description='Poster Generation Pipeline with Logo Support')
|
| 39 |
+
parser.add_argument('--poster_path', type=str)
|
| 40 |
+
parser.add_argument('--model_name_t', type=str, default='4o')
|
| 41 |
+
parser.add_argument('--model_name_v', type=str, default='4o')
|
| 42 |
+
parser.add_argument('--index', type=int, default=0)
|
| 43 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 44 |
+
parser.add_argument('--tmp_dir', type=str, default='tmp')
|
| 45 |
+
parser.add_argument('--estimate_chars', action='store_true')
|
| 46 |
+
parser.add_argument('--max_workers', type=int, default=10)
|
| 47 |
+
parser.add_argument('--poster_width_inches', type=int, default=None)
|
| 48 |
+
parser.add_argument('--poster_height_inches', type=int, default=None)
|
| 49 |
+
parser.add_argument('--no_blank_detection', action='store_true', help='When overflow is severe, try this option.')
|
| 50 |
+
parser.add_argument('--ablation_no_tree_layout', action='store_true', help='Ablation study: no tree layout')
|
| 51 |
+
parser.add_argument('--ablation_no_commenter', action='store_true', help='Ablation study: no commenter')
|
| 52 |
+
parser.add_argument('--ablation_no_example', action='store_true', help='Ablation study: no example')
|
| 53 |
+
|
| 54 |
+
# Logo-related arguments
|
| 55 |
+
parser.add_argument('--conference_venue', type=str, default=None,
|
| 56 |
+
help='Conference name for automatic logo search (e.g., "NeurIPS", "CVPR")')
|
| 57 |
+
parser.add_argument('--institution_logo_path', type=str, default=None,
|
| 58 |
+
help='Custom path to institution logo (auto-searches from paper metadata if not provided)')
|
| 59 |
+
parser.add_argument('--conference_logo_path', type=str, default=None,
|
| 60 |
+
help='Custom path to conference logo (auto-searches if venue specified)')
|
| 61 |
+
parser.add_argument('--use_google_search', action='store_true',
|
| 62 |
+
help='Use Google Custom Search API for logo search (requires API keys in .env)')
|
| 63 |
+
|
| 64 |
+
args = parser.parse_args()
|
| 65 |
+
|
| 66 |
+
start_time = time.time()
|
| 67 |
+
|
| 68 |
+
os.makedirs(args.tmp_dir, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
detail_log = {}
|
| 71 |
+
|
| 72 |
+
agent_config_t = get_agent_config(args.model_name_t)
|
| 73 |
+
agent_config_v = get_agent_config(args.model_name_v)
|
| 74 |
+
poster_name = args.poster_path.split('/')[-2].replace(' ', '_')
|
| 75 |
+
if args.poster_name is None:
|
| 76 |
+
args.poster_name = poster_name
|
| 77 |
+
else:
|
| 78 |
+
poster_name = args.poster_name
|
| 79 |
+
meta_json_path = args.poster_path.replace('paper.pdf', 'meta.json')
|
| 80 |
+
if args.poster_width_inches is not None and args.poster_height_inches is not None:
|
| 81 |
+
poster_width = args.poster_width_inches * units_per_inch
|
| 82 |
+
poster_height = args.poster_height_inches * units_per_inch
|
| 83 |
+
elif os.path.exists(meta_json_path):
|
| 84 |
+
meta_json = json.load(open(meta_json_path, 'r'))
|
| 85 |
+
poster_width = meta_json['width']
|
| 86 |
+
poster_height = meta_json['height']
|
| 87 |
+
else:
|
| 88 |
+
poster_width = 48 * units_per_inch
|
| 89 |
+
poster_height = 36 * units_per_inch
|
| 90 |
+
|
| 91 |
+
poster_width, poster_height = scale_to_target_area(poster_width, poster_height)
|
| 92 |
+
poster_width_inches = to_inches(poster_width, units_per_inch)
|
| 93 |
+
poster_height_inches = to_inches(poster_height, units_per_inch)
|
| 94 |
+
|
| 95 |
+
if poster_width_inches > 56 or poster_height_inches > 56:
|
| 96 |
+
# Work out which side is longer, then compute a single scale factor
|
| 97 |
+
if poster_width_inches >= poster_height_inches:
|
| 98 |
+
scale_factor = 56 / poster_width_inches
|
| 99 |
+
else:
|
| 100 |
+
scale_factor = 56 / poster_height_inches
|
| 101 |
+
|
| 102 |
+
poster_width_inches *= scale_factor
|
| 103 |
+
poster_height_inches *= scale_factor
|
| 104 |
+
|
| 105 |
+
# convert back to internal units
|
| 106 |
+
poster_width = poster_width_inches * units_per_inch
|
| 107 |
+
poster_height = poster_height_inches * units_per_inch
|
| 108 |
+
|
| 109 |
+
print(f'Poster size: {poster_width_inches} x {poster_height_inches} inches')
|
| 110 |
+
|
| 111 |
+
total_input_tokens_t, total_output_tokens_t = 0, 0
|
| 112 |
+
total_input_tokens_v, total_output_tokens_v = 0, 0
|
| 113 |
+
|
| 114 |
+
# Step 1: Parse the raw poster
|
| 115 |
+
input_token, output_token, raw_result = parse_raw(args, agent_config_t, version=2)
|
| 116 |
+
total_input_tokens_t += input_token
|
| 117 |
+
total_output_tokens_t += output_token
|
| 118 |
+
|
| 119 |
+
_, _, images, tables = gen_image_and_table(args, raw_result)
|
| 120 |
+
|
| 121 |
+
print(f'Parsing token consumption: {input_token} -> {output_token}')
|
| 122 |
+
|
| 123 |
+
parser_time_taken = time.time() - start_time
|
| 124 |
+
print(f'Parser time: {parser_time_taken:.2f} seconds')
|
| 125 |
+
detail_log['parser_time'] = parser_time_taken
|
| 126 |
+
|
| 127 |
+
parser_time = time.time()
|
| 128 |
+
|
| 129 |
+
detail_log['parser_in_t'] = input_token
|
| 130 |
+
detail_log['parser_out_t'] = output_token
|
| 131 |
+
|
| 132 |
+
# Initialize LogoManager
|
| 133 |
+
logo_manager = LogoManager()
|
| 134 |
+
institution_logo_path = args.institution_logo_path
|
| 135 |
+
conference_logo_path = args.conference_logo_path
|
| 136 |
+
|
| 137 |
+
# Auto-detect institution from paper if not provided
|
| 138 |
+
# Now using the raw_result directly instead of reading from file
|
| 139 |
+
if not institution_logo_path:
|
| 140 |
+
print("\n" + "="*60)
|
| 141 |
+
print("🔍 AUTO-DETECTING INSTITUTION FROM PAPER")
|
| 142 |
+
print("="*60)
|
| 143 |
+
|
| 144 |
+
# Use the raw_result we already have from the parser
|
| 145 |
+
if raw_result:
|
| 146 |
+
print(f"📄 Using parsed paper content")
|
| 147 |
+
# Extract text content from the ConversionResult object
|
| 148 |
+
try:
|
| 149 |
+
paper_text = raw_result.document.export_to_markdown()
|
| 150 |
+
except:
|
| 151 |
+
# Fallback: try to get text content in another way
|
| 152 |
+
paper_text = str(raw_result)
|
| 153 |
+
|
| 154 |
+
print("🔎 Searching for FIRST AUTHOR's institution...")
|
| 155 |
+
first_author_inst = logo_manager.extract_first_author_institution(paper_text)
|
| 156 |
+
|
| 157 |
+
if first_author_inst:
|
| 158 |
+
print(f"\n✅ FIRST AUTHOR INSTITUTION: {first_author_inst}")
|
| 159 |
+
print(f"🔍 Searching for logo: {first_author_inst}")
|
| 160 |
+
|
| 161 |
+
inst_logo_path = logo_manager.get_logo_path(first_author_inst, category="institute", use_google=args.use_google_search)
|
| 162 |
+
if inst_logo_path:
|
| 163 |
+
institution_logo_path = str(inst_logo_path)
|
| 164 |
+
print(f"✅ Institution logo found: {institution_logo_path}")
|
| 165 |
+
else:
|
| 166 |
+
print(f"❌ Could not find/download logo for: {first_author_inst}")
|
| 167 |
+
else:
|
| 168 |
+
print("❌ No first author institution detected or matched with available logos")
|
| 169 |
+
else:
|
| 170 |
+
print("❌ No parsed content available")
|
| 171 |
+
print("="*60 + "\n")
|
| 172 |
+
|
| 173 |
+
# Handle conference logo
|
| 174 |
+
if args.conference_venue and not conference_logo_path:
|
| 175 |
+
print("\n" + "="*60)
|
| 176 |
+
print("🏛️ SEARCHING FOR CONFERENCE LOGO")
|
| 177 |
+
print("="*60)
|
| 178 |
+
print(f"📍 Conference: {args.conference_venue}")
|
| 179 |
+
print(f"🔍 Searching for logo...")
|
| 180 |
+
|
| 181 |
+
conf_logo_path = logo_manager.get_logo_path(args.conference_venue, category="conference", use_google=args.use_google_search)
|
| 182 |
+
if conf_logo_path:
|
| 183 |
+
conference_logo_path = str(conf_logo_path)
|
| 184 |
+
print(f"✅ Conference logo found: {conference_logo_path}")
|
| 185 |
+
else:
|
| 186 |
+
print(f"❌ Could not find/download logo for: {args.conference_venue}")
|
| 187 |
+
# Note: Web search is now handled inside get_logo_path automatically
|
| 188 |
+
print("="*60 + "\n")
|
| 189 |
+
|
| 190 |
+
# Step 2: Filter unnecessary images and tables
|
| 191 |
+
input_token, output_token = filter_image_table(args, agent_config_t)
|
| 192 |
+
total_input_tokens_t += input_token
|
| 193 |
+
total_output_tokens_t += output_token
|
| 194 |
+
print(f'Filter figures token consumption: {input_token} -> {output_token}')
|
| 195 |
+
|
| 196 |
+
filter_time_taken = time.time() - parser_time
|
| 197 |
+
print(f'Filter time: {filter_time_taken:.2f} seconds')
|
| 198 |
+
detail_log['filter_time'] = filter_time_taken
|
| 199 |
+
|
| 200 |
+
filter_time = time.time()
|
| 201 |
+
|
| 202 |
+
detail_log['filter_in_t'] = input_token
|
| 203 |
+
detail_log['filter_out_t'] = output_token
|
| 204 |
+
|
| 205 |
+
# Step 3: Generate outline
|
| 206 |
+
input_token, output_token, panels, figures = gen_outline_layout_v2(args, agent_config_t)
|
| 207 |
+
total_input_tokens_t += input_token
|
| 208 |
+
total_output_tokens_t += output_token
|
| 209 |
+
print(f'Outline token consumption: {input_token} -> {output_token}')
|
| 210 |
+
|
| 211 |
+
outline_time_taken = time.time() - filter_time
|
| 212 |
+
print(f'Outline time: {outline_time_taken:.2f} seconds')
|
| 213 |
+
detail_log['outline_time'] = outline_time_taken
|
| 214 |
+
|
| 215 |
+
outline_time = time.time()
|
| 216 |
+
|
| 217 |
+
detail_log['outline_in_t'] = input_token
|
| 218 |
+
detail_log['outline_out_t'] = output_token
|
| 219 |
+
|
| 220 |
+
if args.ablation_no_tree_layout:
|
| 221 |
+
panel_arrangement, figure_arrangement, text_arrangement, input_token, output_token = no_tree_get_layout(
|
| 222 |
+
poster_width,
|
| 223 |
+
poster_height,
|
| 224 |
+
panels,
|
| 225 |
+
figures,
|
| 226 |
+
agent_config_t
|
| 227 |
+
)
|
| 228 |
+
total_input_tokens_t += input_token
|
| 229 |
+
total_output_tokens_t += output_token
|
| 230 |
+
print(f'No tree layout token consumption: {input_token} -> {output_token}')
|
| 231 |
+
detail_log['no_tree_layout_in_t'] = input_token
|
| 232 |
+
detail_log['no_tree_layout_out_t'] = output_token
|
| 233 |
+
else:
|
| 234 |
+
|
| 235 |
+
# Step 4: Learn and generate layout
|
| 236 |
+
panel_model_params, figure_model_params = main_train()
|
| 237 |
+
|
| 238 |
+
panel_arrangement, figure_arrangement, text_arrangement = main_inference(
|
| 239 |
+
panels,
|
| 240 |
+
panel_model_params,
|
| 241 |
+
figure_model_params,
|
| 242 |
+
poster_width,
|
| 243 |
+
poster_height,
|
| 244 |
+
shrink_margin=3
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
text_arrangement_title = text_arrangement[0]
|
| 248 |
+
text_arrangement = text_arrangement[1:]
|
| 249 |
+
# Split the title textbox into two parts
|
| 250 |
+
text_arrangement_title_top, text_arrangement_title_bottom = split_textbox(
|
| 251 |
+
text_arrangement_title,
|
| 252 |
+
0.8
|
| 253 |
+
)
|
| 254 |
+
# Add the split textboxes back to the list
|
| 255 |
+
text_arrangement = [text_arrangement_title_top, text_arrangement_title_bottom] + text_arrangement
|
| 256 |
+
|
| 257 |
+
for i in range(len(figure_arrangement)):
|
| 258 |
+
panel_id = figure_arrangement[i]['panel_id']
|
| 259 |
+
panel_section_name = panels[panel_id]['section_name']
|
| 260 |
+
figure_info = figures[panel_section_name]
|
| 261 |
+
if 'image' in figure_info:
|
| 262 |
+
figure_id = figure_info['image']
|
| 263 |
+
if not figure_id in images:
|
| 264 |
+
figure_path = images[str(figure_id)]['image_path']
|
| 265 |
+
else:
|
| 266 |
+
figure_path = images[figure_id]['image_path']
|
| 267 |
+
elif 'table' in figure_info:
|
| 268 |
+
figure_id = figure_info['table']
|
| 269 |
+
if not figure_id in tables:
|
| 270 |
+
figure_path = tables[str(figure_id)]['table_path']
|
| 271 |
+
else:
|
| 272 |
+
figure_path = tables[figure_id]['table_path']
|
| 273 |
+
|
| 274 |
+
figure_arrangement[i]['figure_path'] = figure_path
|
| 275 |
+
|
| 276 |
+
for text_arrangement_item in text_arrangement:
|
| 277 |
+
num_chars = char_capacity(
|
| 278 |
+
bbox=(text_arrangement_item['x'], text_arrangement_item['y'], text_arrangement_item['height'], text_arrangement_item['width'])
|
| 279 |
+
)
|
| 280 |
+
text_arrangement_item['num_chars'] = num_chars
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
width_inch, height_inch, panel_arrangement_inches, figure_arrangement_inches, text_arrangement_inches = get_arrangments_in_inches(
|
| 284 |
+
poster_width, poster_height, panel_arrangement, figure_arrangement, text_arrangement, 25
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Save to file
|
| 288 |
+
tree_split_results = {
|
| 289 |
+
'poster_width': poster_width,
|
| 290 |
+
'poster_height': poster_height,
|
| 291 |
+
'poster_width_inches': width_inch,
|
| 292 |
+
'poster_height_inches': height_inch,
|
| 293 |
+
'panels': panels,
|
| 294 |
+
'panel_arrangement': panel_arrangement,
|
| 295 |
+
'figure_arrangement': figure_arrangement,
|
| 296 |
+
'text_arrangement': text_arrangement,
|
| 297 |
+
'panel_arrangement_inches': panel_arrangement_inches,
|
| 298 |
+
'figure_arrangement_inches': figure_arrangement_inches,
|
| 299 |
+
'text_arrangement_inches': text_arrangement_inches,
|
| 300 |
+
}
|
| 301 |
+
os.makedirs('tree_splits', exist_ok=True)
|
| 302 |
+
with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'w') as f:
|
| 303 |
+
json.dump(tree_split_results, f, indent=4)
|
| 304 |
+
|
| 305 |
+
layout_time_taken = time.time() - outline_time
|
| 306 |
+
print(f'Layout time: {layout_time_taken:.2f} seconds')
|
| 307 |
+
detail_log['layout_time'] = layout_time_taken
|
| 308 |
+
|
| 309 |
+
layout_time = time.time()
|
| 310 |
+
|
| 311 |
+
# # === Configuration Loading ===
|
| 312 |
+
# print("\n📋 Loading configuration from YAML files...", flush=True)
|
| 313 |
+
# yaml_cfg = load_poster_yaml_config(args.poster_path)
|
| 314 |
+
|
| 315 |
+
# # Extract configuration values
|
| 316 |
+
# bullet_fs, title_fs, poster_title_fs, poster_author_fs = extract_font_sizes(yaml_cfg)
|
| 317 |
+
# title_text_color, title_fill_color, main_text_color, main_text_fill_color = extract_colors(yaml_cfg)
|
| 318 |
+
# section_title_vertical_align = extract_vertical_alignment(yaml_cfg)
|
| 319 |
+
# section_title_symbol = extract_section_title_symbol(yaml_cfg)
|
| 320 |
+
|
| 321 |
+
# # Normalize configuration values
|
| 322 |
+
# bullet_fs, title_fs, poster_title_fs, poster_author_fs, \
|
| 323 |
+
# title_text_color, title_fill_color, main_text_color, main_text_fill_color = normalize_config_values(
|
| 324 |
+
# bullet_fs, title_fs, poster_title_fs, poster_author_fs,
|
| 325 |
+
# title_text_color, title_fill_color, main_text_color, main_text_fill_color
|
| 326 |
+
# )
|
| 327 |
+
|
| 328 |
+
# # Store configuration in args
|
| 329 |
+
# setattr(args, 'bullet_font_size', bullet_fs)
|
| 330 |
+
# setattr(args, 'section_title_font_size', title_fs)
|
| 331 |
+
# setattr(args, 'poster_title_font_size', poster_title_fs)
|
| 332 |
+
# setattr(args, 'poster_author_font_size', poster_author_fs)
|
| 333 |
+
# setattr(args, 'title_text_color', title_text_color)
|
| 334 |
+
# setattr(args, 'title_fill_color', title_fill_color)
|
| 335 |
+
# setattr(args, 'main_text_color', main_text_color)
|
| 336 |
+
# setattr(args, 'main_text_fill_color', main_text_fill_color)
|
| 337 |
+
# setattr(args, 'section_title_vertical_align', section_title_vertical_align)
|
| 338 |
+
|
| 339 |
+
# # Step 5: Generate content
|
| 340 |
+
# print(f"\n✍️ Generating poster content (max_workers={args.max_workers})...", flush=True)
|
| 341 |
+
# # --- Step 1: 检查缓存 ---
|
| 342 |
+
# content_cache_path = f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json'
|
| 343 |
+
|
| 344 |
+
# if os.path.exists(content_cache_path):
|
| 345 |
+
# print(f"🧩 Cache found: {content_cache_path}")
|
| 346 |
+
# print("⚡ Skipping model generation, loading from cache...")
|
| 347 |
+
# bullet_content = json.load(open(content_cache_path, 'r'))
|
| 348 |
+
# input_token_t = output_token_t = input_token_v = output_token_v = 0
|
| 349 |
+
# else:
|
| 350 |
+
# print("🧠 Running model to generate poster content...")
|
| 351 |
+
# input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(
|
| 352 |
+
# args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir
|
| 353 |
+
# )
|
| 354 |
+
# bullet_content = json.load(open(content_cache_path, 'r'))
|
| 355 |
+
|
| 356 |
+
# input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir)
|
| 357 |
+
# total_input_tokens_t += input_token
|
| 358 |
+
# total_output_tokens_t += output_token
|
| 359 |
+
# total_input_tokens_v += input_token_v
|
| 360 |
+
# total_output_tokens_v += output_token_v
|
| 361 |
+
# print(f'Content generation token consumption T: {input_token_t} -> {output_token_t}')
|
| 362 |
+
# print(f'Content generation token consumption V: {input_token_v} -> {output_token_v}')
|
| 363 |
+
|
| 364 |
+
# content_time_taken = time.time() - layout_time
|
| 365 |
+
# print(f'Content generation time: {content_time_taken:.2f} seconds')
|
| 366 |
+
# detail_log['content_time'] = content_time_taken
|
| 367 |
+
|
| 368 |
+
# content_time = time.time()
|
| 369 |
+
|
| 370 |
+
# bullet_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'r'))
|
| 371 |
+
|
| 372 |
+
# detail_log['content_in_t'] = input_token_t
|
| 373 |
+
# detail_log['content_out_t'] = output_token_t
|
| 374 |
+
# detail_log['content_in_v'] = input_token_v
|
| 375 |
+
# detail_log['content_out_v'] = output_token_v
|
| 376 |
+
|
| 377 |
+
# # === Style Application ===
|
| 378 |
+
# print("\n🎨 Applying styles and colors...", flush=True)
|
| 379 |
+
|
| 380 |
+
# # Resolve colors with fallbacks
|
| 381 |
+
# final_title_text_color, final_title_fill_color, final_main_text_color, final_main_text_fill_color = resolve_colors(
|
| 382 |
+
# getattr(args, 'title_text_color', None),
|
| 383 |
+
# getattr(args, 'title_fill_color', None),
|
| 384 |
+
# getattr(args, 'main_text_color', None),
|
| 385 |
+
# getattr(args, 'main_text_fill_color', None)
|
| 386 |
+
# )
|
| 387 |
+
|
| 388 |
+
# # Apply all styles in one go
|
| 389 |
+
# bullet_content = apply_all_styles(
|
| 390 |
+
# bullet_content,
|
| 391 |
+
# title_text_color=final_title_text_color,
|
| 392 |
+
# title_fill_color=final_title_fill_color,
|
| 393 |
+
# main_text_color=final_main_text_color,
|
| 394 |
+
# main_text_fill_color=final_main_text_fill_color,
|
| 395 |
+
# section_title_symbol=section_title_symbol,
|
| 396 |
+
# main_text_font_size=bullet_fs
|
| 397 |
+
# )
|
| 398 |
+
|
| 399 |
+
# # === Poster Generation ===
|
| 400 |
+
# # print("\n🎯 Generating PowerPoint code...", flush=True)
|
| 401 |
+
|
| 402 |
+
# # Create theme with alignment
|
| 403 |
+
# base_theme = get_default_theme()
|
| 404 |
+
# theme_with_alignment = create_theme_with_alignment(
|
| 405 |
+
# base_theme,
|
| 406 |
+
# getattr(args, 'section_title_vertical_align', None)
|
| 407 |
+
# )
|
| 408 |
+
|
| 409 |
+
# # poster_code = generate_poster_code(
|
| 410 |
+
# # panel_arrangement_inches,
|
| 411 |
+
# # text_arrangement_inches,
|
| 412 |
+
# # figure_arrangement_inches,
|
| 413 |
+
# # presentation_object_name='poster_presentation',
|
| 414 |
+
# # slide_object_name='poster_slide',
|
| 415 |
+
# # utils_functions=utils_functions,
|
| 416 |
+
# # slide_width=width_inch,
|
| 417 |
+
# # slide_height=height_inch,
|
| 418 |
+
# # img_path=None,
|
| 419 |
+
# # save_path=f'{args.tmp_dir}/poster.pptx',
|
| 420 |
+
# # visible=False,
|
| 421 |
+
# # content=bullet_content,
|
| 422 |
+
# # theme=theme_with_alignment,
|
| 423 |
+
# # tmp_dir=args.tmp_dir,
|
| 424 |
+
# # )
|
| 425 |
+
# print("\n🎯 Generating Beamer poster (LaTeX)...", flush=True)
|
| 426 |
+
|
| 427 |
+
# # --- 1. 提取 poster_info ---
|
| 428 |
+
# poster_info = {
|
| 429 |
+
# "title": args.poster_name,
|
| 430 |
+
# "author": "AutoGen",
|
| 431 |
+
# "institute": "Auto-detected Institution"
|
| 432 |
+
# }
|
| 433 |
+
# if isinstance(bullet_content, list) and len(bullet_content) > 0:
|
| 434 |
+
# first_section = bullet_content[0]
|
| 435 |
+
# if isinstance(first_section, dict):
|
| 436 |
+
# if "poster_title" in first_section:
|
| 437 |
+
# poster_info["title"] = first_section["poster_title"]
|
| 438 |
+
# elif "title" in first_section:
|
| 439 |
+
# poster_info["title"] = first_section["title"]
|
| 440 |
+
|
| 441 |
+
# --- 2. 构造 Beamer 数据结构 ---
|
| 442 |
+
# layout_data = {
|
| 443 |
+
# "text_arrangement": text_arrangement,
|
| 444 |
+
# "figure_arrangement": figure_arrangement
|
| 445 |
+
# }
|
| 446 |
+
# beamer_data = convert_pptx_layout_to_beamer(layout_data)
|
| 447 |
+
|
| 448 |
+
# 将 bullet_content 映射进 sections
|
| 449 |
+
# for i, section in enumerate(beamer_data["sections"]):
|
| 450 |
+
# if i < len(bullet_content):
|
| 451 |
+
# section_data = bullet_content[i]
|
| 452 |
+
# if isinstance(section_data, dict):
|
| 453 |
+
# section["content"] = section_data.get("textbox1") or section_data.get("title") or json.dumps(section_data)
|
| 454 |
+
# else:
|
| 455 |
+
# section["content"] = str(section_data)
|
| 456 |
+
|
| 457 |
+
# --- 3. 生成 LaTeX 文件 ---
|
| 458 |
+
# poster_info = {k: (str(v) if not isinstance(v, str) else v) for k, v in poster_info.items()}
|
| 459 |
+
|
| 460 |
+
# beamer_code = generate_beamer_poster_code(
|
| 461 |
+
# beamer_data["sections"],
|
| 462 |
+
# beamer_data["figures"],
|
| 463 |
+
# poster_info,
|
| 464 |
+
# width_cm=poster_width_inches * 2.54,
|
| 465 |
+
# height_cm=poster_height_inches * 2.54,
|
| 466 |
+
# theme="Madrid",
|
| 467 |
+
# output_path=f"{args.tmp_dir}/{poster_name}.tex"
|
| 468 |
+
# )
|
| 469 |
+
# save_beamer_code(beamer_code, f"{args.tmp_dir}/{poster_name}.tex")
|
| 470 |
+
|
| 471 |
+
|
| 472 |
+
# --- 4. 编译为 PDF ---
|
| 473 |
+
# output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_beamer_posters/{args.poster_path.replace("paper.pdf", "")}'
|
| 474 |
+
# compile_beamer_to_pdf(f"{args.tmp_dir}/{poster_name}.tex", output_dir=args.tmp_dir)
|
| 475 |
+
# pdf_path = os.path.join(args.tmp_dir, f"{poster_name}.pdf")
|
| 476 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 477 |
+
# os.rename(pdf_path, os.path.join(output_dir, f"{poster_name}.pdf"))
|
| 478 |
+
|
| 479 |
+
# print(f"✅ Beamer poster PDF saved to {output_dir}")
|
| 480 |
+
# Add logos to the poster
|
| 481 |
+
# print("\n🖼️ Adding logos to poster...", flush=True)
|
| 482 |
+
# poster_code = add_logos_to_poster_code(
|
| 483 |
+
# poster_code,
|
| 484 |
+
# width_inch,
|
| 485 |
+
# height_inch,
|
| 486 |
+
# institution_logo_path=institution_logo_path,
|
| 487 |
+
# conference_logo_path=conference_logo_path
|
| 488 |
+
# )
|
| 489 |
+
|
| 490 |
+
# output, err = run_code(poster_code)
|
| 491 |
+
# if err is not None:
|
| 492 |
+
# raise RuntimeError(f'Error in generating PowerPoint: {err}')
|
| 493 |
+
|
| 494 |
+
# # Step 8: Create a folder in the output directory
|
| 495 |
+
# output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_posters/{args.poster_path.replace("paper.pdf", "")}'
|
| 496 |
+
# os.makedirs(output_dir, exist_ok=True)
|
| 497 |
+
|
| 498 |
+
# # Copy logos to output directory for reference
|
| 499 |
+
# logos_dir = os.path.join(output_dir, 'logos')
|
| 500 |
+
# if institution_logo_path or conference_logo_path:
|
| 501 |
+
# os.makedirs(logos_dir, exist_ok=True)
|
| 502 |
+
# if institution_logo_path and os.path.exists(institution_logo_path):
|
| 503 |
+
# shutil.copy2(institution_logo_path, os.path.join(logos_dir, 'institution_logo' + os.path.splitext(institution_logo_path)[1]))
|
| 504 |
+
# if conference_logo_path and os.path.exists(conference_logo_path):
|
| 505 |
+
# shutil.copy2(conference_logo_path, os.path.join(logos_dir, 'conference_logo' + os.path.splitext(conference_logo_path)[1]))
|
| 506 |
+
|
| 507 |
+
# # Step 9: Move poster.pptx to the output directory
|
| 508 |
+
# pptx_path = os.path.join(output_dir, f'{poster_name}.pptx')
|
| 509 |
+
# os.rename(f'{args.tmp_dir}/poster.pptx', pptx_path)
|
| 510 |
+
# print(f'Poster PowerPoint saved to {pptx_path}')
|
| 511 |
+
# # Step 10: Convert the PowerPoint to images
|
| 512 |
+
# ppt_to_images(pptx_path, output_dir)
|
| 513 |
+
# print(f'Poster images saved to {output_dir}')
|
| 514 |
+
|
| 515 |
+
# end_time = time.time()
|
| 516 |
+
# time_taken = end_time - start_time
|
| 517 |
+
|
| 518 |
+
# render_time_taken = time.time() - content_time
|
| 519 |
+
# print(f'Render time: {render_time_taken:.2f} seconds')
|
| 520 |
+
# detail_log['render_time'] = render_time_taken
|
| 521 |
+
|
| 522 |
+
# # log
|
| 523 |
+
# log_file = os.path.join(output_dir, 'log.json')
|
| 524 |
+
# with open(log_file, 'w') as f:
|
| 525 |
+
# log_data = {
|
| 526 |
+
# 'input_tokens_t': total_input_tokens_t,
|
| 527 |
+
# 'output_tokens_t': total_output_tokens_t,
|
| 528 |
+
# 'input_tokens_v': total_input_tokens_v,
|
| 529 |
+
# 'output_tokens_v': total_output_tokens_v,
|
| 530 |
+
# 'time_taken': time_taken,
|
| 531 |
+
# 'institution_logo': institution_logo_path,
|
| 532 |
+
# 'conference_logo': conference_logo_path,
|
| 533 |
+
# }
|
| 534 |
+
# json.dump(log_data, f, indent=4)
|
| 535 |
+
|
| 536 |
+
# detail_log_file = os.path.join(output_dir, 'detail_log.json')
|
| 537 |
+
# with open(detail_log_file, 'w') as f:
|
| 538 |
+
# json.dump(detail_log, f, indent=4)
|
| 539 |
+
|
| 540 |
+
# print(f'\nTotal time: {time_taken:.2f} seconds')
|
| 541 |
+
# print(f'Total text model tokens: {total_input_tokens_t} -> {total_output_tokens_t}')
|
| 542 |
+
# print(f'Total vision model tokens: {total_input_tokens_v} -> {total_output_tokens_v}')
|
| 543 |
+
|
| 544 |
+
# if institution_logo_path:
|
| 545 |
+
# print(f'Institution logo added: {institution_logo_path}')
|
| 546 |
+
# if conference_logo_path:
|
| 547 |
+
# print(f'Conference logo added: {conference_logo_path}')
|
Paper2Poster/PosterAgent/parse_raw.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
from utils.src.utils import get_json_from_response
|
| 3 |
+
from utils.src.model_utils import parse_pdf
|
| 4 |
+
import json
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
from camel.models import ModelFactory
|
| 8 |
+
from camel.agents import ChatAgent
|
| 9 |
+
from tenacity import retry, stop_after_attempt
|
| 10 |
+
from docling_core.types.doc import ImageRefMode, PictureItem, TableItem
|
| 11 |
+
|
| 12 |
+
from docling.datamodel.base_models import InputFormat
|
| 13 |
+
from docling.datamodel.pipeline_options import PdfPipelineOptions
|
| 14 |
+
from docling.document_converter import DocumentConverter, PdfFormatOption
|
| 15 |
+
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import PIL
|
| 19 |
+
|
| 20 |
+
from marker.models import create_model_dict
|
| 21 |
+
|
| 22 |
+
from utils.wei_utils import *
|
| 23 |
+
|
| 24 |
+
from utils.pptx_utils import *
|
| 25 |
+
from utils.critic_utils import *
|
| 26 |
+
import torch
|
| 27 |
+
from jinja2 import Template
|
| 28 |
+
import re
|
| 29 |
+
import argparse
|
| 30 |
+
|
| 31 |
+
load_dotenv()
|
| 32 |
+
IMAGE_RESOLUTION_SCALE = 5.0
|
| 33 |
+
|
| 34 |
+
pipeline_options = PdfPipelineOptions()
|
| 35 |
+
pipeline_options.images_scale = IMAGE_RESOLUTION_SCALE
|
| 36 |
+
pipeline_options.generate_page_images = True
|
| 37 |
+
pipeline_options.generate_picture_images = True
|
| 38 |
+
|
| 39 |
+
doc_converter = DocumentConverter(
|
| 40 |
+
format_options={
|
| 41 |
+
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
|
| 42 |
+
}
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
@retry(stop=stop_after_attempt(5))
|
| 46 |
+
def parse_raw(args, actor_config, version=2):
|
| 47 |
+
raw_source = args.poster_path
|
| 48 |
+
markdown_clean_pattern = re.compile(r"<!--[\s\S]*?-->")
|
| 49 |
+
|
| 50 |
+
raw_result = doc_converter.convert(raw_source)
|
| 51 |
+
|
| 52 |
+
raw_markdown = raw_result.document.export_to_markdown()
|
| 53 |
+
text_content = markdown_clean_pattern.sub("", raw_markdown)
|
| 54 |
+
|
| 55 |
+
if len(text_content) < 500:
|
| 56 |
+
print('\nParsing with docling failed, using marker instead\n')
|
| 57 |
+
parser_model = create_model_dict(device='cuda', dtype=torch.float16)
|
| 58 |
+
text_content, rendered = parse_pdf(raw_source, model_lst=parser_model, save_file=False)
|
| 59 |
+
|
| 60 |
+
if version == 1:
|
| 61 |
+
template = Template(open("utils/prompts/gen_poster_raw_content.txt").read())
|
| 62 |
+
elif version == 2:
|
| 63 |
+
print('Using v2 prompt template')
|
| 64 |
+
template = Template(open("utils/prompts/gen_poster_raw_content_v2.txt").read())
|
| 65 |
+
|
| 66 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 67 |
+
actor_model = ModelFactory.create(
|
| 68 |
+
model_platform=actor_config['model_platform'],
|
| 69 |
+
model_type=actor_config['model_type'],
|
| 70 |
+
model_config_dict=actor_config['model_config'],
|
| 71 |
+
url=actor_config['url'],
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
actor_model = ModelFactory.create(
|
| 75 |
+
model_platform=actor_config['model_platform'],
|
| 76 |
+
model_type=actor_config['model_type'],
|
| 77 |
+
model_config_dict=actor_config['model_config'],
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
actor_sys_msg = 'You are the author of the paper, and you will create a poster for the paper.'
|
| 81 |
+
|
| 82 |
+
actor_agent = ChatAgent(
|
| 83 |
+
system_message=actor_sys_msg,
|
| 84 |
+
model=actor_model,
|
| 85 |
+
message_window_size=10,
|
| 86 |
+
token_limit=actor_config.get('token_limit', None)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
while True:
|
| 90 |
+
prompt = template.render(
|
| 91 |
+
markdown_document=text_content,
|
| 92 |
+
)
|
| 93 |
+
actor_agent.reset()
|
| 94 |
+
response = actor_agent.step(prompt)
|
| 95 |
+
input_token, output_token = account_token(response)
|
| 96 |
+
|
| 97 |
+
content_json = get_json_from_response(response.msgs[0].content)
|
| 98 |
+
|
| 99 |
+
if len(content_json) > 0:
|
| 100 |
+
break
|
| 101 |
+
print('Error: Empty response, retrying...')
|
| 102 |
+
if args.model_name_t.startswith('vllm_qwen'):
|
| 103 |
+
text_content = text_content[:80000]
|
| 104 |
+
|
| 105 |
+
if len(content_json['sections']) > 9:
|
| 106 |
+
# First 2 sections + randomly select 5 sections + last 2 sections
|
| 107 |
+
selected_sections = content_json['sections'][:2] + random.sample(content_json['sections'][2:-2], 5) + content_json['sections'][-2:]
|
| 108 |
+
content_json['sections'] = selected_sections
|
| 109 |
+
|
| 110 |
+
has_title = False
|
| 111 |
+
|
| 112 |
+
for section in content_json['sections']:
|
| 113 |
+
if type(section) != dict or not 'title' in section or not 'content' in section:
|
| 114 |
+
print(f"Ouch! The response is invalid, the LLM is not following the format :(")
|
| 115 |
+
print('Trying again...')
|
| 116 |
+
raise
|
| 117 |
+
if 'title' in section['title'].lower():
|
| 118 |
+
has_title = True
|
| 119 |
+
|
| 120 |
+
if not has_title:
|
| 121 |
+
print('Ouch! The response is invalid, the LLM is not following the format :(')
|
| 122 |
+
raise
|
| 123 |
+
|
| 124 |
+
os.makedirs('contents', exist_ok=True)
|
| 125 |
+
json.dump(content_json, open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'w'), indent=4)
|
| 126 |
+
return input_token, output_token, raw_result
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def gen_image_and_table(args, conv_res):
|
| 130 |
+
input_token, output_token = 0, 0
|
| 131 |
+
raw_source = args.poster_path
|
| 132 |
+
|
| 133 |
+
output_dir = Path(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}')
|
| 134 |
+
|
| 135 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 136 |
+
doc_filename = args.poster_name
|
| 137 |
+
|
| 138 |
+
# Save page images
|
| 139 |
+
for page_no, page in conv_res.document.pages.items():
|
| 140 |
+
page_no = page.page_no
|
| 141 |
+
page_image_filename = output_dir / f"{doc_filename}-{page_no}.png"
|
| 142 |
+
with page_image_filename.open("wb") as fp:
|
| 143 |
+
page.image.pil_image.save(fp, format="PNG")
|
| 144 |
+
|
| 145 |
+
# Save images of figures and tables
|
| 146 |
+
table_counter = 0
|
| 147 |
+
picture_counter = 0
|
| 148 |
+
for element, _level in conv_res.document.iterate_items():
|
| 149 |
+
if isinstance(element, TableItem):
|
| 150 |
+
table_counter += 1
|
| 151 |
+
element_image_filename = (
|
| 152 |
+
output_dir / f"{doc_filename}-table-{table_counter}.png"
|
| 153 |
+
)
|
| 154 |
+
with element_image_filename.open("wb") as fp:
|
| 155 |
+
element.get_image(conv_res.document).save(fp, "PNG")
|
| 156 |
+
|
| 157 |
+
if isinstance(element, PictureItem):
|
| 158 |
+
picture_counter += 1
|
| 159 |
+
element_image_filename = (
|
| 160 |
+
output_dir / f"{doc_filename}-picture-{picture_counter}.png"
|
| 161 |
+
)
|
| 162 |
+
with element_image_filename.open("wb") as fp:
|
| 163 |
+
element.get_image(conv_res.document).save(fp, "PNG")
|
| 164 |
+
|
| 165 |
+
# Save markdown with embedded pictures
|
| 166 |
+
md_filename = output_dir / f"{doc_filename}-with-images.md"
|
| 167 |
+
conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.EMBEDDED)
|
| 168 |
+
|
| 169 |
+
# Save markdown with externally referenced pictures
|
| 170 |
+
md_filename = output_dir / f"{doc_filename}-with-image-refs.md"
|
| 171 |
+
conv_res.document.save_as_markdown(md_filename, image_mode=ImageRefMode.REFERENCED)
|
| 172 |
+
|
| 173 |
+
# Save HTML with externally referenced pictures
|
| 174 |
+
html_filename = output_dir / f"{doc_filename}-with-image-refs.html"
|
| 175 |
+
conv_res.document.save_as_html(html_filename, image_mode=ImageRefMode.REFERENCED)
|
| 176 |
+
|
| 177 |
+
tables = {}
|
| 178 |
+
|
| 179 |
+
table_index = 1
|
| 180 |
+
for table in conv_res.document.tables:
|
| 181 |
+
caption = table.caption_text(conv_res.document)
|
| 182 |
+
if len(caption) > 0:
|
| 183 |
+
table_img_path = f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}/{args.poster_name}-table-{table_index}.png'
|
| 184 |
+
table_img = PIL.Image.open(table_img_path)
|
| 185 |
+
tables[str(table_index)] = {
|
| 186 |
+
'caption': caption,
|
| 187 |
+
'table_path': table_img_path,
|
| 188 |
+
'width': table_img.width,
|
| 189 |
+
'height': table_img.height,
|
| 190 |
+
'figure_size': table_img.width * table_img.height,
|
| 191 |
+
'figure_aspect': table_img.width / table_img.height,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
table_index += 1
|
| 195 |
+
|
| 196 |
+
images = {}
|
| 197 |
+
image_index = 1
|
| 198 |
+
for image in conv_res.document.pictures:
|
| 199 |
+
caption = image.caption_text(conv_res.document)
|
| 200 |
+
if len(caption) > 0:
|
| 201 |
+
image_img_path = f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}/{args.poster_name}-picture-{image_index}.png'
|
| 202 |
+
image_img = PIL.Image.open(image_img_path)
|
| 203 |
+
images[str(image_index)] = {
|
| 204 |
+
'caption': caption,
|
| 205 |
+
'image_path': image_img_path,
|
| 206 |
+
'width': image_img.width,
|
| 207 |
+
'height': image_img.height,
|
| 208 |
+
'figure_size': image_img.width * image_img.height,
|
| 209 |
+
'figure_aspect': image_img.width / image_img.height,
|
| 210 |
+
}
|
| 211 |
+
image_index += 1
|
| 212 |
+
|
| 213 |
+
json.dump(images, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images.json', 'w'), indent=4)
|
| 214 |
+
json.dump(tables, open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables.json', 'w'), indent=4)
|
| 215 |
+
|
| 216 |
+
return input_token, output_token, images, tables
|
| 217 |
+
|
| 218 |
+
if __name__ == '__main__':
|
| 219 |
+
parser = argparse.ArgumentParser()
|
| 220 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 221 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 222 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 223 |
+
parser.add_argument('--index', type=int, default=0)
|
| 224 |
+
args = parser.parse_args()
|
| 225 |
+
|
| 226 |
+
agent_config = get_agent_config(args.model_name)
|
| 227 |
+
|
| 228 |
+
if args.poster_name is None:
|
| 229 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 230 |
+
|
| 231 |
+
# Parse raw content
|
| 232 |
+
input_token, output_token = parse_raw(args, agent_config)
|
| 233 |
+
|
| 234 |
+
# Generate images and tables
|
| 235 |
+
_, _ = gen_image_and_table(args)
|
| 236 |
+
|
| 237 |
+
print(f'Token consumption: {input_token} -> {output_token}')
|
Paper2Poster/PosterAgent/poster_gen_pipeline.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
from utils.wei_utils import get_agent_config
|
| 5 |
+
from PosterAgent.parse_raw import parse_raw, gen_image_and_table
|
| 6 |
+
from PosterAgent.gen_outline_layout import filter_image_table, gen_outline_layout
|
| 7 |
+
from PosterAgent.gen_poster_content import gen_poster_content
|
| 8 |
+
from PosterAgent.fill_and_style import fill_poster_content, stylize_poster
|
| 9 |
+
from PosterAgent.deoverflow import deoverflow
|
| 10 |
+
from PosterAgent.apply_theme import poster_apply_theme
|
| 11 |
+
|
| 12 |
+
if __name__ == '__main__':
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('--poster_name', type=str, default=None)
|
| 15 |
+
parser.add_argument('--model_name', type=str, default='4o')
|
| 16 |
+
parser.add_argument('--poster_path', type=str, required=True)
|
| 17 |
+
parser.add_argument('--index', type=int, default=0)
|
| 18 |
+
parser.add_argument('--template_path', type=str, default=None)
|
| 19 |
+
parser.add_argument('--max_retry', type=int, default=3)
|
| 20 |
+
args = parser.parse_args()
|
| 21 |
+
|
| 22 |
+
start_time = time.time()
|
| 23 |
+
|
| 24 |
+
actor_config = get_agent_config(args.model_name)
|
| 25 |
+
critic_config = get_agent_config(args.model_name)
|
| 26 |
+
|
| 27 |
+
if args.poster_name is None:
|
| 28 |
+
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
|
| 29 |
+
|
| 30 |
+
total_input_token, total_output_token = 0, 0
|
| 31 |
+
|
| 32 |
+
# Parse raw content
|
| 33 |
+
input_token, output_token = parse_raw(args, actor_config)
|
| 34 |
+
total_input_token += input_token
|
| 35 |
+
total_output_token += output_token
|
| 36 |
+
|
| 37 |
+
# Generate images and tables
|
| 38 |
+
_, _ = gen_image_and_table(args)
|
| 39 |
+
|
| 40 |
+
print()
|
| 41 |
+
print(f'Parsing token consumption: {input_token} -> {output_token}')
|
| 42 |
+
|
| 43 |
+
input_token, output_token = filter_image_table(args, actor_config)
|
| 44 |
+
total_input_token += input_token
|
| 45 |
+
total_output_token += output_token
|
| 46 |
+
print()
|
| 47 |
+
print(f'Filter images and tables token consumption: {input_token} -> {output_token}')
|
| 48 |
+
|
| 49 |
+
input_token, output_token = gen_outline_layout(args, actor_config, critic_config)
|
| 50 |
+
total_input_token += input_token
|
| 51 |
+
total_output_token += output_token
|
| 52 |
+
print()
|
| 53 |
+
print(f'Generate outline and layout token consumption: {input_token} -> {output_token}')
|
| 54 |
+
|
| 55 |
+
input_token, output_token = gen_poster_content(args, actor_config)
|
| 56 |
+
total_input_token += input_token
|
| 57 |
+
total_output_token += output_token
|
| 58 |
+
print()
|
| 59 |
+
print(f'Generate poster content token consumption: {input_token} -> {output_token}')
|
| 60 |
+
|
| 61 |
+
input_token, output_token = fill_poster_content(args, actor_config)
|
| 62 |
+
total_input_token += input_token
|
| 63 |
+
total_output_token += output_token
|
| 64 |
+
print()
|
| 65 |
+
print(f'Fill poster content token consumption: {input_token} -> {output_token}')
|
| 66 |
+
|
| 67 |
+
input_token, output_token = stylize_poster(args, actor_config)
|
| 68 |
+
total_input_token += input_token
|
| 69 |
+
total_output_token += output_token
|
| 70 |
+
print()
|
| 71 |
+
print(f'Stylize poster token consumption: {input_token} -> {output_token}')
|
| 72 |
+
|
| 73 |
+
input_token, output_token = deoverflow(args, actor_config, critic_config)
|
| 74 |
+
total_input_token += input_token
|
| 75 |
+
total_output_token += output_token
|
| 76 |
+
print()
|
| 77 |
+
print(f'Deoverflow token consumption: {input_token} -> {output_token}')
|
| 78 |
+
|
| 79 |
+
if args.template_path is not None:
|
| 80 |
+
input_token, output_token = poster_apply_theme(args, actor_config, critic_config)
|
| 81 |
+
total_input_token += input_token
|
| 82 |
+
total_output_token += output_token
|
| 83 |
+
print()
|
| 84 |
+
print(f'Apply theme token consumption: {input_token} -> {output_token}')
|
| 85 |
+
|
| 86 |
+
print()
|
| 87 |
+
print(f'Total token consumption: {total_input_token} -> {total_output_token}')
|
| 88 |
+
|
| 89 |
+
end_time = time.time()
|
| 90 |
+
elapsed_time = end_time - start_time
|
| 91 |
+
# Convert to hh:mm:ss format
|
| 92 |
+
hours, rem = divmod(elapsed_time, 3600)
|
| 93 |
+
minutes, seconds = divmod(rem, 60)
|
| 94 |
+
|
| 95 |
+
print(f"Execution Time: {int(hours):02}:{int(minutes):02}:{int(seconds):02}")
|
| 96 |
+
|
| 97 |
+
log_path = f'log/{args.model_name}_{args.poster_name}_{args.index}_log.txt'
|
| 98 |
+
with open(log_path, 'w') as f:
|
| 99 |
+
f.write(f'Total token consumption: {total_input_token} -> {total_output_token}\n')
|
| 100 |
+
f.write(f'Execution Time: {int(hours):02}:{int(minutes):02}:{int(seconds):02}\n')
|
| 101 |
+
|
Paper2Poster/PosterAgent/tree_split_layout.py
ADDED
|
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from lxml import etree
|
| 2 |
+
import os
|
| 3 |
+
import copy
|
| 4 |
+
import glob
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sklearn.linear_model import LinearRegression, LogisticRegression
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import matplotlib.patches as patches
|
| 9 |
+
|
| 10 |
+
def parse_xml_with_recovery(xml_file_path):
|
| 11 |
+
parser = etree.XMLParser(recover=True)
|
| 12 |
+
tree = etree.parse(xml_file_path, parser)
|
| 13 |
+
return tree.getroot()
|
| 14 |
+
|
| 15 |
+
def parse_poster_xml(xml_file):
|
| 16 |
+
"""
|
| 17 |
+
Parse an XML describing a single poster layout, e.g.:
|
| 18 |
+
|
| 19 |
+
<Poster Width="685" Height="968">
|
| 20 |
+
<Panel left="5" right="160" width="674" height="123">
|
| 21 |
+
<Text>Introduction</Text>
|
| 22 |
+
<Figure left="567" right="178" width="81" height="99" no="1" ... />
|
| 23 |
+
</Panel>
|
| 24 |
+
...
|
| 25 |
+
</Poster>
|
| 26 |
+
|
| 27 |
+
Returns a dict with:
|
| 28 |
+
{
|
| 29 |
+
'poster_width': float,
|
| 30 |
+
'poster_height': float,
|
| 31 |
+
'panels': [
|
| 32 |
+
{
|
| 33 |
+
'x': float,
|
| 34 |
+
'y': float,
|
| 35 |
+
'width': float,
|
| 36 |
+
'height': float,
|
| 37 |
+
'text_blocks': [string, string, ...],
|
| 38 |
+
'figure_blocks': [(fx, fy, fw, fh), ...]
|
| 39 |
+
},
|
| 40 |
+
...
|
| 41 |
+
]
|
| 42 |
+
}
|
| 43 |
+
"""
|
| 44 |
+
root = parse_xml_with_recovery(xml_file)
|
| 45 |
+
|
| 46 |
+
# Poster dimensions
|
| 47 |
+
poster_w = float(root.get("Width", "1"))
|
| 48 |
+
poster_h = float(root.get("Height", "1"))
|
| 49 |
+
|
| 50 |
+
panels_data = []
|
| 51 |
+
|
| 52 |
+
# Iterate <Panel> elements
|
| 53 |
+
for panel_node in root.findall("Panel"):
|
| 54 |
+
x = float(panel_node.get("left", "0"))
|
| 55 |
+
y = float(panel_node.get("right", "0"))
|
| 56 |
+
w = float(panel_node.get("width", "0"))
|
| 57 |
+
h = float(panel_node.get("height", "0"))
|
| 58 |
+
|
| 59 |
+
# Gather text blocks
|
| 60 |
+
text_blocks = []
|
| 61 |
+
for text_node in panel_node.findall("Text"):
|
| 62 |
+
txt = text_node.text or ""
|
| 63 |
+
txt = txt.strip()
|
| 64 |
+
if txt:
|
| 65 |
+
text_blocks.append(txt)
|
| 66 |
+
|
| 67 |
+
# Gather figure blocks
|
| 68 |
+
figure_blocks = []
|
| 69 |
+
for fig_node in panel_node.findall("Figure"):
|
| 70 |
+
fx = float(fig_node.get("left", "0"))
|
| 71 |
+
fy = float(fig_node.get("right", "0"))
|
| 72 |
+
fw = float(fig_node.get("width", "0"))
|
| 73 |
+
fh = float(fig_node.get("height", "0"))
|
| 74 |
+
figure_blocks.append((fx, fy, fw, fh))
|
| 75 |
+
|
| 76 |
+
panel_info = {
|
| 77 |
+
"x": x,
|
| 78 |
+
"y": y,
|
| 79 |
+
"width": w,
|
| 80 |
+
"height": h,
|
| 81 |
+
"text_blocks": text_blocks,
|
| 82 |
+
"figure_blocks": figure_blocks
|
| 83 |
+
}
|
| 84 |
+
panels_data.append(panel_info)
|
| 85 |
+
|
| 86 |
+
return {
|
| 87 |
+
"poster_width": poster_w,
|
| 88 |
+
"poster_height": poster_h,
|
| 89 |
+
"panels": panels_data
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
def compute_panel_attributes(poster_data):
|
| 93 |
+
"""
|
| 94 |
+
Given poster_data, compute:
|
| 95 |
+
- tp: ratio of text length for each panel
|
| 96 |
+
- gp: ratio of figure area for each panel
|
| 97 |
+
- sp: ratio of panel area to total poster area
|
| 98 |
+
- rp: aspect ratio (width / height)
|
| 99 |
+
|
| 100 |
+
Returns a list of dicts, each:
|
| 101 |
+
{
|
| 102 |
+
'tp': float,
|
| 103 |
+
'gp': float,
|
| 104 |
+
'sp': float,
|
| 105 |
+
'rp': float
|
| 106 |
+
}
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
poster_w = poster_data["poster_width"]
|
| 110 |
+
poster_h = poster_data["poster_height"]
|
| 111 |
+
panels = poster_data["panels"]
|
| 112 |
+
|
| 113 |
+
poster_area = max(poster_w * poster_h, 1.0) # avoid zero
|
| 114 |
+
|
| 115 |
+
# 1) Compute total text length across all panels
|
| 116 |
+
# 2) Compute total figure area across all panels
|
| 117 |
+
total_text_length = 0
|
| 118 |
+
total_figure_area = 0
|
| 119 |
+
|
| 120 |
+
# We'll store partial info about each panel so we don't parse multiple times
|
| 121 |
+
panel_list = []
|
| 122 |
+
for p in panels:
|
| 123 |
+
# Combine all text
|
| 124 |
+
panel_text_joined = " ".join(p["text_blocks"])
|
| 125 |
+
panel_text_len = len(panel_text_joined)
|
| 126 |
+
|
| 127 |
+
# Sum area of figure blocks
|
| 128 |
+
panel_fig_area = 0.0
|
| 129 |
+
for (fx, fy, fw, fh) in p["figure_blocks"]:
|
| 130 |
+
panel_fig_area += (fw * fh)
|
| 131 |
+
|
| 132 |
+
panel_list.append({
|
| 133 |
+
"x": p["x"],
|
| 134 |
+
"y": p["y"],
|
| 135 |
+
"width": p["width"],
|
| 136 |
+
"height": p["height"],
|
| 137 |
+
"text_len": panel_text_len,
|
| 138 |
+
"fig_area": panel_fig_area
|
| 139 |
+
})
|
| 140 |
+
|
| 141 |
+
total_text_length += panel_text_len
|
| 142 |
+
total_figure_area += panel_fig_area
|
| 143 |
+
|
| 144 |
+
# Avoid divide by zero
|
| 145 |
+
if total_text_length < 1:
|
| 146 |
+
total_text_length = 1
|
| 147 |
+
if total_figure_area < 1e-9:
|
| 148 |
+
total_figure_area = 1e-9
|
| 149 |
+
|
| 150 |
+
# 3) Compute attributes
|
| 151 |
+
results = []
|
| 152 |
+
for pinfo in panel_list:
|
| 153 |
+
pw = pinfo["width"]
|
| 154 |
+
ph = pinfo["height"]
|
| 155 |
+
|
| 156 |
+
panel_area = pw * ph
|
| 157 |
+
sp = panel_area / poster_area # fraction of total area
|
| 158 |
+
rp = (pw / ph) if ph > 0 else 1.0
|
| 159 |
+
|
| 160 |
+
tp = pinfo["text_len"] / float(total_text_length)
|
| 161 |
+
gp = pinfo["fig_area"] / float(total_figure_area)
|
| 162 |
+
|
| 163 |
+
results.append({
|
| 164 |
+
"tp": tp,
|
| 165 |
+
"gp": gp,
|
| 166 |
+
"sp": sp,
|
| 167 |
+
"rp": rp
|
| 168 |
+
})
|
| 169 |
+
|
| 170 |
+
return results
|
| 171 |
+
|
| 172 |
+
def train_panel_attribute_inference(panel_records):
|
| 173 |
+
"""
|
| 174 |
+
The training data `panel_records` is a list of dicts, each containing:
|
| 175 |
+
{
|
| 176 |
+
'tp': float,
|
| 177 |
+
'gp': float,
|
| 178 |
+
'sp': float, # (label for the sp regression)
|
| 179 |
+
'rp': float # (label for the rp regression)
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
We'll train two linear regressors:
|
| 183 |
+
sp = w_s * [tp, gp, 1]
|
| 184 |
+
rp = w_r * [tp, gp, 1]
|
| 185 |
+
|
| 186 |
+
Returns dict with learned parameters:
|
| 187 |
+
{
|
| 188 |
+
'w_s': array, # shape (3,) => sp = w_s[0]*tp + w_s[1]*gp + w_s[2]
|
| 189 |
+
'sigma_s': float, # variance of residual for sp
|
| 190 |
+
'w_r': array,
|
| 191 |
+
'sigma_r': float
|
| 192 |
+
}
|
| 193 |
+
"""
|
| 194 |
+
# Build data arrays
|
| 195 |
+
X_list = []
|
| 196 |
+
sp_list = []
|
| 197 |
+
rp_list = []
|
| 198 |
+
|
| 199 |
+
for rec in panel_records:
|
| 200 |
+
tp = rec['tp']
|
| 201 |
+
gp = rec['gp']
|
| 202 |
+
sp = rec['sp']
|
| 203 |
+
rp = rec['rp']
|
| 204 |
+
# X = [tp, gp, 1]
|
| 205 |
+
X_list.append([tp, gp, 1.0])
|
| 206 |
+
sp_list.append(sp)
|
| 207 |
+
rp_list.append(rp)
|
| 208 |
+
|
| 209 |
+
X_array = np.array(X_list, dtype=float)
|
| 210 |
+
y_sp = np.array(sp_list, dtype=float)
|
| 211 |
+
y_rp = np.array(rp_list, dtype=float)
|
| 212 |
+
|
| 213 |
+
# Fit linear regression for sp
|
| 214 |
+
linreg_sp = LinearRegression(fit_intercept=False)
|
| 215 |
+
linreg_sp.fit(X_array, y_sp)
|
| 216 |
+
w_s = linreg_sp.coef_
|
| 217 |
+
pred_sp = linreg_sp.predict(X_array)
|
| 218 |
+
residual_sp = y_sp - pred_sp
|
| 219 |
+
sigma_s = np.var(residual_sp, ddof=1)
|
| 220 |
+
|
| 221 |
+
# Fit linear regression for rp
|
| 222 |
+
linreg_rp = LinearRegression(fit_intercept=False)
|
| 223 |
+
linreg_rp.fit(X_array, y_rp)
|
| 224 |
+
w_r = linreg_rp.coef_
|
| 225 |
+
pred_rp = linreg_rp.predict(X_array)
|
| 226 |
+
residual_rp = y_rp - pred_rp
|
| 227 |
+
sigma_r = np.var(residual_rp, ddof=1)
|
| 228 |
+
|
| 229 |
+
model_params = {
|
| 230 |
+
"w_s": w_s,
|
| 231 |
+
"sigma_s": sigma_s,
|
| 232 |
+
"w_r": w_r,
|
| 233 |
+
"sigma_r": sigma_r
|
| 234 |
+
}
|
| 235 |
+
return model_params
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def parse_poster_xml_for_figures(xml_path):
|
| 239 |
+
root = parse_xml_with_recovery(xml_path)
|
| 240 |
+
|
| 241 |
+
poster_w = float(root.get("Width", "1"))
|
| 242 |
+
poster_h = float(root.get("Height", "1"))
|
| 243 |
+
poster_area = poster_w * poster_h
|
| 244 |
+
|
| 245 |
+
records = []
|
| 246 |
+
|
| 247 |
+
for panel in root.findall("Panel"):
|
| 248 |
+
px, py = float(panel.get("left", 0)), float(panel.get("right", 0))
|
| 249 |
+
pw, ph = float(panel.get("width", 1)), float(panel.get("height", 1))
|
| 250 |
+
panel_area = pw * ph
|
| 251 |
+
sp = panel_area / poster_area
|
| 252 |
+
rp = pw / ph if ph > 0 else 1.0
|
| 253 |
+
|
| 254 |
+
lp = sum(len(t.text.strip()) for t in panel.findall("Text") if t.text)
|
| 255 |
+
|
| 256 |
+
for fig in panel.findall("Figure"):
|
| 257 |
+
fx, fy = float(fig.get("left", 0)), float(fig.get("right", 0))
|
| 258 |
+
fw, fh = float(fig.get("width", 1)), float(fig.get("height", 1))
|
| 259 |
+
|
| 260 |
+
sg = (fw * fh) / poster_area
|
| 261 |
+
rg = fw / fh if fh > 0 else 1.0
|
| 262 |
+
ug = fw / pw if pw > 0 else 0.1
|
| 263 |
+
|
| 264 |
+
panel_center_x = px + pw / 2
|
| 265 |
+
fig_center_x = fx + fw / 2
|
| 266 |
+
delta_x = fig_center_x - panel_center_x
|
| 267 |
+
|
| 268 |
+
hg = 0 if delta_x < -pw / 6 else (2 if delta_x > pw / 6 else 1)
|
| 269 |
+
|
| 270 |
+
record = {"sp": sp, "rp": rp, "lp": lp, "sg": sg, "rg": rg, "hg": hg, "ug": ug}
|
| 271 |
+
records.append(record)
|
| 272 |
+
|
| 273 |
+
return records
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def train_figure_model(figure_records):
|
| 277 |
+
X_hg, y_hg, X_ug, y_ug = [], [], [], []
|
| 278 |
+
for r in figure_records:
|
| 279 |
+
feats = [r["sp"], r["lp"], r["sg"], 1.0]
|
| 280 |
+
X_hg.append(feats)
|
| 281 |
+
y_hg.append(r["hg"])
|
| 282 |
+
X_ug.append(feats)
|
| 283 |
+
y_ug.append(r["ug"])
|
| 284 |
+
|
| 285 |
+
clf_hg = LogisticRegression(multi_class="multinomial", solver="lbfgs", fit_intercept=False)
|
| 286 |
+
clf_hg.fit(X_hg, y_hg)
|
| 287 |
+
|
| 288 |
+
lin_ug = LinearRegression(fit_intercept=False)
|
| 289 |
+
lin_ug.fit(X_ug, y_ug)
|
| 290 |
+
residuals = y_ug - lin_ug.predict(X_ug)
|
| 291 |
+
sigma_u = np.var(residuals, ddof=1)
|
| 292 |
+
|
| 293 |
+
return {
|
| 294 |
+
"clf_hg": clf_hg,
|
| 295 |
+
"w_u": lin_ug.coef_,
|
| 296 |
+
"sigma_u": sigma_u
|
| 297 |
+
}
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def main_train():
|
| 301 |
+
poster_dataset_path = 'assets/poster_data/Train'
|
| 302 |
+
# loop through all folders in the dataset
|
| 303 |
+
xml_files = []
|
| 304 |
+
for folder in os.listdir(poster_dataset_path):
|
| 305 |
+
folder_path = os.path.join(poster_dataset_path, folder)
|
| 306 |
+
if os.path.isdir(folder_path):
|
| 307 |
+
# find all XML files in this folder
|
| 308 |
+
xml_files.extend(glob.glob(os.path.join(folder_path, "*.txt")))
|
| 309 |
+
|
| 310 |
+
all_panel_records = []
|
| 311 |
+
for xml_file in xml_files:
|
| 312 |
+
poster_data = parse_poster_xml(xml_file)
|
| 313 |
+
# compute tp, gp, sp, rp
|
| 314 |
+
panel_attrs = compute_panel_attributes(poster_data)
|
| 315 |
+
# each panel_attrs entry is {tp, gp, sp, rp}
|
| 316 |
+
all_panel_records.extend(panel_attrs)
|
| 317 |
+
|
| 318 |
+
all_figure_records = []
|
| 319 |
+
for xml_path in xml_files:
|
| 320 |
+
recs = parse_poster_xml_for_figures(xml_path)
|
| 321 |
+
all_figure_records.extend(recs)
|
| 322 |
+
|
| 323 |
+
panel_model_params = train_panel_attribute_inference(all_panel_records)
|
| 324 |
+
figure_model_params = train_figure_model(all_figure_records)
|
| 325 |
+
|
| 326 |
+
return panel_model_params, figure_model_params
|
| 327 |
+
|
| 328 |
+
def place_text_and_figures_exact(panel_dict, figure_model_params, section_title_height=32):
|
| 329 |
+
"""
|
| 330 |
+
Lay out text and figure boxes inside a panel.
|
| 331 |
+
|
| 332 |
+
The figure’s aspect ratio (width / height) is now enforced strictly:
|
| 333 |
+
• width ≤ panel width
|
| 334 |
+
• height ≤ 0.60 × panel height (empirical upper‑bound you already used)
|
| 335 |
+
��� width / height == panel_dict["figure_aspect"]
|
| 336 |
+
"""
|
| 337 |
+
# ---------------- Constants used for text layout -----------------
|
| 338 |
+
char_width_px = 7
|
| 339 |
+
line_height_px = 16
|
| 340 |
+
chars_per_line = max(int(panel_dict["width"] / char_width_px), 1)
|
| 341 |
+
|
| 342 |
+
total_lines_text = np.ceil(panel_dict["text_len"] / chars_per_line)
|
| 343 |
+
total_text_height = total_lines_text * line_height_px
|
| 344 |
+
|
| 345 |
+
x_p, y_p = panel_dict["x"], panel_dict["y"]
|
| 346 |
+
w_p, h_p = panel_dict["width"], panel_dict["height"]
|
| 347 |
+
|
| 348 |
+
figure_boxes, text_boxes = [], []
|
| 349 |
+
|
| 350 |
+
panel_name_lower = panel_dict["panel_name"].lower()
|
| 351 |
+
has_title_in_name = "title" in panel_name_lower
|
| 352 |
+
|
| 353 |
+
# -------------------------------------------------------
|
| 354 |
+
# Helper to build a text‑box dict
|
| 355 |
+
# -------------------------------------------------------
|
| 356 |
+
def make_text_box(panel_id, x, y, width, height, textbox_id, textbox_name):
|
| 357 |
+
return {
|
| 358 |
+
"panel_id": panel_id,
|
| 359 |
+
"x": float(x),
|
| 360 |
+
"y": float(y),
|
| 361 |
+
"width": float(width),
|
| 362 |
+
"height": float(height),
|
| 363 |
+
"textbox_id": textbox_id,
|
| 364 |
+
"textbox_name": textbox_name,
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
# -----------------------------------------------------------------------
|
| 368 |
+
# Case 1 — no figure in this panel
|
| 369 |
+
# -----------------------------------------------------------------------
|
| 370 |
+
if panel_dict["figure_size"] <= 0:
|
| 371 |
+
if has_title_in_name:
|
| 372 |
+
text_boxes.append(
|
| 373 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, h_p,
|
| 374 |
+
textbox_id=0,
|
| 375 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t0')
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
title_h = min(section_title_height, h_p)
|
| 379 |
+
text_boxes.extend([
|
| 380 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, title_h,
|
| 381 |
+
textbox_id=0,
|
| 382 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t0'),
|
| 383 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p + title_h, w_p, h_p - title_h,
|
| 384 |
+
textbox_id=1,
|
| 385 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t1'),
|
| 386 |
+
])
|
| 387 |
+
return text_boxes, figure_boxes # early‑return (simpler branch)
|
| 388 |
+
|
| 389 |
+
# -----------------------------------------------------------------------
|
| 390 |
+
# Case 2 — there *is* a figure
|
| 391 |
+
# -----------------------------------------------------------------------
|
| 392 |
+
# 1. Sample horizontal‑alignment class (hg) and raw width fraction (ug)
|
| 393 |
+
feat = np.array([panel_dict["sp"],
|
| 394 |
+
panel_dict["text_len"],
|
| 395 |
+
panel_dict["figure_size"],
|
| 396 |
+
1.0]).reshape(1, -1)
|
| 397 |
+
|
| 398 |
+
clf_hg = figure_model_params["clf_hg"]
|
| 399 |
+
hg_sample = int(np.argmax(clf_hg.predict_proba(feat)[0]))
|
| 400 |
+
|
| 401 |
+
mean_ug = float(np.dot(figure_model_params["w_u"], feat.flatten()))
|
| 402 |
+
sigma_u = float(np.sqrt(figure_model_params["sigma_u"]))
|
| 403 |
+
ug_sample = float(np.clip(np.random.normal(mean_ug, sigma_u), 0.10, 0.80)) # 10‑80 % of width
|
| 404 |
+
|
| 405 |
+
# 2. **Size the figure while *preserving* aspect ratio**
|
| 406 |
+
aspect = float(panel_dict["figure_aspect"]) # width / height
|
| 407 |
+
fig_w = ug_sample * w_p # preliminary width
|
| 408 |
+
fig_h = fig_w / aspect
|
| 409 |
+
|
| 410 |
+
max_fig_h = 0.60 * h_p # same limit you had
|
| 411 |
+
if fig_h > max_fig_h: # too tall → scale down
|
| 412 |
+
scale = max_fig_h / fig_h
|
| 413 |
+
fig_w *= scale
|
| 414 |
+
fig_h = max_fig_h # (ratio still intact)
|
| 415 |
+
|
| 416 |
+
# 3. Horizontal placement
|
| 417 |
+
if hg_sample == 0: # left
|
| 418 |
+
fig_x = x_p
|
| 419 |
+
elif hg_sample == 2: # right
|
| 420 |
+
fig_x = x_p + w_p - fig_w
|
| 421 |
+
else: # center
|
| 422 |
+
fig_x = x_p + 0.5 * (w_p - fig_w)
|
| 423 |
+
# Vertical centering
|
| 424 |
+
fig_y = y_p + 0.5 * (h_p - fig_h)
|
| 425 |
+
|
| 426 |
+
# 4. Split text into “top” and “bottom” areas around the figure
|
| 427 |
+
top_text_h = (fig_y - y_p)
|
| 428 |
+
bottom_text_h = (y_p + h_p) - (fig_y + fig_h)
|
| 429 |
+
|
| 430 |
+
# --- build top‑text boxes
|
| 431 |
+
if has_title_in_name:
|
| 432 |
+
text_boxes.append(
|
| 433 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, top_text_h,
|
| 434 |
+
textbox_id=0,
|
| 435 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t0')
|
| 436 |
+
)
|
| 437 |
+
next_id = 1
|
| 438 |
+
else:
|
| 439 |
+
title_h = min(section_title_height, top_text_h)
|
| 440 |
+
text_boxes.extend([
|
| 441 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p, w_p, title_h,
|
| 442 |
+
textbox_id=0,
|
| 443 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t0'),
|
| 444 |
+
make_text_box(panel_dict["panel_id"], x_p, y_p + title_h, w_p, top_text_h - title_h,
|
| 445 |
+
textbox_id=1,
|
| 446 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t1'),
|
| 447 |
+
])
|
| 448 |
+
next_id = 2
|
| 449 |
+
|
| 450 |
+
# --- bottom text box
|
| 451 |
+
text_boxes.append(
|
| 452 |
+
make_text_box(panel_dict["panel_id"], x_p, fig_y + fig_h, w_p, bottom_text_h,
|
| 453 |
+
textbox_id=next_id,
|
| 454 |
+
textbox_name=f'p<{panel_dict["panel_name"]}>_t{next_id}')
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
# 5. Figure box
|
| 458 |
+
figure_boxes.append({
|
| 459 |
+
"panel_id": panel_dict["panel_id"],
|
| 460 |
+
"x": float(fig_x),
|
| 461 |
+
"y": float(fig_y),
|
| 462 |
+
"width": float(fig_w),
|
| 463 |
+
"height": float(fig_h),
|
| 464 |
+
"figure_id": 0,
|
| 465 |
+
"figure_name": f'p<{panel_dict["panel_name"]}>_f0',
|
| 466 |
+
})
|
| 467 |
+
|
| 468 |
+
return text_boxes, figure_boxes
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def to_inches(value_in_units, units_per_inch=72):
|
| 472 |
+
"""
|
| 473 |
+
Convert a single coordinate or dimension from 'units' to inches.
|
| 474 |
+
For example, if your units are 'points' (72 points = 1 inch),
|
| 475 |
+
then units_per_inch=72.
|
| 476 |
+
If your units are 'pixels' at 96 DPI, then units_per_inch=96.
|
| 477 |
+
"""
|
| 478 |
+
return value_in_units / units_per_inch
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def from_inches(value_in_inches, units_per_inch=72):
|
| 482 |
+
"""
|
| 483 |
+
Convert from inches back to the original 'units'.
|
| 484 |
+
"""
|
| 485 |
+
return value_in_inches * units_per_inch
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def softmax(logits):
|
| 489 |
+
s = sum(np.exp(logits))
|
| 490 |
+
return [np.exp(l)/s for l in logits]
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
def infer_panel_attrs(panel_model, tp, gp):
|
| 494 |
+
# sp = w_s dot [tp, gp, 1]
|
| 495 |
+
# rp = w_r dot [tp, gp, 1]
|
| 496 |
+
vec = np.array([tp, gp, 1.0])
|
| 497 |
+
w_s = panel_model["w_s"]
|
| 498 |
+
w_r = panel_model["w_r"]
|
| 499 |
+
sp = np.dot(w_s, vec)
|
| 500 |
+
rp = np.dot(w_r, vec)
|
| 501 |
+
# clamp
|
| 502 |
+
sp = max(sp, 0.01)
|
| 503 |
+
rp = max(rp, 0.05)
|
| 504 |
+
return sp, rp
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
def panel_layout_generation(panels, x, y, w, h):
|
| 508 |
+
# If only 1 panel, place it entirely
|
| 509 |
+
if len(panels) == 1:
|
| 510 |
+
p = panels[0]
|
| 511 |
+
cur_rp = (w/h) if h>1e-9 else p["rp"]
|
| 512 |
+
loss = abs(p["rp"] - cur_rp)
|
| 513 |
+
arrangement = [{
|
| 514 |
+
"panel_name": p["section_name"],
|
| 515 |
+
"panel_id": p["panel_id"],
|
| 516 |
+
"x": x, "y": y,
|
| 517 |
+
"width": w, "height": h
|
| 518 |
+
}]
|
| 519 |
+
return loss, arrangement
|
| 520 |
+
|
| 521 |
+
best_loss = float('inf')
|
| 522 |
+
best_arr = []
|
| 523 |
+
total_sp = sum(pp["sp"] for pp in panels)
|
| 524 |
+
n = len(panels)
|
| 525 |
+
|
| 526 |
+
for i in range(1, n):
|
| 527 |
+
subset1 = panels[:i]
|
| 528 |
+
subset2 = panels[i:]
|
| 529 |
+
sp1 = sum(pp["sp"] for pp in subset1)
|
| 530 |
+
ratio = sp1 / total_sp
|
| 531 |
+
|
| 532 |
+
# horizontal
|
| 533 |
+
h_top = ratio * h
|
| 534 |
+
if 0 < h_top < h:
|
| 535 |
+
l1, a1 = panel_layout_generation(subset1, x, y, w, h_top)
|
| 536 |
+
l2, a2 = panel_layout_generation(subset2, x, y + h_top, w, h - h_top)
|
| 537 |
+
if (l1 + l2) < best_loss:
|
| 538 |
+
best_loss = l1 + l2
|
| 539 |
+
best_arr = a1 + a2
|
| 540 |
+
|
| 541 |
+
# vertical
|
| 542 |
+
w_left = ratio * w
|
| 543 |
+
if 0 < w_left < w:
|
| 544 |
+
l1, a1 = panel_layout_generation(subset1, x, y, w_left, h)
|
| 545 |
+
l2, a2 = panel_layout_generation(subset2, x + w_left, y, w - w_left, h)
|
| 546 |
+
if (l1 + l2) < best_loss:
|
| 547 |
+
best_loss = l1 + l2
|
| 548 |
+
best_arr = a1 + a2
|
| 549 |
+
|
| 550 |
+
return best_loss, best_arr
|
| 551 |
+
|
| 552 |
+
def split_textbox(textbox, ratio):
|
| 553 |
+
"""
|
| 554 |
+
Splits a textbox dictionary horizontally into two parts.
|
| 555 |
+
|
| 556 |
+
Parameters:
|
| 557 |
+
textbox (dict): A dictionary with the keys
|
| 558 |
+
'panel_id', 'x', 'y', 'width', 'height', 'textbox_id', 'textbox_name'
|
| 559 |
+
ratio (float or int): Ratio of top height to bottom height.
|
| 560 |
+
For example, if ratio is 3, then:
|
| 561 |
+
top_height = (3/4) * height
|
| 562 |
+
bottom_height = (1/4) * height
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
tuple: Two dictionaries corresponding to the top and bottom split textboxes.
|
| 566 |
+
"""
|
| 567 |
+
# Calculate the new heights
|
| 568 |
+
total_ratio = ratio + 1 # because the ratio represents top:bottom as (ratio):(1)
|
| 569 |
+
top_height = textbox['height'] * ratio / total_ratio
|
| 570 |
+
bottom_height = textbox['height'] * 1 / total_ratio
|
| 571 |
+
|
| 572 |
+
# Derive the base textbox name by splitting off the existing _t suffix if present.
|
| 573 |
+
# This assumes the original textbox_name ends with "_t<number>".
|
| 574 |
+
base_name = textbox['textbox_name'].rsplit('_t', 1)[0]
|
| 575 |
+
|
| 576 |
+
# Create the top textbox dictionary
|
| 577 |
+
top_box = dict(textbox) # make a shallow copy
|
| 578 |
+
top_box['height'] = top_height
|
| 579 |
+
# y remains the same for the top textbox
|
| 580 |
+
top_box['textbox_name'] = f"{base_name}_t0" # rename with _t0
|
| 581 |
+
|
| 582 |
+
# Create the bottom textbox dictionary
|
| 583 |
+
bottom_box = dict(textbox) # make a shallow copy
|
| 584 |
+
bottom_box['y'] = textbox['y'] + top_height # adjust the y position
|
| 585 |
+
bottom_box['height'] = bottom_height
|
| 586 |
+
bottom_box['textbox_name'] = f"{base_name}_t1" # rename with _t1
|
| 587 |
+
|
| 588 |
+
return top_box, bottom_box
|
| 589 |
+
|
| 590 |
+
def generate_constrained_layout(paper_panels, poster_w, poster_h, title_height_ratio=0.1):
|
| 591 |
+
# Find title panel explicitly
|
| 592 |
+
try:
|
| 593 |
+
title_panel = next(p for p in paper_panels if ('title' in p["section_name"].lower()))
|
| 594 |
+
other_panels = [p for p in paper_panels if ('title' not in p["section_name"].lower())]
|
| 595 |
+
except StopIteration:
|
| 596 |
+
print('Oops, no title found, please try again.')
|
| 597 |
+
raise
|
| 598 |
+
|
| 599 |
+
title_h = poster_h * title_height_ratio
|
| 600 |
+
title_layout = {
|
| 601 |
+
"panel_name": title_panel["section_name"],
|
| 602 |
+
"panel_id": title_panel["panel_id"],
|
| 603 |
+
"x": 0, "y": 0,
|
| 604 |
+
"width": poster_w, "height": title_h
|
| 605 |
+
}
|
| 606 |
+
|
| 607 |
+
# Generate recursive layout on remaining space for other panels
|
| 608 |
+
layout_loss, remaining_layout = panel_layout_generation(
|
| 609 |
+
other_panels,
|
| 610 |
+
x=0, y=title_h,
|
| 611 |
+
w=poster_w, h=poster_h - title_h
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
# Combine title panel with others
|
| 615 |
+
complete_layout = [title_layout] + remaining_layout
|
| 616 |
+
return layout_loss, complete_layout
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def main_inference(
|
| 620 |
+
paper_panels,
|
| 621 |
+
panel_model_params,
|
| 622 |
+
figure_model_params,
|
| 623 |
+
poster_width=1200,
|
| 624 |
+
poster_height=800,
|
| 625 |
+
shrink_margin=0
|
| 626 |
+
):
|
| 627 |
+
for p in paper_panels:
|
| 628 |
+
sp, rp = infer_panel_attrs(panel_model_params, p["tp"], p["gp"])
|
| 629 |
+
p["sp"] = sp
|
| 630 |
+
p["rp"] = rp
|
| 631 |
+
|
| 632 |
+
layout_loss, panel_arrangement = generate_constrained_layout(paper_panels, poster_width, poster_height, title_height_ratio=0.1)
|
| 633 |
+
print("Panel layout cost:", layout_loss)
|
| 634 |
+
for p in panel_arrangement:
|
| 635 |
+
print("Panel:", p)
|
| 636 |
+
|
| 637 |
+
panel_map = {}
|
| 638 |
+
for p in paper_panels:
|
| 639 |
+
panel_map[p["panel_id"]] = p
|
| 640 |
+
|
| 641 |
+
final_panels = []
|
| 642 |
+
for pa in panel_arrangement:
|
| 643 |
+
# Merge bounding box with the original sp,rp data
|
| 644 |
+
pid = pa["panel_id"]
|
| 645 |
+
merged_panel = {
|
| 646 |
+
"panel_id": pid,
|
| 647 |
+
"panel_name": pa['panel_name'],
|
| 648 |
+
"x": pa["x"] + shrink_margin,
|
| 649 |
+
"y": pa["y"] + shrink_margin,
|
| 650 |
+
"width": pa["width"] - 2 * shrink_margin,
|
| 651 |
+
"height": pa["height"] - 2 * shrink_margin,
|
| 652 |
+
"sp": panel_map[pid]["sp"],
|
| 653 |
+
"rp": panel_map[pid]["rp"],
|
| 654 |
+
"text_len": panel_map[pid]["text_len"],
|
| 655 |
+
"figure_size": panel_map[pid]["figure_size"],
|
| 656 |
+
"figure_aspect": panel_map[pid]["figure_aspect"]
|
| 657 |
+
}
|
| 658 |
+
final_panels.append(merged_panel)
|
| 659 |
+
|
| 660 |
+
text_arrangement = []
|
| 661 |
+
figure_arrangement = []
|
| 662 |
+
|
| 663 |
+
for p in final_panels:
|
| 664 |
+
text_boxes, fig_boxes = place_text_and_figures_exact(p, figure_model_params)
|
| 665 |
+
text_arrangement.extend(text_boxes) # text arrangement
|
| 666 |
+
figure_arrangement.extend(fig_boxes) # figure arrangement
|
| 667 |
+
|
| 668 |
+
return panel_arrangement, figure_arrangement, text_arrangement
|
| 669 |
+
|
| 670 |
+
def visualize_complete_layout(
|
| 671 |
+
panels, text_boxes, figure_boxes, poster_width, poster_height
|
| 672 |
+
):
|
| 673 |
+
fig, ax = plt.subplots(figsize=(12,8))
|
| 674 |
+
ax.set_xlim(0, poster_width)
|
| 675 |
+
ax.set_ylim(0, poster_height)
|
| 676 |
+
ax.set_aspect('equal')
|
| 677 |
+
|
| 678 |
+
# Draw panels
|
| 679 |
+
for panel in panels:
|
| 680 |
+
rect = patches.Rectangle(
|
| 681 |
+
(panel["x"], panel["y"]), panel["width"], panel["height"],
|
| 682 |
+
linewidth=1, edgecolor='black', facecolor='none'
|
| 683 |
+
)
|
| 684 |
+
ax.add_patch(rect)
|
| 685 |
+
ax.text(
|
| 686 |
+
panel["x"] + 5, panel["y"] + panel["height"] - 5,
|
| 687 |
+
f'Panel {panel["panel_id"]}', fontsize=8, va='top', color='black'
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
# Draw text boxes
|
| 691 |
+
for txt in text_boxes:
|
| 692 |
+
rect = patches.Rectangle(
|
| 693 |
+
(txt["x"], txt["y"]), txt["width"], txt["height"],
|
| 694 |
+
linewidth=1, edgecolor='green', linestyle='-.', facecolor='none'
|
| 695 |
+
)
|
| 696 |
+
ax.add_patch(rect)
|
| 697 |
+
ax.text(
|
| 698 |
+
txt["x"] + 2, txt["y"] + txt["height"] - 2,
|
| 699 |
+
f'Text {txt["panel_id"]}', fontsize=7, color='green', va='top'
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Draw figures
|
| 703 |
+
for fig_box in figure_boxes:
|
| 704 |
+
rect = patches.Rectangle(
|
| 705 |
+
(fig_box["x"], fig_box["y"]), fig_box["width"], fig_box["height"],
|
| 706 |
+
linewidth=1, edgecolor='blue', linestyle='--', facecolor='none'
|
| 707 |
+
)
|
| 708 |
+
ax.add_patch(rect)
|
| 709 |
+
ax.text(
|
| 710 |
+
fig_box["x"] + 2, fig_box["y"] + 2,
|
| 711 |
+
f'Fig {fig_box["panel_id"]}', fontsize=7, color='blue', va='bottom'
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
plt.gca().invert_yaxis() # optional: invert y-axis if needed
|
| 715 |
+
plt.show()
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def get_arrangments_in_inches(
|
| 719 |
+
width,
|
| 720 |
+
height,
|
| 721 |
+
panel_arrangement,
|
| 722 |
+
figure_arrangement,
|
| 723 |
+
text_arrangement,
|
| 724 |
+
units_per_inch=72
|
| 725 |
+
):
|
| 726 |
+
|
| 727 |
+
panel_arrangement_inches = copy.deepcopy(panel_arrangement)
|
| 728 |
+
figure_arrangement_inches = copy.deepcopy(figure_arrangement)
|
| 729 |
+
text_arrangement_inches = copy.deepcopy(text_arrangement)
|
| 730 |
+
|
| 731 |
+
for p in panel_arrangement_inches:
|
| 732 |
+
p["x"] = to_inches(p["x"], units_per_inch)
|
| 733 |
+
p["y"] = to_inches(p["y"], units_per_inch)
|
| 734 |
+
p["width"] = to_inches(p["width"], units_per_inch)
|
| 735 |
+
p["height"] = to_inches(p["height"], units_per_inch)
|
| 736 |
+
|
| 737 |
+
for f in figure_arrangement_inches:
|
| 738 |
+
f["x"] = to_inches(f["x"], units_per_inch)
|
| 739 |
+
f["y"] = to_inches(f["y"], units_per_inch)
|
| 740 |
+
f["width"] = to_inches(f["width"], units_per_inch)
|
| 741 |
+
f["height"] = to_inches(f["height"], units_per_inch)
|
| 742 |
+
|
| 743 |
+
for t in text_arrangement_inches:
|
| 744 |
+
t["x"] = to_inches(t["x"], units_per_inch)
|
| 745 |
+
t["y"] = to_inches(t["y"], units_per_inch)
|
| 746 |
+
t["width"] = to_inches(t["width"], units_per_inch)
|
| 747 |
+
t["height"] = to_inches(t["height"], units_per_inch)
|
| 748 |
+
|
| 749 |
+
width_inch, height_inch = to_inches(width, units_per_inch), to_inches(height, units_per_inch)
|
| 750 |
+
return width_inch, height_inch, panel_arrangement_inches, figure_arrangement_inches, text_arrangement_inches
|
Paper2Poster/README.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🎓Paper2Poster: Multimodal Poster Automation from Scientific Papers
|
| 2 |
+
# 从学术论文自动生成学术海报
|
| 3 |
+
|
| 4 |
+
<p align="center">
|
| 5 |
+
<a href="https://arxiv.org/abs/2505.21497" target="_blank"><img src="https://img.shields.io/badge/arXiv-2505.21497-red"></a>
|
| 6 |
+
<a href="https://paper2poster.github.io/" target="_blank"><img src="https://img.shields.io/badge/Project-Page-brightgreen"></a>
|
| 7 |
+
<a href="https://huggingface.co/datasets/Paper2Poster/Paper2Poster" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Dataset-orange"></a>
|
| 8 |
+
<a href="https://huggingface.co/papers/2505.21497" target="_blank"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Daily Papers-red"></a>
|
| 9 |
+
<a href="https://x.com/_akhaliq/status/1927721150584390129" target="_blank"><img alt="X (formerly Twitter) URL" src="https://img.shields.io/twitter/url?url=https%3A%2F%2Fx.com%2F_akhaliq%2Fstatus%2F1927721150584390129"></a>
|
| 10 |
+
</p>
|
| 11 |
+
|
| 12 |
+
We address **How to create a poster from a paper** and **How to evaluate poster.**
|
| 13 |
+
|
| 14 |
+

|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## 🔥 Update
|
| 18 |
+
- [x] [2025.10.13] Added automatic **logo support** for conferences and institutions, **YAML-based style customization**, a new default theme.
|
| 19 |
+
- [x] [2025.9.18] Paper2Poster has been accepted to **NeurIPS 2025 Dataset and Benchmark Track**.
|
| 20 |
+
- [x] [2025.9.3] We now support generate per section content in **parallel** for faster generation, by simply specifying `--max_workers`.
|
| 21 |
+
- [x] [2025.5.27] We release the [arXiv](https://arxiv.org/abs/2505.21497), [code](https://github.com/Paper2Poster/Paper2Poster) and [`dataset`](https://huggingface.co/datasets/Paper2Poster/Paper2Poster)
|
| 22 |
+
|
| 23 |
+
<!--## 📚 Introduction-->
|
| 24 |
+
|
| 25 |
+
**PosterAgent** is a top-down, visual-in-the-loop multi-agent system from `paper.pdf` to **editable** `poster.pptx`.
|
| 26 |
+
|
| 27 |
+

|
| 28 |
+
|
| 29 |
+
<!--A Top-down, visual-in-the-loop, efficient multi-agent pipeline, which includes (a) Parser distills the paper into a structured asset library; the (b) Planner aligns text–visual pairs into a binary‐tree layout that preserves reading order and spatial balance; and the (c) Painter-Commentor loop refines each panel by executing rendering code and using VLM feedback to eliminate overflow and ensure alignment.-->
|
| 30 |
+
|
| 31 |
+
<!---->
|
| 32 |
+
|
| 33 |
+
<!--**Paper2Poster:** A benchmark for paper to poster generation, paired with human generated poster, with a comprehensive evaluation suite, including metrics like **Visual Quality**, **Textual Coherence**, **VLM-as-Judge** and **PaperQuiz**. Notably, PaperQuiz is a novel evaluation which assume A Good poster should convey core paper content visually.-->
|
| 34 |
+
|
| 35 |
+
## 📋 Table of Contents
|
| 36 |
+
|
| 37 |
+
<!--- [📚 Introduction](#-introduction)-->
|
| 38 |
+
- [🛠️ Installation](#-installation)
|
| 39 |
+
- [🚀 Quick Start](#-quick-start)
|
| 40 |
+
- [🔮 Evaluation](#-evaluation)
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## 🛠️ Installation
|
| 44 |
+
Our Paper2Poster supports both local deployment (via [vLLM](https://docs.vllm.ai/en/v0.6.6/getting_started/installation.html)) or API-based access (e.g., GPT-4o).
|
| 45 |
+
|
| 46 |
+
**Python Environment**
|
| 47 |
+
```bash
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
**Install Libreoffice**
|
| 52 |
+
```bash
|
| 53 |
+
sudo apt install libreoffice
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
or, if you do **not** have sudo access, download `soffice` executable directly: https://www.libreoffice.org/download/download-libreoffice/, and add the executable directory to your `$PATH`.
|
| 57 |
+
|
| 58 |
+
**Install poppler**
|
| 59 |
+
```bash
|
| 60 |
+
conda install -c conda-forge poppler
|
| 61 |
+
```
|
| 62 |
+
|
| 63 |
+
**API Key**
|
| 64 |
+
|
| 65 |
+
Create a `.env` file in the project root and add your OpenAI API key:
|
| 66 |
+
|
| 67 |
+
```bash
|
| 68 |
+
OPENAI_API_KEY=<your_openai_api_key>
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Optional: Google Search API (for logo search)**
|
| 72 |
+
|
| 73 |
+
To use Google Custom Search for more reliable logo search, add these to your `.env` file:
|
| 74 |
+
|
| 75 |
+
```bash
|
| 76 |
+
GOOGLE_SEARCH_API_KEY=<your_google_search_api_key>
|
| 77 |
+
GOOGLE_SEARCH_ENGINE_ID=<your_search_engine_id>
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
---
|
| 81 |
+
|
| 82 |
+
## 🚀 Quick Start
|
| 83 |
+
Create a folder named `{paper_name}` under `{dataset_dir}`, and place your paper inside it as a PDF file named `paper.pdf`.
|
| 84 |
+
```
|
| 85 |
+
📁 {dataset_dir}/
|
| 86 |
+
└── 📁 {paper_name}/
|
| 87 |
+
└── 📄 paper.pdf
|
| 88 |
+
```
|
| 89 |
+
To use open-source models, you need to first deploy them using [vLLM](https://docs.vllm.ai/en/v0.6.6/getting_started/installation.html), ensuring the port is correctly specified in the `get_agent_config()` function in [`utils/wei_utils.py`](utils/wei_utils.py).
|
| 90 |
+
|
| 91 |
+
- [High Performance] Generate a poster with `GPT-4o`:
|
| 92 |
+
|
| 93 |
+
```bash
|
| 94 |
+
python -m PosterAgent.new_pipeline \
|
| 95 |
+
--poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
|
| 96 |
+
--model_name_t="4o" \ # LLM
|
| 97 |
+
--model_name_v="4o" \ # VLM
|
| 98 |
+
--poster_width_inches=48 \
|
| 99 |
+
--poster_height_inches=36
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
- [Economic] Generate a poster with `Qwen-2.5-7B-Instruct` and `GPT-4o`:
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
python -m PosterAgent.new_pipeline \
|
| 106 |
+
--poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
|
| 107 |
+
--model_name_t="vllm_qwen" \ # LLM
|
| 108 |
+
--model_name_v="4o" \ # VLM
|
| 109 |
+
--poster_width_inches=48 \
|
| 110 |
+
--poster_height_inches=36 \
|
| 111 |
+
--no_blank_detection # An option to disable blank detection
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
- [Local] Generate a poster with `Qwen-2.5-7B-Instruct`:
|
| 115 |
+
|
| 116 |
+
```bash
|
| 117 |
+
python -m PosterAgent.new_pipeline \
|
| 118 |
+
--poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
|
| 119 |
+
--model_name_t="vllm_qwen" \ # LLM
|
| 120 |
+
--model_name_v="vllm_qwen_vl" \ # VLM
|
| 121 |
+
--poster_width_inches=48 \
|
| 122 |
+
--poster_height_inches=36
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
PosterAgent **supports flexible combination of LLM / VLM**, feel free to try other options, or customize your own settings in `get_agent_config()` in [`utils/wei_utils.py`](utils/wei_utils.py).
|
| 126 |
+
|
| 127 |
+
### Adding Logos to Posters
|
| 128 |
+
|
| 129 |
+
You can automatically add institutional and conference logos to your posters:
|
| 130 |
+
|
| 131 |
+
```bash
|
| 132 |
+
python -m PosterAgent.new_pipeline \
|
| 133 |
+
--poster_path="${dataset_dir}/${paper_name}/paper.pdf" \
|
| 134 |
+
--model_name_t="4o" \
|
| 135 |
+
--model_name_v="4o" \
|
| 136 |
+
--poster_width_inches=48 \
|
| 137 |
+
--poster_height_inches=36 \
|
| 138 |
+
--conference_venue="NeurIPS" # Automatically searches for conference logo
|
| 139 |
+
```
|
| 140 |
+
|
| 141 |
+
**Logo Search Strategy:**
|
| 142 |
+
1. **Local search**: First checks the provided logo store (`logo_store/institutes/` and `logo_store/conferences/`)
|
| 143 |
+
2. **Web search**: If not found locally, performs online search
|
| 144 |
+
- By default, uses DuckDuckGo (no API key required)
|
| 145 |
+
- For more reliable results, use `--use_google_search` (requires `GOOGLE_SEARCH_API_KEY` and `GOOGLE_SEARCH_ENGINE_ID` in `.env`)
|
| 146 |
+
|
| 147 |
+
You can also specify custom logo paths to skip auto-detection:
|
| 148 |
+
```bash
|
| 149 |
+
--institution_logo_path="path/to/institution_logo.png" \
|
| 150 |
+
--conference_logo_path="path/to/conference_logo.png"
|
| 151 |
+
```
|
| 152 |
+
|
| 153 |
+
### YAML Style Customization
|
| 154 |
+
|
| 155 |
+
Customize poster appearance via YAML configuration files:
|
| 156 |
+
- **Global defaults**: `config/poster.yaml` (applies to all posters)
|
| 157 |
+
- **Per-poster override**: Place `poster.yaml` next to your `paper.pdf` for custom styling
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
## 🔮 Evaluation
|
| 161 |
+
Download Paper2Poster evaluation dataset via:
|
| 162 |
+
```bash
|
| 163 |
+
python -m PosterAgent.create_dataset
|
| 164 |
+
```
|
| 165 |
+
|
| 166 |
+
In evaluation, papers are stored under a directory called `Paper2Poster-data`.
|
| 167 |
+
|
| 168 |
+
To evaluate a generated poster with **PaperQuiz**:
|
| 169 |
+
```bash
|
| 170 |
+
python -m Paper2Poster-eval.eval_poster_pipeline \
|
| 171 |
+
--paper_name="${paper_name}" \
|
| 172 |
+
--poster_method="${model_t}_${model_v}_generated_posters" \
|
| 173 |
+
--metric=qa # PaperQuiz
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
To evaluate a generated poster with **VLM-as-Judge**:
|
| 177 |
+
```bash
|
| 178 |
+
python -m Paper2Poster-eval.eval_poster_pipeline \
|
| 179 |
+
--paper_name="${paper_name}" \
|
| 180 |
+
--poster_method="${model_t}_${model_v}_generated_posters" \
|
| 181 |
+
--metric=judge # VLM-as-Judge
|
| 182 |
+
```
|
| 183 |
+
|
| 184 |
+
To evaluate a generated poster with other statistical metrics (such as visual similarity, PPL, etc):
|
| 185 |
+
```bash
|
| 186 |
+
python -m Paper2Poster-eval.eval_poster_pipeline \
|
| 187 |
+
--paper_name="${paper_name}" \
|
| 188 |
+
--poster_method="${model_t}_${model_v}_generated_posters" \
|
| 189 |
+
--metric=stats # statistical measures
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
If you want to create a PaperQuiz for your own paper:
|
| 193 |
+
```bash
|
| 194 |
+
python -m Paper2Poster-eval.create_paper_questions \
|
| 195 |
+
--paper_folder="Paper2Poster-data/${paper_name}"
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
## ❤ Acknowledgement
|
| 199 |
+
We extend our gratitude to [🐫CAMEL](https://github.com/camel-ai/camel), [🦉OWL](https://github.com/camel-ai/owl), [Docling](https://github.com/docling-project/docling), [PPTAgent](https://github.com/icip-cas/PPTAgent) for providing their codebases.
|
| 200 |
+
|
| 201 |
+
## 📖 Citation
|
| 202 |
+
|
| 203 |
+
Please kindly cite our paper if you find this project helpful.
|
| 204 |
+
|
| 205 |
+
```bibtex
|
| 206 |
+
@misc{paper2poster,
|
| 207 |
+
title={Paper2Poster: Towards Multimodal Poster Automation from Scientific Papers},
|
| 208 |
+
author={Wei Pang and Kevin Qinghong Lin and Xiangru Jian and Xi He and Philip Torr},
|
| 209 |
+
year={2025},
|
| 210 |
+
eprint={2505.21497},
|
| 211 |
+
archivePrefix={arXiv},
|
| 212 |
+
primaryClass={cs.CV},
|
| 213 |
+
url={https://arxiv.org/abs/2505.21497},
|
| 214 |
+
}
|
| 215 |
+
```
|
Paper2Poster/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys, os
|
| 2 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "camel"))
|
| 3 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "docling"))
|
Paper2Poster/camel/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from camel.logger import disable_logging, enable_logging, set_log_level
|
| 16 |
+
|
| 17 |
+
__version__ = '0.2.19'
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
'__version__',
|
| 21 |
+
'camel',
|
| 22 |
+
'disable_logging',
|
| 23 |
+
'enable_logging',
|
| 24 |
+
'set_log_level',
|
| 25 |
+
]
|
Paper2Poster/camel/agents/__init__.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .base import BaseAgent
|
| 15 |
+
from .chat_agent import ChatAgent
|
| 16 |
+
from .critic_agent import CriticAgent
|
| 17 |
+
from .embodied_agent import EmbodiedAgent
|
| 18 |
+
from .knowledge_graph_agent import KnowledgeGraphAgent
|
| 19 |
+
from .role_assignment_agent import RoleAssignmentAgent
|
| 20 |
+
from .search_agent import SearchAgent
|
| 21 |
+
from .task_agent import (
|
| 22 |
+
TaskCreationAgent,
|
| 23 |
+
TaskPlannerAgent,
|
| 24 |
+
TaskPrioritizationAgent,
|
| 25 |
+
TaskSpecifyAgent,
|
| 26 |
+
)
|
| 27 |
+
from .tool_agents.base import BaseToolAgent
|
| 28 |
+
from .tool_agents.hugging_face_tool_agent import HuggingFaceToolAgent
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
'BaseAgent',
|
| 32 |
+
'ChatAgent',
|
| 33 |
+
'TaskSpecifyAgent',
|
| 34 |
+
'TaskPlannerAgent',
|
| 35 |
+
'TaskCreationAgent',
|
| 36 |
+
'TaskPrioritizationAgent',
|
| 37 |
+
'CriticAgent',
|
| 38 |
+
'BaseToolAgent',
|
| 39 |
+
'HuggingFaceToolAgent',
|
| 40 |
+
'EmbodiedAgent',
|
| 41 |
+
'RoleAssignmentAgent',
|
| 42 |
+
'SearchAgent',
|
| 43 |
+
'KnowledgeGraphAgent',
|
| 44 |
+
]
|
Paper2Poster/camel/agents/base.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from abc import ABC, abstractmethod
|
| 15 |
+
from typing import Any
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class BaseAgent(ABC):
|
| 19 |
+
r"""An abstract base class for all CAMEL agents."""
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def reset(self, *args: Any, **kwargs: Any) -> Any:
|
| 23 |
+
r"""Resets the agent to its initial state."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def step(self, *args: Any, **kwargs: Any) -> Any:
|
| 28 |
+
r"""Performs a single step of the agent."""
|
| 29 |
+
pass
|
Paper2Poster/camel/agents/chat_agent.py
ADDED
|
@@ -0,0 +1,1539 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import re
|
| 19 |
+
import uuid
|
| 20 |
+
from collections import defaultdict
|
| 21 |
+
from typing import (
|
| 22 |
+
TYPE_CHECKING,
|
| 23 |
+
Any,
|
| 24 |
+
Callable,
|
| 25 |
+
Dict,
|
| 26 |
+
List,
|
| 27 |
+
Optional,
|
| 28 |
+
Tuple,
|
| 29 |
+
Type,
|
| 30 |
+
Union,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
from openai.types.chat import ChatCompletionMessageToolCall
|
| 34 |
+
from openai.types.chat.chat_completion_message_tool_call import Function
|
| 35 |
+
from pydantic import BaseModel, ValidationError
|
| 36 |
+
|
| 37 |
+
from camel.agents.base import BaseAgent
|
| 38 |
+
from camel.memories import (
|
| 39 |
+
AgentMemory,
|
| 40 |
+
ChatHistoryMemory,
|
| 41 |
+
MemoryRecord,
|
| 42 |
+
ScoreBasedContextCreator,
|
| 43 |
+
)
|
| 44 |
+
from camel.messages import BaseMessage, FunctionCallingMessage, OpenAIMessage
|
| 45 |
+
from camel.models import (
|
| 46 |
+
BaseModelBackend,
|
| 47 |
+
ModelFactory,
|
| 48 |
+
ModelManager,
|
| 49 |
+
ModelProcessingError,
|
| 50 |
+
)
|
| 51 |
+
from camel.responses import ChatAgentResponse
|
| 52 |
+
from camel.types import (
|
| 53 |
+
ChatCompletion,
|
| 54 |
+
ChatCompletionChunk,
|
| 55 |
+
ModelPlatformType,
|
| 56 |
+
ModelType,
|
| 57 |
+
OpenAIBackendRole,
|
| 58 |
+
RoleType,
|
| 59 |
+
)
|
| 60 |
+
from camel.utils import (
|
| 61 |
+
func_string_to_callable,
|
| 62 |
+
generate_prompt_for_structured_output,
|
| 63 |
+
get_model_encoding,
|
| 64 |
+
get_pydantic_object_schema,
|
| 65 |
+
json_to_function_code,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if TYPE_CHECKING:
|
| 69 |
+
from openai import Stream
|
| 70 |
+
|
| 71 |
+
from camel.terminators import ResponseTerminator
|
| 72 |
+
from camel.toolkits import FunctionTool
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
logger = logging.getLogger(__name__)
|
| 76 |
+
|
| 77 |
+
# AgentOps decorator setting
|
| 78 |
+
try:
|
| 79 |
+
import os
|
| 80 |
+
|
| 81 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 82 |
+
from agentops import track_agent
|
| 83 |
+
else:
|
| 84 |
+
raise ImportError
|
| 85 |
+
except (ImportError, AttributeError):
|
| 86 |
+
from camel.utils import track_agent
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class FunctionCallingRecord(BaseModel):
|
| 90 |
+
r"""Historical records of functions called in the conversation.
|
| 91 |
+
|
| 92 |
+
Attributes:
|
| 93 |
+
func_name (str): The name of the function being called.
|
| 94 |
+
args (Dict[str, Any]): The dictionary of arguments passed to
|
| 95 |
+
the function.
|
| 96 |
+
result (Any): The execution result of calling this function.
|
| 97 |
+
tool_call_id (str): The ID of the tool call, if available.
|
| 98 |
+
"""
|
| 99 |
+
|
| 100 |
+
func_name: str
|
| 101 |
+
args: Dict[str, Any]
|
| 102 |
+
result: Any
|
| 103 |
+
tool_call_id: str
|
| 104 |
+
|
| 105 |
+
def __str__(self) -> str:
|
| 106 |
+
r"""Overridden version of the string function.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
str: Modified string to represent the function calling.
|
| 110 |
+
"""
|
| 111 |
+
return (
|
| 112 |
+
f"Function Execution: {self.func_name}\n"
|
| 113 |
+
f"\tArgs: {self.args}\n"
|
| 114 |
+
f"\tResult: {self.result}\n"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def as_dict(self) -> dict[str, Any]:
|
| 118 |
+
r"""Returns the function calling record as a dictionary.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
dict[str, Any]: The function calling record as a dictionary.
|
| 122 |
+
"""
|
| 123 |
+
return self.model_dump()
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
@track_agent(name="ChatAgent")
|
| 127 |
+
class ChatAgent(BaseAgent):
|
| 128 |
+
r"""Class for managing conversations of CAMEL Chat Agents.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
system_message (Union[BaseMessage, str], optional): The system message
|
| 132 |
+
for the chat agent.
|
| 133 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 134 |
+
generating responses. (default: :obj:`ModelPlatformType.DEFAULT`
|
| 135 |
+
with `ModelType.DEFAULT`)
|
| 136 |
+
memory (AgentMemory, optional): The agent memory for managing chat
|
| 137 |
+
messages. If `None`, a :obj:`ChatHistoryMemory` will be used.
|
| 138 |
+
(default: :obj:`None`)
|
| 139 |
+
message_window_size (int, optional): The maximum number of previous
|
| 140 |
+
messages to include in the context window. If `None`, no windowing
|
| 141 |
+
is performed. (default: :obj:`None`)
|
| 142 |
+
token_limit (int, optional): The maximum number of tokens in a context.
|
| 143 |
+
The context will be automatically pruned to fulfill the limitation.
|
| 144 |
+
If `None`, it will be set according to the backend model.
|
| 145 |
+
(default: :obj:`None`)
|
| 146 |
+
output_language (str, optional): The language to be output by the
|
| 147 |
+
agent. (default: :obj:`None`)
|
| 148 |
+
tools (Optional[List[Union[FunctionTool, Callable]]], optional): List
|
| 149 |
+
of available :obj:`FunctionTool` or :obj:`Callable`. (default:
|
| 150 |
+
:obj:`None`)
|
| 151 |
+
external_tools (Optional[List[Union[FunctionTool, Callable]]],
|
| 152 |
+
optional): List of external tools (:obj:`FunctionTool` or or
|
| 153 |
+
:obj:`Callable`) bind to one chat agent. When these tools are
|
| 154 |
+
called, the agent will directly return the request instead of
|
| 155 |
+
processing it. (default: :obj:`None`)
|
| 156 |
+
response_terminators (List[ResponseTerminator], optional): List of
|
| 157 |
+
:obj:`ResponseTerminator` bind to one chat agent.
|
| 158 |
+
(default: :obj:`None`)
|
| 159 |
+
scheduling_strategy (str): name of function that defines how to select
|
| 160 |
+
the next model in ModelManager. (default: :str:`round_robin`)
|
| 161 |
+
single_iteration (bool): Whether to let the agent perform only one
|
| 162 |
+
model calling at each step. (default: :obj:`False`)
|
| 163 |
+
"""
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
system_message: Optional[Union[BaseMessage, str]] = None,
|
| 168 |
+
model: Optional[
|
| 169 |
+
Union[BaseModelBackend, List[BaseModelBackend]]
|
| 170 |
+
] = None,
|
| 171 |
+
memory: Optional[AgentMemory] = None,
|
| 172 |
+
message_window_size: Optional[int] = None,
|
| 173 |
+
token_limit: Optional[int] = None,
|
| 174 |
+
output_language: Optional[str] = None,
|
| 175 |
+
tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
| 176 |
+
external_tools: Optional[List[Union[FunctionTool, Callable]]] = None,
|
| 177 |
+
response_terminators: Optional[List[ResponseTerminator]] = None,
|
| 178 |
+
scheduling_strategy: str = "round_robin",
|
| 179 |
+
single_iteration: bool = False,
|
| 180 |
+
) -> None:
|
| 181 |
+
# Initialize the system message, converting string to BaseMessage if needed
|
| 182 |
+
if isinstance(system_message, str):
|
| 183 |
+
system_message = BaseMessage.make_assistant_message(
|
| 184 |
+
role_name='Assistant', content=system_message
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
self.orig_sys_message: Optional[BaseMessage] = system_message
|
| 188 |
+
self._system_message: Optional[BaseMessage] = system_message
|
| 189 |
+
self.role_name: str = (
|
| 190 |
+
getattr(system_message, 'role_name', None) or "assistant"
|
| 191 |
+
)
|
| 192 |
+
self.role_type: RoleType = (
|
| 193 |
+
getattr(system_message, 'role_type', None) or RoleType.ASSISTANT
|
| 194 |
+
)
|
| 195 |
+
self.model_backend = ModelManager(
|
| 196 |
+
model
|
| 197 |
+
if model is not None
|
| 198 |
+
else ModelFactory.create(
|
| 199 |
+
model_platform=ModelPlatformType.DEFAULT,
|
| 200 |
+
model_type=ModelType.DEFAULT,
|
| 201 |
+
),
|
| 202 |
+
scheduling_strategy=scheduling_strategy,
|
| 203 |
+
)
|
| 204 |
+
self.model_type = self.model_backend.model_type
|
| 205 |
+
|
| 206 |
+
# Initialize tools
|
| 207 |
+
self.tools: List[FunctionTool] = (
|
| 208 |
+
self._initialize_tools(tools) if tools else []
|
| 209 |
+
)
|
| 210 |
+
self.external_tools: List[FunctionTool] = (
|
| 211 |
+
self._initialize_tools(external_tools) if external_tools else []
|
| 212 |
+
)
|
| 213 |
+
self.external_tool_names: List[str] = [
|
| 214 |
+
tool.get_function_name() for tool in self.external_tools
|
| 215 |
+
]
|
| 216 |
+
self.all_tools = self.tools + self.external_tools or []
|
| 217 |
+
|
| 218 |
+
# Create tool dictionaries and configure backend tools if necessary
|
| 219 |
+
self.tool_dict = {
|
| 220 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 221 |
+
}
|
| 222 |
+
|
| 223 |
+
# If the user set tools from `ChatAgent`, it will override the
|
| 224 |
+
# configured tools in `BaseModelBackend`.
|
| 225 |
+
if self.all_tools:
|
| 226 |
+
logger.warning(
|
| 227 |
+
"Overriding the configured tools in `BaseModelBackend` with the tools from `ChatAgent`."
|
| 228 |
+
)
|
| 229 |
+
tool_schema_list = [
|
| 230 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 231 |
+
]
|
| 232 |
+
self.model_backend.model_config_dict['tools'] = tool_schema_list
|
| 233 |
+
|
| 234 |
+
self.model_token_limit = token_limit or self.model_backend.token_limit
|
| 235 |
+
context_creator = ScoreBasedContextCreator(
|
| 236 |
+
self.model_backend.token_counter,
|
| 237 |
+
self.model_token_limit,
|
| 238 |
+
)
|
| 239 |
+
self.memory: AgentMemory = memory or ChatHistoryMemory(
|
| 240 |
+
context_creator, window_size=message_window_size
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
self.output_language: Optional[str] = output_language
|
| 244 |
+
if self.output_language is not None:
|
| 245 |
+
self.set_output_language(self.output_language)
|
| 246 |
+
|
| 247 |
+
self.terminated: bool = False
|
| 248 |
+
self.response_terminators = response_terminators or []
|
| 249 |
+
self.init_messages()
|
| 250 |
+
self.tool_prompt_added = False
|
| 251 |
+
self.single_iteration = single_iteration
|
| 252 |
+
|
| 253 |
+
def _initialize_tools(
|
| 254 |
+
self, tools: List[Union[FunctionTool, Callable]]
|
| 255 |
+
) -> List[FunctionTool]:
|
| 256 |
+
r"""Helper method to initialize tools as FunctionTool instances."""
|
| 257 |
+
from camel.toolkits import FunctionTool
|
| 258 |
+
|
| 259 |
+
func_tools = []
|
| 260 |
+
for tool in tools:
|
| 261 |
+
if not isinstance(tool, FunctionTool):
|
| 262 |
+
tool = FunctionTool(tool)
|
| 263 |
+
func_tools.append(tool)
|
| 264 |
+
return func_tools
|
| 265 |
+
|
| 266 |
+
def add_tool(
|
| 267 |
+
self, tool: Union[FunctionTool, Callable], is_external: bool = False
|
| 268 |
+
) -> None:
|
| 269 |
+
r"""Add a tool to the agent, specifying if it's an external tool."""
|
| 270 |
+
# Initialize the tool
|
| 271 |
+
initialized_tool = self._initialize_tools([tool])
|
| 272 |
+
|
| 273 |
+
# Update tools or external tools based on is_external flag
|
| 274 |
+
if is_external:
|
| 275 |
+
self.external_tools = self.external_tools + initialized_tool
|
| 276 |
+
self.external_tool_names.extend(
|
| 277 |
+
tool.get_function_name() for tool in initialized_tool
|
| 278 |
+
)
|
| 279 |
+
else:
|
| 280 |
+
self.tools = self.tools + initialized_tool
|
| 281 |
+
|
| 282 |
+
# Rebuild all_tools, and tool_dict
|
| 283 |
+
self.all_tools = self.tools + self.external_tools
|
| 284 |
+
self.tool_dict = {
|
| 285 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
tool_schema_list = [
|
| 289 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 290 |
+
]
|
| 291 |
+
self.model_backend.model_config_dict['tools'] = tool_schema_list
|
| 292 |
+
|
| 293 |
+
def remove_tool(self, tool_name: str, is_external: bool = False) -> bool:
|
| 294 |
+
r"""Remove a tool by name, specifying if it's an external tool."""
|
| 295 |
+
tool_list = self.external_tools if is_external else self.tools
|
| 296 |
+
if not tool_list:
|
| 297 |
+
return False
|
| 298 |
+
|
| 299 |
+
for tool in tool_list:
|
| 300 |
+
if tool.get_function_name() == tool_name:
|
| 301 |
+
tool_list.remove(tool)
|
| 302 |
+
if is_external:
|
| 303 |
+
self.external_tool_names.remove(tool_name)
|
| 304 |
+
# Reinitialize the tool dictionary
|
| 305 |
+
self.all_tools = (self.tools or []) + (
|
| 306 |
+
self.external_tools or []
|
| 307 |
+
)
|
| 308 |
+
self.tool_dict = {
|
| 309 |
+
tool.get_function_name(): tool for tool in self.all_tools
|
| 310 |
+
}
|
| 311 |
+
tool_schema_list = [
|
| 312 |
+
tool.get_openai_tool_schema() for tool in self.all_tools
|
| 313 |
+
]
|
| 314 |
+
self.model_backend.model_config_dict['tools'] = (
|
| 315 |
+
tool_schema_list
|
| 316 |
+
)
|
| 317 |
+
return True
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
def list_tools(self) -> dict:
|
| 321 |
+
r"""List all tools, separated into normal and external tools."""
|
| 322 |
+
normal_tools = [
|
| 323 |
+
tool.get_function_name() for tool in (self.tools or [])
|
| 324 |
+
]
|
| 325 |
+
external_tools = [
|
| 326 |
+
tool.get_function_name() for tool in (self.external_tools or [])
|
| 327 |
+
]
|
| 328 |
+
|
| 329 |
+
return {"normal_tools": normal_tools, "external_tools": external_tools}
|
| 330 |
+
|
| 331 |
+
# ruff: noqa: E501
|
| 332 |
+
def _generate_tool_prompt(self, tool_schema_list: List[Dict]) -> str:
|
| 333 |
+
r"""Generates a tool prompt based on the provided tool schema list.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
tool_schema_list (List[Dict]): A list of dictionaries, each
|
| 337 |
+
containing a tool schema.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
str: A string representing the tool prompt.
|
| 341 |
+
"""
|
| 342 |
+
tool_prompts = []
|
| 343 |
+
|
| 344 |
+
for tool in tool_schema_list:
|
| 345 |
+
tool_info = tool['function']
|
| 346 |
+
tool_name = tool_info['name']
|
| 347 |
+
tool_description = tool_info['description']
|
| 348 |
+
tool_json = json.dumps(tool_info, indent=4)
|
| 349 |
+
|
| 350 |
+
prompt = f"Use the function '{tool_name}' to '{tool_description}':\n{tool_json}\n"
|
| 351 |
+
tool_prompts.append(prompt)
|
| 352 |
+
|
| 353 |
+
tool_prompt_str = "\n".join(tool_prompts)
|
| 354 |
+
|
| 355 |
+
final_prompt = f"""
|
| 356 |
+
You have access to the following functions:
|
| 357 |
+
|
| 358 |
+
{tool_prompt_str}
|
| 359 |
+
|
| 360 |
+
If you choose to call a function ONLY reply in the following format with no
|
| 361 |
+
prefix or suffix:
|
| 362 |
+
|
| 363 |
+
<function=example_function_name>{{"example_name": "example_value"}}</function>
|
| 364 |
+
|
| 365 |
+
Reminder:
|
| 366 |
+
- Function calls MUST follow the specified format, start with <function= and end with </function>
|
| 367 |
+
- Required parameters MUST be specified
|
| 368 |
+
- Only call one function at a time
|
| 369 |
+
- Put the entire function call reply on one line
|
| 370 |
+
- If there is no function call available, answer the question like normal
|
| 371 |
+
with your current knowledge and do not tell the user about function calls
|
| 372 |
+
"""
|
| 373 |
+
return final_prompt
|
| 374 |
+
|
| 375 |
+
def _parse_tool_response(self, response: str):
|
| 376 |
+
r"""Parses the tool response to extract the function name and
|
| 377 |
+
arguments.
|
| 378 |
+
|
| 379 |
+
Args:
|
| 380 |
+
response (str): The response from the model containing the
|
| 381 |
+
function call.
|
| 382 |
+
|
| 383 |
+
Returns:
|
| 384 |
+
Optional[Dict[str, Any]]: The parsed function name and arguments
|
| 385 |
+
if found, otherwise :obj:`None`.
|
| 386 |
+
"""
|
| 387 |
+
function_regex = r"<function=(\w+)>(.*?)</function>"
|
| 388 |
+
match = re.search(function_regex, response)
|
| 389 |
+
|
| 390 |
+
if match:
|
| 391 |
+
function_name, args_string = match.groups()
|
| 392 |
+
try:
|
| 393 |
+
args = json.loads(args_string)
|
| 394 |
+
return {"function": function_name, "arguments": args}
|
| 395 |
+
except json.JSONDecodeError as error:
|
| 396 |
+
logger.error(f"Error parsing function arguments: {error}")
|
| 397 |
+
return None
|
| 398 |
+
return None
|
| 399 |
+
|
| 400 |
+
def reset(self):
|
| 401 |
+
r"""Resets the :obj:`ChatAgent` to its initial state."""
|
| 402 |
+
self.terminated = False
|
| 403 |
+
self.init_messages()
|
| 404 |
+
for terminator in self.response_terminators:
|
| 405 |
+
terminator.reset()
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def system_message(self) -> Optional[BaseMessage]:
|
| 409 |
+
r"""The getter method for the property :obj:`system_message`.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
Optional[BaseMessage]: The system message of this agent if set,
|
| 413 |
+
else :obj:`None`.
|
| 414 |
+
"""
|
| 415 |
+
return self._system_message
|
| 416 |
+
|
| 417 |
+
@system_message.setter
|
| 418 |
+
def system_message(self, message: BaseMessage) -> None:
|
| 419 |
+
r"""The setter method for the property :obj:`system_message`.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
message (BaseMessage): The message to be set as the
|
| 423 |
+
new system message of this agent.
|
| 424 |
+
"""
|
| 425 |
+
self._system_message = message
|
| 426 |
+
|
| 427 |
+
def is_tools_added(self) -> bool:
|
| 428 |
+
r"""Whether tool calling is enabled for this agent.
|
| 429 |
+
|
| 430 |
+
Returns:
|
| 431 |
+
bool: Whether tool calling is enabled for this agent, determined
|
| 432 |
+
by whether the dictionary of tools is empty.
|
| 433 |
+
"""
|
| 434 |
+
return len(self.tool_dict) > 0
|
| 435 |
+
|
| 436 |
+
def update_memory(
|
| 437 |
+
self, message: BaseMessage, role: OpenAIBackendRole
|
| 438 |
+
) -> None:
|
| 439 |
+
r"""Updates the agent memory with a new message.
|
| 440 |
+
|
| 441 |
+
Args:
|
| 442 |
+
message (BaseMessage): The new message to add to the stored
|
| 443 |
+
messages.
|
| 444 |
+
role (OpenAIBackendRole): The backend role type.
|
| 445 |
+
"""
|
| 446 |
+
self.memory.write_record(
|
| 447 |
+
MemoryRecord(message=message, role_at_backend=role)
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
def set_output_language(self, output_language: str) -> BaseMessage:
|
| 451 |
+
r"""Sets the output language for the system message. This method
|
| 452 |
+
updates the output language for the system message. The output
|
| 453 |
+
language determines the language in which the output text should be
|
| 454 |
+
generated.
|
| 455 |
+
|
| 456 |
+
Args:
|
| 457 |
+
output_language (str): The desired output language.
|
| 458 |
+
|
| 459 |
+
Returns:
|
| 460 |
+
BaseMessage: The updated system message object.
|
| 461 |
+
"""
|
| 462 |
+
self.output_language = output_language
|
| 463 |
+
language_prompt = (
|
| 464 |
+
"\nRegardless of the input language, "
|
| 465 |
+
f"you must output text in {output_language}."
|
| 466 |
+
)
|
| 467 |
+
if self.orig_sys_message is not None:
|
| 468 |
+
content = self.orig_sys_message.content + language_prompt
|
| 469 |
+
self._system_message = self.orig_sys_message.create_new_instance(
|
| 470 |
+
content
|
| 471 |
+
)
|
| 472 |
+
else:
|
| 473 |
+
self._system_message = BaseMessage.make_assistant_message(
|
| 474 |
+
role_name="Assistant",
|
| 475 |
+
content=language_prompt,
|
| 476 |
+
)
|
| 477 |
+
|
| 478 |
+
system_record = MemoryRecord(
|
| 479 |
+
message=self._system_message,
|
| 480 |
+
role_at_backend=OpenAIBackendRole.SYSTEM,
|
| 481 |
+
)
|
| 482 |
+
self.memory.clear()
|
| 483 |
+
self.memory.write_record(system_record)
|
| 484 |
+
return self._system_message
|
| 485 |
+
|
| 486 |
+
def get_info(
|
| 487 |
+
self,
|
| 488 |
+
session_id: Optional[str],
|
| 489 |
+
usage: Optional[Dict[str, int]],
|
| 490 |
+
termination_reasons: List[str],
|
| 491 |
+
num_tokens: int,
|
| 492 |
+
tool_calls: List[FunctionCallingRecord],
|
| 493 |
+
external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
|
| 494 |
+
) -> Dict[str, Any]:
|
| 495 |
+
r"""Returns a dictionary containing information about the chat session.
|
| 496 |
+
|
| 497 |
+
Args:
|
| 498 |
+
session_id (str, optional): The ID of the chat session.
|
| 499 |
+
usage (Dict[str, int], optional): Information about the usage of
|
| 500 |
+
the LLM.
|
| 501 |
+
termination_reasons (List[str]): The reasons for the termination
|
| 502 |
+
of the chat session.
|
| 503 |
+
num_tokens (int): The number of tokens used in the chat session.
|
| 504 |
+
tool_calls (List[FunctionCallingRecord]): The list of function
|
| 505 |
+
calling records, containing the information of called tools.
|
| 506 |
+
external_tool_request
|
| 507 |
+
(Optional[ChatCompletionMessageToolCall], optional):
|
| 508 |
+
The tool calling request of external tools from the model.
|
| 509 |
+
These requests are directly returned to the user instead of
|
| 510 |
+
being processed by the agent automatically.
|
| 511 |
+
(default: :obj:`None`)
|
| 512 |
+
|
| 513 |
+
Returns:
|
| 514 |
+
Dict[str, Any]: The chat session information.
|
| 515 |
+
"""
|
| 516 |
+
return {
|
| 517 |
+
"id": session_id,
|
| 518 |
+
"usage": usage,
|
| 519 |
+
"termination_reasons": termination_reasons,
|
| 520 |
+
"num_tokens": num_tokens,
|
| 521 |
+
"tool_calls": tool_calls,
|
| 522 |
+
"external_tool_request": external_tool_request,
|
| 523 |
+
}
|
| 524 |
+
|
| 525 |
+
def init_messages(self) -> None:
|
| 526 |
+
r"""Initializes the stored messages list with the current system
|
| 527 |
+
message.
|
| 528 |
+
"""
|
| 529 |
+
if self._system_message is not None:
|
| 530 |
+
system_record = MemoryRecord(
|
| 531 |
+
message=self._system_message,
|
| 532 |
+
role_at_backend=OpenAIBackendRole.SYSTEM,
|
| 533 |
+
)
|
| 534 |
+
self.memory.clear()
|
| 535 |
+
self.memory.write_record(system_record)
|
| 536 |
+
else:
|
| 537 |
+
self.memory.clear()
|
| 538 |
+
|
| 539 |
+
def record_message(self, message: BaseMessage) -> None:
|
| 540 |
+
r"""Records the externally provided message into the agent memory as if
|
| 541 |
+
it were an answer of the :obj:`ChatAgent` from the backend. Currently,
|
| 542 |
+
the choice of the critic is submitted with this method.
|
| 543 |
+
|
| 544 |
+
Args:
|
| 545 |
+
message (BaseMessage): An external message to be recorded in the
|
| 546 |
+
memory.
|
| 547 |
+
"""
|
| 548 |
+
self.update_memory(message, OpenAIBackendRole.ASSISTANT)
|
| 549 |
+
|
| 550 |
+
def step(
|
| 551 |
+
self,
|
| 552 |
+
input_message: Union[BaseMessage, str],
|
| 553 |
+
response_format: Optional[Type[BaseModel]] = None,
|
| 554 |
+
) -> ChatAgentResponse:
|
| 555 |
+
r"""Executes a single step in the chat session, generating a response
|
| 556 |
+
to the input message.
|
| 557 |
+
|
| 558 |
+
Args:
|
| 559 |
+
input_message (Union[BaseMessage, str]): The input message for the
|
| 560 |
+
agent. If provided as a BaseMessage, the `role` is adjusted to
|
| 561 |
+
`user` to indicate an external message.
|
| 562 |
+
response_format (Optional[Type[BaseModel]], optional): A Pydantic
|
| 563 |
+
model defining the expected structure of the response. Used to
|
| 564 |
+
generate a structured response if provided. (default:
|
| 565 |
+
:obj:`None`)
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
ChatAgentResponse: Contains output messages, a termination status
|
| 569 |
+
flag, and session information.
|
| 570 |
+
"""
|
| 571 |
+
|
| 572 |
+
if (
|
| 573 |
+
self.model_backend.model_config_dict.get("response_format")
|
| 574 |
+
and response_format
|
| 575 |
+
):
|
| 576 |
+
raise ValueError(
|
| 577 |
+
"The `response_format` parameter cannot be set both in "
|
| 578 |
+
"the model configuration and in the ChatAgent step."
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
self.original_model_dict = self.model_backend.model_config_dict
|
| 582 |
+
model_response_format_modified = False
|
| 583 |
+
if (
|
| 584 |
+
response_format
|
| 585 |
+
and self.model_type.support_native_structured_output
|
| 586 |
+
):
|
| 587 |
+
self.model_backend.model_config_dict = (
|
| 588 |
+
self.original_model_dict.copy()
|
| 589 |
+
)
|
| 590 |
+
self.model_backend.model_config_dict["response_format"] = (
|
| 591 |
+
response_format
|
| 592 |
+
)
|
| 593 |
+
model_response_format_modified = True
|
| 594 |
+
|
| 595 |
+
# Convert input message to BaseMessage if necessary
|
| 596 |
+
if isinstance(input_message, str):
|
| 597 |
+
input_message = BaseMessage.make_user_message(
|
| 598 |
+
role_name='User', content=input_message
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Handle tool prompt injection if needed
|
| 602 |
+
if (
|
| 603 |
+
self.is_tools_added()
|
| 604 |
+
and not self.model_type.support_native_tool_calling
|
| 605 |
+
and not self.tool_prompt_added
|
| 606 |
+
):
|
| 607 |
+
self._inject_tool_prompt()
|
| 608 |
+
|
| 609 |
+
# Add user input to memory
|
| 610 |
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
| 611 |
+
|
| 612 |
+
try:
|
| 613 |
+
return self._handle_step(response_format, self.single_iteration)
|
| 614 |
+
finally:
|
| 615 |
+
if model_response_format_modified:
|
| 616 |
+
# Reset model config back to original state
|
| 617 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 618 |
+
|
| 619 |
+
def _inject_tool_prompt(self) -> None:
|
| 620 |
+
r"""Generate and add the tool prompt to memory."""
|
| 621 |
+
tool_prompt = self._generate_tool_prompt(
|
| 622 |
+
self.model_backend.model_config_dict["tools"]
|
| 623 |
+
)
|
| 624 |
+
tool_msg = BaseMessage.make_assistant_message(
|
| 625 |
+
role_name="Assistant", content=tool_prompt
|
| 626 |
+
)
|
| 627 |
+
self.update_memory(tool_msg, OpenAIBackendRole.SYSTEM)
|
| 628 |
+
self.tool_prompt_added = True
|
| 629 |
+
|
| 630 |
+
def _handle_step(
|
| 631 |
+
self,
|
| 632 |
+
response_format: Optional[Type[BaseModel]],
|
| 633 |
+
single_step: bool,
|
| 634 |
+
) -> ChatAgentResponse:
|
| 635 |
+
r"""Handles a single or multi-step interaction."""
|
| 636 |
+
|
| 637 |
+
if (
|
| 638 |
+
self.model_backend.model_config_dict.get("tool_choice")
|
| 639 |
+
== "required"
|
| 640 |
+
and not single_step
|
| 641 |
+
):
|
| 642 |
+
raise ValueError(
|
| 643 |
+
"`tool_choice` cannot be set to `required` for multi-step"
|
| 644 |
+
" mode. To proceed, set `single_iteration` to `True`."
|
| 645 |
+
)
|
| 646 |
+
|
| 647 |
+
# Record function calls made during the session
|
| 648 |
+
tool_call_records: List[FunctionCallingRecord] = []
|
| 649 |
+
|
| 650 |
+
external_tool_request = None
|
| 651 |
+
|
| 652 |
+
while True:
|
| 653 |
+
try:
|
| 654 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 655 |
+
except RuntimeError as e:
|
| 656 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 657 |
+
return self._step_token_exceed(
|
| 658 |
+
e.args[1], tool_call_records, "max_tokens_exceeded"
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Prompt engineering approach for structured output for non-native tool calling models
|
| 662 |
+
inject_prompt_for_structured_output = (
|
| 663 |
+
response_format
|
| 664 |
+
and not self.model_type.support_native_structured_output
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
if inject_prompt_for_structured_output:
|
| 668 |
+
# update last openai message
|
| 669 |
+
usr_msg = openai_messages.pop()
|
| 670 |
+
usr_msg["content"] = generate_prompt_for_structured_output(
|
| 671 |
+
response_format,
|
| 672 |
+
usr_msg["content"], # type: ignore [arg-type]
|
| 673 |
+
)
|
| 674 |
+
openai_messages.append(usr_msg)
|
| 675 |
+
|
| 676 |
+
# Process model response
|
| 677 |
+
(
|
| 678 |
+
response,
|
| 679 |
+
output_messages,
|
| 680 |
+
finish_reasons,
|
| 681 |
+
usage_dict,
|
| 682 |
+
response_id,
|
| 683 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 684 |
+
|
| 685 |
+
# Try to parse structured output to return a Pydantic object
|
| 686 |
+
if inject_prompt_for_structured_output and isinstance(
|
| 687 |
+
response, ChatCompletion
|
| 688 |
+
):
|
| 689 |
+
content = response.choices[0].message.content
|
| 690 |
+
try:
|
| 691 |
+
json_content = json.loads(str(content))
|
| 692 |
+
output_messages[0].parsed = response_format(**json_content) # type: ignore [assignment, misc]
|
| 693 |
+
except json.JSONDecodeError as e:
|
| 694 |
+
logger.error(
|
| 695 |
+
f"Failed in parsing the output into JSON: {e}"
|
| 696 |
+
)
|
| 697 |
+
output_messages[0].parsed = None
|
| 698 |
+
except ValidationError as e:
|
| 699 |
+
logger.warning(
|
| 700 |
+
"Successfully generating JSON response, "
|
| 701 |
+
"but failed in parsing it into Pydantic object :"
|
| 702 |
+
f"{e}, return the JSON response in parsed field"
|
| 703 |
+
)
|
| 704 |
+
output_messages[0].parsed = json_content
|
| 705 |
+
|
| 706 |
+
# Finalize on standard response in multi-step mode
|
| 707 |
+
if self._is_standard_response(response):
|
| 708 |
+
break
|
| 709 |
+
|
| 710 |
+
# Handle tool requests
|
| 711 |
+
tool_request = self._extract_tool_call(response)
|
| 712 |
+
if isinstance(response, ChatCompletion) and tool_request:
|
| 713 |
+
response.choices[0].message.tool_calls = [tool_request]
|
| 714 |
+
tool_call_records.append(
|
| 715 |
+
self._step_tool_call_and_update(response)
|
| 716 |
+
)
|
| 717 |
+
|
| 718 |
+
if tool_request.function.name in self.external_tool_names:
|
| 719 |
+
external_tool_request = tool_request
|
| 720 |
+
info = self._step_get_info(
|
| 721 |
+
output_messages,
|
| 722 |
+
finish_reasons,
|
| 723 |
+
usage_dict,
|
| 724 |
+
response_id,
|
| 725 |
+
tool_call_records,
|
| 726 |
+
num_tokens,
|
| 727 |
+
tool_request,
|
| 728 |
+
)
|
| 729 |
+
self._log_final_output(output_messages)
|
| 730 |
+
self.model_backend.model_config_dict = (
|
| 731 |
+
self.original_model_dict
|
| 732 |
+
)
|
| 733 |
+
return ChatAgentResponse(
|
| 734 |
+
msgs=output_messages,
|
| 735 |
+
terminated=self.terminated,
|
| 736 |
+
info=info,
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
# Single-step mode ends after one iteration
|
| 740 |
+
if single_step:
|
| 741 |
+
break
|
| 742 |
+
|
| 743 |
+
# Optional structured output via function calling
|
| 744 |
+
if (
|
| 745 |
+
response_format
|
| 746 |
+
and not inject_prompt_for_structured_output
|
| 747 |
+
and self.model_type
|
| 748 |
+
not in {
|
| 749 |
+
"gpt-4o",
|
| 750 |
+
"gpt-4o-mini",
|
| 751 |
+
}
|
| 752 |
+
):
|
| 753 |
+
(
|
| 754 |
+
output_messages,
|
| 755 |
+
finish_reasons,
|
| 756 |
+
usage_dict,
|
| 757 |
+
response_id,
|
| 758 |
+
tool_call,
|
| 759 |
+
num_tokens,
|
| 760 |
+
) = self._structure_output_with_function(response_format)
|
| 761 |
+
tool_call_records.append(tool_call)
|
| 762 |
+
|
| 763 |
+
# Final info and response
|
| 764 |
+
info = self._step_get_info(
|
| 765 |
+
output_messages,
|
| 766 |
+
finish_reasons,
|
| 767 |
+
usage_dict,
|
| 768 |
+
response_id,
|
| 769 |
+
tool_call_records,
|
| 770 |
+
num_tokens,
|
| 771 |
+
external_tool_request,
|
| 772 |
+
)
|
| 773 |
+
self._log_final_output(output_messages)
|
| 774 |
+
self.model_backend.model_config_dict = self.original_model_dict
|
| 775 |
+
return ChatAgentResponse(
|
| 776 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 777 |
+
)
|
| 778 |
+
|
| 779 |
+
def _extract_tool_call(
|
| 780 |
+
self, response: Any
|
| 781 |
+
) -> Optional[ChatCompletionMessageToolCall]:
|
| 782 |
+
r"""Extract the tool call from the model response, if present.
|
| 783 |
+
|
| 784 |
+
Args:
|
| 785 |
+
response (Any): The model's response object.
|
| 786 |
+
|
| 787 |
+
Returns:
|
| 788 |
+
Optional[ChatCompletionMessageToolCall]: The parsed tool call if
|
| 789 |
+
present, otherwise None.
|
| 790 |
+
"""
|
| 791 |
+
# Check if the response contains tool calls
|
| 792 |
+
if (
|
| 793 |
+
self.is_tools_added()
|
| 794 |
+
and not self.model_type.support_native_tool_calling
|
| 795 |
+
and "</function>" in response.choices[0].message.content
|
| 796 |
+
):
|
| 797 |
+
parsed_content = self._parse_tool_response(
|
| 798 |
+
response.choices[0].message.content
|
| 799 |
+
)
|
| 800 |
+
if parsed_content:
|
| 801 |
+
return ChatCompletionMessageToolCall(
|
| 802 |
+
id=str(uuid.uuid4()),
|
| 803 |
+
function=Function(
|
| 804 |
+
arguments=str(parsed_content["arguments"]).replace(
|
| 805 |
+
"'", '"'
|
| 806 |
+
),
|
| 807 |
+
name=str(parsed_content["function"]),
|
| 808 |
+
),
|
| 809 |
+
type="function",
|
| 810 |
+
)
|
| 811 |
+
elif (
|
| 812 |
+
self.is_tools_added()
|
| 813 |
+
and self.model_type.support_native_tool_calling
|
| 814 |
+
and response.choices[0].message.tool_calls
|
| 815 |
+
):
|
| 816 |
+
return response.choices[0].message.tool_calls[0]
|
| 817 |
+
|
| 818 |
+
# No tool call found
|
| 819 |
+
return None
|
| 820 |
+
|
| 821 |
+
def _is_standard_response(self, response: Any) -> bool:
|
| 822 |
+
r"""Determine if the provided response is a standard reply without
|
| 823 |
+
tool calls.
|
| 824 |
+
|
| 825 |
+
Args:
|
| 826 |
+
response (Any): The response object to evaluate.
|
| 827 |
+
|
| 828 |
+
Returns:
|
| 829 |
+
bool: `True` if the response is a standard reply, `False`
|
| 830 |
+
otherwise.
|
| 831 |
+
"""
|
| 832 |
+
if not self.is_tools_added():
|
| 833 |
+
return True
|
| 834 |
+
|
| 835 |
+
if not isinstance(response, ChatCompletion):
|
| 836 |
+
return True
|
| 837 |
+
|
| 838 |
+
if self.model_type.support_native_tool_calling:
|
| 839 |
+
return not response.choices[0].message.tool_calls
|
| 840 |
+
|
| 841 |
+
return "</function>" not in str(
|
| 842 |
+
response.choices[0].message.content or ""
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
def _log_final_output(self, output_messages: List[BaseMessage]) -> None:
|
| 846 |
+
r"""Log final messages or warnings about multiple responses."""
|
| 847 |
+
if len(output_messages) == 1:
|
| 848 |
+
self.record_message(output_messages[0])
|
| 849 |
+
else:
|
| 850 |
+
logger.warning(
|
| 851 |
+
"Multiple messages returned in `step()`. Record "
|
| 852 |
+
"selected message manually using `record_message()`."
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
async def step_async(
|
| 856 |
+
self,
|
| 857 |
+
input_message: Union[BaseMessage, str],
|
| 858 |
+
response_format: Optional[Type[BaseModel]] = None,
|
| 859 |
+
) -> ChatAgentResponse:
|
| 860 |
+
r"""Performs a single step in the chat session by generating a response
|
| 861 |
+
to the input message. This agent step can call async function calls.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
input_message (Union[BaseMessage, str]): The input message to the
|
| 865 |
+
agent. For BaseMessage input, its `role` field that specifies
|
| 866 |
+
the role at backend may be either `user` or `assistant` but it
|
| 867 |
+
will be set to `user` anyway since for the self agent any
|
| 868 |
+
incoming message is external. For str input, the `role_name`
|
| 869 |
+
would be `User`.
|
| 870 |
+
response_format (Optional[Type[BaseModel]], optional): A pydantic
|
| 871 |
+
model class that includes value types and field descriptions
|
| 872 |
+
used to generate a structured response by LLM. This schema
|
| 873 |
+
helps in defining the expected output format. (default:
|
| 874 |
+
:obj:`None`)
|
| 875 |
+
|
| 876 |
+
Returns:
|
| 877 |
+
ChatAgentResponse: A struct containing the output messages,
|
| 878 |
+
a boolean indicating whether the chat session has terminated,
|
| 879 |
+
and information about the chat session.
|
| 880 |
+
"""
|
| 881 |
+
if isinstance(input_message, str):
|
| 882 |
+
input_message = BaseMessage.make_user_message(
|
| 883 |
+
role_name='User', content=input_message
|
| 884 |
+
)
|
| 885 |
+
|
| 886 |
+
self.update_memory(input_message, OpenAIBackendRole.USER)
|
| 887 |
+
|
| 888 |
+
tool_call_records: List[FunctionCallingRecord] = []
|
| 889 |
+
while True:
|
| 890 |
+
try:
|
| 891 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 892 |
+
except RuntimeError as e:
|
| 893 |
+
return self._step_token_exceed(
|
| 894 |
+
e.args[1], tool_call_records, "max_tokens_exceeded"
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
(
|
| 898 |
+
response,
|
| 899 |
+
output_messages,
|
| 900 |
+
finish_reasons,
|
| 901 |
+
usage_dict,
|
| 902 |
+
response_id,
|
| 903 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 904 |
+
|
| 905 |
+
if (
|
| 906 |
+
not self.is_tools_added()
|
| 907 |
+
or not isinstance(response, ChatCompletion)
|
| 908 |
+
or not response.choices[0].message.tool_calls
|
| 909 |
+
):
|
| 910 |
+
break
|
| 911 |
+
|
| 912 |
+
# Check for external tool call
|
| 913 |
+
external_tool_request = response.choices[0].message.tool_calls[0]
|
| 914 |
+
if external_tool_request.function.name in self.external_tool_names:
|
| 915 |
+
# if model calls an external tool, directly return the request
|
| 916 |
+
info = self._step_get_info(
|
| 917 |
+
output_messages,
|
| 918 |
+
finish_reasons,
|
| 919 |
+
usage_dict,
|
| 920 |
+
response_id,
|
| 921 |
+
tool_call_records,
|
| 922 |
+
num_tokens,
|
| 923 |
+
external_tool_request,
|
| 924 |
+
)
|
| 925 |
+
return ChatAgentResponse(
|
| 926 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 927 |
+
)
|
| 928 |
+
|
| 929 |
+
# Normal function calling
|
| 930 |
+
tool_call_records.append(
|
| 931 |
+
await self._step_tool_call_and_update_async(response)
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
if (
|
| 935 |
+
response_format is not None
|
| 936 |
+
and self.model_type.support_native_tool_calling
|
| 937 |
+
):
|
| 938 |
+
(
|
| 939 |
+
output_messages,
|
| 940 |
+
finish_reasons,
|
| 941 |
+
usage_dict,
|
| 942 |
+
response_id,
|
| 943 |
+
tool_call_record,
|
| 944 |
+
num_tokens,
|
| 945 |
+
) = self._structure_output_with_function(response_format)
|
| 946 |
+
tool_call_records.append(tool_call_record)
|
| 947 |
+
|
| 948 |
+
info = self._step_get_info(
|
| 949 |
+
output_messages,
|
| 950 |
+
finish_reasons,
|
| 951 |
+
usage_dict,
|
| 952 |
+
response_id,
|
| 953 |
+
tool_call_records,
|
| 954 |
+
num_tokens,
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
if len(output_messages) == 1:
|
| 958 |
+
# Auto record if the output result is a single message
|
| 959 |
+
self.record_message(output_messages[0])
|
| 960 |
+
else:
|
| 961 |
+
logger.warning(
|
| 962 |
+
"Multiple messages returned in `step()`, message won't be "
|
| 963 |
+
"recorded automatically. Please call `record_message()` to "
|
| 964 |
+
"record the selected message manually."
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
return ChatAgentResponse(
|
| 968 |
+
msgs=output_messages, terminated=self.terminated, info=info
|
| 969 |
+
)
|
| 970 |
+
|
| 971 |
+
def _step_tool_call_and_update(
|
| 972 |
+
self, response: ChatCompletion
|
| 973 |
+
) -> FunctionCallingRecord:
|
| 974 |
+
r"""Processes a function call within the chat completion response,
|
| 975 |
+
records the function call in the provided list of tool calls and
|
| 976 |
+
updates the memory of the current agent.
|
| 977 |
+
|
| 978 |
+
Args:
|
| 979 |
+
response (ChatCompletion): The response object from the chat
|
| 980 |
+
completion.
|
| 981 |
+
|
| 982 |
+
Returns:
|
| 983 |
+
FunctionCallingRecord: The record of calling the function.
|
| 984 |
+
"""
|
| 985 |
+
|
| 986 |
+
# Perform function calling
|
| 987 |
+
func_assistant_msg, func_result_msg, tool_call_record = (
|
| 988 |
+
self._step_tool_call(response)
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
# Update the messages
|
| 992 |
+
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
| 993 |
+
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
| 994 |
+
|
| 995 |
+
return tool_call_record
|
| 996 |
+
|
| 997 |
+
async def _step_tool_call_and_update_async(
|
| 998 |
+
self, response: ChatCompletion
|
| 999 |
+
) -> FunctionCallingRecord:
|
| 1000 |
+
(
|
| 1001 |
+
func_assistant_msg,
|
| 1002 |
+
func_result_msg,
|
| 1003 |
+
func_record,
|
| 1004 |
+
) = await self.step_tool_call_async(response)
|
| 1005 |
+
|
| 1006 |
+
self.update_memory(func_assistant_msg, OpenAIBackendRole.ASSISTANT)
|
| 1007 |
+
self.update_memory(func_result_msg, OpenAIBackendRole.FUNCTION)
|
| 1008 |
+
|
| 1009 |
+
return func_record
|
| 1010 |
+
|
| 1011 |
+
def _structure_output_with_function(
|
| 1012 |
+
self, response_format: Type[BaseModel]
|
| 1013 |
+
) -> Tuple[
|
| 1014 |
+
List[BaseMessage],
|
| 1015 |
+
List[str],
|
| 1016 |
+
Dict[str, int],
|
| 1017 |
+
str,
|
| 1018 |
+
FunctionCallingRecord,
|
| 1019 |
+
int,
|
| 1020 |
+
]:
|
| 1021 |
+
r"""Internal function of structuring the output of the agent based on
|
| 1022 |
+
the given output schema.
|
| 1023 |
+
|
| 1024 |
+
Args:
|
| 1025 |
+
response_format (Type[BaseModel]): The output schema to use for
|
| 1026 |
+
structuring the output.
|
| 1027 |
+
|
| 1028 |
+
Returns:
|
| 1029 |
+
Tuple[List[BaseMessage], List[str], Dict[str, int], str,
|
| 1030 |
+
FunctionCallingRecord, int]:
|
| 1031 |
+
A tuple containing the output messages, finish reasons, usage
|
| 1032 |
+
dictionary, response ID, function calling record, and number of
|
| 1033 |
+
tokens.
|
| 1034 |
+
"""
|
| 1035 |
+
from camel.toolkits import FunctionTool
|
| 1036 |
+
|
| 1037 |
+
schema_json = get_pydantic_object_schema(response_format)
|
| 1038 |
+
func_str = json_to_function_code(schema_json)
|
| 1039 |
+
func_callable = func_string_to_callable(func_str)
|
| 1040 |
+
func = FunctionTool(func_callable)
|
| 1041 |
+
|
| 1042 |
+
original_model_dict = self.model_backend.model_config_dict
|
| 1043 |
+
|
| 1044 |
+
# Replace the original tools with the structuring function
|
| 1045 |
+
self.tool_dict = {func.get_function_name(): func}
|
| 1046 |
+
self.model_backend.model_config_dict = original_model_dict.copy()
|
| 1047 |
+
self.model_backend.model_config_dict["tools"] = [
|
| 1048 |
+
func.get_openai_tool_schema()
|
| 1049 |
+
]
|
| 1050 |
+
self.model_backend.model_config_dict["tool_choice"] = "required"
|
| 1051 |
+
|
| 1052 |
+
openai_messages, num_tokens = self.memory.get_context()
|
| 1053 |
+
(
|
| 1054 |
+
response,
|
| 1055 |
+
output_messages,
|
| 1056 |
+
finish_reasons,
|
| 1057 |
+
usage_dict,
|
| 1058 |
+
response_id,
|
| 1059 |
+
) = self._step_model_response(openai_messages, num_tokens)
|
| 1060 |
+
|
| 1061 |
+
if isinstance(response, ChatCompletion):
|
| 1062 |
+
tool_call_record = self._step_tool_call_and_update(response)
|
| 1063 |
+
else:
|
| 1064 |
+
raise ValueError(
|
| 1065 |
+
"Structured output is not supported for stream responses."
|
| 1066 |
+
)
|
| 1067 |
+
|
| 1068 |
+
for base_message_item in output_messages:
|
| 1069 |
+
base_message_item.content = json.dumps(tool_call_record.result)
|
| 1070 |
+
|
| 1071 |
+
# Recover the original tools
|
| 1072 |
+
self.model_backend.model_config_dict = original_model_dict
|
| 1073 |
+
|
| 1074 |
+
return (
|
| 1075 |
+
output_messages,
|
| 1076 |
+
finish_reasons,
|
| 1077 |
+
usage_dict,
|
| 1078 |
+
response_id,
|
| 1079 |
+
tool_call_record,
|
| 1080 |
+
num_tokens,
|
| 1081 |
+
)
|
| 1082 |
+
|
| 1083 |
+
def _step_model_response(
|
| 1084 |
+
self,
|
| 1085 |
+
openai_messages: List[OpenAIMessage],
|
| 1086 |
+
num_tokens: int,
|
| 1087 |
+
) -> tuple[
|
| 1088 |
+
Union[ChatCompletion, Stream],
|
| 1089 |
+
List[BaseMessage],
|
| 1090 |
+
List[str],
|
| 1091 |
+
Dict[str, int],
|
| 1092 |
+
str,
|
| 1093 |
+
]:
|
| 1094 |
+
r"""Internal function for agent step model response."""
|
| 1095 |
+
|
| 1096 |
+
response = None
|
| 1097 |
+
# Obtain the model's response
|
| 1098 |
+
for _ in range(len(self.model_backend.models)):
|
| 1099 |
+
try:
|
| 1100 |
+
response = self.model_backend.run(openai_messages)
|
| 1101 |
+
break
|
| 1102 |
+
except Exception as exc:
|
| 1103 |
+
logger.error(
|
| 1104 |
+
f"An error occurred while running model "
|
| 1105 |
+
f"{self.model_backend.model_type}, "
|
| 1106 |
+
f"index: {self.model_backend.current_model_index}",
|
| 1107 |
+
exc_info=exc,
|
| 1108 |
+
)
|
| 1109 |
+
continue
|
| 1110 |
+
if not response:
|
| 1111 |
+
raise ModelProcessingError(
|
| 1112 |
+
"Unable to process messages: none of the provided models "
|
| 1113 |
+
"run succesfully."
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
logger.info(
|
| 1117 |
+
f"Model {self.model_backend.model_type}, "
|
| 1118 |
+
f"index {self.model_backend.current_model_index}, "
|
| 1119 |
+
f"processed these messages: {openai_messages}"
|
| 1120 |
+
)
|
| 1121 |
+
|
| 1122 |
+
if isinstance(response, ChatCompletion):
|
| 1123 |
+
output_messages, finish_reasons, usage_dict, response_id = (
|
| 1124 |
+
self.handle_batch_response(response)
|
| 1125 |
+
)
|
| 1126 |
+
else:
|
| 1127 |
+
output_messages, finish_reasons, usage_dict, response_id = (
|
| 1128 |
+
self.handle_stream_response(response, num_tokens)
|
| 1129 |
+
)
|
| 1130 |
+
return (
|
| 1131 |
+
response,
|
| 1132 |
+
output_messages,
|
| 1133 |
+
finish_reasons,
|
| 1134 |
+
usage_dict,
|
| 1135 |
+
response_id,
|
| 1136 |
+
)
|
| 1137 |
+
|
| 1138 |
+
def _step_get_info(
|
| 1139 |
+
self,
|
| 1140 |
+
output_messages: List[BaseMessage],
|
| 1141 |
+
finish_reasons: List[str],
|
| 1142 |
+
usage_dict: Dict[str, int],
|
| 1143 |
+
response_id: str,
|
| 1144 |
+
tool_calls: List[FunctionCallingRecord],
|
| 1145 |
+
num_tokens: int,
|
| 1146 |
+
external_tool_request: Optional[ChatCompletionMessageToolCall] = None,
|
| 1147 |
+
) -> Dict[str, Any]:
|
| 1148 |
+
r"""Process the output of a chat step and gather information about the
|
| 1149 |
+
step.
|
| 1150 |
+
|
| 1151 |
+
This method checks for termination conditions, updates the agent's
|
| 1152 |
+
state, and collects information about the chat step, including tool
|
| 1153 |
+
calls and termination reasons.
|
| 1154 |
+
|
| 1155 |
+
Args:
|
| 1156 |
+
output_messages (List[BaseMessage]): The messages generated in
|
| 1157 |
+
this step.
|
| 1158 |
+
finish_reasons (List[str]): The reasons for finishing the
|
| 1159 |
+
generation for each message.
|
| 1160 |
+
usage_dict (Dict[str, int]): Dictionary containing token usage
|
| 1161 |
+
information.
|
| 1162 |
+
response_id (str): The ID of the response from the model.
|
| 1163 |
+
tool_calls (List[FunctionCallingRecord]): Records of function calls
|
| 1164 |
+
made during this step.
|
| 1165 |
+
num_tokens (int): The number of tokens used in this step.
|
| 1166 |
+
external_tool_request (Optional[ChatCompletionMessageToolCall]):
|
| 1167 |
+
Any external tool request made during this step.
|
| 1168 |
+
(default: :obj:`None`)
|
| 1169 |
+
|
| 1170 |
+
Returns:
|
| 1171 |
+
Dict[str, Any]: A dictionary containing information about the chat
|
| 1172 |
+
step, including termination status, reasons, and tool call
|
| 1173 |
+
information.
|
| 1174 |
+
|
| 1175 |
+
Note:
|
| 1176 |
+
This method iterates over all response terminators and checks if
|
| 1177 |
+
any of them signal termination. If a terminator signals
|
| 1178 |
+
termination, the agent's state is updated accordingly, and the
|
| 1179 |
+
termination reason is recorded.
|
| 1180 |
+
"""
|
| 1181 |
+
termination = [
|
| 1182 |
+
terminator.is_terminated(output_messages)
|
| 1183 |
+
for terminator in self.response_terminators
|
| 1184 |
+
]
|
| 1185 |
+
# Terminate the agent if any of the terminator terminates
|
| 1186 |
+
self.terminated, termination_reason = next(
|
| 1187 |
+
(
|
| 1188 |
+
(terminated, termination_reason)
|
| 1189 |
+
for terminated, termination_reason in termination
|
| 1190 |
+
if terminated
|
| 1191 |
+
),
|
| 1192 |
+
(False, None),
|
| 1193 |
+
)
|
| 1194 |
+
# For now only retain the first termination reason
|
| 1195 |
+
if self.terminated and termination_reason is not None:
|
| 1196 |
+
finish_reasons = [termination_reason] * len(finish_reasons)
|
| 1197 |
+
|
| 1198 |
+
info = self.get_info(
|
| 1199 |
+
response_id,
|
| 1200 |
+
usage_dict,
|
| 1201 |
+
finish_reasons,
|
| 1202 |
+
num_tokens,
|
| 1203 |
+
tool_calls,
|
| 1204 |
+
external_tool_request,
|
| 1205 |
+
)
|
| 1206 |
+
return info
|
| 1207 |
+
|
| 1208 |
+
def handle_batch_response(
|
| 1209 |
+
self, response: ChatCompletion
|
| 1210 |
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
| 1211 |
+
r"""Process a batch response from the model and extract the necessary
|
| 1212 |
+
information.
|
| 1213 |
+
|
| 1214 |
+
Args:
|
| 1215 |
+
response (dict): Model response.
|
| 1216 |
+
|
| 1217 |
+
Returns:
|
| 1218 |
+
tuple: A tuple of list of output `ChatMessage`, list of
|
| 1219 |
+
finish reasons, usage dictionary, and response id.
|
| 1220 |
+
"""
|
| 1221 |
+
output_messages: List[BaseMessage] = []
|
| 1222 |
+
for choice in response.choices:
|
| 1223 |
+
chat_message = BaseMessage(
|
| 1224 |
+
role_name=self.role_name,
|
| 1225 |
+
role_type=self.role_type,
|
| 1226 |
+
meta_dict=dict(),
|
| 1227 |
+
content=choice.message.content or "",
|
| 1228 |
+
parsed=getattr(choice.message, 'parsed', None),
|
| 1229 |
+
)
|
| 1230 |
+
# Process log probabilities and append to the message meta information
|
| 1231 |
+
if choice.logprobs is not None:
|
| 1232 |
+
tokens_logprobs = choice.logprobs.content
|
| 1233 |
+
|
| 1234 |
+
if tokens_logprobs is not None:
|
| 1235 |
+
# Extract and structure logprob information
|
| 1236 |
+
logprobs_info = [
|
| 1237 |
+
{
|
| 1238 |
+
"token": token_logprob.token,
|
| 1239 |
+
"logprob": token_logprob.logprob,
|
| 1240 |
+
"top_logprobs": [
|
| 1241 |
+
(top_logprob.token, top_logprob.logprob)
|
| 1242 |
+
for top_logprob in token_logprob.top_logprobs
|
| 1243 |
+
],
|
| 1244 |
+
}
|
| 1245 |
+
for token_logprob in tokens_logprobs
|
| 1246 |
+
]
|
| 1247 |
+
# Ensure meta_dict exists before adding logprobs info
|
| 1248 |
+
if chat_message.meta_dict is None:
|
| 1249 |
+
chat_message.meta_dict = {}
|
| 1250 |
+
chat_message.meta_dict["logprobs_info"] = logprobs_info
|
| 1251 |
+
# Append the processed chat message to output
|
| 1252 |
+
output_messages.append(chat_message)
|
| 1253 |
+
|
| 1254 |
+
finish_reasons = [
|
| 1255 |
+
str(choice.finish_reason) for choice in response.choices
|
| 1256 |
+
]
|
| 1257 |
+
usage = (
|
| 1258 |
+
self._safe_model_dump(response.usage)
|
| 1259 |
+
if response.usage is not None
|
| 1260 |
+
else {}
|
| 1261 |
+
)
|
| 1262 |
+
return (
|
| 1263 |
+
output_messages,
|
| 1264 |
+
finish_reasons,
|
| 1265 |
+
usage,
|
| 1266 |
+
response.id,
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
def _safe_model_dump(self, obj) -> dict:
|
| 1270 |
+
r"""Safely dump a Pydantic model to a dictionary.
|
| 1271 |
+
|
| 1272 |
+
This method attempts to use the `model_dump` method if available,
|
| 1273 |
+
otherwise it falls back to the `dict` method.
|
| 1274 |
+
|
| 1275 |
+
Args:
|
| 1276 |
+
obj: The Pydantic model instance to be dumped.
|
| 1277 |
+
|
| 1278 |
+
Returns:
|
| 1279 |
+
dict: A dictionary representation of the Pydantic model.
|
| 1280 |
+
"""
|
| 1281 |
+
# Check if the `model_dump` method exists (Pydantic v2)
|
| 1282 |
+
if hasattr(obj, 'model_dump'):
|
| 1283 |
+
return obj.model_dump()
|
| 1284 |
+
# Fallback to `dict()` method (Pydantic v1)
|
| 1285 |
+
elif hasattr(obj, 'dict'):
|
| 1286 |
+
return obj.dict()
|
| 1287 |
+
else:
|
| 1288 |
+
raise TypeError("The object is not a Pydantic model")
|
| 1289 |
+
|
| 1290 |
+
def handle_stream_response(
|
| 1291 |
+
self,
|
| 1292 |
+
response: Stream[ChatCompletionChunk],
|
| 1293 |
+
prompt_tokens: int,
|
| 1294 |
+
) -> Tuple[List[BaseMessage], List[str], Dict[str, int], str]:
|
| 1295 |
+
r"""Process a stream response from the model and extract the necessary
|
| 1296 |
+
information.
|
| 1297 |
+
|
| 1298 |
+
Args:
|
| 1299 |
+
response (dict): Model response.
|
| 1300 |
+
prompt_tokens (int): Number of input prompt tokens.
|
| 1301 |
+
|
| 1302 |
+
Returns:
|
| 1303 |
+
tuple: A tuple of list of output `ChatMessage`, list of
|
| 1304 |
+
finish reasons, usage dictionary, and response id.
|
| 1305 |
+
"""
|
| 1306 |
+
content_dict: defaultdict = defaultdict(lambda: "")
|
| 1307 |
+
finish_reasons_dict: defaultdict = defaultdict(lambda: "")
|
| 1308 |
+
output_messages: List[BaseMessage] = []
|
| 1309 |
+
response_id: str = ""
|
| 1310 |
+
# All choices in one response share one role
|
| 1311 |
+
for chunk in response:
|
| 1312 |
+
response_id = chunk.id
|
| 1313 |
+
for choice in chunk.choices:
|
| 1314 |
+
index = choice.index
|
| 1315 |
+
delta = choice.delta
|
| 1316 |
+
if delta.content is not None:
|
| 1317 |
+
# When response has not been stopped
|
| 1318 |
+
# Notice that only the first chunk_dict has the "role"
|
| 1319 |
+
content_dict[index] += delta.content
|
| 1320 |
+
if choice.finish_reason:
|
| 1321 |
+
finish_reasons_dict[index] = choice.finish_reason
|
| 1322 |
+
chat_message = BaseMessage(
|
| 1323 |
+
role_name=self.role_name,
|
| 1324 |
+
role_type=self.role_type,
|
| 1325 |
+
meta_dict=dict(),
|
| 1326 |
+
content=content_dict[index],
|
| 1327 |
+
)
|
| 1328 |
+
output_messages.append(chat_message)
|
| 1329 |
+
finish_reasons = [
|
| 1330 |
+
finish_reasons_dict[i] for i in range(len(finish_reasons_dict))
|
| 1331 |
+
]
|
| 1332 |
+
usage_dict = self.get_usage_dict(output_messages, prompt_tokens)
|
| 1333 |
+
return output_messages, finish_reasons, usage_dict, response_id
|
| 1334 |
+
|
| 1335 |
+
def _step_token_exceed(
|
| 1336 |
+
self,
|
| 1337 |
+
num_tokens: int,
|
| 1338 |
+
tool_calls: List[FunctionCallingRecord],
|
| 1339 |
+
termination_reason: str,
|
| 1340 |
+
) -> ChatAgentResponse:
|
| 1341 |
+
r"""Return trivial response containing number of tokens and information
|
| 1342 |
+
of called functions when the number of tokens exceeds.
|
| 1343 |
+
|
| 1344 |
+
Args:
|
| 1345 |
+
num_tokens (int): Number of tokens in the messages.
|
| 1346 |
+
tool_calls (List[FunctionCallingRecord]): List of information
|
| 1347 |
+
objects of functions called in the current step.
|
| 1348 |
+
termination_reason (str): String of termination reason.
|
| 1349 |
+
|
| 1350 |
+
Returns:
|
| 1351 |
+
ChatAgentResponse: The struct containing trivial outputs and
|
| 1352 |
+
information about token number and called functions.
|
| 1353 |
+
"""
|
| 1354 |
+
self.terminated = True
|
| 1355 |
+
output_messages: List[BaseMessage] = []
|
| 1356 |
+
|
| 1357 |
+
info = self.get_info(
|
| 1358 |
+
None,
|
| 1359 |
+
None,
|
| 1360 |
+
[termination_reason],
|
| 1361 |
+
num_tokens,
|
| 1362 |
+
tool_calls,
|
| 1363 |
+
)
|
| 1364 |
+
|
| 1365 |
+
return ChatAgentResponse(
|
| 1366 |
+
msgs=output_messages,
|
| 1367 |
+
terminated=self.terminated,
|
| 1368 |
+
info=info,
|
| 1369 |
+
)
|
| 1370 |
+
|
| 1371 |
+
def _step_tool_call(
|
| 1372 |
+
self,
|
| 1373 |
+
response: ChatCompletion,
|
| 1374 |
+
) -> Tuple[
|
| 1375 |
+
FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
|
| 1376 |
+
]:
|
| 1377 |
+
r"""Execute the function with arguments following the model's response.
|
| 1378 |
+
|
| 1379 |
+
Args:
|
| 1380 |
+
response (Dict[str, Any]): The response obtained by calling the
|
| 1381 |
+
model.
|
| 1382 |
+
|
| 1383 |
+
Returns:
|
| 1384 |
+
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
|
| 1385 |
+
one about the arguments and the other about the execution
|
| 1386 |
+
result, and a struct for logging information about this
|
| 1387 |
+
function call.
|
| 1388 |
+
"""
|
| 1389 |
+
choice = response.choices[0]
|
| 1390 |
+
if choice.message.tool_calls is None:
|
| 1391 |
+
raise RuntimeError("Tool call is None")
|
| 1392 |
+
func_name = choice.message.tool_calls[0].function.name
|
| 1393 |
+
|
| 1394 |
+
arguments_str = choice.message.tool_calls[0].function.arguments
|
| 1395 |
+
args = self._safe_json_loads(arguments_str)
|
| 1396 |
+
|
| 1397 |
+
tool = self.tool_dict[func_name]
|
| 1398 |
+
result = tool(**args)
|
| 1399 |
+
tool_call_id = choice.message.tool_calls[0].id
|
| 1400 |
+
|
| 1401 |
+
assist_msg = FunctionCallingMessage(
|
| 1402 |
+
role_name=self.role_name,
|
| 1403 |
+
role_type=self.role_type,
|
| 1404 |
+
meta_dict=None,
|
| 1405 |
+
content="",
|
| 1406 |
+
func_name=func_name,
|
| 1407 |
+
args=args,
|
| 1408 |
+
tool_call_id=tool_call_id,
|
| 1409 |
+
)
|
| 1410 |
+
func_msg = FunctionCallingMessage(
|
| 1411 |
+
role_name=self.role_name,
|
| 1412 |
+
role_type=self.role_type,
|
| 1413 |
+
meta_dict=None,
|
| 1414 |
+
content="",
|
| 1415 |
+
func_name=func_name,
|
| 1416 |
+
result=result,
|
| 1417 |
+
tool_call_id=tool_call_id,
|
| 1418 |
+
)
|
| 1419 |
+
|
| 1420 |
+
# Record information about this function call
|
| 1421 |
+
func_record = FunctionCallingRecord(
|
| 1422 |
+
func_name=func_name,
|
| 1423 |
+
args=args,
|
| 1424 |
+
result=result,
|
| 1425 |
+
tool_call_id=tool_call_id,
|
| 1426 |
+
)
|
| 1427 |
+
return assist_msg, func_msg, func_record
|
| 1428 |
+
|
| 1429 |
+
def _safe_json_loads(self, arguments_str):
|
| 1430 |
+
# Replace Python types with their JSON equivalents
|
| 1431 |
+
arguments_str = arguments_str.replace("None", "null")
|
| 1432 |
+
arguments_str = arguments_str.replace("True", "true")
|
| 1433 |
+
arguments_str = arguments_str.replace("False", "false")
|
| 1434 |
+
|
| 1435 |
+
# Attempt to parse the corrected string
|
| 1436 |
+
try:
|
| 1437 |
+
return json.loads(arguments_str)
|
| 1438 |
+
except json.JSONDecodeError as e:
|
| 1439 |
+
raise ValueError(f"Invalid JSON format: {e}")
|
| 1440 |
+
|
| 1441 |
+
async def step_tool_call_async(
|
| 1442 |
+
self,
|
| 1443 |
+
response: ChatCompletion,
|
| 1444 |
+
) -> Tuple[
|
| 1445 |
+
FunctionCallingMessage, FunctionCallingMessage, FunctionCallingRecord
|
| 1446 |
+
]:
|
| 1447 |
+
r"""Execute the async function with arguments following the model's
|
| 1448 |
+
response.
|
| 1449 |
+
|
| 1450 |
+
Args:
|
| 1451 |
+
response (Dict[str, Any]): The response obtained by calling the
|
| 1452 |
+
model.
|
| 1453 |
+
|
| 1454 |
+
Returns:
|
| 1455 |
+
tuple: A tuple consisting of two obj:`FunctionCallingMessage`,
|
| 1456 |
+
one about the arguments and the other about the execution
|
| 1457 |
+
result, and a struct for logging information about this
|
| 1458 |
+
function call.
|
| 1459 |
+
"""
|
| 1460 |
+
# Note that when function calling is enabled, `n` is set to 1.
|
| 1461 |
+
choice = response.choices[0]
|
| 1462 |
+
if choice.message.tool_calls is None:
|
| 1463 |
+
raise RuntimeError("Tool call is None")
|
| 1464 |
+
func_name = choice.message.tool_calls[0].function.name
|
| 1465 |
+
|
| 1466 |
+
args = json.loads(choice.message.tool_calls[0].function.arguments)
|
| 1467 |
+
tool = self.tool_dict[func_name]
|
| 1468 |
+
result = await tool(**args)
|
| 1469 |
+
tool_call_id = choice.message.tool_calls[0].id
|
| 1470 |
+
|
| 1471 |
+
assist_msg = FunctionCallingMessage(
|
| 1472 |
+
role_name=self.role_name,
|
| 1473 |
+
role_type=self.role_type,
|
| 1474 |
+
meta_dict=None,
|
| 1475 |
+
content="",
|
| 1476 |
+
func_name=func_name,
|
| 1477 |
+
args=args,
|
| 1478 |
+
tool_call_id=tool_call_id,
|
| 1479 |
+
)
|
| 1480 |
+
func_msg = FunctionCallingMessage(
|
| 1481 |
+
role_name=self.role_name,
|
| 1482 |
+
role_type=self.role_type,
|
| 1483 |
+
meta_dict=None,
|
| 1484 |
+
content="",
|
| 1485 |
+
func_name=func_name,
|
| 1486 |
+
result=result,
|
| 1487 |
+
tool_call_id=tool_call_id,
|
| 1488 |
+
)
|
| 1489 |
+
|
| 1490 |
+
# Record information about this function call
|
| 1491 |
+
func_record = FunctionCallingRecord(
|
| 1492 |
+
func_name=func_name,
|
| 1493 |
+
args=args,
|
| 1494 |
+
result=result,
|
| 1495 |
+
tool_call_id=tool_call_id,
|
| 1496 |
+
)
|
| 1497 |
+
return assist_msg, func_msg, func_record
|
| 1498 |
+
|
| 1499 |
+
def get_usage_dict(
|
| 1500 |
+
self, output_messages: List[BaseMessage], prompt_tokens: int
|
| 1501 |
+
) -> Dict[str, int]:
|
| 1502 |
+
r"""Get usage dictionary when using the stream mode.
|
| 1503 |
+
|
| 1504 |
+
Args:
|
| 1505 |
+
output_messages (list): List of output messages.
|
| 1506 |
+
prompt_tokens (int): Number of input prompt tokens.
|
| 1507 |
+
|
| 1508 |
+
Returns:
|
| 1509 |
+
dict: Usage dictionary.
|
| 1510 |
+
"""
|
| 1511 |
+
encoding = get_model_encoding(self.model_type.value_for_tiktoken)
|
| 1512 |
+
completion_tokens = 0
|
| 1513 |
+
for message in output_messages:
|
| 1514 |
+
completion_tokens += len(encoding.encode(message.content))
|
| 1515 |
+
usage_dict = dict(
|
| 1516 |
+
completion_tokens=completion_tokens,
|
| 1517 |
+
prompt_tokens=prompt_tokens,
|
| 1518 |
+
total_tokens=completion_tokens + prompt_tokens,
|
| 1519 |
+
)
|
| 1520 |
+
return usage_dict
|
| 1521 |
+
|
| 1522 |
+
def add_model_scheduling_strategy(self, name: str, strategy_fn: Callable):
|
| 1523 |
+
r"""Add a scheduling strategy method provided by user to ModelManger.
|
| 1524 |
+
|
| 1525 |
+
Args:
|
| 1526 |
+
name (str): The name of the strategy.
|
| 1527 |
+
strategy_fn (Callable): The scheduling strategy function.
|
| 1528 |
+
"""
|
| 1529 |
+
self.model_backend.add_strategy(name, strategy_fn)
|
| 1530 |
+
|
| 1531 |
+
def __repr__(self) -> str:
|
| 1532 |
+
r"""Returns a string representation of the :obj:`ChatAgent`.
|
| 1533 |
+
|
| 1534 |
+
Returns:
|
| 1535 |
+
str: The string representation of the :obj:`ChatAgent`.
|
| 1536 |
+
"""
|
| 1537 |
+
return (
|
| 1538 |
+
f"ChatAgent({self.role_name}, {self.role_type}, {self.model_type})"
|
| 1539 |
+
)
|
Paper2Poster/camel/agents/critic_agent.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import random
|
| 15 |
+
import warnings
|
| 16 |
+
from typing import Any, Dict, Optional, Sequence
|
| 17 |
+
|
| 18 |
+
from colorama import Fore
|
| 19 |
+
|
| 20 |
+
from camel.agents.chat_agent import ChatAgent
|
| 21 |
+
from camel.memories import AgentMemory
|
| 22 |
+
from camel.messages import BaseMessage
|
| 23 |
+
from camel.models import BaseModelBackend
|
| 24 |
+
from camel.responses import ChatAgentResponse
|
| 25 |
+
from camel.utils import get_first_int, print_text_animated
|
| 26 |
+
|
| 27 |
+
# AgentOps decorator setting
|
| 28 |
+
try:
|
| 29 |
+
import os
|
| 30 |
+
|
| 31 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 32 |
+
from agentops import track_agent
|
| 33 |
+
else:
|
| 34 |
+
raise ImportError
|
| 35 |
+
except (ImportError, AttributeError):
|
| 36 |
+
from camel.utils import track_agent
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@track_agent(name="CriticAgent")
|
| 40 |
+
class CriticAgent(ChatAgent):
|
| 41 |
+
r"""A class for the critic agent that assists in selecting an option.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
system_message (BaseMessage): The system message for the critic
|
| 45 |
+
agent.
|
| 46 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 47 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 48 |
+
`GPT_4O_MINI`)
|
| 49 |
+
message_window_size (int, optional): The maximum number of previous
|
| 50 |
+
messages to include in the context window. If `None`, no windowing
|
| 51 |
+
is performed. (default: :obj:`6`)
|
| 52 |
+
retry_attempts (int, optional): The number of retry attempts if the
|
| 53 |
+
critic fails to return a valid option. (default: :obj:`2`)
|
| 54 |
+
verbose (bool, optional): Whether to print the critic's messages.
|
| 55 |
+
logger_color (Any): The color of the menu options displayed to the
|
| 56 |
+
user. (default: :obj:`Fore.MAGENTA`)
|
| 57 |
+
"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self,
|
| 61 |
+
system_message: BaseMessage,
|
| 62 |
+
model: Optional[BaseModelBackend] = None,
|
| 63 |
+
memory: Optional[AgentMemory] = None,
|
| 64 |
+
message_window_size: int = 6,
|
| 65 |
+
retry_attempts: int = 2,
|
| 66 |
+
verbose: bool = False,
|
| 67 |
+
logger_color: Any = Fore.MAGENTA,
|
| 68 |
+
) -> None:
|
| 69 |
+
super().__init__(
|
| 70 |
+
system_message,
|
| 71 |
+
model=model,
|
| 72 |
+
memory=memory,
|
| 73 |
+
message_window_size=message_window_size,
|
| 74 |
+
)
|
| 75 |
+
self.options_dict: Dict[str, str] = dict()
|
| 76 |
+
self.retry_attempts = retry_attempts
|
| 77 |
+
self.verbose = verbose
|
| 78 |
+
self.logger_color = logger_color
|
| 79 |
+
|
| 80 |
+
def flatten_options(self, messages: Sequence[BaseMessage]) -> str:
|
| 81 |
+
r"""Flattens the options to the critic.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
messages (Sequence[BaseMessage]): A list of `BaseMessage` objects.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
str: A string containing the flattened options to the critic.
|
| 88 |
+
"""
|
| 89 |
+
options = [message.content for message in messages]
|
| 90 |
+
flatten_options = (
|
| 91 |
+
f"> Proposals from "
|
| 92 |
+
f"{messages[0].role_name} ({messages[0].role_type}). "
|
| 93 |
+
"Please choose an option:\n"
|
| 94 |
+
)
|
| 95 |
+
for index, option in enumerate(options):
|
| 96 |
+
flatten_options += f"Option {index + 1}:\n{option}\n\n"
|
| 97 |
+
self.options_dict[str(index + 1)] = option
|
| 98 |
+
format = (
|
| 99 |
+
f"Please first enter your choice ([1-{len(self.options_dict)}]) "
|
| 100 |
+
"and then your explanation and comparison: "
|
| 101 |
+
)
|
| 102 |
+
return flatten_options + format
|
| 103 |
+
|
| 104 |
+
def get_option(self, input_message: BaseMessage) -> str:
|
| 105 |
+
r"""Gets the option selected by the critic.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
input_message (BaseMessage): A `BaseMessage` object representing
|
| 109 |
+
the input message.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
str: The option selected by the critic.
|
| 113 |
+
"""
|
| 114 |
+
# TODO: Add support for editing options by the critic.
|
| 115 |
+
msg_content = input_message.content
|
| 116 |
+
i = 0
|
| 117 |
+
while i < self.retry_attempts:
|
| 118 |
+
critic_response = self.step(input_message)
|
| 119 |
+
|
| 120 |
+
if critic_response.msgs is None or len(critic_response.msgs) == 0:
|
| 121 |
+
raise RuntimeError("Got None critic messages.")
|
| 122 |
+
if critic_response.terminated:
|
| 123 |
+
raise RuntimeError("Critic step failed.")
|
| 124 |
+
|
| 125 |
+
critic_msg = critic_response.msg
|
| 126 |
+
if self.verbose:
|
| 127 |
+
print_text_animated(
|
| 128 |
+
self.logger_color + "\n> Critic response: "
|
| 129 |
+
f"\x1b[3m{critic_msg.content}\x1b[0m\n"
|
| 130 |
+
)
|
| 131 |
+
choice = self.parse_critic(critic_msg)
|
| 132 |
+
|
| 133 |
+
if choice in self.options_dict:
|
| 134 |
+
return self.options_dict[choice]
|
| 135 |
+
else:
|
| 136 |
+
input_message = BaseMessage(
|
| 137 |
+
role_name=input_message.role_name,
|
| 138 |
+
role_type=input_message.role_type,
|
| 139 |
+
meta_dict=input_message.meta_dict,
|
| 140 |
+
content="> Invalid choice. Please choose again.\n"
|
| 141 |
+
+ msg_content,
|
| 142 |
+
)
|
| 143 |
+
i += 1
|
| 144 |
+
warnings.warn(
|
| 145 |
+
"Critic failed to get a valid option. "
|
| 146 |
+
f"After {self.retry_attempts} attempts. "
|
| 147 |
+
"Returning a random option."
|
| 148 |
+
)
|
| 149 |
+
return random.choice(list(self.options_dict.values()))
|
| 150 |
+
|
| 151 |
+
def parse_critic(self, critic_msg: BaseMessage) -> Optional[str]:
|
| 152 |
+
r"""Parses the critic's message and extracts the choice.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
critic_msg (BaseMessage): A `BaseMessage` object representing the
|
| 156 |
+
critic's response.
|
| 157 |
+
|
| 158 |
+
Returns:
|
| 159 |
+
Optional[str]: The critic's choice as a string, or None if the
|
| 160 |
+
message could not be parsed.
|
| 161 |
+
"""
|
| 162 |
+
choice = str(get_first_int(critic_msg.content))
|
| 163 |
+
return choice
|
| 164 |
+
|
| 165 |
+
def reduce_step(
|
| 166 |
+
self,
|
| 167 |
+
input_messages: Sequence[BaseMessage],
|
| 168 |
+
) -> ChatAgentResponse:
|
| 169 |
+
r"""Performs one step of the conversation by flattening options to the
|
| 170 |
+
critic, getting the option, and parsing the choice.
|
| 171 |
+
|
| 172 |
+
Args:
|
| 173 |
+
input_messages (Sequence[BaseMessage]): A list of BaseMessage
|
| 174 |
+
objects.
|
| 175 |
+
|
| 176 |
+
Returns:
|
| 177 |
+
ChatAgentResponse: A `ChatAgentResponse` object includes the
|
| 178 |
+
critic's choice.
|
| 179 |
+
"""
|
| 180 |
+
meta_chat_message = BaseMessage(
|
| 181 |
+
role_name=input_messages[0].role_name,
|
| 182 |
+
role_type=input_messages[0].role_type,
|
| 183 |
+
meta_dict=input_messages[0].meta_dict,
|
| 184 |
+
content="",
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
flatten_options = self.flatten_options(input_messages)
|
| 188 |
+
if self.verbose:
|
| 189 |
+
print_text_animated(
|
| 190 |
+
self.logger_color + f"\x1b[3m{flatten_options}\x1b[0m\n"
|
| 191 |
+
)
|
| 192 |
+
input_msg = meta_chat_message.create_new_instance(flatten_options)
|
| 193 |
+
|
| 194 |
+
option = self.get_option(input_msg)
|
| 195 |
+
output_msg = meta_chat_message.create_new_instance(option)
|
| 196 |
+
|
| 197 |
+
# TODO: The return `info` can be improved.
|
| 198 |
+
return ChatAgentResponse(
|
| 199 |
+
msgs=[output_msg],
|
| 200 |
+
terminated=False,
|
| 201 |
+
info={},
|
| 202 |
+
)
|
Paper2Poster/camel/agents/deductive_reasoner_agent.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, List, Optional, Union
|
| 16 |
+
|
| 17 |
+
from camel.agents.chat_agent import ChatAgent
|
| 18 |
+
from camel.logger import get_logger
|
| 19 |
+
from camel.messages import BaseMessage
|
| 20 |
+
from camel.models import BaseModelBackend
|
| 21 |
+
from camel.prompts import TextPrompt
|
| 22 |
+
from camel.types import RoleType
|
| 23 |
+
|
| 24 |
+
logger = get_logger(__name__)
|
| 25 |
+
|
| 26 |
+
# AgentOps decorator setting
|
| 27 |
+
try:
|
| 28 |
+
import os
|
| 29 |
+
|
| 30 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 31 |
+
from agentops import track_agent
|
| 32 |
+
else:
|
| 33 |
+
raise ImportError
|
| 34 |
+
except (ImportError, AttributeError):
|
| 35 |
+
from camel.utils import track_agent
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@track_agent(name="DeductiveReasonerAgent")
|
| 39 |
+
class DeductiveReasonerAgent(ChatAgent):
|
| 40 |
+
r"""An agent responsible for deductive reasoning. Model of deductive
|
| 41 |
+
reasoning:
|
| 42 |
+
- L: A ⊕ C -> q * B
|
| 43 |
+
- A represents the known starting state.
|
| 44 |
+
- B represents the known target state.
|
| 45 |
+
- C represents the conditions required to transition from A to B.
|
| 46 |
+
- Q represents the quality or effectiveness of the transition from
|
| 47 |
+
A to B.
|
| 48 |
+
- L represents the path or process from A to B.
|
| 49 |
+
|
| 50 |
+
Args:
|
| 51 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 52 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 53 |
+
`GPT_4O_MINI`)
|
| 54 |
+
"""
|
| 55 |
+
|
| 56 |
+
def __init__(
|
| 57 |
+
self,
|
| 58 |
+
model: Optional[BaseModelBackend] = None,
|
| 59 |
+
) -> None:
|
| 60 |
+
system_message = BaseMessage(
|
| 61 |
+
role_name="Insight Agent",
|
| 62 |
+
role_type=RoleType.ASSISTANT,
|
| 63 |
+
meta_dict=None,
|
| 64 |
+
content="You assign roles based on tasks.",
|
| 65 |
+
)
|
| 66 |
+
super().__init__(system_message, model=model)
|
| 67 |
+
|
| 68 |
+
def deduce_conditions_and_quality(
|
| 69 |
+
self,
|
| 70 |
+
starting_state: str,
|
| 71 |
+
target_state: str,
|
| 72 |
+
role_descriptions_dict: Optional[Dict[str, str]] = None,
|
| 73 |
+
) -> Dict[str, Union[List[str], Dict[str, str]]]:
|
| 74 |
+
r"""Derives the conditions and quality from the starting state and the
|
| 75 |
+
target state based on the model of the deductive reasoning and the
|
| 76 |
+
knowledge base. It can optionally consider the roles involved in the
|
| 77 |
+
scenario, which allows tailoring the output more closely to the AI
|
| 78 |
+
agent's environment.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
starting_state (str): The initial or starting state from which
|
| 82 |
+
conditions are deduced.
|
| 83 |
+
target_state (str): The target state of the task.
|
| 84 |
+
role_descriptions_dict (Optional[Dict[str, str]], optional): The
|
| 85 |
+
descriptions of the roles. (default: :obj:`None`)
|
| 86 |
+
role_descriptions_dict (Optional[Dict[str, str]], optional): A
|
| 87 |
+
dictionary describing the roles involved in the scenario. This
|
| 88 |
+
is optional and can be used to provide a context for the
|
| 89 |
+
CAMEL's role-playing, enabling the generation of more relevant
|
| 90 |
+
and tailored conditions and quality assessments. This could be
|
| 91 |
+
generated using a `RoleAssignmentAgent()` or defined manually
|
| 92 |
+
by the user.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
Dict[str, Union[List[str], Dict[str, str]]]: A dictionary with the
|
| 96 |
+
extracted data from the message. The dictionary contains three
|
| 97 |
+
keys:
|
| 98 |
+
- 'conditions': A list where each key is a condition ID and
|
| 99 |
+
each value is the corresponding condition text.
|
| 100 |
+
- 'labels': A list of label strings extracted from the message.
|
| 101 |
+
- 'quality': A string of quality assessment strings extracted
|
| 102 |
+
from the message.
|
| 103 |
+
"""
|
| 104 |
+
self.reset()
|
| 105 |
+
|
| 106 |
+
deduce_prompt = """You are a deductive reasoner. You are tasked to
|
| 107 |
+
complete the TASK based on the THOUGHT OF DEDUCTIVE REASONING, the
|
| 108 |
+
STARTING STATE A and the TARGET STATE B. You are given the CONTEXT
|
| 109 |
+
CONTENT to help you complete the TASK.
|
| 110 |
+
Your answer MUST strictly adhere to the structure of ANSWER TEMPLATE, ONLY
|
| 111 |
+
fill in the BLANKs, and DO NOT alter or modify any other part of the template
|
| 112 |
+
|
| 113 |
+
===== MODELING OF DEDUCTIVE REASONING =====
|
| 114 |
+
You are tasked with understanding a mathematical model based on the components
|
| 115 |
+
${A, B, C, Q, L}$. In this model: ``L: A ⊕ C -> q * B``.
|
| 116 |
+
- $A$ represents the known starting state.
|
| 117 |
+
- $B$ represents the known target state.
|
| 118 |
+
- $C$ represents the conditions required to transition from $A$ to $B$.
|
| 119 |
+
- $Q$ represents the quality or effectiveness of the transition from $A$ to
|
| 120 |
+
$B$.
|
| 121 |
+
- $L$ represents the path or process from $A$ to $B$.
|
| 122 |
+
|
| 123 |
+
===== THOUGHT OF DEDUCTIVE REASONING =====
|
| 124 |
+
1. Define the Parameters of A and B:
|
| 125 |
+
- Characterization: Before delving into transitions, thoroughly understand
|
| 126 |
+
the nature and boundaries of both $A$ and $B$. This includes the type,
|
| 127 |
+
properties, constraints, and possible interactions between the two.
|
| 128 |
+
- Contrast and Compare: Highlight the similarities and differences between
|
| 129 |
+
$A$ and $B$. This comparative analysis will give an insight into what
|
| 130 |
+
needs changing and what remains constant.
|
| 131 |
+
2. Historical & Empirical Analysis:
|
| 132 |
+
- Previous Transitions according to the Knowledge Base of GPT: (if
|
| 133 |
+
applicable) Extract conditions and patterns from the historical instances
|
| 134 |
+
where a similar transition from a state comparable to $A$ moved towards
|
| 135 |
+
$B$.
|
| 136 |
+
- Scientific Principles: (if applicable) Consider the underlying
|
| 137 |
+
scientific principles governing or related to the states and their
|
| 138 |
+
transition. For example, if $A$ and $B$ are physical states, laws of
|
| 139 |
+
physics might apply.
|
| 140 |
+
3. Logical Deduction of Conditions ($C$):
|
| 141 |
+
- Direct Path Analysis: What are the immediate and direct conditions
|
| 142 |
+
required to move from $A$ to $B$?
|
| 143 |
+
- Intermediate States: Are there states between $A$ and $B$ that must be
|
| 144 |
+
traversed or can be used to make the transition smoother or more
|
| 145 |
+
efficient? If yes, what is the content?
|
| 146 |
+
- Constraints & Limitations: Identify potential barriers or restrictions
|
| 147 |
+
in moving from $A$ to $B$. These can be external (e.g., environmental
|
| 148 |
+
factors) or internal (properties of $A$ or $B$).
|
| 149 |
+
- Resource and Information Analysis: What resources and information are
|
| 150 |
+
required for the transition? This could be time, entity, factor, code
|
| 151 |
+
language, software platform, unknowns, etc.
|
| 152 |
+
- External Influences: Consider socio-economic, political, or
|
| 153 |
+
environmental factors (if applicable) that could influence the transition
|
| 154 |
+
conditions.
|
| 155 |
+
- Creative/Heuristic Reasoning: Open your mind to multiple possible $C$'s,
|
| 156 |
+
no matter how unconventional they might seem. Utilize analogies,
|
| 157 |
+
metaphors, or brainstorming techniques to envision possible conditions or
|
| 158 |
+
paths from $A$ to $B$.
|
| 159 |
+
- The conditions $C$ should be multiple but in one sentence. And each
|
| 160 |
+
condition should be concerned with one aspect/entity.
|
| 161 |
+
4. Entity/Label Recognition of Conditions ($C$):
|
| 162 |
+
- Identify and categorize entities of Conditions ($C$) such as the names,
|
| 163 |
+
locations, dates, specific technical terms or contextual parameters that
|
| 164 |
+
might be associated with events, innovations post-2022.
|
| 165 |
+
- The output of the entities/labels will be used as tags or labels for
|
| 166 |
+
semantic similarity searches. The entities/labels may be the words, or
|
| 167 |
+
phrases, each of them should contain valuable, high information entropy
|
| 168 |
+
information, and should be independent.
|
| 169 |
+
- Ensure that the identified entities are formatted in a manner suitable
|
| 170 |
+
for database indexing and retrieval. Organize the entities into
|
| 171 |
+
categories, and combine the category with its instance into a continuous
|
| 172 |
+
phrase, without using colons or other separators.
|
| 173 |
+
- Format these entities for database indexing: output the category rather
|
| 174 |
+
than its instance/content into a continuous phrase. For example, instead
|
| 175 |
+
of "Jan. 02", identify it as "Event time".
|
| 176 |
+
5. Quality Assessment ($Q$):
|
| 177 |
+
- Efficiency: How efficient is the transition from $A$ to $B$, which
|
| 178 |
+
measures the resources used versus the desired outcome?
|
| 179 |
+
- Effectiveness: Did the transition achieve the desired outcome or was the
|
| 180 |
+
target state achieved as intended?
|
| 181 |
+
- Safety & Risks: Assess any risks associated with the transition and the
|
| 182 |
+
measures to mitigate them.
|
| 183 |
+
- Feedback Mechanisms: Incorporate feedback loops to continuously monitor
|
| 184 |
+
and adjust the quality of transition, making it more adaptive.
|
| 185 |
+
6. Iterative Evaluation:
|
| 186 |
+
- Test & Refine: Based on the initially deduced conditions and assessed
|
| 187 |
+
quality, iterate the process to refine and optimize the transition. This
|
| 188 |
+
might involve tweaking conditions, employing different paths, or changing
|
| 189 |
+
resources.
|
| 190 |
+
- Feedback Integration: Use feedback to make improvements and increase the
|
| 191 |
+
quality of the transition.
|
| 192 |
+
7. Real-world scenarios often present challenges that may not be captured by
|
| 193 |
+
models and frameworks. While using the model, maintain an adaptive mindset:
|
| 194 |
+
- Scenario Exploration: Continuously imagine various possible scenarios,
|
| 195 |
+
both positive and negative, to prepare for unexpected events.
|
| 196 |
+
- Flexibility: Be prepared to modify conditions ($C$) or alter the path/
|
| 197 |
+
process ($L$) if unforeseen challenges arise.
|
| 198 |
+
- Feedback Integration: Rapidly integrate feedback from actual
|
| 199 |
+
implementations to adjust the model's application, ensuring relevancy and
|
| 200 |
+
effectiveness.
|
| 201 |
+
|
| 202 |
+
===== TASK =====
|
| 203 |
+
Given the starting state $A$ and the target state $B$, assuming that a path
|
| 204 |
+
$L$ always exists between $A$ and $B$, how can one deduce or identify the
|
| 205 |
+
necessary conditions $C$ and the quality $Q$ of the transition?
|
| 206 |
+
|
| 207 |
+
===== STARTING STATE $A$ =====
|
| 208 |
+
{starting_state}
|
| 209 |
+
|
| 210 |
+
===== TARGET STATE $B$ =====
|
| 211 |
+
{target_state}
|
| 212 |
+
|
| 213 |
+
{role_with_description_prompt}
|
| 214 |
+
===== ANSWER TEMPLATE =====
|
| 215 |
+
- Characterization and comparison of $A$ and $B$:\n<BLANK>
|
| 216 |
+
- Historical & Empirical Analysis:\n<BLANK>/None
|
| 217 |
+
- Logical Deduction of Conditions ($C$) (multiple conditions can be deduced):
|
| 218 |
+
condition <NUM>:
|
| 219 |
+
<BLANK>.
|
| 220 |
+
- Entity/Label Recognition of Conditions:\n[<BLANK>, <BLANK>, ...] (include
|
| 221 |
+
square brackets)
|
| 222 |
+
- Quality Assessment ($Q$) (do not use symbols):
|
| 223 |
+
<BLANK>.
|
| 224 |
+
- Iterative Evaluation:\n<BLANK>/None"""
|
| 225 |
+
|
| 226 |
+
if role_descriptions_dict is not None:
|
| 227 |
+
role_names = role_descriptions_dict.keys()
|
| 228 |
+
role_with_description_prompt = (
|
| 229 |
+
"===== ROLES WITH DESCRIPTIONS =====\n"
|
| 230 |
+
+ "\n".join(
|
| 231 |
+
f"{role_name}:\n{role_descriptions_dict[role_name]}\n"
|
| 232 |
+
for role_name in role_names
|
| 233 |
+
)
|
| 234 |
+
+ "\n\n"
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
role_with_description_prompt = ""
|
| 238 |
+
deduce_prompt = TextPrompt(deduce_prompt)
|
| 239 |
+
|
| 240 |
+
deduce = deduce_prompt.format(
|
| 241 |
+
starting_state=starting_state,
|
| 242 |
+
target_state=target_state,
|
| 243 |
+
role_with_description_prompt=role_with_description_prompt,
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
conditions_and_quality_generation_msg = BaseMessage.make_user_message(
|
| 247 |
+
role_name="Deductive Reasoner", content=deduce
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
response = self.step(
|
| 251 |
+
input_message=conditions_and_quality_generation_msg
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if response.terminated:
|
| 255 |
+
raise RuntimeError(
|
| 256 |
+
"Deduction failed. Error:\n" + f"{response.info}"
|
| 257 |
+
)
|
| 258 |
+
msg: BaseMessage = response.msg
|
| 259 |
+
logger.info(f"Message content:\n{msg.content}")
|
| 260 |
+
|
| 261 |
+
# Extract the conditions from the message
|
| 262 |
+
conditions_dict = {
|
| 263 |
+
f"condition {i}": cdt.replace("<", "")
|
| 264 |
+
.replace(">", "")
|
| 265 |
+
.strip()
|
| 266 |
+
.strip('\n')
|
| 267 |
+
for i, cdt in re.findall(
|
| 268 |
+
r"condition (\d+):\s*(.+?)(?=condition \d+|- Entity)",
|
| 269 |
+
msg.content,
|
| 270 |
+
re.DOTALL,
|
| 271 |
+
)
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
# Extract the labels from the message
|
| 275 |
+
labels = [
|
| 276 |
+
label.strip().strip('\n').strip("\"'")
|
| 277 |
+
for label in re.findall(
|
| 278 |
+
r"Entity/Label Recognition of Conditions:\n\[(.+?)\]",
|
| 279 |
+
msg.content,
|
| 280 |
+
re.DOTALL,
|
| 281 |
+
)[0].split(",")
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# Extract the quality from the message
|
| 285 |
+
quality = next(
|
| 286 |
+
q.strip().strip('\n')
|
| 287 |
+
for q in re.findall(
|
| 288 |
+
r"Quality Assessment \(\$Q\$\) \(do not use symbols\):"
|
| 289 |
+
r"\n(.+?)- Iterative",
|
| 290 |
+
msg.content,
|
| 291 |
+
re.DOTALL,
|
| 292 |
+
)
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
# Convert them into JSON format
|
| 296 |
+
conditions_and_quality_json: Dict[
|
| 297 |
+
str, Union[List[str], Dict[str, str]]
|
| 298 |
+
] = {}
|
| 299 |
+
conditions_and_quality_json["conditions"] = conditions_dict
|
| 300 |
+
conditions_and_quality_json["labels"] = labels
|
| 301 |
+
conditions_and_quality_json["evaluate_quality"] = quality
|
| 302 |
+
|
| 303 |
+
return conditions_and_quality_json
|
Paper2Poster/camel/agents/embodied_agent.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, List, Optional
|
| 15 |
+
|
| 16 |
+
from colorama import Fore
|
| 17 |
+
|
| 18 |
+
from camel.agents.chat_agent import ChatAgent
|
| 19 |
+
from camel.agents.tool_agents.base import BaseToolAgent
|
| 20 |
+
from camel.interpreters import (
|
| 21 |
+
BaseInterpreter,
|
| 22 |
+
InternalPythonInterpreter,
|
| 23 |
+
SubprocessInterpreter,
|
| 24 |
+
)
|
| 25 |
+
from camel.messages import BaseMessage
|
| 26 |
+
from camel.models import BaseModelBackend
|
| 27 |
+
from camel.responses import ChatAgentResponse
|
| 28 |
+
from camel.utils import print_text_animated
|
| 29 |
+
|
| 30 |
+
# AgentOps decorator setting
|
| 31 |
+
try:
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 35 |
+
from agentops import track_agent
|
| 36 |
+
else:
|
| 37 |
+
raise ImportError
|
| 38 |
+
except (ImportError, AttributeError):
|
| 39 |
+
from camel.utils import track_agent
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@track_agent(name="EmbodiedAgent")
|
| 43 |
+
class EmbodiedAgent(ChatAgent):
|
| 44 |
+
r"""Class for managing conversations of CAMEL Embodied Agents.
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
system_message (BaseMessage): The system message for the chat agent.
|
| 48 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 49 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 50 |
+
`GPT_4O_MINI`)
|
| 51 |
+
message_window_size (int, optional): The maximum number of previous
|
| 52 |
+
messages to include in the context window. If `None`, no windowing
|
| 53 |
+
is performed. (default: :obj:`None`)
|
| 54 |
+
tool_agents (List[BaseToolAgent], optional): The tools agents to use in
|
| 55 |
+
the embodied agent. (default: :obj:`None`)
|
| 56 |
+
code_interpreter (BaseInterpreter, optional): The code interpreter to
|
| 57 |
+
execute codes. If `code_interpreter` and `tool_agent` are both
|
| 58 |
+
`None`, default to `SubProcessInterpreter`. If `code_interpreter`
|
| 59 |
+
is `None` and `tool_agents` is not `None`, default to
|
| 60 |
+
`InternalPythonInterpreter`. (default: :obj:`None`)
|
| 61 |
+
verbose (bool, optional): Whether to print the critic's messages.
|
| 62 |
+
logger_color (Any): The color of the logger displayed to the user.
|
| 63 |
+
(default: :obj:`Fore.MAGENTA`)
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def __init__(
|
| 67 |
+
self,
|
| 68 |
+
system_message: BaseMessage,
|
| 69 |
+
model: Optional[BaseModelBackend] = None,
|
| 70 |
+
message_window_size: Optional[int] = None,
|
| 71 |
+
tool_agents: Optional[List[BaseToolAgent]] = None,
|
| 72 |
+
code_interpreter: Optional[BaseInterpreter] = None,
|
| 73 |
+
verbose: bool = False,
|
| 74 |
+
logger_color: Any = Fore.MAGENTA,
|
| 75 |
+
) -> None:
|
| 76 |
+
self.tool_agents = tool_agents
|
| 77 |
+
self.code_interpreter: BaseInterpreter
|
| 78 |
+
if code_interpreter is not None:
|
| 79 |
+
self.code_interpreter = code_interpreter
|
| 80 |
+
elif self.tool_agents:
|
| 81 |
+
self.code_interpreter = InternalPythonInterpreter()
|
| 82 |
+
else:
|
| 83 |
+
self.code_interpreter = SubprocessInterpreter()
|
| 84 |
+
|
| 85 |
+
if self.tool_agents:
|
| 86 |
+
system_message = self._set_tool_agents(system_message)
|
| 87 |
+
self.verbose = verbose
|
| 88 |
+
self.logger_color = logger_color
|
| 89 |
+
super().__init__(
|
| 90 |
+
system_message=system_message,
|
| 91 |
+
model=model,
|
| 92 |
+
message_window_size=message_window_size,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def _set_tool_agents(self, system_message: BaseMessage) -> BaseMessage:
|
| 96 |
+
action_space_prompt = self._get_tool_agents_prompt()
|
| 97 |
+
result_message = system_message.create_new_instance(
|
| 98 |
+
content=system_message.content.format(
|
| 99 |
+
action_space=action_space_prompt
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
if self.tool_agents is not None:
|
| 103 |
+
self.code_interpreter.update_action_space(
|
| 104 |
+
{tool.name: tool for tool in self.tool_agents}
|
| 105 |
+
)
|
| 106 |
+
return result_message
|
| 107 |
+
|
| 108 |
+
def _get_tool_agents_prompt(self) -> str:
|
| 109 |
+
r"""Returns the action space prompt.
|
| 110 |
+
|
| 111 |
+
Returns:
|
| 112 |
+
str: The action space prompt.
|
| 113 |
+
"""
|
| 114 |
+
if self.tool_agents is not None:
|
| 115 |
+
return "\n".join(
|
| 116 |
+
[
|
| 117 |
+
f"*** {tool.name} ***:\n {tool.description}"
|
| 118 |
+
for tool in self.tool_agents
|
| 119 |
+
]
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
return ""
|
| 123 |
+
|
| 124 |
+
def get_tool_agent_names(self) -> List[str]:
|
| 125 |
+
r"""Returns the names of tool agents.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
List[str]: The names of tool agents.
|
| 129 |
+
"""
|
| 130 |
+
if self.tool_agents is not None:
|
| 131 |
+
return [tool.name for tool in self.tool_agents]
|
| 132 |
+
else:
|
| 133 |
+
return []
|
| 134 |
+
|
| 135 |
+
# ruff: noqa: E501
|
| 136 |
+
def step(self, input_message: BaseMessage) -> ChatAgentResponse: # type: ignore[override]
|
| 137 |
+
r"""Performs a step in the conversation.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
input_message (BaseMessage): The input message.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
ChatAgentResponse: A struct containing the output messages,
|
| 144 |
+
a boolean indicating whether the chat session has terminated,
|
| 145 |
+
and information about the chat session.
|
| 146 |
+
"""
|
| 147 |
+
response = super().step(input_message)
|
| 148 |
+
|
| 149 |
+
if response.msgs is None or len(response.msgs) == 0:
|
| 150 |
+
raise RuntimeError("Got None output messages.")
|
| 151 |
+
if response.terminated:
|
| 152 |
+
raise RuntimeError(f"{self.__class__.__name__} step failed.")
|
| 153 |
+
|
| 154 |
+
# NOTE: Only single output messages are supported
|
| 155 |
+
explanations, codes = response.msg.extract_text_and_code_prompts()
|
| 156 |
+
|
| 157 |
+
if self.verbose:
|
| 158 |
+
for explanation, code in zip(explanations, codes):
|
| 159 |
+
print_text_animated(
|
| 160 |
+
self.logger_color + f"> Explanation:\n{explanation}"
|
| 161 |
+
)
|
| 162 |
+
print_text_animated(self.logger_color + f"> Code:\n{code}")
|
| 163 |
+
|
| 164 |
+
if len(explanations) > len(codes):
|
| 165 |
+
print_text_animated(
|
| 166 |
+
self.logger_color + f"> Explanation:\n{explanations[-1]}"
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
content = response.msg.content
|
| 170 |
+
|
| 171 |
+
if codes is not None:
|
| 172 |
+
try:
|
| 173 |
+
content = "\n> Executed Results:\n"
|
| 174 |
+
for block_idx, code in enumerate(codes):
|
| 175 |
+
executed_output = self.code_interpreter.run(
|
| 176 |
+
code, code.code_type
|
| 177 |
+
)
|
| 178 |
+
content += (
|
| 179 |
+
f"Executing code block {block_idx}: {{\n"
|
| 180 |
+
+ executed_output
|
| 181 |
+
+ "}\n"
|
| 182 |
+
)
|
| 183 |
+
except InterruptedError as e:
|
| 184 |
+
content = (
|
| 185 |
+
f"\n> Running code fail: {e}\n"
|
| 186 |
+
"Please regenerate the code."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# TODO: Handle errors
|
| 190 |
+
content = input_message.content + f"\n> Embodied Actions:\n{content}"
|
| 191 |
+
message = BaseMessage(
|
| 192 |
+
input_message.role_name,
|
| 193 |
+
input_message.role_type,
|
| 194 |
+
input_message.meta_dict,
|
| 195 |
+
content,
|
| 196 |
+
)
|
| 197 |
+
return ChatAgentResponse(
|
| 198 |
+
msgs=[message],
|
| 199 |
+
terminated=response.terminated,
|
| 200 |
+
info=response.info,
|
| 201 |
+
)
|
Paper2Poster/camel/agents/knowledge_graph_agent.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import TYPE_CHECKING, Optional, Union
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from unstructured.documents.elements import Element
|
| 18 |
+
|
| 19 |
+
from camel.agents import ChatAgent
|
| 20 |
+
from camel.messages import BaseMessage
|
| 21 |
+
from camel.models import BaseModelBackend
|
| 22 |
+
from camel.prompts import TextPrompt
|
| 23 |
+
from camel.storages.graph_storages.graph_element import (
|
| 24 |
+
GraphElement,
|
| 25 |
+
Node,
|
| 26 |
+
Relationship,
|
| 27 |
+
)
|
| 28 |
+
from camel.types import RoleType
|
| 29 |
+
|
| 30 |
+
# AgentOps decorator setting
|
| 31 |
+
try:
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 35 |
+
from agentops import track_agent
|
| 36 |
+
else:
|
| 37 |
+
raise ImportError
|
| 38 |
+
except (ImportError, AttributeError):
|
| 39 |
+
from camel.utils import track_agent
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
text_prompt = """
|
| 43 |
+
You are tasked with extracting nodes and relationships from given content and
|
| 44 |
+
structures them into Node and Relationship objects. Here's the outline of what
|
| 45 |
+
you needs to do:
|
| 46 |
+
|
| 47 |
+
Content Extraction:
|
| 48 |
+
You should be able to process input content and identify entities mentioned
|
| 49 |
+
within it.
|
| 50 |
+
Entities can be any noun phrases or concepts that represent distinct entities
|
| 51 |
+
in the context of the given content.
|
| 52 |
+
|
| 53 |
+
Node Extraction:
|
| 54 |
+
For each identified entity, you should create a Node object.
|
| 55 |
+
Each Node object should have a unique identifier (id) and a type (type).
|
| 56 |
+
Additional properties associated with the node can also be extracted and
|
| 57 |
+
stored.
|
| 58 |
+
|
| 59 |
+
Relationship Extraction:
|
| 60 |
+
You should identify relationships between entities mentioned in the content.
|
| 61 |
+
For each relationship, create a Relationship object.
|
| 62 |
+
A Relationship object should have a subject (subj) and an object (obj) which
|
| 63 |
+
are Node objects representing the entities involved in the relationship.
|
| 64 |
+
Each relationship should also have a type (type), and additional properties if
|
| 65 |
+
applicable.
|
| 66 |
+
|
| 67 |
+
Output Formatting:
|
| 68 |
+
The extracted nodes and relationships should be formatted as instances of the
|
| 69 |
+
provided Node and Relationship classes.
|
| 70 |
+
Ensure that the extracted data adheres to the structure defined by the classes.
|
| 71 |
+
Output the structured data in a format that can be easily validated against
|
| 72 |
+
the provided code.
|
| 73 |
+
|
| 74 |
+
Instructions for you:
|
| 75 |
+
Read the provided content thoroughly.
|
| 76 |
+
Identify distinct entities mentioned in the content and categorize them as
|
| 77 |
+
nodes.
|
| 78 |
+
Determine relationships between these entities and represent them as directed
|
| 79 |
+
relationships.
|
| 80 |
+
Provide the extracted nodes and relationships in the specified format below.
|
| 81 |
+
Example for you:
|
| 82 |
+
|
| 83 |
+
Example Content:
|
| 84 |
+
"John works at XYZ Corporation. He is a software engineer. The company is
|
| 85 |
+
located in New York City."
|
| 86 |
+
|
| 87 |
+
Expected Output:
|
| 88 |
+
|
| 89 |
+
Nodes:
|
| 90 |
+
|
| 91 |
+
Node(id='John', type='Person')
|
| 92 |
+
Node(id='XYZ Corporation', type='Organization')
|
| 93 |
+
Node(id='New York City', type='Location')
|
| 94 |
+
|
| 95 |
+
Relationships:
|
| 96 |
+
|
| 97 |
+
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='XYZ
|
| 98 |
+
Corporation', type='Organization'), type='WorksAt')
|
| 99 |
+
Relationship(subj=Node(id='John', type='Person'), obj=Node(id='New York City',
|
| 100 |
+
type='Location'), type='ResidesIn')
|
| 101 |
+
|
| 102 |
+
===== TASK =====
|
| 103 |
+
Please extracts nodes and relationships from given content and structures them
|
| 104 |
+
into Node and Relationship objects.
|
| 105 |
+
|
| 106 |
+
{task}
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@track_agent(name="KnowledgeGraphAgent")
|
| 111 |
+
class KnowledgeGraphAgent(ChatAgent):
|
| 112 |
+
r"""An agent that can extract node and relationship information for
|
| 113 |
+
different entities from given `Element` content.
|
| 114 |
+
|
| 115 |
+
Attributes:
|
| 116 |
+
task_prompt (TextPrompt): A prompt for the agent to extract node and
|
| 117 |
+
relationship information for different entities.
|
| 118 |
+
"""
|
| 119 |
+
|
| 120 |
+
def __init__(
|
| 121 |
+
self,
|
| 122 |
+
model: Optional[BaseModelBackend] = None,
|
| 123 |
+
) -> None:
|
| 124 |
+
r"""Initialize the `KnowledgeGraphAgent`.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 128 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 129 |
+
`GPT_4O_MINI`)
|
| 130 |
+
"""
|
| 131 |
+
system_message = BaseMessage(
|
| 132 |
+
role_name="Graphify",
|
| 133 |
+
role_type=RoleType.ASSISTANT,
|
| 134 |
+
meta_dict=None,
|
| 135 |
+
content="Your mission is to transform unstructured content "
|
| 136 |
+
"into structured graph data. Extract nodes and relationships with "
|
| 137 |
+
"precision, and let the connections unfold. Your graphs will "
|
| 138 |
+
"illuminate the hidden connections within the chaos of "
|
| 139 |
+
"information.",
|
| 140 |
+
)
|
| 141 |
+
super().__init__(system_message, model=model)
|
| 142 |
+
|
| 143 |
+
def run(
|
| 144 |
+
self,
|
| 145 |
+
element: "Element",
|
| 146 |
+
parse_graph_elements: bool = False,
|
| 147 |
+
) -> Union[str, GraphElement]:
|
| 148 |
+
r"""Run the agent to extract node and relationship information.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
element (Element): The input element.
|
| 152 |
+
parse_graph_elements (bool, optional): Whether to parse into
|
| 153 |
+
`GraphElement`. Defaults to `False`.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
Union[str, GraphElement]: The extracted node and relationship
|
| 157 |
+
information. If `parse_graph_elements` is `True` then return
|
| 158 |
+
`GraphElement`, else return `str`.
|
| 159 |
+
"""
|
| 160 |
+
self.reset()
|
| 161 |
+
self.element = element
|
| 162 |
+
|
| 163 |
+
knowledge_graph_prompt = TextPrompt(text_prompt)
|
| 164 |
+
knowledge_graph_generation = knowledge_graph_prompt.format(
|
| 165 |
+
task=str(element)
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
knowledge_graph_generation_msg = BaseMessage.make_user_message(
|
| 169 |
+
role_name="Graphify", content=knowledge_graph_generation
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
response = self.step(input_message=knowledge_graph_generation_msg)
|
| 173 |
+
|
| 174 |
+
content = response.msg.content
|
| 175 |
+
|
| 176 |
+
if parse_graph_elements:
|
| 177 |
+
content = self._parse_graph_elements(content)
|
| 178 |
+
|
| 179 |
+
return content
|
| 180 |
+
|
| 181 |
+
def _validate_node(self, node: Node) -> bool:
|
| 182 |
+
r"""Validate if the object is a valid Node.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
node (Node): Object to be validated.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
bool: True if the object is a valid Node, False otherwise.
|
| 189 |
+
"""
|
| 190 |
+
return (
|
| 191 |
+
isinstance(node, Node)
|
| 192 |
+
and isinstance(node.id, (str, int))
|
| 193 |
+
and isinstance(node.type, str)
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
def _validate_relationship(self, relationship: Relationship) -> bool:
|
| 197 |
+
r"""Validate if the object is a valid Relationship.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
relationship (Relationship): Object to be validated.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
bool: True if the object is a valid Relationship, False otherwise.
|
| 204 |
+
"""
|
| 205 |
+
return (
|
| 206 |
+
isinstance(relationship, Relationship)
|
| 207 |
+
and self._validate_node(relationship.subj)
|
| 208 |
+
and self._validate_node(relationship.obj)
|
| 209 |
+
and isinstance(relationship.type, str)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
def _parse_graph_elements(self, input_string: str) -> GraphElement:
|
| 213 |
+
r"""Parses graph elements from given content.
|
| 214 |
+
|
| 215 |
+
Args:
|
| 216 |
+
input_string (str): The input content.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
GraphElement: The parsed graph elements.
|
| 220 |
+
"""
|
| 221 |
+
import re
|
| 222 |
+
|
| 223 |
+
# Regular expressions to extract nodes and relationships
|
| 224 |
+
node_pattern = r"Node\(id='(.*?)', type='(.*?)'\)"
|
| 225 |
+
rel_pattern = (
|
| 226 |
+
r"Relationship\(subj=Node\(id='(.*?)', type='(.*?)'\), "
|
| 227 |
+
r"obj=Node\(id='(.*?)', type='(.*?)'\), type='(.*?)'\)"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
nodes = {}
|
| 231 |
+
relationships = []
|
| 232 |
+
|
| 233 |
+
# Extract nodes
|
| 234 |
+
for match in re.finditer(node_pattern, input_string):
|
| 235 |
+
id, type = match.groups()
|
| 236 |
+
properties = {'source': 'agent_created'}
|
| 237 |
+
if id not in nodes:
|
| 238 |
+
node = Node(id=id, type=type, properties=properties)
|
| 239 |
+
if self._validate_node(node):
|
| 240 |
+
nodes[id] = node
|
| 241 |
+
|
| 242 |
+
# Extract relationships
|
| 243 |
+
for match in re.finditer(rel_pattern, input_string):
|
| 244 |
+
subj_id, subj_type, obj_id, obj_type, rel_type = match.groups()
|
| 245 |
+
properties = {'source': 'agent_created'}
|
| 246 |
+
if subj_id in nodes and obj_id in nodes:
|
| 247 |
+
subj = nodes[subj_id]
|
| 248 |
+
obj = nodes[obj_id]
|
| 249 |
+
relationship = Relationship(
|
| 250 |
+
subj=subj, obj=obj, type=rel_type, properties=properties
|
| 251 |
+
)
|
| 252 |
+
if self._validate_relationship(relationship):
|
| 253 |
+
relationships.append(relationship)
|
| 254 |
+
|
| 255 |
+
return GraphElement(
|
| 256 |
+
nodes=list(nodes.values()),
|
| 257 |
+
relationships=relationships,
|
| 258 |
+
source=self.element,
|
| 259 |
+
)
|
Paper2Poster/camel/agents/multi_hop_generator_agent.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import textwrap
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
from pydantic import ConfigDict
|
| 19 |
+
|
| 20 |
+
from camel.agents.programmed_agent_instruction import (
|
| 21 |
+
ProgrammableChatAgent,
|
| 22 |
+
ProgrammedAgentInstructionResult,
|
| 23 |
+
programmable_capability,
|
| 24 |
+
)
|
| 25 |
+
from camel.datagen.source2synth.models import (
|
| 26 |
+
ContextPrompt,
|
| 27 |
+
MultiHopQA,
|
| 28 |
+
)
|
| 29 |
+
from camel.messages import BaseMessage
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class MultiHopGeneratorAgent(ProgrammableChatAgent):
|
| 33 |
+
r"""An agent specialized in generating multi-hop question-answer pairs.
|
| 34 |
+
|
| 35 |
+
This agent is designed to create complex questions that require multiple
|
| 36 |
+
steps of reasoning to answer. It analyzes context to identify related
|
| 37 |
+
facts and generates questions that require connecting these facts
|
| 38 |
+
logically.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
model_config (ConfigDict): Configuration for model behavior.
|
| 42 |
+
system_message (BaseMessage): System message defining agent's role and
|
| 43 |
+
instructions.
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 47 |
+
|
| 48 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 49 |
+
r"""Initialize the MultiHopGeneratorAgent.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
**kwargs (Any): Additional keyword arguments to pass to parent
|
| 53 |
+
class.
|
| 54 |
+
"""
|
| 55 |
+
super().__init__(**kwargs)
|
| 56 |
+
|
| 57 |
+
system_text: str = textwrap.dedent(
|
| 58 |
+
"""\
|
| 59 |
+
You are an expert at generating
|
| 60 |
+
multi-hop question-answer pairs.
|
| 61 |
+
For each context, you should:
|
| 62 |
+
1. Identify multiple related facts or pieces of information
|
| 63 |
+
2. Create questions that require reasoning across these multiple pieces
|
| 64 |
+
3. Ensure the reasoning chain is clear and logical
|
| 65 |
+
4. Generate questions that require at least 2-3 steps of reasoning
|
| 66 |
+
5. Include the reasoning steps in the answer
|
| 67 |
+
|
| 68 |
+
Give your response with this information:
|
| 69 |
+
Question: [Complex question requiring multiple reasoning steps]
|
| 70 |
+
Reasoning Steps:
|
| 71 |
+
1. [First reasoning step]
|
| 72 |
+
2. [Second reasoning step]
|
| 73 |
+
3. [Final reasoning step]
|
| 74 |
+
Answer: [Final answer]
|
| 75 |
+
Supporting Facts: [List of relevant text segments used]
|
| 76 |
+
""" # noqa: E501
|
| 77 |
+
)
|
| 78 |
+
self.system_message = BaseMessage.make_assistant_message(
|
| 79 |
+
role_name='Assistant', content=system_text
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
@programmable_capability
|
| 83 |
+
def generate_multi_hop_qa(
|
| 84 |
+
self, context: str
|
| 85 |
+
) -> ProgrammedAgentInstructionResult[MultiHopQA]:
|
| 86 |
+
r"""Generate a multi-hop question-answer pair from given context.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
context (str): The input text context to generate QA from.
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
ProgrammedAgentInstructionResult[MultiHopQA]: Result containing the
|
| 93 |
+
generated question, reasoning steps, answer, and supporting
|
| 94 |
+
facts.
|
| 95 |
+
|
| 96 |
+
Raises:
|
| 97 |
+
RuntimeError: If the agent fails to generate a response.
|
| 98 |
+
"""
|
| 99 |
+
context_prompt = ContextPrompt(
|
| 100 |
+
main_context=context, related_contexts=None
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
user_message = BaseMessage.make_user_message(
|
| 104 |
+
content=context_prompt.model_dump_json(), role_name="User"
|
| 105 |
+
)
|
| 106 |
+
response = self.step(
|
| 107 |
+
input_message=user_message, response_format=MultiHopQA
|
| 108 |
+
)
|
| 109 |
+
value = MultiHopQA.model_validate_json(response.msgs[0].content)
|
| 110 |
+
|
| 111 |
+
if response.msgs:
|
| 112 |
+
return ProgrammedAgentInstructionResult(
|
| 113 |
+
user_message=user_message,
|
| 114 |
+
agent_message=response.msgs[0],
|
| 115 |
+
value=value,
|
| 116 |
+
)
|
| 117 |
+
raise RuntimeError("No response from agent")
|
Paper2Poster/camel/agents/programmed_agent_instruction.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import abc
|
| 15 |
+
import threading
|
| 16 |
+
from enum import Enum
|
| 17 |
+
from functools import wraps
|
| 18 |
+
from typing import Any, Callable, Generic, Optional, TypeVar
|
| 19 |
+
|
| 20 |
+
from pydantic import BaseModel, ConfigDict
|
| 21 |
+
|
| 22 |
+
from camel.agents import ChatAgent
|
| 23 |
+
from camel.messages import BaseMessage
|
| 24 |
+
|
| 25 |
+
T = TypeVar('T')
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class ProgrammableAgentRequirement(Enum):
|
| 29 |
+
r"""Requirements for programmable agent state.
|
| 30 |
+
|
| 31 |
+
Defines the possible requirements that can be used to repair the state
|
| 32 |
+
of a programmable agent.
|
| 33 |
+
|
| 34 |
+
Attributes:
|
| 35 |
+
LAST_MESSAGE_NOT_USER (str): Requires that the last message in the
|
| 36 |
+
conversation was not from the user.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
LAST_MESSAGE_NOT_USER = "LAST_MESSAGE_NOT_USER"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class ProgrammedAgentInstructionResult(BaseModel, Generic[T]):
|
| 43 |
+
r"""Result of a programmable agent instruction execution.
|
| 44 |
+
|
| 45 |
+
Contains the messages exchanged during execution and the computed value.
|
| 46 |
+
The value type is specified by the generic type parameter T.
|
| 47 |
+
|
| 48 |
+
Attributes:
|
| 49 |
+
user_message (BaseMessage): The message sent by the user.
|
| 50 |
+
agent_message (BaseMessage): The message sent by the agent.
|
| 51 |
+
value (T): The computed result value of type T.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
user_message: BaseMessage
|
| 55 |
+
agent_message: BaseMessage
|
| 56 |
+
value: T
|
| 57 |
+
|
| 58 |
+
model_config = ConfigDict(arbitrary_types_allowed=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class AbstractProgrammableAgent(abc.ABC):
|
| 62 |
+
r"""Abstract class for a programmable agent.
|
| 63 |
+
|
| 64 |
+
A programmable agent is an agent that can be programmed to perform a
|
| 65 |
+
specific function or task. This class defines the interface for a
|
| 66 |
+
programmable agent.
|
| 67 |
+
|
| 68 |
+
These methods should be implemented in order to ensure the agent supports
|
| 69 |
+
the necessary guarantees to enable a programming interface while
|
| 70 |
+
maintaining compatibility in a multi-agent system.
|
| 71 |
+
|
| 72 |
+
A programmable agent is responsible for providing and maintaining a
|
| 73 |
+
programming interface for its functionality.
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
@abc.abstractmethod
|
| 77 |
+
def run_atomic(
|
| 78 |
+
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
| 79 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 80 |
+
r"""Run an atomic operation on the agent.
|
| 81 |
+
|
| 82 |
+
An atomic operation is an operation that is guaranteed to
|
| 83 |
+
be executed without interruption by any other operation.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
| 87 |
+
operation to execute atomically.
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
| 91 |
+
|
| 92 |
+
Raises:
|
| 93 |
+
RuntimeError: If an operation is already in progress.
|
| 94 |
+
"""
|
| 95 |
+
raise NotImplementedError
|
| 96 |
+
|
| 97 |
+
@abc.abstractmethod
|
| 98 |
+
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
| 99 |
+
r"""Repair the state of the agent.
|
| 100 |
+
|
| 101 |
+
Agents may have other non-atomic interfaces, such as a user interface,
|
| 102 |
+
or chat between other agents. This method should restore the agent to
|
| 103 |
+
a state where it can perform operations according to the specified
|
| 104 |
+
requirement.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
requirement (ProgrammableAgentRequirement): The requirement to
|
| 108 |
+
repair the state for.
|
| 109 |
+
"""
|
| 110 |
+
raise NotImplementedError
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def programmable_capability(
|
| 114 |
+
func: Callable[..., ProgrammedAgentInstructionResult[T]],
|
| 115 |
+
) -> Callable[..., ProgrammedAgentInstructionResult[T]]:
|
| 116 |
+
r"""Decorator for programmable agent capabilities.
|
| 117 |
+
|
| 118 |
+
This decorator ensures that the decorated method is executed atomically
|
| 119 |
+
and maintains the agent's state guarantees.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
func (Callable[..., ProgrammedAgentInstructionResult[T]]): The method
|
| 123 |
+
to decorate.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
Callable[..., ProgrammedAgentInstructionResult[T]]: The decorated
|
| 127 |
+
method that ensures atomic execution.
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
@wraps(func)
|
| 131 |
+
def wrapper(
|
| 132 |
+
self, *args: Any, **kwargs: Any
|
| 133 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 134 |
+
return self.run_atomic(lambda: func(self, *args, **kwargs))
|
| 135 |
+
|
| 136 |
+
return wrapper
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
class ProgrammableChatAgent(ChatAgent, AbstractProgrammableAgent):
|
| 140 |
+
r"""A chat agent that can be programmed to perform specific tasks.
|
| 141 |
+
|
| 142 |
+
Provides a default implementation of atomic execution using threading locks
|
| 143 |
+
and basic state tracking for message roles. Implementing classes need to
|
| 144 |
+
provide specific repair logic for their use cases.
|
| 145 |
+
|
| 146 |
+
Attributes:
|
| 147 |
+
_operation_lock (threading.Lock): Lock for ensuring atomic operations.
|
| 148 |
+
_last_message_role (Optional[str]): Role of the last message in the
|
| 149 |
+
conversation.
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, **kwargs: Any) -> None:
|
| 153 |
+
r"""Initialize the ProgrammableChatAgent.
|
| 154 |
+
|
| 155 |
+
Args:
|
| 156 |
+
**kwargs (Any): Additional keyword arguments to pass to parent
|
| 157 |
+
class.
|
| 158 |
+
"""
|
| 159 |
+
super().__init__(**kwargs)
|
| 160 |
+
self._operation_lock = threading.Lock()
|
| 161 |
+
self._last_message_role: Optional[str] = None
|
| 162 |
+
|
| 163 |
+
def run_atomic(
|
| 164 |
+
self, callback: Callable[[], ProgrammedAgentInstructionResult[T]]
|
| 165 |
+
) -> ProgrammedAgentInstructionResult[T]:
|
| 166 |
+
r"""Run an atomic operation on the agent.
|
| 167 |
+
|
| 168 |
+
Ensures thread-safe execution of the callback function by using a lock.
|
| 169 |
+
|
| 170 |
+
Args:
|
| 171 |
+
callback (Callable[[], ProgrammedAgentInstructionResult[T]]): The
|
| 172 |
+
operation to execute atomically.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
ProgrammedAgentInstructionResult[T]: The result of the operation.
|
| 176 |
+
|
| 177 |
+
Raises:
|
| 178 |
+
RuntimeError: If an operation is already in progress.
|
| 179 |
+
"""
|
| 180 |
+
if not self._operation_lock.acquire(blocking=False):
|
| 181 |
+
raise RuntimeError("Operation already in progress")
|
| 182 |
+
|
| 183 |
+
try:
|
| 184 |
+
result = callback()
|
| 185 |
+
self._last_message_role = result.agent_message.role_name
|
| 186 |
+
return result
|
| 187 |
+
finally:
|
| 188 |
+
self._operation_lock.release()
|
| 189 |
+
|
| 190 |
+
def repair_state(self, requirement: ProgrammableAgentRequirement) -> None:
|
| 191 |
+
r"""Repair the state of the agent.
|
| 192 |
+
|
| 193 |
+
Implements basic state repair for message role requirements.
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
requirement (ProgrammableAgentRequirement): The requirement to
|
| 197 |
+
repair the state for.
|
| 198 |
+
"""
|
| 199 |
+
if requirement == ProgrammableAgentRequirement.LAST_MESSAGE_NOT_USER:
|
| 200 |
+
if self._last_message_role == "user":
|
| 201 |
+
raise NotImplementedError(
|
| 202 |
+
"Must implement repair for LAST_MESSAGE_NOT_USER"
|
| 203 |
+
)
|
Paper2Poster/camel/agents/role_assignment_agent.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
import re
|
| 15 |
+
from typing import Dict, Optional, Union
|
| 16 |
+
|
| 17 |
+
from camel.agents.chat_agent import ChatAgent
|
| 18 |
+
from camel.messages import BaseMessage
|
| 19 |
+
from camel.models import BaseModelBackend
|
| 20 |
+
from camel.prompts import TextPrompt
|
| 21 |
+
from camel.types import RoleType
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="RoleAssignmentAgent")
|
| 36 |
+
class RoleAssignmentAgent(ChatAgent):
|
| 37 |
+
r"""An agent that generates role names based on the task prompt.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 41 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 42 |
+
`GPT_4O_MINI`)
|
| 43 |
+
|
| 44 |
+
Attributes:
|
| 45 |
+
role_assignment_prompt (TextPrompt): A prompt for the agent to generate
|
| 46 |
+
role names.
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
model: Optional[BaseModelBackend] = None,
|
| 52 |
+
) -> None:
|
| 53 |
+
system_message = BaseMessage(
|
| 54 |
+
role_name="Role Assigner",
|
| 55 |
+
role_type=RoleType.ASSISTANT,
|
| 56 |
+
meta_dict=None,
|
| 57 |
+
content="You assign roles based on tasks.",
|
| 58 |
+
)
|
| 59 |
+
super().__init__(system_message, model=model)
|
| 60 |
+
|
| 61 |
+
def run(
|
| 62 |
+
self,
|
| 63 |
+
task_prompt: Union[str, TextPrompt],
|
| 64 |
+
num_roles: int = 2,
|
| 65 |
+
) -> Dict[str, str]:
|
| 66 |
+
r"""Generate role names based on the input task prompt.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
task_prompt (Union[str, TextPrompt]): The prompt
|
| 70 |
+
for the task based on which the roles are to be generated.
|
| 71 |
+
num_roles (int, optional): The number of roles to generate.
|
| 72 |
+
(default: :obj:`2`)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
Dict[str, str]: A dictionary mapping role names to their
|
| 76 |
+
descriptions.
|
| 77 |
+
"""
|
| 78 |
+
self.reset()
|
| 79 |
+
|
| 80 |
+
expert_prompt = "===== ANSWER PROMPT =====\n" + "\n".join(
|
| 81 |
+
f"Domain expert {i + 1}: <BLANK>\n"
|
| 82 |
+
f"Associated competencies, characteristics, duties "
|
| 83 |
+
f"and workflows: <BLANK>. End."
|
| 84 |
+
for i in range(num_roles or 0)
|
| 85 |
+
)
|
| 86 |
+
role_assignment_generation_prompt = TextPrompt(
|
| 87 |
+
"You are a role assignment agent, and you're in charge of "
|
| 88 |
+
+ "recruiting {num_roles} experts for the following task."
|
| 89 |
+
+ "\n==== TASK =====\n {task}\n\n"
|
| 90 |
+
+ "Identify the domain experts you'd recruit and detail their "
|
| 91 |
+
+ "associated competencies, characteristics, duties and workflows "
|
| 92 |
+
+ "to complete the task.\n "
|
| 93 |
+
+ "Your answer MUST adhere to the format of ANSWER PROMPT, and "
|
| 94 |
+
+ "ONLY answer the BLANKs.\n"
|
| 95 |
+
+ expert_prompt
|
| 96 |
+
)
|
| 97 |
+
role_assignment_generation = role_assignment_generation_prompt.format(
|
| 98 |
+
num_roles=num_roles, task=task_prompt
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
role_assignment_generation_msg = BaseMessage.make_user_message(
|
| 102 |
+
role_name="Role Assigner", content=role_assignment_generation
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
response = self.step(input_message=role_assignment_generation_msg)
|
| 106 |
+
|
| 107 |
+
msg = response.msg # type: BaseMessage
|
| 108 |
+
terminated = response.terminated
|
| 109 |
+
|
| 110 |
+
# Distribute the output completions into role names and descriptions
|
| 111 |
+
role_names = [
|
| 112 |
+
desc.replace("<|", "").replace("|>", "")
|
| 113 |
+
for desc in re.findall(
|
| 114 |
+
r"Domain expert \d: (.+?)\nAssociated competencies,",
|
| 115 |
+
msg.content,
|
| 116 |
+
re.DOTALL,
|
| 117 |
+
)
|
| 118 |
+
]
|
| 119 |
+
role_descriptions = [
|
| 120 |
+
desc.replace("<|", "").replace("|>", "")
|
| 121 |
+
for desc in re.findall(
|
| 122 |
+
r"Associated competencies, characteristics, "
|
| 123 |
+
r"duties and workflows: (.+?) End.",
|
| 124 |
+
msg.content,
|
| 125 |
+
re.DOTALL,
|
| 126 |
+
)
|
| 127 |
+
]
|
| 128 |
+
|
| 129 |
+
if len(role_names) != num_roles or len(role_descriptions) != num_roles:
|
| 130 |
+
raise RuntimeError(
|
| 131 |
+
"Got None or insufficient information of roles."
|
| 132 |
+
)
|
| 133 |
+
if terminated:
|
| 134 |
+
raise RuntimeError("Role assignment failed.")
|
| 135 |
+
|
| 136 |
+
role_descriptions_dict = {
|
| 137 |
+
role_name: description
|
| 138 |
+
for role_name, description in zip(role_names, role_descriptions)
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
return role_descriptions_dict
|
Paper2Poster/camel/agents/search_agent.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
+
from camel.agents.chat_agent import ChatAgent
|
| 17 |
+
from camel.messages import BaseMessage
|
| 18 |
+
from camel.models import BaseModelBackend
|
| 19 |
+
from camel.prompts import TextPrompt
|
| 20 |
+
from camel.types import RoleType
|
| 21 |
+
from camel.utils import create_chunks
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="SearchAgent")
|
| 36 |
+
class SearchAgent(ChatAgent):
|
| 37 |
+
r"""An agent that summarizes text based on a query and evaluates the
|
| 38 |
+
relevance of an answer.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 42 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 43 |
+
`GPT_4O_MINI`)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
model: Optional[BaseModelBackend] = None,
|
| 49 |
+
) -> None:
|
| 50 |
+
system_message = BaseMessage(
|
| 51 |
+
role_name="Assistant",
|
| 52 |
+
role_type=RoleType.ASSISTANT,
|
| 53 |
+
meta_dict=None,
|
| 54 |
+
content="You are a helpful assistant.",
|
| 55 |
+
)
|
| 56 |
+
super().__init__(system_message, model=model)
|
| 57 |
+
|
| 58 |
+
def summarize_text(self, text: str, query: str) -> str:
|
| 59 |
+
r"""Summarize the information from the text, base on the query.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
text (str): Text to summarize.
|
| 63 |
+
query (str): What information you want.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
str: Strings with information.
|
| 67 |
+
"""
|
| 68 |
+
self.reset()
|
| 69 |
+
|
| 70 |
+
summary_prompt = TextPrompt(
|
| 71 |
+
'''Gather information from this text that relative to the
|
| 72 |
+
question, but do not directly answer the question.\nquestion:
|
| 73 |
+
{query}\ntext '''
|
| 74 |
+
)
|
| 75 |
+
summary_prompt = summary_prompt.format(query=query)
|
| 76 |
+
# Max length of each chunk
|
| 77 |
+
max_len = 3000
|
| 78 |
+
results = ""
|
| 79 |
+
chunks = create_chunks(text, max_len)
|
| 80 |
+
# Summarize
|
| 81 |
+
for i, chunk in enumerate(chunks, start=1):
|
| 82 |
+
prompt = summary_prompt + str(i) + ": " + chunk
|
| 83 |
+
user_msg = BaseMessage.make_user_message(
|
| 84 |
+
role_name="User",
|
| 85 |
+
content=prompt,
|
| 86 |
+
)
|
| 87 |
+
result = self.step(user_msg).msg.content
|
| 88 |
+
results += result + "\n"
|
| 89 |
+
|
| 90 |
+
# Final summarization
|
| 91 |
+
final_prompt = TextPrompt(
|
| 92 |
+
'''Here are some summarized texts which split from one text. Using
|
| 93 |
+
the information to answer the question. If can't find the answer,
|
| 94 |
+
you must answer "I can not find the answer to the query" and
|
| 95 |
+
explain why.\n Query:\n{query}.\n\nText:\n'''
|
| 96 |
+
)
|
| 97 |
+
final_prompt = final_prompt.format(query=query)
|
| 98 |
+
prompt = final_prompt + results
|
| 99 |
+
|
| 100 |
+
user_msg = BaseMessage.make_user_message(
|
| 101 |
+
role_name="User",
|
| 102 |
+
content=prompt,
|
| 103 |
+
)
|
| 104 |
+
response = self.step(user_msg).msg.content
|
| 105 |
+
|
| 106 |
+
return response
|
| 107 |
+
|
| 108 |
+
def continue_search(self, query: str, answer: str) -> bool:
|
| 109 |
+
r"""Ask whether to continue search or not based on the provided answer.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
query (str): The question.
|
| 113 |
+
answer (str): The answer to the question.
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
bool: `True` if the user want to continue search, `False`
|
| 117 |
+
otherwise.
|
| 118 |
+
"""
|
| 119 |
+
prompt = TextPrompt(
|
| 120 |
+
"Do you think the ANSWER can answer the QUERY? "
|
| 121 |
+
"Use only 'yes' or 'no' to answer.\n"
|
| 122 |
+
"===== QUERY =====\n{query}\n\n"
|
| 123 |
+
"===== ANSWER =====\n{answer}"
|
| 124 |
+
)
|
| 125 |
+
prompt = prompt.format(query=query, answer=answer)
|
| 126 |
+
user_msg = BaseMessage.make_user_message(
|
| 127 |
+
role_name="User",
|
| 128 |
+
content=prompt,
|
| 129 |
+
)
|
| 130 |
+
response = self.step(user_msg).msg.content
|
| 131 |
+
if "yes" in str(response).lower():
|
| 132 |
+
return False
|
| 133 |
+
return True
|
Paper2Poster/camel/agents/task_agent.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, Dict, List, Optional, Union
|
| 15 |
+
|
| 16 |
+
from camel.agents.chat_agent import ChatAgent
|
| 17 |
+
from camel.messages import BaseMessage
|
| 18 |
+
from camel.models import BaseModelBackend
|
| 19 |
+
from camel.prompts import PromptTemplateGenerator, TextPrompt
|
| 20 |
+
from camel.types import RoleType, TaskType
|
| 21 |
+
from camel.utils import get_task_list
|
| 22 |
+
|
| 23 |
+
# AgentOps decorator setting
|
| 24 |
+
try:
|
| 25 |
+
import os
|
| 26 |
+
|
| 27 |
+
if os.getenv("AGENTOPS_API_KEY") is not None:
|
| 28 |
+
from agentops import track_agent
|
| 29 |
+
else:
|
| 30 |
+
raise ImportError
|
| 31 |
+
except (ImportError, AttributeError):
|
| 32 |
+
from camel.utils import track_agent
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@track_agent(name="TaskSpecifyAgent")
|
| 36 |
+
class TaskSpecifyAgent(ChatAgent):
|
| 37 |
+
r"""An agent that specifies a given task prompt by prompting the user to
|
| 38 |
+
provide more details.
|
| 39 |
+
|
| 40 |
+
Attributes:
|
| 41 |
+
DEFAULT_WORD_LIMIT (int): The default word limit for the task prompt.
|
| 42 |
+
task_specify_prompt (TextPrompt): The prompt for specifying the task.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 46 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 47 |
+
`GPT_4O_MINI`)
|
| 48 |
+
task_type (TaskType, optional): The type of task for which to generate
|
| 49 |
+
a prompt. (default: :obj:`TaskType.AI_SOCIETY`)
|
| 50 |
+
task_specify_prompt (Union[str, TextPrompt], optional): The prompt for
|
| 51 |
+
specifying the task. (default: :obj:`None`)
|
| 52 |
+
word_limit (int, optional): The word limit for the task prompt.
|
| 53 |
+
(default: :obj:`50`)
|
| 54 |
+
output_language (str, optional): The language to be output by the
|
| 55 |
+
agent. (default: :obj:`None`)
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
DEFAULT_WORD_LIMIT = 50
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
model: Optional[BaseModelBackend] = None,
|
| 63 |
+
task_type: TaskType = TaskType.AI_SOCIETY,
|
| 64 |
+
task_specify_prompt: Optional[Union[str, TextPrompt]] = None,
|
| 65 |
+
word_limit: int = DEFAULT_WORD_LIMIT,
|
| 66 |
+
output_language: Optional[str] = None,
|
| 67 |
+
) -> None:
|
| 68 |
+
self.task_specify_prompt: Union[str, TextPrompt]
|
| 69 |
+
if task_specify_prompt is None:
|
| 70 |
+
task_specify_prompt_template = (
|
| 71 |
+
PromptTemplateGenerator().get_task_specify_prompt(task_type)
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
self.task_specify_prompt = task_specify_prompt_template.format(
|
| 75 |
+
word_limit=word_limit
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
self.task_specify_prompt = TextPrompt(task_specify_prompt)
|
| 79 |
+
|
| 80 |
+
system_message = BaseMessage(
|
| 81 |
+
role_name="Task Specifier",
|
| 82 |
+
role_type=RoleType.ASSISTANT,
|
| 83 |
+
meta_dict=None,
|
| 84 |
+
content="You can make a task more specific.",
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
super().__init__(
|
| 88 |
+
system_message,
|
| 89 |
+
model=model,
|
| 90 |
+
output_language=output_language,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def run(
|
| 94 |
+
self,
|
| 95 |
+
task_prompt: Union[str, TextPrompt],
|
| 96 |
+
meta_dict: Optional[Dict[str, Any]] = None,
|
| 97 |
+
) -> TextPrompt:
|
| 98 |
+
r"""Specify the given task prompt by providing more details.
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
task_prompt (Union[str, TextPrompt]): The original task
|
| 102 |
+
prompt.
|
| 103 |
+
meta_dict (Dict[str, Any], optional): A dictionary containing
|
| 104 |
+
additional information to include in the prompt.
|
| 105 |
+
(default: :obj:`None`)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
TextPrompt: The specified task prompt.
|
| 109 |
+
"""
|
| 110 |
+
self.reset()
|
| 111 |
+
task_specify_prompt = self.task_specify_prompt.format(task=task_prompt)
|
| 112 |
+
|
| 113 |
+
if meta_dict is not None:
|
| 114 |
+
task_specify_prompt = task_specify_prompt.format(**meta_dict)
|
| 115 |
+
task_msg = BaseMessage.make_user_message(
|
| 116 |
+
role_name="Task Specifier", content=task_specify_prompt
|
| 117 |
+
)
|
| 118 |
+
specifier_response = self.step(task_msg)
|
| 119 |
+
|
| 120 |
+
if specifier_response.terminated:
|
| 121 |
+
raise RuntimeError("Task specification failed.")
|
| 122 |
+
if len(specifier_response.msgs) == 0:
|
| 123 |
+
raise RuntimeError("Got no specification message.")
|
| 124 |
+
|
| 125 |
+
specified_task_msg = specifier_response.msgs[0]
|
| 126 |
+
|
| 127 |
+
return TextPrompt(specified_task_msg.content)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
@track_agent(name="TaskPlannerAgent")
|
| 131 |
+
class TaskPlannerAgent(ChatAgent):
|
| 132 |
+
r"""An agent that helps divide a task into subtasks based on the input
|
| 133 |
+
task prompt.
|
| 134 |
+
|
| 135 |
+
Attributes:
|
| 136 |
+
task_planner_prompt (TextPrompt): A prompt for the agent to divide
|
| 137 |
+
the task into subtasks.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
model (BaseModelBackend, optional): The model backend to use for
|
| 141 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 142 |
+
`GPT_4O_MINI`)
|
| 143 |
+
output_language (str, optional): The language to be output by the
|
| 144 |
+
agent. (default: :obj:`None`)
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(
|
| 148 |
+
self,
|
| 149 |
+
model: Optional[BaseModelBackend] = None,
|
| 150 |
+
output_language: Optional[str] = None,
|
| 151 |
+
) -> None:
|
| 152 |
+
self.task_planner_prompt = TextPrompt(
|
| 153 |
+
"Divide this task into subtasks: {task}. Be concise."
|
| 154 |
+
)
|
| 155 |
+
system_message = BaseMessage(
|
| 156 |
+
role_name="Task Planner",
|
| 157 |
+
role_type=RoleType.ASSISTANT,
|
| 158 |
+
meta_dict=None,
|
| 159 |
+
content="You are a helpful task planner.",
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
super().__init__(
|
| 163 |
+
system_message,
|
| 164 |
+
model=model,
|
| 165 |
+
output_language=output_language,
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
def run(
|
| 169 |
+
self,
|
| 170 |
+
task_prompt: Union[str, TextPrompt],
|
| 171 |
+
) -> TextPrompt:
|
| 172 |
+
r"""Generate subtasks based on the input task prompt.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
task_prompt (Union[str, TextPrompt]): The prompt for the task to
|
| 176 |
+
be divided into subtasks.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
TextPrompt: A prompt for the subtasks generated by the agent.
|
| 180 |
+
"""
|
| 181 |
+
# TODO: Maybe include roles information.
|
| 182 |
+
self.reset()
|
| 183 |
+
task_planner_prompt = self.task_planner_prompt.format(task=task_prompt)
|
| 184 |
+
|
| 185 |
+
task_msg = BaseMessage.make_user_message(
|
| 186 |
+
role_name="Task Planner", content=task_planner_prompt
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
task_response = self.step(task_msg)
|
| 190 |
+
|
| 191 |
+
if task_response.terminated:
|
| 192 |
+
raise RuntimeError("Task planning failed.")
|
| 193 |
+
if len(task_response.msgs) == 0:
|
| 194 |
+
raise RuntimeError("Got no task planning message.")
|
| 195 |
+
|
| 196 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 197 |
+
return TextPrompt(sub_tasks_msg.content)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@track_agent(name="TaskCreationAgent")
|
| 201 |
+
class TaskCreationAgent(ChatAgent):
|
| 202 |
+
r"""An agent that helps create new tasks based on the objective
|
| 203 |
+
and last completed task. Compared to :obj:`TaskPlannerAgent`,
|
| 204 |
+
it's still a task planner, but it has more context information
|
| 205 |
+
like last task and incomplete task list. Modified from
|
| 206 |
+
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
| 207 |
+
|
| 208 |
+
Attributes:
|
| 209 |
+
task_creation_prompt (TextPrompt): A prompt for the agent to
|
| 210 |
+
create new tasks.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
role_name (str): The role name of the Agent to create the task.
|
| 214 |
+
objective (Union[str, TextPrompt]): The objective of the Agent to
|
| 215 |
+
perform the task.
|
| 216 |
+
model (BaseModelBackend, optional): The LLM backend to use for
|
| 217 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 218 |
+
`GPT_4O_MINI`)
|
| 219 |
+
output_language (str, optional): The language to be output by the
|
| 220 |
+
agent. (default: :obj:`None`)
|
| 221 |
+
message_window_size (int, optional): The maximum number of previous
|
| 222 |
+
messages to include in the context window. If `None`, no windowing
|
| 223 |
+
is performed. (default: :obj:`None`)
|
| 224 |
+
max_task_num (int, optional): The maximum number of planned
|
| 225 |
+
tasks in one round. (default: :obj:3)
|
| 226 |
+
"""
|
| 227 |
+
|
| 228 |
+
def __init__(
|
| 229 |
+
self,
|
| 230 |
+
role_name: str,
|
| 231 |
+
objective: Union[str, TextPrompt],
|
| 232 |
+
model: Optional[BaseModelBackend] = None,
|
| 233 |
+
output_language: Optional[str] = None,
|
| 234 |
+
message_window_size: Optional[int] = None,
|
| 235 |
+
max_task_num: Optional[int] = 3,
|
| 236 |
+
) -> None:
|
| 237 |
+
task_creation_prompt = TextPrompt(
|
| 238 |
+
"""Create new a task with the following objective: {objective}.
|
| 239 |
+
Never forget you are a Task Creator of {role_name}.
|
| 240 |
+
You must instruct me based on my expertise and your needs to solve the task.
|
| 241 |
+
You should consider past solved tasks and in-progress tasks: {task_list}.
|
| 242 |
+
The new created tasks must not overlap with these past tasks.
|
| 243 |
+
The result must be a numbered list in the format:
|
| 244 |
+
|
| 245 |
+
#. First Task
|
| 246 |
+
#. Second Task
|
| 247 |
+
#. Third Task
|
| 248 |
+
|
| 249 |
+
You can only give me up to {max_task_num} tasks at a time. \
|
| 250 |
+
Each task should be concise, concrete and doable for a {role_name}.
|
| 251 |
+
You should make task plan and not ask me questions.
|
| 252 |
+
If you think no new tasks are needed right now, write "No tasks to add."
|
| 253 |
+
Now start to give me new tasks one by one. No more than three tasks.
|
| 254 |
+
Be concrete.
|
| 255 |
+
"""
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.task_creation_prompt = task_creation_prompt.format(
|
| 259 |
+
objective=objective, role_name=role_name, max_task_num=max_task_num
|
| 260 |
+
)
|
| 261 |
+
self.objective = objective
|
| 262 |
+
|
| 263 |
+
system_message = BaseMessage(
|
| 264 |
+
role_name="Task Creator",
|
| 265 |
+
role_type=RoleType.ASSISTANT,
|
| 266 |
+
meta_dict=None,
|
| 267 |
+
content="You are a helpful task creator.",
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
super().__init__(
|
| 271 |
+
system_message,
|
| 272 |
+
model=model,
|
| 273 |
+
output_language=output_language,
|
| 274 |
+
message_window_size=message_window_size,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def run(
|
| 278 |
+
self,
|
| 279 |
+
task_list: List[str],
|
| 280 |
+
) -> List[str]:
|
| 281 |
+
r"""Generate subtasks based on the previous task results and
|
| 282 |
+
incomplete task list.
|
| 283 |
+
|
| 284 |
+
Args:
|
| 285 |
+
task_list (List[str]): The completed or in-progress
|
| 286 |
+
tasks which should not overlap with new created tasks.
|
| 287 |
+
|
| 288 |
+
Returns:
|
| 289 |
+
List[str]: The new task list generated by the Agent.
|
| 290 |
+
"""
|
| 291 |
+
|
| 292 |
+
if len(task_list) > 0:
|
| 293 |
+
task_creation_prompt = self.task_creation_prompt.format(
|
| 294 |
+
task_list=task_list
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
task_creation_prompt = self.task_creation_prompt.format(
|
| 298 |
+
task_list=""
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
task_msg = BaseMessage.make_user_message(
|
| 302 |
+
role_name="Task Creator", content=task_creation_prompt
|
| 303 |
+
)
|
| 304 |
+
task_response = self.step(task_msg)
|
| 305 |
+
|
| 306 |
+
if task_response.terminated:
|
| 307 |
+
raise RuntimeError("Task creation failed.")
|
| 308 |
+
if len(task_response.msgs) == 0:
|
| 309 |
+
raise RuntimeError("Got no task creation message.")
|
| 310 |
+
|
| 311 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 312 |
+
return get_task_list(sub_tasks_msg.content)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
@track_agent(name="TaskPrioritizationAgent")
|
| 316 |
+
class TaskPrioritizationAgent(ChatAgent):
|
| 317 |
+
r"""An agent that helps re-prioritize the task list and
|
| 318 |
+
returns numbered prioritized list. Modified from
|
| 319 |
+
`BabyAGI <https://github.com/yoheinakajima/babyagi>`_.
|
| 320 |
+
|
| 321 |
+
Attributes:
|
| 322 |
+
task_prioritization_prompt (TextPrompt): A prompt for the agent to
|
| 323 |
+
prioritize tasks.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
objective (Union[str, TextPrompt]): The objective of the Agent to
|
| 327 |
+
perform the task.
|
| 328 |
+
model (BaseModelBackend, optional): The LLM backend to use for
|
| 329 |
+
generating responses. (default: :obj:`OpenAIModel` with
|
| 330 |
+
`GPT_4O_MINI`)
|
| 331 |
+
output_language (str, optional): The language to be output by the
|
| 332 |
+
agent. (default: :obj:`None`)
|
| 333 |
+
message_window_size (int, optional): The maximum number of previous
|
| 334 |
+
messages to include in the context window. If `None`, no windowing
|
| 335 |
+
is performed. (default: :obj:`None`)
|
| 336 |
+
"""
|
| 337 |
+
|
| 338 |
+
def __init__(
|
| 339 |
+
self,
|
| 340 |
+
objective: Union[str, TextPrompt],
|
| 341 |
+
model: Optional[BaseModelBackend] = None,
|
| 342 |
+
output_language: Optional[str] = None,
|
| 343 |
+
message_window_size: Optional[int] = None,
|
| 344 |
+
) -> None:
|
| 345 |
+
task_prioritization_prompt = TextPrompt(
|
| 346 |
+
"""Prioritize the following tasks : {task_list}.
|
| 347 |
+
Consider the ultimate objective of you: {objective}.
|
| 348 |
+
Tasks should be sorted from highest to lowest priority, where higher-priority \
|
| 349 |
+
tasks are those that act as pre-requisites or are more essential for meeting \
|
| 350 |
+
the objective. Return one task per line in your response.
|
| 351 |
+
Do not remove or modify any tasks.
|
| 352 |
+
The result must be a numbered list in the format:
|
| 353 |
+
|
| 354 |
+
#. First task
|
| 355 |
+
#. Second task
|
| 356 |
+
|
| 357 |
+
The entries must be consecutively numbered, starting with 1.
|
| 358 |
+
The number of each entry must be followed by a period.
|
| 359 |
+
Do not include any headers before your ranked list or follow your list \
|
| 360 |
+
with any other output."""
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
self.task_prioritization_prompt = task_prioritization_prompt.format(
|
| 364 |
+
objective=objective
|
| 365 |
+
)
|
| 366 |
+
self.objective = objective
|
| 367 |
+
|
| 368 |
+
system_message = BaseMessage(
|
| 369 |
+
role_name="Task Prioritizer",
|
| 370 |
+
role_type=RoleType.ASSISTANT,
|
| 371 |
+
meta_dict=None,
|
| 372 |
+
content="You are a helpful task prioritizer.",
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
super().__init__(
|
| 376 |
+
system_message,
|
| 377 |
+
model=model,
|
| 378 |
+
output_language=output_language,
|
| 379 |
+
message_window_size=message_window_size,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
def run(
|
| 383 |
+
self,
|
| 384 |
+
task_list: List[str],
|
| 385 |
+
) -> List[str]:
|
| 386 |
+
r"""Prioritize the task list given the agent objective.
|
| 387 |
+
|
| 388 |
+
Args:
|
| 389 |
+
task_list (List[str]): The unprioritized tasks of agent.
|
| 390 |
+
|
| 391 |
+
Returns:
|
| 392 |
+
List[str]: The new prioritized task list generated by the Agent.
|
| 393 |
+
"""
|
| 394 |
+
task_prioritization_prompt = self.task_prioritization_prompt.format(
|
| 395 |
+
task_list=task_list
|
| 396 |
+
)
|
| 397 |
+
|
| 398 |
+
task_msg = BaseMessage.make_user_message(
|
| 399 |
+
role_name="Task Prioritizer", content=task_prioritization_prompt
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
task_response = self.step(task_msg)
|
| 403 |
+
|
| 404 |
+
if task_response.terminated:
|
| 405 |
+
raise RuntimeError("Task prioritization failed.")
|
| 406 |
+
if len(task_response.msgs) == 0:
|
| 407 |
+
raise RuntimeError("Got no task prioritization message.")
|
| 408 |
+
|
| 409 |
+
sub_tasks_msg = task_response.msgs[0]
|
| 410 |
+
return get_task_list(sub_tasks_msg.content)
|
Paper2Poster/camel/agents/tool_agents/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from .base import BaseToolAgent
|
| 15 |
+
from .hugging_face_tool_agent import HuggingFaceToolAgent
|
| 16 |
+
|
| 17 |
+
__all__ = [
|
| 18 |
+
'BaseToolAgent',
|
| 19 |
+
'HuggingFaceToolAgent',
|
| 20 |
+
]
|
Paper2Poster/camel/agents/tool_agents/base.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from camel.agents import BaseAgent
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class BaseToolAgent(BaseAgent):
|
| 18 |
+
r"""Creates a :obj:`BaseToolAgent` object with the specified name and
|
| 19 |
+
description.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
name (str): The name of the tool agent.
|
| 23 |
+
description (str): The description of the tool agent.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, name: str, description: str) -> None:
|
| 27 |
+
self.name = name
|
| 28 |
+
self.description = description
|
| 29 |
+
|
| 30 |
+
def reset(self) -> None:
|
| 31 |
+
r"""Resets the agent to its initial state."""
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
def step(self) -> None:
|
| 35 |
+
r"""Performs a single step of the agent."""
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
def __str__(self) -> str:
|
| 39 |
+
return f"{self.name}: {self.description}"
|
Paper2Poster/camel/agents/tool_agents/hugging_face_tool_agent.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
from typing import Any, Optional
|
| 15 |
+
|
| 16 |
+
from camel.agents.tool_agents.base import BaseToolAgent
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# flake8: noqa :E501
|
| 20 |
+
class HuggingFaceToolAgent(BaseToolAgent):
|
| 21 |
+
r"""Tool agent for calling HuggingFace models. This agent is a wrapper
|
| 22 |
+
around agents from the `transformers` library. For more information
|
| 23 |
+
about the available models, please see the `transformers` documentation
|
| 24 |
+
at https://huggingface.co/docs/transformers/transformers_agents.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
name (str): The name of the agent.
|
| 28 |
+
*args (Any): Additional positional arguments to pass to the underlying
|
| 29 |
+
Agent class.
|
| 30 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 31 |
+
remotely. (default: :obj:`True`)
|
| 32 |
+
**kwargs (Any): Additional keyword arguments to pass to the underlying
|
| 33 |
+
Agent class.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
name: str,
|
| 39 |
+
*args: Any,
|
| 40 |
+
remote: bool = True,
|
| 41 |
+
**kwargs: Any,
|
| 42 |
+
) -> None:
|
| 43 |
+
try:
|
| 44 |
+
# TODO: Support other tool agents
|
| 45 |
+
import transformers
|
| 46 |
+
from packaging import version
|
| 47 |
+
|
| 48 |
+
if version.parse(transformers.__version__) < version.parse(
|
| 49 |
+
"4.31.0"
|
| 50 |
+
):
|
| 51 |
+
raise ValueError(
|
| 52 |
+
"The version of \"transformers\" package should >= 4.31.0"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
from transformers.tools import OpenAiAgent
|
| 56 |
+
from transformers.tools.agent_types import AgentImage
|
| 57 |
+
except (ImportError, ValueError):
|
| 58 |
+
raise ValueError(
|
| 59 |
+
"Could not import transformers tool agents. "
|
| 60 |
+
"Please setup the environment with "
|
| 61 |
+
"pip install huggingface_hub==0.14.1 transformers==4.31.0 diffusers accelerate==0.20.3 datasets torch soundfile sentencepiece opencv-python"
|
| 62 |
+
)
|
| 63 |
+
self.agent_image_type = AgentImage
|
| 64 |
+
self.agent = OpenAiAgent(*args, **kwargs)
|
| 65 |
+
description = f"""The `{name}` is a tool agent that can perform a variety of tasks including:
|
| 66 |
+
- Document question answering: given a document (such as a PDF) in image format, answer a question on this document
|
| 67 |
+
- Text question answering: given a long text and a question, answer the question in the text
|
| 68 |
+
- Unconditional image captioning: Caption the image!
|
| 69 |
+
- Image question answering: given an image, answer a question on this image
|
| 70 |
+
- Image segmentation: given an image and a prompt, output the segmentation mask of that prompt
|
| 71 |
+
- Speech to text: given an audio recording of a person talking, transcribe the speech into text
|
| 72 |
+
- Text to speech: convert text to speech
|
| 73 |
+
- Zero-shot text classification: given a text and a list of labels, identify to which label the text corresponds the most
|
| 74 |
+
- Text summarization: summarize a long text in one or a few sentences
|
| 75 |
+
- Translation: translate the text into a given language
|
| 76 |
+
- Text downloading: to download a text from a web URL
|
| 77 |
+
- Text to image: generate an image according to a prompt, leveraging stable diffusion
|
| 78 |
+
- Image transformation: modify an image given an initial image and a prompt, leveraging instruct pix2pix stable diffusion
|
| 79 |
+
- Text to video: generate a small video according to a prompt
|
| 80 |
+
|
| 81 |
+
Here are some python code examples of what you can do with this agent:
|
| 82 |
+
|
| 83 |
+
Single execution (step) mode, the single execution method is when using the step() method of the agent:
|
| 84 |
+
```
|
| 85 |
+
# Text to image
|
| 86 |
+
rivers_and_lakes_image = {name}.step("Draw me a picture of rivers and lakes.")
|
| 87 |
+
rivers_and_lakes_image.save("./rivers_and_lakes_image.png")
|
| 88 |
+
|
| 89 |
+
# Text to image -> Image transformation
|
| 90 |
+
sea_add_island_image = {name}.step("Draw me a picture of the sea then transform the picture to add an island")
|
| 91 |
+
sea_add_island_image.save("./sea_add_island_image.png")
|
| 92 |
+
|
| 93 |
+
# If you'd like to keep a state across executions or to pass non-text objects to the agent,
|
| 94 |
+
# you can do so by specifying variables that you would like the agent to use. For example,
|
| 95 |
+
# you could generate the first image of rivers and lakes, and ask the model to update that picture to add an island by doing the following:
|
| 96 |
+
picture = {name}.step("Generate a picture of rivers and lakes.")
|
| 97 |
+
picture.save("./picture.png")
|
| 98 |
+
updated_picture = {name}.step("Transform the image in `picture` to add an island to it.", picture=picture)
|
| 99 |
+
updated_picture.save("./updated_picture.png")
|
| 100 |
+
|
| 101 |
+
capybara_sea_image = {name}.step("Draw me a picture of the `prompt`", prompt="a capybara swimming in the sea")
|
| 102 |
+
capybara_sea_image.save("./capybara_sea_image.png")
|
| 103 |
+
|
| 104 |
+
# Document question answering
|
| 105 |
+
answer = {name}.step(
|
| 106 |
+
"In the following `document`, where will the TRRF Scientific Advisory Council Meeting take place?",
|
| 107 |
+
document=document,
|
| 108 |
+
)
|
| 109 |
+
print(answer)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
# Text to image
|
| 113 |
+
boat_image = {name}.step("Generate an image of a boat in the water")
|
| 114 |
+
boat_image.save("./boat_image.png")
|
| 115 |
+
|
| 116 |
+
# Unconditional image captioning
|
| 117 |
+
boat_image_caption = {name}.step("Can you caption the `boat_image`?", boat_image=boat_image)
|
| 118 |
+
print(boat_image_caption)
|
| 119 |
+
|
| 120 |
+
# Text to image -> Unconditional image captioning -> Text to speech
|
| 121 |
+
boat_audio = {name}.step("Can you generate an image of a boat? Please read out loud the contents of the image afterwards")
|
| 122 |
+
|
| 123 |
+
# Text downloading
|
| 124 |
+
document = {name}.step("Download the text from http://hf.co")
|
| 125 |
+
print(document)
|
| 126 |
+
|
| 127 |
+
# Text summarization
|
| 128 |
+
summary = {name}.step("Summarize the following text: `document`", document=document)
|
| 129 |
+
print(summary)
|
| 130 |
+
|
| 131 |
+
# Text downloading -> Text summarization -> Text to speech
|
| 132 |
+
audio = {name}.step("Read out loud the summary of http://hf.co")
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
Chat-based execution (chat), the agent also has a chat-based approach, using the chat() method:
|
| 136 |
+
```
|
| 137 |
+
# Clean the chat history
|
| 138 |
+
{name}.reset()
|
| 139 |
+
|
| 140 |
+
# Text to image
|
| 141 |
+
capybara_image = {name}.chat("Show me an an image of a capybara")
|
| 142 |
+
capybara_image.save("./capybara_image.png")
|
| 143 |
+
|
| 144 |
+
# Image transformation
|
| 145 |
+
transformed_capybara_image = {name}.chat("Transform the image so that it snows")
|
| 146 |
+
transformed_capybara_image.save("./transformed_capybara_image.png")
|
| 147 |
+
|
| 148 |
+
# Image segmentation
|
| 149 |
+
segmented_transformed_capybara_image = {name}.chat("Show me a mask of the snowy capybaras")
|
| 150 |
+
segmented_transformed_capybara_image.save("./segmented_transformed_capybara_image.png")
|
| 151 |
+
```
|
| 152 |
+
"""
|
| 153 |
+
super(HuggingFaceToolAgent, self).__init__(name, description)
|
| 154 |
+
self.remote = remote
|
| 155 |
+
|
| 156 |
+
def reset(self) -> None:
|
| 157 |
+
r"""Resets the chat history of the agent."""
|
| 158 |
+
self.agent.prepare_for_new_chat()
|
| 159 |
+
|
| 160 |
+
def step(
|
| 161 |
+
self,
|
| 162 |
+
*args: Any,
|
| 163 |
+
remote: Optional[bool] = None,
|
| 164 |
+
**kwargs: Any,
|
| 165 |
+
) -> Any:
|
| 166 |
+
r"""Runs the agent in single execution mode.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
*args (Any): Positional arguments to pass to the agent.
|
| 170 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 171 |
+
remotely. Overrides the default setting. (default: :obj:`None`)
|
| 172 |
+
**kwargs (Any): Keyword arguments to pass to the agent.
|
| 173 |
+
|
| 174 |
+
Returns:
|
| 175 |
+
str: The response from the agent.
|
| 176 |
+
"""
|
| 177 |
+
if remote is None:
|
| 178 |
+
remote = self.remote
|
| 179 |
+
agent_output = self.agent.run(*args, remote=remote, **kwargs)
|
| 180 |
+
if isinstance(agent_output, self.agent_image_type):
|
| 181 |
+
agent_output = agent_output.to_raw()
|
| 182 |
+
return agent_output
|
| 183 |
+
|
| 184 |
+
def chat(
|
| 185 |
+
self,
|
| 186 |
+
*args: Any,
|
| 187 |
+
remote: Optional[bool] = None,
|
| 188 |
+
**kwargs: Any,
|
| 189 |
+
) -> Any:
|
| 190 |
+
r"""Runs the agent in a chat conversation mode.
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
*args (Any): Positional arguments to pass to the agent.
|
| 194 |
+
remote (bool, optional): Flag indicating whether to run the agent
|
| 195 |
+
remotely. Overrides the default setting. (default: :obj:`None`)
|
| 196 |
+
**kwargs (Any): Keyword arguments to pass to the agent.
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
str: The response from the agent.
|
| 200 |
+
"""
|
| 201 |
+
if remote is None:
|
| 202 |
+
remote = self.remote
|
| 203 |
+
agent_output = self.agent.chat(*args, remote=remote, **kwargs)
|
| 204 |
+
if isinstance(agent_output, self.agent_image_type):
|
| 205 |
+
agent_output = agent_output.to_raw()
|
| 206 |
+
return agent_output
|
Paper2Poster/camel/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from .apibank import APIBankBenchmark
|
| 16 |
+
from .apibench import APIBenchBenchmark
|
| 17 |
+
from .base import BaseBenchmark
|
| 18 |
+
from .gaia import DefaultGAIARetriever, GAIABenchmark
|
| 19 |
+
from .nexus import NexusBenchmark
|
| 20 |
+
from .ragbench import RAGBenchBenchmark
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"BaseBenchmark",
|
| 24 |
+
"GAIABenchmark",
|
| 25 |
+
"DefaultGAIARetriever",
|
| 26 |
+
"NexusBenchmark",
|
| 27 |
+
"APIBenchBenchmark",
|
| 28 |
+
"APIBankBenchmark",
|
| 29 |
+
"RAGBenchBenchmark",
|
| 30 |
+
]
|
Paper2Poster/camel/benchmarks/apibank.py
ADDED
|
@@ -0,0 +1,565 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import sys
|
| 21 |
+
from pathlib import Path
|
| 22 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
from rouge import Rouge
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
|
| 28 |
+
from camel.agents import ChatAgent
|
| 29 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 30 |
+
from camel.messages import BaseMessage
|
| 31 |
+
from camel.utils import download_github_subdirectory
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
# Add current folder to sys.path to enable relative import
|
| 36 |
+
current_folder = os.getcwd()
|
| 37 |
+
if current_folder not in sys.path:
|
| 38 |
+
sys.path.append(current_folder)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def process_messages(
|
| 42 |
+
chat_history: List[Dict[str, Any]],
|
| 43 |
+
prompt: str,
|
| 44 |
+
) -> List[Dict[str, str]]:
|
| 45 |
+
"""
|
| 46 |
+
Processes chat history into a structured format for further use.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
chat_history (List[Dict[str, Any]):
|
| 50 |
+
A list of dictionaries representing the chat history.
|
| 51 |
+
prompt (str): A propmt to be set as the system message.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
List[Dict[str, str]]: A list of dictionaries representing
|
| 55 |
+
the processed messages, where each dictionary has:
|
| 56 |
+
- 'role': The role of the message ('system', 'user', or 'assistant').
|
| 57 |
+
- 'content': The content of the message, including formatted
|
| 58 |
+
API responses when applicable.
|
| 59 |
+
"""
|
| 60 |
+
messages = [{'role': 'system', 'content': prompt}]
|
| 61 |
+
for item in chat_history:
|
| 62 |
+
role_map = {'User': 'user', 'AI': 'assistant', 'API': 'system'}
|
| 63 |
+
chat_role = role_map.get(
|
| 64 |
+
item['role'], 'unknown'
|
| 65 |
+
) # default role to 'unknown'
|
| 66 |
+
if item['role'] == 'API':
|
| 67 |
+
chat_content = '[{}({})] Response: {}'.format(
|
| 68 |
+
item['api_name'],
|
| 69 |
+
', '.join(
|
| 70 |
+
[
|
| 71 |
+
'{}=\'{}\''.format(k, v)
|
| 72 |
+
for k, v in item['param_dict'].items()
|
| 73 |
+
]
|
| 74 |
+
),
|
| 75 |
+
str(item['result']['output']),
|
| 76 |
+
)
|
| 77 |
+
else:
|
| 78 |
+
chat_content = item['text']
|
| 79 |
+
messages.append({'role': chat_role, 'content': chat_content})
|
| 80 |
+
return messages
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class APIBankBenchmark(BaseBenchmark):
|
| 84 |
+
r"""API-Bank Benchmark adapted from `API-Bank:
|
| 85 |
+
A Comprehensive Benchmark for Tool-Augmented LLMs`
|
| 86 |
+
<https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank>.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
save_to (str): The file to save the results.
|
| 90 |
+
processes (int, optional): The number of processes to use.
|
| 91 |
+
(default: :obj:`1`)
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
save_to: str,
|
| 97 |
+
processes: int = 1,
|
| 98 |
+
):
|
| 99 |
+
r"""Initialize the APIBank benchmark.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
save_to (str): The file to save the results.
|
| 103 |
+
processes (int, optional): The number of processes to use for
|
| 104 |
+
parallel processing. (default: :obj:`1`)
|
| 105 |
+
"""
|
| 106 |
+
# Predefine data_dir for better import management
|
| 107 |
+
super().__init__("apibank", "api_bank", save_to, processes)
|
| 108 |
+
self._data: Dict[str, List[APIBankSample]] = dict() # type: ignore[assignment]
|
| 109 |
+
|
| 110 |
+
def download(self):
|
| 111 |
+
r"""Download APIBank dataset and code from Github."""
|
| 112 |
+
|
| 113 |
+
repo = "AlibabaResearch/DAMO-ConvAI"
|
| 114 |
+
subdir = "api-bank"
|
| 115 |
+
data_dir = self.data_dir
|
| 116 |
+
|
| 117 |
+
download_github_subdirectory(repo, subdir, data_dir)
|
| 118 |
+
|
| 119 |
+
sys.path.insert(0, self.data_dir)
|
| 120 |
+
logger.info("Download completed.")
|
| 121 |
+
|
| 122 |
+
def load(self, level: str, force_download: bool = False): # type: ignore[override]
|
| 123 |
+
r"""Load the APIBank Benchmark dataset.
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
level (str): Level to run benchmark on.
|
| 127 |
+
force_download (bool, optional): Whether to
|
| 128 |
+
force download the data.
|
| 129 |
+
"""
|
| 130 |
+
if force_download:
|
| 131 |
+
logger.info("Force downloading data.")
|
| 132 |
+
self.download()
|
| 133 |
+
|
| 134 |
+
if level == "level-1":
|
| 135 |
+
file_path = Path("api_bank/lv1-lv2-samples/level-1-given-desc")
|
| 136 |
+
elif level == 'level-2':
|
| 137 |
+
file_path = Path("api_bank/lv1-lv2-samples/level-2-toolsearcher")
|
| 138 |
+
jsonl_files = [
|
| 139 |
+
f for f in os.listdir(file_path) if f.endswith('.jsonl')
|
| 140 |
+
]
|
| 141 |
+
for file in tqdm(jsonl_files, desc="Processing files"):
|
| 142 |
+
history = []
|
| 143 |
+
with open(file_path / file, 'r') as f:
|
| 144 |
+
for line in f:
|
| 145 |
+
history.append(json.loads(line))
|
| 146 |
+
samples = APIBankSample.from_chat_history(history)
|
| 147 |
+
self._data[file.rsplit('.', 1)[0]] = samples
|
| 148 |
+
|
| 149 |
+
# Change import to relative import in the downloaded python files
|
| 150 |
+
def process_files(folder_path, replacements):
|
| 151 |
+
r"""Replace absolute imports in downloaded files with
|
| 152 |
+
relative import."""
|
| 153 |
+
for file in os.listdir(folder_path):
|
| 154 |
+
if file.endswith(".py"):
|
| 155 |
+
file_path = os.path.join(folder_path, file)
|
| 156 |
+
try:
|
| 157 |
+
with open(file_path, "r", encoding="utf-8") as file:
|
| 158 |
+
content = file.read()
|
| 159 |
+
|
| 160 |
+
original_content = content
|
| 161 |
+
|
| 162 |
+
for pattern, replacement in replacements:
|
| 163 |
+
content = re.sub(pattern, replacement, content)
|
| 164 |
+
|
| 165 |
+
if content != original_content:
|
| 166 |
+
with open(
|
| 167 |
+
file_path, "w", encoding="utf-8"
|
| 168 |
+
) as file:
|
| 169 |
+
file.write(content)
|
| 170 |
+
logger.info(f"Updated file: {file_path}")
|
| 171 |
+
|
| 172 |
+
except Exception as e:
|
| 173 |
+
logger.info(f"Error processing file {file_path}: {e}")
|
| 174 |
+
|
| 175 |
+
api_bank_folder = "api_bank"
|
| 176 |
+
apis_folder = os.path.join(api_bank_folder, "apis")
|
| 177 |
+
|
| 178 |
+
apis_replacements = [
|
| 179 |
+
(r"from apis.api", "from .api"),
|
| 180 |
+
(r"from apis import", "from .api import"),
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
api_bank_replacements = [
|
| 184 |
+
(r"from apis", "from .apis"),
|
| 185 |
+
(r"from api_call_extraction", "from .api_call_extraction"),
|
| 186 |
+
(r"f'{basename}", r"f'api_bank.{basename}"),
|
| 187 |
+
]
|
| 188 |
+
|
| 189 |
+
process_files(apis_folder, apis_replacements)
|
| 190 |
+
process_files(api_bank_folder, api_bank_replacements)
|
| 191 |
+
|
| 192 |
+
def run( # type: ignore[override, return]
|
| 193 |
+
self,
|
| 194 |
+
agent: ChatAgent,
|
| 195 |
+
level: Literal["level-1", "level-2"],
|
| 196 |
+
api_test_enabled=True,
|
| 197 |
+
randomize: bool = False,
|
| 198 |
+
subset: Optional[int] = None,
|
| 199 |
+
) -> Dict[str, Any]:
|
| 200 |
+
r"""Run the benchmark.
|
| 201 |
+
|
| 202 |
+
Args:
|
| 203 |
+
agent (ChatAgent): The agent to run the
|
| 204 |
+
benchmark.
|
| 205 |
+
level (Literal['level-1', 'level-2']):
|
| 206 |
+
The level to run the benchmark on.
|
| 207 |
+
randomize (bool, optional): Whether to
|
| 208 |
+
randomize the data.
|
| 209 |
+
api_test_enabled (bool): Whether to test
|
| 210 |
+
API calling (`True`) or response (`False`)
|
| 211 |
+
(default: :obj:`False`)
|
| 212 |
+
subset (Optional[int], optional):
|
| 213 |
+
The subset of data to run.
|
| 214 |
+
(default: :obj:`None`)
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
Dict[str, Any]: The results of the benchmark.
|
| 218 |
+
"""
|
| 219 |
+
logger.info(f"Running APIBench benchmark on {level}.")
|
| 220 |
+
self.load(level)
|
| 221 |
+
datas = self._data
|
| 222 |
+
|
| 223 |
+
# Shuffle and subset data if necessary
|
| 224 |
+
if randomize:
|
| 225 |
+
randomized_items = list(datas.items())
|
| 226 |
+
random.shuffle(randomized_items)
|
| 227 |
+
datas = dict(randomized_items)
|
| 228 |
+
if subset:
|
| 229 |
+
datas = dict(list(datas.items())[:subset])
|
| 230 |
+
|
| 231 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 232 |
+
|
| 233 |
+
# Initialize results storage
|
| 234 |
+
self._results = []
|
| 235 |
+
|
| 236 |
+
# The following code are adapted from the evaluator
|
| 237 |
+
# from the original repo:
|
| 238 |
+
tool_search_enabled = level == "level-2"
|
| 239 |
+
dialog_test_enabled = not api_test_enabled
|
| 240 |
+
total_api_calls, correct_api_calls, rougel_scores = 0, 0, []
|
| 241 |
+
|
| 242 |
+
with open(self.save_to, "w") as f:
|
| 243 |
+
for test in tqdm(datas, desc="Running"):
|
| 244 |
+
samples = self._data[test]
|
| 245 |
+
evaluator = Evaluator(samples) # type: ignore[arg-type]
|
| 246 |
+
|
| 247 |
+
for sample_id in evaluator.get_all_sample_ids():
|
| 248 |
+
# Process sample and generate response
|
| 249 |
+
sample = evaluator.dataset[sample_id]
|
| 250 |
+
|
| 251 |
+
if (
|
| 252 |
+
sample.ground_truth['role'] == 'API'
|
| 253 |
+
and api_test_enabled
|
| 254 |
+
):
|
| 255 |
+
if tool_search_enabled:
|
| 256 |
+
_, chat_history = evaluator.get_model_input(
|
| 257 |
+
sample_id
|
| 258 |
+
)
|
| 259 |
+
api_descriptions = evaluator.get_api_description(
|
| 260 |
+
'ToolSearcher'
|
| 261 |
+
)
|
| 262 |
+
else:
|
| 263 |
+
api_descriptions, chat_history = (
|
| 264 |
+
evaluator.get_model_input(sample_id)
|
| 265 |
+
)
|
| 266 |
+
messages = process_messages(
|
| 267 |
+
chat_history, API_CALL_PROMPT + api_descriptions
|
| 268 |
+
)
|
| 269 |
+
model_output = agent_call(messages, agent)
|
| 270 |
+
api_call = get_api_call(model_output)
|
| 271 |
+
|
| 272 |
+
# Evaluate API call
|
| 273 |
+
if api_call:
|
| 274 |
+
try:
|
| 275 |
+
correct, model_output_result = (
|
| 276 |
+
evaluator.evaluate(sample_id, api_call)
|
| 277 |
+
)
|
| 278 |
+
except AssertionError as e:
|
| 279 |
+
if 'The API name is not correct.' not in str(
|
| 280 |
+
e
|
| 281 |
+
):
|
| 282 |
+
raise e
|
| 283 |
+
logging.info('AssertionError: {}'.format(e))
|
| 284 |
+
correct = False
|
| 285 |
+
else:
|
| 286 |
+
model_output_result = 'No API call found'
|
| 287 |
+
correct = False
|
| 288 |
+
if correct:
|
| 289 |
+
correct_api_calls += 1
|
| 290 |
+
logging.info(
|
| 291 |
+
'Correct API call: {} Ground truth: {}'.format(
|
| 292 |
+
api_call, sample.ground_truth
|
| 293 |
+
)
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
logging.info(
|
| 297 |
+
'Incorrect model output: {} Result: {} \
|
| 298 |
+
Ground truth: {} File: {} Sample ID: {} \
|
| 299 |
+
Messages: {}'.format(
|
| 300 |
+
model_output.replace('\n', ' '),
|
| 301 |
+
model_output_result,
|
| 302 |
+
sample.ground_truth,
|
| 303 |
+
test,
|
| 304 |
+
sample_id,
|
| 305 |
+
messages[1:],
|
| 306 |
+
)
|
| 307 |
+
)
|
| 308 |
+
total_api_calls += 1
|
| 309 |
+
self._results.append(
|
| 310 |
+
{
|
| 311 |
+
'Role': 'API',
|
| 312 |
+
'Model_output': model_output,
|
| 313 |
+
'Model_output_result': model_output_result,
|
| 314 |
+
'Ground_truth': sample.ground_truth,
|
| 315 |
+
'Test': test,
|
| 316 |
+
'Correct': correct,
|
| 317 |
+
}
|
| 318 |
+
)
|
| 319 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 320 |
+
|
| 321 |
+
elif (
|
| 322 |
+
sample.ground_truth['role'] == 'AI'
|
| 323 |
+
and dialog_test_enabled
|
| 324 |
+
):
|
| 325 |
+
# Process sample and generate response
|
| 326 |
+
api_descriptions, chat_history = (
|
| 327 |
+
evaluator.get_model_input(sample_id)
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
messages = process_messages(
|
| 331 |
+
chat_history, RESPONSE_PROMPT + api_descriptions
|
| 332 |
+
)
|
| 333 |
+
model_output = agent_call(messages, agent)
|
| 334 |
+
|
| 335 |
+
# Evaluate model response
|
| 336 |
+
if model_output:
|
| 337 |
+
score = evaluator.evaluate(sample_id, model_output)
|
| 338 |
+
else:
|
| 339 |
+
score = 0
|
| 340 |
+
rougel_scores.append(score)
|
| 341 |
+
if score < 0.2:
|
| 342 |
+
logging.info(
|
| 343 |
+
'Low score: {} Score: {} Ground truth: {} \
|
| 344 |
+
Test: {} Sample ID: {} \
|
| 345 |
+
Messages: {}'.format(
|
| 346 |
+
model_output.replace('\n', ' '),
|
| 347 |
+
score,
|
| 348 |
+
sample.ground_truth,
|
| 349 |
+
test,
|
| 350 |
+
sample_id,
|
| 351 |
+
messages[1:],
|
| 352 |
+
)
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
self._results.append(
|
| 356 |
+
{
|
| 357 |
+
'Role': 'AI',
|
| 358 |
+
'Model_output': model_output,
|
| 359 |
+
'Score': score,
|
| 360 |
+
'Ground_truth': sample.ground_truth,
|
| 361 |
+
'Test': test,
|
| 362 |
+
}
|
| 363 |
+
)
|
| 364 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 365 |
+
|
| 366 |
+
f.flush()
|
| 367 |
+
|
| 368 |
+
if api_test_enabled:
|
| 369 |
+
return {
|
| 370 |
+
'total': total_api_calls,
|
| 371 |
+
'correct': correct_api_calls,
|
| 372 |
+
"accuracy": correct_api_calls / total_api_calls
|
| 373 |
+
if total_api_calls
|
| 374 |
+
else 0,
|
| 375 |
+
}
|
| 376 |
+
elif dialog_test_enabled:
|
| 377 |
+
return {'Dialog_score': np.mean(rougel_scores)}
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# The following code are migrated from the original repo:
|
| 381 |
+
# https://github.com/AlibabaResearch/DAMO-ConvAI/tree/main/api-bank
|
| 382 |
+
def agent_call(messages: List[Dict], agent: ChatAgent):
|
| 383 |
+
r"""Add messages to agent memory and get response."""
|
| 384 |
+
for i, msg in enumerate(messages):
|
| 385 |
+
if msg['role'] == 'user':
|
| 386 |
+
message = BaseMessage.make_user_message(
|
| 387 |
+
role_name="CAMEL User", content=msg['content']
|
| 388 |
+
)
|
| 389 |
+
elif msg['role'] == 'assistant':
|
| 390 |
+
message = BaseMessage.make_assistant_message(
|
| 391 |
+
role_name="CAMEL Assistant", content=msg['content']
|
| 392 |
+
)
|
| 393 |
+
elif msg['role'] == 'system':
|
| 394 |
+
message = BaseMessage.make_assistant_message(
|
| 395 |
+
role_name="System", content=msg['content']
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
raise ValueError(f"Unrecognized role: {msg['role']}")
|
| 399 |
+
|
| 400 |
+
if i == len(messages) - 1:
|
| 401 |
+
break
|
| 402 |
+
agent.record_message(message)
|
| 403 |
+
|
| 404 |
+
response = agent.step(message)
|
| 405 |
+
model_output = response.msgs[0].content
|
| 406 |
+
agent.reset()
|
| 407 |
+
return model_output
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def calculate_rouge_l_score(reference, hypothesis):
|
| 411 |
+
r"""Calculate rouge l score between hypothesis and reference."""
|
| 412 |
+
rouge = Rouge()
|
| 413 |
+
scores = rouge.get_scores(hypothesis, reference)
|
| 414 |
+
rouge_l_score = scores[0]['rouge-l']['f']
|
| 415 |
+
return rouge_l_score
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def get_api_call(model_output):
|
| 419 |
+
r"""Parse api call from model output."""
|
| 420 |
+
api_call_pattern = r"\[(\w+)\((.*)\)\]"
|
| 421 |
+
api_call_pattern = re.compile(api_call_pattern)
|
| 422 |
+
match = api_call_pattern.search(model_output)
|
| 423 |
+
if match:
|
| 424 |
+
return match.group(0)
|
| 425 |
+
else:
|
| 426 |
+
return None
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
class APIBankSample:
|
| 430 |
+
r"""APIBank sample used to load the datasets."""
|
| 431 |
+
|
| 432 |
+
def __init__(self, chat_history, apis, ground_truth):
|
| 433 |
+
self.chat_history = chat_history
|
| 434 |
+
self.apis = apis
|
| 435 |
+
self.ground_truth = ground_truth
|
| 436 |
+
|
| 437 |
+
def __repr__(self):
|
| 438 |
+
return 'Sample(chat_history={}, apis={}, ground_truth={})'.format(
|
| 439 |
+
self.chat_history, self.apis, self.ground_truth
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
@classmethod
|
| 443 |
+
def from_chat_history(cls, chat_history):
|
| 444 |
+
apis = set()
|
| 445 |
+
api_positions = []
|
| 446 |
+
for i, item in enumerate(chat_history):
|
| 447 |
+
if item['role'] == 'API':
|
| 448 |
+
apis.add(item['api_name'])
|
| 449 |
+
api_positions.append(i)
|
| 450 |
+
|
| 451 |
+
samples = []
|
| 452 |
+
for i in api_positions:
|
| 453 |
+
sample = cls(chat_history[:i], apis, chat_history[i])
|
| 454 |
+
samples.append(sample)
|
| 455 |
+
sample = cls(chat_history[: i + 1], apis, chat_history[i + 1])
|
| 456 |
+
samples.append(sample)
|
| 457 |
+
|
| 458 |
+
return samples
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
class Evaluator:
|
| 462 |
+
r"""Evaluator for APIBank benchmark."""
|
| 463 |
+
|
| 464 |
+
def __init__(self, samples: List[APIBankSample]):
|
| 465 |
+
# Place holder for import as the import
|
| 466 |
+
# only works after the files have been downloaded
|
| 467 |
+
try:
|
| 468 |
+
from api_bank.tool_manager import ( # type: ignore[import-not-found]
|
| 469 |
+
ToolManager,
|
| 470 |
+
)
|
| 471 |
+
except Exception as e:
|
| 472 |
+
logger.info(f"{e}, Module will be imported after download.")
|
| 473 |
+
self.dataset = samples
|
| 474 |
+
self.sample_ids = list(range(len(self.dataset)))
|
| 475 |
+
os.chdir("api_bank")
|
| 476 |
+
self.tool_manager = ToolManager("apis")
|
| 477 |
+
os.chdir("..")
|
| 478 |
+
|
| 479 |
+
def get_all_sample_ids(self):
|
| 480 |
+
return self.sample_ids
|
| 481 |
+
|
| 482 |
+
def get_api_description(self, api_name):
|
| 483 |
+
return self.tool_manager.get_api_description(api_name)
|
| 484 |
+
|
| 485 |
+
def get_model_input(self, sample_id: int):
|
| 486 |
+
sample = self.dataset[sample_id]
|
| 487 |
+
apis = sample.apis
|
| 488 |
+
chat_history = sample.chat_history
|
| 489 |
+
api_descriptions = []
|
| 490 |
+
for api_name in apis:
|
| 491 |
+
api_descriptions.append(
|
| 492 |
+
self.tool_manager.get_api_description(api_name)
|
| 493 |
+
)
|
| 494 |
+
api_description = '\n'.join(api_descriptions)
|
| 495 |
+
return api_description, chat_history
|
| 496 |
+
|
| 497 |
+
def evaluate(self, sample_id, model_output):
|
| 498 |
+
try:
|
| 499 |
+
from api_bank.api_call_extraction import ( # type: ignore[import-not-found]
|
| 500 |
+
parse_api_call,
|
| 501 |
+
)
|
| 502 |
+
except Exception as e:
|
| 503 |
+
logger.info(f"{e}, Module will be imported after download.")
|
| 504 |
+
sample = self.dataset[sample_id]
|
| 505 |
+
ground_truth = sample.ground_truth
|
| 506 |
+
if ground_truth['role'] == 'API':
|
| 507 |
+
api_name, param_dict = parse_api_call(model_output)
|
| 508 |
+
if api_name != ground_truth['api_name']:
|
| 509 |
+
return False, 'API Name Mismatch: {} vs {}'.format(
|
| 510 |
+
api_name, ground_truth['api_name']
|
| 511 |
+
)
|
| 512 |
+
try:
|
| 513 |
+
result = self.tool_manager.api_call(api_name, **param_dict)
|
| 514 |
+
except Exception as e:
|
| 515 |
+
return False, str(e)
|
| 516 |
+
api = self.tool_manager.init_tool(api_name)
|
| 517 |
+
try:
|
| 518 |
+
correct = api.check_api_call_correctness(
|
| 519 |
+
result, ground_truth['result']
|
| 520 |
+
)
|
| 521 |
+
except KeyError:
|
| 522 |
+
correct = False
|
| 523 |
+
result = 'KeyError' + str(result)
|
| 524 |
+
return correct, result
|
| 525 |
+
elif ground_truth['role'] == 'AI':
|
| 526 |
+
score = calculate_rouge_l_score(ground_truth['text'], model_output)
|
| 527 |
+
return round(score, 4)
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
API_CALL_PROMPT = '''
|
| 531 |
+
Based on the given API description and the existing \
|
| 532 |
+
conversation history 1..t, please generate the API request \
|
| 533 |
+
that the AI should call in step t+1 and output it in the \
|
| 534 |
+
format of [ApiName(key1='value1', key2='value2', ...)], \
|
| 535 |
+
replace the ApiName with the actual API name, and \
|
| 536 |
+
replace the key and value with the actual parameters. \
|
| 537 |
+
Your output should start with a square bracket "[" \
|
| 538 |
+
and end with a square bracket "]". Do not output any \
|
| 539 |
+
other explanation or prompt or the result of the API call in your output.
|
| 540 |
+
This year is 2023.
|
| 541 |
+
Input:
|
| 542 |
+
User: [User's utterence]
|
| 543 |
+
AI: [AI's utterence]
|
| 544 |
+
|
| 545 |
+
Expected output:
|
| 546 |
+
[ApiName(key1='value1', key2='value2', ...)]
|
| 547 |
+
|
| 548 |
+
API descriptions:
|
| 549 |
+
'''
|
| 550 |
+
|
| 551 |
+
RESPONSE_PROMPT = '''
|
| 552 |
+
Based on the given API description and the existing \
|
| 553 |
+
conversation history 1..t, please generate the next \
|
| 554 |
+
dialog that the AI should response after the API call t.
|
| 555 |
+
This year is 2023.
|
| 556 |
+
Input:
|
| 557 |
+
User: [User's utterence]
|
| 558 |
+
AI: [AI's utterence]
|
| 559 |
+
[ApiName(key1='value1', key2='value2', …)]
|
| 560 |
+
|
| 561 |
+
Expected output:
|
| 562 |
+
AI: [AI's utterence]
|
| 563 |
+
|
| 564 |
+
API descriptions:
|
| 565 |
+
'''
|
Paper2Poster/camel/benchmarks/apibench.py
ADDED
|
@@ -0,0 +1,500 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import random
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Dict, Literal, Optional
|
| 20 |
+
|
| 21 |
+
import tree_sitter_python as tspython
|
| 22 |
+
from tqdm import tqdm
|
| 23 |
+
from tree_sitter import Language, Parser
|
| 24 |
+
|
| 25 |
+
from camel.agents import ChatAgent
|
| 26 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 27 |
+
from camel.messages import BaseMessage
|
| 28 |
+
from camel.utils import download_github_subdirectory
|
| 29 |
+
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Mapping of dataset names to file names
|
| 34 |
+
# 'Oracle' retriver used here which means all the full
|
| 35 |
+
# API documentation will be included in the prompt
|
| 36 |
+
dataset_mapping = {
|
| 37 |
+
"huggingface": {
|
| 38 |
+
"api": "huggingface_api.jsonl",
|
| 39 |
+
"eval": "huggingface_eval.json",
|
| 40 |
+
"train": "huggingface_train.json",
|
| 41 |
+
"questions": "questions_huggingface_oracle.jsonl",
|
| 42 |
+
},
|
| 43 |
+
"tensorflowhub": {
|
| 44 |
+
"api": "tensorflowhub_api.jsonl",
|
| 45 |
+
"eval": "tensorflow_eval.json",
|
| 46 |
+
"train": "tensorflow_train.json",
|
| 47 |
+
"questions": "questions_tensorflowhub_oracle.jsonl",
|
| 48 |
+
},
|
| 49 |
+
"torchhub": {
|
| 50 |
+
"api": "torchhub_api.jsonl",
|
| 51 |
+
"eval": "torchhub_eval.json",
|
| 52 |
+
"train": "torchhub_train.json",
|
| 53 |
+
"questions": "questions_torchhub_oracle.jsonl",
|
| 54 |
+
},
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# This function is migrated from the original repo:
|
| 59 |
+
# https://github.com/ShishirPatil/gorilla
|
| 60 |
+
def encode_question(question: str, dataset_name: str) -> str:
|
| 61 |
+
r"""Encode multiple prompt instructions into a single string."""
|
| 62 |
+
|
| 63 |
+
if dataset_name == "torchhub":
|
| 64 |
+
domains = "1. $DOMAIN is inferred from the task description and \
|
| 65 |
+
should include one of {Classification, Semantic Segmentation, \
|
| 66 |
+
Object Detection, Audio Separation, Video Classification, \
|
| 67 |
+
Text-to-Speech}."
|
| 68 |
+
elif dataset_name == "huggingface":
|
| 69 |
+
domains = "1. $DOMAIN should include one of {Multimodal Feature \
|
| 70 |
+
Extraction, Multimodal Text-to-Image, Multimodal \
|
| 71 |
+
Image-to-Text, Multimodal Text-to-Video, \
|
| 72 |
+
Multimodal Visual Question Answering, Multimodal Document \
|
| 73 |
+
Question Answer, Multimodal Graph Machine Learning, \
|
| 74 |
+
Computer Vision Depth Estimation, Computer Vision Image \
|
| 75 |
+
Classification, Computer Vision Object Detection, \
|
| 76 |
+
Computer Vision Image Segmentation, Computer Vision \
|
| 77 |
+
Image-to-Image, Computer Vision Unconditional \
|
| 78 |
+
Image Generation, Computer Vision Video Classification, \
|
| 79 |
+
Computer Vision Zero-Shor Image Classification, \
|
| 80 |
+
Natural Language Processing Text Classification, \
|
| 81 |
+
Natural Language Processing Token Classification, \
|
| 82 |
+
Natural Language Processing Table Question Answering, \
|
| 83 |
+
Natural Language Processing Question Answering, \
|
| 84 |
+
Natural Language Processing, Zero-Shot Classification \
|
| 85 |
+
Natural Language Processing Translation, Natural Language \
|
| 86 |
+
Processing Summarization, Natural Language Processing \
|
| 87 |
+
Conversational, Natural Language Processing Text \
|
| 88 |
+
Generation, Natural Language Processing Fill-Mask, \
|
| 89 |
+
Natural Language Processing Text2Text Generation, \
|
| 90 |
+
Natural Language Processing Sentence Similarity, \
|
| 91 |
+
Audio Text-to-Speech, Audio Automatic Speech Recognition, \
|
| 92 |
+
Audio Audio-to-Audio, Audio Audio Classification, \
|
| 93 |
+
Audio Voice Activity Detection, Tabular Tabular \
|
| 94 |
+
Classification, Tabular Tabular Regression, \
|
| 95 |
+
Reinforcement Learning Reinforcement Learning, \
|
| 96 |
+
Reinforcement Learning Robotics }"
|
| 97 |
+
elif dataset_name == "tensorflowhub":
|
| 98 |
+
domains = "1. $DOMAIN is inferred from the task description \
|
| 99 |
+
and should include one of {text-sequence-alignment, \
|
| 100 |
+
text-embedding, text-language-model, text-preprocessing, \
|
| 101 |
+
text-classification, text-generation, text-question-answering, \
|
| 102 |
+
text-retrieval-question-answering, text-segmentation, \
|
| 103 |
+
text-to-mel, image-classification, image-feature-vector, \
|
| 104 |
+
image-object-detection, image-segmentation, \
|
| 105 |
+
image-generator, image-pose-detection, image-rnn-agent, \
|
| 106 |
+
image-augmentation, image-classifier, image-style-transfer, \
|
| 107 |
+
image-aesthetic-quality, image-depth-estimation, \
|
| 108 |
+
image-super-resolution, image-deblurring, image-extrapolation, \
|
| 109 |
+
image-text-recognition, image-dehazing, image-deraining, \
|
| 110 |
+
image-enhancemenmt, image-classification-logits, \
|
| 111 |
+
image-frame-interpolation, image-text-detection, image-denoising, \
|
| 112 |
+
image-others, video-classification, video-feature-extraction, \
|
| 113 |
+
video-generation, video-audio-text, video-text, \
|
| 114 |
+
audio-embedding, audio-event-classification, audio-command-detection, \
|
| 115 |
+
audio-paralinguists-classification, audio-speech-to-text, \
|
| 116 |
+
audio-speech-synthesis, audio-synthesis, audio-pitch-extraction}"
|
| 117 |
+
else:
|
| 118 |
+
logger.info("Error: API name is not supported.")
|
| 119 |
+
|
| 120 |
+
prompt = (
|
| 121 |
+
question
|
| 122 |
+
+ "\nWrite a python program in 1 to 2 lines to call API in "
|
| 123 |
+
+ dataset_name
|
| 124 |
+
+ ".\n\nThe answer should follow the format: <<<domain>>> $DOMAIN, \
|
| 125 |
+
<<<api_call>>>: $API_CALL, <<<api_provider>>>: $API_PROVIDER, \
|
| 126 |
+
<<<explanation>>>: $EXPLANATION, <<<code>>>: $CODE}. \
|
| 127 |
+
Here are the requirements:\n"
|
| 128 |
+
+ domains
|
| 129 |
+
+ "\n2. The $API_CALL should have only 1 line of code \
|
| 130 |
+
that calls api.\n 3. The $API_PROVIDER should be the \
|
| 131 |
+
programming framework used.\n4. $EXPLANATION should be \
|
| 132 |
+
a step-by-step explanation.\n5. The $CODE is the python code.\n6. \
|
| 133 |
+
Do not repeat the format in your answer."
|
| 134 |
+
)
|
| 135 |
+
return prompt
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class APIBenchBenchmark(BaseBenchmark):
|
| 139 |
+
r"""APIBench Benchmark adopted from `Gorilla: Large Language Model
|
| 140 |
+
Connected with Massive APIs`
|
| 141 |
+
<https://huggingface.co/datasets/gorilla-llm/APIBench>.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
data_dir (str): The directory to save the data.
|
| 145 |
+
save_to (str): The file to save the results.
|
| 146 |
+
processes (int, optional): The number of processes to use.
|
| 147 |
+
(default: :obj:`1`)
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
# TODO: Integrate retriever (pending)
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
data_dir: str,
|
| 155 |
+
save_to: str,
|
| 156 |
+
processes: int = 1,
|
| 157 |
+
):
|
| 158 |
+
r"""Initialize the APIBench benchmark.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
data_dir (str): The directory to save the data.
|
| 162 |
+
save_to (str): The file to save the results.
|
| 163 |
+
processes (int, optional): The number of processes to use for
|
| 164 |
+
parallel processing. (default: :obj:`1`)
|
| 165 |
+
"""
|
| 166 |
+
super().__init__("apibench", data_dir, save_to, processes)
|
| 167 |
+
|
| 168 |
+
def download(self):
|
| 169 |
+
r"""Download the APIBench dataset."""
|
| 170 |
+
from huggingface_hub import snapshot_download
|
| 171 |
+
|
| 172 |
+
snapshot_download(
|
| 173 |
+
repo_id="gorilla-llm/APIBench",
|
| 174 |
+
repo_type="dataset",
|
| 175 |
+
local_dir=self.data_dir,
|
| 176 |
+
local_dir_use_symlinks=True,
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
repo = "ShishirPatil/gorilla"
|
| 180 |
+
subdir = "/gorilla/eval/eval-data/questions"
|
| 181 |
+
data_dir = self.data_dir
|
| 182 |
+
|
| 183 |
+
download_github_subdirectory(repo, subdir, data_dir)
|
| 184 |
+
|
| 185 |
+
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
| 186 |
+
r"""Load the APIBench Benchmark dataset.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
dataset_name (str): Name of the specific dataset to be loaded.
|
| 190 |
+
force_download (bool, optional): Whether to force
|
| 191 |
+
download the data. (default: :obj:`False`)
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
if force_download:
|
| 195 |
+
logger.info("Force downloading data.")
|
| 196 |
+
self.download()
|
| 197 |
+
|
| 198 |
+
def load_json_lines(file_path: Path):
|
| 199 |
+
r"""Helper function to load JSON lines from a file."""
|
| 200 |
+
try:
|
| 201 |
+
with open(file_path, "r") as f:
|
| 202 |
+
return [json.loads(line) for line in f]
|
| 203 |
+
except FileNotFoundError:
|
| 204 |
+
raise FileNotFoundError(f"File not found: {file_path}")
|
| 205 |
+
except json.JSONDecodeError as e:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"Error decoding JSON in file {file_path}: {e}"
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
dataset_path = self.data_dir / dataset_name
|
| 211 |
+
if not dataset_path.exists():
|
| 212 |
+
raise FileNotFoundError(
|
| 213 |
+
f"Dataset directory does not exist: {dataset_path}"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
for label in ['api', 'eval', 'questions']:
|
| 217 |
+
file_name = dataset_mapping[dataset_name][label]
|
| 218 |
+
file_path = (
|
| 219 |
+
dataset_path / file_name
|
| 220 |
+
if label == 'questions'
|
| 221 |
+
else self.data_dir / file_name
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
# Load data based on label type
|
| 225 |
+
if label in ['api', 'questions', 'eval']:
|
| 226 |
+
data = load_json_lines(file_path)
|
| 227 |
+
|
| 228 |
+
if label == 'eval':
|
| 229 |
+
# Extract 'api_data' specifically for eval label
|
| 230 |
+
data = [item['api_data'] for item in data]
|
| 231 |
+
|
| 232 |
+
self._data[label] = data
|
| 233 |
+
else:
|
| 234 |
+
raise ValueError(f"Unknown label: {label}")
|
| 235 |
+
|
| 236 |
+
ast_database = []
|
| 237 |
+
for data in self._data['api']:
|
| 238 |
+
ast_tree = ast_parse(data['api_call'])
|
| 239 |
+
ast_database.append(ast_tree)
|
| 240 |
+
self._data['ast'] = ast_database
|
| 241 |
+
|
| 242 |
+
def run( # type: ignore[override]
|
| 243 |
+
self,
|
| 244 |
+
agent: ChatAgent,
|
| 245 |
+
dataset_name: Literal["huggingface", "tensorflowhub", "torchhub"],
|
| 246 |
+
randomize: bool = False,
|
| 247 |
+
subset: Optional[int] = None,
|
| 248 |
+
) -> Dict[str, Any]:
|
| 249 |
+
r"""Run the benchmark.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
agent (ChatAgent): The agent to run the
|
| 253 |
+
benchmark.
|
| 254 |
+
dataset_name (Literal["huggingface",
|
| 255 |
+
"tensorflowhub", "torchhub"]):
|
| 256 |
+
The dataset to run the benchmark.
|
| 257 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 258 |
+
(default: :obj:`False`)
|
| 259 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 260 |
+
(default: :obj:`None`)
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
if dataset_name not in dataset_mapping:
|
| 264 |
+
raise ValueError(f"Invalid value for dataset: {dataset_name}.")
|
| 265 |
+
|
| 266 |
+
logger.info(f"Running APIBench benchmark on {dataset_name}.")
|
| 267 |
+
self.load(dataset_name)
|
| 268 |
+
datas = self._data['questions']
|
| 269 |
+
|
| 270 |
+
# Shuffle and subset data if necessary
|
| 271 |
+
if randomize:
|
| 272 |
+
random.shuffle(datas)
|
| 273 |
+
if subset:
|
| 274 |
+
datas = datas[:subset]
|
| 275 |
+
|
| 276 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 277 |
+
|
| 278 |
+
# Initialize results storage
|
| 279 |
+
self._results = []
|
| 280 |
+
|
| 281 |
+
with open(self.save_to, "w") as f:
|
| 282 |
+
for question in tqdm(datas, desc="Running"):
|
| 283 |
+
prompt = encode_question(question["text"], dataset_name)
|
| 284 |
+
msg = BaseMessage.make_user_message(
|
| 285 |
+
role_name="User", content=prompt
|
| 286 |
+
)
|
| 287 |
+
try:
|
| 288 |
+
# Generate response
|
| 289 |
+
responses = agent.step(msg)
|
| 290 |
+
response = responses.msgs[0].content
|
| 291 |
+
api_database = self._data['api']
|
| 292 |
+
qa_pairs = self._data['eval']
|
| 293 |
+
ast_database = self._data['ast']
|
| 294 |
+
question_id = question['question_id']
|
| 295 |
+
|
| 296 |
+
# Evaluate response
|
| 297 |
+
error, correct, hallucination = evaluate_response(
|
| 298 |
+
response,
|
| 299 |
+
question_id,
|
| 300 |
+
dataset_name,
|
| 301 |
+
api_database,
|
| 302 |
+
qa_pairs,
|
| 303 |
+
ast_database,
|
| 304 |
+
)
|
| 305 |
+
self._results.append(
|
| 306 |
+
{
|
| 307 |
+
"question": question,
|
| 308 |
+
"agent_response": response,
|
| 309 |
+
"correct": correct,
|
| 310 |
+
"hallucination": hallucination,
|
| 311 |
+
"error": str(error) if error else None,
|
| 312 |
+
}
|
| 313 |
+
)
|
| 314 |
+
except Exception as e:
|
| 315 |
+
logger.warning(
|
| 316 |
+
f"Error in processing task: {question}: {e}"
|
| 317 |
+
)
|
| 318 |
+
self._results.append(
|
| 319 |
+
{
|
| 320 |
+
"question": question,
|
| 321 |
+
"agent_response": None,
|
| 322 |
+
"correct": False,
|
| 323 |
+
"hallucination": False,
|
| 324 |
+
"error": str(e),
|
| 325 |
+
}
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
agent.reset()
|
| 329 |
+
|
| 330 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 331 |
+
f.flush()
|
| 332 |
+
|
| 333 |
+
total = len(self._results)
|
| 334 |
+
correct = sum(r["correct"] for r in self.results)
|
| 335 |
+
hallucination = sum(r["hallucination"] for r in self.results)
|
| 336 |
+
|
| 337 |
+
return {
|
| 338 |
+
"total": total,
|
| 339 |
+
"correct": correct,
|
| 340 |
+
"hallucination": hallucination,
|
| 341 |
+
"accuracy": correct / total if total else "N/A",
|
| 342 |
+
"hallucination rate": hallucination / total if total else "N/A",
|
| 343 |
+
}
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# This code is modified from the
|
| 347 |
+
# evaluators in the original repo
|
| 348 |
+
# https://github.com/ShishirPatil/gorilla
|
| 349 |
+
# Get all the subtrees given a root_node
|
| 350 |
+
def get_all_sub_trees(root_node):
|
| 351 |
+
node_stack = []
|
| 352 |
+
sub_tree_sexp_list = []
|
| 353 |
+
depth = 1
|
| 354 |
+
# text = root_node.text
|
| 355 |
+
node_stack.append([root_node, depth])
|
| 356 |
+
while len(node_stack) != 0:
|
| 357 |
+
cur_node, cur_depth = node_stack.pop()
|
| 358 |
+
if cur_node.child_count > 0:
|
| 359 |
+
sub_tree_sexp_list.append(
|
| 360 |
+
[
|
| 361 |
+
str(cur_node),
|
| 362 |
+
cur_depth,
|
| 363 |
+
cur_node,
|
| 364 |
+
cur_node.children[0].text,
|
| 365 |
+
]
|
| 366 |
+
)
|
| 367 |
+
else:
|
| 368 |
+
sub_tree_sexp_list.append(
|
| 369 |
+
[str(cur_node), cur_depth, cur_node, None]
|
| 370 |
+
)
|
| 371 |
+
for child_node in cur_node.children:
|
| 372 |
+
if len(child_node.children) != 0:
|
| 373 |
+
depth = cur_depth + 1
|
| 374 |
+
node_stack.append([child_node, depth])
|
| 375 |
+
return sub_tree_sexp_list
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# Parse the program into AST trees
|
| 379 |
+
def ast_parse(candidate):
|
| 380 |
+
PY_LANGUAGE = Language(tspython.language())
|
| 381 |
+
parser = Parser(PY_LANGUAGE)
|
| 382 |
+
|
| 383 |
+
candidate_tree = parser.parse(bytes(candidate, "utf8")).root_node
|
| 384 |
+
return candidate_tree
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
# Get all the arguments in the ast tree
|
| 388 |
+
def get_args(node, dataset_name):
|
| 389 |
+
if node.child_count == 0:
|
| 390 |
+
return []
|
| 391 |
+
args_list = []
|
| 392 |
+
if dataset_name == "huggingface":
|
| 393 |
+
for child in node.children[0].children[0].children[1].children:
|
| 394 |
+
if "=" in child.text.decode():
|
| 395 |
+
args_list.append(child.children[2].text)
|
| 396 |
+
elif (
|
| 397 |
+
child.text.decode() != "("
|
| 398 |
+
and child.text.decode() != ")"
|
| 399 |
+
and child.text.decode() != ","
|
| 400 |
+
):
|
| 401 |
+
args_list.append(child.text)
|
| 402 |
+
elif dataset_name == "tensorflowhub":
|
| 403 |
+
for child in node.children[0].children[0].children[1].children:
|
| 404 |
+
if (
|
| 405 |
+
'model=' in child.text.decode()
|
| 406 |
+
or 'model =' in child.text.decode()
|
| 407 |
+
):
|
| 408 |
+
args_list.append(child.children[2].text)
|
| 409 |
+
elif (
|
| 410 |
+
child.text.decode() != "("
|
| 411 |
+
and child.text.decode() != ")"
|
| 412 |
+
and child.text.decode() != ","
|
| 413 |
+
):
|
| 414 |
+
args_list.append(child.text)
|
| 415 |
+
elif dataset_name == "torchhub":
|
| 416 |
+
for child in node.children[0].children[0].children[1].children:
|
| 417 |
+
if (
|
| 418 |
+
"repo_or_dir" in child.text.decode()
|
| 419 |
+
or "model" in child.text.decode()
|
| 420 |
+
):
|
| 421 |
+
args_list.append(child.children[2].text)
|
| 422 |
+
return args_list
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
# Check if there is an api match
|
| 426 |
+
def ast_check(candidate_subtree_list, base_tree_list, dataset_name):
|
| 427 |
+
for idx, base_tree in enumerate(base_tree_list):
|
| 428 |
+
if base_tree.children[0].children[0].child_count == 0:
|
| 429 |
+
continue
|
| 430 |
+
api_name = base_tree.children[0].children[0].children[0].text
|
| 431 |
+
for candidate_tree in candidate_subtree_list:
|
| 432 |
+
if candidate_tree[3] == api_name:
|
| 433 |
+
break
|
| 434 |
+
# Now we have a sub-tree
|
| 435 |
+
candidate_tree = candidate_tree[2]
|
| 436 |
+
args_list = get_args(base_tree, dataset_name)
|
| 437 |
+
if len(args_list) == 0:
|
| 438 |
+
continue
|
| 439 |
+
ast_match = True
|
| 440 |
+
for arg in args_list:
|
| 441 |
+
if (
|
| 442 |
+
arg.decode().lstrip("'").rstrip("'")
|
| 443 |
+
not in candidate_tree.text.decode()
|
| 444 |
+
):
|
| 445 |
+
ast_match = False
|
| 446 |
+
break
|
| 447 |
+
if ast_match:
|
| 448 |
+
return idx
|
| 449 |
+
return -1
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def evaluate_response(
|
| 453 |
+
response, question_id, dataset_name, api_database, qa_pairs, ast_database
|
| 454 |
+
):
|
| 455 |
+
try:
|
| 456 |
+
# Index the "api_call" domain
|
| 457 |
+
output = response.split("api_call")
|
| 458 |
+
if len(output) == 1:
|
| 459 |
+
api_call = output[0]
|
| 460 |
+
else:
|
| 461 |
+
# Parse the output
|
| 462 |
+
output = output[1].split("api_provider")[0]
|
| 463 |
+
if ":" not in output:
|
| 464 |
+
start = 0
|
| 465 |
+
else:
|
| 466 |
+
start = output.index(":")
|
| 467 |
+
if ")" not in output:
|
| 468 |
+
end = -2
|
| 469 |
+
else:
|
| 470 |
+
end = output.rindex(")")
|
| 471 |
+
api_call = output[start + 2 : end + 1]
|
| 472 |
+
|
| 473 |
+
try:
|
| 474 |
+
ast_tree = ast_parse(api_call)
|
| 475 |
+
except Exception as parse_error:
|
| 476 |
+
print(f"Error parsing api_call: {api_call}, error: {parse_error}")
|
| 477 |
+
return parse_error, False, False
|
| 478 |
+
# Search for a subtree
|
| 479 |
+
ast_subtree_list = get_all_sub_trees(ast_tree)
|
| 480 |
+
# Check which ast tree is matching
|
| 481 |
+
database_index = ast_check(
|
| 482 |
+
ast_subtree_list, ast_database, dataset_name
|
| 483 |
+
)
|
| 484 |
+
# We cannot index this ast in our database
|
| 485 |
+
if database_index == -1:
|
| 486 |
+
halluncination = True
|
| 487 |
+
correct = False
|
| 488 |
+
# We index our reference api_call
|
| 489 |
+
ref_api_call = api_database[database_index]
|
| 490 |
+
# Check for functionality
|
| 491 |
+
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
|
| 492 |
+
correct = True
|
| 493 |
+
halluncination = False
|
| 494 |
+
else:
|
| 495 |
+
return None, False, False
|
| 496 |
+
except Exception as e:
|
| 497 |
+
print(f'Error parsing response: {response}, error: {e}')
|
| 498 |
+
return e, False, False
|
| 499 |
+
|
| 500 |
+
return None, correct, halluncination
|
Paper2Poster/camel/benchmarks/base.py
ADDED
|
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
from abc import ABC, abstractmethod
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from typing import Any, Dict, List, Literal, Optional
|
| 19 |
+
|
| 20 |
+
from camel.agents import ChatAgent
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class BaseBenchmark(ABC):
|
| 26 |
+
r"""Base class for benchmarks.
|
| 27 |
+
|
| 28 |
+
Attributes:
|
| 29 |
+
name (str): Name of the benchmark.
|
| 30 |
+
data_dir (str): Path to the data directory.
|
| 31 |
+
save_to (str): Path to save the results.
|
| 32 |
+
processes (int): Number of processes to use for parallel
|
| 33 |
+
processing. :(default: :obj:`1`)
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self, name: str, data_dir: str, save_to: str, processes: int = 1
|
| 38 |
+
):
|
| 39 |
+
r"""Initialize the benchmark.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
name (str): Name of the benchmark.
|
| 43 |
+
data_dir (str): Path to the data directory.
|
| 44 |
+
save_to (str): Path to save the results.
|
| 45 |
+
processes (int): Number of processes to use for parallel
|
| 46 |
+
processing. :(default: :obj:`1`)
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
self.name = name
|
| 50 |
+
self.data_dir = Path(data_dir)
|
| 51 |
+
self.processes = processes
|
| 52 |
+
self.save_to = save_to
|
| 53 |
+
if not self.data_dir.exists():
|
| 54 |
+
logger.info(
|
| 55 |
+
f"Data directory {data_dir} does not exist. Creating it."
|
| 56 |
+
)
|
| 57 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 58 |
+
if not self.data_dir.is_dir():
|
| 59 |
+
raise NotADirectoryError(
|
| 60 |
+
f"Data directory {data_dir} is not a directory"
|
| 61 |
+
)
|
| 62 |
+
self._data: Dict[str, List[Dict[str, Any]]] = dict()
|
| 63 |
+
self._results: List[Dict[str, Any]] = []
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def download(self) -> "BaseBenchmark":
|
| 67 |
+
r"""Download the benchmark data.
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
BaseBenchmark: The benchmark instance.
|
| 71 |
+
"""
|
| 72 |
+
pass
|
| 73 |
+
|
| 74 |
+
@abstractmethod
|
| 75 |
+
def load(self, force_download: bool = False) -> "BaseBenchmark":
|
| 76 |
+
r"""Load the benchmark data.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
force_download (bool): Whether to force download the data.
|
| 80 |
+
|
| 81 |
+
Returns:
|
| 82 |
+
BaseBenchmark: The benchmark instance.
|
| 83 |
+
"""
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def train(self) -> List[Dict[str, Any]]:
|
| 88 |
+
r"""Get the training data.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
List[Dict[str, Any]]: The training data.
|
| 92 |
+
"""
|
| 93 |
+
if not self._data:
|
| 94 |
+
logger.info("Data not loaded. Loading data.")
|
| 95 |
+
self.load()
|
| 96 |
+
return self._data["train"]
|
| 97 |
+
|
| 98 |
+
@property
|
| 99 |
+
def valid(self) -> List[Dict[str, Any]]:
|
| 100 |
+
r"""Get the validation data.
|
| 101 |
+
|
| 102 |
+
Returns:
|
| 103 |
+
List[Dict[str, Any]]: The validation data.
|
| 104 |
+
"""
|
| 105 |
+
if not self._data:
|
| 106 |
+
logger.info("Data not loaded. Loading data.")
|
| 107 |
+
self.load()
|
| 108 |
+
return self._data["valid"]
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def test(self) -> List[Dict[str, Any]]:
|
| 112 |
+
r"""Get the test data.
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
List[Dict[str, Any]]: The test data.
|
| 116 |
+
"""
|
| 117 |
+
if not self._data:
|
| 118 |
+
logger.info("Data not loaded. Loading data.")
|
| 119 |
+
self.load()
|
| 120 |
+
return self._data["test"]
|
| 121 |
+
|
| 122 |
+
@abstractmethod
|
| 123 |
+
def run(
|
| 124 |
+
self,
|
| 125 |
+
agent: ChatAgent,
|
| 126 |
+
on: Literal["train", "valid", "test"],
|
| 127 |
+
randomize: bool = False,
|
| 128 |
+
subset: Optional[int] = None,
|
| 129 |
+
*args,
|
| 130 |
+
**kwargs,
|
| 131 |
+
) -> "BaseBenchmark":
|
| 132 |
+
r"""Run the benchmark.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
agent (ChatAgent): The chat agent.
|
| 136 |
+
on (str): The data split to run the benchmark on.
|
| 137 |
+
randomize (bool): Whether to randomize the data.
|
| 138 |
+
subset (int): The subset of the data to run the benchmark on.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
BaseBenchmark: The benchmark instance.
|
| 142 |
+
"""
|
| 143 |
+
pass
|
| 144 |
+
|
| 145 |
+
@property
|
| 146 |
+
def results(self) -> List[Dict[str, Any]]:
|
| 147 |
+
r"""Get the results.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
List[Dict[str, Any]]: The results.
|
| 151 |
+
"""
|
| 152 |
+
return self._results
|
Paper2Poster/camel/benchmarks/gaia.py
ADDED
|
@@ -0,0 +1,478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
import logging
|
| 17 |
+
import os
|
| 18 |
+
import random
|
| 19 |
+
import re
|
| 20 |
+
import string
|
| 21 |
+
import uuid
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
|
| 24 |
+
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
|
| 27 |
+
from camel.agents import ChatAgent
|
| 28 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 29 |
+
from camel.messages import BaseMessage
|
| 30 |
+
from camel.retrievers.auto_retriever import AutoRetriever
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class RetrieverProtocol(Protocol):
|
| 36 |
+
r"""Protocol for the retriever class. Any retriever class implementing
|
| 37 |
+
this protocol can be used in the benchmark class.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def retrieve(
|
| 41 |
+
self, query: str, contents: List[str], **kwargs: Dict[str, Any]
|
| 42 |
+
) -> Dict[str, Any]:
|
| 43 |
+
r"""Retrieve the relevant content for the query.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
query (str): The query to retrieve the content for.
|
| 47 |
+
contents (List[str]): The list of contents to search in.
|
| 48 |
+
**kwargs (Dict[str, Any]): Additional keyword arguments.
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Dict[str, Any]: The relevant content for the query.
|
| 52 |
+
"""
|
| 53 |
+
...
|
| 54 |
+
|
| 55 |
+
def reset(self, **kwargs) -> bool:
|
| 56 |
+
r"""Reset the retriever.
|
| 57 |
+
Some benchmarks may require resetting the retriever
|
| 58 |
+
after each query.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
**kwargs: Additional keyword arguments.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
bool: True if the reset was successful, False otherwise.
|
| 65 |
+
"""
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class DefaultGAIARetriever(AutoRetriever):
|
| 70 |
+
r"""Default retriever for the GAIA benchmark.
|
| 71 |
+
This retriever uses AutoRetriever in camel to retrieve the content based on
|
| 72 |
+
the query.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def retrieve(
|
| 76 |
+
self, query: str, contents: List[str], **kwargs: Any
|
| 77 |
+
) -> Dict[str, Any]:
|
| 78 |
+
r"""Retrieve the content based on the query.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
query (str): The query to search for.
|
| 82 |
+
contents (List[str]): The list of contents to search from.
|
| 83 |
+
**kwargs (Any): The keyword arguments to pass to the
|
| 84 |
+
retriever.
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Dict[str, Any]: The retrieved content.
|
| 88 |
+
"""
|
| 89 |
+
return self.run_vector_retriever(query, contents, **kwargs) # type: ignore[arg-type]
|
| 90 |
+
|
| 91 |
+
def reset(self, **kwargs: Any) -> bool:
|
| 92 |
+
r"""Reset the retriever.
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
**kwargs (Any): The keyword arguments to pass to the
|
| 96 |
+
retriever.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
bool: Whether the reset was successful.
|
| 100 |
+
"""
|
| 101 |
+
path = Path(self.vector_storage_local_path or os.getcwd())
|
| 102 |
+
task_id = str(kwargs.get("task_id", uuid.uuid4()))
|
| 103 |
+
retriever_dir = path / task_id
|
| 104 |
+
if not retriever_dir.exists():
|
| 105 |
+
try:
|
| 106 |
+
retriever_dir.mkdir(parents=True)
|
| 107 |
+
except Exception as e:
|
| 108 |
+
logger.error(
|
| 109 |
+
"Error in creating directory: " + f"{retriever_dir}: {e!s}"
|
| 110 |
+
)
|
| 111 |
+
return False
|
| 112 |
+
self.vector_storage_local_path = str(retriever_dir)
|
| 113 |
+
return True
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
class GAIABenchmark(BaseBenchmark):
|
| 117 |
+
r"""GAIA Benchmark adapted from `"GAIA: a benchmark for General AI
|
| 118 |
+
Assistants"
|
| 119 |
+
<https://huggingface.co/datasets/gaia-benchmark/GAIA>`_.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
data_dir (str): The directory to save the data.
|
| 123 |
+
save_to (str): The file to save the results.
|
| 124 |
+
retriever (Optional[RetrieverProtocol]): The retriever to use.
|
| 125 |
+
(default: :obj:`None`)
|
| 126 |
+
processes (int, optional): The number of processes to use.
|
| 127 |
+
(default: :obj:`1`)
|
| 128 |
+
"""
|
| 129 |
+
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
data_dir: str,
|
| 133 |
+
save_to: str,
|
| 134 |
+
retriever: Optional[RetrieverProtocol] = None,
|
| 135 |
+
processes: int = 1,
|
| 136 |
+
):
|
| 137 |
+
r"""Initialize the GAIA benchmark.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
data_dir (str): The directory to save the data.
|
| 141 |
+
save_to (str): The file to save the results.
|
| 142 |
+
retriever (Optional[RetrieverProtocol], optional): The retriever to
|
| 143 |
+
use. (default: :obj:`None`)
|
| 144 |
+
processes (int, optional): The number of processes to use for
|
| 145 |
+
parallel processing. (default: :obj:`1`)
|
| 146 |
+
"""
|
| 147 |
+
super().__init__("gaia", data_dir, save_to, processes)
|
| 148 |
+
self.retriever = retriever or DefaultGAIARetriever()
|
| 149 |
+
|
| 150 |
+
def download(self):
|
| 151 |
+
r"""Download the GAIA dataset."""
|
| 152 |
+
from huggingface_hub import snapshot_download
|
| 153 |
+
|
| 154 |
+
snapshot_download(
|
| 155 |
+
repo_id="gaia-benchmark/GAIA",
|
| 156 |
+
repo_type="dataset",
|
| 157 |
+
local_dir=self.data_dir,
|
| 158 |
+
local_dir_use_symlinks=True,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def load(self, force_download=False):
|
| 162 |
+
r"""Load the GAIA dataset.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
force_download (bool, optional): Whether to
|
| 166 |
+
force download the data.
|
| 167 |
+
"""
|
| 168 |
+
if force_download:
|
| 169 |
+
logger.info("Force downloading data.")
|
| 170 |
+
self.download()
|
| 171 |
+
|
| 172 |
+
# Define validation and test directories
|
| 173 |
+
valid_dir = self.data_dir / "2023/validation"
|
| 174 |
+
test_dir = self.data_dir / "2023/test"
|
| 175 |
+
|
| 176 |
+
# Check if directories exist; if not, download the data
|
| 177 |
+
if not valid_dir.is_dir() or not test_dir.is_dir():
|
| 178 |
+
logger.info("Data not found. Downloading data.")
|
| 179 |
+
self.download()
|
| 180 |
+
|
| 181 |
+
# Load metadata for both validation and test datasets
|
| 182 |
+
for path, label in zip([valid_dir, test_dir], ["valid", "test"]):
|
| 183 |
+
self._data[label] = []
|
| 184 |
+
with open(path / "metadata.jsonl", "r") as f:
|
| 185 |
+
lines = f.readlines()
|
| 186 |
+
for line in lines:
|
| 187 |
+
data = json.loads(line)
|
| 188 |
+
if data["task_id"] == "0-0-0-0-0":
|
| 189 |
+
continue
|
| 190 |
+
if data["file_name"]:
|
| 191 |
+
data["file_name"] = path / data["file_name"]
|
| 192 |
+
self._data[label].append(data)
|
| 193 |
+
return self
|
| 194 |
+
|
| 195 |
+
@property
|
| 196 |
+
def train(self):
|
| 197 |
+
r"""Get the training set."""
|
| 198 |
+
raise NotImplementedError("GAIA does not have a training set.")
|
| 199 |
+
|
| 200 |
+
def run( # type: ignore[override]
|
| 201 |
+
self,
|
| 202 |
+
agent: ChatAgent,
|
| 203 |
+
on: Literal["train", "valid", "test"],
|
| 204 |
+
level: Union[int, List[int], Literal["all"]],
|
| 205 |
+
randomize: bool = False,
|
| 206 |
+
subset: Optional[int] = None,
|
| 207 |
+
) -> Dict[str, Any]:
|
| 208 |
+
r"""Run the benchmark.
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
agent (ChatAgent): The agent to run the benchmark.
|
| 212 |
+
on (Literal["valid", "test"]): The set to run the benchmark.
|
| 213 |
+
level (Union[int, List[int], Literal["all"]]): The level to run
|
| 214 |
+
the benchmark.
|
| 215 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 216 |
+
(default: :obj:`False`)
|
| 217 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 218 |
+
(default: :obj:`None`)
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
Dict[str, Any]: The results of the benchmark.
|
| 222 |
+
"""
|
| 223 |
+
# Validate inputs
|
| 224 |
+
if on not in ["valid", "test"]:
|
| 225 |
+
raise ValueError(
|
| 226 |
+
f"Invalid value for `on`: {on}, expected 'valid' or 'test'."
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
levels = (
|
| 230 |
+
[1, 2, 3]
|
| 231 |
+
if level == "all"
|
| 232 |
+
else [level]
|
| 233 |
+
if isinstance(level, int)
|
| 234 |
+
else level
|
| 235 |
+
)
|
| 236 |
+
if not all(
|
| 237 |
+
isinstance(level, int) and level in [1, 2, 3] for level in levels
|
| 238 |
+
):
|
| 239 |
+
raise ValueError(
|
| 240 |
+
f"Invalid value for `level`: {level}, expected 1, 2, 3 "
|
| 241 |
+
"or 'all'."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
logger.info(f"Running benchmark on {on} set at levels {levels}.")
|
| 245 |
+
datas = [data for data in self._data[on] if data["Level"] in levels]
|
| 246 |
+
|
| 247 |
+
# Shuffle and subset data if necessary
|
| 248 |
+
if randomize:
|
| 249 |
+
random.shuffle(datas)
|
| 250 |
+
if subset:
|
| 251 |
+
datas = datas[:subset]
|
| 252 |
+
|
| 253 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 254 |
+
|
| 255 |
+
# Initialize results storage
|
| 256 |
+
self._results = []
|
| 257 |
+
|
| 258 |
+
# Process tasks
|
| 259 |
+
with open(self.save_to, "w") as f:
|
| 260 |
+
for task in tqdm(datas, desc="Running"):
|
| 261 |
+
if not self._prepare_task(task):
|
| 262 |
+
continue
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
result = agent.step(self._create_user_message(task))
|
| 266 |
+
self._process_result(agent, task, result, f)
|
| 267 |
+
except Exception as e:
|
| 268 |
+
self._handle_error(task, e, f)
|
| 269 |
+
finally:
|
| 270 |
+
agent.reset()
|
| 271 |
+
|
| 272 |
+
return self._generate_summary()
|
| 273 |
+
|
| 274 |
+
def _prepare_task(self, task: Dict[str, Any]) -> bool:
|
| 275 |
+
r"""Prepare the task by validating and enriching its data."""
|
| 276 |
+
if task["file_name"]:
|
| 277 |
+
file_path = Path(task["file_name"])
|
| 278 |
+
if not file_path.exists():
|
| 279 |
+
logger.info(
|
| 280 |
+
f"Skipping task because file not found: {file_path}"
|
| 281 |
+
)
|
| 282 |
+
return False
|
| 283 |
+
if file_path.suffix in [".pdf", ".docx", ".doc", ".txt"]:
|
| 284 |
+
if not self.retriever.reset(task_id=task["task_id"]):
|
| 285 |
+
return False
|
| 286 |
+
retrieved_info = self.retriever.retrieve(
|
| 287 |
+
query=task["Question"], contents=[task["file_name"]]
|
| 288 |
+
)
|
| 289 |
+
retrieved_content = [
|
| 290 |
+
item["text"]
|
| 291 |
+
for item in retrieved_info.get("Retrieved Context", [])
|
| 292 |
+
]
|
| 293 |
+
if retrieved_content:
|
| 294 |
+
task["Question"] += "\n" + "\n".join(retrieved_content)
|
| 295 |
+
else:
|
| 296 |
+
logger.info(
|
| 297 |
+
f"Skipping task due to unsupported file "
|
| 298 |
+
f"format: {file_path.suffix}"
|
| 299 |
+
)
|
| 300 |
+
return False
|
| 301 |
+
return True
|
| 302 |
+
|
| 303 |
+
def _create_user_message(self, task: Dict[str, Any]) -> BaseMessage:
|
| 304 |
+
r"""Create a user message from a task."""
|
| 305 |
+
return BaseMessage.make_user_message(
|
| 306 |
+
role_name="User",
|
| 307 |
+
content=task["Question"],
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
def _process_result(
|
| 311 |
+
self,
|
| 312 |
+
agent: ChatAgent,
|
| 313 |
+
task: Dict[str, Any],
|
| 314 |
+
result: Any,
|
| 315 |
+
file_obj: Any,
|
| 316 |
+
) -> None:
|
| 317 |
+
r"""Process and store the result of a task."""
|
| 318 |
+
model_answer = self.get_final_answer(result.msgs[0].content)
|
| 319 |
+
final_answer = task["Final answer"]
|
| 320 |
+
score = self.question_scorer(model_answer, final_answer)
|
| 321 |
+
tool_calls = result.info.get("tool_calls", [])
|
| 322 |
+
|
| 323 |
+
result_data = {
|
| 324 |
+
"task_id": task["task_id"],
|
| 325 |
+
"question": task["Question"],
|
| 326 |
+
"level": task["Level"],
|
| 327 |
+
"model_answer": model_answer,
|
| 328 |
+
"ground_truth": final_answer,
|
| 329 |
+
"tool_calls": [tool.model_dump() for tool in tool_calls],
|
| 330 |
+
"error": None,
|
| 331 |
+
"score": int(score),
|
| 332 |
+
"history": agent.memory.get_context(),
|
| 333 |
+
}
|
| 334 |
+
self._results.append(result_data)
|
| 335 |
+
file_obj.write(json.dumps(result_data, indent=2) + "\n")
|
| 336 |
+
file_obj.flush()
|
| 337 |
+
|
| 338 |
+
def _handle_error(
|
| 339 |
+
self, task: Dict[str, Any], error: Exception, file_obj: Any
|
| 340 |
+
) -> None:
|
| 341 |
+
r"""Handle errors encountered during task processing."""
|
| 342 |
+
logger.warning(f"Error processing task {task['task_id']}: {error}")
|
| 343 |
+
error_data = {
|
| 344 |
+
"task_id": task["task_id"],
|
| 345 |
+
"question": task["Question"],
|
| 346 |
+
"level": task["Level"],
|
| 347 |
+
"model_answer": "ERROR",
|
| 348 |
+
"ground_truth": task["Final answer"],
|
| 349 |
+
"tool_calls": [],
|
| 350 |
+
"error": str(error),
|
| 351 |
+
"score": 0,
|
| 352 |
+
}
|
| 353 |
+
self._results.append(error_data)
|
| 354 |
+
file_obj.write(json.dumps(error_data, indent=2) + "\n")
|
| 355 |
+
file_obj.flush()
|
| 356 |
+
|
| 357 |
+
def _generate_summary(self) -> Dict[str, Any]:
|
| 358 |
+
r"""Generate and return a summary of the benchmark results."""
|
| 359 |
+
return {
|
| 360 |
+
"total": len(self._results),
|
| 361 |
+
"correct": sum(result["score"] for result in self._results),
|
| 362 |
+
"results": self._results,
|
| 363 |
+
}
|
| 364 |
+
|
| 365 |
+
def question_scorer(self, model_answer: str, ground_truth: str) -> bool:
|
| 366 |
+
r"""Scorer for the GAIA benchmark.
|
| 367 |
+
https://huggingface.co/spaces/gaia-benchmark/leaderboard/blob/main/
|
| 368 |
+
scorer.py
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
model_answer (str): The model answer.
|
| 372 |
+
ground_truth (str): The ground truth answer.
|
| 373 |
+
|
| 374 |
+
Returns:
|
| 375 |
+
bool: The score of the model
|
| 376 |
+
"""
|
| 377 |
+
|
| 378 |
+
def is_float(element: Any) -> bool:
|
| 379 |
+
try:
|
| 380 |
+
float(element)
|
| 381 |
+
return True
|
| 382 |
+
except ValueError:
|
| 383 |
+
return False
|
| 384 |
+
|
| 385 |
+
if is_float(ground_truth):
|
| 386 |
+
logger.info(f"Evaluating {model_answer} as a number.")
|
| 387 |
+
normalized_answer = self.normalize_number_str(model_answer)
|
| 388 |
+
return normalized_answer == float(ground_truth)
|
| 389 |
+
|
| 390 |
+
elif any(char in ground_truth for char in [",", ";"]):
|
| 391 |
+
logger.info(
|
| 392 |
+
f"Evaluating {model_answer} as a comma separated list."
|
| 393 |
+
)
|
| 394 |
+
gt_elems = self.split_string(ground_truth)
|
| 395 |
+
ma_elems = self.split_string(model_answer)
|
| 396 |
+
|
| 397 |
+
if len(gt_elems) != len(ma_elems):
|
| 398 |
+
logger.warning(
|
| 399 |
+
"Answer lists have different lengths, returning False.",
|
| 400 |
+
UserWarning,
|
| 401 |
+
)
|
| 402 |
+
return False
|
| 403 |
+
|
| 404 |
+
comparisons = []
|
| 405 |
+
for ma_elem, gt_elem in zip(ma_elems, gt_elems):
|
| 406 |
+
if is_float(gt_elem):
|
| 407 |
+
normalized_ma_elem = self.normalize_number_str(ma_elem)
|
| 408 |
+
comparisons.append(normalized_ma_elem == float(gt_elem))
|
| 409 |
+
else:
|
| 410 |
+
ma_elem = self.normalize_str(ma_elem, remove_punct=False)
|
| 411 |
+
gt_elem = self.normalize_str(gt_elem, remove_punct=False)
|
| 412 |
+
comparisons.append(ma_elem == gt_elem)
|
| 413 |
+
return all(comparisons)
|
| 414 |
+
else:
|
| 415 |
+
logger.info(f"Evaluating {model_answer} as a string.")
|
| 416 |
+
ma_elem = self.normalize_str(model_answer)
|
| 417 |
+
gt_elem = self.normalize_str(ground_truth)
|
| 418 |
+
return ma_elem == gt_elem
|
| 419 |
+
|
| 420 |
+
def normalize_number_str(self, number_str: str) -> float:
|
| 421 |
+
for char in ["$", "%", ","]:
|
| 422 |
+
number_str = number_str.replace(char, "")
|
| 423 |
+
try:
|
| 424 |
+
return float(number_str)
|
| 425 |
+
except ValueError:
|
| 426 |
+
logger.error(
|
| 427 |
+
f"String {number_str} cannot be normalized to number str."
|
| 428 |
+
)
|
| 429 |
+
return float("inf")
|
| 430 |
+
|
| 431 |
+
def split_string(
|
| 432 |
+
self, s: str, char_list: Optional[List[str]] = None
|
| 433 |
+
) -> list[str]:
|
| 434 |
+
r"""Split a string based on a list of characters.
|
| 435 |
+
|
| 436 |
+
Args:
|
| 437 |
+
s (str): The string to split.
|
| 438 |
+
char_list (Optional[List[str]], optional): T
|
| 439 |
+
he list of characters to split on.
|
| 440 |
+
(default: :obj:`None`)
|
| 441 |
+
"""
|
| 442 |
+
if char_list is None:
|
| 443 |
+
char_list = [",", ";"]
|
| 444 |
+
pattern = f"[{''.join(char_list)}]"
|
| 445 |
+
return re.split(pattern, s)
|
| 446 |
+
|
| 447 |
+
def normalize_str(self, input_str, remove_punct=True) -> str:
|
| 448 |
+
r"""Normalize a string.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
input_str: The input string to normalize.
|
| 452 |
+
remove_punct: Whether to remove punctuation.
|
| 453 |
+
|
| 454 |
+
Returns:
|
| 455 |
+
str: The normalized string.
|
| 456 |
+
"""
|
| 457 |
+
no_spaces = re.sub(r"\s", "", input_str)
|
| 458 |
+
if remove_punct:
|
| 459 |
+
translator = str.maketrans("", "", string.punctuation)
|
| 460 |
+
return no_spaces.lower().translate(translator)
|
| 461 |
+
else:
|
| 462 |
+
return no_spaces.lower()
|
| 463 |
+
|
| 464 |
+
def get_final_answer(self, content: str) -> str:
|
| 465 |
+
r"""Get the final answer from the content.
|
| 466 |
+
|
| 467 |
+
Args:
|
| 468 |
+
content (str): The content to extract the final answer from.
|
| 469 |
+
|
| 470 |
+
Returns:
|
| 471 |
+
str: The final answer.
|
| 472 |
+
"""
|
| 473 |
+
final_answer_index = content.find("FINAL ANSWER")
|
| 474 |
+
if final_answer_index == -1:
|
| 475 |
+
return "FINAL ANSWER not found"
|
| 476 |
+
start_index = final_answer_index + len("FINAL ANSWER: ")
|
| 477 |
+
final_answer_content = content[start_index:].strip()
|
| 478 |
+
return final_answer_content
|
Paper2Poster/camel/benchmarks/nexus.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
import ast
|
| 16 |
+
import json
|
| 17 |
+
import logging
|
| 18 |
+
import os
|
| 19 |
+
import random
|
| 20 |
+
import textwrap
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
|
| 24 |
+
|
| 25 |
+
import pandas as pd
|
| 26 |
+
from datasets import load_dataset
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from camel.agents import ChatAgent
|
| 30 |
+
from camel.benchmarks.base import BaseBenchmark
|
| 31 |
+
from camel.messages import BaseMessage
|
| 32 |
+
|
| 33 |
+
logger = logging.getLogger(__name__)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Define the data class
|
| 37 |
+
@dataclass
|
| 38 |
+
class NexusSample:
|
| 39 |
+
r"""Nexus benchmark dataset sample."""
|
| 40 |
+
|
| 41 |
+
input: str
|
| 42 |
+
output: str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
@dataclass
|
| 46 |
+
class NexusTool:
|
| 47 |
+
r"""Nexus benchmark tool"""
|
| 48 |
+
|
| 49 |
+
function_calls: str
|
| 50 |
+
descriptions: str
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
dataset_mapping = {
|
| 54 |
+
"NVDLibrary": "Nexusflow/NVDLibraryBenchmark",
|
| 55 |
+
"VirusTotal": "Nexusflow/VirusTotalBenchmark",
|
| 56 |
+
"PlacesAPI": "Nexusflow/PlacesAPIBenchmark",
|
| 57 |
+
"ClimateAPI": "Nexusflow/ClimateAPIBenchmark",
|
| 58 |
+
"OTX": "Nexusflow/OTXAPIBenchmark",
|
| 59 |
+
"VirusTotal-NestedCalls": "Nexusflow/vt_multiapi",
|
| 60 |
+
"VirusTotal-ParallelCalls": "Nexusflow/vt_multiapi",
|
| 61 |
+
"NVDLibrary-NestedCalls": "Nexusflow/CVECPEAPIBenchmark",
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
TOOL_CALLING_PROMPT = """
|
| 65 |
+
You are given multiple functions and a user query.
|
| 66 |
+
|
| 67 |
+
Please proceed with generating a function call for the function \
|
| 68 |
+
with the proper arguments that best answers the given prompt.
|
| 69 |
+
|
| 70 |
+
Respond with nothing but the function call ONLY, such that I can \
|
| 71 |
+
directly execute your function call without any post processing \
|
| 72 |
+
necessary from my end. Do not use variables.
|
| 73 |
+
If there are more than two function calls, separate them with a semicolon (;).
|
| 74 |
+
|
| 75 |
+
{tools}
|
| 76 |
+
|
| 77 |
+
Question: {input}
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class NexusBenchmark(BaseBenchmark):
|
| 82 |
+
r"""Nexus Function Calling Benchmark adapted from `NexusRaven V2
|
| 83 |
+
Function Calling Benchmark`
|
| 84 |
+
<https://huggingface.co/collections/Nexusflow/nexusraven-v2-function-calling-benchmark-657a597fb84dbe7a09ebfc3e>.
|
| 85 |
+
|
| 86 |
+
Args:
|
| 87 |
+
data_dir (str): The directory to save the data.
|
| 88 |
+
save_to (str): The file to save the results.
|
| 89 |
+
processes (int, optional): The number of processes to use.
|
| 90 |
+
(default: :obj:`1`)
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
data_dir: str,
|
| 96 |
+
save_to: str,
|
| 97 |
+
processes: int = 1,
|
| 98 |
+
):
|
| 99 |
+
r"""Initialize the Nexus Function Calling benchmark.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
data_dir (str): The directory to save the data.
|
| 103 |
+
save_to (str): The file to save the results.
|
| 104 |
+
processes (int, optional): The number of processes to use for
|
| 105 |
+
parallel processing. (default: :obj:`1`)
|
| 106 |
+
"""
|
| 107 |
+
super().__init__("nexus", data_dir, save_to, processes)
|
| 108 |
+
self._data: List[NexusSample] = [] # type: ignore[assignment]
|
| 109 |
+
|
| 110 |
+
def download(self):
|
| 111 |
+
r"""Download the Nexus Functional Calling Benchmark dataset."""
|
| 112 |
+
from huggingface_hub import snapshot_download
|
| 113 |
+
|
| 114 |
+
for dataset_name, repo_id in dataset_mapping.items():
|
| 115 |
+
local_dir = self.data_dir / dataset_name
|
| 116 |
+
snapshot_download(
|
| 117 |
+
repo_id=repo_id,
|
| 118 |
+
repo_type="dataset",
|
| 119 |
+
local_dir=local_dir,
|
| 120 |
+
local_dir_use_symlinks=True,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def load(self, dataset_name: str, force_download: bool = False): # type: ignore[override]
|
| 124 |
+
r"""Load the Nexus Benchmark dataset.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
dataset_name (str): Name of the specific dataset to be loaded.
|
| 128 |
+
force_download (bool): Whether to force download the data.
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
def _load_csv_data(dataset_dir: Path) -> List:
|
| 132 |
+
r"""Load datasets from CSV files."""
|
| 133 |
+
dataset = []
|
| 134 |
+
for file_name in os.listdir(dataset_dir):
|
| 135 |
+
file_path = dataset_dir / file_name
|
| 136 |
+
if file_name.endswith(".csv"):
|
| 137 |
+
data = pd.read_csv(file_path)
|
| 138 |
+
for _, sample in data.iterrows():
|
| 139 |
+
dataset.append(
|
| 140 |
+
NexusSample(
|
| 141 |
+
sample["Input"], "".join(sample["Output"])
|
| 142 |
+
)
|
| 143 |
+
)
|
| 144 |
+
continue
|
| 145 |
+
|
| 146 |
+
logger.warning(f"Skipping unsupported file: {file_name}")
|
| 147 |
+
return dataset
|
| 148 |
+
|
| 149 |
+
def _load_parquet_data(data_dir: Path, dataset_name: str) -> List:
|
| 150 |
+
r"""Load datasets from Parquet files."""
|
| 151 |
+
dataset = []
|
| 152 |
+
if not data_dir.exists():
|
| 153 |
+
raise FileNotFoundError(
|
| 154 |
+
f"Data directory '{data_dir}' does not exist."
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
for file_name in os.listdir(data_dir):
|
| 158 |
+
file_path = data_dir / file_name
|
| 159 |
+
if file_name.endswith(".parquet"):
|
| 160 |
+
data = pd.read_parquet(file_path)
|
| 161 |
+
dataset.extend(_process_parquet_data(data, dataset_name))
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
logger.warning(f"Skipping unsupported file: {file_name}")
|
| 165 |
+
|
| 166 |
+
return dataset
|
| 167 |
+
|
| 168 |
+
def _process_parquet_data(
|
| 169 |
+
data: pd.DataFrame, dataset_name: str
|
| 170 |
+
) -> List:
|
| 171 |
+
r"""Process data from Parquet files based on dataset name."""
|
| 172 |
+
dataset: List = []
|
| 173 |
+
dataset_handlers = {
|
| 174 |
+
"NVDLibrary": _process_nvdlibrary,
|
| 175 |
+
"VirusTotal": _process_simple,
|
| 176 |
+
"PlacesAPI": _process_simple,
|
| 177 |
+
"ClimateAPI": _process_simple,
|
| 178 |
+
"OTX": _process_simple,
|
| 179 |
+
"VirusTotal-NestedCalls": _process_nested_calls,
|
| 180 |
+
"VirusTotal-ParallelCalls": _process_parallel_calls,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
if dataset_name not in dataset_handlers:
|
| 184 |
+
logger.warning(
|
| 185 |
+
f"No specific handler for dataset: {dataset_name}"
|
| 186 |
+
)
|
| 187 |
+
return dataset
|
| 188 |
+
|
| 189 |
+
handler = dataset_handlers[dataset_name]
|
| 190 |
+
for _, sample in data.iterrows():
|
| 191 |
+
processed_sample = handler(sample)
|
| 192 |
+
if processed_sample:
|
| 193 |
+
dataset.append(processed_sample)
|
| 194 |
+
return dataset
|
| 195 |
+
|
| 196 |
+
def _process_nvdlibrary(sample) -> NexusSample:
|
| 197 |
+
r"""Process samples for the NVDLibrary dataset."""
|
| 198 |
+
return NexusSample(
|
| 199 |
+
sample["Input"], sample["Output"].replace("r = nvdlib.", "")
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def _process_simple(sample) -> NexusSample:
|
| 203 |
+
r"""Process samples for simple datasets (e.g., VirusTotal)."""
|
| 204 |
+
return NexusSample(sample["Input"], sample["Output"])
|
| 205 |
+
|
| 206 |
+
def _process_nested_calls(sample) -> Union[NexusSample, None]:
|
| 207 |
+
r"""Process samples for VirusTotal-NestedCalls dataset."""
|
| 208 |
+
if len(sample["fncall"]) == 1:
|
| 209 |
+
return NexusSample(
|
| 210 |
+
sample["generated_question"], "".join(sample["fncall"])
|
| 211 |
+
)
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def _process_parallel_calls(sample) -> Union[NexusSample, None]:
|
| 215 |
+
r"""Process samples for VirusTotal-ParallelCalls dataset."""
|
| 216 |
+
if len(sample["fncall"]) > 1:
|
| 217 |
+
return NexusSample(
|
| 218 |
+
sample["generated_question"], "; ".join(sample["fncall"])
|
| 219 |
+
)
|
| 220 |
+
return None
|
| 221 |
+
|
| 222 |
+
if force_download:
|
| 223 |
+
logger.info("Force downloading data.")
|
| 224 |
+
self.download()
|
| 225 |
+
|
| 226 |
+
# Validate dataset name
|
| 227 |
+
if dataset_name not in dataset_mapping:
|
| 228 |
+
available_datasets = list(dataset_mapping.keys())
|
| 229 |
+
raise ValueError(
|
| 230 |
+
f"Dataset '{dataset_name}' is not recognized. "
|
| 231 |
+
f"Available datasets: {available_datasets}"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Get the dataset directory
|
| 235 |
+
dataset_dir = self.data_dir / dataset_name
|
| 236 |
+
if not dataset_dir.exists():
|
| 237 |
+
raise FileNotFoundError(
|
| 238 |
+
f"The dataset directory for '{dataset_name}' \
|
| 239 |
+
does not exist at {dataset_dir}. "
|
| 240 |
+
"Please download it first."
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Load the dataset
|
| 244 |
+
if dataset_name == "NVDLibrary-NestedCalls":
|
| 245 |
+
self._data = _load_csv_data(dataset_dir)
|
| 246 |
+
else:
|
| 247 |
+
self._data = _load_parquet_data(dataset_dir / "data", dataset_name)
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def train(self):
|
| 251 |
+
r"""Get the training set."""
|
| 252 |
+
raise NotImplementedError(
|
| 253 |
+
"Nexus Functional Calling has only a single 'train' set."
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def run( # type: ignore[override, return]
|
| 257 |
+
self,
|
| 258 |
+
agent: ChatAgent,
|
| 259 |
+
task: Literal[
|
| 260 |
+
"NVDLibrary",
|
| 261 |
+
"VirusTotal",
|
| 262 |
+
"OTX",
|
| 263 |
+
"PlacesAPI",
|
| 264 |
+
"ClimateAPI",
|
| 265 |
+
"VirusTotal-ParallelCalls",
|
| 266 |
+
"VirusTotal-NestedCalls",
|
| 267 |
+
"NVDLibrary-NestedCalls",
|
| 268 |
+
],
|
| 269 |
+
randomize: bool = False,
|
| 270 |
+
subset: Optional[int] = None,
|
| 271 |
+
) -> Dict[str, Any]:
|
| 272 |
+
r"""Run the benchmark.
|
| 273 |
+
|
| 274 |
+
Args:
|
| 275 |
+
agent (ChatAgent): The agent to run the benchmark.
|
| 276 |
+
task (Literal["NVDLibrary", "VirusTotal", "OTX",
|
| 277 |
+
"PlacesAPI", "ClimateAPI", "VirusTotal-ParallelCalls",
|
| 278 |
+
"VirusTotal-NestedCalls",
|
| 279 |
+
"NVDLibrary-NestedCalls"]): The task to run the benchmark.
|
| 280 |
+
randomize (bool, optional): Whether to randomize the data.
|
| 281 |
+
(default: :obj:`False`)
|
| 282 |
+
subset (Optional[int], optional): The subset of data to run.
|
| 283 |
+
(default: :obj:`None`)
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
Dict[str, Any]: The results of the benchmark.
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
if task not in dataset_mapping:
|
| 290 |
+
raise ValueError(f"Invalid value for dataset: {task}.")
|
| 291 |
+
|
| 292 |
+
logger.info(f"Running Nexus Function Calling benchmark on {task}.")
|
| 293 |
+
self.load(task)
|
| 294 |
+
datas = self._data
|
| 295 |
+
|
| 296 |
+
# Shuffle and subset data if necessary
|
| 297 |
+
if randomize:
|
| 298 |
+
random.shuffle(datas)
|
| 299 |
+
if subset:
|
| 300 |
+
datas = datas[:subset]
|
| 301 |
+
|
| 302 |
+
logger.info(f"Number of tasks: {len(datas)}")
|
| 303 |
+
|
| 304 |
+
# Initialize results storage
|
| 305 |
+
self._results = []
|
| 306 |
+
|
| 307 |
+
# Process samples
|
| 308 |
+
tools = construct_tool_descriptions(task)
|
| 309 |
+
with open(self.save_to, "w") as f:
|
| 310 |
+
for sample in tqdm(datas, desc="Running"):
|
| 311 |
+
prompt = construct_prompt(input=sample.input, tools=tools)
|
| 312 |
+
msg = BaseMessage.make_user_message(
|
| 313 |
+
role_name="User", content=prompt
|
| 314 |
+
)
|
| 315 |
+
ground_truth_call = sample.output
|
| 316 |
+
try:
|
| 317 |
+
# Generate response
|
| 318 |
+
response = agent.step(msg)
|
| 319 |
+
agent_call = response.msgs[0].content
|
| 320 |
+
|
| 321 |
+
# Evaluate response
|
| 322 |
+
if agent_call:
|
| 323 |
+
result = compare_function_calls(
|
| 324 |
+
agent_call=agent_call,
|
| 325 |
+
ground_truth_call=ground_truth_call,
|
| 326 |
+
)
|
| 327 |
+
self._results.append(
|
| 328 |
+
{
|
| 329 |
+
"input": sample.input,
|
| 330 |
+
"agent_call": agent_call,
|
| 331 |
+
"ground_truth_call": ground_truth_call,
|
| 332 |
+
"result": result,
|
| 333 |
+
"error": None,
|
| 334 |
+
}
|
| 335 |
+
)
|
| 336 |
+
except Exception as e:
|
| 337 |
+
logger.warning(f"Error in processing task: {sample.input}")
|
| 338 |
+
self._results.append(
|
| 339 |
+
{
|
| 340 |
+
"input": sample.input,
|
| 341 |
+
"agent_call": None,
|
| 342 |
+
"ground_truth_call": ground_truth_call,
|
| 343 |
+
"result": 0,
|
| 344 |
+
"error": str(e),
|
| 345 |
+
}
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
agent.reset()
|
| 349 |
+
|
| 350 |
+
f.write(json.dumps(self._results[-1], indent=2) + "\n")
|
| 351 |
+
f.flush()
|
| 352 |
+
|
| 353 |
+
total = len(self._results)
|
| 354 |
+
correct = sum(r["result"] for r in self._results)
|
| 355 |
+
|
| 356 |
+
return {
|
| 357 |
+
"total": total,
|
| 358 |
+
"correct": correct,
|
| 359 |
+
"accuracy": correct / total,
|
| 360 |
+
}
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# Utility functions
|
| 364 |
+
def construct_tool_descriptions(dataset_name: str) -> str:
|
| 365 |
+
r"""Construct tool descriptions from function definitions and
|
| 366 |
+
descriptions."""
|
| 367 |
+
tool_dataset_mapping = {
|
| 368 |
+
"NVDLibrary": "CVECPE",
|
| 369 |
+
"VirusTotal": "VirusTotal",
|
| 370 |
+
"PlacesAPI": "Places",
|
| 371 |
+
"ClimateAPI": "Climate",
|
| 372 |
+
"OTX": "OTX",
|
| 373 |
+
"VirusTotal-NestedCalls": "VT_Multi (Nested)",
|
| 374 |
+
"VirusTotal-ParallelCalls": "VT_Multi (Parallel)",
|
| 375 |
+
"NVDLibrary-NestedCalls": "CVECPE_Multi (Nested)",
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
if dataset_name not in tool_dataset_mapping:
|
| 379 |
+
raise ValueError(
|
| 380 |
+
f"Dataset '{dataset_name}' is not recognized. "
|
| 381 |
+
f"Available datasets: {list(dataset_mapping.keys())}"
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
# Load the dataset based on the dataset name
|
| 385 |
+
dataset = load_dataset(
|
| 386 |
+
"Nexusflow/Function_Call_Definitions",
|
| 387 |
+
name=tool_dataset_mapping[dataset_name],
|
| 388 |
+
)["train"]
|
| 389 |
+
|
| 390 |
+
# Construct tool descriptions
|
| 391 |
+
tools = [
|
| 392 |
+
NexusTool(tool["function_calls"], tool["descriptions"])
|
| 393 |
+
for tool in dataset
|
| 394 |
+
]
|
| 395 |
+
|
| 396 |
+
# Generate the tool prompt
|
| 397 |
+
tool_prompt = "".join(
|
| 398 |
+
f"Function:\ndef {tool.function_calls}:\n"
|
| 399 |
+
+ "\"\"\"\n"
|
| 400 |
+
+ f"{tool.descriptions}\n"
|
| 401 |
+
+ "\"\"\"\n"
|
| 402 |
+
for tool in tools
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return tool_prompt
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
def construct_prompt(input: str, tools: str) -> str:
|
| 409 |
+
r"Construct prompt from tools and input."
|
| 410 |
+
return TOOL_CALLING_PROMPT.format(tools=tools, input=input)
|
| 411 |
+
|
| 412 |
+
|
| 413 |
+
# Functions for function call evaluation
|
| 414 |
+
def parse_function_call(
|
| 415 |
+
call: str,
|
| 416 |
+
) -> Tuple[Optional[str], Optional[List[Any]], Optional[Dict[str, Any]]]:
|
| 417 |
+
r"""Parse a function call string to extract the function name,
|
| 418 |
+
positional arguments, and keyword arguments, including
|
| 419 |
+
nested function calls.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
call (str): A string in the format `func(arg1, arg2, kwarg=value)`.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
tuple: (function_name (str), positional_args (list),
|
| 426 |
+
keyword_args (dict)) or (None, None, None).
|
| 427 |
+
"""
|
| 428 |
+
|
| 429 |
+
def preprocess_input(call: str) -> str:
|
| 430 |
+
r"""Remove formatting like code blocks and whitespace."""
|
| 431 |
+
if call.strip().startswith("```python"):
|
| 432 |
+
call = call.strip().removeprefix("```python").removesuffix("```")
|
| 433 |
+
return textwrap.dedent(call).strip()
|
| 434 |
+
|
| 435 |
+
def evaluate_arg(arg):
|
| 436 |
+
r"""Recursively evaluate arguments, including nested calls."""
|
| 437 |
+
if isinstance(arg, ast.Call):
|
| 438 |
+
# Recursively parse nested calls
|
| 439 |
+
func_name, args, kwargs = parse_function_call(ast.unparse(arg))
|
| 440 |
+
return func_name, args, kwargs
|
| 441 |
+
elif isinstance(
|
| 442 |
+
arg, ast.Constant
|
| 443 |
+
): # Handle literals like numbers, strings, etc.
|
| 444 |
+
return arg.value
|
| 445 |
+
elif isinstance(arg, ast.List): # Handle list literals
|
| 446 |
+
return [evaluate_arg(el) for el in arg.elts]
|
| 447 |
+
elif isinstance(arg, ast.Dict): # Handle dictionary literals
|
| 448 |
+
return {
|
| 449 |
+
evaluate_arg(k): evaluate_arg(v)
|
| 450 |
+
for k, v in zip(arg.keys, arg.values)
|
| 451 |
+
}
|
| 452 |
+
elif isinstance(arg, ast.Tuple): # Handle tuple literals
|
| 453 |
+
return tuple(evaluate_arg(el) for el in arg.elts)
|
| 454 |
+
else:
|
| 455 |
+
return ast.literal_eval(arg) # Safely evaluate other types
|
| 456 |
+
|
| 457 |
+
call = preprocess_input(call)
|
| 458 |
+
parsed_calls = []
|
| 459 |
+
|
| 460 |
+
try:
|
| 461 |
+
# Parse the string into an AST
|
| 462 |
+
parsed_calls = call.split(";")
|
| 463 |
+
for single_call in parsed_calls:
|
| 464 |
+
tree = ast.parse(single_call, mode='eval')
|
| 465 |
+
|
| 466 |
+
# Ensure it's a function call
|
| 467 |
+
if isinstance(tree.body, ast.Call):
|
| 468 |
+
# Extract function name
|
| 469 |
+
if isinstance(
|
| 470 |
+
tree.body.func, ast.Name
|
| 471 |
+
): # Simple function call
|
| 472 |
+
func_name = tree.body.func.id
|
| 473 |
+
elif isinstance(
|
| 474 |
+
tree.body.func, ast.Attribute
|
| 475 |
+
): # Attribute function call
|
| 476 |
+
func_name = (
|
| 477 |
+
f"{tree.body.func.value.id}.{tree.body.func.attr}" # type: ignore[attr-defined]
|
| 478 |
+
)
|
| 479 |
+
else:
|
| 480 |
+
raise ValueError(f"Unsupported function call: {call}")
|
| 481 |
+
|
| 482 |
+
# Extract positional arguments
|
| 483 |
+
args = [evaluate_arg(arg) for arg in tree.body.args]
|
| 484 |
+
|
| 485 |
+
# Extract keyword arguments
|
| 486 |
+
kwargs: Dict[str, Any] = {
|
| 487 |
+
kw.arg: evaluate_arg(kw.value)
|
| 488 |
+
for kw in tree.body.keywords
|
| 489 |
+
if kw.arg is not None
|
| 490 |
+
}
|
| 491 |
+
logger.info("Valid call.")
|
| 492 |
+
return func_name, args, kwargs
|
| 493 |
+
else:
|
| 494 |
+
raise ValueError(f"Not a valid function call: {call}")
|
| 495 |
+
except Exception as e:
|
| 496 |
+
logger.info(f"Error parsing call: {call}, {e}")
|
| 497 |
+
return None, None, None
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def compare_function_calls(agent_call: str, ground_truth_call: str) -> bool:
|
| 501 |
+
r"""Compare the function name and arguments of
|
| 502 |
+
agent_call and ground_truth_call.
|
| 503 |
+
Args:
|
| 504 |
+
agent_call (str): Function call by agent.
|
| 505 |
+
ground_truth_call (str): Ground truth function call.
|
| 506 |
+
|
| 507 |
+
Returns:
|
| 508 |
+
- `True` if the function names and arguments match.
|
| 509 |
+
- `False` otherwise.
|
| 510 |
+
"""
|
| 511 |
+
# Parse both calls
|
| 512 |
+
agent_parsed = parse_function_call(agent_call)
|
| 513 |
+
gt_parsed = parse_function_call(ground_truth_call)
|
| 514 |
+
|
| 515 |
+
if agent_parsed and gt_parsed:
|
| 516 |
+
return agent_parsed == gt_parsed
|
| 517 |
+
else:
|
| 518 |
+
return False
|
Paper2Poster/camel/benchmarks/ragbench.py
ADDED
|
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 3 |
+
# you may not use this file except in compliance with the License.
|
| 4 |
+
# You may obtain a copy of the License at
|
| 5 |
+
#
|
| 6 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 7 |
+
#
|
| 8 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 9 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 10 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 11 |
+
# See the License for the specific language governing permissions and
|
| 12 |
+
# limitations under the License.
|
| 13 |
+
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
|
| 14 |
+
|
| 15 |
+
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
from datasets import Dataset, load_dataset
|
| 19 |
+
|
| 20 |
+
from camel.agents import ChatAgent
|
| 21 |
+
from camel.benchmarks import BaseBenchmark
|
| 22 |
+
from camel.logger import get_logger
|
| 23 |
+
from camel.retrievers import AutoRetriever
|
| 24 |
+
|
| 25 |
+
logger = get_logger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class RagasFields:
|
| 29 |
+
r"""Constants for RAGAS evaluation field names."""
|
| 30 |
+
|
| 31 |
+
INPUT_CONTEXT = "contexts"
|
| 32 |
+
INPUT_QUESTION = "question"
|
| 33 |
+
INPUT_ANSWER = "answer"
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def annotate_dataset(
|
| 37 |
+
dataset: Dataset,
|
| 38 |
+
context_call: Optional[Callable[[Dict[str, Any]], List[str]]],
|
| 39 |
+
answer_call: Optional[Callable[[Dict[str, Any]], str]],
|
| 40 |
+
) -> Dataset:
|
| 41 |
+
r"""Annotate the dataset by adding context and answers using the provided
|
| 42 |
+
functions.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
dataset (Dataset): The input dataset to annotate.
|
| 46 |
+
context_call (Optional[Callable[[Dict[str, Any]], List[str]]]):
|
| 47 |
+
Function to generate context for each example.
|
| 48 |
+
answer_call (Optional[Callable[[Dict[str, Any]], str]]): Function to
|
| 49 |
+
generate answer for each example.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
Dataset: The annotated dataset with added contexts and/or answers.
|
| 53 |
+
"""
|
| 54 |
+
|
| 55 |
+
def process_example(example: Dict[str, Any]) -> Dict[str, Any]:
|
| 56 |
+
if context_call:
|
| 57 |
+
example["contexts"] = context_call(example)
|
| 58 |
+
if answer_call:
|
| 59 |
+
example["answer"] = answer_call(example)
|
| 60 |
+
return example
|
| 61 |
+
|
| 62 |
+
return dataset.map(process_example)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def rmse(
|
| 66 |
+
input_trues: Sequence[float],
|
| 67 |
+
input_preds: Sequence[float],
|
| 68 |
+
) -> Optional[float]:
|
| 69 |
+
r"""Calculate Root Mean Squared Error (RMSE).
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
input_trues (Sequence[float]): Ground truth values.
|
| 73 |
+
input_preds (Sequence[float]): Predicted values.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Optional[float]: RMSE value, or None if inputs have different lengths.
|
| 77 |
+
"""
|
| 78 |
+
if len(input_trues) != len(input_preds):
|
| 79 |
+
logger.warning("Input lengths mismatch in RMSE calculation")
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
trues = np.array(input_trues)
|
| 83 |
+
preds = np.array(input_preds, dtype=float)
|
| 84 |
+
|
| 85 |
+
# Ignore NaN values in predictions
|
| 86 |
+
eval_idx = ~np.isnan(preds)
|
| 87 |
+
if not np.any(eval_idx):
|
| 88 |
+
logger.warning("No valid predictions for RMSE calculation")
|
| 89 |
+
return None
|
| 90 |
+
|
| 91 |
+
trues = trues[eval_idx]
|
| 92 |
+
preds = preds[eval_idx]
|
| 93 |
+
|
| 94 |
+
return float(np.sqrt(np.mean((preds - trues) ** 2)))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def auroc(trues: Sequence[bool], preds: Sequence[float]) -> float:
|
| 98 |
+
r"""Calculate Area Under Receiver Operating Characteristic Curve (AUROC).
|
| 99 |
+
|
| 100 |
+
Args:
|
| 101 |
+
trues (Sequence[bool]): Ground truth binary values.
|
| 102 |
+
preds (Sequence[float]): Predicted probability values.
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
float: AUROC score.
|
| 106 |
+
"""
|
| 107 |
+
from sklearn.metrics import roc_auc_score # type: ignore[import-untyped]
|
| 108 |
+
|
| 109 |
+
eval_idx = ~np.isnan(preds)
|
| 110 |
+
if not np.any(eval_idx):
|
| 111 |
+
logger.warning("No valid predictions for AUROC calculation")
|
| 112 |
+
return 0.5 # Return random classifier score
|
| 113 |
+
|
| 114 |
+
return float(
|
| 115 |
+
roc_auc_score(np.array(trues)[eval_idx], np.array(preds)[eval_idx])
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def ragas_calculate_metrics(
|
| 120 |
+
dataset: Dataset,
|
| 121 |
+
pred_context_relevance_field: Optional[str],
|
| 122 |
+
pred_faithfulness_field: Optional[str],
|
| 123 |
+
metrics_to_evaluate: Optional[List[str]] = None,
|
| 124 |
+
ground_truth_context_relevance_field: str = "relevance_score",
|
| 125 |
+
ground_truth_faithfulness_field: str = "adherence_score",
|
| 126 |
+
) -> Dict[str, Optional[float]]:
|
| 127 |
+
r"""Calculate RAGAS evaluation metrics.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
dataset (Dataset): The dataset containing predictions and ground truth.
|
| 131 |
+
pred_context_relevance_field (Optional[str]): Field name for predicted
|
| 132 |
+
context relevance.
|
| 133 |
+
pred_faithfulness_field (Optional[str]): Field name for predicted
|
| 134 |
+
faithfulness.
|
| 135 |
+
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
| 136 |
+
ground_truth_context_relevance_field (str): Field name for ground truth
|
| 137 |
+
relevance.
|
| 138 |
+
ground_truth_faithfulness_field (str): Field name for ground truth
|
| 139 |
+
adherence.
|
| 140 |
+
|
| 141 |
+
Returns:
|
| 142 |
+
Dict[str, Optional[float]]: Dictionary of calculated metrics.
|
| 143 |
+
"""
|
| 144 |
+
metrics_to_evaluate = metrics_to_evaluate or [
|
| 145 |
+
"context_relevancy",
|
| 146 |
+
"faithfulness",
|
| 147 |
+
]
|
| 148 |
+
calculated_metrics: Dict[str, Optional[float]] = {}
|
| 149 |
+
|
| 150 |
+
if (
|
| 151 |
+
"context_relevancy" in metrics_to_evaluate
|
| 152 |
+
and pred_context_relevance_field
|
| 153 |
+
):
|
| 154 |
+
trues_relevance = dataset[ground_truth_context_relevance_field]
|
| 155 |
+
preds_relevance = dataset[pred_context_relevance_field]
|
| 156 |
+
calculated_metrics["relevance_rmse"] = rmse(
|
| 157 |
+
trues_relevance, preds_relevance
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
if "faithfulness" in metrics_to_evaluate and pred_faithfulness_field:
|
| 161 |
+
trues_hallucination = ~np.array(
|
| 162 |
+
dataset[ground_truth_faithfulness_field]
|
| 163 |
+
)
|
| 164 |
+
preds_hallucination = 1 - np.array(
|
| 165 |
+
dataset[pred_faithfulness_field], dtype=float
|
| 166 |
+
)
|
| 167 |
+
calculated_metrics["hallucination_auroc"] = auroc(
|
| 168 |
+
trues_hallucination.tolist(), preds_hallucination.tolist()
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
return calculated_metrics
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def ragas_evaluate_dataset(
|
| 175 |
+
dataset: Dataset,
|
| 176 |
+
contexts_field_name: Optional[str],
|
| 177 |
+
answer_field_name: Optional[str],
|
| 178 |
+
metrics_to_evaluate: Optional[List[str]] = None,
|
| 179 |
+
) -> Dataset:
|
| 180 |
+
r"""Evaluate the dataset using RAGAS metrics.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
dataset (Dataset): Input dataset to evaluate.
|
| 184 |
+
contexts_field_name (Optional[str]): Field name containing contexts.
|
| 185 |
+
answer_field_name (Optional[str]): Field name containing answers.
|
| 186 |
+
metrics_to_evaluate (Optional[List[str]]): List of metrics to evaluate.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
Dataset: Dataset with added evaluation metrics.
|
| 190 |
+
"""
|
| 191 |
+
from ragas import evaluate
|
| 192 |
+
from ragas.metrics import ( # type: ignore[import-untyped]
|
| 193 |
+
context_relevancy,
|
| 194 |
+
faithfulness,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
metrics_to_evaluate = metrics_to_evaluate or [
|
| 198 |
+
"context_relevancy",
|
| 199 |
+
"faithfulness",
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
# Rename fields if necessary
|
| 203 |
+
if (
|
| 204 |
+
contexts_field_name
|
| 205 |
+
and contexts_field_name != RagasFields.INPUT_CONTEXT
|
| 206 |
+
):
|
| 207 |
+
dataset = dataset.rename_column(
|
| 208 |
+
contexts_field_name, RagasFields.INPUT_CONTEXT
|
| 209 |
+
)
|
| 210 |
+
if answer_field_name and answer_field_name != RagasFields.INPUT_ANSWER:
|
| 211 |
+
dataset = dataset.rename_column(
|
| 212 |
+
answer_field_name, RagasFields.INPUT_ANSWER
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
metrics = []
|
| 216 |
+
if "context_relevancy" in metrics_to_evaluate:
|
| 217 |
+
metrics.append(context_relevancy)
|
| 218 |
+
if "faithfulness" in metrics_to_evaluate:
|
| 219 |
+
metrics.append(faithfulness)
|
| 220 |
+
|
| 221 |
+
ragas_result = evaluate(dataset, metrics=metrics)
|
| 222 |
+
return Dataset.from_pandas(ragas_result.to_pandas())
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class RAGBenchBenchmark(BaseBenchmark):
|
| 226 |
+
r"""RAGBench Benchmark for evaluating RAG performance.
|
| 227 |
+
|
| 228 |
+
This benchmark uses the rungalileo/ragbench dataset to evaluate
|
| 229 |
+
retrieval-augmented generation (RAG) systems. It measures context
|
| 230 |
+
relevancy and faithfulness metrics as described in
|
| 231 |
+
https://arxiv.org/abs/2407.11005.
|
| 232 |
+
|
| 233 |
+
Args:
|
| 234 |
+
processes (int, optional): Number of processes for parallel processing.
|
| 235 |
+
subset (str, optional): Dataset subset to use (e.g., "hotpotqa").
|
| 236 |
+
split (str, optional): Dataset split to use (e.g., "test").
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
def __init__(
|
| 240 |
+
self,
|
| 241 |
+
processes: int = 1,
|
| 242 |
+
subset: Literal[
|
| 243 |
+
"covidqa",
|
| 244 |
+
"cuad",
|
| 245 |
+
"delucionqa",
|
| 246 |
+
"emanual",
|
| 247 |
+
"expertqa",
|
| 248 |
+
"finqa",
|
| 249 |
+
"hagrid",
|
| 250 |
+
"hotpotqa",
|
| 251 |
+
"msmarco",
|
| 252 |
+
"pubmedqa",
|
| 253 |
+
"tatqa",
|
| 254 |
+
"techqa",
|
| 255 |
+
] = "hotpotqa",
|
| 256 |
+
split: Literal["train", "test", "validation"] = "test",
|
| 257 |
+
) -> None:
|
| 258 |
+
super().__init__("ragbench", "rag_bench", "", processes)
|
| 259 |
+
self.subset = subset
|
| 260 |
+
self.split = split
|
| 261 |
+
self.dataset: Optional[Dataset] = None
|
| 262 |
+
|
| 263 |
+
def download(self):
|
| 264 |
+
r"""Download the RAGBench dataset."""
|
| 265 |
+
try:
|
| 266 |
+
self.dataset = load_dataset(
|
| 267 |
+
"rungalileo/ragbench", self.subset, split=self.split
|
| 268 |
+
)
|
| 269 |
+
except Exception as e:
|
| 270 |
+
logger.error(f"Failed to download dataset: {e}")
|
| 271 |
+
raise
|
| 272 |
+
|
| 273 |
+
def load(self, force_download: bool = False):
|
| 274 |
+
r"""Load the RAGBench dataset.
|
| 275 |
+
|
| 276 |
+
Args:
|
| 277 |
+
force_download (bool, optional): Whether to force download the
|
| 278 |
+
data.
|
| 279 |
+
"""
|
| 280 |
+
if force_download or self.dataset is None:
|
| 281 |
+
logger.info(
|
| 282 |
+
"%s dataset",
|
| 283 |
+
"Force downloading" if force_download else "Loading",
|
| 284 |
+
)
|
| 285 |
+
self.download()
|
| 286 |
+
|
| 287 |
+
def run( # type: ignore[override, return]
|
| 288 |
+
self,
|
| 289 |
+
agent: ChatAgent,
|
| 290 |
+
auto_retriever: AutoRetriever,
|
| 291 |
+
) -> Dict[str, Optional[float]]:
|
| 292 |
+
r"""Run the benchmark evaluation.
|
| 293 |
+
|
| 294 |
+
Args:
|
| 295 |
+
agent (ChatAgent): Chat agent for generating answers.
|
| 296 |
+
auto_retriever (AutoRetriever): Retriever for finding relevant
|
| 297 |
+
contexts.
|
| 298 |
+
|
| 299 |
+
Returns:
|
| 300 |
+
Dict[str, Optional[float]]: Dictionary of evaluation metrics.
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def context_call(example):
|
| 304 |
+
retrieved_info = auto_retriever.run_vector_retriever(
|
| 305 |
+
query=example['question'],
|
| 306 |
+
contents=example['documents'],
|
| 307 |
+
top_k=1,
|
| 308 |
+
return_detailed_info=True,
|
| 309 |
+
similarity_threshold=0.5,
|
| 310 |
+
)
|
| 311 |
+
return [c['text'] for c in retrieved_info['Retrieved Context']]
|
| 312 |
+
|
| 313 |
+
def answer_call(example: Dict[str, Any]) -> str:
|
| 314 |
+
user_msg = str(example)
|
| 315 |
+
assistant_response = agent.step(user_msg)
|
| 316 |
+
return assistant_response.msg.content
|
| 317 |
+
|
| 318 |
+
# Annotate the dataset
|
| 319 |
+
annotated_ds = annotate_dataset(
|
| 320 |
+
self.dataset, context_call, answer_call
|
| 321 |
+
)
|
| 322 |
+
evaluated_ds = ragas_evaluate_dataset(
|
| 323 |
+
annotated_ds,
|
| 324 |
+
contexts_field_name="contexts",
|
| 325 |
+
answer_field_name="answer",
|
| 326 |
+
metrics_to_evaluate=["context_relevancy", "faithfulness"],
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return ragas_calculate_metrics(
|
| 330 |
+
evaluated_ds,
|
| 331 |
+
pred_context_relevance_field="context_relevancy",
|
| 332 |
+
pred_faithfulness_field="faithfulness",
|
| 333 |
+
)
|