from typing import List, Dict from configs import load_yaml_config from query_utils import (get_non_autism_response, get_non_autism_answer_response, check_answer_autism_relevance, process_query_for_rewrite) from dotenv import load_dotenv from clients import init_weaviate_client, siliconflow_qwen_generate_content, groq_qwen_generate_content from openai import OpenAI import asyncio from weaviate.classes.init import Auth import os import requests import time import re import google.generativeai as genai from rag_steps import embed_texts from torch import ge import warnings import logging import weaviate # Suppress Protobuf version mismatch warnings warnings.filterwarnings("ignore", category=UserWarning, module="google.protobuf.runtime_version") load_dotenv() # load config from yaml config = load_yaml_config("config.yaml") try: from logger.custom_logger import CustomLoggerTracker custom_log = CustomLoggerTracker() logger = custom_log.get_logger("rag_utils") except ImportError: # Fallback to standard logging if cuastom logger not available logger = logging.getLogger("rag_utils") # --------------------------- # Environment & Globals # --------------------------- env = os.getenv("ENVIRONMENT", "production") SESSION_ID = "default" pending_clarifications: Dict[str, str] = {} SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip() if not SILICONFLOW_API_KEY: logger.warning( "SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.") if not SILICONFLOW_URL: logger.warning( "SILICONFLOW_URL is not set. OpenAI client base_url will not work.") def encode_query(query: str) -> list[float] | None: """Generate a single embedding vector for a query string.""" embs = embed_texts([query], batch_size=1) if embs and embs[0]: logger.info(f"Query embedding (first 5 dims): {embs[0][:5]}") return embs[0] logger.error("Failed to generate query embedding.") return None async def rag_autism(query, top_k=3): qe = encode_query(query) if not qe: return {"answer": []} client = None try: client = init_weaviate_client() logger.info( f"Weaviate Collection : {config['rag']['weavaite_collection']}") coll = client.collections.get( config["rag"]["weavaite_collection"]) # Books # Add timeout configuration to prevent GRPC deadline exceeded res = coll.query.near_vector( near_vector=qe, limit=top_k,) if not getattr(res, "objects", None): return {"❌ answer": []} return { "answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]} except Exception as e: logger.error(f"❌ RAG Error: {e}") return {"answer": []} finally: # Always close the client connection if client: try: client.close() logger.info("Weaviate client connection closed successfully") except Exception as e: logger.warning(f"Error closing Weaviate client: {e}") async def answer_question_async(query: str) -> str: """Async version of answer_question""" corrected_query, is_autism_related, _ = process_query_for_rewrite(query) if not is_autism_related: return get_non_autism_response() # Use the corrected query for retrieval rag_resp = await rag_autism(corrected_query) chunks = rag_resp.get("answer", []) if not chunks: return "Sorry, I couldn't find relevant content in the old document." # Combine chunks into a single answer for relevance checking combined_answer = "\n".join(f"- {c}" for c in chunks) answer_relevance_score = check_answer_autism_relevance(combined_answer) # If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring) if answer_relevance_score < 50: return get_non_autism_answer_response() # If sufficiently autism-related, return the answer return combined_answer def answer_question(query: str) -> str: """Synchronous wrapper for answer_question_async""" return asyncio.run(answer_question_async(query)) def is_greeting_or_thank(text: str) -> str: if re.search(r"(?i)\b(hi|hello|hey|good (morning|afternoon|evening))\b|" r"(صباح الخير|مساء الخير|أهلا|مرحبا|السلام عليكم)", text): return "greeting" elif re.search(r"(?i)\b(thank you|thanks)\b|" r"(شكرا|شكرًا|مشكور|ألف شكر)", text): return "thanks" return "" async def main(): """Main async function to test the RAG functionality""" query = "Can you tell me more about autism?" try: logger.info("Starting RAG query...") result = await rag_autism(query=query) logger.info(f"RAG Result: {result}") # Test the answer_question function too answer = await answer_question_async(query) logger.info(f"Answer: {answer}") except Exception as e: logger.error(f"Error in main: {e}") if __name__ == "__main__": # Run the async main function asyncio.run(main())