File size: 5,315 Bytes
712579e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from typing import List, Dict
from configs import load_yaml_config
from query_utils import (get_non_autism_response, get_non_autism_answer_response,
                         check_answer_autism_relevance, process_query_for_rewrite)
from dotenv import load_dotenv
from clients import init_weaviate_client, siliconflow_qwen_generate_content, groq_qwen_generate_content
from openai import OpenAI
import asyncio
from weaviate.classes.init import Auth
import os
import requests
import time
import re
import google.generativeai as genai
from rag_steps import embed_texts
from torch import ge
import warnings
import logging
import weaviate

# Suppress Protobuf version mismatch warnings
warnings.filterwarnings("ignore", category=UserWarning,
                        module="google.protobuf.runtime_version")


load_dotenv()


# load config from yaml
config = load_yaml_config("config.yaml")
try:
    from logger.custom_logger import CustomLoggerTracker
    custom_log = CustomLoggerTracker()
    logger = custom_log.get_logger("rag_utils")

except ImportError:
    # Fallback to standard logging if cuastom logger not available
    logger = logging.getLogger("rag_utils")


# ---------------------------
# Environment & Globals
# ---------------------------
env = os.getenv("ENVIRONMENT", "production")
SESSION_ID = "default"
pending_clarifications: Dict[str, str] = {}
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "")
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip()

if not SILICONFLOW_API_KEY:
    logger.warning(
        "SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.")
if not SILICONFLOW_URL:
    logger.warning(
        "SILICONFLOW_URL is not set. OpenAI client base_url will not work.")


def encode_query(query: str) -> list[float] | None:
    """Generate a single embedding vector for a query string."""
    embs = embed_texts([query], batch_size=1)
    if embs and embs[0]:
        logger.info(f"Query embedding (first 5 dims): {embs[0][:5]}")
        return embs[0]
    logger.error("Failed to generate query embedding.")
    return None


async def rag_autism(query, top_k=3):
    qe = encode_query(query)
    if not qe:
        return {"answer": []}
    client = None
    try:
        client = init_weaviate_client()
        logger.info(
            f"Weaviate Collection : {config['rag']['weavaite_collection']}")
        coll = client.collections.get(
            config["rag"]["weavaite_collection"])  # Books

        # Add timeout configuration to prevent GRPC deadline exceeded
        res = coll.query.near_vector(
            near_vector=qe,
            limit=top_k,)

        if not getattr(res, "objects", None):
            return {"❌ answer": []}
        return {
            "answer": [obj.properties.get("text", "[No Text]") for obj in res.objects]}
    except Exception as e:
        logger.error(f"❌ RAG Error: {e}")
        return {"answer": []}
    finally:
        # Always close the client connection
        if client:
            try:
                client.close()
                logger.info("Weaviate client connection closed successfully")
            except Exception as e:
                logger.warning(f"Error closing Weaviate client: {e}")


async def answer_question_async(query: str) -> str:
    """Async version of answer_question"""
    corrected_query, is_autism_related, _ = process_query_for_rewrite(query)
    if not is_autism_related:
        return get_non_autism_response()

    # Use the corrected query for retrieval
    rag_resp = await rag_autism(corrected_query)
    chunks = rag_resp.get("answer", [])
    if not chunks:
        return "Sorry, I couldn't find relevant content in the old document."

    # Combine chunks into a single answer for relevance checking
    combined_answer = "\n".join(f"- {c}" for c in chunks)
    answer_relevance_score = check_answer_autism_relevance(combined_answer)
    # If answer relevance is below 50%, refuse the answer (updated threshold for enhanced scoring)
    if answer_relevance_score < 50:
        return get_non_autism_answer_response()
    # If sufficiently autism-related, return the answer
    return combined_answer


def answer_question(query: str) -> str:
    """Synchronous wrapper for answer_question_async"""
    return asyncio.run(answer_question_async(query))



def is_greeting_or_thank(text: str) -> str:
    if re.search(r"(?i)\b(hi|hello|hey|good (morning|afternoon|evening))\b|"
                 r"(صباح الخير|مساء الخير|أهلا|مرحبا|السلام عليكم)", text):
        return "greeting"
    elif re.search(r"(?i)\b(thank you|thanks)\b|"
                   r"(شكرا|شكرًا|مشكور|ألف شكر)", text):
        return "thanks"
    return ""


async def main():
    """Main async function to test the RAG functionality"""
    query = "Can you tell me more about autism?"
    try:
        logger.info("Starting RAG query...")
        result = await rag_autism(query=query)
        logger.info(f"RAG Result: {result}")

        # Test the answer_question function too
        answer = await answer_question_async(query)
        logger.info(f"Answer: {answer}")
    except Exception as e:
        logger.error(f"Error in main: {e}")


if __name__ == "__main__":
    # Run the async main function
    asyncio.run(main())