Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
|
@@ -10,6 +10,7 @@ from contextlib import nullcontext
|
|
| 10 |
from model import GPTConfig, GPT
|
| 11 |
from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
|
| 12 |
from pedalboard.io import AudioFile
|
|
|
|
| 13 |
|
| 14 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 15 |
|
|
@@ -22,7 +23,7 @@ ckpt_load = 'model.pt'
|
|
| 22 |
|
| 23 |
start = "000000000000\n"
|
| 24 |
num_samples = 1
|
| 25 |
-
max_new_tokens =
|
| 26 |
|
| 27 |
seed = random.randint(1, 100000)
|
| 28 |
torch.manual_seed(seed)
|
|
@@ -58,9 +59,9 @@ model.to(device)
|
|
| 58 |
if compile:
|
| 59 |
model = torch.compile(model)
|
| 60 |
|
| 61 |
-
tokenizer = re.compile(r'000000000000|\d{
|
| 62 |
|
| 63 |
-
meta_path = os.path.join('', 'meta.pkl')
|
| 64 |
with open(meta_path, 'rb') as f:
|
| 65 |
meta = pickle.load(f)
|
| 66 |
stoi = meta.get('stoi', None)
|
|
@@ -131,7 +132,6 @@ def generate_midi(temperature, top_k):
|
|
| 131 |
return midi_events
|
| 132 |
|
| 133 |
|
| 134 |
-
|
| 135 |
def write_midi(midi_events, bpm):
|
| 136 |
midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
|
| 137 |
midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0))
|
|
@@ -152,19 +152,21 @@ def write_midi(midi_events, bpm):
|
|
| 152 |
print(f"Generated: {midi_path}")
|
| 153 |
|
| 154 |
|
| 155 |
-
def render_wav(midi_file):
|
| 156 |
-
|
| 157 |
sf2_dir = 'sf2_kits'
|
| 158 |
audio_format = 's16'
|
| 159 |
sample_rate = '44100'
|
| 160 |
gain = '2.0'
|
| 161 |
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
print(sf2_file)
|
| 168 |
output_wav = os.path.join(temp_dir, 'output.wav')
|
| 169 |
|
| 170 |
with open(os.devnull, 'w') as devnull:
|
|
@@ -177,23 +179,7 @@ def render_wav(midi_file):
|
|
| 177 |
return output_wav
|
| 178 |
|
| 179 |
|
| 180 |
-
def
|
| 181 |
-
wav_fx = wav_raw
|
| 182 |
-
|
| 183 |
-
for setting in settings:
|
| 184 |
-
board = setting['board']
|
| 185 |
-
|
| 186 |
-
with AudioFile(wav_raw) as f:
|
| 187 |
-
with AudioFile(wav_fx, 'w', f.samplerate, f.num_channels) as o:
|
| 188 |
-
while f.tell() < f.frames:
|
| 189 |
-
chunk = f.read(int(f.samplerate))
|
| 190 |
-
effected = board(chunk, f.samplerate, reset=False)
|
| 191 |
-
o.write(effected)
|
| 192 |
-
|
| 193 |
-
return wav_fx
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
def generate_and_return_files(bpm, temperature, top_k):
|
| 197 |
midi_events = generate_midi(temperature, top_k)
|
| 198 |
if not midi_events:
|
| 199 |
return "Error generating MIDI.", None, None
|
|
@@ -201,7 +187,7 @@ def generate_and_return_files(bpm, temperature, top_k):
|
|
| 201 |
write_midi(midi_events, bpm)
|
| 202 |
|
| 203 |
midi_file = os.path.join(temp_dir, 'output.mid')
|
| 204 |
-
wav_raw = render_wav(midi_file)
|
| 205 |
wav_fx = os.path.join(temp_dir, 'output_fx.wav')
|
| 206 |
|
| 207 |
sfx_settings = [
|
|
@@ -226,22 +212,45 @@ def generate_and_return_files(bpm, temperature, top_k):
|
|
| 226 |
return midi_file, wav_fx
|
| 227 |
|
| 228 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 229 |
|
| 230 |
-
iface = gr.Interface(
|
| 231 |
-
fn=generate_and_return_files,
|
| 232 |
-
inputs=[
|
| 233 |
-
gr.Slider(minimum=50, maximum=200, step=1, value=87, label="bpm"),
|
| 234 |
-
gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="temperature"),
|
| 235 |
-
gr.Slider(minimum=4, maximum=128, step=1, value=16, label="top_k")
|
| 236 |
-
],
|
| 237 |
-
outputs=[
|
| 238 |
-
gr.File(label="MIDI File"),
|
| 239 |
-
gr.Audio(label="Generated Audio", type="filepath")
|
| 240 |
-
],
|
| 241 |
-
title="<h1 style='font-weight: bold; text-align: center;'>nanoMPC - AI Midi Drum Sequencer</h1>",
|
| 242 |
-
description="<p style='text-align:center;'>nanoMPC is a tiny transformer model that generates MIDI drum beats inspired by Lo-Fi, Boom Bap and other styles of Hip Hop.</p>",
|
| 243 |
-
theme="soft",
|
| 244 |
-
allow_flagging="never",
|
| 245 |
-
)
|
| 246 |
-
|
| 247 |
-
iface.launch()
|
|
|
|
| 10 |
from model import GPTConfig, GPT
|
| 11 |
from pedalboard import Pedalboard, Reverb, Compressor, Gain, Limiter
|
| 12 |
from pedalboard.io import AudioFile
|
| 13 |
+
import gradio as gr
|
| 14 |
|
| 15 |
in_space = os.getenv("SYSTEM") == "spaces"
|
| 16 |
|
|
|
|
| 23 |
|
| 24 |
start = "000000000000\n"
|
| 25 |
num_samples = 1
|
| 26 |
+
max_new_tokens = 384
|
| 27 |
|
| 28 |
seed = random.randint(1, 100000)
|
| 29 |
torch.manual_seed(seed)
|
|
|
|
| 59 |
if compile:
|
| 60 |
model = torch.compile(model)
|
| 61 |
|
| 62 |
+
tokenizer = re.compile(r'000000000000|\d{2}|\n')
|
| 63 |
|
| 64 |
+
meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
|
| 65 |
with open(meta_path, 'rb') as f:
|
| 66 |
meta = pickle.load(f)
|
| 67 |
stoi = meta.get('stoi', None)
|
|
|
|
| 132 |
return midi_events
|
| 133 |
|
| 134 |
|
|
|
|
| 135 |
def write_midi(midi_events, bpm):
|
| 136 |
midi_data = pretty_midi.PrettyMIDI(initial_tempo=bpm, resolution=96)
|
| 137 |
midi_data.time_signature_changes.append(pretty_midi.containers.TimeSignature(4, 4, 0))
|
|
|
|
| 152 |
print(f"Generated: {midi_path}")
|
| 153 |
|
| 154 |
|
| 155 |
+
def render_wav(midi_file, uploaded_sf2=None):
|
|
|
|
| 156 |
sf2_dir = 'sf2_kits'
|
| 157 |
audio_format = 's16'
|
| 158 |
sample_rate = '44100'
|
| 159 |
gain = '2.0'
|
| 160 |
|
| 161 |
+
if uploaded_sf2:
|
| 162 |
+
sf2_file = uploaded_sf2
|
| 163 |
+
else:
|
| 164 |
+
sf2_files = [f for f in os.listdir(sf2_dir) if f.endswith('.sf2')]
|
| 165 |
+
if not sf2_files:
|
| 166 |
+
raise ValueError("No SoundFont (.sf2) file found in directory.")
|
| 167 |
+
sf2_file = os.path.join(sf2_dir, random.choice(sf2_files))
|
| 168 |
|
| 169 |
+
print(f"Using SoundFont: {sf2_file}")
|
|
|
|
| 170 |
output_wav = os.path.join(temp_dir, 'output.wav')
|
| 171 |
|
| 172 |
with open(os.devnull, 'w') as devnull:
|
|
|
|
| 179 |
return output_wav
|
| 180 |
|
| 181 |
|
| 182 |
+
def generate_and_return_files(bpm, temperature, top_k, uploaded_sf2=None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
midi_events = generate_midi(temperature, top_k)
|
| 184 |
if not midi_events:
|
| 185 |
return "Error generating MIDI.", None, None
|
|
|
|
| 187 |
write_midi(midi_events, bpm)
|
| 188 |
|
| 189 |
midi_file = os.path.join(temp_dir, 'output.mid')
|
| 190 |
+
wav_raw = render_wav(midi_file, uploaded_sf2)
|
| 191 |
wav_fx = os.path.join(temp_dir, 'output_fx.wav')
|
| 192 |
|
| 193 |
sfx_settings = [
|
|
|
|
| 212 |
return midi_file, wav_fx
|
| 213 |
|
| 214 |
|
| 215 |
+
custom_css = """
|
| 216 |
+
#generate-btn {
|
| 217 |
+
background-color: #6366f1 !important;
|
| 218 |
+
color: white !important;
|
| 219 |
+
border: none !important;
|
| 220 |
+
font-size: 16px;
|
| 221 |
+
padding: 10px 20px;
|
| 222 |
+
border-radius: 5px;
|
| 223 |
+
cursor: pointer;
|
| 224 |
+
}
|
| 225 |
+
#generate-btn:hover {
|
| 226 |
+
background-color: #4f51c5 !important;
|
| 227 |
+
}
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
with gr.Blocks(css=custom_css, theme="soft") as iface:
|
| 231 |
+
gr.Markdown("<h1 style='font-weight: bold; text-align: center;'>nanoMPC - AI Midi Drum Sequencer</h1>")
|
| 232 |
+
gr.Markdown("<p style='text-align:center;'>nanoMPC is a tiny transformer model that generates MIDI drum beats inspired by Lo-Fi, Boom Bap and other styles of Hip Hop.</p>")
|
| 233 |
+
|
| 234 |
+
with gr.Row():
|
| 235 |
+
with gr.Column(scale=1):
|
| 236 |
+
bpm = gr.Slider(minimum=50, maximum=200, step=1, value=90, label="BPM")
|
| 237 |
+
temperature = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, value=1.0, label="Temperature")
|
| 238 |
+
top_k = gr.Slider(minimum=4, maximum=256, step=1, value=128, label="Top-k")
|
| 239 |
+
soundfont = gr.File(label="Optional: Upload SoundFont (preset=0, bank=0)")
|
| 240 |
+
|
| 241 |
+
with gr.Column(scale=1):
|
| 242 |
+
midi_file = gr.File(label="MIDI File Output")
|
| 243 |
+
audio_file = gr.Audio(label="Generated Audio Output", type="filepath")
|
| 244 |
+
generate_button = gr.Button("Generate", elem_id="generate-btn")
|
| 245 |
+
|
| 246 |
+
generate_button.click(
|
| 247 |
+
fn=generate_and_return_files,
|
| 248 |
+
inputs=[bpm, temperature, top_k, soundfont],
|
| 249 |
+
outputs=[midi_file, audio_file]
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
iface.launch(share=True)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
|
| 256 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|