File size: 2,461 Bytes
7c08dc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import yaml
import json
from utils.wei_utils import account_token
from jinja2 import Environment, StrictUndefined
from camel.models import ModelFactory
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from utils.src.utils import get_json_from_response
def no_tree_get_layout(poster_width, poster_height, panels, figures, agent_config):
total_input_token, total_output_token = 0, 0
agent_name = 'ablation_no_tree_layout'
with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
planner_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(planner_config["template"])
planner_jinja_args = {
'poster_width': poster_width,
'poster_height': poster_height,
'panels': json.dumps(panels, indent=4),
'figures': json.dumps(figures, indent=4),
}
planner_model = ModelFactory.create(
model_platform=agent_config['model_platform'],
model_type=agent_config['model_type'],
model_config_dict=agent_config['model_config'],
)
planner_agent = ChatAgent(
system_message=planner_config['system_prompt'],
model=planner_model,
message_window_size=None,
)
planner_prompt = template.render(**planner_jinja_args)
num_trials = 0
while True:
num_trials += 1
print(f"Trial {num_trials}: Generating layout...")
planner_agent.reset()
response = planner_agent.step(planner_prompt)
input_token, output_token = account_token(response)
total_input_token += input_token
total_output_token += output_token
arrangements = get_json_from_response(response.msgs[0].content)
if len(arrangements) == 0:
print('Error: Empty response, retrying...')
continue
if not 'panel_arrangement' in arrangements or\
not 'figure_arrangement' in arrangements or\
not 'text_arrangement' in arrangements:
print('Error: Invalid response, retrying...')
continue
if len(arrangements['panel_arrangement']) != len(panels) or\
len(arrangements['figure_arrangement']) != len(figures):
print('Error: Invalid response, retrying...')
continue
break
return arrangements['panel_arrangement'], arrangements['figure_arrangement'], arrangements['text_arrangement'], input_token, output_token |