Spaces:
Sleeping
Sleeping
| from flask import Flask, request, Response, jsonify, stream_with_context | |
| from flask_cors import CORS | |
| import json | |
| from typegpt_api import generate, model_mapping, simplified_models | |
| from api_info import developer_info, model_providers | |
| app = Flask(__name__) | |
| # Set up CORS middleware if needed | |
| CORS(app, resources={ | |
| r"/*": { | |
| "origins": "*", | |
| "allow_credentials": True, | |
| "methods": ["*"], | |
| "headers": ["*"] | |
| } | |
| }) | |
| def health_check(): | |
| return jsonify({"status": "OK"}) | |
| def get_models(): | |
| try: | |
| response = { | |
| "object": "list", | |
| "data": [] | |
| } | |
| for provider, info in model_providers.items(): | |
| for model in info["models"]: | |
| response["data"].append({ | |
| "id": model, | |
| "object": "model", | |
| "provider": provider, | |
| "description": info["description"] | |
| }) | |
| return jsonify(response) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def chat_completions(): | |
| # Receive the JSON payload | |
| try: | |
| body = request.get_json() | |
| except Exception as e: | |
| return jsonify({"error": "Invalid JSON payload"}), 400 | |
| # Extract parameters | |
| model = body.get("model") | |
| messages = body.get("messages") | |
| temperature = body.get("temperature", 0.7) | |
| top_p = body.get("top_p", 1.0) | |
| n = body.get("n", 1) | |
| stream = body.get("stream", False) | |
| stop = body.get("stop") | |
| max_tokens = body.get("max_tokens") | |
| presence_penalty = body.get("presence_penalty", 0.0) | |
| frequency_penalty = body.get("frequency_penalty", 0.0) | |
| logit_bias = body.get("logit_bias") | |
| user = body.get("user") | |
| timeout = 30 # or set based on your preference | |
| # Validate required parameters | |
| if not model: | |
| return jsonify({"error": "The 'model' parameter is required."}), 400 | |
| if not messages: | |
| return jsonify({"error": "The 'messages' parameter is required."}), 400 | |
| # Call the generate function | |
| try: | |
| if stream: | |
| def generate_stream(): | |
| response = generate( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| n=n, | |
| stream=True, | |
| stop=stop, | |
| max_tokens=max_tokens, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| logit_bias=logit_bias, | |
| user=user, | |
| timeout=timeout, | |
| ) | |
| for chunk in response: | |
| yield f"data: {json.dumps(chunk)}\n\n" | |
| yield "data: [DONE]\n\n" | |
| return Response( | |
| stream_with_context(generate_stream()), | |
| mimetype="text/event-stream", | |
| headers={ | |
| "Cache-Control": "no-cache", | |
| "Connection": "keep-alive", | |
| "Transfer-Encoding": "chunked" | |
| } | |
| ) | |
| else: | |
| response = generate( | |
| model=model, | |
| messages=messages, | |
| temperature=temperature, | |
| top_p=top_p, | |
| n=n, | |
| stream=False, | |
| stop=stop, | |
| max_tokens=max_tokens, | |
| presence_penalty=presence_penalty, | |
| frequency_penalty=frequency_penalty, | |
| logit_bias=logit_bias, | |
| user=user, | |
| timeout=timeout, | |
| ) | |
| return jsonify(response) | |
| except Exception as e: | |
| return jsonify({"error": str(e)}), 500 | |
| def get_developer_info(): | |
| return jsonify(developer_info) | |
| if __name__ == "__main__": | |
| app.run(host="0.0.0.0", port=8000) |