# -*- coding: utf-8 -*-
# ZenCtrl Inpainting Playground (Baseten backend)
#import spaces
import os, json, base64, requests
from io import BytesIO
from PIL import Image, ImageDraw
import gradio as gr
import replicate
# ────────── Secrets & endpoints ──────────
BASETEN_MODEL_URL = os.getenv("BASETEN_MODEL_URL")
BASETEN_API_KEY = os.getenv("BASETEN_API_KEY")
REPLICATE_TOKEN = os.getenv("REPLICATE_API_TOKEN")
# ────────── Globals ──────────
ADAPTER_SIZE = 1024
css = "#col-container {margin:0 auto; max-width:960px;}"
# Background generation via Replicate
def _gen_bg(prompt: str):
    url = replicate.run(
        "google/imagen-4-fast",
        input={"prompt": prompt or "cinematic background", "aspect_ratio": "1:1"},
    )
    url = url[0] if isinstance(url, list) else url
    return Image.open(BytesIO(requests.get(url, timeout=120).content)).convert("RGB")
# Main processing function
def process_image_and_text(subject_image, adapter_dict, prompt, _unused1, _unused2, size=ADAPTER_SIZE, rank=10.0):
    seed, guidance_scale, steps = 42, 2.5, 28
    adapter_image = adapter_dict["image"] if isinstance(adapter_dict, dict) else adapter_dict
    if isinstance(adapter_dict, dict):  # Image/sketch input
        adapter_image = adapter_dict["image"]
        adapter_mask = adapter_dict["mask"]
        if adapter_mask is not None:
            # convert mask -> solid green rectangle on copy of adapter_image
            m = adapter_mask.convert("L").point(lambda p: 255 if p else 0)
            bbox = m.getbbox()
            if bbox:
                rect = Image.new("L", m.size, 0)
                ImageDraw.Draw(rect).rectangle(bbox, fill=255)
                m = rect
            green = Image.new("RGB", adapter_image.size, "#00FF00")
            adapter_image = Image.composite(green, adapter_image, m)
    else:
        adapter_image = adapter_dict
    # def prep(img: Image.Image):
    #     w, h = img.size
    #     m = min(w, h)
    #     return img.crop(((w - m) // 2, (h - m) // 2, (w + m) // 2, (h + m) // 2)).resize((size, size), Image.LANCZOS)
    def prep(img: Image.Image):
        return img.resize((size, size), Image.LANCZOS)
    subj_proc = prep(subject_image)
    adap_proc = prep(adapter_image)
    def b64(img):
        buf = BytesIO()
        img.save(buf, format="PNG")
        return base64.b64encode(buf.getvalue()).decode()
    payload = {
        "prompt": prompt,
        "subject_image": b64(subj_proc),
        "adapter_image": b64(adap_proc),
        "height": size,
        "width": size,
        "steps": steps,
        "seed": seed,
        "guidance_scale": guidance_scale,
        "rank": rank,
    }
    headers = {"Content-Type": "application/json"}
    if BASETEN_API_KEY:
        headers["Authorization"] = f"Api-Key {BASETEN_API_KEY}"
    resp = requests.post(BASETEN_MODEL_URL, headers=headers, json=payload, timeout=180)
    resp.raise_for_status()
    data = resp.json()
    
    # Extract base64 image from 'blended' key
    if "blended" in data:
        try:
            blended_bytes = base64.b64decode(data["raw_result"])
            raw_img = Image.open(BytesIO(blended_bytes)).convert("RGB")
            return raw_img, raw_img
        except Exception:
            raise gr.Error("Failed to decode 'blended' image from Baseten response.")
    else:
        raise gr.Error("Baseten response missing 'blended' image.")
