Spaces:
Sleeping
Sleeping
| from fastapi import Request, status | |
| from fastapi.responses import JSONResponse | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| import logging | |
| import time | |
| import traceback | |
| import uuid | |
| from .utils import get_local_time | |
| # Configure logging | |
| logger = logging.getLogger(__name__) | |
| class RequestLoggingMiddleware(BaseHTTPMiddleware): | |
| """Middleware to log requests and responses""" | |
| async def dispatch(self, request: Request, call_next): | |
| request_id = str(uuid.uuid4()) | |
| request.state.request_id = request_id | |
| # Log request information | |
| client_host = request.client.host if request.client else "unknown" | |
| logger.info(f"Request [{request_id}]: {request.method} {request.url.path} from {client_host}") | |
| # Measure processing time | |
| start_time = time.time() | |
| try: | |
| # Process request | |
| response = await call_next(request) | |
| # Calculate processing time | |
| process_time = time.time() - start_time | |
| logger.info(f"Response [{request_id}]: {response.status_code} processed in {process_time:.4f}s") | |
| # Add headers | |
| response.headers["X-Request-ID"] = request_id | |
| response.headers["X-Process-Time"] = str(process_time) | |
| return response | |
| except Exception as e: | |
| # Log error | |
| process_time = time.time() - start_time | |
| logger.error(f"Error [{request_id}] after {process_time:.4f}s: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| # Return error response | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={ | |
| "detail": "Internal server error", | |
| "request_id": request_id, | |
| "timestamp": get_local_time() | |
| } | |
| ) | |
| class ErrorHandlingMiddleware(BaseHTTPMiddleware): | |
| """Middleware to handle uncaught exceptions in the application""" | |
| async def dispatch(self, request: Request, call_next): | |
| try: | |
| return await call_next(request) | |
| except Exception as e: | |
| # Get request_id if available | |
| request_id = getattr(request.state, "request_id", str(uuid.uuid4())) | |
| # Log error | |
| logger.error(f"Uncaught exception [{request_id}]: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| # Return error response | |
| return JSONResponse( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| content={ | |
| "detail": "Internal server error", | |
| "request_id": request_id, | |
| "timestamp": get_local_time() | |
| } | |
| ) | |
| class DatabaseCheckMiddleware(BaseHTTPMiddleware): | |
| """Middleware to check database connections before each request""" | |
| async def dispatch(self, request: Request, call_next): | |
| # Skip paths that don't need database checks | |
| skip_paths = ["/", "/health", "/docs", "/redoc", "/openapi.json"] | |
| if request.url.path in skip_paths: | |
| return await call_next(request) | |
| # Check database connections | |
| try: | |
| # TODO: Add checks for MongoDB and Pinecone if needed | |
| # PostgreSQL check is already done in route handler with get_db() method | |
| # Process request normally | |
| return await call_next(request) | |
| except Exception as e: | |
| # Log error | |
| logger.error(f"Database connection check failed: {str(e)}") | |
| # Return error response | |
| return JSONResponse( | |
| status_code=status.HTTP_503_SERVICE_UNAVAILABLE, | |
| content={ | |
| "detail": "Database connection failed", | |
| "timestamp": get_local_time() | |
| } | |
| ) |