Spaces:
Running
Running
File size: 6,176 Bytes
73c6377 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
"""
Middleware for Medical RAG AI Advisor API
"""
import time
import logging
from typing import Callable, Awaitable, Optional
from fastapi import Request, Response, HTTPException, Cookie
from fastapi.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
class ProcessTimeMiddleware(BaseHTTPMiddleware):
"""Middleware to add processing time to response headers"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
response = await call_next(request)
process_time = time.time() - start_time
response.headers["X-Process-Time"] = f"{process_time:.4f}"
return response
class LoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for request/response logging"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# Log request
logger.info(f"Request: {request.method} {request.url}")
try:
response = await call_next(request)
process_time = time.time() - start_time
# Log response
logger.info(
f"Response: {response.status_code} - "
f"Time: {process_time:.4f}s - "
f"Path: {request.url.path}"
)
return response
except Exception as e:
process_time = time.time() - start_time
logger.error(
f"Error: {str(e)} - "
f"Time: {process_time:.4f}s - "
f"Path: {request.url.path}"
)
raise
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Simple rate limiting middleware"""
def __init__(self, app, calls_per_minute: int = 60):
super().__init__(app)
self.calls_per_minute = calls_per_minute
self.client_calls = {}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
client_ip = request.client.host
current_time = time.time()
# Clean old entries
self.client_calls = {
ip: calls for ip, calls in self.client_calls.items()
if any(call_time > current_time - 60 for call_time in calls)
}
# Check rate limit
if client_ip in self.client_calls:
recent_calls = [
call_time for call_time in self.client_calls[client_ip]
if call_time > current_time - 60
]
if len(recent_calls) >= self.calls_per_minute:
raise HTTPException(
status_code=429,
detail="Rate limit exceeded. Please try again later."
)
self.client_calls[client_ip] = recent_calls + [current_time]
else:
self.client_calls[client_ip] = [current_time]
return await call_next(request)
class AuthenticationMiddleware(BaseHTTPMiddleware):
"""Middleware to protect endpoints with session authentication"""
# Paths that don't require authentication
PUBLIC_PATHS = [
"/",
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/auth/login",
"/auth/status",
"/assess", # Allow assess endpoint for local testing
"/ask", # Allow ask endpoint for local testing
]
async def dispatch(self, request: Request, call_next: Callable) -> Response:
# For local testing, disable authentication
# TODO: Enable authentication in production
return await call_next(request)
# Original authentication code (disabled for local testing)
# # Check if path is public
# path = request.url.path
#
# # Allow public paths
# if any(path.startswith(public_path) for public_path in self.PUBLIC_PATHS):
# return await call_next(request)
#
# # Check for session token
# session_token = request.cookies.get("session_token")
#
# if not session_token:
# raise HTTPException(
# status_code=401,
# detail="Authentication required"
# )
#
# # Verify session
# from api.routers.auth import verify_session
# session_data = verify_session(session_token)
#
# if not session_data:
# raise HTTPException(
# status_code=401,
# detail="Invalid or expired session"
# )
#
# # Add user info to request state
# request.state.user = session_data.get("username")
#
# return await call_next(request)
def get_cors_middleware_config():
"""Get CORS middleware configuration"""
import os
# Get allowed origins from environment or use defaults
allowed_origins = os.getenv("ALLOWED_ORIGINS", "").split(",")
if not allowed_origins or allowed_origins == [""]:
# Default to allowing Hugging Face Space and localhost
# Include null for file:// protocol and common local development origins
allowed_origins = [
"http://127.0.0.1:7860",
"https://huggingface.co",
"http://localhost:8000",
"http://127.0.0.1:8000",
"http://localhost:8080", # Frontend server
"http://127.0.0.1:8080",
"http://localhost:5500", # Live Server default port
"http://127.0.0.1:5500",
"http://localhost:3000", # Common dev server port
"http://127.0.0.1:3000",
"null", # For file:// protocol
"*" # Allow all origins for local testing
]
return {
"allow_origins": ["*"], # Allow all origins for local testing
"allow_credentials": False, # Must be False when allow_origins is "*"
"allow_methods": ["GET", "POST", "PUT", "DELETE", "OPTIONS"],
"allow_headers": ["*"], # Allow all headers
"expose_headers": ["Set-Cookie"],
}
|