# ────────── Header HTML ──────────
header_html = """
ZenCtrl Inpainting Beta
"""
# ────────── Gradio UI ──────────
with gr.Blocks(css=css, title="ZenCtrl Inpainting") as demo:
    raw_state = gr.State()
    gr.HTML(header_html)
    gr.Markdown(
        "**Generate context-aware images of your subject with ZenCtrl’s inpainting playground.** Upload a subject + optional mask, write a prompt, and hit **Generate**.   \n"
        "Open *Advanced Settings* for an AI-generated background.  \n\n"
        "**Note:** The model was trained mainly on interior scenes and other *rigid* objects. Results on people or highly deformable items may contain visual distortions. \n"
        "In case of High traffic , your requests might be queued and processed one by one by our backend server"
    )
    with gr.Row():
        with gr.Column(scale=2, elem_id="col-container"):
            subj_img = gr.Image(type="pil", label="Subject image")
            ref_img = gr.Image(type="pil", label="Background / Mask image", tool="sketch", brush_color="#00FF00")
            ref_img_ex = gr.Image(type="pil", visible=False)
            # Removed Florence-SAM
            promptbox = gr.Textbox(label="Generation prompt", value="furniture", lines=2)
            run_btn = gr.Button("Generate", variant="primary")
            with gr.Accordion("Advanced Settings", open=False):
                bgprompt = gr.Textbox(label="Background Prompt", value="Scandinavian living room …")
                bg_btn = gr.Button("Generate BG")
        with gr.Column(scale=2):
            output_img = gr.Image(label="Output Image")
            bg_img = gr.Image(label="Background", visible=True)
    # ---------- Example wrapper ---------------------------------
    # def _load_and_show(subj_path, bg_path, prompt_text):
    #     out_path = subj_path.replace(".png", "_out.png")
    #     return (
    #         Image.open(subj_path),                             # → gr.Image widget ok
    #         {"image": Image.open(bg_path), "mask": None},      # ← **dict for sketch!**
    #         prompt_text,                                       # → gr.Textbox
    #         Image.open(out_path)                               # → gr.Image output
    #     )
    def _load_and_show(subj_path, bg_path, prompt_text):
        """
        Takes the three values coming from an Examples row
        and returns FOUR objects – one for every output widget:
          1. subject PIL image               -> subj_img
          2. dict for the sketch component   -> ref_img
          3. prompt string                   -> promptbox
          4. pre-rendered result PIL         -> output_img
        """
        out_path = subj_path.replace(".png", "_out.png")          # your saved result
        return (
            Image.open(subj_path),                                # 1️⃣ subject
            {"image": Image.open(bg_path), "mask": None},         # 2️⃣ sketch dict
            prompt_text,                                          # 3️⃣ prompt
            Image.open(out_path)                                  # 4️⃣ output image
        )
    def ex(subj, bg, prompt):
        return [
            Image.open(subj),
            {"image": Image.open(bg), "mask": None},
            prompt
        ]
    # ---------- Examples ----------------------------------------
    gr.Examples(
        examples=[
            ["examples/sofa1_1.png", "examples/sofa1_bg.png", "add the sofa", "examples/sofa1_out.png"],
            ["examples/sofa2.png", "examples/sofa2_bg.png", "add this sofa", "examples/sofa2_out.png"],
            ["examples/chair1.png", "examples/chair1_bg.png", "add the chair", "examples/chair1_out.png"],
            ["examples/console_table.png", "examples/console_table_bg.png", "Scandinavian console table against a gallery-style wall filled with abstract framed art,", "examples/console_table_out.png"],
            ["examples/office_chair.png", "examples/office_chair_bg.png", "office chair", "examples/office_chair_out.png"],
            ["examples/office_chair1.png", "examples/office_chair1_bg.png", "Executive mesh chair in a modern home office, with matte black wall panels, built-in shelves, ", "examples/office_chair1_out.png"],
            ["examples/bed.png", "examples/bed_in.png", "Low platform bed in a Japandi-style bedroom, surrounded by floating nightstands", "examples/bed_out.png"],
            ["examples/car.png", "examples/car_bg.png", "car on the road", "examples/car_out.png"],
        ],
        inputs        = [subj_img, ref_img, promptbox, output_img],
        outputs       = [subj_img, ref_img, promptbox, output_img],
        fn            = _load_and_show,
        #preprocess    = False,     # data already widget-ready
        cache_examples=False
        )
    # ---------- Buttons & interactions --------------------------
    # run_btn.click(
    #     process_image_and_text,
    #     inputs=[subj_img, ref_img, promptbox, gr.State(False), gr.State("")],
    #     outputs=[gallery, raw_state]
    # )
    run_btn.click(
        process_image_and_text,
        inputs=[subj_img, ref_img, promptbox, gr.State(False), gr.State("")],
        outputs=[output_img, raw_state]
    )
    bg_btn.click(_gen_bg, inputs=[bgprompt], outputs=[bg_img])
# ---------------- Launch ---------------------------------------
if __name__ == "__main__":
    demo.launch()