AutoPage / ProjectPageAgent /html_generator.py
Mqleet's picture
upd code
fcaa164
raw
history blame
25.4 kB
"""
HTML generator for project page generation.
Generates the final HTML project page from planned content.
"""
import json
import yaml
import os
import io
import re
import json
import yaml
from pathlib import Path
from urllib.parse import urlparse
from datetime import datetime
from jinja2 import Environment, StrictUndefined
from camel.models import ModelFactory
from camel.agents import ChatAgent
from utils.wei_utils import get_agent_config, account_token
from utils.src.utils import get_json_from_response, extract_html_code_block
from ProjectPageAgent.css_checker import check_css
from utils.src.utils import run_sync_screenshots
from PIL import Image
from camel.messages import BaseMessage
from camel.models import ModelFactory
def to_url(input_path_or_url: str) -> str:
parsed = urlparse(input_path_or_url)
if parsed.scheme in ("http", "https", "file"):
return input_path_or_url
p = Path(input_path_or_url).expanduser().resolve()
if not p.exists():
raise FileNotFoundError(f"Input not found: {p}")
return p.as_uri() # file://...
def crop_image_to_max_size(image_path, max_bytes=8*1024*1024, output_path=None):
img = Image.open(image_path)
img_format = img.format
if output_path is None:
output_path = image_path
buffer = io.BytesIO()
img.save(buffer, format=img_format)
size = buffer.getbuffer().nbytes
if size <= max_bytes:
img.save(output_path, format=img_format)
return output_path
width, height = img.size
scale = max_bytes / size
new_height = max(int(height * scale), 1)
img_cropped = img.crop((0, 0, width, new_height))
img_cropped.save(output_path, format=img_format)
return output_path
class ProjectPageHTMLGenerator:
"""Generates HTML project pages from planned content."""
def __init__(self, agent_config,args):
self.agent_config = agent_config
self.args = args
self.html_agent = self._create_html_agent()
self.review_agent = self._create_review_agent()
self.table_agent = self._create_table_agent()
self.long_agent = self._create_long_agent()
# self.client = OpenAI(api_key=api_key,base_url=api_url)
def _create_html_agent(self):
"""Create the HTML generation agent."""
model_type = str(self.agent_config['model_type'])
# Get API key from environment variables
api_key = None
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
api_key = os.environ.get('OPENAI_API_KEY')
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
api_key = os.environ.get('GEMINI_API_KEY')
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
api_key = os.environ.get('QWEN_API_KEY')
elif self.args.model_name_t.startswith('openrouter_'):
api_key = os.environ.get('OPENROUTER_API_KEY')
elif self.args.model_name_t in ['zhipuai']:
api_key = os.environ.get('ZHIPUAI_API_KEY')
if model_type.startswith('vllm_qwen') or 'vllm' in model_type.lower():
model = ModelFactory.create(
model_platform=self.agent_config['model_platform'],
model_type=self.agent_config['model_type'],
model_config_dict=self.agent_config['model_config'],
url=self.agent_config.get('url', None),
api_key=api_key,
)
else:
model = ModelFactory.create(
model_platform=self.agent_config['model_platform'],
model_type=self.agent_config['model_type'],
model_config_dict=self.agent_config['model_config'],
api_key=api_key,
)
system_message = """You are an expert web developer specializing in creating professional project pages for research papers.
You have extensive experience in HTML5, CSS3, responsive design, and academic content presentation.
Your goal is to create engaging, well-structured, and visually appealing project pages."""
return ChatAgent(
system_message=system_message,
model=model,
message_window_size=10
)
def _create_review_agent(self):
with open('utils/prompt_templates/page_templates/html_review.yaml', 'r') as f:
prompt_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
system_message_template = jinja_env.from_string(prompt_config["system_prompt"])
system_message = system_message_template.render()
model_type = self.args.model_name_v
# Get API key from environment variables
api_key = None
if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
api_key = os.environ.get('OPENAI_API_KEY')
elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
api_key = os.environ.get('GEMINI_API_KEY')
elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
api_key = os.environ.get('QWEN_API_KEY')
elif self.args.model_name_v.startswith('openrouter_'):
api_key = os.environ.get('OPENROUTER_API_KEY')
elif self.args.model_name_v in ['zhipuai']:
api_key = os.environ.get('ZHIPUAI_API_KEY')
config = get_agent_config(model_type)
model = ModelFactory.create(
model_platform=config['model_platform'],
model_type=config['model_type'],
model_config_dict=config['model_config'],
url=config.get('url', None),
api_key=api_key,
)
return ChatAgent(
system_message=system_message,
model=model,
message_window_size=10
)
def _create_table_agent(self):
model_type = self.args.model_name_v
# Get API key from environment variables
api_key = None
if self.args.model_name_v in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
api_key = os.environ.get('OPENAI_API_KEY')
elif self.args.model_name_v in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
api_key = os.environ.get('GEMINI_API_KEY')
elif self.args.model_name_v in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
api_key = os.environ.get('QWEN_API_KEY')
elif self.args.model_name_v.startswith('openrouter_'):
api_key = os.environ.get('OPENROUTER_API_KEY')
elif self.args.model_name_v in ['zhipuai']:
api_key = os.environ.get('ZHIPUAI_API_KEY')
vlm_config = get_agent_config(model_type)
vlm_model = ModelFactory.create(
model_platform=vlm_config['model_platform'],
model_type=vlm_config['model_type'],
model_config_dict=vlm_config['model_config'],
url=vlm_config.get('url', None),
api_key=api_key,
)
return ChatAgent(
system_message=None,
model=vlm_model,
message_window_size=10,
)
def _create_long_agent(self):
model_type = self.args.model_name_t
# Get API key from environment variables
api_key = None
if self.args.model_name_t in ['4o', '4o-mini', 'gpt-4.1', 'gpt-4.1-mini', 'o1', 'o3', 'o3-mini']:
api_key = os.environ.get('OPENAI_API_KEY')
elif self.args.model_name_t in ['gemini', 'gemini-2.5-pro', 'gemini-2.5-flash']:
api_key = os.environ.get('GEMINI_API_KEY')
elif self.args.model_name_t in ['qwen', 'qwen-plus', 'qwen-max', 'qwen-long']:
api_key = os.environ.get('QWEN_API_KEY')
elif self.args.model_name_t.startswith('openrouter_'):
api_key = os.environ.get('OPENROUTER_API_KEY')
elif self.args.model_name_t in ['zhipuai']:
api_key = os.environ.get('ZHIPUAI_API_KEY')
long_config = get_agent_config(model_type)
long_model = ModelFactory.create(
model_platform=long_config['model_platform'],
model_type=long_config['model_type'],
model_config_dict=long_config['model_config'],
url=long_config.get('url', None),
api_key=api_key,
)
return ChatAgent(
system_message=None,
model=long_model,
message_window_size=10,
token_limit=long_config.get('token_limit', None)
)
def render_html_to_png(self, iter, html_content, project_output_dir) -> str:
import time
tmp_html = Path(project_output_dir) / f"index_iter{iter}.html"
tmp_html.write_text(html_content, encoding="utf-8")
url = tmp_html.resolve().as_uri()
image_path = str(Path(project_output_dir) / f"page_iter{iter}.png")
run_sync_screenshots(url, image_path)
return image_path
def get_revision_suggestions(self, image_path: str, html_path) -> str:
def crop_image_max_width(img, max_width=1280):
width, height = img.size
if width > max_width:
img = img.crop((0, 0, max_width, height)) # (left, top, right, bottom)
return img
img = Image.open(image_path)
img = crop_image_max_width(img, max_width=1280)
img.save(image_path,format='PNG')
crop_image_to_max_size(image_path=image_path,output_path=image_path)
img =Image.open(image_path)
message = BaseMessage.make_user_message(
role_name="User",
content = '\nHere is the image of the generated project page.',
image_list=[img]
)
response = self.review_agent.step(message)
return get_json_from_response(response.msgs[0].content.strip())
def modify_html_table(self, html_content: str,html_dir: str):
in_tokens, out_tokens = 0, 0
print("Starting table modification...")
def replace_tables_in_html(html_content, table_html_map, paper_name):
pattern = rf'<img[^>]*src="(assets/{paper_name}-table-\d+\.png)"[^>]*>'
def repl(match):
img_path = match.group(1) # e.g. assets/MambaFusion-table-10.png
if img_path in table_html_map:
return table_html_map[img_path]
return match.group(0)
return re.sub(pattern, repl, html_content)
# ============ step 1 extract table ============
pattern = rf"assets/{self.args.paper_name}-table-\d+\.png"
with open(os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'index_no_modify_table.html'), 'r', encoding='utf-8') as f:
html_content = f.read()
matches = re.findall(pattern, html_content)
if matches is None:
print("No table images found, skipping modification.")
return None, 0, 0
model_type = self.args.model_name_v
print(f"Starting table modification phase 1: Table Extraction with {model_type}...")
with open('utils/prompt_templates/page_templates/extract_table.yaml', 'r') as f:
table_extraction_config = yaml.safe_load(f)
content = table_extraction_config["system_prompt"]
init_message = BaseMessage.make_user_message(
role_name="User",
content=content
)
response = self.table_agent.step(init_message)
in_tok , out_tok = account_token(response)
in_tokens += in_tok
out_tokens += out_tok
# Step 2
table_html_map = {}
matches = list(set(matches))
for match in matches:
img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,match)
print(f"Processing table image: {img_path}")
img = Image.open(img_path)
msg = BaseMessage.make_user_message(
role_name="User",
content=f'''Here is table image: {match}
Please output its HTML table (<table>...</table>) with an inline <style>...</style> block.
Only return pure HTML , nothing else.
''',
image_list=[img]
)
response = self.table_agent.step(msg)
in_tok , out_tok = account_token(response)
in_tokens += in_tok
out_tokens += out_tok
print(f'in:{in_tok},out:{out_tok}')
_output_html = response.msgs[0].content.strip()
table_html_map[match] = _output_html
tabel_dir = os.path.join(self.args.output_dir,self.args.paper_name, html_dir)
os.makedirs(f'{tabel_dir}/table_html', exist_ok=True)
with open(f'{tabel_dir}/table_html/{match.replace("/", "_")}.html', 'w', encoding='utf-8') as f:
f.write(table_html_map[match])
# ============ 阶段 2:HTML Merge ============
self.table_agent.reset()
img_path =os.path.join(self.args.output_dir,self.args.paper_name, html_dir,'page_final_no_modify_table.png')
img = Image.open(img_path)
with open('utils/prompt_templates/page_templates/color_suggestion.yaml','r') as f:
prompt_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
init_prompt_template = jinja_env.from_string(prompt_config["system_prompt"])
init_prompt = init_prompt_template.render()
msg = BaseMessage.make_user_message(
role_name="User",
content=init_prompt,
image_list=[img]
)
color_response = self.table_agent.step(msg)
color_suggestion = color_response.msgs[0].content.strip()
in_tok , out_tok = account_token(color_response)
in_tokens += in_tok
out_tokens += out_tok
print(f"Starting table modification phase 2: HTML Merging with {model_type}...")
tables_str = "\n\n".join(
[f"Table extracted for {fname}:\n{html}" for fname, html in table_html_map.items()]
)
with open("utils/prompt_templates/page_templates/merge_html_table.yaml",'r') as f:
prompt_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(prompt_config["template"])
jinja_args = {
'html_content': html_content,
'color_suggestion': color_suggestion,
'tables_str': tables_str
}
prompt = template.render(**jinja_args)
final_message = BaseMessage.make_user_message(
role_name = "User",
content = prompt
)
for i in range(3):
self.long_agent.reset()
response = self.long_agent.step(final_message)
in_tok, out_tok = account_token(response)
in_tokens += in_tok
out_tokens += out_tok
output_html = response.msgs[0].content.strip()
print(f'in:{in_tok},out:{out_tok}')
exteact_html_code = extract_html_code_block(output_html)
if exteact_html_code is not None:
break
print(f"html format is not correct, regenerate {i} turn")
return exteact_html_code, in_tokens, out_tokens
def modify_html_from_human_feedback(self, html_content: str, user_feedback: str):
"""
Modify HTML based on human feedback using the HTML agent.
Args:
html_content: Original HTML content
user_feedback: Feedback from human reviewers
Returns:
str: Modified HTML content
"""
in_tokens, out_tokens = 0, 0
print("Starting HTML modification based on human feedback...")
with open('utils/prompt_templates/page_templates/modify_html_from_human_feedback.yaml', 'r') as f:
modifier_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(modifier_config["template"])
jinja_args = {
'generated_html': html_content,
'user_feedback': user_feedback
}
prompt = template.render(**jinja_args)
for i in range(3):
self.html_agent.reset()
response = self.html_agent.step(prompt)
in_tok, out_tok = account_token(response)
in_tokens += in_tok
out_tokens += out_tok
print(f'input_token: {in_tok}, output_token: {out_tok}')
modified_html = extract_html_code_block(response.msgs[0].content)
if modified_html is not None:
break
print(f"html format is not correct, regenerate {i} turn")
return modified_html, in_tokens, out_tokens
def generate_complete_html(self, args, generated_content, html_dir, html_template=None):
"""
Generate complete HTML by combining all sections, then render to PNG,
send to OpenAI API for feedback, and regenerate HTML with suggestions.
"""
# Create output directory for this specific project
project_output_dir = f"{args.output_dir}/{args.paper_name}"
html_path = os.path.join(project_output_dir, html_dir)
if args.resume != 'html_check':
with open('utils/prompt_templates/page_templates/html_generation.yaml', 'r') as f:
generator_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
template = jinja_env.from_string(generator_config["template"])
jinja_args = {
'generated_content': json.dumps(generated_content, indent=2),
'html_template': html_template,
}
prompt = template.render(**jinja_args)
for i in range(3):
self.html_agent.reset()
# print(self.html_agent)
response = self.html_agent.step(prompt)
# print(response.msgs[0].content)
input_token, output_token = account_token(response)
print(f'input_token: {input_token}, output_token: {output_token}')
#print(input_token, output_token)
html_content = extract_html_code_block(response.msgs[0].content)
if html_content is not None:
break
print(f"html format is not correct, regenerate {i} turn")
# check css paths
html_content = check_css(html_content, html_template)
with open(os.path.join(html_path, 'index_init.html'),'w') as f:
f.write(html_content)
print(f"Initial HTML generation completed. Tokens: {input_token} -> {output_token}")
else:
with open(os.path.join(html_path, 'index_init.html'), 'r', encoding='utf-8') as f:
html_content = f.read()
revised_html = html_content
for i in range(self.args.html_check_times):
if i==0:
print("starting html check and revision...")
image_path = self.render_html_to_png(i, revised_html, html_path)
suggestions = self.get_revision_suggestions(image_path,os.path.join(html_path,f'index_iter{i}.html'))
# print(f"Revision suggestions from {self.args.model_name_v}:\n", suggestions)
review_path = f'project_contents/{args.paper_name}_html_review_iter{i}.json'
with open(review_path, 'w') as f:
json.dump(suggestions, f, indent=4)
self.html_agent.reset()
with open('utils/prompt_templates/page_templates/html_modify_from_suggestion.yaml', 'r') as f:
regenerator_config = yaml.safe_load(f)
jinja_env = Environment(undefined=StrictUndefined)
_template = jinja_env.from_string(regenerator_config["template"])
_jinja_args = {
'existing_html': revised_html,
'suggestions': suggestions
}
revision_prompt = _template.render(**_jinja_args)
# print(revision_prompt)
revised_response = self.html_agent.step(revision_prompt)
# print(revised_response.msgs[0].content)
revised_html = extract_html_code_block(revised_response.msgs[0].content)
print("Revised HTML generation completed.")
input_token, output_token = account_token(revised_response)
print(f'in:{input_token}, out:{output_token}')
return revised_html, input_token, output_token
def save_html_file(self, html_content, args, html_dir, output_dir="generated_project_pages"):
"""
Save the generated HTML to a file.
Args:
html_content: Generated HTML content
args: Command line arguments
output_dir: Output directory for the HTML file
Returns:html_check
str: Path to the saved HTML file
"""
os.makedirs(output_dir, exist_ok=True)
# Create output directory for this specific project
project_output_dir = f"{output_dir}/{args.paper_name}"
os.makedirs(project_output_dir, exist_ok=True)
# Save HTML file
html_file_path = f"{project_output_dir}/{html_dir}/index.html"
with open(html_file_path, 'w', encoding='utf-8') as f:
f.write(html_content)
print(f"HTML project page saved to: {html_file_path}")
return html_file_path
def create_assets_directory(self, args, html_dir, output_dir="generated_project_pages"):
"""
Create assets directory and copy images/tables.
Args:
args: Command line arguments
output_dir: Output directory
Returns:
str: Path to the assets directory
"""
project_output_dir = f"{output_dir}/{args.paper_name}"
assets_dir = os.path.join(project_output_dir, html_dir, "assets")
os.makedirs(assets_dir, exist_ok=True)
# Copy images and tables from the extracted assets
source_assets_dir = f"generated_project_pages/images_and_tables/{args.paper_name}"
if os.path.exists(source_assets_dir):
import shutil
for file in os.listdir(source_assets_dir):
if file.endswith(('.png', '.jpg', '.jpeg', '.gif')):
src_path = os.path.join(source_assets_dir, file)
dst_path = os.path.join(assets_dir, file)
shutil.copy2(src_path, dst_path)
print(f"Assets directory created at: {assets_dir}")
return assets_dir
def generate_metadata(self, generated_content, args):
"""
Generate metadata for the project page.
Args:
generated_content: Generated content
args: Command line arguments
Returns:
dict: Metadata for the project page
"""
metadata = {
'title': generated_content.get('meta', {}).get('poster_title', 'Research Project'),
'description': generated_content.get('meta', {}).get('abstract', '')[:160],
'authors': generated_content.get('meta', {}).get('authors', ''),
'affiliations': generated_content.get('meta', {}).get('affiliations', ''),
'keywords': [],
'generated_by': f"Paper2ProjectPage ({args.model_name_t}_{args.model_name_v})",
'generation_date': str(datetime.now())
}
# Extract keywords from content
content_text = json.dumps(generated_content, ensure_ascii=False)
# Simple keyword extraction (can be improved)
words = content_text.lower().split()
word_freq = {}
for word in words:
if len(word) > 4 and word.isalpha():
word_freq[word] = word_freq.get(word, 0) + 1
# Get top 10 most frequent words as keywords
sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)
metadata['keywords'] = [word for word, freq in sorted_words[:10]]
return metadata
def save_metadata(self, metadata, args, output_dir="generated_project_pages"):
"""
Save metadata to a JSON file.
Args:
metadata: Generated metadata
args: Command line arguments
output_dir: Output directory
Returns:
str: Path to the saved metadata file
"""
project_output_dir = f"{output_dir}/{args.paper_name}"
metadata_file_path = f"{project_output_dir}/metadata.json"
with open(metadata_file_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=4, ensure_ascii=False)
print(f"Metadata saved to: {metadata_file_path}")
return metadata_file_path