|
|
from query_utils import process_query_for_rewrite, get_non_autism_response |
|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
from dotenv import load_dotenv |
|
|
from query_utils import check_answer_autism_relevance, get_non_autism_answer_response |
|
|
from clients import get_weaviate_client, qwen_generate |
|
|
from query_utils import process_query_for_rewrite |
|
|
from rag_steps import * |
|
|
from rag_utils import * |
|
|
from prompt_template import ( |
|
|
Prompt_template_Wisal, |
|
|
Prompt_template_User_document_prompt) |
|
|
|
|
|
import os |
|
|
import asyncio |
|
|
from typing import Dict |
|
|
from configs import load_yaml_config |
|
|
|
|
|
config = load_yaml_config("config.yaml") |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("doc_utils") |
|
|
logger.info("Logger initialized for Documents utilities module") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SESSION_ID = "default" |
|
|
pending_clarifications: Dict[str, str] = {} |
|
|
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") |
|
|
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip() |
|
|
SILICONFLOW_CHAT_URL = os.getenv( |
|
|
"SILICONFLOW_CHAT_URL", "https://api.siliconflow.com/v1/chat/completions").strip() |
|
|
|
|
|
if not SILICONFLOW_API_KEY: |
|
|
logger.warning( |
|
|
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.") |
|
|
if not SILICONFLOW_URL: |
|
|
logger.warning( |
|
|
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.") |
|
|
|
|
|
|
|
|
last_uploaded_path = None |
|
|
|
|
|
|
|
|
def get_text_splitter(): |
|
|
"""Factory function for text splitter - makes testing easier""" |
|
|
return RecursiveCharacterTextSplitter( |
|
|
chunk_size=config["chunking"]["chunk_size"], |
|
|
|
|
|
chunk_overlap=config["chunking"]["chunk_overlap"], |
|
|
separators=config["chunking"]["separators"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def rag_dom_ingest(file_path: str) -> str: |
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f"File not found: {file_path}") |
|
|
try: |
|
|
raw = extract_text(file_path) |
|
|
if not raw.strip(): |
|
|
raise ValueError(f"No text extracted from {file_path}") |
|
|
splitter = get_text_splitter() |
|
|
docs = splitter.split_text(raw) |
|
|
|
|
|
texts = [chunk for chunk in docs if chunk.strip()] |
|
|
if not texts: |
|
|
raise ValueError("No valid text chunks created") |
|
|
vectors = embed_texts(texts) |
|
|
collection_name = config['rag']['weavaite_collection'] |
|
|
logger.info(f"RAG domain ingesting to collection: {collection_name}") |
|
|
client = get_weaviate_client() |
|
|
|
|
|
with client.batch.dynamic() as batch: |
|
|
for txt, vec in zip(texts, vectors): |
|
|
batch.add_object( |
|
|
collection=collection_name, |
|
|
properties={"text": txt}, |
|
|
vector=vec) |
|
|
logger.info(f"Successfully ingested {len(texts)} chunks from {os.path.basename(file_path)}") |
|
|
return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}" |
|
|
except Exception as e: |
|
|
logger.exception(f"Error ingesting file {file_path}: {e}") |
|
|
|
|
|
finally: |
|
|
if client is not None: |
|
|
try: |
|
|
client.close() |
|
|
except Exception as close_error: |
|
|
logger.error(f"Error closing Weaviate client: {close_error}") |
|
|
|
|
|
|
|
|
|
|
|
def rag_dom_qa(question: str) -> str: |
|
|
if not question.strip(): |
|
|
return "Please provide a valid question." |
|
|
try: |
|
|
corrected_query, is_autism_related, _ = process_query_for_rewrite( |
|
|
question) |
|
|
if not is_autism_related: |
|
|
return get_non_autism_response() |
|
|
q_vec = embed_texts([corrected_query])[0] |
|
|
collection_name = config["rag"]["weavaite_collection"] |
|
|
logger.info(f"RAG domain QA using collection: {collection_name}") |
|
|
client = get_weaviate_client() |
|
|
documents = client.collections.get(collection_name) |
|
|
response = documents.query.near_vector( |
|
|
near_vector=q_vec, |
|
|
limit=5, |
|
|
return_metadata=["distance"]) |
|
|
hits = response.objects |
|
|
if not hits: |
|
|
return "I couldn't find relevant information to answer your question." |
|
|
context = "\n\n".join(hit.properties["text"] for hit in hits) |
|
|
wisal_prompt = Prompt_template_Wisal.format( |
|
|
new_query=corrected_query, |
|
|
document=context) |
|
|
initial_answer = qwen_generate(wisal_prompt) |
|
|
answer_relevance_score = check_answer_autism_relevance(initial_answer) |
|
|
if answer_relevance_score < 50: |
|
|
return get_non_autism_answer_response() |
|
|
return initial_answer |
|
|
except Exception as e: |
|
|
logger.error(f"Error in RAG domain QA: {e}") |
|
|
return f"Sorry, I encountered an error processing your question: {str(e)}" |
|
|
finally: |
|
|
if client is not None: |
|
|
try: |
|
|
client.close() |
|
|
except Exception as close_error: |
|
|
logger.error(f"Error closing Weaviate client: {close_error}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def old_doc_vdb(query: str, top_k: int = 1) -> dict: |
|
|
"""Query old documents vector database""" |
|
|
if not query.strip(): |
|
|
return {"answer": []} |
|
|
qe = encode_query(query) |
|
|
if not qe: |
|
|
return {"answer": []} |
|
|
try: |
|
|
client = get_weaviate_client() |
|
|
coll = client.collections.get(config["rag"]["weavaite_collection"]) |
|
|
res = coll.query.near_vector( |
|
|
near_vector=qe, |
|
|
limit=top_k, |
|
|
return_properties=["text"]) |
|
|
if not getattr(res, "objects", None): |
|
|
return {"answer": []} |
|
|
return {"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]} |
|
|
except Exception as e: |
|
|
logger.error(f"RAG Error in old_doc_vdb: {e}") |
|
|
return {"answer": []} |
|
|
finally: |
|
|
if client is not None: |
|
|
try: |
|
|
client.close() |
|
|
except Exception as close_error: |
|
|
logger.error(f"Error closing Weaviate client: {close_error}") |
|
|
|
|
|
|
|
|
def old_doc_ingestion(path: str) -> str: |
|
|
global last_uploaded_path |
|
|
if not os.path.exists(path): |
|
|
raise FileNotFoundError(f"File not found: {path}") |
|
|
last_uploaded_path = path |
|
|
logger.info(f"Old document path set: {os.path.basename(path)}") |
|
|
return f"Old document ingested: {os.path.basename(path)}" |
|
|
|
|
|
|
|
|
def old_doc_qa(query: str) -> str: |
|
|
if not query.strip(): |
|
|
return "Please provide a valid question." |
|
|
try: |
|
|
corrected_query, is_autism_related, _ = process_query_for_rewrite( |
|
|
query) |
|
|
if not is_autism_related: |
|
|
return get_non_autism_response() |
|
|
rag_resp = asyncio.run(old_doc_vdb(corrected_query)) |
|
|
chunks = rag_resp.get("answer", []) |
|
|
if not chunks: |
|
|
return "Sorry, I couldn't find relevant content in the old document." |
|
|
combined_answer = "\n".join(f"- {c}" for c in chunks if c.strip()) |
|
|
answer_relevance_score = check_answer_autism_relevance(combined_answer) |
|
|
if answer_relevance_score < 50: |
|
|
return get_non_autism_answer_response() |
|
|
return combined_answer |
|
|
except Exception as e: |
|
|
logger.error(f"Error in old_doc_qa: {e}") |
|
|
return f"Error processing your request: {e}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def user_doc_ingest(file_path: str) -> str: |
|
|
if not os.path.exists(file_path): |
|
|
raise FileNotFoundError(f"File not found: {file_path}") |
|
|
try: |
|
|
raw = extract_text(file_path) |
|
|
if not raw.strip(): |
|
|
raise ValueError(f"No text extracted from {file_path}") |
|
|
splitter = get_text_splitter() |
|
|
docs = splitter.split_text(raw) |
|
|
texts = [chunk for chunk in docs if chunk.strip()] |
|
|
if not texts: |
|
|
raise ValueError("No valid text chunks created") |
|
|
vectors = embed_texts(texts) |
|
|
client = get_weaviate_client() |
|
|
collection_name = config["rag"]["weavaite_collection"] |
|
|
|
|
|
with client.batch.dynamic() as batch: |
|
|
for txt, vec in zip(texts, vectors): |
|
|
batch.add_object( |
|
|
collection=collection_name, |
|
|
properties={"text": txt}, |
|
|
vector=vec) |
|
|
logger.info( |
|
|
f"Successfully ingested user document: {os.path.basename(file_path)}") |
|
|
return f"Ingested {len(texts)} chunks from {os.path.basename(file_path)}" |
|
|
except Exception as e: |
|
|
logger.exception(f"Error ingesting user document {file_path}: {e}") |
|
|
finally: |
|
|
if client is not None: |
|
|
try: |
|
|
client.close() |
|
|
except Exception as close_error: |
|
|
logger.error(f"Error closing Weaviate client: {close_error}") |
|
|
|
|
|
def user_doc_qa(question: str) -> str: |
|
|
if not question.strip(): |
|
|
return "Please provide a valid question." |
|
|
try: |
|
|
corrected_query, is_autism_related, _ = process_query_for_rewrite( |
|
|
question) |
|
|
if not is_autism_related: |
|
|
return get_non_autism_response() |
|
|
q_vec = embed_texts([corrected_query])[0] |
|
|
client = get_weaviate_client() |
|
|
documents = client.collections.get( |
|
|
config["rag"]["weavaite_collection"]) |
|
|
response = documents.query.near_vector( |
|
|
near_vector=q_vec, |
|
|
limit=5, |
|
|
return_metadata=["distance"]) |
|
|
hits = response.objects |
|
|
if not hits: |
|
|
return "I couldn't find relevant information to answer your question." |
|
|
context = "\n\n".join(hit.properties["text"] for hit in hits) |
|
|
UserSpecificDocument_prompt = Prompt_template_User_document_prompt.format( |
|
|
new_query=corrected_query, |
|
|
document=context) |
|
|
initial_answer = qwen_generate(UserSpecificDocument_prompt) |
|
|
answer_relevance_score = check_answer_autism_relevance(initial_answer) |
|
|
if answer_relevance_score < 50: |
|
|
return get_non_autism_answer_response() |
|
|
return initial_answer |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in user_doc_qa: {e}") |
|
|
return f"Sorry, I encountered an error processing your question: {str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
pdf_test = "tests/Computational Requirements for Embed.pdf" |
|
|
docs_test = "tests/Computational Requirements for Embed.docx" |
|
|
txt_test = "assets/RAG_Documents/Autism_Books_1.txt" |
|
|
|
|
|
print(f"=" * 70) |
|
|
print("COMPREHENSIVE RAG DOCUMENT UTILS TEST SUITE") |
|
|
print(f"=" * 70) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 1: RAG DOMAIN FUNCTIONS") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
try: |
|
|
print(f"Testing RAG domain ingestion with: {os.path.basename(txt_test)}") |
|
|
if os.path.exists(txt_test): |
|
|
result = rag_dom_ingest(txt_test) |
|
|
print(f"β RAG Domain Ingestion Result: {result}") |
|
|
|
|
|
|
|
|
print(f"\nTesting RAG domain QA...") |
|
|
test_questions = [ |
|
|
"What is autism?", |
|
|
"How can I help a child with autism?", |
|
|
"What are the symptoms of autism?", |
|
|
"Tell me about weather today" |
|
|
] |
|
|
|
|
|
for question in test_questions: |
|
|
print(f"\nQ: {question}") |
|
|
answer = rag_dom_qa(question) |
|
|
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") |
|
|
|
|
|
else: |
|
|
print(f"β Test file not found: {txt_test}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β RAG Domain Test Failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 2: OLD DOCUMENT FUNCTIONS") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
try: |
|
|
print(f"Testing old document ingestion...") |
|
|
if os.path.exists(txt_test): |
|
|
result = old_doc_ingestion(txt_test) |
|
|
print(f"β Old Document Ingestion Result: {result}") |
|
|
|
|
|
|
|
|
print(f"\nTesting old document QA...") |
|
|
test_questions = [ |
|
|
"What information is in this document?", |
|
|
"Tell me about autism interventions", |
|
|
"What is machine learning?" |
|
|
] |
|
|
|
|
|
for question in test_questions: |
|
|
print(f"\nQ: {question}") |
|
|
answer = old_doc_qa(question) |
|
|
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") |
|
|
|
|
|
else: |
|
|
print(f"β Test file not found: {txt_test}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Old Document Test Failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 3: USER DOCUMENT FUNCTIONS") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
try: |
|
|
print(f"Testing user document ingestion...") |
|
|
if os.path.exists(txt_test): |
|
|
result = user_doc_ingest(txt_test) |
|
|
print(f"β User Document Ingestion Result: {result}") |
|
|
|
|
|
|
|
|
print(f"\nTesting user document QA...") |
|
|
test_questions = [ |
|
|
"What does this document say about autism?", |
|
|
"Are there any treatment recommendations?", |
|
|
"What's the capital of France?" |
|
|
] |
|
|
|
|
|
for question in test_questions: |
|
|
print(f"\nQ: {question}") |
|
|
answer = user_doc_qa(question) |
|
|
print(f"A: {answer[:200]}{'...' if len(answer) > 200 else ''}") |
|
|
|
|
|
else: |
|
|
print(f"β Test file not found: {txt_test}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β User Document Test Failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 4: MULTIPLE FILE FORMAT SUPPORT") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
test_files = [ |
|
|
(pdf_test, "PDF"), |
|
|
(docs_test, "DOCX"), |
|
|
(txt_test, "TXT") |
|
|
] |
|
|
|
|
|
for file_path, file_type in test_files: |
|
|
print(f"\nTesting {file_type} file: {os.path.basename(file_path)}") |
|
|
if os.path.exists(file_path): |
|
|
try: |
|
|
|
|
|
text = extract_text(file_path) |
|
|
if text: |
|
|
print(f"β {file_type} text extraction successful: {len(text)} characters") |
|
|
print(f" Preview: {text[:100]}...") |
|
|
|
|
|
|
|
|
result = rag_dom_ingest(file_path) |
|
|
print(f"β {file_type} ingestion successful: {result}") |
|
|
else: |
|
|
print(f"β {file_type} text extraction returned empty") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β {file_type} processing failed: {e}") |
|
|
else: |
|
|
print(f"β {file_type} file not found: {file_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 5: ERROR HANDLING") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
|
|
|
print("Testing with non-existent file...") |
|
|
try: |
|
|
result = rag_dom_ingest("non_existent_file.txt") |
|
|
print(f"β Should have failed: {result}") |
|
|
except FileNotFoundError: |
|
|
print("β Correctly handled non-existent file") |
|
|
except Exception as e: |
|
|
print(f"β Handled error: {e}") |
|
|
|
|
|
|
|
|
print("\nTesting with empty query...") |
|
|
empty_result = rag_dom_qa("") |
|
|
print(f"β Empty query handled: {empty_result}") |
|
|
|
|
|
|
|
|
print("\nTesting with very long query...") |
|
|
long_query = "autism " * 100 + "what is it?" |
|
|
long_result = rag_dom_qa(long_query) |
|
|
print(f"β Long query handled: {long_result[:100]}...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 6: OLD DOCUMENT VECTOR DB") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
try: |
|
|
print("Testing old document vector database query...") |
|
|
vdb_result = asyncio.run(old_doc_vdb("autism interventions", top_k=3)) |
|
|
print(f"β Vector DB query successful: {len(vdb_result.get('answer', []))} results") |
|
|
|
|
|
for i, answer in enumerate(vdb_result.get('answer', [])[:2]): |
|
|
print(f" Result {i+1}: {answer[:100]}...") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Vector DB test failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 7: CONFIGURATION AND ENVIRONMENT") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
print("Checking environment variables...") |
|
|
env_vars = [ |
|
|
"SILICONFLOW_API_KEY", |
|
|
"SILICONFLOW_URL", |
|
|
"SILICONFLOW_CHAT_URL", |
|
|
"WEAVIATE_URL", |
|
|
"WEAVIATE_API_KEY" |
|
|
] |
|
|
|
|
|
for var in env_vars: |
|
|
value = os.getenv(var) |
|
|
if value: |
|
|
print(f"β {var}: Set (length: {len(value)})") |
|
|
else: |
|
|
print(f"β {var}: Not set") |
|
|
|
|
|
print(f"\nChecking configuration...") |
|
|
try: |
|
|
print(f"β Chunk size: {config['chunking']['chunk_size']}") |
|
|
print(f"β Chunk overlap: {config['chunking']['chunk_overlap']}") |
|
|
print(f"β RAG collection: {config['rag']['weavaite_collection']}") |
|
|
print(f"β Old doc collection: {config['rag']['old_doc']}") |
|
|
except Exception as e: |
|
|
print(f"β Configuration error: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 50}") |
|
|
print("TEST 8: TEXT SPLITTER") |
|
|
print(f"{'=' * 50}") |
|
|
|
|
|
try: |
|
|
splitter = get_text_splitter() |
|
|
sample_text = "This is a sample text. " * 100 |
|
|
chunks = splitter.split_text(sample_text) |
|
|
print(f"β Text splitter created {len(chunks)} chunks") |
|
|
print(f"β Average chunk size: {sum(len(c) for c in chunks) / len(chunks):.0f} characters") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Text splitter test failed: {e}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"\n{'=' * 70}") |
|
|
print("TEST SUMMARY") |
|
|
print(f"{'=' * 70}") |
|
|
print("β All major functions tested") |
|
|
print("β Error handling verified") |
|
|
print("β Multiple file formats supported") |
|
|
print("β Configuration checked") |
|
|
print("β Vector database operations tested") |
|
|
print(f"{'=' * 70}") |
|
|
print("TEST SUITE COMPLETED") |
|
|
print(f"{'=' * 70}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|