Decomate / app.py
Milhaud's picture
refactor: remove commented code and improve function documentation
cb78f99
raw
history blame
27.3 kB
import gradio as gr
import xml.etree.ElementTree as ET
import re
import anthropic
import os
from dotenv import load_dotenv
from utils import svg_to_png_base64, fix_html_styles_for_preview
import pathlib
from logger import setup_logger
import html
import uuid
load_dotenv(override=True)
logger = setup_logger()
class SVGAnimationGenerator:
def __init__(self):
self.client = None
self.predict_decompose_group_prompt = self._get_prompt(
"prompts/predict_decompose_group.txt"
)
self.feedback_decompose_group_prompt = self._get_prompt(
"prompts/feedback_decompose_group.txt"
)
self.generate_animation_prompt = self._get_prompt(
"prompts/generate_animation.txt"
)
if "ANTHROPIC_API_KEY" in os.environ:
self.client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def _get_prompt(self, prompt_file_path: str) -> str:
try:
with open(prompt_file_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return "Prompt file not found. Please check the path."
def parse_svg(self, svg_content: str) -> dict:
try:
svg_content = re.sub(r'xmlns[^=]*="[^"]*"', "", svg_content)
svg_content = re.sub(r"<svg[^>]*>", "<svg>", svg_content)
return {"svg_content": svg_content}
except Exception as e:
return {"error": f"SVG parsing error: {e}"}
def predict_decompose_group(self, parsed_svg: dict, object_name: str) -> dict:
try:
svg_content = parsed_svg["svg_content"]
image_media_type, image_data = svg_to_png_base64(svg_content)
if not image_data:
return {"error": "Failed to convert SVG to PNG"}
prompt = self.predict_decompose_group_prompt.format(
object_name=object_name, svg_content=svg_content
)
logger.info(f"Decomposition Prompt for {object_name}:\n{prompt}")
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=10000,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": image_media_type,
"data": image_data,
},
},
{"type": "text", "text": prompt},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "<animation_plan>"}],
},
],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
decomposed_svg_match = re.search(
r"<decomposed_svg>(.*?)</decomposed_svg>", response_text, re.DOTALL
)
animation_suggestions_match = re.search(
r"<animation_suggestions>(.*?)</animation_suggestions>",
response_text,
re.DOTALL,
)
print(
"[SVG Decompose] Decomposed SVG found", decomposed_svg_match is not None
)
if decomposed_svg_match and animation_suggestions_match:
decomposed_svg_text = decomposed_svg_match.group(1).strip()
animation_suggestions = animation_suggestions_match.group(1).strip()
print("[SVG Decompose] Animation suggestions found")
return {
"svg_content": decomposed_svg_text,
"animation_suggestions": animation_suggestions,
}
else:
return {
"error": "Decomposed SVG and Animation Suggestion not found in response."
}
except Exception as e:
return {"error": f"Error during MLLM prediction: {e}"}
def feedback_decompose_group(self, svg_content: str, feedback: str) -> tuple:
try:
# Parse the SVG content first
parsed_svg = self.parse_svg(svg_content)
if "error" in parsed_svg:
error_message = parsed_svg["error"]
error_html = create_error_html(error_message)
return "", "", error_html
prompt = self.feedback_decompose_group_prompt.format(
parsed_svg=parsed_svg, feedback=feedback
)
logger.info(f"Feedback Prompt:\n{prompt}")
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=10000,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
}
],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
decomposed_svg_match = re.search(
r"<decomposed_svg>(.*?)</decomposed_svg>", response_text, re.DOTALL
)
animation_suggestions_match = re.search(
r"<animation_suggestions>(.*?)</animation_suggestions>",
response_text,
re.DOTALL,
)
if decomposed_svg_match and animation_suggestions_match:
decomposed_svg_text = decomposed_svg_match.group(1).strip()
animation_suggestions = animation_suggestions_match.group(1).strip()
viewer_html = create_svg_viewer_html(decomposed_svg_text)
return decomposed_svg_text, animation_suggestions, viewer_html
else:
error_message = (
"Decomposed SVG and Animation Suggestion not found in response."
)
error_html = create_error_html(error_message)
return "", "", error_html
except Exception as e:
error_message = f"Error during MLLM feedback prediction: {e}"
error_html = create_error_html(error_message)
return "", "", error_html
def generate_animation(self, proposed_animation: str, svg_content: str) -> tuple:
try:
prompt = self.generate_animation_prompt.format(
svg_content=svg_content, proposed_animation=proposed_animation
)
logger.info(f"Animation Generation Prompt:\n{prompt}")
if self.client:
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=20000,
messages=[{"role": "user", "content": prompt}],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
# Extract HTML content from Claude's response
html_match = re.search(
r"<html_output>(.*?)</html_output>", response_text, re.DOTALL
)
if html_match:
html_content = html_match.group(1).strip()
return response_text, html_content
else:
return response_text, ""
else:
error_msg = "Anthropic API client not initialized"
return f"<html><body><h3>Error: {error_msg}</h3></body></html>", ""
except Exception as e:
error_msg = f"Error generating animation: {e}"
return f"<html><body><h3>{error_msg}</h3></body></html>", ""
def _sanitize_svg(svg: str) -> str:
"""Script tag/inline on*Handler removal for safe preview"""
svg = re.sub(r"<\s*script\b[^>]*>.*?<\s*/\s*script\s*>", "", svg, flags=re.I | re.S)
svg = re.sub(r"\son[a-zA-Z]+\s*=\s*(['\"]).*?\1", "", svg, flags=re.I | re.S)
return svg
def _fix_svg_markup(svg: str) -> str:
"""SVG attribute to fit the viewer (size/ratio/problem style modification)"""
s = svg
s = s.replace('preserveAspectRatio="none"', 'preserveAspectRatio="xMidYMid meet"')
s = s.replace(
'style="display: block; overflow: hidden; position: absolute; left: 0px; top: 0px;"',
'style="display:block; width:100%; height:100%;"',
)
s = re.sub(r'width="[^"]+"', 'width="100%"', s, count=1)
s = re.sub(r'height="[^"]+"', 'height="100%"', s, count=1)
return s
generator = SVGAnimationGenerator()
def create_error_html(message: str) -> str:
"""Create formatted error message HTML with XSS protection."""
safe_message = html.escape(str(message))
return f"""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<h3>{safe_message}</h3>
</div>
"""
def create_svg_viewer_html(svg_content: str) -> str:
"""Decomposed SVG를 인터랙티브(외곽선 하이라이트/툴팁) 미리보기로 iframe에서 렌더링"""
svg = (svg_content or "").strip()
if not (svg.startswith("<svg") and svg.endswith("</svg>")):
return create_error_html("Invalid SVG format")
svg = _sanitize_svg(svg)
svg = _fix_svg_markup(svg)
uid = f"svg-preview-{uuid.uuid4().hex}"
doc = f"""<!doctype html>
<html>
<head>
<meta charset="utf-8">
<style>
:root {{
--hl-color: #ff3b30; /* 🔧 원하는 색으로 바꿔도 됨 */
--hl-width: 3.5px;
}}
#{uid} svg {{
max-width: 100%;
max-height: 100%;
display: block;
}}
/* 포인터 이벤트는 도형 요소만 */
#{uid} svg g,
#{uid} svg path,
#{uid} svg rect,
#{uid} svg circle,
#{uid} svg ellipse,
#{uid} svg polygon,
#{uid} svg polyline,
#{uid} svg line,
#{uid} svg text {{
cursor: pointer;
pointer-events: visiblePainted;
transition: filter .12s ease, stroke-width .12s ease;
}}
/* 보조 요소는 포인터/효과 제외 */
#{uid} svg defs,
#{uid} svg clipPath,
#{uid} svg mask,
#{uid} svg title,
#{uid} svg desc {{
pointer-events: none !important;
}}
/* ✅ 외곽선 하이라이트: 기존 stroke가 있어도 강제로 덮어쓰기 */
#{uid} .hl {{
stroke: var(--hl-color) !important;
stroke-width: var(--hl-width) !important;
paint-order: stroke fill; /* stroke를 위로 */
vector-effect: non-scaling-stroke;
filter: drop-shadow(0 0 6px rgba(0,0,0,.35));
}}
/* 필요하다면 다른 요소 미세 디밍 (기본 OFF) */
/* #{uid} .dim {{ opacity: .65; }} */
#{uid}-wrap {{
position:relative; padding:20px; background:#fff;
border:1px solid #eee; border-radius:8px; height:100%; box-sizing:border-box;
}}
#{uid} {{
border:1px solid #ddd; border-radius:8px; background:#fafafa;
height:100%; min-height:360px; display:flex; align-items:center;
justify-content:center; padding:20px; position:relative; box-sizing:border-box;
}}
#{uid}-tooltip {{
position: absolute; display: none; pointer-events: none;
background: rgba(0,0,0,0.9); color: #fff; border: 2px solid #fff;
border-radius: 6px; padding: 6px 10px; font-size: 12px; font-weight: 600;
white-space: nowrap; z-index: 10;
}}
</style>
</head>
<body style="margin:0;height:100vh;">
<div id="{uid}-wrap">
<div id="{uid}" class="svg-container">
{svg}
<div class="tooltip" id="{uid}-tooltip"></div>
</div>
</div>
<script>
(function() {{
const root = document.getElementById("{uid}");
if (!root) return;
const svg = root.querySelector("svg");
if (!svg) return;
const tooltip = document.getElementById("{uid}-tooltip");
const GEOM_SEL = "g,path,rect,circle,ellipse,polygon,polyline,line,text";
const DIM_OTHERS = false; // 🔧 true로 바꾸면 나머지 살짝 디밍
function allGeom() {{
return svg.querySelectorAll(GEOM_SEL);
}}
function closestWithId(node) {{
let cur = node;
while (cur && cur !== svg) {{
if (cur.id) return cur;
cur = cur.parentNode;
}}
return null;
}}
function clearMarks() {{
svg.querySelectorAll(".hl").forEach(el => el.classList.remove("hl"));
if (DIM_OTHERS) svg.querySelectorAll(".dim").forEach(el => el.classList.remove("dim"));
}}
function highlightOwner(owner) {{
// 그룹이면 하위 도형 전체에 hl, 도형이면 자기 자신에만
const targets = owner.matches(GEOM_SEL)
? [owner, ...owner.querySelectorAll(GEOM_SEL)]
: [...owner.querySelectorAll(GEOM_SEL)];
targets.forEach(el => el.classList.add("hl"));
}}
function dimExcept(owner) {{
if (!DIM_OTHERS) return;
const keep = new Set([owner, ...owner.querySelectorAll(GEOM_SEL)]);
allGeom().forEach(el => {{
if (!keep.has(el)) el.classList.add("dim");
}});
}}
function moveTooltip(e) {{
const rect = root.getBoundingClientRect();
tooltip.style.left = (e.clientX - rect.left + 10) + "px";
tooltip.style.top = (e.clientY - rect.top + 10) + "px";
}}
svg.addEventListener("pointerover", (e) => {{
const owner = closestWithId(e.target);
clearMarks();
if (owner && owner.id) {{
highlightOwner(owner);
dimExcept(owner);
tooltip.textContent = owner.id;
tooltip.style.display = "block";
moveTooltip(e);
}} else {{
tooltip.style.display = "none";
}}
}});
svg.addEventListener("pointermove", (e) => {{
if (tooltip.style.display === "block") moveTooltip(e);
}});
svg.addEventListener("pointerleave", () => {{
clearMarks();
tooltip.style.display = "none";
}});
}})();
</script>
</body>
</html>
"""
safe_doc = doc.replace('"', "&quot;")
return f"""
<div style='padding: 20px; width:100%; height:520px; background: #fff; border: 1px solid #eee; border-radius: 8px; display: block;'>
<iframe
srcdoc="{safe_doc}"
width="100%"
height="100%"
style="border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
"""
def _extract_path_from_gradio_file(svg_file) -> str | None:
if isinstance(svg_file, (str, pathlib.Path)):
return str(svg_file)
if isinstance(svg_file, dict) and "name" in svg_file:
return svg_file["name"]
if hasattr(svg_file, "name"):
return svg_file.name
return None
def process_svg(svg_file):
if svg_file is None:
return "Please upload an SVG file"
try:
path = _extract_path_from_gradio_file(svg_file)
if not path:
return "Invalid file input. Please upload a valid SVG file."
with open(path, "r", encoding="utf-8") as f:
svg_content = f.read()
parsed_svg = generator.parse_svg(svg_content)
return parsed_svg.get("svg_content", "")
except FileNotFoundError:
return "File not found. Please upload a valid SVG file."
except ET.ParseError:
return "Invalid SVG file format. Please upload a valid SVG file."
except Exception as e:
return f"Error processing file: {e}"
def predict_decompose_group(svg_file, svg_text, object_name):
if not object_name.strip():
error_msg = "Please enter a valid object name for the SVG"
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
if svg_file is not None:
svg_content_inner = process_svg(svg_file)
else:
svg_content_inner = svg_text.strip()
if not svg_content_inner:
error_msg = "Please upload an SVG file or enter SVG markup"
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
parsed_svg = generator.parse_svg(svg_content_inner)
if "error" in parsed_svg:
error_msg = parsed_svg["error"]
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
decompose_result = generator.predict_decompose_group(parsed_svg, object_name)
if "error" in decompose_result:
error_msg = decompose_result["error"]
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
decomposed_svg = decompose_result["svg_content"]
animation_suggestions = decompose_result["animation_suggestions"]
decomposed_svg_viewer = create_svg_viewer_html(decomposed_svg)
return (
decomposed_svg,
decomposed_svg,
animation_suggestions,
decomposed_svg_viewer,
)
def update_preview_from_html(html_content: str) -> str:
"""Update animation preview from manually edited HTML content."""
if not html_content.strip():
return create_error_html("⚠️ HTML content is empty")
try:
safe_html_content = html_content.replace('"', "&quot;")
preview_html = f"""
<div style='padding: 20px; width:100%; height:520px; background: #fff; border: 1px solid #eee; border-radius: 8px; display: block;'>
<div style='display: block; align-items: center; margin-bottom: 10px;'>
</div>
<div id='animation-container' style='height: 100%; display: block; justify-content: center; align-items: center; background: #fafafa; border-radius: 4px; padding: 20px; box-sizing: border-box;'>
<iframe srcdoc="{safe_html_content}"
width="100%" height="100%"
style="border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
</div>
"""
return preview_html
except Exception as e:
return create_error_html(f"❌ Error updating preview: {str(e)}")
def create_animation_preview(animation_desc: str, svg_content: str) -> tuple:
"""Create animation preview from description and SVG content."""
if not svg_content.strip():
error_html = create_error_html("⚠️ Please process SVG first")
return error_html, ""
if not animation_desc.strip():
error_html = create_error_html("⚠️ Please describe the animation you want")
return error_html, ""
try:
animation_response, html_content = generator.generate_animation(
animation_desc, svg_content
)
if not html_content:
error_html = create_error_html("❌ Failed to generate animation HTML")
return error_html, animation_response
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
html_path = os.path.join(output_dir, "animation_preview.html")
with open(html_path, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"Animation preview saved to: {html_path}")
fixed_html_content = fix_html_styles_for_preview(html_content)
safe_html_content = fixed_html_content.replace('"', "&quot;")
preview_html = f"""
<div style='padding: 20px; width:100%; background: #fff; border: 1px solid #eee; border-radius: 8px; position: relative;'>
<div style='border:1px solid #ddd; border-radius:8px; background:#fafafa; display:flex; align-items:center; justify-content:center; padding:20px; position:relative; box-sizing:border-box;'>
<iframe
srcdoc="{safe_html_content}"
style="width: 360px; height: 360px; border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
</div>
"""
return preview_html, fixed_html_content
except Exception as e:
error_html = create_error_html(f"❌ Error creating animation: {str(e)}")
return error_html, ""
# Define examples with proper path handling and categories
example_list = {
"Animals": [
[os.path.join(os.path.dirname(__file__), "examples/corgi.svg"), "corgi"],
[os.path.join(os.path.dirname(__file__), "examples/duck.svg"), "duck"],
[os.path.join(os.path.dirname(__file__), "examples/bunny.svg"), "bunny"],
],
"Objects": [
[os.path.join(os.path.dirname(__file__), "examples/rocket.svg"), "rocket"]
],
}
def load_example(example_choice):
for category, examples in example_list.items():
for example in examples:
if example[1] == example_choice:
return example[0], example[1]
return None, None
example_choices = [
example[1] for category in example_list.values() for example in category
]
demo = gr.Blocks(title="SVG Animation Generator", theme=gr.themes.Soft())
with demo:
gr.Markdown("# 🎨 SVG Decomposition & Animation Generator")
gr.Markdown(
"Intelligent SVG decomposition and animation generation powered by MLLM. This tool decomposes SVG structures and generates animations based on your descriptions."
)
with gr.Column():
with gr.Row(scale=2):
with gr.Column(scale=1):
gr.Markdown("## 📤 Input SVG")
with gr.Row(scale=2):
svg_file = gr.File(label="Upload SVG File", file_types=[".svg"])
svg_text = gr.Textbox(
label="Or Paste SVG Code",
lines=8.4,
)
example_dropdown = gr.Dropdown(
choices=example_choices, label="Try an Example", value=None
)
with gr.Column(scale=1):
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 🔍 SVG Analysis")
object_name = gr.Textbox(
label="Name Your Object",
placeholder="Give a name to your SVG (e.g., 'dove', 'robot')",
value="corgi",
)
example_dropdown.change(
fn=load_example,
inputs=[example_dropdown],
outputs=[svg_file, object_name],
)
process_btn = gr.Button("🔄 Decompose Structure", variant="primary")
groups_summary = gr.Textbox(
label="Decomposition Results",
placeholder="MLLM will analyze and decompose the SVG structure...",
lines=6,
interactive=False,
)
with gr.Column(scale=1):
gr.Markdown("## 🎯 Recommeded Animation")
animation_suggestion = gr.Textbox(
label="AI Suggestions",
placeholder="MLLM will suggest animations based on the decomposed structure...",
lines=14.5,
)
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 💡 Refine Decomposition")
groups_feedback = gr.Textbox(
label="Element Structure",
placeholder="If you have specific decomposition in mind, describe it here...",
lines=2,
)
groups_feedback_btn = gr.Button(
"💭 Apply Decomposition Feedback", variant="primary"
)
with gr.Column(scale=1):
gr.Markdown("## ✨ Create Animation")
describe_animation = gr.Textbox(
label="Animation Description",
placeholder="Describe your desired animation (e.g., 'gentle floating motion')",
lines=2,
)
animate_btn = gr.Button("🎬 Generate Animation", variant="primary")
with gr.Row(scale=3):
with gr.Column(scale=1):
svg_content_hidden = gr.Textbox(visible=False)
gr.Markdown("## 🖼️ Decomposed Structure")
decomposed_svg_viewer = gr.HTML(
label="Decomposed SVG",
value="""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<div id='decomposed-svg-container' style='min-height: 150px; display: flex; justify-content: center; align-items: center; border-radius: 4px; padding: 10px;'>
<div style='color: #999; text-align: center;'>Decomposed SVG structure will appear here</div>
</div>
</div>
""",
)
with gr.Column(scale=1):
gr.Markdown("## 🎭 Animation Preview")
animation_preview = gr.HTML(
label="Live Preview",
value="""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<div id='animation-container' style='min-height: 150px; display: flex; justify-content: center; align-items: center; border-radius: 4px; padding: 10px;'>
<div style='color: #999; text-align: center;'>Animation preview will appear here</div>
</div>
</div>
""",
)
with gr.Column():
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 📂 HTML Output")
output_html = gr.Textbox(
label="Output HTML",
lines=10,
placeholder="Generated HTML will appear here. You can edit this HTML and see live preview.",
interactive=True,
)
process_btn.click(
fn=predict_decompose_group,
inputs=[svg_file, svg_text, object_name],
outputs=[
svg_content_hidden,
groups_summary,
animation_suggestion,
decomposed_svg_viewer,
],
)
groups_feedback_btn.click(
fn=generator.feedback_decompose_group,
inputs=[
svg_content_hidden,
groups_feedback,
],
outputs=[
svg_content_hidden,
animation_suggestion,
decomposed_svg_viewer,
],
)
animate_btn.click(
fn=create_animation_preview,
inputs=[
describe_animation,
svg_content_hidden,
],
outputs=[
animation_preview,
output_html,
],
)
output_html.change(
fn=update_preview_from_html,
inputs=[output_html],
outputs=[animation_preview],
)
if __name__ == "__main__":
demo.launch(share=True)