Autism_QA / audio_utils.py
A7m0d's picture
Upload folder using huggingface_hub
712579e verified
import asyncio
import base64
import os
import time
from google.genai import types
from google.genai.types import (
LiveConnectConfig,
SpeechConfig,
VoiceConfig,
PrebuiltVoiceConfig,
Content,
Part,)
import os
from pipeQuery import process_query
import re
from pipeQuery import clean_pipeline_result
import numpy as np
from dotenv import load_dotenv
from fastrtc import wait_for_item
import google.genai as genai
import asyncio
import base64
import os
from typing import AsyncGenerator, Literal
import gradio as gr
import numpy as np
from fastrtc import (
AsyncStreamHandler,
wait_for_item,)
import google.generativeai as genai
from google.genai.types import (
LiveConnectConfig,
PrebuiltVoiceConfig,
SpeechConfig,
VoiceConfig,)
from clients import gemini_client
import soundfile as sf
import io
import collections
import time
## Load Custom Logger
from logger.custom_logger import CustomLoggerTracker
custom_log = CustomLoggerTracker()
logger = custom_log.get_logger("audio_utils")
## Load APIS from dotenv
load_dotenv()
## load config
from configs import load_yaml_config
config = load_yaml_config("config.yaml")
def encode_audio(data: np.ndarray) -> dict:
return {
"mime_type": "audio/pcm",
"data": base64.b64encode(data.tobytes()).decode("UTF-8"),}
def encode_audio2(data: np.ndarray) -> bytes:
return data.tobytes()
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
return buffer.getvalue()
def numpy_array_to_wav_bytes(audio_array, sample_rate=16000):
buffer = io.BytesIO()
sf.write(buffer, audio_array, sample_rate, format='WAV')
buffer.seek(0)
return buffer.read()
class GeminiHandler(AsyncStreamHandler):
def __init__(
self,
expected_layout: Literal["mono"] = "mono",
output_sample_rate: int = 24000,
prompt_dict: dict = {"prompt": "PHQ-9"},
) -> None:
super().__init__(
expected_layout,
output_sample_rate,
input_sample_rate=16000,
)
self.input_queue: asyncio.Queue = asyncio.Queue()
self.output_queue: asyncio.Queue = asyncio.Queue()
self.quit: asyncio.Event = asyncio.Event()
self.is_active: bool = False
self.prompt_dict = prompt_dict
# Load from config if available, otherwise use defaults
try:
self.model = config["audio"]["model_live"]
self.t2t_model = config["audio"]["tts_model"]
self.s2t_model = config["audio"]["stt_model"]
self.VAD_RATE = config["audio"]["VAD_RATE"]
self.VAD_FRAME_MS = config["audio"]["VAD_FRAME_MS"]
padding_ms = config["audio"]["padding_ms"]
self.vad_ratio = config["audio"]["vad_ratio"]
except (KeyError, NameError):
# Fallback defaults if config not available
self.model = "gemini-2.5-flash-preview-tts"
self.t2t_model = "gemini-2.0-flash-exp"
self.s2t_model = "gemini-2.0-flash-exp"
self.VAD_RATE = 16000
self.VAD_FRAME_MS = 30
padding_ms = 300
self.vad_ratio = 0.9
# VAD Initialization
try:
import webrtcvad
self.vad = webrtcvad.Vad(3)
self.vad_available = True
except ImportError:
logger.warning("webrtcvad not available, VAD disabled")
self.vad_available = False
self.VAD_FRAME_SAMPLES = int(self.VAD_RATE * (self.VAD_FRAME_MS / 1000.0))
self.VAD_FRAME_BYTES = self.VAD_FRAME_SAMPLES * 2
self.vad_padding_frames = padding_ms // self.VAD_FRAME_MS
self.vad_ring_buffer = collections.deque(maxlen=self.vad_padding_frames)
self.vad_triggered = False
self.wav_data = bytearray()
self.internal_buffer = bytearray()
self.end_of_speech_time: float | None = None
self.first_latency_calculated: bool = False
def copy(self) -> "GeminiHandler":
return GeminiHandler(
expected_layout="mono",
output_sample_rate=self.output_sample_rate,
prompt_dict=self.prompt_dict,
)
def stop(self) -> None:
logger.info("Stopping GeminiHandler...")
self.quit.set()
self.is_active = False
def shutdown(self) -> None:
self.stop()
def t2t_with_rag(self, text: str) -> str:
try:
response = process_query(text)
if isinstance(response, tuple):
result = clean_pipeline_result(response[0] if response[0] else response[1])
else:
result = clean_pipeline_result(str(response))
logger.info(f"RAG response generated: {result[:100]}...")
return result
except Exception as e:
logger.error(f"Error in RAG processing: {e}")
try:
response = self.chat.send_message(text)
return response.text
except Exception as fallback_error:
logger.error(f"Fallback Gemini also failed: {fallback_error}")
return "I'm sorry, I'm having trouble processing your request right now."
def s2t(self, audio) -> str:
try:
response = self.s2t_client.models.generate_content(
model=self.s2t_model,
contents=[
types.Part.from_bytes(data=audio, mime_type='audio/wav'),
'Generate a transcript of the speech.'
]
)
return response.text.strip()
except Exception as e:
logger.error(f"STT error: {e}")
return ""
async def start_up(self):
"""Initialize the handler with proper error handling"""
try:
self.is_active = True
self.t2t_bool = True # Enable RAG processing
# Initialize clients with error handling
try:
self.t2t_client = gemini_client()
self.s2t_client = gemini_client()
self.t2s_client = gemini_client()
except Exception as e:
logger.error(f"Failed to initialize Gemini clients: {e}")
return
# Chat configuration
sys_instruction = """You are Wisal, an AI assistant developed by Compumacy AI, specialized in Autism Spectrum Disorder (ASD).
Your sole purpose is to provide helpful, respectful, and easy-to-understand answers about Autism.
Always be clear, non-judgmental, and supportive."""
try:
chat_config = types.GenerateContentConfig(system_instruction=sys_instruction)
self.chat = self.t2t_client.chats.create(model=self.t2t_model, config=chat_config)
except Exception as e:
logger.error(f"Failed to create chat: {e}")
return
# Live connect configuration
voice_name = "Puck"
try:
config = LiveConnectConfig(
response_modalities=["AUDIO"],
speech_config=SpeechConfig(
voice_config=VoiceConfig(
prebuilt_voice_config=PrebuiltVoiceConfig(voice_name=voice_name)
)
),
system_instruction=Content(parts=[Part.from_text(text=sys_instruction)])
)
except Exception as e:
logger.error(f"Failed to create live config: {e}")
return
# Main processing loop with stop capability
try:
async with self.t2s_client.aio.live.connect(model=self.model, config=config) as session:
async for text_from_user in self.stream():
if self.quit.is_set():
break
if text_from_user and text_from_user.strip():
logger.info(f"Processing user input: {text_from_user}")
# Process through RAG pipeline
if self.t2t_bool:
processed_response = self.t2t_with_rag(text_from_user)
else:
processed_response = text_from_user
try:
await session.send_client_content(
turns=types.Content(
role='user',
parts=[types.Part(text=processed_response)]
)
)
async for resp_chunk in session.receive():
if self.quit.is_set():
break
if resp_chunk.data:
array = np.frombuffer(resp_chunk.data, dtype=np.int16)
self.output_queue.put_nowait((self.output_sample_rate, array))
except Exception as e:
logger.error(f"Error in session communication: {e}")
except Exception as e:
logger.error(f"Error in live session: {e}")
except Exception as e:
logger.error(f"Error in start_up: {e}")
finally:
self.is_active = False
async def stream(self) -> AsyncGenerator[str, None]:
"""Stream text messages with stop capability"""
while not self.quit.is_set():
try:
text_to_speak = await asyncio.wait_for(self.input_queue.get(), timeout=1.0)
if text_to_speak and not self.quit.is_set():
yield text_to_speak
except asyncio.TimeoutError:
continue
except Exception as e:
logger.error(f"Error in stream: {e}")
break
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
"""Receive and process audio frames with VAD"""
if self.quit.is_set():
return
try:
sr, array = frame
audio_bytes = array.tobytes()
self.internal_buffer.extend(audio_bytes)
# VAD processing if available
if not self.vad_available:
# Simple fallback without VAD
if len(self.internal_buffer) > self.VAD_FRAME_BYTES * 10: # Collect some audio
full_utterance_np = np.frombuffer(self.internal_buffer, dtype=np.int16)
audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
text_input = self.s2t(audio_input_wav)
if text_input and text_input.strip():
self.input_queue.put_nowait(text_input)
self.internal_buffer = bytearray()
return
# Original VAD processing
while len(self.internal_buffer) >= self.VAD_FRAME_BYTES:
if self.quit.is_set():
break
vad_frame = self.internal_buffer[:self.VAD_FRAME_BYTES]
self.internal_buffer = self.internal_buffer[self.VAD_FRAME_BYTES:]
try:
is_speech = self.vad.is_speech(vad_frame, self.VAD_RATE)
except Exception as e:
logger.error(f"VAD error: {e}")
continue
if not self.vad_triggered:
self.vad_ring_buffer.append((vad_frame, is_speech))
num_voiced = len([f for f, speech in self.vad_ring_buffer if speech])
if num_voiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
logger.info("Speech detected, starting to record...")
self.vad_triggered = True
for f, s in self.vad_ring_buffer:
self.wav_data.extend(f)
self.vad_ring_buffer.clear()
else:
self.wav_data.extend(vad_frame)
self.vad_ring_buffer.append((vad_frame, is_speech))
num_unvoiced = len([f for f, speech in self.vad_ring_buffer if not speech])
if num_unvoiced > self.vad_ratio * self.vad_ring_buffer.maxlen:
logger.info("End of speech detected.")
self.vad_triggered = False
try:
full_utterance_np = np.frombuffer(self.wav_data, dtype=np.int16)
audio_input_wav = numpy_array_to_wav_bytes(full_utterance_np, sr)
text_input = self.s2t(audio_input_wav)
if text_input and text_input.strip():
self.input_queue.put_nowait(text_input)
except Exception as e:
logger.error(f"Error processing speech: {e}")
self.vad_ring_buffer.clear()
self.wav_data = bytearray()
except Exception as e:
logger.error(f"Error in receive: {e}")
async def emit(self) -> tuple[int, np.ndarray] | None:
"""Emit audio output with stop capability"""
try:
return await asyncio.wait_for(wait_for_item(self.output_queue), timeout=1.0)
except asyncio.TimeoutError:
return None
except Exception as e:
logger.error(f"Error in emit: {e}")
return None
# Global handle
# ---------------------------
# Audio Transcription
# ---------------------------
def transcribe_audio(audio_filepath):
logger.info(f"Starting audio transcription for: {audio_filepath}")
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
logger.error("GEMINI_API_KEY environment variable not set.")
yield "[ERROR] API Key is missing. Please configure your environment."
return
if not audio_filepath or not os.path.exists(audio_filepath):
logger.error(f"Audio file does not exist at path: {audio_filepath}")
yield "[ERROR] Audio file not found. Please record or upload again."
return
genai.configure(api_key=api_key)
model = genai.GenerativeModel(model_name=config["audio"]["tts_model"])
logger.info(f"Uploading audio file for transcription: {audio_filepath}")
yield "Status: Uploading audio..."
audio_file = genai.upload_file(path=audio_filepath)
while audio_file.state.name == "PROCESSING":
yield "Status: Processing uploaded file..."
time.sleep(2)
audio_file = genai.get_file(audio_file.name)
if audio_file.state.name == "FAILED":
logger.error("Google AI file processing failed.")
yield "[ERROR] Audio file processing failed on the server."
return
yield "Status: Transcribing..."
response = model.generate_content(
["Please transcribe this audio recording accurately.", audio_file],
request_options={"timeout": 120})
genai.delete_file(audio_file.name)
if response and hasattr(response, 'text') and response.text:
query = response.text.strip()
logger.info(f"Transcription complete, length={len(query)}")
yield query
else:
logger.error("Transcription failed: empty/malformed response.")
yield "[ERROR] Transcription failed: The model returned an empty response."
def get_transcription_or_text(text_input, audio_input):
"""Extract text from either text input or audio input."""
if text_input and text_input.strip():
logger.info(f"Processing text query...")
return text_input.strip(), "Status: Processing text query..."
if audio_input is not None:
try:
transcription_result = transcribe_audio(audio_input)
# Handle generator or direct result
if hasattr(transcription_result, '__iter__') and not isinstance(transcription_result, str):
for result in transcription_result:
if result.startswith("[ERROR]"):
return result, "error"
return result, "Status: Processing audio transcription..."
else:
if transcription_result.startswith("[ERROR]"):
return transcription_result, "error"
return transcription_result, "Status: Processing audio transcription..."
except Exception as e:
logger.error(f"Transcription error: {e}")
return f"[ERROR] Transcription failed: {e}", "error"
return None, "Status: Please type a question or provide an audio recording."
def generate_tts_response(cleaned_text, voice_name):
"""Generate TTS response using Gemini."""
try:
tts_config = types.GenerateContentConfig(
response_modalities=["AUDIO"],
speech_config=types.SpeechConfig(
voice_config=types.VoiceConfig(
prebuilt_voice_config=types.PrebuiltVoiceConfig(voice_name=voice_name)
)
)
)
# Fixed: Call gemini_client() function to get the client instance
client = gemini_client()
response = client.models.generate_content(
model=config["audio"]["tts_model"],
contents=cleaned_text,
config=tts_config)
if not response.candidates or not response.candidates[0].content.parts:
logger.warning("Model did not return audio content")
return None, "Status: Model did not return audio."
pcm_data = response.candidates[0].content.parts[0].inline_data.data
return (24000, np.frombuffer(pcm_data, dtype=np.int16)), "Status: Success!"
except Exception as e:
logger.error(f"TTS Error: {e}")
return None, f"Status: An error occurred during TTS: {e}"
def process_input_and_generate_speech(text_input, audio_input, voice_name, chat_history):
"""Process user input and generate speech response."""
try:
query, status = get_transcription_or_text(text_input, audio_input)
if not query:
# Return proper message format for Gradio chatbot
new_history = chat_history + [{"role": "assistant", "content": status}]
return new_history, None, status, text_input, None
is_first_turn = len(chat_history) == 0
new_history = chat_history + [{"role": "user", "content": query}]
response_html = process_query(query, first_turn=is_first_turn)
new_history.append({"role": "assistant", "content": response_html})
# Clean text for TTS
cleaned_text = re.sub('<[^<]+?>', '', response_html).strip()
if not cleaned_text:
new_history[-1]["content"] = "The pipeline returned an empty response."
return new_history, None, "Status: Error - Empty response.", "", None
# Generate TTS
audio_data, tts_status = generate_tts_response(cleaned_text, voice_name)
if not audio_data:
# Add TTS status to the response
new_history[-1]["content"] = response_html + f"<br><br><i>({tts_status})</i>"
return new_history, None, tts_status, "", None
return new_history, audio_data, tts_status, "", None
except Exception as e:
logger.error(f"Error in process_input_and_generate_speech: {e}")
error_history = chat_history + [{"role": "assistant", "content": f"An error occurred: {str(e)}"}]
return error_history, None, f"Status: Error - {str(e)}", "", None
# ---------------------------
# Testing Functions
# ---------------------------
def test_encode_audio_functions():
"""Test audio encoding functions"""
print("\n" + "="*60)
print("TESTING AUDIO ENCODING FUNCTIONS")
print("="*60)
results = {}
# Create test audio data
test_data = np.array([1, 2, 3, 4, 5], dtype=np.int16)
try:
# Test encode_audio
print("Testing encode_audio...")
result1 = encode_audio(test_data)
expected_keys = {'mime_type', 'data'}
if set(result1.keys()) == expected_keys and result1['mime_type'] == 'audio/pcm':
print("βœ… encode_audio: PASS")
results['encode_audio'] = "βœ… PASS"
else:
print("❌ encode_audio: FAIL - incorrect format")
results['encode_audio'] = "❌ FAIL"
except Exception as e:
print(f"❌ encode_audio: ERROR - {e}")
results['encode_audio'] = f"❌ ERROR: {e}"
try:
# Test encode_audio2
print("Testing encode_audio2...")
result2 = encode_audio2(test_data)
if isinstance(result2, bytes) and len(result2) > 0:
print("βœ… encode_audio2: PASS")
results['encode_audio2'] = "βœ… PASS"
else:
print("❌ encode_audio2: FAIL - not bytes or empty")
results['encode_audio2'] = "❌ FAIL"
except Exception as e:
print(f"❌ encode_audio2: ERROR - {e}")
results['encode_audio2'] = f"❌ ERROR: {e}"
return results
def test_numpy_to_wav_conversion():
"""Test numpy array to WAV conversion"""
print("\n" + "="*60)
print("TESTING NUMPY TO WAV CONVERSION")
print("="*60)
results = {}
# Create test audio data - sine wave
sample_rate = 16000
duration = 0.1 # 0.1 seconds
frequency = 440 # A4 note
t = np.linspace(0, duration, int(sample_rate * duration))
test_audio = (np.sin(2 * np.pi * frequency * t) * 32767).astype(np.int16)
try:
print("Testing numpy_array_to_wav_bytes...")
wav_bytes = numpy_array_to_wav_bytes(test_audio, sample_rate)
if isinstance(wav_bytes, bytes) and len(wav_bytes) > 44: # WAV header is 44 bytes
print(f"βœ… WAV conversion: PASS - Generated {len(wav_bytes)} bytes")
results['wav_conversion'] = "βœ… PASS"
# Check if it starts with WAV header
if wav_bytes[:4] == b'RIFF' and wav_bytes[8:12] == b'WAVE':
print("βœ… WAV header validation: PASS")
results['wav_header'] = "βœ… PASS"
else:
print("⚠️ WAV header validation: WARNING - may not be valid WAV")
results['wav_header'] = "⚠️ WARNING"
else:
print("❌ WAV conversion: FAIL - invalid output")
results['wav_conversion'] = "❌ FAIL"
except Exception as e:
print(f"❌ WAV conversion: ERROR - {e}")
results['wav_conversion'] = f"❌ ERROR: {e}"
return results
def test_gemini_handler_initialization():
"""Test GeminiHandler class initialization"""
print("\n" + "="*60)
print("TESTING GEMINI HANDLER INITIALIZATION")
print("="*60)
results = {}
try:
print("Testing GeminiHandler initialization...")
handler = GeminiHandler()
# Check basic attributes
checks = {
'input_queue': isinstance(handler.input_queue, asyncio.Queue),
'output_queue': isinstance(handler.output_queue, asyncio.Queue),
'quit_event': isinstance(handler.quit, asyncio.Event),
'vad_initialized': hasattr(handler, 'vad'),
'config_loaded': hasattr(handler, 'model') and handler.model is not None
}
passed_checks = sum(checks.values())
total_checks = len(checks)
print(f"Initialization checks: {passed_checks}/{total_checks}")
for check_name, passed in checks.items():
status = "βœ…" if passed else "❌"
print(f" {status} {check_name}")
if passed_checks == total_checks:
results['gemini_handler_init'] = "βœ… PASS"
else:
results['gemini_handler_init'] = f"⚠️ PARTIAL: {passed_checks}/{total_checks}"
except Exception as e:
print(f"❌ GeminiHandler initialization: ERROR - {e}")
results['gemini_handler_init'] = f"❌ ERROR: {e}"
try:
print("Testing GeminiHandler copy method...")
handler = GeminiHandler()
handler_copy = handler.copy()
if isinstance(handler_copy, GeminiHandler) and handler_copy is not handler:
print("βœ… Copy method: PASS")
results['gemini_handler_copy'] = "βœ… PASS"
else:
print("❌ Copy method: FAIL")
results['gemini_handler_copy'] = "❌ FAIL"
except Exception as e:
print(f"❌ Copy method: ERROR - {e}")
results['gemini_handler_copy'] = f"❌ ERROR: {e}"
return results
def test_transcription_function_validation():
"""Test transcribe_audio function validation (without actual API calls)"""
print("\n" + "="*60)
print("TESTING TRANSCRIPTION FUNCTION VALIDATION")
print("="*60)
results = {}
# Test with missing API key
print("Testing with missing API key...")
original_key = os.environ.get("GEMINI_API_KEY")
if original_key:
del os.environ["GEMINI_API_KEY"]
try:
gen = transcribe_audio("nonexistent.wav")
result = next(gen)
if result.startswith("[ERROR]") and "API Key" in result:
print("βœ… API key validation: PASS")
results['api_key_validation'] = "βœ… PASS"
else:
print("❌ API key validation: FAIL")
results['api_key_validation'] = "❌ FAIL"
except Exception as e:
print(f"❌ API key validation: ERROR - {e}")
results['api_key_validation'] = f"❌ ERROR: {e}"
# Restore API key
if original_key:
os.environ["GEMINI_API_KEY"] = original_key
# Test with nonexistent file
print("Testing with nonexistent file...")
try:
gen = transcribe_audio("definitely_nonexistent_file.wav")
result = next(gen)
if result.startswith("[ERROR]") and "not found" in result:
print("βœ… File validation: PASS")
results['file_validation'] = "βœ… PASS"
else:
print("❌ File validation: FAIL")
results['file_validation'] = "❌ FAIL"
except Exception as e:
print(f"❌ File validation: ERROR - {e}")
results['file_validation'] = f"❌ ERROR: {e}"
return results
def test_text_input_processing():
"""Test get_transcription_or_text function"""
print("\n" + "="*60)
print("TESTING TEXT INPUT PROCESSING")
print("="*60)
results = {}
# Test with text input
print("Testing with text input...")
try:
text_input = "What is autism?"
audio_input = None
query, status = get_transcription_or_text(text_input, audio_input)
if query == text_input and "text query" in status:
print("βœ… Text input processing: PASS")
results['text_input'] = "βœ… PASS"
else:
print("❌ Text input processing: FAIL")
results['text_input'] = "❌ FAIL"
except Exception as e:
print(f"❌ Text input processing: ERROR - {e}")
results['text_input'] = f"❌ ERROR: {e}"
# Test with empty inputs
print("Testing with empty inputs...")
try:
query, status = get_transcription_or_text("", None)
if query is None and "Please type" in status:
print("βœ… Empty input handling: PASS")
results['empty_input'] = "βœ… PASS"
else:
print("❌ Empty input handling: FAIL")
results['empty_input'] = "❌ FAIL"
except Exception as e:
print(f"❌ Empty input handling: ERROR - {e}")
results['empty_input'] = f"❌ ERROR: {e}"
# Test with whitespace only
print("Testing with whitespace input...")
try:
query, status = get_transcription_or_text(" \n\t ", None)
if query is None and "Please type" in status:
print("βœ… Whitespace input handling: PASS")
results['whitespace_input'] = "βœ… PASS"
else:
print("❌ Whitespace input handling: FAIL")
results['whitespace_input'] = "❌ FAIL"
except Exception as e:
print(f"❌ Whitespace input handling: ERROR - {e}")
results['whitespace_input'] = f"❌ ERROR: {e}"
return results
def test_tts_function_structure():
"""Test TTS function structure and error handling"""
print("\n" + "="*60)
print("TESTING TTS FUNCTION STRUCTURE")
print("="*60)
results = {}
# Test with invalid voice name
print("Testing TTS function error handling...")
try:
# This should fail gracefully
audio_data, status = generate_tts_response("Hello world", "invalid_voice")
if audio_data is None and "error" in status.lower():
print("βœ… TTS error handling: PASS")
results['tts_error_handling'] = "βœ… PASS"
elif audio_data is not None:
print("βœ… TTS function: UNEXPECTED SUCCESS - function worked")
results['tts_error_handling'] = "βœ… UNEXPECTED SUCCESS"
else:
print("❌ TTS error handling: FAIL")
results['tts_error_handling'] = "❌ FAIL"
except Exception as e:
# This is expected if API is not available
print(f"βœ… TTS error handling: EXPECTED ERROR - {str(e)[:100]}")
results['tts_error_handling'] = "βœ… EXPECTED ERROR"
# Test with empty text
print("Testing TTS with empty text...")
try:
audio_data, status = generate_tts_response("", "Puck")
if audio_data is None:
print("βœ… Empty text handling: PASS")
results['tts_empty_text'] = "βœ… PASS"
else:
print("⚠️ Empty text handling: WARNING - generated audio for empty text")
results['tts_empty_text'] = "⚠️ WARNING"
except Exception as e:
print(f"βœ… Empty text handling: EXPECTED ERROR - {str(e)[:100]}")
results['tts_empty_text'] = "βœ… EXPECTED ERROR"
return results
def test_main_processing_function():
"""Test the main process_input_and_generate_speech function"""
print("\n" + "="*60)
print("TESTING MAIN PROCESSING FUNCTION")
print("="*60)
results = {}
# Test with valid text input
print("Testing main processing with text input...")
try:
text_input = "What is autism?"
audio_input = None
voice_name = "Puck"
chat_history = []
new_history, audio_data, status, cleared_text, cleared_audio = process_input_and_generate_speech(
text_input, audio_input, voice_name, chat_history
)
# Check if function returns expected structure
expected_items = 5
if len([new_history, audio_data, status, cleared_text, cleared_audio]) == expected_items:
print("βœ… Return structure: PASS - correct number of return values")
# Check if history is updated
if isinstance(new_history, list) and len(new_history) >= 2:
print("βœ… Chat history update: PASS")
results['history_update'] = "βœ… PASS"
else:
print("❌ Chat history update: FAIL")
results['history_update'] = "❌ FAIL"
# Check status
if isinstance(status, str):
print("βœ… Status return: PASS")
results['status_return'] = "βœ… PASS"
else:
print("❌ Status return: FAIL")
results['status_return'] = "❌ FAIL"
else:
print(f"❌ Return structure: FAIL - expected {expected_items} items")
results['return_structure'] = "❌ FAIL"
except Exception as e:
print(f"⚠️ Main processing: EXPECTED ERROR - {str(e)[:100]}")
results['main_processing'] = "⚠️ EXPECTED ERROR (API dependency)"
# Test with empty inputs
print("Testing main processing with empty inputs...")
try:
new_history, audio_data, status, cleared_text, cleared_audio = process_input_and_generate_speech(
"", None, "Puck", []
)
if isinstance(status, str) and "Please type" in status:
print("βœ… Empty input handling: PASS")
results['empty_input_main'] = "βœ… PASS"
else:
print("❌ Empty input handling: FAIL")
results['empty_input_main'] = "❌ FAIL"
except Exception as e:
print(f"❌ Empty input handling: ERROR - {e}")
results['empty_input_main'] = f"❌ ERROR: {e}"
return results
def test_environment_and_config():
"""Test environment variables and configuration loading"""
print("\n" + "="*60)
print("TESTING ENVIRONMENT AND CONFIGURATION")
print("="*60)
results = {}
# Test configuration loading
try:
print("Testing configuration loading...")
required_config_keys = ['audio']
config_checks = {}
for key in required_config_keys:
config_checks[key] = key in config
if all(config_checks.values()):
print("βœ… Config loading: PASS")
results['config_loading'] = "βœ… PASS"
else:
failed_keys = [k for k, v in config_checks.items() if not v]
print(f"❌ Config loading: FAIL - missing keys: {failed_keys}")
results['config_loading'] = f"❌ FAIL: missing {failed_keys}"
except Exception as e:
print(f"❌ Config loading: ERROR - {e}")
results['config_loading'] = f"❌ ERROR: {e}"
# Test audio config specifically
try:
print("Testing audio configuration...")
if 'audio' in config:
audio_config = config['audio']
required_audio_keys = ['model_live', 'tts_model', 'stt_model', 'VAD_RATE', 'VAD_FRAME_MS']
audio_checks = {}
for key in required_audio_keys:
audio_checks[key] = key in audio_config
passed_audio = sum(audio_checks.values())
total_audio = len(audio_checks)
print(f"Audio config checks: {passed_audio}/{total_audio}")
for key, passed in audio_checks.items():
status = "βœ…" if passed else "❌"
print(f" {status} {key}")
if passed_audio == total_audio:
results['audio_config'] = "βœ… PASS"
else:
results['audio_config'] = f"⚠️ PARTIAL: {passed_audio}/{total_audio}"
else:
print("❌ Audio configuration: FAIL - no audio section")
results['audio_config'] = "❌ FAIL"
except Exception as e:
print(f"❌ Audio configuration: ERROR - {e}")
results['audio_config'] = f"❌ ERROR: {e}"
# Test environment variables
print("Testing environment variables...")
env_vars = ['GEMINI_API_KEY', 'SILICONFLOW_API_KEY']
env_results = {}
for var in env_vars:
value = os.getenv(var)
if value:
print(f"βœ… {var}: SET")
env_results[var] = "βœ… SET"
else:
print(f"❌ {var}: NOT SET")
env_results[var] = "❌ NOT SET"
results.update(env_results)
return results
def create_test_audio_file(filename="test_audio.wav", duration=1.0, sample_rate=16000):
"""Create a test audio file for testing purposes"""
try:
# Generate a simple sine wave
t = np.linspace(0, duration, int(sample_rate * duration))
frequency = 440 # A4 note
audio_data = (np.sin(2 * np.pi * frequency * t) * 0.3 * 32767).astype(np.int16)
# Save as WAV file
sf.write(filename, audio_data, sample_rate)
return filename
except Exception as e:
print(f"Failed to create test audio file: {e}")
return None
def run_performance_benchmarks():
"""Run performance benchmarks on key functions"""
print("\n" + "="*60)
print("RUNNING PERFORMANCE BENCHMARKS")
print("="*60)
results = {}
# Benchmark encode_audio functions
print("Benchmarking audio encoding functions...")
test_data_sizes = [1000, 10000, 100000] # Different sizes
for size in test_data_sizes:
test_data = np.random.randint(-32768, 32767, size, dtype=np.int16)
# Benchmark encode_audio
start_time = time.time()
for _ in range(100): # 100 iterations
encode_audio(test_data)
encode_audio_time = (time.time() - start_time) / 100
# Benchmark encode_audio2
start_time = time.time()
for _ in range(100):
encode_audio2(test_data)
encode_audio2_time = (time.time() - start_time) / 100
print(f"Size {size} samples:")
print(f" encode_audio: {encode_audio_time*1000:.2f}ms")
print(f" encode_audio2: {encode_audio2_time*1000:.2f}ms")
results[f'encode_audio_{size}'] = f"{encode_audio_time*1000:.2f}ms"
results[f'encode_audio2_{size}'] = f"{encode_audio2_time*1000:.2f}ms"
# Benchmark WAV conversion
print("\nBenchmarking WAV conversion...")
test_audio = np.random.randint(-32768, 32767, 16000, dtype=np.int16) # 1 second
start_time = time.time()
for _ in range(10):
numpy_array_to_wav_bytes(test_audio)
wav_time = (time.time() - start_time) / 10
print(f"WAV conversion (1s audio): {wav_time*1000:.2f}ms")
results['wav_conversion_benchmark'] = f"{wav_time*1000:.2f}ms"
return results
def run_integration_tests():
"""Run integration tests that test multiple components together"""
print("\n" + "="*60)
print("RUNNING INTEGRATION TESTS")
print("="*60)
results = {}
# Test GeminiHandler + audio encoding integration
print("Testing GeminiHandler initialization with audio encoding...")
try:
handler = GeminiHandler()
test_data = np.array([1, 2, 3, 4, 5], dtype=np.int16)
# Test if handler can work with encoded audio
encoded = encode_audio(test_data)
raw_bytes = encode_audio2(test_data)
if handler and encoded and raw_bytes:
print("βœ… Handler + Encoding integration: PASS")
results['handler_encoding'] = "βœ… PASS"
else:
print("❌ Handler + Encoding integration: FAIL")
results['handler_encoding'] = "❌ FAIL"
except Exception as e:
print(f"❌ Handler + Encoding integration: ERROR - {e}")
results['handler_encoding'] = f"❌ ERROR: {e}"
# Test text processing pipeline
print("Testing text processing pipeline...")
try:
text_input = "Hello world"
query, status = get_transcription_or_text(text_input, None)
if query == text_input and "text query" in status:
print("βœ… Text processing pipeline: PASS")
results['text_pipeline'] = "βœ… PASS"
else:
print("❌ Text processing pipeline: FAIL")
results['text_pipeline'] = "❌ FAIL"
except Exception as e:
print(f"❌ Text processing pipeline: ERROR - {e}")
results['text_pipeline'] = f"❌ ERROR: {e}"
return results
def run_all_tests():
"""Run all test functions and provide a comprehensive report"""
print("\n" + "πŸ§ͺ" + "="*58)
print("πŸ§ͺ RUNNING COMPREHENSIVE AUDIO UTILS TESTS")
print("πŸ§ͺ" + "="*58)
test_results = {}
# Run all test categories
print("Starting audio utilities test suite...")
test_results["Environment & Config"] = test_environment_and_config()
test_results["Audio Encoding"] = test_encode_audio_functions()
test_results["WAV Conversion"] = test_numpy_to_wav_conversion()
test_results["GeminiHandler"] = test_gemini_handler_initialization()
test_results["Transcription Validation"] = test_transcription_function_validation()
test_results["Text Processing"] = test_text_input_processing()
test_results["TTS Structure"] = test_tts_function_structure()
test_results["Main Processing"] = test_main_processing_function()
test_results["Performance"] = run_performance_benchmarks()
test_results["Integration"] = run_integration_tests()
# Print comprehensive summary
print("\n" + "πŸ“‹" + "="*58)
print("πŸ“‹ COMPREHENSIVE TEST SUMMARY")
print("πŸ“‹" + "="*58)
total_categories = len(test_results)
passed_categories = 0
for category, results in test_results.items():
print(f"\nπŸ”§ {category}:")
if isinstance(results, dict):
category_passed = 0
category_total = 0
for test_name, result in results.items():
category_total += 1
if result.startswith("βœ…"):
category_passed += 1
status = "PASS"
elif result.startswith("⚠️"):
status = "WARNING"
else:
status = "FAIL/ERROR"
print(f" β€’ {test_name}: {status}")
category_success_rate = category_passed / category_total if category_total > 0 else 0
if category_success_rate >= 0.8: # 80% success rate
passed_categories += 1
print(f" πŸ“Š Category Score: {category_passed}/{category_total} ({category_success_rate:.1%})")
else:
print(f" πŸ“Š {results}")
# Overall summary
overall_success_rate = passed_categories / total_categories
print(f"\nπŸ† OVERALL RESULTS:")
print(f" Categories Passed: {passed_categories}/{total_categories}")
print(f" Success Rate: {overall_success_rate:.1%}")
if overall_success_rate >= 0.8:
print(" Status: βœ… SYSTEM READY")
elif overall_success_rate >= 0.6:
print(" Status: ⚠️ NEEDS ATTENTION")
else:
print(" Status: ❌ REQUIRES FIXES")
print("\n🏁 Audio utilities testing completed!")
return test_results
if __name__ == "__main__":
logger.info("Audio utils module loaded successfully.")
# Interactive testing menu
print("\n" + "🎡" + "="*58)
print("🎡 AUDIO UTILS TESTING SUITE")
print("🎡" + "="*58)
import sys
if len(sys.argv) > 1:
# Command line mode
mode = sys.argv[1].lower()
if mode == "all":
run_all_tests()
elif mode == "encoding":
test_encode_audio_functions()
elif mode == "wav":
test_numpy_to_wav_conversion()
elif mode == "handler":
test_gemini_handler_initialization()
elif mode == "transcription":
test_transcription_function_validation()
elif mode == "text":
test_text_input_processing()
elif mode == "tts":
test_tts_function_structure()
elif mode == "main":
test_main_processing_function()
elif mode == "env":
test_environment_and_config()
elif mode == "performance":
run_performance_benchmarks()
elif mode == "integration":
run_integration_tests()
else:
print(f"Unknown test mode: {mode}")
print("Available modes: all, encoding, wav, handler, transcription, text, tts, main, env, performance, integration")
else:
# Interactive mode
while True:
print("\n" + "🎡" + " "*20 + "TEST MENU" + " "*20 + "🎡")
print("1. 🌐 Run All Tests")
print("2. πŸ”§ Environment & Config")
print("3. 🎧 Audio Encoding Functions")
print("4. 🎡 WAV Conversion")
print("5. πŸ€– GeminiHandler Tests")
print("6. 🎀 Transcription Validation")
print("7. πŸ“ Text Processing")
print("8. πŸ”Š TTS Function Structure")
print("9. πŸŽ›οΈ Main Processing Function")
print("10. ⚑ Performance Benchmarks")
print("11. πŸ”— Integration Tests")
print("12. πŸ§ͺ Create Test Audio File")
print("0. πŸšͺ Exit")
choice = input("\nEnter your choice (0-12): ").strip()
if choice == "1":
run_all_tests()
elif choice == "2":
test_environment_and_config()
elif choice == "3":
test_encode_audio_functions()
elif choice == "4":
test_numpy_to_wav_conversion()
elif choice == "5":
test_gemini_handler_initialization()
elif choice == "6":
test_transcription_function_validation()
elif choice == "7":
test_text_input_processing()
elif choice == "8":
test_tts_function_structure()
elif choice == "9":
test_main_processing_function()
elif choice == "10":
run_performance_benchmarks()
elif choice == "11":
run_integration_tests()
elif choice == "12":
filename = create_test_audio_file()
if filename:
print(f"βœ… Test audio file created: {filename}")
else:
print("❌ Failed to create test audio file")
elif choice == "0":
print("\nπŸ‘‹ Audio testing complete!")
break
else:
print("❌ Invalid choice. Please try again.")
input("\nPress Enter to continue...")