Autism_QA / docs_utils.py
A7m0d's picture
Upload folder using huggingface_hub
712579e verified
from query_utils import process_query_for_rewrite, get_non_autism_response
from logger.custom_logger import CustomLoggerTracker
from dotenv import load_dotenv
from query_utils import check_answer_autism_relevance, get_non_autism_answer_response
from clients import get_weaviate_client, qwen_generate
from query_utils import process_query_for_rewrite
from rag_steps import *
from rag_utils import *
from prompt_template import (
Prompt_template_Wisal,
Prompt_template_User_document_prompt)
import os
import asyncio
from typing import Dict
from configs import load_yaml_config
config = load_yaml_config("config.yaml")
# Load .env early
load_dotenv()
# ---------------------------
# Custom Logger Initialization
# ---------------------------
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("doc_utils")
logger.info("Logger initialized for Documents utilities module")
# ---------------------------
# Environment & Globals
# ---------------------------
# client = get_weaviate_client()
# if client is None:
# logger.info("Weaviate client not connected. Please check your WEAVIATE_URL and WEAVIATE_API_KEY.")
# else:
# logger.info("Weaviate client connected (startup checks skipped).")
SESSION_ID = "default"
pending_clarifications: Dict[str, str] = {}
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip()
SILICONFLOW_CHAT_URL = os.getenv(
"SILICONFLOW_CHAT_URL", "https://api.siliconflow.com/v1/chat/completions").strip()
if not SILICONFLOW_API_KEY:
logger.warning(
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.")
if not SILICONFLOW_URL:
logger.warning(
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.")
# Global variables - consider moving to a config class
last_uploaded_path = None
def get_text_splitter():
"""Factory function for text splitter - makes testing easier"""
return RecursiveCharacterTextSplitter(
chunk_size=config["chunking"]["chunk_size"],
# Fixed: was chunk_size
chunk_overlap=config["chunking"]["chunk_overlap"],
separators=config["chunking"]["separators"], # Fixed: was chunk_size
)
# ---------------------------
# RAG DOMAIN FUNCTIONS
# ---------------------------
def rag_dom_ingest(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
raw = extract_text(file_path)
if not raw.strip():
raise ValueError(f"No text extracted from {file_path}")
splitter = get_text_splitter()
docs = splitter.split_text(raw)
# Filter empty chunks
texts = [chunk for chunk in docs if chunk.strip()]
if not texts:
raise ValueError("No valid text chunks created")
vectors = embed_texts(texts)
collection_name = config['rag']['weavaite_collection']
logger.info(f"RAG domain ingesting to collection: {collection_name}")
client = get_weaviate_client()
# Batch insert with error handling
with client.batch.dynamic() as batch:
for txt, vec in zip(texts, vectors):
batch.add_object(
collection=collection_name,
properties={"text": txt},
vector=vec)
logger.info(f"Successfully ingested {len(texts)} chunks from {os.path.basename(file_path)}")
return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}"
except Exception as e:
logger.exception(f"Error ingesting file {file_path}: {e}")
finally:
if client is not None:
try:
client.close()
except Exception as close_error:
logger.error(f"Error closing Weaviate client: {close_error}")
def rag_dom_qa(question: str) -> str:
if not question.strip():
return "Please provide a valid question."
try:
corrected_query, is_autism_related, _ = process_query_for_rewrite(
question)
if not is_autism_related:
return get_non_autism_response()
q_vec = embed_texts([corrected_query])[0]
collection_name = config["rag"]["weavaite_collection"]
logger.info(f"RAG domain QA using collection: {collection_name}")
client = get_weaviate_client()
documents = client.collections.get(collection_name)
response = documents.query.near_vector(
near_vector=q_vec,
limit=5,
return_metadata=["distance"])
hits = response.objects
if not hits:
return "I couldn't find relevant information to answer your question."
context = "\n\n".join(hit.properties["text"] for hit in hits)
wisal_prompt = Prompt_template_Wisal.format(
new_query=corrected_query,
document=context)
initial_answer = qwen_generate(wisal_prompt)
answer_relevance_score = check_answer_autism_relevance(initial_answer)
if answer_relevance_score < 50:
return get_non_autism_answer_response()
return initial_answer
except Exception as e:
logger.error(f"Error in RAG domain QA: {e}")
return f"Sorry, I encountered an error processing your question: {str(e)}"
finally:
if client is not None:
try:
client.close()
except Exception as close_error:
logger.error(f"Error closing Weaviate client: {close_error}")
# ---------------------------
# OLD DOCUMENTS
# ---------------------------
async def old_doc_vdb(query: str, top_k: int = 1) -> dict:
"""Query old documents vector database"""
if not query.strip():
return {"answer": []}
qe = encode_query(query)
if not qe:
return {"answer": []}
try:
client = get_weaviate_client()
coll = client.collections.get(config["rag"]["weavaite_collection"]) ## old_documents
res = coll.query.near_vector(
near_vector=qe,
limit=top_k,
return_properties=["text"])
if not getattr(res, "objects", None):
return {"answer": []}
return {"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]}
except Exception as e:
logger.error(f"RAG Error in old_doc_vdb: {e}")
return {"answer": []}
finally:
if client is not None:
try:
client.close()
except Exception as close_error:
logger.error(f"Error closing Weaviate client: {close_error}")
def old_doc_ingestion(path: str) -> str:
global last_uploaded_path
if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
last_uploaded_path = path
logger.info(f"Old document path set: {os.path.basename(path)}")
return f"Old document ingested: {os.path.basename(path)}"
def old_doc_qa(query: str) -> str:
if not query.strip():
return "Please provide a valid question."
try:
corrected_query, is_autism_related, _ = process_query_for_rewrite(
query)
if not is_autism_related:
return get_non_autism_response()
rag_resp = asyncio.run(old_doc_vdb(corrected_query))
chunks = rag_resp.get("answer", [])
if not chunks:
return "Sorry, I couldn't find relevant content in the old document."
combined_answer = "\n".join(f"- {c}" for c in chunks if c.strip())
answer_relevance_score = check_answer_autism_relevance(combined_answer)
if answer_relevance_score < 50:
return get_non_autism_answer_response()
return combined_answer
except Exception as e:
logger.error(f"Error in old_doc_qa: {e}")
return f"Error processing your request: {e}"
# ---------------------------
# USER SPECIFIC DOCUMENTS
# ---------------------------
def user_doc_ingest(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
try:
raw = extract_text(file_path)
if not raw.strip():
raise ValueError(f"No text extracted from {file_path}")
splitter = get_text_splitter()
docs = splitter.split_text(raw)
texts = [chunk for chunk in docs if chunk.strip()]
if not texts:
raise ValueError("No valid text chunks created")
vectors = embed_texts(texts)
client = get_weaviate_client()
collection_name = config["rag"]["weavaite_collection"]
# Batch insert
with client.batch.dynamic() as batch:
for txt, vec in zip(texts, vectors):
batch.add_object(
collection=collection_name,
properties={"text": txt},
vector=vec)
logger.info(
f"Successfully ingested user document: {os.path.basename(file_path)}")
return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}"
except Exception as e:
logger.exception(f"Error ingesting user document {file_path}: {e}")
finally:
if client is not None:
try:
client.close()
except Exception as close_error:
logger.error(f"Error closing Weaviate client: {close_error}")
def user_doc_qa(question: str) -> str:
if not question.strip():
return "Please provide a valid question."
try:
corrected_query, is_autism_related, _ = process_query_for_rewrite(
question)
if not is_autism_related:
return get_non_autism_response()
q_vec = embed_texts([corrected_query])[0]
client = get_weaviate_client()
documents = client.collections.get(
config["rag"]["weavaite_collection"])
response = documents.query.near_vector(
near_vector=q_vec,
limit=5,
return_metadata=["distance"])
hits = response.objects
if not hits:
return "I couldn't find relevant information to answer your question."
context = "\n\n".join(hit.properties["text"] for hit in hits)
UserSpecificDocument_prompt = Prompt_template_User_document_prompt.format(
new_query=corrected_query,
document=context)
initial_answer = qwen_generate(UserSpecificDocument_prompt)
answer_relevance_score = check_answer_autism_relevance(initial_answer)
if answer_relevance_score < 50:
return get_non_autism_answer_response()
return initial_answer
except Exception as e:
logger.error(f"Error in user_doc_qa: {e}")
return f"Sorry, I encountered an error processing your question: {str(e)}"
# finally:
# if client is not None:
# try:
# client.close()
# except Exception as close_error:
# logger.error(f"Error closing Weaviate client: {close_error}")
## close client of weaviate
# client.close()
if __name__ == "__main__":
# Test file paths
pdf_test = "tests/Computational Requirements for Embed.pdf"
docs_test = "tests/Computational Requirements for Embed.docx"
txt_test = "assets/RAG_Documents/Autism_Books_1.txt"
print(f"=" * 70)
print("COMPREHENSIVE RAG DOCUMENT UTILS TEST SUITE")
print(f"=" * 70)
# ===========================
# Test 1: RAG Domain Functions
# ===========================
print(f"\n{'=' * 50}")
print("TEST 1: RAG DOMAIN FUNCTIONS")
print(f"{'=' * 50}")
try:
print(f"Testing RAG domain ingestion with: {os.path.basename(txt_test)}")
if os.path.exists(txt_test):
result = rag_dom_ingest(txt_test)
print(f"βœ“ RAG Domain Ingestion Result: {result}")
# Test RAG domain QA
print(f"\nTesting RAG domain QA...")
test_questions = [
"What is autism?",
"How can I help a child with autism?",
"What are the symptoms of autism?",
"Tell me about weather today" # Non-autism related
]
for question in test_questions:
print(f"\nQ: {question}")
answer = rag_dom_qa(question)
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}")
else:
print(f"βœ— Test file not found: {txt_test}")
except Exception as e:
print(f"βœ— RAG Domain Test Failed: {e}")
# ===========================
# Test 2: Old Document Functions
# ===========================
print(f"\n{'=' * 50}")
print("TEST 2: OLD DOCUMENT FUNCTIONS")
print(f"{'=' * 50}")
try:
print(f"Testing old document ingestion...")
if os.path.exists(txt_test):
result = old_doc_ingestion(txt_test)
print(f"βœ“ Old Document Ingestion Result: {result}")
# Test old document QA
print(f"\nTesting old document QA...")
test_questions = [
"What information is in this document?",
"Tell me about autism interventions",
"What is machine learning?" # Non-autism related
]
for question in test_questions:
print(f"\nQ: {question}")
answer = old_doc_qa(question)
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}")
else:
print(f"βœ— Test file not found: {txt_test}")
except Exception as e:
print(f"βœ— Old Document Test Failed: {e}")
# ===========================
# Test 3: User Document Functions
# ===========================
print(f"\n{'=' * 50}")
print("TEST 3: USER DOCUMENT FUNCTIONS")
print(f"{'=' * 50}")
try:
print(f"Testing user document ingestion...")
if os.path.exists(txt_test):
result = user_doc_ingest(txt_test)
print(f"βœ“ User Document Ingestion Result: {result}")
# Test user document QA
print(f"\nTesting user document QA...")
test_questions = [
"What does this document say about autism?",
"Are there any treatment recommendations?",
"What's the capital of France?" # Non-autism related
]
for question in test_questions:
print(f"\nQ: {question}")
answer = user_doc_qa(question)
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}")
else:
print(f"βœ— Test file not found: {txt_test}")
except Exception as e:
print(f"βœ— User Document Test Failed: {e}")
# ===========================
# Test 4: Multiple File Format Support
# ===========================
print(f"\n{'=' * 50}")
print("TEST 4: MULTIPLE FILE FORMAT SUPPORT")
print(f"{'=' * 50}")
test_files = [
(pdf_test, "PDF"),
(docs_test, "DOCX"),
(txt_test, "TXT")
]
for file_path, file_type in test_files:
print(f"\nTesting {file_type} file: {os.path.basename(file_path)}")
if os.path.exists(file_path):
try:
# Test extraction
text = extract_text(file_path)
if text:
print(f"βœ“ {file_type} text extraction successful: {len(text)} characters")
print(f" Preview: {text[:100]}...")
# Test ingestion
result = rag_dom_ingest(file_path)
print(f"βœ“ {file_type} ingestion successful: {result}")
else:
print(f"βœ— {file_type} text extraction returned empty")
except Exception as e:
print(f"βœ— {file_type} processing failed: {e}")
else:
print(f"βœ— {file_type} file not found: {file_path}")
# ===========================
# Test 5: Error Handling
# ===========================
print(f"\n{'=' * 50}")
print("TEST 5: ERROR HANDLING")
print(f"{'=' * 50}")
# Test with non-existent file
print("Testing with non-existent file...")
try:
result = rag_dom_ingest("non_existent_file.txt")
print(f"βœ— Should have failed: {result}")
except FileNotFoundError:
print("βœ“ Correctly handled non-existent file")
except Exception as e:
print(f"βœ“ Handled error: {e}")
# Test with empty query
print("\nTesting with empty query...")
empty_result = rag_dom_qa("")
print(f"βœ“ Empty query handled: {empty_result}")
# Test with very long query
print("\nTesting with very long query...")
long_query = "autism " * 100 + "what is it?"
long_result = rag_dom_qa(long_query)
print(f"βœ“ Long query handled: {long_result[:100]}...")
# ===========================
# Test 6: Old Document Vector DB
# ===========================
print(f"\n{'=' * 50}")
print("TEST 6: OLD DOCUMENT VECTOR DB")
print(f"{'=' * 50}")
try:
print("Testing old document vector database query...")
vdb_result = asyncio.run(old_doc_vdb("autism interventions", top_k=3))
print(f"βœ“ Vector DB query successful: {len(vdb_result.get('answer', []))} results")
for i, answer in enumerate(vdb_result.get('answer', [])[:2]):
print(f" Result {i+1}: {answer[:100]}...")
except Exception as e:
print(f"βœ— Vector DB test failed: {e}")
# ===========================
# Test 7: Configuration and Environment
# ===========================
print(f"\n{'=' * 50}")
print("TEST 7: CONFIGURATION AND ENVIRONMENT")
print(f"{'=' * 50}")
print("Checking environment variables...")
env_vars = [
"SILICONFLOW_API_KEY",
"SILICONFLOW_URL",
"SILICONFLOW_CHAT_URL",
"WEAVIATE_URL",
"WEAVIATE_API_KEY"
]
for var in env_vars:
value = os.getenv(var)
if value:
print(f"βœ“ {var}: Set (length: {len(value)})")
else:
print(f"βœ— {var}: Not set")
print(f"\nChecking configuration...")
try:
print(f"βœ“ Chunk size: {config['chunking']['chunk_size']}")
print(f"βœ“ Chunk overlap: {config['chunking']['chunk_overlap']}")
print(f"βœ“ RAG collection: {config['rag']['weavaite_collection']}")
print(f"βœ“ Old doc collection: {config['rag']['old_doc']}")
except Exception as e:
print(f"βœ— Configuration error: {e}")
# ===========================
# Test 8: Text Splitter
# ===========================
print(f"\n{'=' * 50}")
print("TEST 8: TEXT SPLITTER")
print(f"{'=' * 50}")
try:
splitter = get_text_splitter()
sample_text = "This is a sample text. " * 100 # Create long text
chunks = splitter.split_text(sample_text)
print(f"βœ“ Text splitter created {len(chunks)} chunks")
print(f"βœ“ Average chunk size: {sum(len(c) for c in chunks) / len(chunks):.0f} characters")
except Exception as e:
print(f"βœ— Text splitter test failed: {e}")
# ===========================
# Test Summary
# ===========================
print(f"\n{'=' * 70}")
print("TEST SUMMARY")
print(f"{'=' * 70}")
print("βœ“ All major functions tested")
print("βœ“ Error handling verified")
print("βœ“ Multiple file formats supported")
print("βœ“ Configuration checked")
print("βœ“ Vector database operations tested")
print(f"{'=' * 70}")
print("TEST SUITE COMPLETED")
print(f"{'=' * 70}")
# # Close client properly
# try:
# if 'client' in globals() and client:
# client.close()
# print("βœ“ Weaviate client closed properly")
# except Exception as e:
# print(f"βœ— Error closing client: {e}")