Autism_QA / rag_utils.py
A7m0d's picture
Upload folder using huggingface_hub
712579e verified
from typing import List, Dict
from configs import load_yaml_config
from query_utils import (get_non_autism_response, get_non_autism_answer_response,
check_answer_autism_relevance, process_query_for_rewrite)
from dotenv import load_dotenv
from clients import init_weaviate_client, siliconflow_qwen_generate_content, groq_qwen_generate_content
from openai import OpenAI
import asyncio
from weaviate.classes.init import Auth
import os
import requests
import time
import re
import google.generativeai as genai
from rag_steps import embed_texts
from torch import ge
import warnings
import logging
import weaviate
# Suppress Protobuf version mismatch warnings
warnings.filterwarnings("ignore", category=UserWarning,
module="google.protobuf.runtime_version")
load_dotenv()
# load config from yaml
config = load_yaml_config("config.yaml")
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("rag_utils")
except ImportError:
# Fallback to standard logging if cuastom logger not available
logger = logging.getLogger("rag_utils")
# ---------------------------
# Environment & Globals
# ---------------------------
env = os.getenv("ENVIRONMENT", "production")
SESSION_ID = "default"
pending_clarifications: Dict[str, str] = {}
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip()
if not SILICONFLOW_API_KEY:
logger.warning(
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.")
if not SILICONFLOW_URL:
logger.warning(
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.")
def encode_query(query: str) -> list[float] | None:
"""Generate a single embedding vector for a query string."""
embs = embed_texts([query], batch_size=1)
if embs and embs[0]:
logger.info(f"Query embedding (first 5 dims): {embs[0][:5]}")
return embs[0]
logger.error("Failed to generate query embedding.")
return None
async def rag_autism(query, top_k=3):
qe = encode_query(query)
if not qe:
return {"answer": []}
client = None
try:
client = init_weaviate_client()
logger.info(
f"Weaviate Collection : {config['rag']['weavaite_collection']}")
coll = client.collections.get(
config["rag"]["weavaite_collection"]) # Books
# Add timeout configuration to prevent GRPC deadline exceeded
res = coll.query.near_vector(
near_vector=qe,
limit=top_k,)
if not getattr(res, "objects", None):
return {"❌ answer": []}
return {
"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]}
except Exception as e:
logger.error(f"❌ RAG Error: {e}")
return {"answer": []}
finally:
# Always close the client connection
if client:
try:
client.close()
logger.info("Weaviate client connection closed successfully")
except Exception as e:
logger.warning(f"Error closing Weaviate client: {e}")
async def answer_question_async(query: str) -> str:
"""Async version of answer_question"""
corrected_query, is_autism_related, _ = process_query_for_rewrite(query)
if not is_autism_related:
return get_non_autism_response()
# Use the corrected query for retrieval
rag_resp = await rag_autism(corrected_query)
chunks = rag_resp.get("answer", [])
if not chunks:
return "Sorry, I couldn't find relevant content in the old document."
# Combine chunks into a single answer for relevance checking
combined_answer = "\n".join(f"- {c}" for c in chunks)
answer_relevance_score = check_answer_autism_relevance(combined_answer)
# If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring)
if answer_relevance_score < 50:
return get_non_autism_answer_response()
# If sufficiently autism-related, return the answer
return combined_answer
def answer_question(query: str) -> str:
"""Synchronous wrapper for answer_question_async"""
return asyncio.run(answer_question_async(query))
def is_greeting_or_thank(text: str) -> str:
if re.search(r"(?i)\b(hi|hello|hey|good (morning|afternoon|evening))\b|"
r"(صباح الخير|مساء الخير|أهلا|مرحبا|السلام عليكم)", text):
return "greeting"
elif re.search(r"(?i)\b(thank you|thanks)\b|"
r"(شكرا|شكرًا|مشكور|ألف شكر)", text):
return "thanks"
return ""
async def main():
"""Main async function to test the RAG functionality"""
query = "Can you tell me more about autism?"
try:
logger.info("Starting RAG query...")
result = await rag_autism(query=query)
logger.info(f"RAG Result: {result}")
# Test the answer_question function too
answer = await answer_question_async(query)
logger.info(f"Answer: {answer}")
except Exception as e:
logger.error(f"Error in main: {e}")
if __name__ == "__main__":
# Run the async main function
asyncio.run(main())