from query_utils import process_query_for_rewrite, get_non_autism_response from logger.custom_logger import CustomLoggerTracker from dotenv import load_dotenv from query_utils import check_answer_autism_relevance, get_non_autism_answer_response from clients import get_weaviate_client, qwen_generate from query_utils import process_query_for_rewrite from rag_steps import * from rag_utils import * from prompt_template import ( Prompt_template_Wisal, Prompt_template_User_document_prompt) import os import asyncio from typing import Dict from configs import load_yaml_config config = load_yaml_config("config.yaml") # Load .env early load_dotenv() # --------------------------- # Custom Logger Initialization # --------------------------- custom_log = CustomLoggerTracker() logger = custom_log.get_logger("doc_utils") logger.info("Logger initialized for Documents utilities module") # --------------------------- # Environment & Globals # --------------------------- # client = get_weaviate_client() # if client is None: # logger.info("Weaviate client not connected. Please check your WEAVIATE_URL and WEAVIATE_API_KEY.") # else: # logger.info("Weaviate client connected (startup checks skipped).") SESSION_ID = "default" pending_clarifications: Dict[str, str] = {} SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip() SILICONFLOW_CHAT_URL = os.getenv( "SILICONFLOW_CHAT_URL", "https://api.siliconflow.com/v1/chat/completions").strip() if not SILICONFLOW_API_KEY: logger.warning( "SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.") if not SILICONFLOW_URL: logger.warning( "SILICONFLOW_URL is not set. OpenAI client base_url will not work.") # Global variables - consider moving to a config class last_uploaded_path = None def get_text_splitter(): """Factory function for text splitter - makes testing easier""" return RecursiveCharacterTextSplitter( chunk_size=config["chunking"]["chunk_size"], # Fixed: was chunk_size chunk_overlap=config["chunking"]["chunk_overlap"], separators=config["chunking"]["separators"], # Fixed: was chunk_size ) # --------------------------- # RAG DOMAIN FUNCTIONS # --------------------------- def rag_dom_ingest(file_path: str) -> str: if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") try: raw = extract_text(file_path) if not raw.strip(): raise ValueError(f"No text extracted from {file_path}") splitter = get_text_splitter() docs = splitter.split_text(raw) # Filter empty chunks texts = [chunk for chunk in docs if chunk.strip()] if not texts: raise ValueError("No valid text chunks created") vectors = embed_texts(texts) collection_name = config['rag']['weavaite_collection'] logger.info(f"RAG domain ingesting to collection: {collection_name}") client = get_weaviate_client() # Batch insert with error handling with client.batch.dynamic() as batch: for txt, vec in zip(texts, vectors): batch.add_object( collection=collection_name, properties={"text": txt}, vector=vec) logger.info(f"Successfully ingested {len(texts)} chunks from {os.path.basename(file_path)}") return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}" except Exception as e: logger.exception(f"Error ingesting file {file_path}: {e}") finally: if client is not None: try: client.close() except Exception as close_error: logger.error(f"Error closing Weaviate client: {close_error}") def rag_dom_qa(question: str) -> str: if not question.strip(): return "Please provide a valid question." try: corrected_query, is_autism_related, _ = process_query_for_rewrite( question) if not is_autism_related: return get_non_autism_response() q_vec = embed_texts([corrected_query])[0] collection_name = config["rag"]["weavaite_collection"] logger.info(f"RAG domain QA using collection: {collection_name}") client = get_weaviate_client() documents = client.collections.get(collection_name) response = documents.query.near_vector( near_vector=q_vec, limit=5, return_metadata=["distance"]) hits = response.objects if not hits: return "I couldn't find relevant information to answer your question." context = "\n\n".join(hit.properties["text"] for hit in hits) wisal_prompt = Prompt_template_Wisal.format( new_query=corrected_query, document=context) initial_answer = qwen_generate(wisal_prompt) answer_relevance_score = check_answer_autism_relevance(initial_answer) if answer_relevance_score < 50: return get_non_autism_answer_response() return initial_answer except Exception as e: logger.error(f"Error in RAG domain QA: {e}") return f"Sorry, I encountered an error processing your question: {str(e)}" finally: if client is not None: try: client.close() except Exception as close_error: logger.error(f"Error closing Weaviate client: {close_error}") # --------------------------- # OLD DOCUMENTS # --------------------------- async def old_doc_vdb(query: str, top_k: int = 1) -> dict: """Query old documents vector database""" if not query.strip(): return {"answer": []} qe = encode_query(query) if not qe: return {"answer": []} try: client = get_weaviate_client() coll = client.collections.get(config["rag"]["weavaite_collection"]) ## old_documents res = coll.query.near_vector( near_vector=qe, limit=top_k, return_properties=["text"]) if not getattr(res, "objects", None): return {"answer": []} return {"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]} except Exception as e: logger.error(f"RAG Error in old_doc_vdb: {e}") return {"answer": []} finally: if client is not None: try: client.close() except Exception as close_error: logger.error(f"Error closing Weaviate client: {close_error}") def old_doc_ingestion(path: str) -> str: global last_uploaded_path if not os.path.exists(path): raise FileNotFoundError(f"File not found: {path}") last_uploaded_path = path logger.info(f"Old document path set: {os.path.basename(path)}") return f"Old document ingested: {os.path.basename(path)}" def old_doc_qa(query: str) -> str: if not query.strip(): return "Please provide a valid question." try: corrected_query, is_autism_related, _ = process_query_for_rewrite( query) if not is_autism_related: return get_non_autism_response() rag_resp = asyncio.run(old_doc_vdb(corrected_query)) chunks = rag_resp.get("answer", []) if not chunks: return "Sorry, I couldn't find relevant content in the old document." combined_answer = "\n".join(f"- {c}" for c in chunks if c.strip()) answer_relevance_score = check_answer_autism_relevance(combined_answer) if answer_relevance_score < 50: return get_non_autism_answer_response() return combined_answer except Exception as e: logger.error(f"Error in old_doc_qa: {e}") return f"Error processing your request: {e}" # --------------------------- # USER SPECIFIC DOCUMENTS # --------------------------- def user_doc_ingest(file_path: str) -> str: if not os.path.exists(file_path): raise FileNotFoundError(f"File not found: {file_path}") try: raw = extract_text(file_path) if not raw.strip(): raise ValueError(f"No text extracted from {file_path}") splitter = get_text_splitter() docs = splitter.split_text(raw) texts = [chunk for chunk in docs if chunk.strip()] if not texts: raise ValueError("No valid text chunks created") vectors = embed_texts(texts) client = get_weaviate_client() collection_name = config["rag"]["weavaite_collection"] # Batch insert with client.batch.dynamic() as batch: for txt, vec in zip(texts, vectors): batch.add_object( collection=collection_name, properties={"text": txt}, vector=vec) logger.info( f"Successfully ingested user document: {os.path.basename(file_path)}") return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}" except Exception as e: logger.exception(f"Error ingesting user document {file_path}: {e}") finally: if client is not None: try: client.close() except Exception as close_error: logger.error(f"Error closing Weaviate client: {close_error}") def user_doc_qa(question: str) -> str: if not question.strip(): return "Please provide a valid question." try: corrected_query, is_autism_related, _ = process_query_for_rewrite( question) if not is_autism_related: return get_non_autism_response() q_vec = embed_texts([corrected_query])[0] client = get_weaviate_client() documents = client.collections.get( config["rag"]["weavaite_collection"]) response = documents.query.near_vector( near_vector=q_vec, limit=5, return_metadata=["distance"]) hits = response.objects if not hits: return "I couldn't find relevant information to answer your question." context = "\n\n".join(hit.properties["text"] for hit in hits) UserSpecificDocument_prompt = Prompt_template_User_document_prompt.format( new_query=corrected_query, document=context) initial_answer = qwen_generate(UserSpecificDocument_prompt) answer_relevance_score = check_answer_autism_relevance(initial_answer) if answer_relevance_score < 50: return get_non_autism_answer_response() return initial_answer except Exception as e: logger.error(f"Error in user_doc_qa: {e}") return f"Sorry, I encountered an error processing your question: {str(e)}" # finally: # if client is not None: # try: # client.close() # except Exception as close_error: # logger.error(f"Error closing Weaviate client: {close_error}") ## close client of weaviate # client.close() if __name__ == "__main__": # Test file paths pdf_test = "tests/Computational Requirements for Embed.pdf" docs_test = "tests/Computational Requirements for Embed.docx" txt_test = "assets/RAG_Documents/Autism_Books_1.txt" print(f"=" * 70) print("COMPREHENSIVE RAG DOCUMENT UTILS TEST SUITE") print(f"=" * 70) # =========================== # Test 1: RAG Domain Functions # =========================== print(f"\n{'=' * 50}") print("TEST 1: RAG DOMAIN FUNCTIONS") print(f"{'=' * 50}") try: print(f"Testing RAG domain ingestion with: {os.path.basename(txt_test)}") if os.path.exists(txt_test): result = rag_dom_ingest(txt_test) print(f"✓ RAG Domain Ingestion Result: {result}") # Test RAG domain QA print(f"\nTesting RAG domain QA...") test_questions = [ "What is autism?", "How can I help a child with autism?", "What are the symptoms of autism?", "Tell me about weather today" # Non-autism related ] for question in test_questions: print(f"\nQ: {question}") answer = rag_dom_qa(question) print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") else: print(f"✗ Test file not found: {txt_test}") except Exception as e: print(f"✗ RAG Domain Test Failed: {e}") # =========================== # Test 2: Old Document Functions # =========================== print(f"\n{'=' * 50}") print("TEST 2: OLD DOCUMENT FUNCTIONS") print(f"{'=' * 50}") try: print(f"Testing old document ingestion...") if os.path.exists(txt_test): result = old_doc_ingestion(txt_test) print(f"✓ Old Document Ingestion Result: {result}") # Test old document QA print(f"\nTesting old document QA...") test_questions = [ "What information is in this document?", "Tell me about autism interventions", "What is machine learning?" # Non-autism related ] for question in test_questions: print(f"\nQ: {question}") answer = old_doc_qa(question) print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") else: print(f"✗ Test file not found: {txt_test}") except Exception as e: print(f"✗ Old Document Test Failed: {e}") # =========================== # Test 3: User Document Functions # =========================== print(f"\n{'=' * 50}") print("TEST 3: USER DOCUMENT FUNCTIONS") print(f"{'=' * 50}") try: print(f"Testing user document ingestion...") if os.path.exists(txt_test): result = user_doc_ingest(txt_test) print(f"✓ User Document Ingestion Result: {result}") # Test user document QA print(f"\nTesting user document QA...") test_questions = [ "What does this document say about autism?", "Are there any treatment recommendations?", "What's the capital of France?" # Non-autism related ] for question in test_questions: print(f"\nQ: {question}") answer = user_doc_qa(question) print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") else: print(f"✗ Test file not found: {txt_test}") except Exception as e: print(f"✗ User Document Test Failed: {e}") # =========================== # Test 4: Multiple File Format Support # =========================== print(f"\n{'=' * 50}") print("TEST 4: MULTIPLE FILE FORMAT SUPPORT") print(f"{'=' * 50}") test_files = [ (pdf_test, "PDF"), (docs_test, "DOCX"), (txt_test, "TXT") ] for file_path, file_type in test_files: print(f"\nTesting {file_type} file: {os.path.basename(file_path)}") if os.path.exists(file_path): try: # Test extraction text = extract_text(file_path) if text: print(f"✓ {file_type} text extraction successful: {len(text)} characters") print(f" Preview: {text[:100]}...") # Test ingestion result = rag_dom_ingest(file_path) print(f"✓ {file_type} ingestion successful: {result}") else: print(f"✗ {file_type} text extraction returned empty") except Exception as e: print(f"✗ {file_type} processing failed: {e}") else: print(f"✗ {file_type} file not found: {file_path}") # =========================== # Test 5: Error Handling # =========================== print(f"\n{'=' * 50}") print("TEST 5: ERROR HANDLING") print(f"{'=' * 50}") # Test with non-existent file print("Testing with non-existent file...") try: result = rag_dom_ingest("non_existent_file.txt") print(f"✗ Should have failed: {result}") except FileNotFoundError: print("✓ Correctly handled non-existent file") except Exception as e: print(f"✓ Handled error: {e}") # Test with empty query print("\nTesting with empty query...") empty_result = rag_dom_qa("") print(f"✓ Empty query handled: {empty_result}") # Test with very long query print("\nTesting with very long query...") long_query = "autism " * 100 + "what is it?" long_result = rag_dom_qa(long_query) print(f"✓ Long query handled: {long_result[:100]}...") # =========================== # Test 6: Old Document Vector DB # =========================== print(f"\n{'=' * 50}") print("TEST 6: OLD DOCUMENT VECTOR DB") print(f"{'=' * 50}") try: print("Testing old document vector database query...") vdb_result = asyncio.run(old_doc_vdb("autism interventions", top_k=3)) print(f"✓ Vector DB query successful: {len(vdb_result.get('answer', []))} results") for i, answer in enumerate(vdb_result.get('answer', [])[:2]): print(f" Result {i+1}: {answer[:100]}...") except Exception as e: print(f"✗ Vector DB test failed: {e}") # =========================== # Test 7: Configuration and Environment # =========================== print(f"\n{'=' * 50}") print("TEST 7: CONFIGURATION AND ENVIRONMENT") print(f"{'=' * 50}") print("Checking environment variables...") env_vars = [ "SILICONFLOW_API_KEY", "SILICONFLOW_URL", "SILICONFLOW_CHAT_URL", "WEAVIATE_URL", "WEAVIATE_API_KEY" ] for var in env_vars: value = os.getenv(var) if value: print(f"✓ {var}: Set (length: {len(value)})") else: print(f"✗ {var}: Not set") print(f"\nChecking configuration...") try: print(f"✓ Chunk size: {config['chunking']['chunk_size']}") print(f"✓ Chunk overlap: {config['chunking']['chunk_overlap']}") print(f"✓ RAG collection: {config['rag']['weavaite_collection']}") print(f"✓ Old doc collection: {config['rag']['old_doc']}") except Exception as e: print(f"✗ Configuration error: {e}") # =========================== # Test 8: Text Splitter # =========================== print(f"\n{'=' * 50}") print("TEST 8: TEXT SPLITTER") print(f"{'=' * 50}") try: splitter = get_text_splitter() sample_text = "This is a sample text. " * 100 # Create long text chunks = splitter.split_text(sample_text) print(f"✓ Text splitter created {len(chunks)} chunks") print(f"✓ Average chunk size: {sum(len(c) for c in chunks) / len(chunks):.0f} characters") except Exception as e: print(f"✗ Text splitter test failed: {e}") # =========================== # Test Summary # =========================== print(f"\n{'=' * 70}") print("TEST SUMMARY") print(f"{'=' * 70}") print("✓ All major functions tested") print("✓ Error handling verified") print("✓ Multiple file formats supported") print("✓ Configuration checked") print("✓ Vector database operations tested") print(f"{'=' * 70}") print("TEST SUITE COMPLETED") print(f"{'=' * 70}") # # Close client properly # try: # if 'client' in globals() and client: # client.close() # print("✓ Weaviate client closed properly") # except Exception as e: # print(f"✗ Error closing client: {e}")