Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, Response | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import uvicorn | |
| import json | |
| from typegpt_api import generate, model_mapping, simplified_models | |
| from api_info import developer_info, model_providers | |
| app = FastAPI() | |
| # Set up CORS middleware if needed | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health_check(): | |
| return {"status": "OK"} | |
| async def get_models(): | |
| try: | |
| response = { | |
| "object": "list", | |
| "data": [] | |
| } | |
| for provider, info in model_providers.items(): | |
| for model in info["models"]: | |
| response["data"].append({ | |
| "id": model, | |
| "object": "model", | |
| "provider": provider, | |
| "description": info["description"] | |
| }) | |
| return JSONResponse(content=response) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| async def chat_completions(request: Request): | |
| # Receive the JSON payload | |
| try: | |
| body = await request.json() | |
| except Exception as e: | |
| return JSONResponse(content={"error": "Invalid JSON payload"}, status_code=400) | |
| # Extract parameters | |
| model = body.get("model") | |
| messages = body.get("messages") | |
| temperature = body.get("temperature", 0.7) | |
| top_p = body.get("top_p", 1.0) | |
| n = body.get("n", 1) | |
| stream = body.get("stream", False) | |
| stop = body.get("stop") | |
| max_tokens = body.get("max_tokens") | |
| presence_penalty = body.get("presence_penalty", 0.0) | |
| frequency_penalty = body.get("frequency_penalty", 0.0) | |
| logit_bias = body.get("logit_bias") | |
| user = body.get("user") | |
| timeout = 30 # or set based on your preference | |
| # Validate required parameters | |
| if not model: | |
| return JSONResponse(content={"error": "The 'model' parameter is required."}, status_code=400) | |
| if not messages: | |
| return JSONResponse(content={"error": "The 'messages' parameter is required."}, status_code=400) | |
| # Call the generate function | |
| try: | |
| if stream: | |
| async def generate_stream(): | |
| response = generate( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| n=n, | |
| stream=True, | |
| stop=stop, | |
| max_tokens=max_tokens, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| logit_bias=logit_bias, | |
| user=user, | |
| timeout=timeout, | |
| ) | |
| for chunk in response: | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return StreamingResponse( | |
| generate_stream(), | |
| media_type="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Transfer-Encoding": "chunked" | |
| } | |
| ) | |
| else: | |
| response = generate( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| n=n, | |
| stream=False, | |
| stop=stop, | |
| max_tokens=max_tokens, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| logit_bias=logit_bias, | |
| user=user, | |
| timeout=timeout, | |
| ) | |
| return JSONResponse(content=response) | |
| except Exception as e: | |
| return JSONResponse(content={"error": str(e)}, status_code=500) | |
| async def get_developer_info(): | |
| return JSONResponse(content=developer_info) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |