| from dotenv import load_dotenv | |
| import os | |
| from utils.src.utils import ppt_to_images, get_json_from_response | |
| import json | |
| import pptx | |
| from camel.models import ModelFactory | |
| from camel.types import ModelPlatformType, ModelType | |
| from camel.configs import ChatGPTConfig, QwenConfig | |
| from camel.agents import ChatAgent | |
| from utils.wei_utils import fill_content | |
| from camel.messages import BaseMessage | |
| from PIL import Image | |
| import pickle as pkl | |
| from utils.pptx_utils import * | |
| from utils.critic_utils import * | |
| from utils.wei_utils import * | |
| import importlib | |
| import yaml | |
| import os | |
| import shutil | |
| from datetime import datetime | |
| from jinja2 import Environment, StrictUndefined, Template | |
| import argparse | |
| load_dotenv() | |
| def fill_poster_content(args, actor_config): | |
| total_input_token, total_output_token = 0, 0 | |
| poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r')) | |
| agent_name = 'content_filler_agent' | |
| with open(f"prompt_templates/{agent_name}.yaml", "r") as f: | |
| fill_config = yaml.safe_load(f) | |
| actor_model = ModelFactory.create( | |
| model_platform=actor_config['model_platform'], | |
| model_type=actor_config['model_type'], | |
| model_config_dict=actor_config['model_config'], | |
| ) | |
| actor_sys_msg = fill_config['system_prompt'] | |
| actor_agent = ChatAgent( | |
| system_message=actor_sys_msg, | |
| model=actor_model, | |
| message_window_size=10, | |
| ) | |
| ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_ckpt_{args.index}.pkl', 'rb')) | |
| logs = ckpt['logs'] | |
| outline = ckpt['outline'] | |
| sections = list(outline.keys()) | |
| sections = [s for s in sections if s != 'meta'] | |
| jinja_env = Environment(undefined=StrictUndefined) | |
| template = jinja_env.from_string(fill_config["template"]) | |
| content_logs = {} | |
| for section_index in range(len(sections)): | |
| section_name = sections[section_index] | |
| section_code = logs[section_name][-1]['code'] | |
| print(f'Filling content for {section_name}') | |
| jinja_args = { | |
| 'content_json': poster_content[section_name], | |
| 'function_docs': documentation, | |
| 'existing_code': section_code | |
| } | |
| prompt = template.render(**jinja_args) | |
| if section_index == 0: | |
| existing_code = '' | |
| else: | |
| existing_code = content_logs[sections[section_index - 1]][-1]['concatenated_code'] | |
| content_logs[section_name] = fill_content( | |
| actor_agent, | |
| prompt, | |
| 3, | |
| existing_code | |
| ) | |
| shutil.copy('poster.pptx', f'tmp/content_poster_<{section_name}>.pptx') | |
| if content_logs[section_name][-1]['error'] is not None: | |
| raise Exception(f'Error in filling content for {section_name}: {content_logs[section_name][-1]["error"]}') | |
| total_input_token += content_logs[section_name][-1]['cumulative_tokens'][0] | |
| total_output_token += content_logs[section_name][-1]['cumulative_tokens'][1] | |
| ppt_to_images(f'tmp/content_poster_<{sections[-1]}>.pptx', 'tmp/content_preview') | |
| ckpt = { | |
| 'logs': logs, | |
| 'content_logs': content_logs, | |
| 'outline': outline, | |
| 'total_input_token': total_input_token, | |
| 'total_output_token': total_output_token | |
| } | |
| pkl.dump(ckpt, open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'wb')) | |
| return total_input_token, total_output_token | |
| def stylize_poster(args, actor_config): | |
| total_input_token, total_output_token = 0, 0 | |
| poster_content = json.load(open(f'contents/{args.model_name}_{args.poster_name}_poster_content_{args.index}.json', 'r')) | |
| agent_name = 'style_agent' | |
| with open(f"prompt_templates/{agent_name}.yaml", "r") as f: | |
| style_config = yaml.safe_load(f) | |
| actor_model = ModelFactory.create( | |
| model_platform=actor_config['model_platform'], | |
| model_type=actor_config['model_type'], | |
| model_config_dict=actor_config['model_config'], | |
| ) | |
| actor_sys_msg = style_config['system_prompt'] | |
| actor_agent = ChatAgent( | |
| system_message=actor_sys_msg, | |
| model=actor_model, | |
| message_window_size=10, | |
| ) | |
| ckpt = pkl.load(open(f'checkpoints/{args.model_name}_{args.poster_name}_content_ckpt_{args.index}.pkl', 'rb')) | |
| content_logs = ckpt['content_logs'] | |
| outline = ckpt['outline'] | |
| sections = list(outline.keys()) | |
| sections = [s for s in sections if s != 'meta'] | |
| jinja_env = Environment(undefined=StrictUndefined) | |
| template = jinja_env.from_string(style_config["template"]) | |
| style_logs = {} | |
| for section_index in range(len(sections)): | |
| section_name = sections[section_index] | |
| section_outline = json.dumps(outline[section_name]) | |
| section_code = content_logs[section_name][-1]['code'] | |
| print(f'Stylizing for {section_name}') | |
| img_ratio_json = get_img_ratio_in_section(poster_content[section_name]) | |
| jinja_args = { | |
| 'content_json': poster_content[section_name], | |
| 'function_docs': documentation, | |
| 'existing_code': section_code, | |
| 'image_ratio': img_ratio_json, | |
| } | |
| prompt = template.render(**jinja_args) | |
| if section_index == 0: | |
| existing_code = '' | |
| else: | |
| existing_code = style_logs[sections[section_index - 1]][-1]['concatenated_code'] | |
| style_logs[section_name] = stylize( | |
| actor_agent, | |
| prompt, | |
| args.max_retry, | |
| existing_code | |
| ) | |
| shutil.copy('poster.pptx', f'tmp/style_poster_<{section_name}>.pptx') | |
| if style_logs[section_name][-1]['error'] is not None: | |
| raise Exception(f'Error in stylizing for {section_name}') | |
| total_input_token += style_logs[section_name][-1]['cumulative_tokens'][0] | |
| total_output_token += style_logs[section_name][-1]['cumulative_tokens'][1] | |
| ppt_to_images(f'tmp/style_poster_<{sections[-1]}>.pptx', 'tmp/style_preview') | |
| ckpt = { | |
| 'logs': ckpt['logs'], | |
| 'content_logs': content_logs, | |
| 'style_logs': style_logs, | |
| 'outline': outline, | |
| 'total_input_token': total_input_token, | |
| 'total_output_token': total_output_token | |
| } | |
| with open(f'checkpoints/{args.model_name}_{args.poster_name}_style_ckpt_{args.index}.pkl', 'wb') as f: | |
| pkl.dump(ckpt, f) | |
| return total_input_token, total_output_token | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--poster_name', type=str, default=None) | |
| parser.add_argument('--model_name', type=str, default='4o') | |
| parser.add_argument('--poster_path', type=str, required=True) | |
| parser.add_argument('--index', type=int, default=0) | |
| parser.add_argument('--max_retry', type=int, default=3) | |
| args = parser.parse_args() | |
| actor_config = get_agent_config(args.model_name) | |
| if args.poster_name is None: | |
| args.poster_name = args.poster_path.split('/')[-1].replace('.pdf', '').replace(' ', '_') | |
| fill_total_input_token, fill_total_output_token = fill_poster_content(args, actor_config) | |
| style_total_input_token, style_total_output_token = stylize_poster(args, actor_config) | |
| total_input_token = fill_total_input_token + style_total_input_token | |
| total_output_token = fill_total_output_token + style_total_output_token | |
| print(f'Token consumption: {total_input_token} -> {total_output_token}') |