Spaces:
Running
Running
| """ | |
| Middleware for Medical RAG AI Advisor API | |
| """ | |
| import time | |
| import logging | |
| from typing import Callable, Awaitable, Optional | |
| from fastapi import Request, Response, HTTPException, Cookie | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| logger = logging.getLogger(__name__) | |
| class ProcessTimeMiddleware(BaseHTTPMiddleware): | |
| """Middleware to add processing time to response headers""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| response.headers["X-Process-Time"] = f"{process_time:.4f}" | |
| return response | |
| class LoggingMiddleware(BaseHTTPMiddleware): | |
| """Middleware for request/response logging""" | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| start_time = time.time() | |
| # Log request | |
| logger.info(f"Request: {request.method} {request.url}") | |
| try: | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| # Log response | |
| logger.info( | |
| f"Response: {response.status_code} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| return response | |
| except Exception as e: | |
| process_time = time.time() - start_time | |
| logger.error( | |
| f"Error: {str(e)} - " | |
| f"Time: {process_time:.4f}s - " | |
| f"Path: {request.url.path}" | |
| ) | |
| raise | |
| class RateLimitMiddleware(BaseHTTPMiddleware): | |
| """Simple rate limiting middleware""" | |
| def __init__(self, app, calls_per_minute: int = 60): | |
| super().__init__(app) | |
| self.calls_per_minute = calls_per_minute | |
| self.client_calls = {} | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| client_ip = request.client.host | |
| current_time = time.time() | |
| # Clean old entries | |
| self.client_calls = { | |
| ip: calls for ip, calls in self.client_calls.items() | |
| if any(call_time > current_time - 60 for call_time in calls) | |
| } | |
| # Check rate limit | |
| if client_ip in self.client_calls: | |
| recent_calls = [ | |
| call_time for call_time in self.client_calls[client_ip] | |
| if call_time > current_time - 60 | |
| ] | |
| if len(recent_calls) >= self.calls_per_minute: | |
| raise HTTPException( | |
| status_code=429, | |
| detail="Rate limit exceeded. Please try again later." | |
| ) | |
| self.client_calls[client_ip] = recent_calls + [current_time] | |
| else: | |
| self.client_calls[client_ip] = [current_time] | |
| return await call_next(request) | |
| class AuthenticationMiddleware(BaseHTTPMiddleware): | |
| """Middleware to protect endpoints with session authentication""" | |
| # Paths that don't require authentication | |
| PUBLIC_PATHS = [ | |
| "/", | |
| "/docs", | |
| "/redoc", | |
| "/openapi.json", | |
| "/health", | |
| "/auth/login", | |
| "/auth/status", | |
| "/assess", # Allow assess endpoint for local testing | |
| "/ask", # Allow ask endpoint for local testing | |
| ] | |
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | |
| # For local testing, disable authentication | |
| # TODO: Enable authentication in production | |
| return await call_next(request) | |
| # Original authentication code (disabled for local testing) | |
| # # Check if path is public | |
| # path = request.url.path | |
| # | |
| # # Allow public paths | |
| # if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS): | |
| # return await call_next(request) | |
| # | |
| # # Check for session token | |
| # session_token = request.cookies.get("session_token") | |
| # | |
| # if not session_token: | |
| # raise HTTPException( | |
| # status_code=401, | |
| # detail="Authentication required" | |
| # ) | |
| # | |
| # # Verify session | |
| # from api.routers.auth import verify_session | |
| # session_data = verify_session(session_token) | |
| # | |
| # if not session_data: | |
| # raise HTTPException( | |
| # status_code=401, | |
| # detail="Invalid or expired session" | |
| # ) | |
| # | |
| # # Add user info to request state | |
| # request.state.user = session_data.get("username") | |
| # | |
| # return await call_next(request) | |
| def get_cors_middleware_config(): | |
| """Get CORS middleware configuration""" | |
| import os | |
| # Get allowed origins from environment or use defaults | |
| allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",") | |
| if not allowed_origins or allowed_origins == [""]: | |
| # Default to allowing Hugging Face Space and localhost | |
| # Include null for file:// protocol and common local development origins | |
| allowed_origins = [ | |
| "http://127.0.0.1:7860", | |
| "https://huggingface.co", | |
| "http://localhost:8000", | |
| "http://127.0.0.1:8000", | |
| "http://localhost:8080", # Frontend server | |
| "http://127.0.0.1:8080", | |
| "http://localhost:5500", # Live Server default port | |
| "http://127.0.0.1:5500", | |
| "http://localhost:3000", # Common dev server port | |
| "http://127.0.0.1:3000", | |
| "null", # For file:// protocol | |
| "*" # Allow all origins for local testing | |
| ] | |
| return { | |
| "allow_origins": ["*"], # Allow all origins for local testing | |
| "allow_credentials": False, # Must be False when allow_origins is "*" | |
| "allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| "allow_headers": ["*"], # Allow all headers | |
| "expose_headers": ["Set-Cookie"], | |
| } | |