File size: 5,315 Bytes
712579e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
from typing import List, Dict
from configs import load_yaml_config
from query_utils import (get_non_autism_response, get_non_autism_answer_response,
check_answer_autism_relevance, process_query_for_rewrite)
from dotenv import load_dotenv
from clients import init_weaviate_client, siliconflow_qwen_generate_content, groq_qwen_generate_content
from openai import OpenAI
import asyncio
from weaviate.classes.init import Auth
import os
import requests
import time
import re
import google.generativeai as genai
from rag_steps import embed_texts
from torch import ge
import warnings
import logging
import weaviate
# Suppress Protobuf version mismatch warnings
warnings.filterwarnings("ignore", category=UserWarning,
module="google.protobuf.runtime_version")
load_dotenv()
# load config from yaml
config = load_yaml_config("config.yaml")
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("rag_utils")
except ImportError:
# Fallback to standard logging if cuastom logger not available
logger = logging.getLogger("rag_utils")
# ---------------------------
# Environment & Globals
# ---------------------------
env = os.getenv("ENVIRONMENT", "production")
SESSION_ID = "default"
pending_clarifications: Dict[str, str] = {}
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip()
if not SILICONFLOW_API_KEY:
logger.warning(
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.")
if not SILICONFLOW_URL:
logger.warning(
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.")
def encode_query(query: str) -> list[float] | None:
"""Generate a single embedding vector for a query string."""
embs = embed_texts([query], batch_size=1)
if embs and embs[0]:
logger.info(f"Query embedding (first 5 dims): {embs[0][:5]}")
return embs[0]
logger.error("Failed to generate query embedding.")
return None
async def rag_autism(query, top_k=3):
qe = encode_query(query)
if not qe:
return {"answer": []}
client = None
try:
client = init_weaviate_client()
logger.info(
f"Weaviate Collection : {config['rag']['weavaite_collection']}")
coll = client.collections.get(
config["rag"]["weavaite_collection"]) # Books
# Add timeout configuration to prevent GRPC deadline exceeded
res = coll.query.near_vector(
near_vector=qe,
limit=top_k,)
if not getattr(res, "objects", None):
return {"❌ answer": []}
return {
"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]}
except Exception as e:
logger.error(f"❌ RAG Error: {e}")
return {"answer": []}
finally:
# Always close the client connection
if client:
try:
client.close()
logger.info("Weaviate client connection closed successfully")
except Exception as e:
logger.warning(f"Error closing Weaviate client: {e}")
async def answer_question_async(query: str) -> str:
"""Async version of answer_question"""
corrected_query, is_autism_related, _ = process_query_for_rewrite(query)
if not is_autism_related:
return get_non_autism_response()
# Use the corrected query for retrieval
rag_resp = await rag_autism(corrected_query)
chunks = rag_resp.get("answer", [])
if not chunks:
return "Sorry, I couldn't find relevant content in the old document."
# Combine chunks into a single answer for relevance checking
combined_answer = "\n".join(f"- {c}" for c in chunks)
answer_relevance_score = check_answer_autism_relevance(combined_answer)
# If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring)
if answer_relevance_score < 50:
return get_non_autism_answer_response()
# If sufficiently autism-related, return the answer
return combined_answer
def answer_question(query: str) -> str:
"""Synchronous wrapper for answer_question_async"""
return asyncio.run(answer_question_async(query))
def is_greeting_or_thank(text: str) -> str:
if re.search(r"(?i)\b(hi|hello|hey|good (morning|afternoon|evening))\b|"
r"(صباح الخير|مساء الخير|أهلا|مرحبا|السلام عليكم)", text):
return "greeting"
elif re.search(r"(?i)\b(thank you|thanks)\b|"
r"(شكرا|شكرًا|مشكور|ألف شكر)", text):
return "thanks"
return ""
async def main():
"""Main async function to test the RAG functionality"""
query = "Can you tell me more about autism?"
try:
logger.info("Starting RAG query...")
result = await rag_autism(query=query)
logger.info(f"RAG Result: {result}")
# Test the answer_question function too
answer = await answer_question_async(query)
logger.info(f"Answer: {answer}")
except Exception as e:
logger.error(f"Error in main: {e}")
if __name__ == "__main__":
# Run the async main function
asyncio.run(main())
|