Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel | |
| from NLP_model import chatbot | |
| import uvicorn | |
| import asyncio | |
| import time | |
| import logging | |
| from contextlib import asynccontextmanager | |
| import os | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| ) | |
| logger = logging.getLogger(__name__) | |
| # Chuẩn bị RAG model tại lúc khởi động | |
| async def lifespan(app: FastAPI): | |
| # Khởi tạo retriever sẵn khi server bắt đầu | |
| logger.info("Initializing RAG model retriever...") | |
| # Sử dụng asyncio.to_thread để không block event loop | |
| await asyncio.to_thread(chatbot.get_chain) | |
| logger.info("RAG model retriever initialized successfully") | |
| yield | |
| # Dọn dẹp khi shutdown | |
| logger.info("Shutting down RAG model...") | |
| app = FastAPI( | |
| title="Solana SuperTeam RAG API", | |
| description="API cho mô hình RAG của Solana SuperTeam", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # Add CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Request counter để theo dõi số lượng request đang xử lý | |
| active_requests = 0 | |
| max_concurrent_requests = 5 # Giới hạn số request xử lý đồng thời | |
| request_lock = asyncio.Lock() | |
| class ChatRequest(BaseModel): | |
| query: str | |
| user_id: str = "default_user" | |
| class ChatResponse(BaseModel): | |
| response: str | |
| processing_time: float = None | |
| async def add_process_time_header(request: Request, call_next): | |
| """Middleware để đo thời gian xử lý và kiểm soát số lượng request""" | |
| global active_requests | |
| # Kiểm tra và tăng số request đang xử lý | |
| async with request_lock: | |
| # Nếu đã đạt giới hạn, từ chối request mới | |
| if active_requests >= max_concurrent_requests and request.url.path == "/chat": | |
| return JSONResponse( | |
| status_code=429, | |
| content={"detail": "Too many requests. Please try again later."} | |
| ) | |
| active_requests += 1 | |
| try: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| # Thêm thời gian xử lý vào header | |
| response.headers["X-Process-Time"] = str(process_time) | |
| logger.info(f"Request processed in {process_time:.2f} seconds: {request.url.path}") | |
| return response | |
| finally: | |
| # Giảm counter khi xử lý xong | |
| async with request_lock: | |
| active_requests -= 1 | |
| async def chat_endpoint(request: ChatRequest): | |
| """ | |
| Xử lý yêu cầu chat từ người dùng | |
| """ | |
| start_time = time.time() | |
| try: | |
| # Gọi hàm chat với thông tin được cung cấp | |
| response = await asyncio.to_thread(chatbot.chat, request.query, int(request.user_id)) | |
| process_time = time.time() - start_time | |
| return ChatResponse( | |
| response=response, | |
| processing_time=process_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing chat request: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| """ | |
| Kiểm tra trạng thái của API | |
| """ | |
| # Kiểm tra xem retriever đã được khởi tạo chưa | |
| retriever = chatbot.get_chain() | |
| if retriever: | |
| status = "healthy" | |
| else: | |
| status = "degraded" | |
| return { | |
| "status": status, | |
| "active_requests": active_requests, | |
| "cache_size": len(chatbot.response_cache) | |
| } | |
| async def clear_user_memory(user_id: str): | |
| """ | |
| Xóa lịch sử trò chuyện của một người dùng | |
| """ | |
| try: | |
| result = await asyncio.to_thread(chatbot.clear_memory, user_id) | |
| return {"status": "success", "message": result} | |
| except Exception as e: | |
| logger.error(f"Error clearing memory for user {user_id}: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import nest_asyncio | |
| nest_asyncio.apply() | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |