|
|
from typing import List, Dict |
|
|
from configs import load_yaml_config |
|
|
from query_utils import (get_non_autism_response, get_non_autism_answer_response, |
|
|
check_answer_autism_relevance, process_query_for_rewrite) |
|
|
from dotenv import load_dotenv |
|
|
from clients import init_weaviate_client, siliconflow_qwen_generate_content, groq_qwen_generate_content |
|
|
from openai import OpenAI |
|
|
import asyncio |
|
|
from weaviate.classes.init import Auth |
|
|
import os |
|
|
import requests |
|
|
import time |
|
|
import re |
|
|
import google.generativeai as genai |
|
|
from rag_steps import embed_texts |
|
|
from torch import ge |
|
|
import warnings |
|
|
import logging |
|
|
import weaviate |
|
|
|
|
|
|
|
|
warnings.filterwarnings("ignore", category=UserWarning, |
|
|
module="google.protobuf.runtime_version") |
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
config = load_yaml_config("config.yaml") |
|
|
try: |
|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("rag_utils") |
|
|
|
|
|
except ImportError: |
|
|
|
|
|
logger = logging.getLogger("rag_utils") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
env = os.getenv("ENVIRONMENT", "production") |
|
|
SESSION_ID = "default" |
|
|
pending_clarifications: Dict[str, str] = {} |
|
|
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") |
|
|
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip() |
|
|
|
|
|
if not SILICONFLOW_API_KEY: |
|
|
logger.warning( |
|
|
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.") |
|
|
if not SILICONFLOW_URL: |
|
|
logger.warning( |
|
|
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.") |
|
|
|
|
|
|
|
|
def encode_query(query: str) -> list[float] | None: |
|
|
"""Generate a single embedding vector for a query string.""" |
|
|
embs = embed_texts([query], batch_size=1) |
|
|
if embs and embs[0]: |
|
|
logger.info(f"Query embedding (first 5 dims): {embs[0][:5]}") |
|
|
return embs[0] |
|
|
logger.error("Failed to generate query embedding.") |
|
|
return None |
|
|
|
|
|
|
|
|
async def rag_autism(query, top_k=3): |
|
|
qe = encode_query(query) |
|
|
if not qe: |
|
|
return {"answer": []} |
|
|
client = None |
|
|
try: |
|
|
client = init_weaviate_client() |
|
|
logger.info( |
|
|
f"Weaviate Collection : {config['rag']['weavaite_collection']}") |
|
|
coll = client.collections.get( |
|
|
config["rag"]["weavaite_collection"]) |
|
|
|
|
|
|
|
|
res = coll.query.near_vector( |
|
|
near_vector=qe, |
|
|
limit=top_k,) |
|
|
|
|
|
if not getattr(res, "objects", None): |
|
|
return {"❌ answer": []} |
|
|
return { |
|
|
"answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]} |
|
|
except Exception as e: |
|
|
logger.error(f"❌ RAG Error: {e}") |
|
|
return {"answer": []} |
|
|
finally: |
|
|
|
|
|
if client: |
|
|
try: |
|
|
client.close() |
|
|
logger.info("Weaviate client connection closed successfully") |
|
|
except Exception as e: |
|
|
logger.warning(f"Error closing Weaviate client: {e}") |
|
|
|
|
|
|
|
|
async def answer_question_async(query: str) -> str: |
|
|
"""Async version of answer_question""" |
|
|
corrected_query, is_autism_related, _ = process_query_for_rewrite(query) |
|
|
if not is_autism_related: |
|
|
return get_non_autism_response() |
|
|
|
|
|
|
|
|
rag_resp = await rag_autism(corrected_query) |
|
|
chunks = rag_resp.get("answer", []) |
|
|
if not chunks: |
|
|
return "Sorry, I couldn't find relevant content in the old document." |
|
|
|
|
|
|
|
|
combined_answer = "\n".join(f"- {c}" for c in chunks) |
|
|
answer_relevance_score = check_answer_autism_relevance(combined_answer) |
|
|
|
|
|
if answer_relevance_score < 50: |
|
|
return get_non_autism_answer_response() |
|
|
|
|
|
return combined_answer |
|
|
|
|
|
|
|
|
def answer_question(query: str) -> str: |
|
|
"""Synchronous wrapper for answer_question_async""" |
|
|
return asyncio.run(answer_question_async(query)) |
|
|
|
|
|
|
|
|
|
|
|
def is_greeting_or_thank(text: str) -> str: |
|
|
if re.search(r"(?i)\b(hi|hello|hey|good (morning|afternoon|evening))\b|" |
|
|
r"(صباح الخير|مساء الخير|أهلا|مرحبا|السلام عليكم)", text): |
|
|
return "greeting" |
|
|
elif re.search(r"(?i)\b(thank you|thanks)\b|" |
|
|
r"(شكرا|شكرًا|مشكور|ألف شكر)", text): |
|
|
return "thanks" |
|
|
return "" |
|
|
|
|
|
|
|
|
async def main(): |
|
|
"""Main async function to test the RAG functionality""" |
|
|
query = "Can you tell me more about autism?" |
|
|
try: |
|
|
logger.info("Starting RAG query...") |
|
|
result = await rag_autism(query=query) |
|
|
logger.info(f"RAG Result: {result}") |
|
|
|
|
|
|
|
|
answer = await answer_question_async(query) |
|
|
logger.info(f"Answer: {answer}") |
|
|
except Exception as e: |
|
|
logger.error(f"Error in main: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
asyncio.run(main()) |
|
|
|