| import yaml | |
| import json | |
| from utils.wei_utils import account_token | |
| from jinja2 import Environment, StrictUndefined | |
| from camel.models import ModelFactory | |
| from camel.agents import ChatAgent | |
| from camel.messages import BaseMessage | |
| from utils.src.utils import get_json_from_response | |
| def no_tree_get_layout(poster_width, poster_height, panels, figures, agent_config): | |
| total_input_token, total_output_token = 0, 0 | |
| agent_name = 'ablation_no_tree_layout' | |
| with open(f"prompt_templates/{agent_name}.yaml", "r") as f: | |
| planner_config = yaml.safe_load(f) | |
| jinja_env = Environment(undefined=StrictUndefined) | |
| template = jinja_env.from_string(planner_config["template"]) | |
| planner_jinja_args = { | |
| 'poster_width': poster_width, | |
| 'poster_height': poster_height, | |
| 'panels': json.dumps(panels, indent=4), | |
| 'figures': json.dumps(figures, indent=4), | |
| } | |
| planner_model = ModelFactory.create( | |
| model_platform=agent_config['model_platform'], | |
| model_type=agent_config['model_type'], | |
| model_config_dict=agent_config['model_config'], | |
| ) | |
| planner_agent = ChatAgent( | |
| system_message=planner_config['system_prompt'], | |
| model=planner_model, | |
| message_window_size=None, | |
| ) | |
| planner_prompt = template.render(**planner_jinja_args) | |
| num_trials = 0 | |
| while True: | |
| num_trials += 1 | |
| print(f"Trial {num_trials}: Generating layout...") | |
| planner_agent.reset() | |
| response = planner_agent.step(planner_prompt) | |
| input_token, output_token = account_token(response) | |
| total_input_token += input_token | |
| total_output_token += output_token | |
| arrangements = get_json_from_response(response.msgs[0].content) | |
| if len(arrangements) == 0: | |
| print('Error: Empty response, retrying...') | |
| continue | |
| if not 'panel_arrangement' in arrangements or\ | |
| not 'figure_arrangement' in arrangements or\ | |
| not 'text_arrangement' in arrangements: | |
| print('Error: Invalid response, retrying...') | |
| continue | |
| if len(arrangements['panel_arrangement']) != len(panels) or\ | |
| len(arrangements['figure_arrangement']) != len(figures): | |
| print('Error: Invalid response, retrying...') | |
| continue | |
| break | |
| return arrangements['panel_arrangement'], arrangements['figure_arrangement'], arrangements['text_arrangement'], input_token, output_token |