Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, File, UploadFile, HTTPException, BackgroundTasks | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import FileResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import os | |
| import tempfile | |
| import shutil | |
| from typing import List, Optional, Dict, Any, Iterator | |
| import pathlib | |
| import asyncio | |
| import logging | |
| import time | |
| import traceback | |
| import uuid | |
| import json | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Import our RAG components | |
| from rag import RetrievalAugmentedQAPipeline, process_file, setup_vector_db | |
| # Add local aimakerspace module to the path | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
| # Import from local aimakerspace module | |
| from aimakerspace.utils.session_manager import SessionManager | |
| # Load environment variables | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| app = FastAPI( | |
| title="RAG Application", | |
| description="Retrieval Augmented Generation with FastAPI and React", | |
| version="0.1.0", | |
| root_path="", # Important for proxy environments | |
| ) | |
| # More robust middleware for handling HTTPS | |
| from starlette.middleware.base import BaseHTTPMiddleware | |
| from starlette.responses import RedirectResponse, JSONResponse | |
| class ProxyMiddleware(BaseHTTPMiddleware): | |
| async def dispatch(self, request, call_next): | |
| # Log request details for debugging | |
| logger.info(f"Request path: {request.url.path}") | |
| logger.info(f"Request headers: {request.headers}") | |
| # Validate request before processing | |
| try: | |
| start_time = time.time() | |
| response = await call_next(request) | |
| process_time = time.time() - start_time | |
| response.headers["X-Process-Time"] = str(process_time) | |
| return response | |
| except Exception as e: | |
| logger.error(f"Request failed: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": f"Internal server error: {str(e)}"} | |
| ) | |
| # Add custom middleware | |
| app.add_middleware(ProxyMiddleware) | |
| # Configure CORS - more specific configuration for Hugging Face | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # In production, you should restrict this | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["*"], | |
| expose_headers=["Content-Length", "X-Process-Time"], | |
| max_age=600, # 10 minutes cache for preflight requests | |
| ) | |
| # Initialize session manager | |
| session_manager = SessionManager() | |
| class QueryRequest(BaseModel): | |
| session_id: str | |
| query: str | |
| class QueryResponse(BaseModel): | |
| response: str | |
| session_id: str | |
| # Set file size limit to 10MB - adjust as needed | |
| FILE_SIZE_LIMIT = 10 * 1024 * 1024 # 10MB | |
| async def process_file_background(temp_path: str, filename: str, session_id: str): | |
| """Process file in background and set up the RAG pipeline""" | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Background processing started for file: {filename} (session: {session_id})") | |
| # Set max processing time (5 minutes) | |
| max_processing_time = 300 # seconds | |
| # Process the file | |
| logger.info(f"Starting text extraction for file: {filename}") | |
| try: | |
| texts = process_file(temp_path, filename) | |
| logger.info(f"Processed file into {len(texts)} text chunks (took {time.time() - start_time:.2f}s)") | |
| # Check if processing is taking too long already | |
| if time.time() - start_time > max_processing_time / 2: | |
| logger.warning(f"Text extraction took more than half the allowed time. Limiting chunks...") | |
| # Limit to a smaller number if extraction took a long time | |
| max_chunks = 50 | |
| if len(texts) > max_chunks: | |
| logger.warning(f"Limiting text chunks from {len(texts)} to {max_chunks}") | |
| texts = texts[:max_chunks] | |
| except Exception as e: | |
| logger.error(f"Error during text extraction: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| session_manager.update_session(session_id, "failed") | |
| os.unlink(temp_path) | |
| return | |
| # Setup vector database - This is the part that might be hanging | |
| logger.info(f"Starting vector DB creation for {len(texts)} chunks") | |
| embedding_start = time.time() | |
| # Create a task with overall timeout | |
| try: | |
| async def setup_with_timeout(): | |
| return await setup_vector_db(texts) | |
| # Wait for vector DB setup with timeout | |
| vector_db = await asyncio.wait_for( | |
| setup_with_timeout(), | |
| timeout=max_processing_time - (time.time() - start_time) | |
| ) | |
| # Get document count - check if documents property is available | |
| if hasattr(vector_db, 'documents'): | |
| doc_count = len(vector_db.documents) | |
| else: | |
| # If using the original VectorDatabase implementation that uses vectors dict | |
| doc_count = len(vector_db.vectors) if hasattr(vector_db, 'vectors') else 0 | |
| logger.info(f"Created vector database with {doc_count} documents (took {time.time() - embedding_start:.2f}s)") | |
| # Create RAG pipeline | |
| logger.info(f"Creating RAG pipeline for session {session_id}") | |
| rag_pipeline = RetrievalAugmentedQAPipeline(vector_db_retriever=vector_db) | |
| # Store pipeline in session manager | |
| session_manager.update_session(session_id, rag_pipeline) | |
| logger.info(f"Updated session {session_id} with processed pipeline (total time: {time.time() - start_time:.2f}s)") | |
| except asyncio.TimeoutError: | |
| logger.error(f"Vector database creation timed out after {time.time() - embedding_start:.2f}s") | |
| session_manager.update_session(session_id, "failed") | |
| except Exception as e: | |
| logger.error(f"Error in vector database creation: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| session_manager.update_session(session_id, "failed") | |
| # Clean up temp file | |
| os.unlink(temp_path) | |
| logger.info(f"Removed temporary file: {temp_path}") | |
| except Exception as e: | |
| logger.error(f"Error in background processing for session {session_id}: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| # Mark the session as failed rather than removing it | |
| session_manager.update_session(session_id, "failed") | |
| # Try to clean up temp file if it exists | |
| try: | |
| if os.path.exists(temp_path): | |
| os.unlink(temp_path) | |
| logger.info(f"Cleaned up temporary file after error: {temp_path}") | |
| except Exception as cleanup_error: | |
| logger.error(f"Error cleaning up temp file: {str(cleanup_error)}") | |
| async def upload_file(background_tasks: BackgroundTasks, file: UploadFile = File(...)): | |
| try: | |
| logger.info(f"Received upload request for file: {file.filename}") | |
| logger.info(f"Content type: {file.content_type}") | |
| # Create a unique ID for this upload | |
| upload_id = str(uuid.uuid4()) | |
| logger.info(f"Assigned upload ID: {upload_id}") | |
| # Check file size first | |
| file_size = 0 | |
| chunk_size = 1024 * 1024 # 1MB chunks for reading | |
| contents = bytearray() | |
| # Read file in chunks to avoid memory issues | |
| try: | |
| while True: | |
| chunk = await asyncio.wait_for(file.read(chunk_size), timeout=60.0) | |
| if not chunk: | |
| break | |
| file_size += len(chunk) | |
| contents.extend(chunk) | |
| # Check size limit | |
| if file_size > FILE_SIZE_LIMIT: | |
| logger.warning(f"File too large: {file_size/1024/1024:.2f}MB exceeds limit of {FILE_SIZE_LIMIT/1024/1024}MB") | |
| return HTTPException( | |
| status_code=413, | |
| detail=f"File too large. Maximum size is {FILE_SIZE_LIMIT/1024/1024}MB" | |
| ) | |
| # Log progress for large files | |
| if file_size % (5 * 1024 * 1024) == 0: # Log every 5MB | |
| logger.info(f"Upload progress: {file_size/1024/1024:.2f}MB read so far...") | |
| except asyncio.TimeoutError: | |
| logger.error(f"Timeout reading file: {file.filename}") | |
| raise HTTPException( | |
| status_code=408, | |
| detail="Request timeout while reading file. Please try again." | |
| ) | |
| logger.info(f"File size: {file_size/1024/1024:.2f}MB") | |
| # Reset file stream for processing | |
| file_content = bytes(contents) | |
| # Create a temporary file | |
| suffix = f".{file.filename.split('.')[-1]}" | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as temp_file: | |
| # Write file content to temp file | |
| temp_file.write(file_content) | |
| temp_path = temp_file.name | |
| logger.info(f"Created temporary file at: {temp_path}") | |
| # Generate session ID and create session | |
| session_id = session_manager.create_session("processing") | |
| logger.info(f"Created session ID: {session_id}") | |
| # Process file in background | |
| background_tasks.add_task( | |
| process_file_background, | |
| temp_path, | |
| file.filename, | |
| session_id | |
| ) | |
| return {"session_id": session_id, "message": "File uploaded and processing started", "upload_id": upload_id} | |
| except Exception as e: | |
| logger.error(f"Error processing upload: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
| async def process_query(request: QueryRequest): | |
| logger.info(f"Received query request for session: {request.session_id}") | |
| # Check if session exists | |
| if not session_manager.session_exists(request.session_id): | |
| logger.warning(f"Session not found: {request.session_id}") | |
| raise HTTPException(status_code=404, detail="Session not found. Please upload a document first.") | |
| # Get session data | |
| session_data = session_manager.get_session(request.session_id) | |
| # Check if processing is still ongoing | |
| if session_data == "processing": | |
| logger.info(f"Document still processing for session: {request.session_id}") | |
| raise HTTPException(status_code=409, detail="Document is still being processed. Please try again in a moment.") | |
| # Check if processing failed | |
| if session_data == "failed": | |
| logger.error(f"Processing failed for session: {request.session_id}") | |
| raise HTTPException(status_code=500, detail="Document processing failed. Please try uploading again.") | |
| try: | |
| logger.info(f"Processing query: '{request.query}' for session: {request.session_id}") | |
| # Get response from RAG pipeline | |
| start_time = time.time() | |
| result = await session_data.arun_pipeline(request.query) | |
| # Stream the response - this is key for the Star Wars effect | |
| async def stream_response(): | |
| try: | |
| async for chunk in result["response"]: | |
| # Add a small delay between chunks for dramatic effect | |
| await asyncio.sleep(0.01) | |
| # Stream each chunk as JSON with proper encoding | |
| yield chunk | |
| logger.info(f"Completed streaming response (took {time.time() - start_time:.2f}s)") | |
| except Exception as e: | |
| logger.error(f"Error in streaming: {str(e)}") | |
| yield f"Error during streaming: {str(e)}" | |
| # Return streaming response | |
| return StreamingResponse( | |
| stream_response(), | |
| media_type="text/plain", | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing query for session {request.session_id}: {str(e)}") | |
| logger.error(traceback.format_exc()) # Log the full error traceback | |
| raise HTTPException(status_code=500, detail=f"Error processing query: {str(e)}") | |
| def health_check(): | |
| return {"status": "healthy"} | |
| def test_endpoint(): | |
| return {"message": "Backend is accessible"} | |
| async def session_status(session_id: str): | |
| """Check if a session exists and its processing status""" | |
| logger.info(f"Checking status for session: {session_id}") | |
| if not session_manager.session_exists(session_id): | |
| logger.warning(f"Session not found: {session_id}") | |
| return {"exists": False, "status": "not_found"} | |
| session_data = session_manager.get_session(session_id) | |
| if session_data == "processing": | |
| logger.info(f"Session {session_id} is still processing") | |
| return {"exists": True, "status": "processing"} | |
| if session_data == "failed": | |
| logger.error(f"Session {session_id} processing failed") | |
| return {"exists": True, "status": "failed"} | |
| logger.info(f"Session {session_id} is ready") | |
| return {"exists": True, "status": "ready"} | |
| async def debug_sessions(): | |
| """Return debug information about all sessions - for diagnostic use only""" | |
| logger.info("Accessed debug sessions endpoint") | |
| # Get summary of all sessions | |
| sessions_summary = session_manager.get_sessions_summary() | |
| return sessions_summary | |
| # For Hugging Face Spaces deployment, serve the static files from the React build | |
| frontend_path = pathlib.Path(__file__).parent.parent / "frontend" / "build" | |
| if frontend_path.exists(): | |
| app.mount("/", StaticFiles(directory=str(frontend_path), html=True), name="frontend") | |
| async def serve_frontend(): | |
| return FileResponse(str(frontend_path / "index.html")) | |
| if __name__ == "__main__": | |
| # Get the port from environment variable or use default | |
| port = int(os.environ.get("PORT", 8000)) | |
| # For Hugging Face Spaces deployment | |
| uvicorn.run( | |
| "main:app", | |
| host="0.0.0.0", | |
| port=port, | |
| proxy_headers=True, # This tells uvicorn to trust the X-Forwarded-* headers | |
| forwarded_allow_ips="*" # Allow forwarded requests from any IP | |
| ) |