Spaces:
Running
Running
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| import soundfile as sf | |
| import numpy as np | |
| import os | |
| import sys | |
| from pathlib import Path | |
| import base64 | |
| from io import BytesIO | |
| # Model and Tokenizer Loading | |
| MODEL_ID = "Qwen/Qwen-Audio-Chat" | |
| def load_model(): | |
| print("Loading model and tokenizer...") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True) | |
| print("Model and tokenizer loaded successfully") | |
| return model, tokenizer | |
| def process_audio(audio_path): | |
| """Process audio file for the model.""" | |
| try: | |
| print(f"Processing audio file: {audio_path}") | |
| # Read audio file | |
| audio_data, sample_rate = sf.read(audio_path) | |
| # Convert to mono if stereo | |
| if len(audio_data.shape) > 1: | |
| audio_data = audio_data.mean(axis=1) | |
| # Ensure float32 format | |
| audio_data = audio_data.astype(np.float32) | |
| # Create in-memory buffer | |
| audio_buffer = BytesIO() | |
| # Write audio to buffer in WAV format | |
| sf.write(audio_buffer, audio_data, sample_rate, format='WAV') | |
| # Get the buffer content and encode to base64 | |
| audio_buffer.seek(0) | |
| audio_base64 = base64.b64encode(audio_buffer.read()).decode('utf-8') | |
| print(f"Audio processed successfully. Sample rate: {sample_rate}, Shape: {audio_data.shape}") | |
| return { | |
| "audio": audio_base64, | |
| "sampling_rate": sample_rate | |
| } | |
| except Exception as e: | |
| print(f"Error processing audio: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| def analyze_audio(audio_path: str, question: str = None) -> str: | |
| """ | |
| Main function for audio analysis that will be exposed as a tool. | |
| Args: | |
| audio_path: Path to the audio file | |
| question: Optional question about the audio | |
| Returns: | |
| str: Model's response about the audio | |
| """ | |
| print(f"\nStarting analysis with audio_path: {audio_path}, question: {question}") | |
| # Input validation | |
| if audio_path is None or not isinstance(audio_path, str): | |
| return "Please provide a valid audio file." | |
| if not os.path.exists(audio_path): | |
| return f"Audio file not found: {audio_path}" | |
| # Process audio | |
| audio_data = process_audio(audio_path) | |
| if audio_data is None: | |
| return "Failed to process the audio file. Please ensure it's a valid audio format." | |
| try: | |
| model, tokenizer = load_model() | |
| query = question if question else "Please describe what you hear in this audio clip." | |
| print("Preparing messages...") | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "audio", | |
| "data": audio_data["audio"], | |
| "sampling_rate": audio_data["sampling_rate"] | |
| }, | |
| { | |
| "type": "text", | |
| "text": query | |
| } | |
| ] | |
| } | |
| ] | |
| print("Applying chat template...") | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| print(f"Generated prompt text: {text[:200]}...") | |
| print("Tokenizing input...") | |
| model_inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| print("Generating response...") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **model_inputs, | |
| max_new_tokens=512, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| bos_token_id=tokenizer.bos_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| if outputs is None: | |
| print("Model generated None output") | |
| return "The model failed to generate a response. Please try again." | |
| print(f"Output shape: {outputs.shape}") | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| print(f"Generated response: {response[:200]}...") | |
| return response | |
| except Exception as e: | |
| print(f"Error during processing: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return f"An error occurred while processing: {str(e)}" | |
| # Create Gradio interface with clear input/output specifications | |
| demo = gr.Interface( | |
| fn=analyze_audio, | |
| inputs=[ | |
| gr.Audio( | |
| type="filepath", | |
| label="Audio Input", | |
| sources=["upload", "microphone"], | |
| format="mp3" # Specify format to ensure consistent audio format | |
| ), | |
| gr.Textbox( | |
| label="Question", | |
| placeholder="Optional: Ask a specific question about the audio", | |
| value="" | |
| ) | |
| ], | |
| outputs=gr.Textbox(label="Analysis"), | |
| title="Qwen Audio Analysis Tool", | |
| description="Upload an audio file or record from microphone to get AI-powered analysis using Qwen-Audio-Chat model", | |
| examples=[ | |
| ["path/to/example1.wav", "What instruments do you hear?"], | |
| ["path/to/example2.wav", "Describe the mood of this audio."] | |
| ], | |
| cache_examples=False | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |