Qwen2-Audio-7B / app.py
desiree's picture
Upload app.py
c69cd11 verified
raw
history blame
5.7 kB
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import soundfile as sf
import numpy as np
import os
import sys
from pathlib import Path
import base64
from io import BytesIO
# Model and Tokenizer Loading
MODEL_ID = "Qwen/Qwen-Audio-Chat"
def load_model():
print("Loading model and tokenizer...")
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
print("Model and tokenizer loaded successfully")
return model, tokenizer
def process_audio(audio_path):
"""Process audio file for the model."""
try:
print(f"Processing audio file: {audio_path}")
# Read audio file
audio_data, sample_rate = sf.read(audio_path)
# Convert to mono if stereo
if len(audio_data.shape) > 1:
audio_data = audio_data.mean(axis=1)
# Ensure float32 format
audio_data = audio_data.astype(np.float32)
# Create in-memory buffer
audio_buffer = BytesIO()
# Write audio to buffer in WAV format
sf.write(audio_buffer, audio_data, sample_rate, format='WAV')
# Get the buffer content and encode to base64
audio_buffer.seek(0)
audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8')
print(f"Audio processed successfully. Sample rate: {sample_rate}, Shape: {audio_data.shape}")
return {
"audio": audio_base64,
"sampling_rate": sample_rate
}
except Exception as e:
print(f"Error processing audio: {e}")
import traceback
traceback.print_exc()
return None
@spaces.GPU
def analyze_audio(audio_path: str, question: str = None) -> str:
"""
Main function for audio analysis that will be exposed as a tool.
Args:
audio_path: Path to the audio file
question: Optional question about the audio
Returns:
str: Model's response about the audio
"""
print(f"\nStarting analysis with audio_path: {audio_path}, question: {question}")
# Input validation
if audio_path is None or not isinstance(audio_path, str):
return "Please provide a valid audio file."
if not os.path.exists(audio_path):
return f"Audio file not found: {audio_path}"
# Process audio
audio_data = process_audio(audio_path)
if audio_data is None:
return "Failed to process the audio file. Please ensure it's a valid audio format."
try:
model, tokenizer = load_model()
query = question if question else "Please describe what you hear in this audio clip."
print("Preparing messages...")
messages = [
{
"role": "user",
"content": [
{
"type": "audio",
"data": audio_data["audio"],
"sampling_rate": audio_data["sampling_rate"]
},
{
"type": "text",
"text": query
}
]
}
]
print("Applying chat template...")
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
print(f"Generated prompt text: {text[:200]}...")
print("Tokenizing input...")
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
print("Generating response...")
with torch.no_grad():
outputs = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id
)
if outputs is None:
print("Model generated None output")
return "The model failed to generate a response. Please try again."
print(f"Output shape: {outputs.shape}")
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
print(f"Generated response: {response[:200]}...")
return response
except Exception as e:
print(f"Error during processing: {str(e)}")
import traceback
traceback.print_exc()
return f"An error occurred while processing: {str(e)}"
# Create Gradio interface with clear input/output specifications
demo = gr.Interface(
fn=analyze_audio,
inputs=[
gr.Audio(
type="filepath",
label="Audio Input",
sources=["upload", "microphone"],
format="mp3" # Specify format to ensure consistent audio format
),
gr.Textbox(
label="Question",
placeholder="Optional: Ask a specific question about the audio",
value=""
)
],
outputs=gr.Textbox(label="Analysis"),
title="Qwen Audio Analysis Tool",
description="Upload an audio file or record from microphone to get AI-powered analysis using Qwen-Audio-Chat model",
examples=[
["path/to/example1.wav", "What instruments do you hear?"],
["path/to/example2.wav", "Describe the mood of this audio."]
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()