import json
import logging
import uuid
import asyncio
from typing import  Optional
import gradio as gr
from api import request_sse_stream_parsed, stop_chat
from utils import contains_chinese, replace_chinese_punctuation
logger = logging.getLogger(__name__)
running_tasks = {}
from typing import Optional
# ========================= Gradio Integration =========================
def _init_render_state():
    return {
        "agent_order": [],
        "agents": {},  # agent_id -> {"agent_name": str, "tool_call_order": [], "tools": {tool_call_id: {...}}}
        "current_agent_id": None,
        "errors": [],
    }
def _append_show_text(tool_entry: dict, delta: str):
    existing = tool_entry.get("content", "")
    tool_entry["content"] = existing + delta
def _is_empty_payload(value) -> bool:
    if value is None:
        return True
    if isinstance(value, str):
        stripped = value.strip()
        return stripped == "" or stripped in ("{}", "[]")
    if isinstance(value, (dict, list, tuple, set)):
        return len(value) == 0
    return False
def _render_markdown(state: dict) -> str:
    lines = []
    emoji_cycle = ["๐ง ", "๐", "๐ ๏ธ", "๐", "๐ค", "๐งช", "๐", "๐งญ", "โ๏ธ", "๐งฎ"]
    # Render errors first if any
    if state.get("errors"):
        lines.append("### โ Errors")
        for idx, err in enumerate(state["errors"], start=1):
            lines.append(f"- **Error {idx}**: {err}")
        lines.append("\n---\n")
    for idx, agent_id in enumerate(state.get("agent_order", [])):
        agent = state["agents"].get(agent_id, {})
        agent_name = agent.get("agent_name", "unknown")
        emoji = emoji_cycle[idx % len(emoji_cycle)]
        lines.append(f"### {emoji} Agent: {agent_name}")
        for call_id in agent.get("tool_call_order", []):
            call = agent["tools"].get(call_id, {})
            tool_name = call.get("tool_name", "unknown_tool")
            if tool_name in ("show_text", "message"):
                content = call.get("content", "")
                if content:
                    lines.append(content)
            else:
                tool_input = call.get("input")
                tool_output = call.get("output")
                has_input = not _is_empty_payload(tool_input)
                has_output = not _is_empty_payload(tool_output)
                if not has_input and not has_output:
                    # No parameters, only show tool name with emoji on separate line
                    if tool_name == "Partial Summary":
                        lines.append("\n๐กPartial Summary\n")
                    else:
                        lines.append(f"\n๐ง{tool_name}\n")
                else:
                    # Show as collapsible details for any tool with input or output
                    if tool_name == "Partial Summary":
                        summary = f"๐ก{tool_name} ({call_id[:8]})"
                    else:
                        summary = f"๐ง{tool_name} ({call_id[:8]})"
                    lines.append(f"\n{summary}
")
                    if has_input:
                        pretty = json.dumps(tool_input, ensure_ascii=False, indent=2)
                        lines.append("\n**Input**:\n")
                        lines.append(f"```json\n{pretty}\n```")
                    if has_output:
                        pretty = json.dumps(tool_output, ensure_ascii=False, indent=2)
                        lines.append("\n**Output**:\n")
                        lines.append(f"```json\n{pretty}\n```")
                    lines.append(" \n")
        lines.append("\n---\n")
    return "\n".join(lines) if lines else "Waiting..."
def _update_state_with_event(ui_state: dict, state: dict, message: dict):
    event = message.get("event")
    data = message.get("data", {})
    if event == "job_started":
        chat_id = data.get("chat_id")
        state["chat_id"] = chat_id
        ui_state["chat_id"] = chat_id
    elif event == "start_of_agent":
        agent_id = data.get("agent_id")
        agent_name = data.get("agent_name", "unknown")
        if agent_id and agent_id not in state["agents"]:
            state["agents"][agent_id] = {
                "agent_name": agent_name,
                "tool_call_order": [],
                "tools": {}
            }
            state["agent_order"].append(agent_id)
        state["current_agent_id"] = agent_id
    elif event == "end_of_agent":
        # End marker, no special handling needed, keep structure
        state["current_agent_id"] = None
    elif event == "tool_call":
        tool_call_id = data.get("tool_call_id")
        tool_name = data.get("tool_name", "unknown_tool")
        agent_id = state.get("current_agent_id") or (state["agent_order"][-1] if state["agent_order"] else None)
        if not agent_id:
            return state
        agent = state["agents"].setdefault(agent_id, {"agent_name": "unknown", "tool_call_order": [], "tools": {}})
        tools = agent["tools"]
        if tool_call_id not in tools:
            tools[tool_call_id] = {"tool_name": tool_name}
            agent["tool_call_order"].append(tool_call_id)
        entry = tools[tool_call_id]
        if tool_name == "show_text" and "delta_input" in data:
            delta = data.get("delta_input", {}).get("text", "")
            _append_show_text(entry, delta)
        elif tool_name == "show_text" and "tool_input" in data:
            ti = data.get("tool_input")
            text = ""
            if isinstance(ti,dict):
                text = ti.get("text", "") or ((ti.get('result') or {}).get("text") if isinstance(ti.get('result'),dict) else "")
            elif isinstance(ti,str):
                text = ti
            if text:
                _append_show_text(entry, text)
        else:
            # Distinguish between input and output:
            if "tool_input" in data:
                # Could be input (first time) or output with result (second time)
                ti = data["tool_input"]
                # If contains result, assign to output; otherwise assign to input
                if isinstance(ti, dict) and "result" in ti:
                    entry["output"] = ti
                else:
                    # Only update input if we don't already have valid input data, or if the new data is not empty
                    if "input" not in entry or not _is_empty_payload(ti):
                        entry["input"] = ti
    elif event == "message":
        # Same incremental text display as show_text, aggregated by message_id
        message_id = data.get("message_id")
        agent_id = state.get("current_agent_id") or (state["agent_order"][-1] if state["agent_order"] else None)
        if not agent_id:
            return state
        agent = state["agents"].setdefault(agent_id, {"agent_name": "unknown", "tool_call_order": [], "tools": {}})
        tools = agent["tools"]
        if message_id not in tools:
            tools[message_id] = {"tool_name": "message"}
            agent["tool_call_order"].append(message_id)
        entry = tools[message_id]
        delta_content = (data.get("delta") or {}).get("content", "")
        if isinstance(delta_content, str) and delta_content:
            _append_show_text(entry, delta_content)
    elif event == "error":
        # Collect errors, display uniformly during rendering
        err_text = data.get("error") if isinstance(data, dict) else None
        if not err_text:
            try:
                err_text = json.dumps(data, ensure_ascii=False)
            except Exception:
                err_text = str(data)
        state.setdefault("errors", []).append(err_text)
    else:
        # Ignore heartbeat or other events
        pass
    return state
def _spinner_markup(running: bool) -> str:
    if not running:
        return ""
    return (
        "\n\n
\n\n"
    )
async def gradio_run(query: str, ui_state: Optional[dict]):
    query = replace_chinese_punctuation(query or "")
    if contains_chinese(query):
        warning_html = (
            ""
            "
๐ก Notice
"
            "
We only support English input for the time being. "
            "Please translate your question to English and try again.
"
            "
 "
        )
        yield (
            warning_html,
            gr.update(interactive=True),
            gr.update(interactive=False),
            ui_state or {"chat_id": None}
        )
        return
    state = _init_render_state()
    
    task_id = str(uuid.uuid4())
    if ui_state is None:
        ui_state = {"chat_id": None}
    ui_state["task_id"] = task_id
    
    # Initial: disable Run, enable Stop, and show spinner at bottom of text
    yield (
        _render_markdown(state) + _spinner_markup(True),
        gr.update(interactive=False),
        gr.update(interactive=True),
        ui_state
    )
    
    try:
        current_task = asyncio.current_task()
        running_tasks[task_id] = current_task
        
        async for message in request_sse_stream_parsed(query):
            if current_task.cancelled():
                break
                
            state = _update_state_with_event(ui_state, state, message)
            md = _render_markdown(state)
            yield (
                md + _spinner_markup(True),
                gr.update(interactive=False),
                gr.update(interactive=True),
                ui_state
            )
    except asyncio.CancelledError:
        cancelled_html = (
            ""
            "
๐ Task Cancelled
"
            "
The current task has been cancelled successfully.
"
            "
 "
        )
        existing_content = _render_markdown(state)
        final_content = cancelled_html + existing_content if existing_content and existing_content != "Waiting..." else cancelled_html
        yield (
            final_content,
            gr.update(interactive=True),
            gr.update(interactive=False),
            ui_state
        )
        return
    finally:
        if task_id in running_tasks:
            del running_tasks[task_id]
    
    # End: enable Run, disable Stop, remove spinner
    yield (
        _render_markdown(state),
        gr.update(interactive=True),
        gr.update(interactive=False),
        ui_state
    )
async def stop_current(ui_state: Optional[dict]):
    if ui_state is None:
        ui_state = {}
    
    task_id = ui_state.get("task_id")
    if task_id and task_id in running_tasks:
        task = running_tasks[task_id]
        if task and not task.done():
            task.cancel()
            logger.info(f"Task has been cancelled: {task_id}")
    
    chat_id = ui_state.get("chat_id")
    if chat_id:
        try:
            res = await stop_chat(chat_id)
            logger.info(f"Chat has been stopped: {chat_id}, res: {res}")
        except Exception as e:
            logger.error(f"Stop chat API call failed: {e}")
    
    return (
        gr.update(interactive=True),
        gr.update(interactive=False),
    )
def build_demo():
    custom_css = """
    #log-view { border: 1px solid #ececec; padding: 12px; border-radius: 8px; scroll-behavior: smooth; }
    """
    with gr.Blocks(css=custom_css, title="MiroMind Open-Source Deep Research") as demo:
        gr.HTML(
            "MiroMind Open-Source Deep Research
"
            ""
            ""
        )
        with gr.Row():
            inp = gr.Textbox(lines=3, label="Question (English only)")
        with gr.Row():
            run_btn = gr.Button("Send")
            stop_btn = gr.Button("Stop", variant="stop", interactive=False)
        out_md = gr.Markdown("", elem_id="log-view")
        ui_state = gr.State({"chat_id": None})
        # run: outputs -> markdown, run_btn(update), stop_btn(update), ui_state
        run_btn.click(fn=gradio_run, inputs=[inp, ui_state], outputs=[out_md, run_btn, stop_btn, ui_state])
        # stop: outputs -> run_btn(update), stop_btn(update)
        stop_btn.click(fn=stop_current, inputs=[ui_state], outputs=[run_btn, stop_btn])
    return demo
if __name__ == "__main__":
    demo = build_demo()
    demo.launch(favicon_path="./favicon.ico")