File size: 2,461 Bytes
7c08dc3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import yaml
import json
from utils.wei_utils import account_token
from jinja2 import Environment, StrictUndefined
from camel.models import ModelFactory
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from utils.src.utils import get_json_from_response

def no_tree_get_layout(poster_width, poster_height, panels, figures, agent_config):
    total_input_token, total_output_token = 0, 0
    agent_name = 'ablation_no_tree_layout'
    with open(f"prompt_templates/{agent_name}.yaml", "r") as f:
        planner_config = yaml.safe_load(f)

    jinja_env = Environment(undefined=StrictUndefined)
    template = jinja_env.from_string(planner_config["template"])
    planner_jinja_args = {
        'poster_width': poster_width,
        'poster_height': poster_height,
        'panels': json.dumps(panels, indent=4),
        'figures': json.dumps(figures, indent=4),
    }

    planner_model = ModelFactory.create(
        model_platform=agent_config['model_platform'],
        model_type=agent_config['model_type'],
        model_config_dict=agent_config['model_config'],
    )

    planner_agent = ChatAgent(
        system_message=planner_config['system_prompt'],
        model=planner_model,
        message_window_size=None,
    )

    planner_prompt = template.render(**planner_jinja_args)

    num_trials = 0
    
    while True:
        num_trials += 1
        print(f"Trial {num_trials}: Generating layout...")
        planner_agent.reset()
        response = planner_agent.step(planner_prompt)
        input_token, output_token = account_token(response)
        total_input_token += input_token
        total_output_token += output_token
    
        arrangements = get_json_from_response(response.msgs[0].content)

        if len(arrangements) == 0:
            print('Error: Empty response, retrying...')
            continue

        if not 'panel_arrangement' in arrangements or\
        not 'figure_arrangement' in arrangements or\
        not 'text_arrangement' in arrangements:
            print('Error: Invalid response, retrying...')
            continue

        if len(arrangements['panel_arrangement']) != len(panels) or\
        len(arrangements['figure_arrangement']) != len(figures):
            print('Error: Invalid response, retrying...')
            continue
        break
    return arrangements['panel_arrangement'], arrangements['figure_arrangement'], arrangements['text_arrangement'], input_token, output_token