|
|
from logger.custom_logger import CustomLoggerTracker |
|
|
from dotenv import load_dotenv |
|
|
import requests |
|
|
from langdetect import detect |
|
|
from web_search import search_autism |
|
|
from rag_utils import rag_autism |
|
|
from clients import qwen_generate |
|
|
from query_utils import process_query_for_rewrite |
|
|
from rag_utils import is_greeting_or_thank |
|
|
from prompt_template import * |
|
|
import os |
|
|
import re |
|
|
import time |
|
|
import asyncio |
|
|
from typing import List, Dict, Optional |
|
|
from configs import load_yaml_config |
|
|
from query_utils import * |
|
|
|
|
|
config = load_yaml_config("config.yaml") |
|
|
|
|
|
|
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_log = CustomLoggerTracker() |
|
|
logger = custom_log.get_logger("Pipeline Query") |
|
|
logger.info("Logger initialized for Pipeline Query module") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
SESSION_ID = "default" |
|
|
pending_clarifications: Dict[str, str] = {} |
|
|
SILICONFLOW_API_KEY = os.getenv("SILICONFLOW_API_KEY", "") |
|
|
SILICONFLOW_URL = os.getenv("SILICONFLOW_URL", "").strip() |
|
|
SILICONFLOW_CHAT_URL = os.getenv( |
|
|
"SILICONFLOW_CHAT_URL", "https://api.siliconflow.com/v1/chat/completions").strip() |
|
|
|
|
|
if not SILICONFLOW_API_KEY: |
|
|
logger.warning( |
|
|
"SILICONFLOW_API_KEY is not set. LLM/Reranker calls may fail.") |
|
|
if not SILICONFLOW_URL: |
|
|
logger.warning( |
|
|
"SILICONFLOW_URL is not set. OpenAI client base_url will not work.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
|
|
|
|
def clean_pipeline_result(result: str) -> str: |
|
|
if not result: |
|
|
|
|
|
return "I apologize, but I couldn't generate a response. Please try again." |
|
|
result = str(result) |
|
|
|
|
|
result = re.sub(r'<think>.*?</think>', '', result, flags=re.DOTALL) |
|
|
|
|
|
result = re.sub(r'<div[^>]*>', '', result) |
|
|
result = re.sub(r'</div>', '', result) |
|
|
result = re.sub(r'<br\s*/?>', '\n', result) |
|
|
|
|
|
result = re.sub(r'\n\s*\n\s*\n', '\n\n', result) |
|
|
result = result.strip() |
|
|
if len(result.strip()) < 10: |
|
|
|
|
|
return "I apologize, but there was an issue generating a complete response. Please try again." |
|
|
return result |
|
|
|
|
|
|
|
|
def clean_hallucination_score(raw_score_text: str) -> int: |
|
|
""" |
|
|
Clean and extract hallucination score from LLM response. |
|
|
Handles responses like "Score: 5**" or "**Score: 4**" etc. |
|
|
""" |
|
|
try: |
|
|
|
|
|
numbers = re.findall(r'\d+', str(raw_score_text)) |
|
|
if numbers: |
|
|
score = int(numbers[0]) |
|
|
return max(1, min(5, score)) |
|
|
else: |
|
|
logger.warning(f"No numbers found in hallucination score: {raw_score_text}") |
|
|
return 3 |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing hallucination score '{raw_score_text}': {e}") |
|
|
return 3 |
|
|
|
|
|
|
|
|
def _log(process_log: List[str], message: str, level: str = "info") -> None: |
|
|
"""Append to process_log AND send to the central logger.""" |
|
|
process_log.append(message) |
|
|
if level == "error": |
|
|
logger.error(message) |
|
|
elif level == "warning": |
|
|
logger.warning(message) |
|
|
else: |
|
|
logger.info(message) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_autism_pipeline(query, corrected_query, process_log, intro, start_ts): |
|
|
step_times: Dict[str, float] = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting web search phase[1]:") |
|
|
loop = asyncio.get_event_loop() |
|
|
if loop.is_running(): |
|
|
_log(process_log, "Event loop is running, using create_task for search.") |
|
|
task = asyncio.create_task(search_autism(corrected_query)) |
|
|
web_search_resp = loop.run_until_complete(task) |
|
|
else: |
|
|
web_search_resp = asyncio.run(search_autism(corrected_query)) |
|
|
|
|
|
web_answer = web_search_resp.get("answer", "") |
|
|
step_times["Web Search"] = time.time() - start_ts |
|
|
print("=" * 50) |
|
|
print("=" * 50) |
|
|
print(f"Web Answer: β
{web_answer}") |
|
|
print("=" * 50) |
|
|
print("=" * 50) |
|
|
_log(process_log, f"β
Web Search answer: {web_answer}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting LLM generation phase[2]:") |
|
|
gen_prompt = Prompt_template_LLM_Generation.format(new_query=corrected_query) |
|
|
t0 = time.time() |
|
|
generated = qwen_generate(gen_prompt) |
|
|
step_times["LLM Generation"] = time.time() - t0 |
|
|
_log(process_log, f"β
LLM Generated: {generated}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting RAG retrieval phase[3]: ") |
|
|
t0 = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
start = time.time() |
|
|
rag_resp = asyncio.run(rag_autism(corrected_query, top_k=3)) |
|
|
rag_contexts = rag_resp.get("answer", []) |
|
|
step_times["RAG Retrieval"] = time.time() - start |
|
|
_log(process_log, f"β
RAG Contexts: {rag_contexts}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Starting reranking phase") |
|
|
t0 = time.time() |
|
|
items_to_rerank = [generated, web_answer] + rag_contexts |
|
|
rerank_payload = { |
|
|
"model": config["apis_models"]["silicon_flow"]["qwen"]["rerank"], |
|
|
"query": corrected_query, |
|
|
"documents": items_to_rerank} |
|
|
rerank_headers = { |
|
|
"Authorization": f"Bearer {SILICONFLOW_API_KEY}", |
|
|
"Content-Type": "application/json"} |
|
|
reranked = generated |
|
|
_log(process_log, "Rerank: [generated, web_answer] + rag_contexts") |
|
|
_log(process_log, f"Rerank Model: {config['apis_models']['silicon_flow']['qwen']['rerank']}") |
|
|
_log(process_log, "Calling SiliconFlow rerank endpoint...") |
|
|
r = requests.post( |
|
|
os.environ["SILICONFLOW_RERANKING_URL"], |
|
|
json=rerank_payload, |
|
|
headers=rerank_headers, |
|
|
timeout=60,) |
|
|
if r.ok: |
|
|
rerank_data = r.json() |
|
|
ranked_docs = sorted( |
|
|
zip(rerank_data.get("results", []), items_to_rerank), |
|
|
key=lambda x: x[0].get("relevance_score", 0), |
|
|
reverse=True) |
|
|
|
|
|
reranked = ranked_docs[0][1] if ranked_docs else generated |
|
|
_log(process_log, "Reranking succeeded.") |
|
|
print("=" * 50) |
|
|
print(f"Reranked Documents") |
|
|
print("="*50) |
|
|
_log(process_log, f"reranker docs: {ranked_docs}") |
|
|
|
|
|
else: |
|
|
_log(process_log, f"Rerank API failed: {r.text}", level="warning") |
|
|
step_times["Reranking"] = time.time() - t0 |
|
|
_log(process_log, f"β
Reranked doc: {reranked}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Generating Wisal answer") |
|
|
wisal_prompt = Prompt_template_Wisal.format( |
|
|
new_query=corrected_query, document=reranked) |
|
|
t0 = time.time() |
|
|
wisal = qwen_generate(wisal_prompt) |
|
|
step_times["Wisal Answer"] = time.time() - t0 |
|
|
_log(process_log, f"β
Wisal Answer: {wisal}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Running hallucination detection") |
|
|
halluc_prompt = Prompt_template_Halluciations.format( |
|
|
new_query=corrected_query, answer=wisal, document=reranked) |
|
|
|
|
|
t0 = time.time() |
|
|
halluc_raw = qwen_generate(halluc_prompt) |
|
|
step_times["Hallucination Detection"] = time.time() - t0 |
|
|
_log(process_log, f"β
Hallucination Score Raw: {halluc_raw}") |
|
|
|
|
|
|
|
|
score = clean_hallucination_score(halluc_raw) |
|
|
_log(process_log, f"β
Cleaned Hallucination Score: {score}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if score in (2, 3): |
|
|
logger.info("Hallucination detected, running paraphrasing") |
|
|
t0 = time.time() |
|
|
_log(process_log, "Score indicates paraphrasing path.") |
|
|
paraphrased = qwen_generate( |
|
|
Prompt_template_paraphrasing.format(document=reranked)) |
|
|
wisal = qwen_generate( |
|
|
Prompt_template_Wisal.format( |
|
|
new_query=corrected_query, document=paraphrased)) |
|
|
step_times["Paraphrasing & Re-Wisal"] = time.time() - t0 |
|
|
_log(process_log, f"Paraphrased Wisal: {wisal}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Checking if translation is needed") |
|
|
t0 = time.time() |
|
|
detected_lang = "en" |
|
|
if query.strip(): |
|
|
try: |
|
|
detected_lang = detect(query) |
|
|
except: |
|
|
detected_lang = "en" |
|
|
|
|
|
|
|
|
|
|
|
is_english_text = bool(re.fullmatch(r"[A-Za-z0-9 .,?;:'\"!()\-]+", query)) |
|
|
|
|
|
|
|
|
result = wisal |
|
|
logger.info(f"Input language detected as: {detected_lang}, but output forced to English") |
|
|
_log(process_log, f"Input language: {detected_lang}, Output language: English (forced)") |
|
|
step_times["Language Detection & Translation"] = time.time() - t0 |
|
|
_log(process_log, f"β
Final Result: {result}") |
|
|
|
|
|
for step, duration in step_times.items(): |
|
|
_log(process_log, f"β±οΈ {step} completed in {duration:.2f} seconds") |
|
|
_save_process_log(process_log) |
|
|
text_dir = "rtl" if detected_lang in ["ar", "fa", "ur", "he"] else "ltr" |
|
|
|
|
|
|
|
|
cleaned_result = clean_pipeline_result(result) |
|
|
logger.info( f'<div dir="{text_dir}">{result}</div>') |
|
|
logger.info("Pipeline completed successfully") |
|
|
return cleaned_result |
|
|
|
|
|
|
|
|
def _save_process_log(log_lines: List[str], filename: Optional[str] = None) -> None: |
|
|
import datetime |
|
|
logs_dir = os.path.join(os.path.dirname(__file__), "logs") |
|
|
|
|
|
os.makedirs(logs_dir, exist_ok=True) |
|
|
if not filename: |
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
|
filename = f"log_{timestamp}.txt" |
|
|
log_path = os.path.join(logs_dir, filename) |
|
|
with open(log_path, "w", encoding="utf-8") as f: |
|
|
for line in log_lines: |
|
|
f.write(str(line) + "\n\n") |
|
|
logger.info(f"Process log saved to {log_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def process_query(query: str, first_turn: bool = False, session_id: str = "default"): |
|
|
start_ts = time.time() |
|
|
intro = "" |
|
|
process_log: List[str] = [] |
|
|
step_times: Dict[str, float] = {} |
|
|
|
|
|
logger.info(f"π Query received at {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
logger.info(f"π Session ID: {session_id}") |
|
|
logger.info(f"π First turn: {first_turn}") |
|
|
logger.info(f"π Query: {query}") |
|
|
logger.info(f"Processing query: {query[:100]}... (session: {session_id})") |
|
|
|
|
|
|
|
|
if session_id in pending_clarifications: |
|
|
if query.strip().lower() == "yes": |
|
|
corrected_query = pending_clarifications.pop(session_id) |
|
|
step_times["Language Detection & Translation"] = time.time() - \ |
|
|
start_ts |
|
|
_log(process_log, f"User confirmed clarification. corrected_query={corrected_query}") |
|
|
return process_autism_pipeline(corrected_query, corrected_query, process_log, intro, start_ts) |
|
|
else: |
|
|
pending_clarifications.pop(session_id) |
|
|
_log(process_log, "User rejected clarification; resetting session.") |
|
|
|
|
|
return "Hello I'm Wisal, an AI assistant developed by Compumacy AI. Please ask a question specifically about autism." |
|
|
|
|
|
if first_turn and (not query or query.strip() == ""): |
|
|
_log(process_log, "Empty first turn; sending greeting.") |
|
|
|
|
|
return "Hello! I'm Wisal, an AI assistant developed by Compumacy AI. How can I help you today?" |
|
|
|
|
|
|
|
|
intent = is_greeting_or_thank(query) |
|
|
if intent == "greeting": |
|
|
_log(process_log, "Greeting detected.") |
|
|
|
|
|
return intro + "Hello! I'm Wisal, your AI assistant developed by Compumacy AI. How can I help you today?" |
|
|
|
|
|
elif intent == "thanks": |
|
|
_log(process_log, "Thanks detected.") |
|
|
|
|
|
return "You're welcome! π If you have more questions about autism, feel free to ask." |
|
|
|
|
|
|
|
|
logger.info(f"β±οΈ Query preprocessing completed in {time.time() - start_ts:.2f} seconds") |
|
|
corrected_query, is_autism_related, rewritten_query = process_query_for_rewrite(query) |
|
|
_log(process_log, f"β
Original Query: {query}") |
|
|
_log(process_log, f"β
Corrected Query: {corrected_query}") |
|
|
_log(process_log, f"β
Relevance Check: {'RELATED' if is_autism_related else 'NOT RELATED'}") |
|
|
|
|
|
if rewritten_query: |
|
|
_log(process_log, f"β
LLM rewritten: {rewritten_query}") |
|
|
if not is_autism_related: |
|
|
clarification = f"""β
Your query was not clearly related to autism. Do you mean: "{rewritten_query}"?""" |
|
|
pending_clarifications[session_id] = rewritten_query |
|
|
_log(process_log, f"β
Clarification prompted: {clarification}") |
|
|
return clarification |
|
|
|
|
|
logger.info(f"π Starting autism pipeline at {time.strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
return process_autism_pipeline(query, corrected_query, process_log, intro, start_ts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_environment_setup(): |
|
|
"""Test environment variables and configuration""" |
|
|
print("\n" + "="*60) |
|
|
print("π§ TESTING ENVIRONMENT SETUP") |
|
|
print("="*60) |
|
|
|
|
|
test_results = {} |
|
|
|
|
|
|
|
|
test_results['SILICONFLOW_API_KEY'] = bool(SILICONFLOW_API_KEY) |
|
|
test_results['SILICONFLOW_URL'] = bool(SILICONFLOW_URL) |
|
|
test_results['SILICONFLOW_CHAT_URL'] = bool(SILICONFLOW_CHAT_URL) |
|
|
|
|
|
|
|
|
try: |
|
|
test_results['config_loaded'] = bool(config) |
|
|
test_results['apis_models_config'] = 'apis_models' in config |
|
|
except Exception as e: |
|
|
test_results['config_loaded'] = False |
|
|
test_results['config_error'] = str(e) |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Test log message") |
|
|
test_results['logger_working'] = True |
|
|
except Exception as e: |
|
|
test_results['logger_working'] = False |
|
|
test_results['logger_error'] = str(e) |
|
|
|
|
|
|
|
|
for key, value in test_results.items(): |
|
|
status = "β
" if value else "β" |
|
|
print(f"{status} {key}: {value}") |
|
|
|
|
|
return all(v for k, v in test_results.items() if not k.endswith('_error')) |
|
|
|
|
|
|
|
|
def test_score_cleaning(): |
|
|
"""Test the new score cleaning function""" |
|
|
print("\n" + "="*60) |
|
|
print("π§Ή TESTING SCORE CLEANING FUNCTION") |
|
|
print("="*60) |
|
|
|
|
|
test_cases = [ |
|
|
("Score: 5**", 5), |
|
|
("**Score: 4**", 4), |
|
|
("Score: 3", 3), |
|
|
("The score is 2 out of 5", 2), |
|
|
("No numbers here", 3), |
|
|
("Score: 0", 1), |
|
|
("Score: 10", 5), |
|
|
("", 3), |
|
|
] |
|
|
|
|
|
results = {} |
|
|
for input_text, expected in test_cases: |
|
|
try: |
|
|
result = clean_hallucination_score(input_text) |
|
|
success = result == expected |
|
|
status = "β
" if success else "β" |
|
|
print(f"{status} Input: '{input_text}' -> Got: {result}, Expected: {expected}") |
|
|
results[input_text or "empty"] = success |
|
|
except Exception as e: |
|
|
print(f"β Error with '{input_text}': {e}") |
|
|
results[input_text or "empty"] = False |
|
|
|
|
|
success_rate = sum(results.values()) / len(results) |
|
|
print(f"\nπ Score Cleaning Success Rate: {success_rate:.1%}") |
|
|
return results |
|
|
|
|
|
|
|
|
def run_all_tests(): |
|
|
"""Run all tests and provide a summary""" |
|
|
print("\n" + "π§ͺ" + "="*58) |
|
|
print("π§ͺ RUNNING COMPREHENSIVE PIPELINE TESTS (FIXED VERSION)") |
|
|
print("π§ͺ" + "="*58) |
|
|
|
|
|
test_results = {} |
|
|
|
|
|
|
|
|
print("Starting test suite...") |
|
|
|
|
|
test_results["Environment"] = test_environment_setup() |
|
|
test_results["Score Cleaning"] = test_score_cleaning() |
|
|
|
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("π§© TESTING FIXED PIPELINE") |
|
|
print("="*60) |
|
|
|
|
|
try: |
|
|
test_query = "What are the early signs of autism?" |
|
|
print(f"Testing query: '{test_query}'") |
|
|
start_time = time.time() |
|
|
response = process_query(test_query, session_id="fix_test") |
|
|
duration = time.time() - start_time |
|
|
print(f"β
SUCCESS - Pipeline completed in {duration:.2f}s") |
|
|
print(f"Response length: {len(response)} characters") |
|
|
test_results["Fixed Pipeline"] = True |
|
|
except Exception as e: |
|
|
print(f"β FAILED - Error: {e}") |
|
|
test_results["Fixed Pipeline"] = False |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
print("\n" + "π" + "="*58) |
|
|
print("π TEST SUMMARY") |
|
|
print("π" + "="*58) |
|
|
|
|
|
for test_name, result in test_results.items(): |
|
|
if isinstance(result, bool): |
|
|
status = "β
PASS" if result else "β FAIL" |
|
|
print(f"{status} {test_name}") |
|
|
elif isinstance(result, dict): |
|
|
passed = sum(result.values()) |
|
|
total = len(result) |
|
|
print(f"π {test_name}: {passed}/{total} ({passed/total:.1%})") |
|
|
else: |
|
|
print(f"βΉοΈ INFO {test_name}: {result}") |
|
|
|
|
|
print("\nπ Testing completed!") |
|
|
return test_results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def is_obvious_autism_query(query: str) -> bool: |
|
|
"""Check if query is obviously autism-related to bypass heavy processing""" |
|
|
obvious_keywords = [ |
|
|
'autism', 'autistic', 'asd', 'autism spectrum', 'asperger', |
|
|
'stimming', 'stim', 'meltdown', 'sensory processing disorder', |
|
|
'special interest', 'echolalia', 'repetitive behavior'] |
|
|
query_lower = query.lower() |
|
|
return any(keyword in query_lower for keyword in obvious_keywords) |
|
|
|
|
|
|
|
|
def is_obvious_non_autism_query(query: str) -> bool: |
|
|
"""Check if query is obviously NOT autism-related""" |
|
|
non_autism_patterns = [ |
|
|
r'\b(weather|temperature|forecast|rain|snow|sunny)\b', |
|
|
r'\b(recipe|cooking|food preparation|ingredients)\b', |
|
|
r'\b(sports|football|basketball|soccer|tennis)\b', |
|
|
r'\b(stock market|investing|cryptocurrency|trading)\b', |
|
|
r'\b(travel|vacation|hotel|flight|tourism)\b', |
|
|
r'\b(movie|film|entertainment|celebrity|actor)\b'] |
|
|
query_lower = query.lower() |
|
|
return any(re.search(pattern, query_lower) for pattern in non_autism_patterns) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_improved_pipeline(): |
|
|
"""Test the improved pipeline with various query types""" |
|
|
test_cases = [ |
|
|
|
|
|
("What is autism?", True), |
|
|
("My autistic child has meltdowns", True), |
|
|
("Autism spectrum disorder symptoms", True), |
|
|
|
|
|
|
|
|
("My child has behavioral issues", True), |
|
|
("Sleep problems in 6 year old", True), |
|
|
("ADHD and anxiety in teenagers", True), |
|
|
("Social skills development", True), |
|
|
|
|
|
|
|
|
("Child development milestones", True), |
|
|
("Family stress management", True), |
|
|
|
|
|
|
|
|
("What's the weather today?", False), |
|
|
("How to cook pasta?", False), |
|
|
("Stock market trends", False), |
|
|
] |
|
|
|
|
|
print("Testing improved pipeline:") |
|
|
print("-" * 50) |
|
|
|
|
|
for query, expected_acceptance in test_cases: |
|
|
try: |
|
|
_, is_relevant, _ = process_query_for_rewrite(query) |
|
|
result = "ACCEPTED" if is_relevant else "REJECTED" |
|
|
expected = "ACCEPTED" if expected_acceptance else "REJECTED" |
|
|
status = "β
" if (is_relevant == expected_acceptance) else "β" |
|
|
|
|
|
print(f"{status} '{query[:40]}...' -> {result} (expected {expected})") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β '{query[:40]}...' -> ERROR: {e}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
logger.info("PipeQuery Logger Starting ....") |
|
|
|
|
|
|
|
|
print("\nπ§ TESTING SCORE CLEANING FIX...") |
|
|
test_score_cleaning() |
|
|
|
|
|
|
|
|
print("\n" + "π" + "="*58) |
|
|
print("π WISAL AUTISM PIPELINE - TESTING SUITE (FIXED)") |
|
|
print("π" + "="*58) |
|
|
|
|
|
|
|
|
import sys |
|
|
|
|
|
if len(sys.argv) > 1: |
|
|
|
|
|
mode = sys.argv[1].lower() |
|
|
|
|
|
if mode == "full": |
|
|
run_all_tests() |
|
|
elif mode == "fix": |
|
|
test_score_cleaning() |
|
|
else: |
|
|
print(f"Unknown test mode: {mode}") |
|
|
print("Available modes: full, fix") |
|
|
else: |
|
|
|
|
|
while True: |
|
|
print("\n" + "π§" + " "*20 + "TEST MENU" + " "*20 + "π§") |
|
|
print("1. π Run All Tests") |
|
|
print("2. π§Ή Test Score Cleaning Fix") |
|
|
print("3. π§© Test Fixed Pipeline") |
|
|
print("4. π¬ Interactive Query Test") |
|
|
print("0. πͺ Exit") |
|
|
|
|
|
choice = input("\nEnter your choice (0-4): ").strip() |
|
|
|
|
|
if choice == "1": |
|
|
run_all_tests() |
|
|
elif choice == "2": |
|
|
test_score_cleaning() |
|
|
elif choice == "3": |
|
|
try: |
|
|
test_query = input("Enter test query: ").strip() |
|
|
if test_query: |
|
|
print(f"\nπ Processing: {test_query}") |
|
|
start_time = time.time() |
|
|
response = process_query(test_query, session_id="manual_test") |
|
|
duration = time.time() - start_time |
|
|
print(f"\nβ
Response ({duration:.2f}s):") |
|
|
print("-" * 50) |
|
|
print(response) |
|
|
print("-" * 50) |
|
|
except Exception as e: |
|
|
print(f"\nβ Error: {e}") |
|
|
elif choice == "4": |
|
|
|
|
|
print("\n" + "π¬" + "="*40) |
|
|
print("π¬ INTERACTIVE QUERY TESTING") |
|
|
print("π¬" + "="*40) |
|
|
print("Enter 'quit' to return to menu") |
|
|
|
|
|
while True: |
|
|
query = input("\nEnter your query: ").strip() |
|
|
if query.lower() == 'quit': |
|
|
break |
|
|
|
|
|
try: |
|
|
print(f"\nπ Processing: {query}") |
|
|
start_time = time.time() |
|
|
response = process_query(query, session_id="interactive_test") |
|
|
duration = time.time() - start_time |
|
|
|
|
|
print(f"\nβ
Response ({duration:.2f}s):") |
|
|
print("-" * 50) |
|
|
print(response) |
|
|
print("-" * 50) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"\nβ Error: {e}") |
|
|
|
|
|
elif choice == "0": |
|
|
print("\nπ Goodbye! The score cleaning fix should resolve your issue!") |
|
|
break |
|
|
else: |
|
|
print("β Invalid choice. Please try again.") |
|
|
|
|
|
input("\nPress Enter to continue...") |
|
|
|
|
|
|
|
|
|