Decomate / app.py
Milhaud's picture
refactor: update docstring for create_svg_viewer_html function to improve clarity
df5263a
import gradio as gr
import xml.etree.ElementTree as ET
import re
import anthropic
import os
from dotenv import load_dotenv
from utils import svg_to_png_base64, fix_html_styles_for_preview
import pathlib
from logger import setup_logger
import html
import uuid
load_dotenv(override=True)
logger = setup_logger()
class SVGAnimationGenerator:
def __init__(self):
self.client = None
self.predict_decompose_group_prompt = self._get_prompt(
"prompts/predict_decompose_group.txt"
)
self.feedback_decompose_group_prompt = self._get_prompt(
"prompts/feedback_decompose_group.txt"
)
self.generate_animation_prompt = self._get_prompt(
"prompts/generate_animation.txt"
)
if "ANTHROPIC_API_KEY" in os.environ:
self.client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
def _get_prompt(self, prompt_file_path: str) -> str:
try:
with open(prompt_file_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return "Prompt file not found. Please check the path."
def parse_svg(self, svg_content: str) -> dict:
try:
svg_content = re.sub(r'xmlns[^=]*="[^"]*"', "", svg_content)
svg_content = re.sub(r"<svg[^>]*>", "<svg>", svg_content)
return {"svg_content": svg_content}
except Exception as e:
return {"error": f"SVG parsing error: {e}"}
def predict_decompose_group(self, parsed_svg: dict, object_name: str) -> dict:
try:
svg_content = parsed_svg["svg_content"]
image_media_type, image_data = svg_to_png_base64(svg_content)
if not image_data:
return {"error": "Failed to convert SVG to PNG"}
prompt = self.predict_decompose_group_prompt.format(
object_name=object_name, svg_content=svg_content
)
logger.info(f"Decomposition Prompt for {object_name}:\n{prompt}")
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=10000,
messages=[
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": image_media_type,
"data": image_data,
},
},
{"type": "text", "text": prompt},
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": "<animation_plan>"}],
},
],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
decomposed_svg_match = re.search(
r"<decomposed_svg>(.*?)</decomposed_svg>", response_text, re.DOTALL
)
animation_suggestions_match = re.search(
r"<animation_suggestions>(.*?)</animation_suggestions>",
response_text,
re.DOTALL,
)
print(
"[SVG Decompose] Decomposed SVG found", decomposed_svg_match is not None
)
if decomposed_svg_match and animation_suggestions_match:
decomposed_svg_text = decomposed_svg_match.group(1).strip()
animation_suggestions = animation_suggestions_match.group(1).strip()
print("[SVG Decompose] Animation suggestions found")
return {
"svg_content": decomposed_svg_text,
"animation_suggestions": animation_suggestions,
}
else:
return {
"error": "Decomposed SVG and Animation Suggestion not found in response."
}
except Exception as e:
return {"error": f"Error during MLLM prediction: {e}"}
def feedback_decompose_group(self, svg_content: str, feedback: str) -> tuple:
try:
# Parse the SVG content first
parsed_svg = self.parse_svg(svg_content)
if "error" in parsed_svg:
error_message = parsed_svg["error"]
error_html = create_error_html(error_message)
return "", "", error_html
prompt = self.feedback_decompose_group_prompt.format(
parsed_svg=parsed_svg, feedback=feedback
)
logger.info(f"Feedback Prompt:\n{prompt}")
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=10000,
messages=[
{
"role": "user",
"content": [
{"type": "text", "text": prompt},
],
}
],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
decomposed_svg_match = re.search(
r"<decomposed_svg>(.*?)</decomposed_svg>", response_text, re.DOTALL
)
animation_suggestions_match = re.search(
r"<animation_suggestions>(.*?)</animation_suggestions>",
response_text,
re.DOTALL,
)
if decomposed_svg_match and animation_suggestions_match:
decomposed_svg_text = decomposed_svg_match.group(1).strip()
animation_suggestions = animation_suggestions_match.group(1).strip()
viewer_html = create_svg_viewer_html(decomposed_svg_text)
return decomposed_svg_text, animation_suggestions, viewer_html
else:
error_message = (
"Decomposed SVG and Animation Suggestion not found in response."
)
error_html = create_error_html(error_message)
return "", "", error_html
except Exception as e:
error_message = f"Error during MLLM feedback prediction: {e}"
error_html = create_error_html(error_message)
return "", "", error_html
def generate_animation(self, proposed_animation: str, svg_content: str) -> tuple:
try:
prompt = self.generate_animation_prompt.format(
svg_content=svg_content, proposed_animation=proposed_animation
)
logger.info(f"Animation Generation Prompt:\n{prompt}")
if self.client:
response = self.client.messages.create(
model="claude-sonnet-4-20250514",
max_tokens=20000,
messages=[{"role": "user", "content": prompt}],
)
response_text = response.content[0].text
logger.info(f"Model Response:\n{response_text}")
# Extract HTML content from Claude's response
html_match = re.search(
r"<html_output>(.*?)</html_output>", response_text, re.DOTALL
)
if html_match:
html_content = html_match.group(1).strip()
return response_text, html_content
else:
return response_text, ""
else:
error_msg = "Anthropic API client not initialized"
return f"<html><body><h3>Error: {error_msg}</h3></body></html>", ""
except Exception as e:
error_msg = f"Error generating animation: {e}"
return f"<html><body><h3>{error_msg}</h3></body></html>", ""
def _sanitize_svg(svg: str) -> str:
"""Script tag/inline on*Handler removal for safe preview"""
svg = re.sub(r"<\s*script\b[^>]*>.*?<\s*/\s*script\s*>", "", svg, flags=re.I | re.S)
svg = re.sub(r"\son[a-zA-Z]+\s*=\s*(['\"]).*?\1", "", svg, flags=re.I | re.S)
return svg
def _fix_svg_markup(svg: str) -> str:
"""SVG attribute to fit the viewer (size/ratio/problem style modification)"""
s = svg
s = s.replace('preserveAspectRatio="none"', 'preserveAspectRatio="xMidYMid meet"')
s = s.replace(
'style="display: block; overflow: hidden; position: absolute; left: 0px; top: 0px;"',
'style="display:block; width:100%; height:100%;"',
)
s = re.sub(r'width="[^"]+"', 'width="100%"', s, count=1)
s = re.sub(r'height="[^"]+"', 'height="100%"', s, count=1)
return s
generator = SVGAnimationGenerator()
def create_error_html(message: str) -> str:
"""Create formatted error message HTML with XSS protection."""
safe_message = html.escape(str(message))
return f"""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<h3>{safe_message}</h3>
</div>
"""
def create_svg_viewer_html(svg_content: str) -> str:
"""DECOMPOSED SVG Rendering in IFRAME with preview"""
svg = (svg_content or "").strip()
if not (svg.startswith("<svg") and svg.endswith("</svg>")):
return create_error_html("Invalid SVG format")
svg = _sanitize_svg(svg)
svg = _fix_svg_markup(svg)
uid = f"svg-preview-{uuid.uuid4().hex}"
doc = f"""<!doctype html>
<html>
<head>
<meta charset="utf-8">
<style>
:root {{
--hl-color: #ff3b30; /* 🔧 원하는 색으로 바꿔도 됨 */
--hl-width: 3.5px;
}}
#{uid} svg {{
max-width: 100%;
max-height: 100%;
display: block;
}}
/* 포인터 이벤트는 도형 요소만 */
#{uid} svg g,
#{uid} svg path,
#{uid} svg rect,
#{uid} svg circle,
#{uid} svg ellipse,
#{uid} svg polygon,
#{uid} svg polyline,
#{uid} svg line,
#{uid} svg text {{
cursor: pointer;
pointer-events: visiblePainted;
transition: filter .12s ease, stroke-width .12s ease;
}}
/* 보조 요소는 포인터/효과 제외 */
#{uid} svg defs,
#{uid} svg clipPath,
#{uid} svg mask,
#{uid} svg title,
#{uid} svg desc {{
pointer-events: none !important;
}}
/* ✅ 외곽선 하이라이트: 기존 stroke가 있어도 강제로 덮어쓰기 */
#{uid} .hl {{
stroke: var(--hl-color) !important;
stroke-width: var(--hl-width) !important;
paint-order: stroke fill; /* stroke를 위로 */
vector-effect: non-scaling-stroke;
filter: drop-shadow(0 0 6px rgba(0,0,0,.35));
}}
/* 필요하다면 다른 요소 미세 디밍 (기본 OFF) */
/* #{uid} .dim {{ opacity: .65; }} */
#{uid}-wrap {{
position:relative; padding:20px; background:#fff;
border:1px solid #eee; border-radius:8px; height:100%; box-sizing:border-box;
}}
#{uid} {{
border:1px solid #ddd; border-radius:8px; background:#fafafa;
height:100%; min-height:360px; display:flex; align-items:center;
justify-content:center; padding:20px; position:relative; box-sizing:border-box;
}}
#{uid}-tooltip {{
position: absolute; display: none; pointer-events: none;
background: rgba(0,0,0,0.9); color: #fff; border: 2px solid #fff;
border-radius: 6px; padding: 6px 10px; font-size: 12px; font-weight: 600;
white-space: nowrap; z-index: 10;
}}
</style>
</head>
<body style="margin:0;height:100vh;">
<div id="{uid}-wrap">
<div id="{uid}" class="svg-container">
{svg}
<div class="tooltip" id="{uid}-tooltip"></div>
</div>
</div>
<script>
(function() {{
const root = document.getElementById("{uid}");
if (!root) return;
const svg = root.querySelector("svg");
if (!svg) return;
const tooltip = document.getElementById("{uid}-tooltip");
const GEOM_SEL = "g,path,rect,circle,ellipse,polygon,polyline,line,text";
const DIM_OTHERS = false; // 🔧 true로 바꾸면 나머지 살짝 디밍
function allGeom() {{
return svg.querySelectorAll(GEOM_SEL);
}}
function closestWithId(node) {{
let cur = node;
while (cur && cur !== svg) {{
if (cur.id) return cur;
cur = cur.parentNode;
}}
return null;
}}
function clearMarks() {{
svg.querySelectorAll(".hl").forEach(el => el.classList.remove("hl"));
if (DIM_OTHERS) svg.querySelectorAll(".dim").forEach(el => el.classList.remove("dim"));
}}
function highlightOwner(owner) {{
// 그룹이면 하위 도형 전체에 hl, 도형이면 자기 자신에만
const targets = owner.matches(GEOM_SEL)
? [owner, ...owner.querySelectorAll(GEOM_SEL)]
: [...owner.querySelectorAll(GEOM_SEL)];
targets.forEach(el => el.classList.add("hl"));
}}
function dimExcept(owner) {{
if (!DIM_OTHERS) return;
const keep = new Set([owner, ...owner.querySelectorAll(GEOM_SEL)]);
allGeom().forEach(el => {{
if (!keep.has(el)) el.classList.add("dim");
}});
}}
function moveTooltip(e) {{
const rect = root.getBoundingClientRect();
tooltip.style.left = (e.clientX - rect.left + 10) + "px";
tooltip.style.top = (e.clientY - rect.top + 10) + "px";
}}
svg.addEventListener("pointerover", (e) => {{
const owner = closestWithId(e.target);
clearMarks();
if (owner && owner.id) {{
highlightOwner(owner);
dimExcept(owner);
tooltip.textContent = owner.id;
tooltip.style.display = "block";
moveTooltip(e);
}} else {{
tooltip.style.display = "none";
}}
}});
svg.addEventListener("pointermove", (e) => {{
if (tooltip.style.display === "block") moveTooltip(e);
}});
svg.addEventListener("pointerleave", () => {{
clearMarks();
tooltip.style.display = "none";
}});
}})();
</script>
</body>
</html>
"""
safe_doc = doc.replace('"', "&quot;")
return f"""
<div style='padding: 20px; width:100%; height:520px; background: #fff; border: 1px solid #eee; border-radius: 8px; display: block;'>
<iframe
srcdoc="{safe_doc}"
width="100%"
height="100%"
style="border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
"""
def _extract_path_from_gradio_file(svg_file) -> str | None:
if isinstance(svg_file, (str, pathlib.Path)):
return str(svg_file)
if isinstance(svg_file, dict) and "name" in svg_file:
return svg_file["name"]
if hasattr(svg_file, "name"):
return svg_file.name
return None
def process_svg(svg_file):
if svg_file is None:
return "Please upload an SVG file"
try:
path = _extract_path_from_gradio_file(svg_file)
if not path:
return "Invalid file input. Please upload a valid SVG file."
with open(path, "r", encoding="utf-8") as f:
svg_content = f.read()
parsed_svg = generator.parse_svg(svg_content)
return parsed_svg.get("svg_content", "")
except FileNotFoundError:
return "File not found. Please upload a valid SVG file."
except ET.ParseError:
return "Invalid SVG file format. Please upload a valid SVG file."
except Exception as e:
return f"Error processing file: {e}"
def predict_decompose_group(svg_file, svg_text, object_name):
if not object_name.strip():
error_msg = "Please enter a valid object name for the SVG"
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
if svg_file is not None:
svg_content_inner = process_svg(svg_file)
else:
svg_content_inner = svg_text.strip()
if not svg_content_inner:
error_msg = "Please upload an SVG file or enter SVG markup"
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
parsed_svg = generator.parse_svg(svg_content_inner)
if "error" in parsed_svg:
error_msg = parsed_svg["error"]
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
decompose_result = generator.predict_decompose_group(parsed_svg, object_name)
if "error" in decompose_result:
error_msg = decompose_result["error"]
error_html = create_error_html(error_msg)
return "", error_msg, "", error_html
decomposed_svg = decompose_result["svg_content"]
animation_suggestions = decompose_result["animation_suggestions"]
decomposed_svg_viewer = create_svg_viewer_html(decomposed_svg)
return (
decomposed_svg,
decomposed_svg,
animation_suggestions,
decomposed_svg_viewer,
)
def update_preview_from_html(html_content: str) -> str:
"""Update animation preview from manually edited HTML content."""
if not html_content.strip():
return create_error_html("⚠️ HTML content is empty")
try:
safe_html_content = html_content.replace('"', "&quot;")
preview_html = f"""
<div style='padding: 20px; width:100%; height:520px; background: #fff; border: 1px solid #eee; border-radius: 8px; display: block;'>
<div style='display: block; align-items: center; margin-bottom: 10px;'>
</div>
<div id='animation-container' style='height: 100%; display: block; justify-content: center; align-items: center; background: #fafafa; border-radius: 4px; padding: 20px; box-sizing: border-box;'>
<iframe srcdoc="{safe_html_content}"
width="100%" height="100%"
style="border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
</div>
"""
return preview_html
except Exception as e:
return create_error_html(f"❌ Error updating preview: {str(e)}")
def create_animation_preview(animation_desc: str, svg_content: str) -> tuple:
"""Create animation preview from description and SVG content."""
if not svg_content.strip():
error_html = create_error_html("⚠️ Please process SVG first")
return error_html, ""
if not animation_desc.strip():
error_html = create_error_html("⚠️ Please describe the animation you want")
return error_html, ""
try:
animation_response, html_content = generator.generate_animation(
animation_desc, svg_content
)
if not html_content:
error_html = create_error_html("❌ Failed to generate animation HTML")
return error_html, animation_response
output_dir = "output"
os.makedirs(output_dir, exist_ok=True)
html_path = os.path.join(output_dir, "animation_preview.html")
with open(html_path, "w", encoding="utf-8") as f:
f.write(html_content)
print(f"Animation preview saved to: {html_path}")
fixed_html_content = fix_html_styles_for_preview(html_content)
safe_html_content = fixed_html_content.replace('"', "&quot;")
preview_html = f"""
<div style='padding: 20px; width:100%; background: #fff; border: 1px solid #eee; border-radius: 8px; position: relative;'>
<div style='border:1px solid #ddd; border-radius:8px; background:#fafafa; display:flex; align-items:center; justify-content:center; padding:20px; position:relative; box-sizing:border-box;'>
<iframe
srcdoc="{safe_html_content}"
style="width: 360px; height: 360px; border:none; border-radius:8px; overflow:hidden;"
sandbox="allow-scripts allow-same-origin">
</iframe>
</div>
</div>
"""
return preview_html, fixed_html_content
except Exception as e:
error_html = create_error_html(f"❌ Error creating animation: {str(e)}")
return error_html, ""
# Define examples with proper path handling and categories
example_list = {
"Animals": [
[os.path.join(os.path.dirname(__file__), "examples/corgi.svg"), "corgi"],
[os.path.join(os.path.dirname(__file__), "examples/duck.svg"), "duck"],
[os.path.join(os.path.dirname(__file__), "examples/bunny.svg"), "bunny"],
],
"Objects": [
[os.path.join(os.path.dirname(__file__), "examples/rocket.svg"), "rocket"]
],
}
def load_example(example_choice):
for category, examples in example_list.items():
for example in examples:
if example[1] == example_choice:
return example[0], example[1]
return None, None
example_choices = [
example[1] for category in example_list.values() for example in category
]
demo = gr.Blocks(title="SVG Animation Generator", theme=gr.themes.Soft())
with demo:
gr.Markdown("# 🎨 SVG Decomposition & Animation Generator")
gr.Markdown(
"Intelligent SVG decomposition and animation generation powered by MLLM. This tool decomposes SVG structures and generates animations based on your descriptions."
)
with gr.Column():
with gr.Row(scale=2):
with gr.Column(scale=1):
gr.Markdown("## 📤 Input SVG")
with gr.Row(scale=2):
svg_file = gr.File(label="Upload SVG File", file_types=[".svg"])
svg_text = gr.Textbox(
label="Or Paste SVG Code",
lines=8.4,
)
example_dropdown = gr.Dropdown(
choices=example_choices, label="Try an Example", value=None
)
with gr.Column(scale=1):
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 🔍 SVG Analysis")
object_name = gr.Textbox(
label="Name Your Object",
placeholder="Give a name to your SVG (e.g., 'dove', 'robot')",
value="corgi",
)
example_dropdown.change(
fn=load_example,
inputs=[example_dropdown],
outputs=[svg_file, object_name],
)
process_btn = gr.Button("🔄 Decompose Structure", variant="primary")
groups_summary = gr.Textbox(
label="Decomposition Results",
placeholder="MLLM will analyze and decompose the SVG structure...",
lines=6,
interactive=False,
)
with gr.Column(scale=1):
gr.Markdown("## 🎯 Recommeded Animation")
animation_suggestion = gr.Textbox(
label="AI Suggestions",
placeholder="MLLM will suggest animations based on the decomposed structure...",
lines=14.5,
)
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 💡 Refine Decomposition")
groups_feedback = gr.Textbox(
label="Element Structure",
placeholder="If you have specific decomposition in mind, describe it here...",
lines=2,
)
groups_feedback_btn = gr.Button(
"💭 Apply Decomposition Feedback", variant="primary"
)
with gr.Column(scale=1):
gr.Markdown("## ✨ Create Animation")
describe_animation = gr.Textbox(
label="Animation Description",
placeholder="Describe your desired animation (e.g., 'gentle floating motion')",
lines=2,
)
animate_btn = gr.Button("🎬 Generate Animation", variant="primary")
with gr.Row(scale=3):
with gr.Column(scale=1):
svg_content_hidden = gr.Textbox(visible=False)
gr.Markdown("## 🖼️ Decomposed Structure")
decomposed_svg_viewer = gr.HTML(
label="Decomposed SVG",
value="""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<div id='decomposed-svg-container' style='min-height: 150px; display: flex; justify-content: center; align-items: center; border-radius: 4px; padding: 10px;'>
<div style='color: #999; text-align: center;'>Decomposed SVG structure will appear here</div>
</div>
</div>
""",
)
with gr.Column(scale=1):
gr.Markdown("## 🎭 Animation Preview")
animation_preview = gr.HTML(
label="Live Preview",
value="""
<div style='padding: 40px; text-align: center; color: #666; border: 2px dashed #ddd; border-radius: 10px;'>
<div id='animation-container' style='min-height: 150px; display: flex; justify-content: center; align-items: center; border-radius: 4px; padding: 10px;'>
<div style='color: #999; text-align: center;'>Animation preview will appear here</div>
</div>
</div>
""",
)
with gr.Column():
with gr.Row(scale=1):
with gr.Column(scale=1):
gr.Markdown("## 📂 HTML Output")
output_html = gr.Textbox(
label="Output HTML",
lines=10,
placeholder="Generated HTML will appear here. You can edit this HTML and see live preview.",
interactive=True,
)
process_btn.click(
fn=predict_decompose_group,
inputs=[svg_file, svg_text, object_name],
outputs=[
svg_content_hidden,
groups_summary,
animation_suggestion,
decomposed_svg_viewer,
],
)
groups_feedback_btn.click(
fn=generator.feedback_decompose_group,
inputs=[
svg_content_hidden,
groups_feedback,
],
outputs=[
svg_content_hidden,
animation_suggestion,
decomposed_svg_viewer,
],
)
animate_btn.click(
fn=create_animation_preview,
inputs=[
describe_animation,
svg_content_hidden,
],
outputs=[
animation_preview,
output_html,
],
)
output_html.change(
fn=update_preview_from_html,
inputs=[output_html],
outputs=[animation_preview],
)
if __name__ == "__main__":
demo.launch(share=True)