Spaces:
Running
Running
| import os | |
| import pickle | |
| import torch | |
| import random | |
| import subprocess | |
| import re | |
| import pretty_midi | |
| import gradio as gr | |
| from contextlib import nullcontext | |
| from model import GPTConfig, GPT | |
| from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter | |
| from pedalboard.io import AudioFile | |
| import gradio as gr | |
| import spaces | |
| in_space = os.getenv("SYSTEM") == "spaces" | |
| temp_dir = 'temp' | |
| os.makedirs(temp_dir, exist_ok=True) | |
| init_from = 'resume' | |
| out_dir = 'checkpoints' | |
| ckpt_load = 'model.pt' | |
| start = "000000000000\n" | |
| num_samples = 1 | |
| max_new_tokens = 768 | |
| seed = random.randint(1, 100000) | |
| torch.manual_seed(seed) | |
| device = 'cpu' if torch.cuda.is_available() else 'cpu' | |
| dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' | |
| compile = False | |
| exec(open('configurator.py').read()) | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| device_type = 'cpu' if 'cuda' in device else 'cpu' | |
| ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] | |
| ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) | |
| if init_from == 'resume': | |
| ckpt_path = os.path.join(out_dir, ckpt_load) | |
| checkpoint = torch.load(ckpt_path, map_location=device, weights_only=True) | |
| gptconf = GPTConfig(**checkpoint['model_args']) | |
| model = GPT(gptconf) | |
| state_dict = checkpoint['model'] | |
| unwanted_prefix = '_orig_mod.' | |
| for k, v in list(state_dict.items()): | |
| if k.startswith(unwanted_prefix): | |
| state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
| model.load_state_dict(state_dict) | |
| elif init_from.startswith('gpt2'): | |
| model = GPT.from_pretrained(init_from, dict(dropout=0.0)) | |
| model.eval() | |
| model.to(device) | |
| if compile: | |
| model = torch.compile(model) | |
| tokenizer = re.compile(r'000000000000|\d{2}|\n') | |
| meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') | |
| with open(meta_path, 'rb') as f: | |
| meta = pickle.load(f) | |
| stoi = meta.get('stoi', None) | |
| itos = meta.get('itos', None) | |
| def encode(text): | |
| matches = tokenizer.findall(text) | |
| return [stoi[c] for c in matches] | |
| def decode(encoded): | |
| return ''.join([itos[i] for i in encoded]) | |
| def clear_midi(dir): | |
| for file in os.listdir(dir): | |
| if file.endswith('.mid'): | |
| os.remove(os.path.join(dir, file)) | |
| clear_midi(temp_dir) | |
| def generate_midi(temperature, top_k): | |
| start_ids = encode(start) | |
| x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) | |
| midi_events = [] | |
| seq_count = 0 | |
| with torch.no_grad(): | |
| for _ in range(num_samples): | |
| sequence = [] | |
| y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
| tkn_seq = decode(y[0].tolist()) | |
| lines = tkn_seq.splitlines() | |
| for event in lines: | |
| if event.startswith(start.strip()): | |
| if sequence: | |
| midi_events.append(sequence) | |
| sequence = [] | |
| seq_count += 1 | |
| elif event.strip() == "": | |
| continue | |
| else: | |
| try: | |
| p = int(event[0:2]) | |
| v = int(event[2:4]) | |
| s = int(event[4:8]) | |
| e = int(event[8:12]) | |
| except ValueError: | |
| p, v, s, e = 0, 0, 0, 0 | |
| sequence.append({'file_name': f'nanompc_{seq_count:02d}', 'pitch': p, 'velocity': v, 'start': s, 'end': e}) | |
| if sequence: | |
| midi_events.append(sequence) | |
| round_bars = [] | |
| for sequence in midi_events: | |
| filtered_sequence = [] | |
| for event in sequence: | |
| if event['start'] < 1536 and event['end'] <= 1536: | |
| filtered_sequence.append(event) | |
| if filtered_sequence: | |
| round_bars.append(filtered_sequence) | |
| midi_events = round_bars | |
| for track in midi_events: | |
| track.sort(key=lambda x: x['start']) | |
| unique_notes = [] | |
| for note in track: | |
| if not any(abs(note['start'] - n['start']) < 12 and note['pitch'] == n['pitch'] for n in unique_notes): | |
| unique_notes.append(note) | |
| track[:] = unique_notes | |
| return midi_events | |
| def write_single_midi(midi_events, bpm): | |
| midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96) | |
| midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0)) | |
| instrument = pretty_midi.Instrument(0) | |
| midi_data.instruments.append(instrument) | |
| for event in midi_events[0]: | |
| pitch = event['pitch'] | |
| velocity = event['velocity'] | |
| start = midi_data.tick_to_time(event['start']) | |
| end = midi_data.tick_to_time(event['end']) | |
| note = pretty_midi.Note(pitch=pitch, velocity=velocity, start=start, end=end) | |
| instrument.notes.append(note) | |
| midi_path = os.path.join(temp_dir, 'output.mid') | |
| midi_data.write(midi_path) | |
| print(f"Generated: {midi_path}") | |
| def render_wav(midi_file, uploaded_sf2=None, output_level='2.0'): | |
| sf2_dir = 'sf2_kits' | |
| audio_format = 's16' | |
| sample_rate = '44100' | |
| gain = str(output_level) | |
| if uploaded_sf2: | |
| sf2_file = uploaded_sf2 | |
| else: | |
| sf2_files = [f for f in os.listdir(os.path.join(sf2_dir)) if f.endswith('.sf2')] | |
| if not sf2_files: | |
| raise ValueError("No SoundFont (.sf2) file found in directory.") | |
| sf2_file = os.path.join(sf2_dir, random.choice(sf2_files)) | |
| output_wav = os.path.join(temp_dir, 'output.wav') | |
| with open(os.devnull, 'w') as devnull: | |
| command = [ | |
| 'fluidsynth', '-ni', sf2_file, midi_file, '-F', output_wav, '-r', str(sample_rate), | |
| '-o', f'audio.file.format={audio_format}', '-g', str(gain) | |
| ] | |
| subprocess.call(command, stdout=devnull, stderr=devnull) | |
| return output_wav | |
| def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None, output_level='2.0'): | |
| midi_events = generate_midi(temperature, top_k) | |
| if not midi_events: | |
| return "Error generating MIDI.", None, None | |
| write_single_midi(midi_events, bpm) | |
| midi_file = os.path.join(temp_dir, 'output.mid') | |
| wav_raw = render_wav(midi_file, uploaded_sf2, output_level) | |
| wav_fx = os.path.join(temp_dir, 'output_fx.wav') | |
| sfx_settings = [ | |
| { | |
| 'board': Pedalboard([ | |
| Reverb(room_size=0.01, wet_level=random.uniform(0.005, 0.01), dry_level=0.75, width=1.0), | |
| Compressor(threshold_db=-3.0, ratio=8.0, attack_ms=0.0, release_ms=300.0), | |
| ]) | |
| } | |
| ] | |
| for setting in sfx_settings: | |
| board = setting['board'] | |
| with AudioFile(wav_raw) as f: | |
| with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o: | |
| while f.tell() < f.frames: | |
| chunk = f.read(int(f.samplerate)) | |
| effected = board(chunk, f.samplerate, reset=False) | |
| o.write(effected) | |
| return midi_file, wav_fx | |
| custom_css = """ | |
| #container { | |
| max-width: 1200px !important; | |
| margin: 0 auto !important; | |
| } | |
| #generate-btn { | |
| font-size: 18px; | |
| color: white; | |
| padding: 10px 20px; | |
| border: none; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| background: linear-gradient(90deg, hsla(268, 90%, 70%, 1) 0%, hsla(260, 72%, 74%, 1) 50%, hsla(247, 73%, 65%, 1) 100%); | |
| transition: background 1s ease; | |
| } | |
| #generate-btn:hover { | |
| color: white; | |
| background: linear-gradient(90deg, hsla(268, 90%, 62%, 1) 0%, hsla(260, 70%, 70%, 1) 50%, hsla(247, 73%, 55%, 1) 100%); | |
| } | |
| #container .prose { | |
| text-align: center !important; | |
| } | |
| #container h1 { | |
| font-weight: bold; | |
| font-size: 40px; | |
| margin: 0px; | |
| } | |
| #container p { | |
| font-size: 18px; | |
| text-align: center; | |
| } | |
| """ | |
| with gr.Blocks( | |
| css=custom_css, | |
| theme=gr.themes.Default( | |
| font=[gr.themes.GoogleFont("Roboto"), "sans-serif"], | |
| primary_hue="violet", | |
| secondary_hue="violet" | |
| ) | |
| ) as iface: | |
| with gr.Column(elem_id="container"): | |
| gr.Markdown("<h1>nanoMPC</h1>") | |
| gr.Markdown("<p>nanoMPC is a MIDI transformer model that generates lo-fi and boom bap beats.</p>") | |
| bpm = gr.Slider(minimum=50, maximum=200, step=1, value=90, label="BPM") | |
| temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature") | |
| top_k = gr.Slider(minimum=4, maximum=16, step=1, value=8, label="Top-k") | |
| output_level = gr.Slider(minimum=0, maximum=3, step=0.10, value=2.0, label="Output Gain") | |
| generate_button = gr.Button("Generate", elem_id="generate-btn") | |
| midi_file = gr.File(label="MIDI Output") | |
| audio_file = gr.Audio(label="Audio Output", type="filepath") | |
| soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)") | |
| generate_button.click( | |
| fn=generate_and_return_files, | |
| inputs=[bpm, temperature, top_k, soundfont, output_level], | |
| outputs=[midi_file, audio_file] | |
| ) | |
| gr.Markdown("<p style='font-size: 16px;'>Developed by <a href='https://www.patchbanks.com/' target='_blank'><strong>Patchbanks</strong></a></p>") | |
| iface.launch(share=True) | |