JaceWei's picture
update
0f74dc7
import os
print("Initializing...")
from PosterAgent.parse_raw import parse_raw, gen_image_and_table
from PosterAgent.gen_outline_layout import filter_image_table, gen_outline_layout_v2
from utils.wei_utils import get_agent_config, scale_to_target_area
# from PosterAgent.tree_split_layout import main_train, main_inference, get_arrangments_in_inches, split_textbox, to_inches
# from PosterAgent.gen_pptx_code import generate_poster_code
# from utils.src.utils import ppt_to_images
# from PosterAgent.gen_poster_content import gen_bullet_point_content
# from utils.ablation_utils import no_tree_get_layout
# Import refactored utilities
# from utils.logo_utils import LogoManager, add_logos_to_poster_code
# from utils.config_utils import (
# load_poster_yaml_config, extract_font_sizes, extract_colors,
# extract_vertical_alignment, extract_section_title_symbol, normalize_config_values
# )
# from utils.style_utils import apply_all_styles
# from utils.theme_utils import get_default_theme, create_theme_with_alignment, resolve_colors
# from PosterAgent.gen_beamer_code import (
# generate_beamer_poster_code,
# save_beamer_code,
# compile_beamer_to_pdf,
# convert_pptx_layout_to_beamer
# )
import argparse
import json
import time
import shutil
units_per_inch = 25
def to_inches(value_in_units, units_per_inch=72):
"""
Convert a single coordinate or dimension from 'units' to inches.
For example, if your units are 'points' (72 points = 1 inch),
then units_per_inch=72.
If your units are 'pixels' at 96 DPI, then units_per_inch=96.
"""
return value_in_units / units_per_inch
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Poster Generation Pipeline with Logo Support')
parser.add_argument('--poster_path', type=str)
parser.add_argument('--model_name_t', type=str, default='4o')
parser.add_argument('--model_name_v', type=str, default='4o')
parser.add_argument('--index', type=int, default=0)
parser.add_argument('--poster_name', type=str, default=None)
parser.add_argument('--tmp_dir', type=str, default='tmp')
parser.add_argument('--estimate_chars', action='store_true')
parser.add_argument('--max_workers', type=int, default=10)
parser.add_argument('--poster_width_inches', type=int, default=None)
parser.add_argument('--poster_height_inches', type=int, default=None)
parser.add_argument('--no_blank_detection', action='store_true', help='When overflow is severe, try this option.')
parser.add_argument('--ablation_no_tree_layout', action='store_true', help='Ablation study: no tree layout')
parser.add_argument('--ablation_no_commenter', action='store_true', help='Ablation study: no commenter')
parser.add_argument('--ablation_no_example', action='store_true', help='Ablation study: no example')
# Logo-related arguments
parser.add_argument('--conference_venue', type=str, default=None,
help='Conference name for automatic logo search (e.g., "NeurIPS", "CVPR")')
parser.add_argument('--institution_logo_path', type=str, default=None,
help='Custom path to institution logo (auto-searches from paper metadata if not provided)')
parser.add_argument('--conference_logo_path', type=str, default=None,
help='Custom path to conference logo (auto-searches if venue specified)')
parser.add_argument('--use_google_search', action='store_true',
help='Use Google Custom Search API for logo search (requires API keys in .env)')
args = parser.parse_args()
start_time = time.time()
os.makedirs(args.tmp_dir, exist_ok=True)
detail_log = {}
agent_config_t = get_agent_config(args.model_name_t)
agent_config_v = get_agent_config(args.model_name_v)
poster_name = args.poster_path.split('/')[-2].replace(' ', '_')
if args.poster_name is None:
args.poster_name = poster_name
else:
poster_name = args.poster_name
meta_json_path = args.poster_path.replace('paper.pdf', 'meta.json')
if args.poster_width_inches is not None and args.poster_height_inches is not None:
poster_width = args.poster_width_inches * units_per_inch
poster_height = args.poster_height_inches * units_per_inch
elif os.path.exists(meta_json_path):
meta_json = json.load(open(meta_json_path, 'r'))
poster_width = meta_json['width']
poster_height = meta_json['height']
else:
poster_width = 48 * units_per_inch
poster_height = 36 * units_per_inch
poster_width, poster_height = scale_to_target_area(poster_width, poster_height)
poster_width_inches = to_inches(poster_width, units_per_inch)
poster_height_inches = to_inches(poster_height, units_per_inch)
if poster_width_inches > 56 or poster_height_inches > 56:
# Work out which side is longer, then compute a single scale factor
if poster_width_inches >= poster_height_inches:
scale_factor = 56 / poster_width_inches
else:
scale_factor = 56 / poster_height_inches
poster_width_inches *= scale_factor
poster_height_inches *= scale_factor
# convert back to internal units
poster_width = poster_width_inches * units_per_inch
poster_height = poster_height_inches * units_per_inch
print(f'Poster size: {poster_width_inches} x {poster_height_inches} inches')
total_input_tokens_t, total_output_tokens_t = 0, 0
total_input_tokens_v, total_output_tokens_v = 0, 0
# Step 1: Parse the raw poster
input_token, output_token, raw_result = parse_raw(args, agent_config_t, version=2)
total_input_tokens_t += input_token
total_output_tokens_t += output_token
_, _, images, tables = gen_image_and_table(args, raw_result)
print(f'Parsing token consumption: {input_token} -> {output_token}')
parser_time_taken = time.time() - start_time
print(f'Parser time: {parser_time_taken:.2f} seconds')
detail_log['parser_time'] = parser_time_taken
parser_time = time.time()
detail_log['parser_in_t'] = input_token
detail_log['parser_out_t'] = output_token
# # Initialize LogoManager
# logo_manager = LogoManager()
# institution_logo_path = args.institution_logo_path
# conference_logo_path = args.conference_logo_path
# # Auto-detect institution from paper if not provided
# # Now using the raw_result directly instead of reading from file
# if not institution_logo_path:
# print("\n" + "="*60)
# print("🔍 AUTO-DETECTING INSTITUTION FROM PAPER")
# print("="*60)
# # Use the raw_result we already have from the parser
# if raw_result:
# print(f"📄 Using parsed paper content")
# # Extract text content from the ConversionResult object
# try:
# paper_text = raw_result.document.export_to_markdown()
# except:
# # Fallback: try to get text content in another way
# paper_text = str(raw_result)
# print("🔎 Searching for FIRST AUTHOR's institution...")
# first_author_inst = logo_manager.extract_first_author_institution(paper_text)
# if first_author_inst:
# print(f"\n✅ FIRST AUTHOR INSTITUTION: {first_author_inst}")
# print(f"🔍 Searching for logo: {first_author_inst}")
# inst_logo_path = logo_manager.get_logo_path(first_author_inst, category="institute", use_google=args.use_google_search)
# if inst_logo_path:
# institution_logo_path = str(inst_logo_path)
# print(f"✅ Institution logo found: {institution_logo_path}")
# else:
# print(f"❌ Could not find/download logo for: {first_author_inst}")
# else:
# print("❌ No first author institution detected or matched with available logos")
# else:
# print("❌ No parsed content available")
# print("="*60 + "\n")
# # Handle conference logo
# if args.conference_venue and not conference_logo_path:
# print("\n" + "="*60)
# print("🏛️ SEARCHING FOR CONFERENCE LOGO")
# print("="*60)
# print(f"📍 Conference: {args.conference_venue}")
# print(f"🔍 Searching for logo...")
# conf_logo_path = logo_manager.get_logo_path(args.conference_venue, category="conference", use_google=args.use_google_search)
# if conf_logo_path:
# conference_logo_path = str(conf_logo_path)
# print(f"✅ Conference logo found: {conference_logo_path}")
# else:
# print(f"❌ Could not find/download logo for: {args.conference_venue}")
# # Note: Web search is now handled inside get_logo_path automatically
# print("="*60 + "\n")
# Step 2: Filter unnecessary images and tables
input_token, output_token = filter_image_table(args, agent_config_t)
total_input_tokens_t += input_token
total_output_tokens_t += output_token
print(f'Filter figures token consumption: {input_token} -> {output_token}')
filter_time_taken = time.time() - parser_time
print(f'Filter time: {filter_time_taken:.2f} seconds')
detail_log['filter_time'] = filter_time_taken
filter_time = time.time()
detail_log['filter_in_t'] = input_token
detail_log['filter_out_t'] = output_token
# Step 3: Generate outline
input_token, output_token, panels, figures = gen_outline_layout_v2(args, agent_config_t)
total_input_tokens_t += input_token
total_output_tokens_t += output_token
print(f'Outline token consumption: {input_token} -> {output_token}')
outline_time_taken = time.time() - filter_time
print(f'Outline time: {outline_time_taken:.2f} seconds')
detail_log['outline_time'] = outline_time_taken
outline_time = time.time()
detail_log['outline_in_t'] = input_token
detail_log['outline_out_t'] = output_token
# if args.ablation_no_tree_layout:
# panel_arrangement, figure_arrangement, text_arrangement, input_token, output_token = no_tree_get_layout(
# poster_width,
# poster_height,
# panels,
# figures,
# agent_config_t
# )
# total_input_tokens_t += input_token
# total_output_tokens_t += output_token
# print(f'No tree layout token consumption: {input_token} -> {output_token}')
# detail_log['no_tree_layout_in_t'] = input_token
# detail_log['no_tree_layout_out_t'] = output_token
# else:
# # Step 4: Learn and generate layout
# panel_model_params, figure_model_params = main_train()
# panel_arrangement, figure_arrangement, text_arrangement = main_inference(
# panels,
# panel_model_params,
# figure_model_params,
# poster_width,
# poster_height,
# shrink_margin=3
# )
# text_arrangement_title = text_arrangement[0]
# text_arrangement = text_arrangement[1:]
# # Split the title textbox into two parts
# text_arrangement_title_top, text_arrangement_title_bottom = split_textbox(
# text_arrangement_title,
# 0.8
# )
# # Add the split textboxes back to the list
# text_arrangement = [text_arrangement_title_top, text_arrangement_title_bottom] + text_arrangement
# for i in range(len(figure_arrangement)):
# panel_id = figure_arrangement[i]['panel_id']
# panel_section_name = panels[panel_id]['section_name']
# figure_info = figures[panel_section_name]
# if 'image' in figure_info:
# figure_id = figure_info['image']
# if not figure_id in images:
# figure_path = images[str(figure_id)]['image_path']
# else:
# figure_path = images[figure_id]['image_path']
# elif 'table' in figure_info:
# figure_id = figure_info['table']
# if not figure_id in tables:
# figure_path = tables[str(figure_id)]['table_path']
# else:
# figure_path = tables[figure_id]['table_path']
# figure_arrangement[i]['figure_path'] = figure_path
# for text_arrangement_item in text_arrangement:
# num_chars = char_capacity(
# bbox=(text_arrangement_item['x'], text_arrangement_item['y'], text_arrangement_item['height'], text_arrangement_item['width'])
# )
# text_arrangement_item['num_chars'] = num_chars
# width_inch, height_inch, panel_arrangement_inches, figure_arrangement_inches, text_arrangement_inches = get_arrangments_in_inches(
# poster_width, poster_height, panel_arrangement, figure_arrangement, text_arrangement, 25
# )
# # Save to file
# tree_split_results = {
# 'poster_width': poster_width,
# 'poster_height': poster_height,
# 'poster_width_inches': width_inch,
# 'poster_height_inches': height_inch,
# 'panels': panels,
# 'panel_arrangement': panel_arrangement,
# 'figure_arrangement': figure_arrangement,
# 'text_arrangement': text_arrangement,
# 'panel_arrangement_inches': panel_arrangement_inches,
# 'figure_arrangement_inches': figure_arrangement_inches,
# 'text_arrangement_inches': text_arrangement_inches,
# }
# ============================
# ### NEW: only build a simple figure_arrangement with {panel_id, figure_path}
# ============================
# 有些项目把 images/tables 放在上游全局;若未定义,则从过滤结果 JSON 兜底加载
try:
images
except NameError:
images = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_images_filtered.json', 'r'))
try:
tables
except NameError:
tables = json.load(open(f'<{args.model_name_t}_{args.model_name_v}>_images_and_tables/{args.poster_name}_tables_filtered.json', 'r'))
# 建立 section_name -> panel_id 的映射
section2pid = {p['section_name']: p['panel_id'] for p in panels}
# 生成精简后的 figure_arrangement:只保留 panel_id 与 figure_path
simple_figure_arrangement = []
for section_name, f in figures.items():
if section_name not in section2pid:
continue
pid = section2pid[section_name]
fig_path = None
if 'image' in f:
fid = str(f['image'])
info = images.get(fid) or images.get(str(fid)) or {}
fig_path = info.get('image_path')
elif 'table' in f:
tid = str(f['table'])
info = tables.get(tid) or tables.get(str(tid)) or {}
fig_path = info.get('table_path')
if fig_path: # 只收集有路径的
simple_figure_arrangement.append({
'panel_id': pid,
'figure_path': fig_path,
})
# ============================
# ### REMOVED: no layout/train/text capacity/inches conversion
# - 删除 args.ablation_no_tree_layout 分支
# - 删除 main_train() / main_inference()
# - 删除为 figure_arrangement[i] 补 figure_path 的循环
# - 删除 text_arrangement / char_capacity / get_arrangments_in_inches
# ============================
# Save to file (只保留 panels + figure_arrangement)
tree_split_results = {
'panels': panels,
'figure_arrangement': simple_figure_arrangement,
}
os.makedirs('tree_splits', exist_ok=True)
with open(f'tree_splits/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_tree_split_{args.index}.json', 'w') as f:
json.dump(tree_split_results, f, indent=4)
layout_time_taken = time.time() - outline_time
print(f'Layout time: {layout_time_taken:.2f} seconds')
detail_log['layout_time'] = layout_time_taken
layout_time = time.time()
# # === Configuration Loading ===
# print("\n📋 Loading configuration from YAML files...", flush=True)
# yaml_cfg = load_poster_yaml_config(args.poster_path)
# # Extract configuration values
# bullet_fs, title_fs, poster_title_fs, poster_author_fs = extract_font_sizes(yaml_cfg)
# title_text_color, title_fill_color, main_text_color, main_text_fill_color = extract_colors(yaml_cfg)
# section_title_vertical_align = extract_vertical_alignment(yaml_cfg)
# section_title_symbol = extract_section_title_symbol(yaml_cfg)
# # Normalize configuration values
# bullet_fs, title_fs, poster_title_fs, poster_author_fs, \
# title_text_color, title_fill_color, main_text_color, main_text_fill_color = normalize_config_values(
# bullet_fs, title_fs, poster_title_fs, poster_author_fs,
# title_text_color, title_fill_color, main_text_color, main_text_fill_color
# )
# # Store configuration in args
# setattr(args, 'bullet_font_size', bullet_fs)
# setattr(args, 'section_title_font_size', title_fs)
# setattr(args, 'poster_title_font_size', poster_title_fs)
# setattr(args, 'poster_author_font_size', poster_author_fs)
# setattr(args, 'title_text_color', title_text_color)
# setattr(args, 'title_fill_color', title_fill_color)
# setattr(args, 'main_text_color', main_text_color)
# setattr(args, 'main_text_fill_color', main_text_fill_color)
# setattr(args, 'section_title_vertical_align', section_title_vertical_align)
# # Step 5: Generate content
# print(f"\n✍️ Generating poster content (max_workers={args.max_workers})...", flush=True)
# # --- Step 1: 检查缓存 ---
# content_cache_path = f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json'
# if os.path.exists(content_cache_path):
# print(f"🧩 Cache found: {content_cache_path}")
# print("⚡ Skipping model generation, loading from cache...")
# bullet_content = json.load(open(content_cache_path, 'r'))
# input_token_t = output_token_t = input_token_v = output_token_v = 0
# else:
# print("🧠 Running model to generate poster content...")
# input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(
# args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir
# )
# bullet_content = json.load(open(content_cache_path, 'r'))
# input_token_t, output_token_t, input_token_v, output_token_v = gen_bullet_point_content(args, agent_config_t, agent_config_v, tmp_dir=args.tmp_dir)
# total_input_tokens_t += input_token
# total_output_tokens_t += output_token
# total_input_tokens_v += input_token_v
# total_output_tokens_v += output_token_v
# print(f'Content generation token consumption T: {input_token_t} -> {output_token_t}')
# print(f'Content generation token consumption V: {input_token_v} -> {output_token_v}')
# content_time_taken = time.time() - layout_time
# print(f'Content generation time: {content_time_taken:.2f} seconds')
# detail_log['content_time'] = content_time_taken
# content_time = time.time()
# bullet_content = json.load(open(f'contents/<{args.model_name_t}_{args.model_name_v}>_{args.poster_name}_bullet_point_content_{args.index}.json', 'r'))
# detail_log['content_in_t'] = input_token_t
# detail_log['content_out_t'] = output_token_t
# detail_log['content_in_v'] = input_token_v
# detail_log['content_out_v'] = output_token_v
# # === Style Application ===
# print("\n🎨 Applying styles and colors...", flush=True)
# # Resolve colors with fallbacks
# final_title_text_color, final_title_fill_color, final_main_text_color, final_main_text_fill_color = resolve_colors(
# getattr(args, 'title_text_color', None),
# getattr(args, 'title_fill_color', None),
# getattr(args, 'main_text_color', None),
# getattr(args, 'main_text_fill_color', None)
# )
# # Apply all styles in one go
# bullet_content = apply_all_styles(
# bullet_content,
# title_text_color=final_title_text_color,
# title_fill_color=final_title_fill_color,
# main_text_color=final_main_text_color,
# main_text_fill_color=final_main_text_fill_color,
# section_title_symbol=section_title_symbol,
# main_text_font_size=bullet_fs
# )
# # === Poster Generation ===
# # print("\n🎯 Generating PowerPoint code...", flush=True)
# # Create theme with alignment
# base_theme = get_default_theme()
# theme_with_alignment = create_theme_with_alignment(
# base_theme,
# getattr(args, 'section_title_vertical_align', None)
# )
# # poster_code = generate_poster_code(
# # panel_arrangement_inches,
# # text_arrangement_inches,
# # figure_arrangement_inches,
# # presentation_object_name='poster_presentation',
# # slide_object_name='poster_slide',
# # utils_functions=utils_functions,
# # slide_width=width_inch,
# # slide_height=height_inch,
# # img_path=None,
# # save_path=f'{args.tmp_dir}/poster.pptx',
# # visible=False,
# # content=bullet_content,
# # theme=theme_with_alignment,
# # tmp_dir=args.tmp_dir,
# # )
# print("\n🎯 Generating Beamer poster (LaTeX)...", flush=True)
# # --- 1. 提取 poster_info ---
# poster_info = {
# "title": args.poster_name,
# "author": "AutoGen",
# "institute": "Auto-detected Institution"
# }
# if isinstance(bullet_content, list) and len(bullet_content) > 0:
# first_section = bullet_content[0]
# if isinstance(first_section, dict):
# if "poster_title" in first_section:
# poster_info["title"] = first_section["poster_title"]
# elif "title" in first_section:
# poster_info["title"] = first_section["title"]
# --- 2. 构造 Beamer 数据结构 ---
# layout_data = {
# "text_arrangement": text_arrangement,
# "figure_arrangement": figure_arrangement
# }
# beamer_data = convert_pptx_layout_to_beamer(layout_data)
# 将 bullet_content 映射进 sections
# for i, section in enumerate(beamer_data["sections"]):
# if i < len(bullet_content):
# section_data = bullet_content[i]
# if isinstance(section_data, dict):
# section["content"] = section_data.get("textbox1") or section_data.get("title") or json.dumps(section_data)
# else:
# section["content"] = str(section_data)
# --- 3. 生成 LaTeX 文件 ---
# poster_info = {k: (str(v) if not isinstance(v, str) else v) for k, v in poster_info.items()}
# beamer_code = generate_beamer_poster_code(
# beamer_data["sections"],
# beamer_data["figures"],
# poster_info,
# width_cm=poster_width_inches * 2.54,
# height_cm=poster_height_inches * 2.54,
# theme="Madrid",
# output_path=f"{args.tmp_dir}/{poster_name}.tex"
# )
# save_beamer_code(beamer_code, f"{args.tmp_dir}/{poster_name}.tex")
# --- 4. 编译为 PDF ---
# output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_beamer_posters/{args.poster_path.replace("paper.pdf", "")}'
# compile_beamer_to_pdf(f"{args.tmp_dir}/{poster_name}.tex", output_dir=args.tmp_dir)
# pdf_path = os.path.join(args.tmp_dir, f"{poster_name}.pdf")
# os.makedirs(output_dir, exist_ok=True)
# os.rename(pdf_path, os.path.join(output_dir, f"{poster_name}.pdf"))
# print(f"✅ Beamer poster PDF saved to {output_dir}")
# Add logos to the poster
# print("\n🖼️ Adding logos to poster...", flush=True)
# poster_code = add_logos_to_poster_code(
# poster_code,
# width_inch,
# height_inch,
# institution_logo_path=institution_logo_path,
# conference_logo_path=conference_logo_path
# )
# output, err = run_code(poster_code)
# if err is not None:
# raise RuntimeError(f'Error in generating PowerPoint: {err}')
# # Step 8: Create a folder in the output directory
# output_dir = f'<{args.model_name_t}_{args.model_name_v}>_generated_posters/{args.poster_path.replace("paper.pdf", "")}'
# os.makedirs(output_dir, exist_ok=True)
# # Copy logos to output directory for reference
# logos_dir = os.path.join(output_dir, 'logos')
# if institution_logo_path or conference_logo_path:
# os.makedirs(logos_dir, exist_ok=True)
# if institution_logo_path and os.path.exists(institution_logo_path):
# shutil.copy2(institution_logo_path, os.path.join(logos_dir, 'institution_logo' + os.path.splitext(institution_logo_path)[1]))
# if conference_logo_path and os.path.exists(conference_logo_path):
# shutil.copy2(conference_logo_path, os.path.join(logos_dir, 'conference_logo' + os.path.splitext(conference_logo_path)[1]))
# # Step 9: Move poster.pptx to the output directory
# pptx_path = os.path.join(output_dir, f'{poster_name}.pptx')
# os.rename(f'{args.tmp_dir}/poster.pptx', pptx_path)
# print(f'Poster PowerPoint saved to {pptx_path}')
# # Step 10: Convert the PowerPoint to images
# ppt_to_images(pptx_path, output_dir)
# print(f'Poster images saved to {output_dir}')
# end_time = time.time()
# time_taken = end_time - start_time
# render_time_taken = time.time() - content_time
# print(f'Render time: {render_time_taken:.2f} seconds')
# detail_log['render_time'] = render_time_taken
# # log
# log_file = os.path.join(output_dir, 'log.json')
# with open(log_file, 'w') as f:
# log_data = {
# 'input_tokens_t': total_input_tokens_t,
# 'output_tokens_t': total_output_tokens_t,
# 'input_tokens_v': total_input_tokens_v,
# 'output_tokens_v': total_output_tokens_v,
# 'time_taken': time_taken,
# 'institution_logo': institution_logo_path,
# 'conference_logo': conference_logo_path,
# }
# json.dump(log_data, f, indent=4)
# detail_log_file = os.path.join(output_dir, 'detail_log.json')
# with open(detail_log_file, 'w') as f:
# json.dump(detail_log, f, indent=4)
# print(f'\nTotal time: {time_taken:.2f} seconds')
# print(f'Total text model tokens: {total_input_tokens_t} -> {total_output_tokens_t}')
# print(f'Total vision model tokens: {total_input_tokens_v} -> {total_output_tokens_v}')
# if institution_logo_path:
# print(f'Institution logo added: {institution_logo_path}')
# if conference_logo_path:
# print(f'Conference logo added: {conference_logo_path}')