Spaces:
Running
Running
| import gradio as gr | |
| import time | |
| import base64 | |
| import os | |
| from io import BytesIO | |
| from PIL import Image | |
| # 清理环境变量中的代理设置,避免与 OpenAI 客户端冲突 | |
| for key in list(os.environ.keys()): | |
| if 'proxy' in key.lower() or 'PROXY' in key: | |
| del os.environ[key] | |
| # 导入 OpenAI(在清理环境变量后) | |
| from openai import OpenAI | |
| # 配置 | |
| BASE_URL = "https://api.stepfun.com/v1" | |
| # 从环境变量获取API密钥 | |
| STEP_API_KEY = os.environ.get("STEP_API_KEY", "") | |
| # 可选模型 | |
| MODELS = ["step-3", "step-r1-v-mini"] | |
| def image_to_base64(image): | |
| """将PIL图像转换为base64字符串""" | |
| if image is None: | |
| return None | |
| if isinstance(image, Image.Image): | |
| buffered = BytesIO() | |
| image.save(buffered, format="PNG") | |
| img_str = base64.b64encode(buffered.getvalue()).decode() | |
| return img_str | |
| return None | |
| def create_client(): | |
| """创建 OpenAI 客户端,处理各种环境问题""" | |
| import importlib | |
| import sys | |
| # 重新加载 openai 模块以确保干净的状态 | |
| if 'openai' in sys.modules: | |
| importlib.reload(sys.modules['openai']) | |
| # 尝试不同的初始化方式 | |
| try: | |
| # 方式1:只传递必需参数 | |
| return OpenAI( | |
| api_key=STEP_API_KEY, | |
| base_url=BASE_URL | |
| ) | |
| except: | |
| pass | |
| try: | |
| # 方式2:通过环境变量 | |
| os.environ['OPENAI_API_KEY'] = STEP_API_KEY | |
| os.environ['OPENAI_BASE_URL'] = BASE_URL | |
| return OpenAI() | |
| except: | |
| pass | |
| # 方式3:使用 httpx 客户端自定义 | |
| try: | |
| import httpx | |
| http_client = httpx.Client() | |
| return OpenAI( | |
| api_key=STEP_API_KEY, | |
| base_url=BASE_URL, | |
| http_client=http_client | |
| ) | |
| except: | |
| pass | |
| # 如果都失败,返回 None | |
| return None | |
| def call_step_api(image, prompt, model, temperature=0.7, max_tokens=2000): | |
| """调用Step API进行分析,支持纯文本和图像+文本""" | |
| if not prompt: | |
| yield "", "❌ 请输入提示词" | |
| return | |
| if not STEP_API_KEY: | |
| yield "", "❌ API密钥未配置。请在 Hugging Face Space 的 Settings 中添加 STEP_API_KEY 环境变量。" | |
| return | |
| # 构造消息内容 | |
| if image is not None: | |
| # 有图片的情况 | |
| try: | |
| base64_image = image_to_base64(image) | |
| if base64_image is None: | |
| yield "", "❌ 图片处理失败" | |
| return | |
| message_content = [ | |
| { | |
| "type": "image_url", | |
| "image_url": { | |
| "url": f"data:image/png;base64,{base64_image}", | |
| "detail": "high" | |
| } | |
| }, | |
| { | |
| "type": "text", | |
| "text": prompt | |
| } | |
| ] | |
| except Exception as e: | |
| yield "", f"❌ 图片处理错误: {str(e)}" | |
| return | |
| else: | |
| # 纯文本的情况 | |
| message_content = prompt | |
| # 构造消息 | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": message_content | |
| } | |
| ] | |
| # 创建OpenAI客户端 | |
| client = create_client() | |
| if client is None: | |
| # 如果客户端创建失败,尝试直接使用 requests | |
| try: | |
| import requests | |
| headers = { | |
| "Authorization": f"Bearer {STEP_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| data = { | |
| "model": model, | |
| "messages": messages, | |
| "temperature": temperature, | |
| "max_tokens": max_tokens, | |
| "stream": False | |
| } | |
| response = requests.post( | |
| f"{BASE_URL}/chat/completions", | |
| headers=headers, | |
| json=data, | |
| timeout=60 | |
| ) | |
| if response.status_code == 200: | |
| result = response.json() | |
| if result.get("choices") and result["choices"][0].get("message"): | |
| content = result["choices"][0]["message"]["content"] | |
| # 解析 reasoning 标记 | |
| reasoning_content = "" | |
| final_answer = content | |
| if "<reasoning>" in content and "</reasoning>" in content: | |
| parts = content.split("<reasoning>") | |
| before = parts[0] | |
| after_reasoning = parts[1].split("</reasoning>") | |
| reasoning_content = after_reasoning[0] | |
| final_answer = before + after_reasoning[1] if len(after_reasoning) > 1 else before | |
| yield reasoning_content, final_answer | |
| else: | |
| yield "", "❌ API 返回格式错误" | |
| else: | |
| yield "", f"❌ API 请求失败: {response.status_code}" | |
| except Exception as e: | |
| yield "", f"❌ 请求失败: {str(e)}" | |
| return | |
| try: | |
| # 记录开始时间 | |
| start_time = time.time() | |
| # 流式输出 | |
| response = client.chat.completions.create( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| max_tokens=max_tokens, | |
| stream=True | |
| ) | |
| full_response = "" | |
| reasoning_content = "" | |
| final_answer = "" | |
| is_reasoning = False | |
| reasoning_started = False | |
| for chunk in response: | |
| if chunk.choices and chunk.choices[0].delta: | |
| delta = chunk.choices[0].delta | |
| if hasattr(delta, 'content') and delta.content: | |
| content = delta.content | |
| full_response += content | |
| # 检测reasoning标记 | |
| if "<reasoning>" in content: | |
| is_reasoning = True | |
| reasoning_started = True | |
| # 提取<reasoning>之前的内容添加到final_answer | |
| before_reasoning = content.split("<reasoning>")[0] | |
| if before_reasoning: | |
| final_answer += before_reasoning | |
| # 提取<reasoning>之后的内容开始reasoning | |
| after_tag = content.split("<reasoning>")[1] if len(content.split("<reasoning>")) > 1 else "" | |
| reasoning_content += after_tag | |
| elif "</reasoning>" in content: | |
| # 提取</reasoning>之前的内容添加到reasoning | |
| before_tag = content.split("</reasoning>")[0] | |
| reasoning_content += before_tag | |
| is_reasoning = False | |
| # 提取</reasoning>之后的内容添加到final_answer | |
| after_reasoning = content.split("</reasoning>")[1] if len(content.split("</reasoning>")) > 1 else "" | |
| final_answer += after_reasoning | |
| elif is_reasoning: | |
| reasoning_content += content | |
| else: | |
| final_answer += content | |
| # 实时输出 | |
| if reasoning_started: | |
| yield reasoning_content, final_answer | |
| else: | |
| yield "", final_answer | |
| # 添加生成时间 | |
| elapsed_time = time.time() - start_time | |
| time_info = f"\n\n⏱️ 生成用时: {elapsed_time:.2f}秒" | |
| final_answer += time_info | |
| yield reasoning_content, final_answer | |
| except Exception as e: | |
| error_msg = str(e) | |
| if "api_key" in error_msg.lower(): | |
| yield "", "❌ API密钥错误:请检查密钥是否有效" | |
| elif "network" in error_msg.lower() or "connection" in error_msg.lower(): | |
| yield "", "❌ 网络连接错误:请检查网络连接" | |
| else: | |
| yield "", f"❌ API调用错误: {error_msg[:200]}" | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="Step-3", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 Step-3 | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # 输入区域 | |
| image_input = gr.Image( | |
| label="上传图片(可选)", | |
| type="pil", | |
| height=300 | |
| ) | |
| prompt_input = gr.Textbox( | |
| label="提示词", | |
| placeholder="输入你的问题或描述...", | |
| lines=3, | |
| value="" | |
| ) | |
| with gr.Accordion("高级设置", open=False): | |
| model_select = gr.Dropdown( | |
| choices=MODELS, | |
| value=MODELS[0], | |
| label="选择模型" | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0, | |
| maximum=1, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature" | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=100, | |
| maximum=4000, | |
| value=2000, | |
| step=100, | |
| label="最大输出长度" | |
| ) | |
| submit_btn = gr.Button("🚀 开始分析", variant="primary") | |
| clear_btn = gr.Button("🗑️ 清空", variant="secondary") | |
| with gr.Column(scale=1): | |
| # 推理过程展示 | |
| with gr.Accordion("💭 推理过程 (CoT)", open=True): | |
| reasoning_output = gr.Textbox( | |
| label="思考过程", | |
| lines=10, | |
| max_lines=15, | |
| show_copy_button=True, | |
| interactive=False | |
| ) | |
| # 最终答案展示 | |
| answer_output = gr.Textbox( | |
| label="📝 分析结果", | |
| lines=15, | |
| max_lines=25, | |
| show_copy_button=True, | |
| interactive=False | |
| ) | |
| # 事件处理 - 流式输出到两个文本框 | |
| submit_btn.click( | |
| fn=call_step_api, | |
| inputs=[ | |
| image_input, | |
| prompt_input, | |
| model_select, | |
| temperature_slider, | |
| max_tokens_slider | |
| ], | |
| outputs=[reasoning_output, answer_output], | |
| show_progress=True | |
| ) | |
| clear_btn.click( | |
| fn=lambda: (None, "", "", ""), | |
| inputs=[], | |
| outputs=[image_input, prompt_input, reasoning_output, answer_output] | |
| ) | |
| # 页脚 | |
| gr.Markdown(""" | |
| --- | |
| Powered by [Step-3](https://www.stepfun.com/) | |
| """) | |
| # 启动应用 | |
| if __name__ == "__main__": | |
| demo.launch() |