Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import os | |
| import json | |
| import traceback | |
| import sys | |
| import re | |
| # Enable or disable tracing | |
| ENABLE_TRACING = False | |
| # Set up the API endpoint and key | |
| API_BASE_URL = os.getenv("RUNPOD_API_URL") | |
| API_KEY = os.getenv("RUNPOD_API_KEY") | |
| API_URL = f"{API_BASE_URL}/chat/completions" | |
| headers = { | |
| "Authorization": f"Bearer {API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| import re | |
| def style_xml_content(text): | |
| def replace_content(match): | |
| full_match = match.group(0) | |
| tag = match.group(1) | |
| content = match.group(2) | |
| if tag == 'thinking': | |
| styled_content = f'<i><b>{content}</b></i>' | |
| return f'<details open><summary><thinking></summary>{styled_content}<br></thinking></details>' | |
| elif tag == 'reflection': | |
| styled_content = f'<u><b>{content}</b></u>' | |
| return f'<details open><summary><reflection></summary>{styled_content}<br></reflection></details>' | |
| else: | |
| return full_match.replace('<', '<').replace('>', '>') | |
| # First, escape all < and > characters | |
| text = text.replace('<', '<').replace('>', '>') | |
| # Then, unescape the specific tags we want to process | |
| text = text.replace('<thinking>', '<thinking>').replace('</thinking>', '</thinking>') | |
| text = text.replace('<reflection>', '<reflection>').replace('</reflection>', '</reflection>') | |
| # Apply styling to content inside tags | |
| styled_text = re.sub(r'<(\w+)>(.*?)</\1>', replace_content, text, flags=re.DOTALL) | |
| # Remove blacklisted text | |
| styled_text = styled_text.replace("<|im_start|>", "") | |
| return styled_text | |
| # Fixed system prompt | |
| SYSTEM_PROMPT = "You an advanced artificial intelligence system, capable of <thinking> and then creating a length <reflection>, where you ask if you were wrong? And then you correct yourself. Always use <reflection></reflection> unless it is a trivial or wikipedia question. Finally you output a brief and small to the point <output>." | |
| def debug_print(*args, **kwargs): | |
| if ENABLE_TRACING: | |
| print(*args, file=sys.stderr, **kwargs) | |
| def parse_sse(data): | |
| if data: | |
| data = data.decode('utf-8').strip() | |
| debug_print(f"Raw SSE data: {data}") | |
| if data.startswith('data: '): | |
| data = data[6:] # Remove 'data: ' prefix | |
| if data == '[DONE]': | |
| return None | |
| try: | |
| return json.loads(data) | |
| except json.JSONDecodeError: | |
| debug_print(f"Failed to parse SSE data: {data}") | |
| return None | |
| def stream_response(message, history, max_tokens, temperature, top_p): | |
| messages = [{"role": "system", "content": SYSTEM_PROMPT}] | |
| for human, assistant in history: | |
| messages.append({"role": "user", "content": human}) | |
| messages.append({"role": "assistant", "content": assistant}) | |
| messages.append({"role": "user", "content": message}) | |
| data = { | |
| "model": "forcemultiplier/fmx-reflective-2b", | |
| "messages": messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "system" : "You are a world-class AI system, capable of complex reasoning and reflection. Reason through the query inside <thinking> tags, and then provide your final response inside <output> tags. If you detect that you made a mistake in your reasoning at any point, correct yourself inside <reflection> tags.", | |
| "top_p": top_p, | |
| "stream": True, | |
| "stop": [ "<|start_header_id|>", | |
| "<|end_header_id|>", | |
| "<|eot_id|>"] # Add stop sequence | |
| } | |
| debug_print(f"Sending request to API: {API_URL}") | |
| debug_print(f"Request data: {json.dumps(data, indent=2)}") | |
| try: | |
| response = requests.post(API_URL, headers=headers, json=data, stream=True) | |
| debug_print(f"Response status code: {response.status_code}") | |
| debug_print(f"Response headers: {response.headers}") | |
| response.raise_for_status() | |
| accumulated_content = "" | |
| for line in response.iter_lines(): | |
| if line: | |
| debug_print(f"Received line: {line}") | |
| parsed = parse_sse(line) | |
| if parsed: | |
| debug_print(f"Parsed SSE data: {parsed}") | |
| if 'choices' in parsed and len(parsed['choices']) > 0: | |
| content = parsed['choices'][0]['delta'].get('content', '') | |
| if content: | |
| accumulated_content += content | |
| styled_content = style_xml_content(accumulated_content) | |
| yield styled_content | |
| # Check if we've reached the stop sequence | |
| if accumulated_content.endswith("</output>"): | |
| break | |
| except requests.exceptions.RequestException as e: | |
| debug_print(f"Request exception: {str(e)}") | |
| debug_print(f"Request exception traceback: {traceback.format_exc()}") | |
| yield f"Error: {str(e)}" | |
| except Exception as e: | |
| debug_print(f"Unexpected error: {str(e)}") | |
| debug_print(f"Error traceback: {traceback.format_exc()}") | |
| yield f"Unexpected error: {str(e)}" | |
| demo = gr.ChatInterface( | |
| stream_response, | |
| additional_inputs=[ | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max tokens"), | |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.4, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.83, step=0.05, label="Top-p (nucleus sampling)"), | |
| ], | |
| ) | |
| if __name__ == "__main__": | |
| debug_print(f"Starting application with API URL: {API_URL}") | |
| debug_print(f"Using system prompt: {SYSTEM_PROMPT}") | |
| debug_print(f"Tracing enabled: {ENABLE_TRACING}") | |
| demo.launch() |