Spaces:
Sleeping
Sleeping
| # src/query_service/api.py | |
| from fastapi import FastAPI, HTTPException, UploadFile, File | |
| from fastapi.middleware.cors import CORSMiddleware # Import CORSMiddleware | |
| from pydantic import BaseModel | |
| from src.retrieval_handler.retriever import RetrievalHandler | |
| from src.llm_integrator.llm import LLMIntegrator | |
| from src.embedding_generator.embedder import EmbeddingGenerator | |
| from src.vector_store_manager.chroma_manager import ChromaManager | |
| import logging | |
| from typing import Literal, Optional, Dict, Any, List # Import List | |
| from langchain_core.messages import HumanMessage, SystemMessage, AIMessage | |
| import shutil | |
| import uuid | |
| logger = logging.getLogger(__name__) | |
| # Initialize core components (these should ideally be dependency injected in a larger app) | |
| # For simplicity in this example, we initialize them globally. | |
| embedding_generator: Optional[EmbeddingGenerator] = None | |
| vector_store_manager: Optional[ChromaManager] = None | |
| retrieval_handler: Optional[RetrievalHandler] = None | |
| llm_integrator: Optional[LLMIntegrator] = None | |
| try: | |
| embedding_generator = EmbeddingGenerator() | |
| vector_store_manager = ChromaManager(embedding_generator) | |
| retrieval_handler = RetrievalHandler(embedding_generator, vector_store_manager) | |
| llm_integrator = LLMIntegrator() | |
| logger.info("Initialized core RAG components.") | |
| except Exception as e: | |
| logger.critical(f"Failed to initialize core RAG components: {e}") | |
| # Depending on severity, you might want to exit or raise an error here | |
| # For a production API, you might want to return a 500 error on relevant endpoints | |
| # if components fail to initialize, rather than crashing the app startup. | |
| app = FastAPI( | |
| title="Insight AI API", | |
| description="API for querying financial information.", | |
| version="1.0.0" | |
| ) | |
| # --- CORS Middleware --- | |
| # Add CORSMiddleware to allow cross-origin requests from your frontend. | |
| # For development, you can allow all origins (*). | |
| # For production, you should restrict this to your frontend's specific origin(s). | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allows all origins. Change this to your frontend's URL in production. | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allows all methods (GET, POST, OPTIONS, etc.) | |
| allow_headers=["*"], # Allows all headers | |
| ) | |
| # ----------------------- | |
| class Message(BaseModel): | |
| role: Literal['user', 'assistant', 'system'] | |
| content: str | |
| class QueryRequest(BaseModel): | |
| query: str | |
| chat_history: Optional[List[Message]] = [] | |
| filters: Optional[Dict[str, Any]] = None # Allow passing metadata filters | |
| # Define interfaces matching the backend response structure | |
| class SourceMetadata(BaseModel): | |
| source: Optional[str] = None | |
| ruling_date: Optional[str] = None | |
| # Add other expected metadata fields here | |
| # Example: topic: Optional[str] = None | |
| class RetrievedSource(BaseModel): | |
| content_snippet: str | |
| metadata: Optional[SourceMetadata] = None | |
| class QueryResponse(BaseModel): | |
| answer: str | |
| retrieved_sources: Optional[List[RetrievedSource]] = None | |
| class TitleResponse(BaseModel): | |
| title: str | |
| class TitleRequest(BaseModel): | |
| query: str | |
| async def query_rulings(request: QueryRequest): | |
| """ | |
| Receives a user query and returns a generated answer based on retrieved rulings. | |
| """ | |
| logger.info(f"Received query: {request.query}") | |
| if request.filters: | |
| logger.info(f"Received filters: {request.filters}") | |
| # Check if RAG components were initialized successfully | |
| if not retrieval_handler or not llm_integrator: | |
| logger.error("RAG components not initialized.") | |
| raise HTTPException(status_code=500, detail="System components not ready.") | |
| try: | |
| # 1. Retrieve relevant documents based on the query and filters | |
| # Pass filters if your RetrievalHandler/ChromaManager supports using them in search | |
| # Current simple implementation in RetrievalHandler doesn't directly use filters in invoke, | |
| # requires adjustment in RetrievalHandler.retrieve_documents if needed. | |
| retrieved_docs = retrieval_handler.retrieve_documents(request.query, filters=request.filters) | |
| if not retrieved_docs: | |
| logger.warning("No relevant documents retrieved for query.") | |
| return QueryResponse(answer="Could not find relevant rulings for your query.") | |
| # Convert chat_history to appropriate LangChain message types | |
| chat_history = [] | |
| logger.debug(f"Raw chat history input: {request.chat_history}") | |
| for msg in request.chat_history: | |
| logger.debug(f"Processing message - Role: {msg.role}, Content: {msg.content[:50]}...") | |
| if msg.role == "user": | |
| new_msg = HumanMessage(content=msg.content) | |
| elif msg.role == "assistant": | |
| new_msg = AIMessage(content=msg.content) | |
| elif msg.role == "system": | |
| new_msg = SystemMessage(content=msg.content) | |
| else: | |
| logger.warning(f"Invalid message role: {msg.role}. Skipping message.") | |
| continue | |
| logger.debug(f"Converted to: {type(new_msg).__name__}") | |
| chat_history.append(new_msg) | |
| logger.debug(f"Final chat history types: {[type(m).__name__ for m in chat_history]}") | |
| # 2. Generate response using the LLM based on the query, retrieved context, and chat history | |
| answer = llm_integrator.generate_response(request.query, retrieved_docs, chat_history) | |
| # 3. Prepare retrieved source information for the response | |
| retrieved_sources = [] | |
| for doc in retrieved_docs: | |
| # Ensure the structure matches the RetrievedSource Pydantic model | |
| source_metadata = SourceMetadata(**doc.metadata) if doc.metadata else None | |
| retrieved_sources.append(RetrievedSource( | |
| content_snippet=doc.page_content[:500] + "..." if len(doc.page_content) > 500 else doc.page_content, # Snippet of content | |
| metadata=source_metadata # Include all metadata | |
| )) | |
| logger.info("Successfully processed query and generated response.") | |
| return QueryResponse(answer=answer, retrieved_sources=retrieved_sources) | |
| except Exception as e: | |
| logger.error(f"An error occurred during query processing: {e}") | |
| # Provide a more informative but secure error message to the user. | |
| raise HTTPException(status_code=500, detail="An internal error occurred while processing your query.") | |
| async def generate_chat_title(request: TitleRequest): | |
| try: | |
| title = llm_integrator.generate_chat_title(request.query) | |
| return {"title": title} | |
| except Exception as e: | |
| logger.error(f"Title generation error: {e}") | |
| return {"title": "New Chat"} | |
| async def upload_docs(files: list[UploadFile] = File(...)): | |
| """ | |
| Upload new documents and trigger ingestion for them. | |
| """ | |
| import os | |
| from src.ingestion_orchestrator.orchestrator import IngestionOrchestrator | |
| # Create a unique folder in /tmp | |
| upload_id = str(uuid.uuid4()) | |
| upload_dir = f"/tmp/ingest_{upload_id}" | |
| os.makedirs(upload_dir, exist_ok=True) | |
| saved_files = [] | |
| for file in files: | |
| file_path = os.path.join(upload_dir, file.filename) | |
| with open(file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| saved_files.append(file.filename) | |
| # Run the ingestion pipeline for the uploaded folder | |
| try: | |
| orchestrator = IngestionOrchestrator() | |
| orchestrator.run_ingestion_pipeline(docs_folder=upload_dir) | |
| logger.info(f"Ingested files: {saved_files}") | |
| return {"status": "success", "files": saved_files} | |
| except Exception as e: | |
| logger.error(f"Ingestion failed: {e}") | |
| raise HTTPException(status_code=500, detail="Ingestion failed.") | |
| # You can add more endpoints here, e.g., /health for health checks | |
| # @app.get("/health") | |
| # async def health_check(): | |
| # # Check connectivity to ChromaDB, LLM service, etc. | |
| # # This requires adding health check methods to your ChromaManager and LLMIntegrator | |
| # chroma_status = vector_store_manager.check_health() if vector_store_manager else "uninitialized" | |
| # llm_status = llm_integrator.check_health() if llm_integrator else "uninitialized" | |
| # return {"chroma": chroma_status, "llm": llm_status} | |