File size: 8,437 Bytes
712579e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
import os
from dotenv import load_dotenv
from configs import load_yaml_config
import requests
import os
import logging
import time
import re
import os
from clients import qwen_generate
import requests
from typing import List
from prompt_template import (
Prompt_template_translation,
Prompt_template_autism_confidence,
Prompt_template_autism_rewriter,
Prompt_template_answer_autism_relevance
)
# Load SiliconFlow API key
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY")
SILICONFLOW_EMBEDDING_URL = os.getenv("SILICONFLOW_EMBEDDING_URL")
load_dotenv()
## load config from yaml
config = load_yaml_config("config.yaml")
try:
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("query_utils")
except ImportError:
# Fallback to standard logging if custom logger not available
logger = logging.getLogger("query_utils")
def _log(process_log: List[str], message: str, level: str = "info") -> None:
"""Append to process_log AND send to the central logger."""
process_log.append(message)
if level == "error":
logger.error(message)
elif level == "warning":
logger.warning(message)
else:
logger.info(message)
def enhanced_autism_relevance_check(query: str) -> dict:
try:
logger.info(f"Enhanced autism relevance check for: '{query[:50]}...'")
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
response = qwen_generate(confidence_prompt)
numbers = re.findall(r'\d+', response)
confidence_score = int(numbers[0]) if numbers else 0
confidence_score = max(0, min(100, confidence_score))
if confidence_score >= 80:
category = "directly_autism_related"
action = "accept_as_is"
reasoning = "Directly mentions autism or autism-specific topics"
elif confidence_score >= 60:
category = "highly_autism_relevant"
action = "accept_as_is"
reasoning = "Core autism symptoms or characteristics"
elif confidence_score >= 40:
category = "significantly_autism_relevant"
action = "rewrite_for_autism"
reasoning = "Common comorbidity or autism-related issue"
elif confidence_score >= 20:
category = "moderately_autism_relevant"
action = "rewrite_for_autism"
reasoning = "Broader developmental or family concern related to autism"
elif confidence_score >= 10:
category = "somewhat_autism_relevant"
action = "conditional_rewrite"
reasoning = "General topic with potential autism applications"
else:
category = "not_autism_relevant"
action = "reject"
reasoning = "Not related to autism or autism care"
result = {
"score": confidence_score,
"category": category,
"action": action,
"reasoning": reasoning}
logger.info(f"Enhanced relevance result: {result}")
return result
except Exception as e:
logger.error(f"Error in enhanced_autism_relevance_check: {e}")
return {
"score": 0,
"category": "error",
"action": "reject",
"reasoning": "Error during processing"
}
def check_autism_confidence(query: str) -> int:
try:
logger.info(f"Checking autism confidence for query: '{query[:50]}...'")
confidence_prompt = Prompt_template_autism_confidence.format(query=query)
response = qwen_generate(confidence_prompt)
numbers = re.findall(r'\d+', response)
confidence_score = int(numbers[0]) if numbers else 0
confidence_score = max(0, min(100, confidence_score))
logger.info(f"Autism confidence score: {confidence_score}")
return confidence_score
except Exception as e:
logger.error(f"Error in check_autism_confidence: {e}")
return 0
def rewrite_query_for_autism(query: str) -> str:
try:
logger.info(f"Rewriting query for autism: '{query[:50]}...'")
rewrite_prompt = Prompt_template_autism_rewriter.format(query=query)
rewritten_query = qwen_generate(rewrite_prompt)
if rewritten_query == "Error" or not rewritten_query.strip():
logger.warning("Rewriting failed, using fallback")
rewritten_query = f"How does autism relate to {query.lower()}?"
else:
rewritten_query = rewritten_query.strip()
logger.info(f"Query rewritten to: '{rewritten_query[:50]}...'")
return rewritten_query
except Exception as e:
logger.error(f"Error in rewrite_query_for_autism: {e}")
return f"How does autism relate to {query.lower()}?"
def check_answer_autism_relevance(answer: str) -> int:
try:
logger.info(f"Checking answer autism relevance for: '{answer[:50]}...'")
relevance_prompt = Prompt_template_answer_autism_relevance.format(answer=answer)
response = qwen_generate(relevance_prompt)
numbers = re.findall(r'\d+', response)
relevance_score = int(numbers[0]) if numbers else 0
relevance_score = max(0, min(100, relevance_score))
logger.info(f"Answer autism relevance score: {relevance_score}")
return relevance_score
except Exception as e:
logger.error(f"Error in check_answer_autism_relevance: {e}")
return 0
def process_query_for_rewrite(query: str) -> tuple[str, bool, str]:
try:
logger.info(f"Processing query with enhanced confidence logic: '{query[:50]}...'")
start_time = time.time()
logger.info("Step 1: Translating/correcting query")
corrected_query = qwen_generate(Prompt_template_translation.format(query=query))
if corrected_query == "Error":
logger.warning("Translation failed, using original query")
corrected_query = query
logger.info("Step 2: Enhanced autism relevance checking")
relevance_result = enhanced_autism_relevance_check(corrected_query)
confidence_score = relevance_result["score"]
action = relevance_result["action"]
reasoning = relevance_result["reasoning"]
logger.info(f"Relevance analysis: {confidence_score}% - {reasoning}")
if action == "accept_as_is":
return corrected_query, True, ""
elif action in ["rewrite_for_autism", "conditional_rewrite"]:
rewritten_query = rewrite_query_for_autism(corrected_query)
return rewritten_query, True, ""
else:
return corrected_query, False, ""
except Exception as e:
logger.error(f"Error in process_query_for_rewrite: {e}")
return query, False, ""
def get_non_autism_response() -> str:
# This response is always in English regardless of input language
return ("Hi there! I appreciate you reaching out to me. I'm Wisal, and I specialize specifically in autism and Autism Spectrum Disorders. "
"I noticed your question isn't quite related to autism topics. I'd love to help you, but I'm most effective when answering "
"questions about autism, ASD, autism support strategies, therapies, or related concerns.\n\n"
"Could you try asking me something about autism instead? I'm here and ready to help with any autism-related questions you might have! π")
def get_non_autism_answer_response() -> str:
# This response is always in English regardless of input language
return ("I'm sorry, but the information I found in the document doesn't seem to be related to autism or Autism Spectrum Disorders. "
"Since I'm Wisal, your autism specialist, I want to make sure I'm providing you with relevant, autism-focused information. "
"Could you try asking a question that's more specifically about autism? I'm here to help with any autism-related topics! π")
if __name__=="__main__":
logger.info(f"Starting Query utils ...")
# answer = call_llm("what is autism ?")
answer = qwen_generate("what is autism ?")
logger.info(f"Answer: {answer}")
|