Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import librosa | |
| from transformers import AutoModelForAudioClassification, AutoFeatureExtractor | |
| import torch | |
| import numpy as np | |
| import os | |
| # Load the pre-trained model and feature extractor for genre prediction | |
| model_name = "sanchit-gandhi/distilhubert-finetuned-gtzan" | |
| feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) | |
| model = AutoModelForAudioClassification.from_pretrained(model_name) | |
| # List of genres the model can predict | |
| genres = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"] | |
| # Function to process the uploaded audio file | |
| def process_audio(audio_file, user_label): | |
| try: | |
| # Extract filename from the uploaded file path | |
| filename = os.path.basename(audio_file) | |
| # Load the audio file with its native sample rate | |
| audio, sr = librosa.load(audio_file, sr=None) | |
| # Extract duration | |
| duration = librosa.get_duration(y=audio, sr=sr) | |
| # Extract tempo | |
| tempo = librosa.beat.tempo(y=audio, sr=sr)[0] | |
| # Preprocess audio for the model (resample to 16kHz if needed) | |
| target_sr = 16000 | |
| if sr != target_sr: | |
| audio = librosa.resample(audio, orig_sr=sr, target_sr=target_sr) | |
| sr_model = target_sr | |
| else: | |
| sr_model = sr | |
| inputs = feature_extractor(audio, sampling_rate=sr_model, return_tensors="pt") | |
| # Predict genre using the model | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| probabilities = torch.nn.functional.softmax(logits, dim=-1).squeeze().numpy() | |
| predicted_genre = genres[np.argmax(probabilities)] | |
| # Use the user-provided label as the description | |
| description = f"{user_label}, {predicted_genre}, {tempo:.2f} BPM, {sr} Hz" | |
| # Create metadata dictionary | |
| metadata = { | |
| "filename": filename, | |
| "duration": np.round(duration, 3), | |
| "description": description, | |
| "genre": predicted_genre, | |
| "tempo": np.round(tempo, 2), | |
| "sample_rate": sr | |
| } | |
| return metadata | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # Gradio interface | |
| with gr.Blocks(theme="Surn/beeuty") as app: | |
| gr.Markdown("# Audio Classifier for MusicGen Fine Tuning") | |
| gr.Markdown("Upload a audio file (preferred `.wav`), provide a label, and get metadata for MusicGen training.") | |
| with gr.Row(): | |
| audio_input = gr.Audio(label="Upload Audio File", type="filepath") | |
| label_input = gr.Textbox(label="Enter Label", placeholder="e.g., A calm melody") | |
| submit_button = gr.Button("Classify") | |
| output_json = gr.JSON(label="Metadata Output") | |
| submit_button.click(process_audio, inputs=[audio_input, label_input], outputs=output_json) | |
| # Launch the app | |
| app.launch() |