Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import os | |
| import allin1 | |
| import time | |
| import json | |
| import torch | |
| import librosa | |
| import numpy as np | |
| from pathlib import Path | |
| HEADER = """ | |
| <header style="text-align: center;"> | |
| <h1> | |
| All-In-One Music Structure Analyzer 🔮 | |
| </h1> | |
| <p> | |
| <a href="https://github.com/mir-aidj/all-in-one">[Python Package]</a> | |
| <a href="https://arxiv.org/abs/2307.16425">[Paper]</a> | |
| <a href="https://taejun.kim/music-dissector/">[Visual Demo]</a> | |
| </p> | |
| </header> | |
| <main | |
| style="display: flex; justify-content: center;" | |
| > | |
| <div | |
| style="display: inline-block;" | |
| > | |
| <p> | |
| This Space demonstrates the music structure analyzer predicts: | |
| <ul | |
| style="padding-left: 1rem;" | |
| > | |
| <li>BPM</li> | |
| <li>Beats</li> | |
| <li>Downbeats</li> | |
| <li>Functional segment boundaries</li> | |
| <li>Functional segment labels (e.g. intro, verse, chorus, bridge, outro)</li> | |
| </ul> | |
| </p> | |
| <p> | |
| For more information, please visit the links above ✨🧸 | |
| </p> | |
| </div> | |
| </main> | |
| """ | |
| CACHE_EXAMPLES = os.getenv('CACHE_EXAMPLES', '1') == '1' | |
| base_dir = "/tmp/gradio/" | |
| # Defining sample rate for voice activity detection (must use multiple of 8k) | |
| SAMPLING_RATE = 32000 | |
| torch.set_num_threads(1) | |
| # Import of models to do voice detection | |
| model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad:v4.0', | |
| model='silero_vad', | |
| force_reload=True) | |
| (get_speech_timestamps, | |
| save_audio, | |
| read_audio, | |
| VADIterator, | |
| collect_chunks) = utils | |
| def analyze(path): | |
| #Measure time for inference | |
| start = time.time() | |
| string_path = path | |
| path = Path(path) | |
| result= allin1.analyze( | |
| path, | |
| out_dir='./struct', | |
| multiprocess=False, | |
| keep_byproducts=True, # TODO: remove this | |
| ) | |
| json_structure_output = None | |
| for root, dirs, files in os.walk(f"./struct"): | |
| for file_path in files: | |
| json_structure_output = os.path.join(root, file_path) | |
| print(json_structure_output) | |
| add_voice_label(json_structure_output, string_path) | |
| fig = allin1.visualize( | |
| result, | |
| multiprocess=False, | |
| ) | |
| fig.set_dpi(300) | |
| #allin1.sonify( | |
| # result, | |
| # out_dir='./sonif', | |
| # multiprocess=False, | |
| #) | |
| #sonif_path = Path(f'./sonif/{path.stem}.sonif{path.suffix}').resolve().as_posix() | |
| #Measure time for inference | |
| end = time.time() | |
| elapsed_time = end-start | |
| # Get the base name of the file | |
| file_name = os.path.basename(path) | |
| # Remove the extension from the file name | |
| file_name_without_extension = os.path.splitext(file_name)[0] | |
| print(file_name_without_extension) | |
| bass_path, drums_path, other_path, vocals_path = None, None, None, None | |
| for root, dirs, files in os.walk(f"./demix/htdemucs/{file_name_without_extension}"): | |
| for file_path in files: | |
| file_path = os.path.join(root, file_path) | |
| print(file_path) | |
| if "bass.wav" in file_path: | |
| bass_path = file_path | |
| if "vocals.wav" in file_path: | |
| vocals_path = file_path | |
| if "other.wav" in file_path: | |
| other_path = file_path | |
| if "drums.wav" in file_path: | |
| drums_path = file_path | |
| #return result.bpm, fig, sonif_path, elapsed_time | |
| return result.bpm, fig, elapsed_time, json_structure_output, bass_path, drums_path, other_path, vocals_path | |
| def aggregate_vocal_times(vocal_time): | |
| """ | |
| Aggregates multiple vocal segments into one single segment. This is done because | |
| usually segments are very short (<3 seconds) sections of audio. | |
| """ | |
| # This is an hyperparameter for the aggregation of the segments. This means we aggregate | |
| # until we don't find a segment which has a start_time NEXT_SEGMENT_SECONDS after the end_time | |
| # of the previous segment | |
| NEXT_SEGMENT_SECONDS = 5 | |
| try: | |
| start_time = 0.0 | |
| end_time = 0.0 | |
| begin_seq = True | |
| compressed_vocal_times = [] | |
| for vocal_time in vocal_times: | |
| if begin_seq: | |
| start_time = vocal_time['start_time'] | |
| end_time = vocal_time['end_time'] | |
| begin_seq = False | |
| continue | |
| if float(vocal_time['start_time']) < float(end_time) + NEXT_SEGMENT_SECONDS: | |
| end_time = vocal_time['end_time'] | |
| else: | |
| print(start_time, end_time) | |
| compressed_vocal_times.append( { | |
| "start_time": f"{start_time}", | |
| "end_time": f"{end_time}" | |
| } | |
| ) | |
| start_time = vocal_time['start_time'] | |
| end_time = vocal_time['end_time'] | |
| compressed_vocal_times.append( { | |
| "start_time": f"{start_time}", | |
| "end_time": f"{end_time}" | |
| } | |
| ) | |
| except Exception as e: | |
| print(f"An exception occurred: {e}") | |
| return compressed_vocal_times | |
| def add_voice_label(json_file, audio_path): | |
| # This is an hyperparameter of the model which determines wheter to consider | |
| # the segment voice of non voice | |
| THRESHOLD_PROBABILITY = 0.75 | |
| # Load the JSON file | |
| with open(json_file, 'r') as f: | |
| data = json.load(f) | |
| # Create VAD object | |
| vad_iterator = VADIterator(model) | |
| # Read input audio file | |
| wav, _ = librosa.load(audio_path, sr=SAMPLING_RATE, mono=True) | |
| speech_probs = [] | |
| # Size of the window we compute the probability on. | |
| # This is an hyperparameter for the detection and can be changed to obtain different | |
| # result. I found this to be optimal. | |
| window_size_samples = int(SAMPLING_RATE/4) | |
| for i in range(0, len(wav), window_size_samples): | |
| chunk = torch.from_numpy(wav[i: i+ window_size_samples]) | |
| if len(chunk) < window_size_samples: | |
| break | |
| speech_prob = model(chunk, SAMPLING_RATE).item() | |
| speech_probs.append(speech_prob) | |
| vad_iterator.reset_states() # reset model states after each audio | |
| voice_idxs = np.where(np.array(speech_probs) >= THRESHOLD_PROBABILITY)[0] | |
| print(len(voice_idxs)) | |
| if len(voice_idxs) == 0: | |
| print("NO VOICE SEGMENTS DETECTED!") | |
| try: | |
| begin_seq = True | |
| start_idx = 0 | |
| vocal_times=[] | |
| for i in range(len(voice_idxs)-1): | |
| if begin_seq: | |
| start_idx = voice_idxs[i] | |
| begin_seq = False | |
| if voice_idxs[i+1] == voice_idxs[i]+1: | |
| continue | |
| start_time = float((start_idx*window_size_samples)/SAMPLING_RATE) | |
| end_time = float((voice_idxs[i]*window_size_samples)/SAMPLING_RATE) | |
| vocal_times.append( { | |
| "start_time": f"{start_time:.2f}", | |
| "end_time": f"{end_time:.2f}" | |
| } | |
| ) | |
| begin_seq = True | |
| vocal_times = aggregate_vocal_times(vocal_times) | |
| data['vocal_times'] = vocal_times | |
| except Exception as e: | |
| print(f"An exception occurred: {e}") | |
| with open(json_file, 'w') as f: | |
| print("writing_to_json...") | |
| json.dump(data, f, indent=4) | |
| with gr.Blocks() as demo: | |
| gr.HTML(HEADER) | |
| input_audio_path = gr.Audio( | |
| label='Input', | |
| type='filepath', | |
| format='mp3', | |
| show_download_button=False, | |
| ) | |
| button = gr.Button('Analyze', variant='primary') | |
| output_viz = gr.Plot(label='Visualization') | |
| with gr.Row(): | |
| output_bpm = gr.Textbox(label='BPM', scale=1) | |
| #output_sonif = gr.Audio( | |
| # label='Sonification', | |
| # type='filepath', | |
| # format='mp3', | |
| # show_download_button=False, | |
| # scale=9, | |
| #) | |
| elapsed_time = gr.Textbox(label='Overall inference time', scale=1) | |
| json_structure_output = gr.File(label="Json structure") | |
| with gr.Column(): | |
| bass = gr.Audio(label='bass', show_share_button=False) | |
| vocals =gr.Audio(label='vocals', show_share_button=False) | |
| other = gr.Audio(label='other', show_share_button=False) | |
| drums =gr.Audio(label='drums', show_share_button=False) | |
| #bass_path = gr.Textbox(label='bass_path', scale=1) | |
| #drums_path = gr.Textbox(label='drums_path', scale=1) | |
| #other_path = gr.Textbox(label='other_path', scale=1) | |
| #vocals_path = gr.Textbox(label='vocals_path', scale=1) | |
| #gr.Examples( | |
| # examples=[ | |
| # './assets/NewJeans - Super Shy.mp3', | |
| # './assets/Bruno Mars - 24k Magic.mp3' | |
| # ], | |
| # inputs=input_audio_path, | |
| # outputs=[output_bpm, output_viz, output_sonif], | |
| # fn=analyze, | |
| # cache_examples=CACHE_EXAMPLES, | |
| #) | |
| button.click( | |
| fn=analyze, | |
| inputs=input_audio_path, | |
| #outputs=[output_bpm, output_viz, output_sonif, elapsed_time], | |
| outputs=[output_bpm, output_viz, elapsed_time, json_structure_output, bass, drums, other, vocals], | |
| api_name='analyze', | |
| ) | |
| if __name__ == '__main__': | |
| demo.launch() | |