from dataclasses import dataclass
from typing import List, Optional
import re
import textwrap
from cot_reasoning import VisualizationConfig, AnthropicAPI
@dataclass
class ToTNode:
    """Data class representing a node in the Tree of Thoughts"""
    id: str
    content: str
    parent_id: Optional[str] = None
    children: List['ToTNode'] = None
    is_answer: bool = False
    def __post_init__(self):
        if self.children is None:
            self.children = []
@dataclass
class ToTResponse:
    """Data class representing a complete ToT response"""
    question: str
    root: ToTNode
    answer: Optional[str] = None
def parse_tot_response(response_text: str, question: str) -> ToTResponse:
    """Parse ToT response text to extract nodes and build the tree"""
    # Parse nodes
    node_pattern = r'\s*(.*?)\s*'
    nodes_dict = {}
    
    # First pass: create all nodes
    for match in re.finditer(node_pattern, response_text, re.DOTALL):
        node_id = match.group(1)
        parent_id = match.group(2)
        content = match.group(3).strip()
        
        node = ToTNode(id=node_id, content=content, parent_id=parent_id)
        nodes_dict[node_id] = node
    # Second pass: build tree relationships
    root = None
    for node in nodes_dict.values():
        if node.parent_id is None:
            root = node
        else:
            parent = nodes_dict.get(node.parent_id)
            if parent:
                parent.children.append(node)
    # Parse answer if present
    answer_pattern = r'\s*(.*?)\s*'
    answer_match = re.search(answer_pattern, response_text, re.DOTALL)
    answer = answer_match.group(1).strip() if answer_match else None
    
    if answer:
        # Mark the node leading to the answer
        for node in nodes_dict.values():
            if node.content.strip() in answer.strip():
                node.is_answer = True
    return ToTResponse(question=question, root=root, answer=answer)
def create_mermaid_diagram(tot_response: ToTResponse, config: VisualizationConfig) -> str:
    """Convert ToT response to Mermaid diagram"""
    diagram = ['
', 'graph TD']
    
    # Add question node
    question_content = wrap_text(tot_response.question, config)
    diagram.append(f'    Q["{question_content}"]')
    
    # Track leaf nodes for connecting to answer
    leaf_nodes = []
    
    def add_node_and_children(node: ToTNode, parent_id: Optional[str] = None):
        content = wrap_text(node.content, config)
        node_style = 'answer' if node.is_answer else 'default'
        
        # Add node
        diagram.append(f'    {node.id}["{content}"]')
        
        # Add connection from parent
        if parent_id:
            diagram.append(f'    {parent_id} --> {node.id}')
        
        # Process children
        if node.children:
            for child in node.children:
                add_node_and_children(child, node.id)
        else:
            # This is a leaf node
            leaf_nodes.append(node.id)
    
    # Build tree structure
    if tot_response.root:
        diagram.append(f'    Q --> {tot_response.root.id}')
        add_node_and_children(tot_response.root)
    
    # Add final answer node if answer exists
    if tot_response.answer:
        answer_content = wrap_text(tot_response.answer, config)
        diagram.append(f'    Answer["{answer_content}"]')
        # Connect all leaf nodes to the answer
        for leaf_id in leaf_nodes:
            diagram.append(f'    {leaf_id} --> Answer')
        diagram.append('    class Answer final_answer;')
    
    # Add styles
    diagram.extend([
        '    classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
        '    classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
        '    classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
        '    classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
        '    class Q question;',
        '    linkStyle default stroke:#666,stroke-width:2px;'
    ])
    
    diagram.append('
')
    return '\n'.join(diagram)
def wrap_text(text: str, config: VisualizationConfig) -> str:
    """Wrap text to fit within box constraints"""
    text = text.replace('\n', ' ').replace('"', "'")
    wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
    
    if len(wrapped_lines) > config.max_lines:
        # Option 1: Simply truncate and add ellipsis to the last line
        wrapped_lines = wrapped_lines[:config.max_lines]
        wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
        
        # Option 2 (alternative): Include part of the next line to show continuity
        # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
        # wrapped_lines = wrapped_lines[:config.max_lines-1]
        # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
    
    return "
".join(wrapped_lines)