final_agent_course / utils /audio_tool.py
tuan3335's picture
use langchain
040a6c6
"""
AUDIO PROCESSING TOOL - Groq Audio Only
Handles audio file transcription using Groq Whisper API
"""
import os
import tempfile
import requests
from typing import Dict, Any, Optional
from groq import Groq
from .state_manager import get_agent_state
class AudioTool:
def __init__(self):
self.client = Groq(api_key=os.environ.get("GROQ_API_KEY", ""))
self.model = "whisper-large-v3"
print("🎵 Audio Tool (Groq Whisper) initialized")
def process_audio(self, audio_input: str, **kwargs) -> Dict[str, Any]:
"""
Process audio files using Groq Whisper API
Supports URLs, file paths, and base64 audio
"""
try:
audio_path = self._prepare_audio_file(audio_input)
if not audio_path:
return self._error_result("Could not prepare audio file")
# Transcribe using Groq Whisper
transcript = self._transcribe_with_groq(audio_path)
# Cleanup temp file if created
if audio_path.startswith(tempfile.gettempdir()):
os.unlink(audio_path)
result = {
"transcript": transcript,
"source": audio_input,
"model": self.model,
"tool": "groq_whisper"
}
# Update agent state
state = get_agent_state()
state.cached_data["audio_analysis"] = result
return {
"success": True,
"data": result,
"summary": f"Audio transcribed: {transcript[:100]}..."
}
except Exception as e:
error_msg = f"Audio processing failed: {str(e)}"
print(f"❌ {error_msg}")
return self._error_result(error_msg)
def _prepare_audio_file(self, audio_input: str) -> Optional[str]:
"""Prepare audio file for processing"""
try:
# If it's a URL, download it
if audio_input.startswith(('http://', 'https://')):
return self._download_audio(audio_input)
# If it's a local file path
if os.path.exists(audio_input):
return audio_input
# If it's base64, decode it
if self._is_base64(audio_input):
return self._decode_base64_audio(audio_input)
return None
except Exception as e:
print(f"⚠️ Audio prep error: {str(e)}")
return None
def _download_audio(self, url: str) -> str:
"""Download audio from URL to temp file"""
response = requests.get(url, stream=True)
response.raise_for_status()
# Create temp file with audio extension
suffix = '.mp3' # Default
if '.' in url:
suffix = '.' + url.split('.')[-1].split('?')[0]
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
for chunk in response.iter_content(chunk_size=8192):
tmp_file.write(chunk)
return tmp_file.name
def _is_base64(self, s: str) -> bool:
"""Check if string is base64 encoded"""
import base64
try:
if isinstance(s, str):
s_bytes = bytes(s, 'ascii')
elif isinstance(s, bytes):
s_bytes = s
else:
return False
return base64.b64encode(base64.b64decode(s_bytes)) == s_bytes
except Exception:
return False
def _decode_base64_audio(self, b64_string: str) -> str:
"""Decode base64 audio to temp file"""
import base64
audio_data = base64.b64decode(b64_string)
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp3') as tmp_file:
tmp_file.write(audio_data)
return tmp_file.name
def _transcribe_with_groq(self, audio_path: str) -> str:
"""Transcribe audio using Groq Whisper API"""
with open(audio_path, "rb") as audio_file:
transcript = self.client.audio.transcriptions.create(
file=audio_file,
model=self.model,
language="en", # Auto-detect or specify
response_format="text"
)
return transcript if isinstance(transcript, str) else transcript.text
def _error_result(self, error_msg: str) -> Dict[str, Any]:
"""Standard error result format"""
return {
"success": False,
"error": error_msg,
"data": None,
"summary": f"Audio processing failed: {error_msg}"
}
def download_audio_file(task_id: str) -> Optional[str]:
"""
Download audio file from API
"""
try:
api_url = "https://agents-course-unit4-scoring.hf.space"
file_url = f"{api_url}/files/{task_id}"
response = requests.get(file_url, timeout=30)
if response.status_code == 200:
# Determine file extension
content_type = response.headers.get('content-type', '')
if 'audio' in content_type:
if 'mp3' in content_type:
suffix = '.mp3'
elif 'wav' in content_type:
suffix = '.wav'
elif 'ogg' in content_type:
suffix = '.ogg'
elif 'm4a' in content_type:
suffix = '.m4a'
else:
suffix = '.mp3' # Default
else:
suffix = '.mp3' # Default for unknown audio types
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp_file:
tmp_file.write(response.content)
return tmp_file.name
else:
return None
except Exception as e:
print(f"Error downloading audio: {e}")
return None
def transcribe_audio_groq(task_id: str = "", audio_path: str = "", language: str = "en") -> str:
"""
Main function: Transcribe audio với Groq Whisper API - model whisper-large-v3
Args:
task_id: ID để download file từ API
audio_path: Đường dẫn file audio local (nếu có)
language: Ngôn ngữ transcription (default: "en")
Returns:
Transcribed text
"""
target_audio_path = None
try:
# Initialize Groq client
groq_api_key = os.environ.get("GROQ_API_KEY")
if not groq_api_key:
return "Error: GROQ_API_KEY not found in environment variables"
groq_client = Groq(api_key=groq_api_key)
# Xác định đường dẫn audio
if audio_path and os.path.exists(audio_path):
target_audio_path = audio_path
elif task_id:
target_audio_path = download_audio_file(task_id)
if not target_audio_path:
return "Error: Could not download audio file"
else:
return "Error: No audio path or task_id provided"
# Kiểm tra file audio tồn tại
if not os.path.exists(target_audio_path):
return "Error: Audio file not found"
# Transcribe với Groq Whisper
with open(target_audio_path, "rb") as audio_file:
transcription = groq_client.audio.transcriptions.create(
file=(os.path.basename(target_audio_path), audio_file.read()),
model="whisper-large-v3",
response_format="text",
language=language,
temperature=0.0 # Deterministic results
)
# Lấy kết quả
if hasattr(transcription, 'text'):
result = transcription.text
else:
result = str(transcription)
# Cleanup downloaded file nếu cần
if task_id and target_audio_path != audio_path:
try:
os.unlink(target_audio_path)
except:
pass
return result.strip()
except Exception as e:
# Cleanup file nếu có lỗi
if task_id and target_audio_path and target_audio_path != audio_path:
try:
os.unlink(target_audio_path)
except:
pass
return f"Audio transcription error: {str(e)}"
def transcribe_audio_with_details(task_id: str = "", audio_path: str = "", language: str = "en") -> dict:
"""
Transcribe audio với thêm chi tiết metadata
Returns:
Dict chứa transcription và metadata
"""
try:
# Lấy transcription
text = transcribe_audio_groq(task_id, audio_path, language)
# Metadata cơ bản
metadata = {
"model": "whisper-large-v3",
"language": language,
"provider": "groq"
}
# Nếu có file local, lấy thêm thông tin
if audio_path and os.path.exists(audio_path):
file_size = os.path.getsize(audio_path)
metadata["file_size"] = file_size
metadata["file_path"] = audio_path
return {
"transcription": text,
"metadata": metadata,
"success": not text.startswith("Error:")
}
except Exception as e:
return {
"transcription": f"Error: {str(e)}",
"metadata": {},
"success": False
}
# Fallback function nếu Groq không khả dụng
def fallback_audio_info(task_id: str = "", audio_path: str = "") -> str:
"""
Fallback function khi không thể transcribe audio
"""
try:
target_audio_path = None
if audio_path and os.path.exists(audio_path):
target_audio_path = audio_path
elif task_id:
target_audio_path = download_audio_file(task_id)
if not target_audio_path:
return "Error: Could not download audio file"
else:
return "Error: No audio path or task_id provided"
# Basic file info
file_size = os.path.getsize(target_audio_path)
result = f"Audio file detected - Size: {file_size} bytes. Groq transcription not available. Please describe the audio content."
# Cleanup
if task_id and target_audio_path != audio_path:
try:
os.unlink(target_audio_path)
except:
pass
return result
except Exception as e:
return f"Audio processing error: {str(e)}"
# Test function
if __name__ == "__main__":
# Test với file audio local (nếu có)
test_audio = "/path/to/test/audio.mp3"
if os.path.exists(test_audio):
result = transcribe_audio_groq(audio_path=test_audio)
print("Transcription Result:", result)
else:
print("No test audio found")
# Test với task_id (cần API key)
# result = transcribe_audio_groq(task_id="some_task_id")
# print("Transcription Result:", result)