PaperShow / Paper2Poster /PosterAgent /gen_poster_content.py
ZaynZhu
Clean version without large assets
7c08dc3
import tempfile
import shutil
from dotenv import load_dotenv
from utils.src.utils import get_json_from_response
from concurrent.futures import ThreadPoolExecutor, as_completed
import json
from camel.models import ModelFactory
from PosterAgent.gen_pptx_code import generate_poster_code
from camel.agents import ChatAgent
from camel.messages import BaseMessage
from utils.src.utils import ppt_to_images
from PIL import Image
from utils.wei_utils import *
from utils.pptx_utils import *
from utils.critic_utils import *
import yaml
from jinja2 import Environment, StrictUndefined
import argparse
load_dotenv()
MAX_ATTEMPT = 10
def gen_content_process_section(
section_name,
outline,
raw_content,
raw_outline,
template,
create_actor_agent,
MAX_ATTEMPT
):
"""
Process a single section in its own thread or process.
Returns (section_name, result_json, total_input_token, total_output_token).
"""
# Create a fresh ActorAgent instance for each parallel call
actor_agent = create_actor_agent()
section_outline = ''
num_attempts = 0
total_input_token = 0
total_output_token = 0
result_json = None
while True:
print(f"[Thread] Generating content for section: {section_name}")
if len(section_outline) == 0:
# Initialize the section outline
section_outline = json.dumps(outline[section_name], indent=4)
# Render prompt using Jinja template
jinja_args = {
'json_outline': section_outline,
'json_content': raw_content,
}
prompt = template.render(**jinja_args)
# Step the actor_agent and track tokens
response = actor_agent.step(prompt)
input_token, output_token = account_token(response)
total_input_token += input_token
total_output_token += output_token
# Parse JSON and possibly adjust text length
result_json = get_json_from_response(response.msgs[0].content)
new_section_outline, suggested = generate_length_suggestions(
result_json,
json.dumps(outline[section_name]),
raw_outline[section_name]
)
section_outline = json.dumps(new_section_outline, indent=4)
if not suggested:
# No more adjustments needed
break
print(f"[Thread] Adjusting text length for section: {section_name}...")
num_attempts += 1
if num_attempts >= MAX_ATTEMPT:
break
return section_name, result_json, total_input_token, total_output_token
def gen_content_parallel_process_sections(
sections,
outline,
raw_content,
raw_outline,
template,
create_actor_agent,
MAX_ATTEMPT=3
):
"""
Parallelize the section processing using ThreadPoolExecutor.
"""
poster_content = {}
total_input_token = 0
total_output_token = 0
# Create a pool of worker threads (or processes)
with ThreadPoolExecutor() as executor:
futures = []
# Submit each section to be processed in parallel
for section_name in sections:
futures.append(
executor.submit(
gen_content_process_section,
section_name,
outline,
raw_content,
raw_outline,
template,
create_actor_agent,
MAX_ATTEMPT
)
)
# Collect results as they complete
for future in as_completed(futures):
section_name, result_json, sec_input_token, sec_output_token = future.result()
poster_content[section_name] = result_json
total_input_token += sec_input_token
total_output_token += sec_output_token
return poster_content, total_input_token, total_output_token
def render_textbox(text_arrangement, textbox_content, tmp_dir):
arrangement = copy.deepcopy(text_arrangement)
arrangement['x'] = 1
arrangement['y'] = 1
poster_code = generate_poster_code(
[],
[arrangement],
[],
presentation_object_name='poster_presentation',
slide_object_name='poster_slide',
utils_functions=utils_functions,
slide_width=text_arrangement['width'] + 3,
slide_height=text_arrangement['height'] + 3,
img_path='placeholder.jpg',
save_path=f'{tmp_dir}/poster.pptx',
visible=True,
content=textbox_content,
check_overflow=True,
tmp_dir=tmp_dir,
)
output, err = run_code(poster_code)
ppt_to_images(f'{tmp_dir}/poster.pptx', tmp_dir, output_type='jpg')
img = Image.open(f'{tmp_dir}/poster.jpg')
return img
def gen_poster_title_content(args, actor_config):
total_input_token, total_output_token = 0, 0
raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
actor_agent_name = 'poster_title_agent'
title_string = raw_content['meta']
with open(f'utils/prompt_templates/{actor_agent_name}.yaml', "r") as f:
content_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(content_config["template"])
if args.model_name_t == 'vllm_qwen':
actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config'],
url=actor_config['url'],
)
else:
actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config']
)
actor_sys_msg = content_config['system_prompt']
actor_agent = ChatAgent(
system_message=actor_sys_msg,
model=actor_model,
message_window_size=30
)
jinja_args = {
'title_string': title_string,
'title_font_size': getattr(args, 'poster_title_font_size', None) or getattr(args, 'title_font_size', None),
'author_font_size': getattr(args, 'poster_author_font_size', None) or getattr(args, 'author_font_size', None),
}
prompt = template.render(**jinja_args)
# Step the actor_agent and track tokens
actor_agent.reset()
response = actor_agent.step(prompt)
input_token, output_token = account_token(response)
total_input_token += input_token
total_output_token += output_token
result_json = get_json_from_response(response.msgs[0].content)
return result_json, total_input_token, total_output_token
def gen_bullet_point_content(args, actor_config, critic_config, agent_modify=True, tmp_dir='tmp'):
import json, yaml, copy, threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from PIL import Image
from jinja2 import Environment, StrictUndefined
# ----------------------- Load data & configs -----------------------
total_input_token_t = total_output_token_t = 0
total_input_token_v = total_output_token_v = 0
raw_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_raw_content.json', 'r'))
with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'r') as f:
tree_split_results = json.load(f)
panels = tree_split_results['panels']
text_arrangement_list = tree_split_results['text_arrangement_inches']
actor_agent_name = 'bullet_point_agent'
if args.model_name_v == 'vllm_qwen_vl':
critic_agent_name = 'critic_overlap_agent_v3_short'
else:
critic_agent_name = 'critic_overlap_agent_v3'
with open(f"utils/prompt_templates/{actor_agent_name}.yaml", "r") as f:
content_config = yaml.safe_load(f)
with open(f"utils/prompt_templates/{critic_agent_name}.yaml", "r") as f:
critic_content_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(content_config["template"])
critic_template = jinja_env.from_string(critic_content_config["template"])
# Preload images once (each worker can reopen if needed, or just pass paths)
neg_img_path = 'assets/overflow_example_v2/neg.jpg'
pos_img_path = 'assets/overflow_example_v2/pos.jpg'
# Group text arrangements by panel_id for O(1) lookup in workers
from collections import defaultdict
textboxes_by_panel = defaultdict(list)
for ta in text_arrangement_list:
textboxes_by_panel[ta['panel_id']].append(ta)
# Ensure deterministic order inside each panel
for k in textboxes_by_panel:
textboxes_by_panel[k] = sorted(textboxes_by_panel[k], key=lambda x: x.get('textbox_id', 0))
# ----------------------- Worker (defined INSIDE main fn) -----------------------
def _process_section(i):
"""
Returns:
(i, result_json, t_in, t_out, v_in, v_out)
"""
local_t_in = local_t_out = 0
local_v_in = local_v_out = 0
arrangement = panels[i]
num_textboxes = 2 if arrangement.get('gp', 0) > 0 else 1
local_tmp_dir = tempfile.mkdtemp(prefix=f"sec_{i}_", dir=tmp_dir)
jinja_args = {
'summary_of_section': raw_content['sections'][i]['content'],
'number_of_textboxes': num_textboxes,
'section_title': raw_content['sections'][i]['title'],
'bullet_font_size': args.bullet_font_size,
'section_title_font_size': args.section_title_font_size,
}
target_textboxes = textboxes_by_panel[i][1:] # skip first (section title)
total_expected_length = sum(tb['num_chars'] for tb in target_textboxes)
# Create fresh models & agents per thread for safety
if args.model_name_t.startswith('vllm_qwen'):
actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config'],
url=actor_config['url'],
)
else:
actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config']
)
if args.model_name_v.startswith('vllm_qwen'):
critic_model = ModelFactory.create(
model_platform=critic_config['model_platform'],
model_type=critic_config['model_type'],
model_config_dict=critic_config['model_config'],
url=critic_config['url'],
)
else:
critic_model = ModelFactory.create(
model_platform=critic_config['model_platform'],
model_type=critic_config['model_type'],
model_config_dict=critic_config['model_config']
)
actor_agent = ChatAgent(system_message=content_config['system_prompt'], model=actor_model, message_window_size=30)
critic_agent = ChatAgent(system_message=critic_content_config['system_prompt'], model=critic_model, message_window_size=10)
prompt = template.render(**jinja_args)
actor_agent.reset()
response = actor_agent.step(prompt)
t_in, t_out = account_token(response)
local_t_in += t_in
local_t_out += t_out
result_json = get_json_from_response(response.msgs[0].content)
max_attempts = 5
num_attempts = 0
old_result_json = copy.deepcopy(result_json)
# Length control loop
while args.estimate_chars:
num_attempts += 1
if num_attempts > max_attempts:
result_json = old_result_json
break
try:
total_bullet_length = 0
for j in range(num_textboxes):
bullet_content_key = f'textbox{j + 1}'
total_bullet_length += compute_bullet_length(result_json[bullet_content_key])
except Exception:
result_json = old_result_json
break
if total_bullet_length > total_expected_length:
percentage_to_shrink = int((total_bullet_length - total_expected_length) / total_bullet_length * 100)
percentage_to_shrink = min(90, percentage_to_shrink + 10)
old_result_json = copy.deepcopy(result_json)
response = actor_agent.step('Too long, please shorten the bullet points by ' + str(percentage_to_shrink) + '%.')
t_in, t_out = account_token(response)
local_t_in += t_in
local_t_out += t_out
result_json = get_json_from_response(response.msgs[0].content)
else:
break
critic_prompt = critic_template.render()
bullet_contents = ['textbox1'] + (['textbox2'] if num_textboxes == 2 else [])
# Visual overflow/blank detection & correction
for j, text_arrangement in enumerate(target_textboxes[:num_textboxes]):
bullet_content = bullet_contents[j]
curr_round = 0
while True:
if args.ablation_no_commenter:
break
curr_round += 1
img = render_textbox(text_arrangement, result_json[bullet_content], local_tmp_dir)
if args.model_name_v.startswith('vllm_qwen') or args.ablation_no_example:
critic_msg = BaseMessage.make_user_message(
role_name="User",
content=critic_prompt,
image_list=[img],
)
else:
critic_msg = BaseMessage.make_user_message(
role_name="User",
content=critic_prompt,
image_list=[Image.open(neg_img_path), Image.open(pos_img_path), img],
)
critic_agent.reset()
response = critic_agent.step(critic_msg)
v_in, v_out = account_token(response)
local_v_in += v_in
local_v_out += v_out
decision = response.msgs[0].content.lower()
if decision in ['1', '1.', '"1"', "'1'"]:
if curr_round > 10:
print(f'Section {i}: Too many rounds of modification, breaking...')
break
if agent_modify:
print(f'Section {i}: Text overflow detected, modifying...')
modify_message = f'{bullet_content} is too long, please shorten that part, other content should stay the same. Return the entire modified JSON.'
response = actor_agent.step(modify_message)
t_in, t_out = account_token(response)
local_t_in += t_in
local_t_out += t_out
result_json = get_json_from_response(response.msgs[0].content)
else:
# naive truncate
result_json[bullet_content] = result_json[bullet_content][:-1]
continue
elif decision in ['2', '2.', '"2"', "'2'"]:
if args.no_blank_detection:
print(f'Section {i}: No blank space detection, skipping...')
break
if curr_round > 10:
print(f'Section {i}: Too many rounds of modification, breaking...')
break
print(f'Section {i}: Too much blank space detected, modifying...')
modify_message = f'{bullet_content} is too short, please add one more bullet point, other content should stay the same. Return the entire modified JSON.'
response = actor_agent.step(modify_message)
t_in, t_out = account_token(response)
local_t_in += t_in
local_t_out += t_out
result_json = get_json_from_response(response.msgs[0].content)
else:
break
# Clean up temp dir
if local_tmp_dir:
try:
print(f'Section {i}: Cleaning up temp dir {local_tmp_dir}')
shutil.rmtree(local_tmp_dir)
except Exception as e:
print(f"Error cleaning up temp dir {local_tmp_dir}: {e}")
return i, result_json, local_t_in, local_t_out, local_v_in, local_v_out
# ----------------------- Parallel execution -----------------------
max_workers = getattr(args, 'max_workers', 4)
results = {}
lock = threading.Lock()
with ThreadPoolExecutor(max_workers=max_workers) as ex:
futures = {
ex.submit(_process_section, i): i
for i in range(1, len(raw_content['sections']))
}
for fut in as_completed(futures):
i, rjson, t_in, t_out, v_in, v_out = fut.result()
with lock:
results[i] = rjson
total_input_token_t += t_in
total_output_token_t += t_out
total_input_token_v += v_in
total_output_token_v += v_out
# ----------------------- Title generation (sequential) -----------------------
title_json, title_input_token, title_output_token = gen_poster_title_content(args, actor_config)
total_input_token_t += title_input_token
total_output_token_t += title_output_token
# ----------------------- Assemble & save -----------------------
bullet_point_content = [title_json]
for idx in range(1, len(raw_content['sections'])):
bullet_point_content.append(results[idx])
json.dump(
bullet_point_content,
open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'w'),
indent=2
)
return total_input_token_t, total_output_token_t, total_input_token_v, total_output_token_v
def gen_poster_content(args, actor_config):
total_input_token, total_output_token = 0, 0
raw_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_raw_content.json', 'r'))
agent_name = 'poster_content_agent'
with open(f"utils/prompt_templates/{agent_name}.yaml", "r") as f:
content_config = yaml.safe_load(f)
actor_model = ModelFactory.create(
model_platform=actor_config['model_platform'],
model_type=actor_config['model_type'],
model_config_dict=actor_config['model_config']
)
actor_sys_msg = content_config['system_prompt']
def create_actor_agent():
actor_agent = ChatAgent(
system_message=actor_sys_msg,
model=actor_model,
message_window_size=10
)
return actor_agent
outline = json.load(open(f'outlines/{args.model_name}_{args.poster_name}_outline_{args.index}.json', 'r'))
raw_outline = json.loads(json.dumps(outline))
outline_estimate_num_chars(outline)
outline = remove_hierarchy_and_id(outline)
sections = list(outline.keys())
sections = [s for s in sections if s != 'meta']
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(content_config["template"])
poster_content = {}
poster_content, total_input_token, total_output_token = gen_content_parallel_process_sections(
sections,
outline,
raw_content,
raw_outline,
template,
create_actor_agent,
MAX_ATTEMPT=5
)
json.dump(poster_content, open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'w'), indent=2)
return total_input_token, total_output_token
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--poster_name', type=str, default=None)
parser.add_argument('--model_name', type=str, default='4o')
parser.add_argument('--poster_path', type=str, required=True)
parser.add_argument('--index', type=int, default=0)
parser.add_argument('--max_retry', type=int, default=3)
args = parser.parse_args()
actor_config = get_agent_config(args.model_name)
if args.poster_name is None:
args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_')
input_token, output_token = gen_poster_content(args, actor_config)
print(f'Token consumption: {input_token} -> {output_token}')