|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
from configs import load_yaml_config
|
|
|
import requests
|
|
|
import os
|
|
|
import logging
|
|
|
import time
|
|
|
import re
|
|
|
import os
|
|
|
from clients import qwen_generate
|
|
|
import requests
|
|
|
from typing import List
|
|
|
from prompt_template import (
|
|
|
Prompt_template_translation,
|
|
|
Prompt_template_autism_confidence,
|
|
|
Prompt_template_autism_rewriter,
|
|
|
Prompt_template_answer_autism_relevance
|
|
|
)
|
|
|
|
|
|
|
|
|
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
|
|
|
SILICONFLOW_EMBEDDING_URL = os.getenv("SILICONFLOW_EMBEDDING_URL")
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
|
|
|
config = load_yaml_config("config.yaml")
|
|
|
|
|
|
try:
|
|
|
from logger.custom_logger import CustomLoggerTracker
|
|
|
custom_log = CustomLoggerTracker()
|
|
|
logger = custom_log.get_logger("query_utils")
|
|
|
|
|
|
except ImportError:
|
|
|
|
|
|
logger = logging.getLogger("query_utils")
|
|
|
|
|
|
|
|
|
def _log(process_log: List[str], message: str, level: str = "info") -> None:
|
|
|
"""Append to process_log AND send to the central logger."""
|
|
|
process_log.append(message)
|
|
|
if level == "error":
|
|
|
logger.error(message)
|
|
|
elif level == "warning":
|
|
|
logger.warning(message)
|
|
|
else:
|
|
|
logger.info(message)
|
|
|
|
|
|
|
|
|
|
|
|
def enhanced_autism_relevance_check(query: str) -> dict:
|
|
|
try:
|
|
|
logger.info(f"Enhanced autism relevance check for: '{query[:50]}...'")
|
|
|
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
|
|
|
response = qwen_generate(confidence_prompt)
|
|
|
numbers = re.findall(r'\d+', response)
|
|
|
confidence_score = int(numbers[0]) if numbers else 0
|
|
|
confidence_score = max(0, min(100, confidence_score))
|
|
|
|
|
|
if confidence_score >= 80:
|
|
|
category = "directly_autism_related"
|
|
|
action = "accept_as_is"
|
|
|
reasoning = "Directly mentions autism or autism-specific topics"
|
|
|
|
|
|
elif confidence_score >= 60:
|
|
|
category = "highly_autism_relevant"
|
|
|
action = "accept_as_is"
|
|
|
reasoning = "Core autism symptoms or characteristics"
|
|
|
|
|
|
elif confidence_score >= 40:
|
|
|
category = "significantly_autism_relevant"
|
|
|
action = "rewrite_for_autism"
|
|
|
reasoning = "Common comorbidity or autism-related issue"
|
|
|
|
|
|
elif confidence_score >= 20:
|
|
|
category = "moderately_autism_relevant"
|
|
|
action = "rewrite_for_autism"
|
|
|
reasoning = "Broader developmental or family concern related to autism"
|
|
|
|
|
|
elif confidence_score >= 10:
|
|
|
category = "somewhat_autism_relevant"
|
|
|
action = "conditional_rewrite"
|
|
|
reasoning = "General topic with potential autism applications"
|
|
|
else:
|
|
|
category = "not_autism_relevant"
|
|
|
action = "reject"
|
|
|
reasoning = "Not related to autism or autism care"
|
|
|
|
|
|
result = {
|
|
|
"score": confidence_score,
|
|
|
"category": category,
|
|
|
"action": action,
|
|
|
"reasoning": reasoning}
|
|
|
|
|
|
logger.info(f"Enhanced relevance result: {result}")
|
|
|
return result
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in enhanced_autism_relevance_check: {e}")
|
|
|
return {
|
|
|
"score": 0,
|
|
|
"category": "error",
|
|
|
"action": "reject",
|
|
|
"reasoning": "Error during processing"
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def check_autism_confidence(query: str) -> int:
|
|
|
try:
|
|
|
logger.info(f"Checking autism confidence for query: '{query[:50]}...'")
|
|
|
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
|
|
|
response = qwen_generate(confidence_prompt)
|
|
|
numbers = re.findall(r'\d+', response)
|
|
|
confidence_score = int(numbers[0]) if numbers else 0
|
|
|
confidence_score = max(0, min(100, confidence_score))
|
|
|
logger.info(f"Autism confidence score: {confidence_score}")
|
|
|
return confidence_score
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in check_autism_confidence: {e}")
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def rewrite_query_for_autism(query: str) -> str:
|
|
|
try:
|
|
|
logger.info(f"Rewriting query for autism: '{query[:50]}...'")
|
|
|
rewrite_prompt = Prompt_template_autism_rewriter.format(query=query)
|
|
|
rewritten_query = qwen_generate(rewrite_prompt)
|
|
|
if rewritten_query == "Error" or not rewritten_query.strip():
|
|
|
logger.warning("Rewriting failed, using fallback")
|
|
|
rewritten_query = f"How does autism relate to {query.lower()}?"
|
|
|
else:
|
|
|
rewritten_query = rewritten_query.strip()
|
|
|
logger.info(f"Query rewritten to: '{rewritten_query[:50]}...'")
|
|
|
return rewritten_query
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in rewrite_query_for_autism: {e}")
|
|
|
return f"How does autism relate to {query.lower()}?"
|
|
|
|
|
|
|
|
|
|
|
|
def check_answer_autism_relevance(answer: str) -> int:
|
|
|
try:
|
|
|
logger.info(f"Checking answer autism relevance for: '{answer[:50]}...'")
|
|
|
relevance_prompt = Prompt_template_answer_autism_relevance.format(answer=answer)
|
|
|
response = qwen_generate(relevance_prompt)
|
|
|
numbers = re.findall(r'\d+', response)
|
|
|
relevance_score = int(numbers[0]) if numbers else 0
|
|
|
relevance_score = max(0, min(100, relevance_score))
|
|
|
|
|
|
logger.info(f"Answer autism relevance score: {relevance_score}")
|
|
|
return relevance_score
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in check_answer_autism_relevance: {e}")
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def process_query_for_rewrite(query: str) -> tuple[str, bool, str]:
|
|
|
try:
|
|
|
logger.info(f"Processing query with enhanced confidence logic: '{query[:50]}...'")
|
|
|
start_time = time.time()
|
|
|
logger.info("Step 1: Translating/correcting query")
|
|
|
corrected_query = qwen_generate(Prompt_template_translation.format(query=query))
|
|
|
if corrected_query == "Error":
|
|
|
logger.warning("Translation failed, using original query")
|
|
|
corrected_query = query
|
|
|
|
|
|
logger.info("Step 2: Enhanced autism relevance checking")
|
|
|
relevance_result = enhanced_autism_relevance_check(corrected_query)
|
|
|
confidence_score = relevance_result["score"]
|
|
|
action = relevance_result["action"]
|
|
|
reasoning = relevance_result["reasoning"]
|
|
|
logger.info(f"Relevance analysis: {confidence_score}% - {reasoning}")
|
|
|
if action == "accept_as_is":
|
|
|
return corrected_query, True, ""
|
|
|
elif action in ["rewrite_for_autism", "conditional_rewrite"]:
|
|
|
rewritten_query = rewrite_query_for_autism(corrected_query)
|
|
|
return rewritten_query, True, ""
|
|
|
else:
|
|
|
return corrected_query, False, ""
|
|
|
except Exception as e:
|
|
|
logger.error(f"Error in process_query_for_rewrite: {e}")
|
|
|
return query, False, ""
|
|
|
|
|
|
|
|
|
|
|
|
def get_non_autism_response() -> str:
|
|
|
|
|
|
return ("Hi there! I appreciate you reaching out to me. I'm Wisal, and I specialize specifically in autism and Autism Spectrum Disorders. "
|
|
|
"I noticed your question isn't quite related to autism topics. I'd love to help you, but I'm most effective when answering "
|
|
|
"questions about autism, ASD, autism support strategies, therapies, or related concerns.\n\n"
|
|
|
"Could you try asking me something about autism instead? I'm here and ready to help with any autism-related questions you might have! π")
|
|
|
|
|
|
def get_non_autism_answer_response() -> str:
|
|
|
|
|
|
return ("I'm sorry, but the information I found in the document doesn't seem to be related to autism or Autism Spectrum Disorders. "
|
|
|
"Since I'm Wisal, your autism specialist, I want to make sure I'm providing you with relevant, autism-focused information. "
|
|
|
"Could you try asking a question that's more specifically about autism? I'm here to help with any autism-related topics! π")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=="__main__":
|
|
|
logger.info(f"Starting Query utils ...")
|
|
|
|
|
|
answer = qwen_generate("what is autism ?")
|
|
|
logger.info(f"Answer: {answer}")
|
|
|
|