Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, WebSocket, Request | |
| from fastapi.responses import HTMLResponse | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from huggingface_hub import InferenceClient | |
| import os | |
| import json | |
| import asyncio | |
| app = FastAPI() | |
| # Mount static files directory | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| # Setup Jinja2 templates | |
| templates = Jinja2Templates(directory="templates") | |
| # Initialize the Hugging Face Inference Client | |
| client = InferenceClient() | |
| async def generate_stream_response(prompt_template: str, **kwargs): | |
| """ | |
| Generate a streaming response using Hugging Face Inference Client | |
| Args: | |
| prompt_template (str): The prompt template to use | |
| **kwargs: Dynamic arguments to format the prompt | |
| Yields: | |
| str: Streamed content chunks | |
| """ | |
| # Construct the prompt (you'll need to set up environment variables or a prompt mapping) | |
| prompt = os.getenv(prompt_template).format(**kwargs) | |
| # Prepare messages for the model | |
| messages = [ | |
| {"role": "user", "content": prompt} | |
| ] | |
| try: | |
| # Create a stream for the chat completion | |
| stream = client.chat.completions.create( | |
| model="Qwen/Qwen2.5-Math-1.5B-Instruct", | |
| messages=messages, | |
| temperature=0.7, | |
| max_tokens=1024, | |
| top_p=0.8, | |
| stream=True | |
| ) | |
| # Stream the generated content | |
| for chunk in stream: | |
| if chunk.choices and chunk.choices[0].delta and chunk.choices[0].delta.content: | |
| yield chunk.choices[0].delta.content | |
| except Exception as e: | |
| yield f"Error occurred: {str(e)}" | |
| async def websocket_endpoint(websocket: WebSocket, endpoint: str): | |
| """ | |
| WebSocket endpoint for streaming responses | |
| Args: | |
| websocket (WebSocket): The WebSocket connection | |
| endpoint (str): The specific endpoint/task to process | |
| """ | |
| await websocket.accept() | |
| try: | |
| # Receive the initial message with parameters | |
| data = await websocket.receive_json() | |
| # Map the endpoint to the appropriate prompt template | |
| endpoint_prompt_map = { | |
| "solve": "PROMPT_SOLVE", | |
| "hint": "PROMPT_HINT", | |
| "verify": "PROMPT_VERIFY", | |
| "generate": "PROMPT_GENERATE", | |
| "explain": "PROMPT_EXPLAIN" | |
| } | |
| # Get the appropriate prompt template | |
| prompt_template = endpoint_prompt_map.get(endpoint) | |
| if not prompt_template: | |
| await websocket.send_json({"error": "Invalid endpoint"}) | |
| return | |
| # Stream the response | |
| full_response = "" | |
| async for chunk in generate_stream_response(prompt_template, **data): | |
| full_response += chunk | |
| await websocket.send_json({"chunk": chunk}) | |
| # Send a final message to indicate streaming is complete | |
| await websocket.send_json({"complete": True, "full_response": full_response}) | |
| except Exception as e: | |
| await websocket.send_json({"error": str(e)}) | |
| finally: | |
| await websocket.close() | |
| # Existing routes remain the same as in the previous implementation | |
| async def home(request: Request): | |
| return HTMLResponse(open("static/index.html").read()) | |