Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import base64 | |
| from openai import OpenAI | |
| from PIL import Image | |
| import io | |
| import os | |
| import time | |
| import traceback | |
| # API配置 | |
| BASE_URL = "https://api.stepfun.com/v1" | |
| STEP_API_KEY = os.environ.get("STEP_API_KEY", "5GTbxYn2RDN9qmm3Y2T2yhuzlJNrNj65y0W9dVVNrOUmD7eLB3aJ2NDXGyyl2Yccq") | |
| print(f"[DEBUG] Starting app with API key: {'Set' if STEP_API_KEY else 'Not set'}") | |
| print(f"[DEBUG] Base URL: {BASE_URL}") | |
| def image_to_base64(image_path): | |
| """将图片文件转换为base64字符串""" | |
| try: | |
| with Image.open(image_path) as img: | |
| # 如果是RGBA,转换为RGB | |
| if img.mode == 'RGBA': | |
| rgb_img = Image.new('RGB', img.size, (255, 255, 255)) | |
| rgb_img.paste(img, mask=img.split()[3]) | |
| img = rgb_img | |
| # 转换为字节流 | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG", quality=95) | |
| return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
| except Exception as e: | |
| print(f"[ERROR] Failed to convert image: {e}") | |
| return None | |
| def user_submit(message, history, images): | |
| """处理用户提交""" | |
| if not message and not images: | |
| return message, history, images, "", None | |
| # 创建用户消息显示 | |
| display_message = message if message else "" | |
| if images: | |
| # 显示上传的图片数量 | |
| if isinstance(images, list): | |
| num_images = len(images) | |
| image_text = f"[{num_images} Image{'s' if num_images > 1 else ''}]" | |
| else: | |
| image_text = "[1 Image]" | |
| display_message = f"{image_text} {display_message}" if display_message else image_text | |
| history = history + [[display_message, None]] | |
| # 返回清空的输入框、更新的历史、清空的图片,以及保存的消息和图片 | |
| return "", history, None, message, images | |
| def bot_response(history, saved_message, saved_images, system_prompt, temperature, max_tokens, top_p): | |
| """生成机器人回复""" | |
| if saved_message or saved_images: | |
| # 调用process_message并流式返回结果 | |
| for updated_history in process_message( | |
| saved_message, | |
| history, | |
| saved_images, | |
| system_prompt, | |
| temperature, | |
| max_tokens, | |
| top_p | |
| ): | |
| yield updated_history | |
| else: | |
| yield history | |
| def process_message(message, history, images, system_prompt, temperature, max_tokens, top_p): | |
| """处理消息并调用Step-3 API""" | |
| print(f"[DEBUG] Processing message: {message[:100] if message else 'None'}") | |
| print(f"[DEBUG] Has images: {images is not None}") | |
| print(f"[DEBUG] Images type: {type(images)}") | |
| if images: | |
| print(f"[DEBUG] Images content: {images}") | |
| if not message and not images: | |
| history[-1][1] = "Please provide a message or image." | |
| yield history | |
| return | |
| # 确保历史记录中有用户消息 | |
| if not history or history[-1][1] is not None: | |
| display_message = message if message else "" | |
| if images: | |
| if isinstance(images, list): | |
| num_images = len(images) | |
| image_text = f"[{num_images} Image{'s' if num_images > 1 else ''}]" | |
| else: | |
| image_text = "[1 Image]" | |
| display_message = f"{image_text} {display_message}" if display_message else image_text | |
| history.append([display_message, None]) | |
| # 开始生成回复 | |
| history[-1][1] = "🤔 Thinking..." | |
| yield history | |
| try: | |
| # 构建消息内容 | |
| content = [] | |
| # 处理图片(支持多图) | |
| if images: | |
| # 确保images是列表 | |
| image_list = images if isinstance(images, list) else [images] | |
| for image_path in image_list: | |
| if image_path: | |
| print(f"[DEBUG] Processing image: {image_path}") | |
| base64_image = image_to_base64(image_path) | |
| if base64_image: | |
| content.append({ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/jpeg;base64,{base64_image}", | |
| "detail": "high" | |
| } | |
| }) | |
| print(f"[DEBUG] Successfully added image to content") | |
| else: | |
| print(f"[ERROR] Failed to convert image: {image_path}") | |
| # 添加文本消息 | |
| if message: | |
| content.append({ | |
| "type": "text", | |
| "text": message | |
| }) | |
| print(f"[DEBUG] Added text to content: {message[:100]}") | |
| if not content: | |
| history[-1][1] = "❌ No valid input provided." | |
| yield history | |
| return | |
| # 构造API消息 | |
| messages = [] | |
| # 添加系统提示(如果有) | |
| if system_prompt: | |
| messages.append({"role": "system", "content": system_prompt}) | |
| # 只使用用户消息内容,不包括之前的历史 | |
| messages.append({ | |
| "role": "user", | |
| "content": content | |
| }) | |
| print(f"[DEBUG] Prepared {len(messages)} messages for API") | |
| print(f"[DEBUG] Message structure: {[{'role': m['role'], 'content_types': [c.get('type', 'text') for c in m['content']] if isinstance(m['content'], list) else 'text'} for m in messages]}") | |
| # 处理代理问题 - 确保删除所有代理相关的环境变量 | |
| import os | |
| import httpx | |
| # 删除所有可能的代理环境变量 | |
| proxy_vars = ['HTTP_PROXY', 'HTTPS_PROXY', 'http_proxy', 'https_proxy', | |
| 'ALL_PROXY', 'all_proxy', 'NO_PROXY', 'no_proxy'] | |
| for var in proxy_vars: | |
| if var in os.environ: | |
| del os.environ[var] | |
| print(f"[DEBUG] Removed {var} from environment") | |
| # 尝试创建客户端 | |
| try: | |
| # 方法1:直接创建 | |
| client = OpenAI( | |
| api_key=STEP_API_KEY, | |
| base_url=BASE_URL | |
| ) | |
| print("[DEBUG] Client created successfully (method 1)") | |
| except TypeError as e: | |
| if 'proxies' in str(e): | |
| print(f"[DEBUG] Method 1 failed with proxy error, trying method 2") | |
| # 方法2:使用自定义HTTP客户端 | |
| http_client = httpx.Client(trust_env=False) | |
| client = OpenAI( | |
| api_key=STEP_API_KEY, | |
| base_url=BASE_URL, | |
| http_client=http_client | |
| ) | |
| print("[DEBUG] Client created successfully (method 2)") | |
| else: | |
| raise e | |
| print(f"[DEBUG] Making API call to {BASE_URL}") | |
| # 调用API | |
| response = client.chat.completions.create( | |
| model="step-3", | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| top_p=top_p, | |
| stream=True | |
| ) | |
| print("[DEBUG] API call successful, starting streaming") | |
| # 流式输出 | |
| full_response = "" | |
| in_reasoning = False | |
| reasoning_content = "" | |
| final_content = "" | |
| has_reasoning_field = False | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0]: | |
| delta = chunk.choices[0].delta | |
| # 优先检查 delta.reasoning 字段(Step-3 API 的 CoT 内容) | |
| if hasattr(delta, 'reasoning') and delta.reasoning: | |
| has_reasoning_field = True | |
| reasoning_content += delta.reasoning | |
| print(f"[DEBUG] CoT chunk: {delta.reasoning[:50] if len(delta.reasoning) > 50 else delta.reasoning}") | |
| # 实时更新显示 CoT 内容 | |
| if final_content: | |
| display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content}\n\n---\n\n📝 **Answer:**\n\n{final_content}" | |
| else: | |
| display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content}\n\n---\n\n📝 **Answer:**\n\n*Generating...*" | |
| history[-1][1] = display_text | |
| yield history | |
| # 处理常规 content 字段 | |
| delta_content = delta.content if hasattr(delta, 'content') else None | |
| if delta_content: | |
| # 如果通过 reasoning 字段获取了 CoT,content 就是最终答案 | |
| if has_reasoning_field: | |
| final_content += delta_content | |
| full_response += delta_content | |
| else: | |
| # 否则尝试解析 <reasoning> 标签 | |
| full_response += delta_content | |
| # 检测reasoning标签 | |
| if '<reasoning>' in full_response and not in_reasoning: | |
| in_reasoning = True | |
| parts = full_response.split('<reasoning>') | |
| if len(parts) > 1: | |
| reasoning_content = parts[1] | |
| if in_reasoning and '</reasoning>' in full_response: | |
| in_reasoning = False | |
| parts = full_response.split('</reasoning>') | |
| if len(parts) > 1: | |
| reasoning_content = parts[0].split('<reasoning>')[-1] | |
| final_content = parts[1] | |
| elif in_reasoning: | |
| reasoning_content = full_response.split('<reasoning>')[-1] | |
| elif '</reasoning>' in full_response: | |
| parts = full_response.split('</reasoning>') | |
| if len(parts) > 1: | |
| final_content = parts[1] | |
| else: | |
| # 没有reasoning标签的情况 | |
| if '<reasoning>' not in full_response: | |
| final_content = full_response | |
| # 格式化显示 | |
| if reasoning_content and final_content: | |
| display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n{final_content.strip()}" | |
| elif reasoning_content: | |
| display_text = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n*Generating...*" | |
| else: | |
| display_text = full_response | |
| history[-1][1] = display_text | |
| yield history | |
| # 最终格式化 | |
| if reasoning_content or final_content: | |
| final_display = f"💭 **Chain of Thought:**\n\n{reasoning_content.strip()}\n\n---\n\n📝 **Answer:**\n\n{final_content.strip()}" | |
| history[-1][1] = final_display | |
| else: | |
| history[-1][1] = full_response | |
| print(f"[DEBUG] Streaming completed. Response length: {len(full_response)}") | |
| yield history | |
| except Exception as e: | |
| error_msg = f"❌ Error: {str(e)}" | |
| print(f"[ERROR] {error_msg}") | |
| traceback.print_exc() | |
| history[-1][1] = f"❌ Error: {str(e)}" | |
| yield history | |
| # 创建Gradio界面 | |
| css = """ | |
| /* 强制设置File组件容器高度 */ | |
| .compact-file, .compact-file > * { | |
| height: 52px !important; | |
| max-height: 52px !important; | |
| min-height: 52px !important; | |
| } | |
| /* 使用ID选择器确保优先级 */ | |
| #image-upload { | |
| height: 52px !important; | |
| max-height: 52px !important; | |
| min-height: 52px !important; | |
| } | |
| #image-upload > div, | |
| #image-upload .wrap, | |
| #image-upload .block, | |
| #image-upload .container { | |
| height: 52px !important; | |
| max-height: 52px !important; | |
| min-height: 52px !important; | |
| padding: 0 !important; | |
| margin: 0 !important; | |
| } | |
| /* 文件上传按钮样式 */ | |
| #image-upload button, | |
| .compact-file button { | |
| height: 50px !important; | |
| max-height: 50px !important; | |
| min-height: 50px !important; | |
| font-size: 13px !important; | |
| padding: 0 12px !important; | |
| margin: 1px !important; | |
| } | |
| /* 文件预览区域 */ | |
| #image-upload .file-preview, | |
| .compact-file .file-preview { | |
| height: 50px !important; | |
| max-height: 50px !important; | |
| overflow-y: auto !important; | |
| font-size: 12px !important; | |
| padding: 4px !important; | |
| } | |
| /* 隐藏标签 */ | |
| #image-upload label, | |
| .compact-file label { | |
| display: none !important; | |
| } | |
| /* 确保input元素也是正确高度 */ | |
| #image-upload input[type="file"], | |
| .compact-file input[type="file"] { | |
| height: 50px !important; | |
| max-height: 50px !important; | |
| } | |
| /* 文本框参考高度 */ | |
| #message-textbox textarea { | |
| min-height: 52px !important; | |
| max-height: 52px !important; | |
| } | |
| /* 使用通配符确保所有子元素 */ | |
| #image-upload * { | |
| max-height: 52px !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Step-3", theme=gr.themes.Soft(), css=css) as demo: | |
| gr.Markdown(""" | |
| # <img src="https://huggingface.co/stepfun-ai/step3/resolve/main/figures/stepfun-logo.png" alt="StepFun Logo" style="height: 30px; vertical-align: middle; margin-right: 8px;"> Step-3 | |
| Welcome to Step-3, an advanced multimodal AI assistant by <a href="https://stepfun.com/" target="_blank" style="color: #0969da;">StepFun</a>. | |
| """) | |
| # 创建状态变量来保存消息和图片 | |
| saved_msg = gr.State("") | |
| saved_imgs = gr.State([]) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| height=600, | |
| show_label=False, | |
| elem_id="chatbot", | |
| bubble_full_width=False, | |
| avatar_images=None, | |
| render_markdown=True | |
| ) | |
| # 输入区域 | |
| with gr.Row(): | |
| with gr.Column(scale=8): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Type your message here...", | |
| lines=2, | |
| max_lines=10, | |
| show_label=False, | |
| elem_id="message-textbox" | |
| ) | |
| with gr.Column(scale=2): | |
| image_input = gr.File( | |
| label="Upload Images", | |
| file_count="multiple", | |
| file_types=[".png", ".jpg", ".jpeg", ".gif", ".webp"], | |
| interactive=True, | |
| show_label=False, | |
| elem_classes="compact-file", | |
| elem_id="image-upload" | |
| ) | |
| with gr.Column(scale=1, min_width=100): | |
| submit_btn = gr.Button("Send", variant="primary") | |
| # 底部按钮 | |
| with gr.Row(): | |
| clear_btn = gr.Button("🗑️ Clear", scale=1) | |
| undo_btn = gr.Button("↩️ Undo", scale=1) | |
| retry_btn = gr.Button("🔄 Retry", scale=1) | |
| with gr.Column(scale=1): | |
| # 设置面板 | |
| with gr.Accordion("⚙️ Settings", open=False): | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| placeholder="Set a system prompt (optional)", | |
| lines=3 | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=100, | |
| maximum=8000, | |
| value=2000, | |
| step=100, | |
| label="Max Tokens" | |
| ) | |
| top_p_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.95, | |
| step=0.05, | |
| label="Top P" | |
| ) | |
| # 事件处理 | |
| submit_event = msg.submit( | |
| user_submit, | |
| [msg, chatbot, image_input], | |
| [msg, chatbot, image_input, saved_msg, saved_imgs], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider], | |
| chatbot | |
| ) | |
| submit_btn.click( | |
| user_submit, | |
| [msg, chatbot, image_input], | |
| [msg, chatbot, image_input, saved_msg, saved_imgs], | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider], | |
| chatbot | |
| ) | |
| clear_btn.click(lambda: None, None, chatbot, queue=False) | |
| undo_btn.click( | |
| lambda h: h[:-1] if h else h, | |
| chatbot, | |
| chatbot, | |
| queue=False | |
| ) | |
| retry_btn.click( | |
| lambda h: h[:-1] if h and h[-1][1] is not None else h, | |
| chatbot, | |
| chatbot, | |
| queue=False | |
| ).then( | |
| bot_response, | |
| [chatbot, saved_msg, saved_imgs, system_prompt, temperature_slider, max_tokens_slider, top_p_slider], | |
| chatbot | |
| ) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| print(f"[DEBUG] Starting app with API key: {'Set' if STEP_API_KEY else 'Not set'}") | |
| print(f"[DEBUG] Base URL: {BASE_URL}") | |
| demo.queue(max_size=10) | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| debug=False | |
| ) |