Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Any | |
| import tempfile | |
| import shutil | |
| import logging | |
| import time | |
| import traceback | |
| import asyncio | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| # Make sure aimakerspace is in the path | |
| import sys | |
| sys.path.append(os.path.join(os.path.dirname(os.path.dirname(__file__)), "")) | |
| # Import from local aimakerspace module | |
| from aimakerspace.text_utils import CharacterTextSplitter, TextFileLoader, PDFLoader | |
| from aimakerspace.vectordatabase import VectorDatabase | |
| from aimakerspace.openai_utils.embedding import EmbeddingModel | |
| from openai import OpenAI | |
| # Initialize OpenAI client | |
| client = OpenAI(api_key=os.getenv("OPENAI_API_KEY")) | |
| logger.info(f"Initialized OpenAI client with API key: {'valid key' if os.getenv('OPENAI_API_KEY') else 'API KEY MISSING!'}") | |
| class RetrievalAugmentedQAPipeline: | |
| def __init__(self, vector_db_retriever: VectorDatabase) -> None: | |
| self.vector_db_retriever = vector_db_retriever | |
| async def arun_pipeline(self, user_query: str): | |
| """ | |
| Run the RAG pipeline with the given user query. | |
| Returns a stream of response chunks. | |
| """ | |
| try: | |
| # 1. Retrieve relevant documents | |
| logger.info(f"RAG Pipeline: Retrieving documents for query: '{user_query}'") | |
| relevant_docs = self.vector_db_retriever.search_by_text(user_query, k=4) | |
| if not relevant_docs: | |
| logger.warning("No relevant documents found in vector database") | |
| documents_context = "No relevant information found in the document." | |
| else: | |
| logger.info(f"Found {len(relevant_docs)} relevant document chunks") | |
| # Format documents | |
| documents_context = "\n\n".join([doc[0] for doc in relevant_docs]) | |
| # Debug similarity scores | |
| doc_scores = [f"{i+1}. Score: {doc[1]:.4f}" for i, doc in enumerate(relevant_docs)] | |
| logger.info(f"Document similarity scores: {', '.join(doc_scores) if doc_scores else 'No documents'}") | |
| # 2. Create messaging payload | |
| messages = [ | |
| {"role": "system", "content": f"""You are a helpful AI assistant that answers questions based on the provided document context. | |
| If the answer is not in the context, say that you don't know based on the available information. | |
| Use the following document extracts to answer the user's question: | |
| {documents_context}"""}, | |
| {"role": "user", "content": user_query} | |
| ] | |
| # 3. Call LLM and stream the output | |
| async def generate_response(): | |
| try: | |
| logger.info("Initiating streaming completion from OpenAI") | |
| stream = client.chat.completions.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| temperature=0.2, | |
| stream=True | |
| ) | |
| for chunk in stream: | |
| if chunk.choices[0].delta.content: | |
| yield chunk.choices[0].delta.content | |
| except Exception as e: | |
| logger.error(f"Error generating stream: {str(e)}") | |
| yield f"\n\nI apologize, but I encountered an error while generating a response: {str(e)}" | |
| return { | |
| "response": generate_response() | |
| } | |
| except Exception as e: | |
| logger.error(f"Error in RAG pipeline: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return { | |
| "response": (chunk for chunk in [f"I apologize, but an error occurred: {str(e)}"]) | |
| } | |
| def process_file(file_path: str, file_name: str) -> List[str]: | |
| """Process an uploaded file and convert it to text chunks - optimized for speed""" | |
| logger.info(f"Processing file: {file_name} at path: {file_path}") | |
| try: | |
| # Determine loader based on file extension | |
| if file_name.lower().endswith('.txt'): | |
| logger.info(f"Using TextFileLoader for {file_name}") | |
| loader = TextFileLoader(file_path) | |
| loader.load() | |
| elif file_name.lower().endswith('.pdf'): | |
| logger.info(f"Using PDFLoader for {file_name}") | |
| loader = PDFLoader(file_path) | |
| loader.load() | |
| else: | |
| logger.warning(f"Unsupported file type: {file_name}") | |
| return ["Unsupported file format. Please upload a .txt or .pdf file."] | |
| # Get documents from loader | |
| documents = loader.documents | |
| if documents and len(documents) > 0: | |
| logger.info(f"Loaded document with {len(documents[0])} characters") | |
| else: | |
| logger.warning("No document content loaded") | |
| return ["No content found in the document"] | |
| # Split text into chunks - use parallel processing | |
| logger.info("Splitting document with parallel processing") | |
| chunk_size = 1500 # Increased from 1000 for fewer chunks | |
| chunk_overlap = 150 # Increased from 100 for better context | |
| # Use 8 workers for parallel processing | |
| text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap, max_workers=8) | |
| text_chunks = text_splitter.split_texts(documents) | |
| # Limit chunks to avoid processing too many for speed | |
| max_chunks = 40 # Reduced from default | |
| if len(text_chunks) > max_chunks: | |
| logger.warning(f"Too many chunks ({len(text_chunks)}), limiting to {max_chunks} for faster processing") | |
| text_chunks = text_chunks[:max_chunks] | |
| logger.info(f"Split document into {len(text_chunks)} chunks") | |
| return text_chunks | |
| except Exception as e: | |
| logger.error(f"Error processing file: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| return [f"Error processing file: {str(e)}"] | |
| async def setup_vector_db(texts: List[str]) -> VectorDatabase: | |
| """Create vector database from text chunks - optimized with parallel processing""" | |
| logger.info(f"Setting up vector database with {len(texts)} text chunks") | |
| # Create embedding model to use with VectorDatabase | |
| embedding_model = EmbeddingModel() | |
| # Use batch size of 20 for better parallelization | |
| vector_db = VectorDatabase(embedding_model=embedding_model, batch_size=20) | |
| try: | |
| # Limit number of chunks for faster processing | |
| max_chunks = 40 | |
| if len(texts) > max_chunks: | |
| logger.warning(f"Limiting {len(texts)} chunks to {max_chunks} for vector embedding") | |
| texts = texts[:max_chunks] | |
| # Build vector database with batch processing | |
| logger.info("Building vector database with batch processing") | |
| await vector_db.abuild_from_list(texts) | |
| # Add documents property for compatibility | |
| vector_db.documents = texts | |
| logger.info(f"Vector database built with {len(texts)} documents") | |
| return vector_db | |
| except asyncio.TimeoutError: | |
| logger.error(f"Vector database creation timed out after 300 seconds") | |
| # Create minimal fallback DB with just a few documents | |
| fallback_db = VectorDatabase(embedding_model=embedding_model) | |
| if texts: | |
| # Use just first few texts for minimal functionality | |
| minimal_texts = texts[:3] | |
| for text in minimal_texts: | |
| fallback_db.insert(text, [0.0] * 1536) # Use zero vectors for speed | |
| fallback_db.documents = minimal_texts | |
| else: | |
| error_text = "I'm sorry, but there was a timeout during document processing." | |
| fallback_db.insert(error_text, [0.0] * 1536) | |
| fallback_db.documents = [error_text] | |
| return fallback_db | |
| except Exception as e: | |
| logger.error(f"Error setting up vector database: {str(e)}") | |
| logger.error(traceback.format_exc()) | |
| # Create fallback DB for this error case | |
| fallback_db = VectorDatabase(embedding_model=embedding_model) | |
| error_text = "I'm sorry, but there was an error processing the document." | |
| fallback_db.insert(error_text, [0.0] * 1536) | |
| fallback_db.documents = [error_text] | |
| return fallback_db |