Spaces:
Running
Running
| import re | |
| import requests | |
| import textwrap | |
| from dataclasses import dataclass | |
| from typing import List, Optional | |
| class CoTStep: | |
| """Data class representing a single CoT step""" | |
| number: int | |
| content: str | |
| class CoTResponse: | |
| """Data class representing a complete CoT response""" | |
| question: str | |
| steps: List[CoTStep] | |
| answer: Optional[str] = None | |
| class VisualizationConfig: | |
| """Configuration for CoT visualization""" | |
| max_chars_per_line: int = 40 | |
| max_lines: int = 4 | |
| truncation_suffix: str = "..." | |
| class AnthropicAPI: | |
| """Class to handle interactions with the Anthropic API""" | |
| def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"): | |
| self.api_key = api_key | |
| self.model = model | |
| self.base_url = "https://api.anthropic.com/v1/messages" | |
| self.headers = { | |
| "x-api-key": api_key, | |
| "anthropic-version": "2023-06-01", | |
| "content-type": "application/json" | |
| } | |
| def generate_response(self, prompt: str, max_tokens: int = 1024, prompt_format: str = None) -> str: | |
| """Generate a response using the Anthropic API""" | |
| formatted_prompt = self._format_prompt(prompt, prompt_format) if prompt_format else prompt | |
| data = { | |
| "model": self.model, | |
| "messages": [{"role": "user", "content": formatted_prompt}], | |
| "max_tokens": max_tokens | |
| } | |
| try: | |
| response = requests.post(self.base_url, headers=self.headers, json=data) | |
| response.raise_for_status() | |
| return response.json()["content"][0]["text"] | |
| except Exception as e: | |
| raise Exception(f"API call failed: {str(e)}") | |
| def _format_prompt(self, question: str, prompt_format: str = None) -> str: | |
| """Format the prompt using custom format if provided""" | |
| if prompt_format: | |
| return prompt_format.format(question=question) | |
| # Default format if none provided | |
| return f"""Please answer the question using the following format, with each step clearly marked: | |
| Question: {question} | |
| Let's solve this step by step: | |
| <step number="1"> | |
| [First step of reasoning] | |
| </step> | |
| <step number="2"> | |
| [Second step of reasoning] | |
| </step> | |
| <step number="3"> | |
| [Third step of reasoning] | |
| </step> | |
| (add more steps as needed) | |
| <answer> | |
| [Final answer] | |
| </answer> | |
| Note: | |
| 1. Each step must be wrapped in XML tags <step> | |
| 2. Each step must have a number attribute | |
| 3. The final answer must be wrapped in <answer> tags | |
| """ | |
| def wrap_text(text: str, config: VisualizationConfig) -> str: | |
| """Wrap text to fit within box constraints""" | |
| text = text.replace('\n', ' ').replace('"', "'") | |
| wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line) | |
| if len(wrapped_lines) > config.max_lines: | |
| # Option 1: Simply truncate and add ellipsis to the last line | |
| wrapped_lines = wrapped_lines[:config.max_lines] | |
| wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..." | |
| # Option 2 (alternative): Include part of the next line to show continuity | |
| # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else "" | |
| # wrapped_lines = wrapped_lines[:config.max_lines-1] | |
| # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...") | |
| return "<br>".join(wrapped_lines) | |
| def parse_cot_response(response_text: str, question: str) -> CoTResponse: | |
| """ | |
| Parse CoT response text to extract steps and final answer. | |
| Args: | |
| response_text: The raw response from the API | |
| question: The original question | |
| Returns: | |
| CoTResponse object containing question, steps, and answer | |
| """ | |
| # Extract all steps | |
| step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>' | |
| steps = [] | |
| for match in re.finditer(step_pattern, response_text, re.DOTALL): | |
| number = int(match.group(1)) | |
| content = match.group(2).strip() | |
| steps.append(CoTStep(number=number, content=content)) | |
| # Extract answer | |
| answer_pattern = r'<answer>\s*(.*?)\s*</answer>' | |
| answer_match = re.search(answer_pattern, response_text, re.DOTALL) | |
| answer = answer_match.group(1).strip() if answer_match else None | |
| # Sort steps by number | |
| steps.sort(key=lambda x: x.number) | |
| return CoTResponse(question=question, steps=steps, answer=answer) | |
| def create_mermaid_diagram(cot_response: CoTResponse, config: VisualizationConfig) -> str: | |
| """ | |
| Convert CoT steps to Mermaid diagram with improved text wrapping. | |
| Args: | |
| cot_response: CoTResponse object containing the reasoning steps | |
| config: VisualizationConfig for text formatting | |
| Returns: | |
| Mermaid diagram markup as a string | |
| """ | |
| diagram = ['<div class="mermaid">', 'graph TD'] | |
| # Add question node | |
| question_content = wrap_text(cot_response.question, config) | |
| diagram.append(f' Q["{question_content}"]') | |
| # Add steps with wrapped text and connect them | |
| if cot_response.steps: | |
| # Connect question to first step | |
| diagram.append(f' Q --> S{cot_response.steps[0].number}') | |
| # Add all steps | |
| for i, step in enumerate(cot_response.steps): | |
| content = wrap_text(step.content, config) | |
| node_id = f'S{step.number}' | |
| diagram.append(f' {node_id}["{content}"]') | |
| # Connect steps sequentially | |
| if i < len(cot_response.steps) - 1: | |
| next_id = f'S{cot_response.steps[i + 1].number}' | |
| diagram.append(f' {node_id} --> {next_id}') | |
| # Add final answer node | |
| if cot_response.answer: | |
| answer = wrap_text(cot_response.answer, config) | |
| diagram.append(f' A["{answer}"]') | |
| if cot_response.steps: | |
| diagram.append(f' S{cot_response.steps[-1].number} --> A') | |
| else: | |
| diagram.append(' Q --> A') | |
| # Add styles for better visualization | |
| diagram.extend([ | |
| ' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;', | |
| ' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;', | |
| ' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;', | |
| ' class Q question;', | |
| ' class A answer;', | |
| ' linkStyle default stroke:#666,stroke-width:2px;' | |
| ]) | |
| diagram.append('</div>') | |
| return '\n'.join(diagram) |