Spaces:
Runtime error
Runtime error
kind of working
Browse files- app.py +12 -6
- audiocraft/models/musicgen.py +19 -6
app.py
CHANGED
|
@@ -59,6 +59,9 @@ def load_model(version='melody'):
|
|
| 59 |
|
| 60 |
|
| 61 |
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
|
|
|
|
|
|
|
|
| 62 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
| 63 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
| 64 |
be = time.time()
|
|
@@ -76,7 +79,7 @@ def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
|
| 76 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
| 77 |
processed_melodies.append(melody)
|
| 78 |
|
| 79 |
-
if
|
| 80 |
outputs = MODEL.generate_with_chroma(
|
| 81 |
descriptions=texts,
|
| 82 |
melody_wavs=processed_melodies,
|
|
@@ -110,12 +113,10 @@ def predict_batched(texts, melodies):
|
|
| 110 |
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
| 111 |
topk = int(topk)
|
| 112 |
load_model(model)
|
| 113 |
-
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
| 114 |
-
raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
|
| 115 |
|
| 116 |
outs = _do_predictions(
|
| 117 |
[text], [melody], duration,
|
| 118 |
-
|
| 119 |
return outs[0]
|
| 120 |
|
| 121 |
|
|
@@ -138,7 +139,7 @@ def ui_full(launch_kwargs):
|
|
| 138 |
with gr.Row():
|
| 139 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
| 140 |
with gr.Row():
|
| 141 |
-
duration = gr.Slider(minimum=1, maximum=
|
| 142 |
with gr.Row():
|
| 143 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
| 144 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
|
@@ -184,7 +185,12 @@ def ui_full(launch_kwargs):
|
|
| 184 |
### More details
|
| 185 |
|
| 186 |
The model will generate a short music extract based on the description you provided.
|
| 187 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
|
| 189 |
We present 4 model variations:
|
| 190 |
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def _do_predictions(texts, melodies, duration, **gen_kwargs):
|
| 62 |
+
if duration > MODEL.lm.cfg.dataset.segment_duration:
|
| 63 |
+
raise gr.Error("MusicGen currently supports durations of up to 30 seconds!")
|
| 64 |
+
|
| 65 |
MODEL.set_generation_params(duration=duration, **gen_kwargs)
|
| 66 |
print("new batch", len(texts), texts, [None if m is None else (m[0], m[1].shape) for m in melodies])
|
| 67 |
be = time.time()
|
|
|
|
| 79 |
melody = convert_audio(melody, sr, target_sr, target_ac)
|
| 80 |
processed_melodies.append(melody)
|
| 81 |
|
| 82 |
+
if any(m is not None for m in processed_melodies):
|
| 83 |
outputs = MODEL.generate_with_chroma(
|
| 84 |
descriptions=texts,
|
| 85 |
melody_wavs=processed_melodies,
|
|
|
|
| 113 |
def predict_full(model, text, melody, duration, topk, topp, temperature, cfg_coef):
|
| 114 |
topk = int(topk)
|
| 115 |
load_model(model)
|
|
|
|
|
|
|
| 116 |
|
| 117 |
outs = _do_predictions(
|
| 118 |
[text], [melody], duration,
|
| 119 |
+
top_k=topk, top_p=topp, temperature=temperature, cfg_coef=cfg_coef)
|
| 120 |
return outs[0]
|
| 121 |
|
| 122 |
|
|
|
|
| 139 |
with gr.Row():
|
| 140 |
model = gr.Radio(["melody", "medium", "small", "large"], label="Model", value="melody", interactive=True)
|
| 141 |
with gr.Row():
|
| 142 |
+
duration = gr.Slider(minimum=1, maximum=120, value=10, label="Duration", interactive=True)
|
| 143 |
with gr.Row():
|
| 144 |
topk = gr.Number(label="Top-k", value=250, interactive=True)
|
| 145 |
topp = gr.Number(label="Top-p", value=0, interactive=True)
|
|
|
|
| 185 |
### More details
|
| 186 |
|
| 187 |
The model will generate a short music extract based on the description you provided.
|
| 188 |
+
The model can generate up to 30 seconds of audio in one pass. It is now possible
|
| 189 |
+
to extend the generation by feeding back the end of the previous chunk of audio.
|
| 190 |
+
This can take a long time, and the model might lose consistency. The model might also
|
| 191 |
+
decide at arbitrary positions that the song ends.
|
| 192 |
+
|
| 193 |
+
**WARNING:** Choosing long durations will take a long time to generate (2min might take ~10min).
|
| 194 |
|
| 195 |
We present 4 model variations:
|
| 196 |
1. Melody -- a music generation model capable of generating music condition on text and melody inputs. **Note**, you can also use text only.
|
audiocraft/models/musicgen.py
CHANGED
|
@@ -45,6 +45,7 @@ class MusicGen:
|
|
| 45 |
self.device = next(iter(lm.parameters())).device
|
| 46 |
self.generation_params: dict = {}
|
| 47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
|
|
|
| 48 |
if self.device.type == 'cpu':
|
| 49 |
self.autocast = TorchAutocast(enabled=False)
|
| 50 |
else:
|
|
@@ -127,6 +128,9 @@ class MusicGen:
|
|
| 127 |
'two_step_cfg': two_step_cfg,
|
| 128 |
}
|
| 129 |
|
|
|
|
|
|
|
|
|
|
| 130 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
| 131 |
"""Generate samples in an unconditional manner.
|
| 132 |
|
|
@@ -274,6 +278,10 @@ class MusicGen:
|
|
| 274 |
current_gen_offset: int = 0
|
| 275 |
|
| 276 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
| 278 |
|
| 279 |
if prompt_tokens is not None:
|
|
@@ -296,11 +304,17 @@ class MusicGen:
|
|
| 296 |
# melody conditioning etc.
|
| 297 |
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
| 298 |
all_tokens = []
|
| 299 |
-
if prompt_tokens is
|
|
|
|
|
|
|
| 300 |
all_tokens.append(prompt_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
|
| 302 |
-
|
| 303 |
-
|
| 304 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 305 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
| 306 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
|
@@ -321,14 +335,13 @@ class MusicGen:
|
|
| 321 |
gen_tokens = self.lm.generate(
|
| 322 |
prompt_tokens, attributes,
|
| 323 |
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
| 324 |
-
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 325 |
if prompt_tokens is None:
|
| 326 |
all_tokens.append(gen_tokens)
|
| 327 |
else:
|
| 328 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
| 329 |
-
prompt_tokens = gen_tokens[:, :, stride_tokens]
|
|
|
|
| 330 |
current_gen_offset += stride_tokens
|
| 331 |
-
time_offset += self.extend_stride
|
| 332 |
|
| 333 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
| 334 |
|
|
|
|
| 45 |
self.device = next(iter(lm.parameters())).device
|
| 46 |
self.generation_params: dict = {}
|
| 47 |
self.set_generation_params(duration=15) # 15 seconds by default
|
| 48 |
+
self._progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None
|
| 49 |
if self.device.type == 'cpu':
|
| 50 |
self.autocast = TorchAutocast(enabled=False)
|
| 51 |
else:
|
|
|
|
| 128 |
'two_step_cfg': two_step_cfg,
|
| 129 |
}
|
| 130 |
|
| 131 |
+
def set_custom_progress_callback(self, progress_callback: tp.Optional[tp.Callable[[int, int], None]] = None):
|
| 132 |
+
self._progress_callback = progress_callback
|
| 133 |
+
|
| 134 |
def generate_unconditional(self, num_samples: int, progress: bool = False) -> torch.Tensor:
|
| 135 |
"""Generate samples in an unconditional manner.
|
| 136 |
|
|
|
|
| 278 |
current_gen_offset: int = 0
|
| 279 |
|
| 280 |
def _progress_callback(generated_tokens: int, tokens_to_generate: int):
|
| 281 |
+
generated_tokens += current_gen_offset
|
| 282 |
+
if self._progress_callback is not None:
|
| 283 |
+
self._progress_callback(generated_tokens, total_gen_len)
|
| 284 |
+
else:
|
| 285 |
print(f'{current_gen_offset + generated_tokens: 6d} / {total_gen_len: 6d}', end='\r')
|
| 286 |
|
| 287 |
if prompt_tokens is not None:
|
|
|
|
| 304 |
# melody conditioning etc.
|
| 305 |
ref_wavs = [attr.wav['self_wav'] for attr in attributes]
|
| 306 |
all_tokens = []
|
| 307 |
+
if prompt_tokens is None:
|
| 308 |
+
prompt_length = 0
|
| 309 |
+
else:
|
| 310 |
all_tokens.append(prompt_tokens)
|
| 311 |
+
prompt_length = prompt_tokens.shape[-1]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
stride_tokens = int(self.frame_rate * self.extend_stride)
|
| 315 |
|
| 316 |
+
while current_gen_offset + prompt_length < total_gen_len:
|
| 317 |
+
time_offset = current_gen_offset / self.frame_rate
|
| 318 |
chunk_duration = min(self.duration - time_offset, self.max_duration)
|
| 319 |
max_gen_len = int(chunk_duration * self.frame_rate)
|
| 320 |
for attr, ref_wav in zip(attributes, ref_wavs):
|
|
|
|
| 335 |
gen_tokens = self.lm.generate(
|
| 336 |
prompt_tokens, attributes,
|
| 337 |
callback=callback, max_gen_len=max_gen_len, **self.generation_params)
|
|
|
|
| 338 |
if prompt_tokens is None:
|
| 339 |
all_tokens.append(gen_tokens)
|
| 340 |
else:
|
| 341 |
all_tokens.append(gen_tokens[:, :, prompt_tokens.shape[-1]:])
|
| 342 |
+
prompt_tokens = gen_tokens[:, :, stride_tokens:]
|
| 343 |
+
prompt_length = prompt_tokens.shape[-1]
|
| 344 |
current_gen_offset += stride_tokens
|
|
|
|
| 345 |
|
| 346 |
gen_tokens = torch.cat(all_tokens, dim=-1)
|
| 347 |
|