Spaces:
Runtime error
Runtime error
CY
commited on
Commit
·
7d35d1e
1
Parent(s):
a073f0a
Added jam space
Browse files- TangoFlux.py +0 -58
- app.py +64 -151
- gt0.json +1 -0
- model.py +162 -493
- requirements.txt +34 -10
TangoFlux.py
DELETED
|
@@ -1,58 +0,0 @@
|
|
| 1 |
-
from diffusers import AutoencoderOobleck
|
| 2 |
-
import torch
|
| 3 |
-
from transformers import T5EncoderModel,T5TokenizerFast
|
| 4 |
-
from diffusers import FluxTransformer2DModel
|
| 5 |
-
from torch import nn
|
| 6 |
-
from typing import List
|
| 7 |
-
from diffusers import FlowMatchEulerDiscreteScheduler
|
| 8 |
-
from diffusers.training_utils import compute_density_for_timestep_sampling
|
| 9 |
-
import copy
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
-
import numpy as np
|
| 12 |
-
from model import TangoFlux
|
| 13 |
-
from huggingface_hub import snapshot_download
|
| 14 |
-
from tqdm import tqdm
|
| 15 |
-
from typing import Optional,Union,List
|
| 16 |
-
from datasets import load_dataset, Audio
|
| 17 |
-
from math import pi
|
| 18 |
-
import json
|
| 19 |
-
import inspect
|
| 20 |
-
import yaml
|
| 21 |
-
from safetensors.torch import load_file
|
| 22 |
-
import os
|
| 23 |
-
print(os.environ['HOME'])
|
| 24 |
-
|
| 25 |
-
class TangoFluxInference:
|
| 26 |
-
|
| 27 |
-
def __init__(self,name='declare-lab/TangoFlux',device="cuda"):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
self.vae = AutoencoderOobleck.from_pretrained("stabilityai/stable-audio-open-1.0",subfolder='vae',token=os.environ['HF_TOKEN'])
|
| 31 |
-
|
| 32 |
-
paths = snapshot_download(repo_id=name,token=os.environ['HF_TOKEN'])
|
| 33 |
-
weights = load_file("{}/model_1.safetensors".format(paths))
|
| 34 |
-
|
| 35 |
-
with open('{}/config.json'.format(paths),'r') as f:
|
| 36 |
-
config = json.load(f)
|
| 37 |
-
self.model = TangoFlux(config)
|
| 38 |
-
self.model.load_state_dict(weights,strict=False)
|
| 39 |
-
# _IncompatibleKeys(missing_keys=['text_encoder.encoder.embed_tokens.weight'], unexpected_keys=[]) this behaviour is expected
|
| 40 |
-
self.vae.to(device)
|
| 41 |
-
self.model.to(device)
|
| 42 |
-
|
| 43 |
-
def generate(self,prompt,steps=25,duration=10,guidance_scale=4.5):
|
| 44 |
-
|
| 45 |
-
with torch.no_grad():
|
| 46 |
-
latents = self.model.inference_flow(prompt,
|
| 47 |
-
duration=duration,
|
| 48 |
-
num_inference_steps=steps,
|
| 49 |
-
guidance_scale=guidance_scale)
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
wave = self.vae.decode(latents.transpose(2,1)).sample.cpu()[0]
|
| 54 |
-
return wave
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
|
@@ -1,155 +1,68 @@
|
|
| 1 |
-
import spaces
|
| 2 |
import gradio as gr
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
@spaces.GPU(duration=15)
|
| 27 |
-
def gradio_generate(prompt, steps, guidance,duration=10):
|
| 28 |
-
|
| 29 |
-
output = tangoflux.generate(prompt,steps=steps,guidance_scale=guidance,duration=duration)
|
| 30 |
-
#output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
| 31 |
|
| 32 |
-
|
| 33 |
-
#wavio.write(output_filename, output_wave, rate=44100, sampwidth=2)
|
| 34 |
-
filename = 'temp.wav'
|
| 35 |
-
#print(f"Saving audio to file: {unique_filename}")
|
| 36 |
-
|
| 37 |
-
# Save to file
|
| 38 |
-
torchaudio.save(filename, output, 44100)
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
# Return the path to the generated audio file
|
| 42 |
-
return filename
|
| 43 |
-
|
| 44 |
-
#if (output_format == "mp3"):
|
| 45 |
-
# AudioSegment.from_wav("temp.wav").export("temp.mp3", format = "mp3")
|
| 46 |
-
# output_filename = "temp.mp3"
|
| 47 |
-
|
| 48 |
-
#return output_filename
|
| 49 |
-
|
| 50 |
-
description_text = """
|
| 51 |
-
Generate high quality and faithful audio in just a few seconds using <b>TangoFlux</b> by providing a text prompt. <b>TangoFlux</b> was trained from scratch and underwent alignment to follow human instructions using a new method called <b>Claped-Ranked Preference Optimization (CRPO)</b>.
|
| 52 |
-
<div style="display: flex; gap: 10px; align-items: center;">
|
| 53 |
-
<a href="https://arxiv.org/abs/2412.21037">
|
| 54 |
-
<img src="https://img.shields.io/badge/Read_the_Paper-blue?link=https%3A%2F%2Fopenreview.net%2Fattachment%3Fid%3DtpJPlFTyxd%26name%3Dpdf" alt="arXiv">
|
| 55 |
-
</a>
|
| 56 |
-
<a href="https://huggingface.co/declare-lab/TangoFlux">
|
| 57 |
-
<img src="https://img.shields.io/badge/TangoFlux-Huggingface-violet?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
| 58 |
-
</a>
|
| 59 |
-
<a href="https://tangoflux.github.io/">
|
| 60 |
-
<img src="https://img.shields.io/badge/Demos-declare--lab-brightred?style=flat" alt="Static Badge">
|
| 61 |
-
</a>
|
| 62 |
-
<a href="https://huggingface.co/spaces/declare-lab/TangoFlux">
|
| 63 |
-
<img src="https://img.shields.io/badge/TangoFlux-Huggingface_Space-8A2BE2?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fspaces%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
| 64 |
-
</a>
|
| 65 |
-
<a href="https://huggingface.co/datasets/declare-lab/CRPO">
|
| 66 |
-
<img src="https://img.shields.io/badge/TangoFlux_Dataset-Huggingface-red?logo=huggingface&link=https%3A%2F%2Fhuggingface.co%2Fdatasets%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
| 67 |
-
</a>
|
| 68 |
-
<a href="https://github.com/declare-lab/TangoFlux">
|
| 69 |
-
<img src="https://img.shields.io/badge/Github-brown?logo=github&link=https%3A%2F%2Fgithub.com%2Fdeclare-lab%2FTangoFlux" alt="Static Badge">
|
| 70 |
-
</a>
|
| 71 |
-
</div>
|
| 72 |
-
"""
|
| 73 |
-
# Gradio input and output components
|
| 74 |
-
input_text = gr.Textbox(lines=2, label="Prompt")
|
| 75 |
-
#output_format = gr.Radio(label = "Output format", info = "The file you can dowload", choices = "wav"], value = "wav")
|
| 76 |
-
output_audio = gr.Audio(label="Generated Audio", type="filepath")
|
| 77 |
-
denoising_steps = gr.Slider(minimum=10, maximum=100, value=25, step=5, label="Steps", interactive=True)
|
| 78 |
-
guidance_scale = gr.Slider(minimum=1, maximum=10, value=4.5, step=0.5, label="Guidance Scale", interactive=True)
|
| 79 |
-
duration_scale = gr.Slider(minimum=1, maximum=30, value=10, step=1, label="Duration", interactive=True)
|
| 80 |
-
|
| 81 |
|
| 82 |
# Gradio interface
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
[
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
[
|
| 114 |
-
["
|
| 115 |
-
["
|
| 116 |
-
|
| 117 |
-
[
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
["Massive stadium crowd cheering as thunder crashes and lightning strikes"],
|
| 125 |
-
["Heavy helicopter blades chopping through air with engine and wind noise"],
|
| 126 |
-
["Dog barking excitedly and man shouting as race car engine roars past"],
|
| 127 |
-
["Quiet speech and then and airplane flying away"],
|
| 128 |
-
["A bicycle peddling on dirt and gravel followed by a man speaking then laughing"],
|
| 129 |
-
["Ducks quack and water splashes with some animal screeching in the background"],
|
| 130 |
-
["Describe the sound of the ocean"],
|
| 131 |
-
["A woman and a baby are having a conversation"],
|
| 132 |
-
["A man speaks followed by a popping noise and laughter"],
|
| 133 |
-
["A cup is filled from a faucet"],
|
| 134 |
-
["An audience cheering and clapping"],
|
| 135 |
-
["Rolling thunder with lightning strikes"],
|
| 136 |
-
["A dog barking and a cat mewing and a racing car passes by"],
|
| 137 |
-
["Gentle water stream, birds chirping and sudden gun shot"],
|
| 138 |
-
["A dog barking"],
|
| 139 |
-
["A cat meowing"],
|
| 140 |
-
["Wooden table tapping sound while water pouring"],
|
| 141 |
-
["Applause from a crowd with distant clicking and a man speaking over a loudspeaker"],
|
| 142 |
-
["two gunshots followed by birds flying away while chirping"],
|
| 143 |
-
["Whistling with birds chirping"],
|
| 144 |
-
["A person snoring"],
|
| 145 |
-
["Motor vehicles are driving with loud engines and a person whistles"],
|
| 146 |
-
["People cheering in a stadium while thunder and lightning strikes"],
|
| 147 |
-
["A helicopter is in flight"],
|
| 148 |
-
["A dog barking and a man talking and a racing car passes by"],
|
| 149 |
-
],
|
| 150 |
-
cache_examples="lazy", # Turn on to cache.
|
| 151 |
-
)
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
gr_interface.queue(15).launch()
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
from model import Jamify
|
| 4 |
+
|
| 5 |
+
# Initialize the Jamify model once
|
| 6 |
+
print("Initializing Jamify model...")
|
| 7 |
+
jamify_model = Jamify()
|
| 8 |
+
print("Jamify model ready.")
|
| 9 |
+
|
| 10 |
+
def generate_song(reference_audio, lyrics_file, style_prompt, duration):
|
| 11 |
+
# We need to save the uploaded files to temporary paths to pass to the model
|
| 12 |
+
ref_audio_path = reference_audio.name if reference_audio else None
|
| 13 |
+
lyrics_path = lyrics_file.name
|
| 14 |
+
|
| 15 |
+
# The model expects paths, so we write the prompt to a temp file if needed
|
| 16 |
+
# (This part of the model could be improved to accept the string directly)
|
| 17 |
+
|
| 18 |
+
output_path = jamify_model.predict(
|
| 19 |
+
reference_audio_path=ref_audio_path,
|
| 20 |
+
lyrics_json_path=lyrics_path,
|
| 21 |
+
style_prompt=style_prompt,
|
| 22 |
+
duration_sec=duration
|
| 23 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Gradio interface
|
| 28 |
+
with gr.Blocks() as demo:
|
| 29 |
+
gr.Markdown("# Jamify: Music Generation from Lyrics and Style")
|
| 30 |
+
gr.Markdown("Provide your lyrics, a style reference (either an audio file or a text prompt), and a desired duration to generate a song.")
|
| 31 |
+
|
| 32 |
+
with gr.Row():
|
| 33 |
+
with gr.Column():
|
| 34 |
+
gr.Markdown("### Inputs")
|
| 35 |
+
lyrics_file = gr.File(label="Lyrics File (.json)", type="filepath")
|
| 36 |
+
duration_slider = gr.Slider(minimum=5, maximum=180, value=30, step=1, label="Duration (seconds)")
|
| 37 |
+
|
| 38 |
+
with gr.Tab("Style from Audio"):
|
| 39 |
+
reference_audio = gr.File(label="Reference Audio (.mp3, .wav)", type="filepath")
|
| 40 |
+
with gr.Tab("Style from Text"):
|
| 41 |
+
style_prompt = gr.Textbox(label="Style Prompt", lines=3, placeholder="e.g., A high-energy electronic dance track with a strong bassline and euphoric synths.")
|
| 42 |
+
|
| 43 |
+
generate_button = gr.Button("Generate Song", variant="primary")
|
| 44 |
+
|
| 45 |
+
with gr.Column():
|
| 46 |
+
gr.Markdown("### Output")
|
| 47 |
+
output_audio = gr.Audio(label="Generated Song")
|
| 48 |
+
|
| 49 |
+
generate_button.click(
|
| 50 |
+
fn=generate_song,
|
| 51 |
+
inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
|
| 52 |
+
outputs=output_audio,
|
| 53 |
+
api_name="generate_song"
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
gr.Markdown("### Example Usage")
|
| 57 |
+
gr.Examples(
|
| 58 |
+
examples=[
|
| 59 |
+
[None, "jamify/inputs/Jade Bird - Avalanche.json", "A sad, slow, acoustic country song", 30],
|
| 60 |
+
["jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.mp3", "jamify/inputs/Rizzle Kicks, Rachel Chinouriri - Follow Excitement!.json", "", 45],
|
| 61 |
+
],
|
| 62 |
+
inputs=[reference_audio, lyrics_file, style_prompt, duration_slider],
|
| 63 |
+
outputs=output_audio,
|
| 64 |
+
fn=generate_song,
|
| 65 |
+
cache_examples=True
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt0.json
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[{"word": "Every", "start_offset": 259, "end_offset": 267, "start": 20.72, "end": 21.36, "phoneme": "\u025bv\u025di|_"}, {"word": "night", "start_offset": 267, "end_offset": 275, "start": 21.36, "end": 22.0, "phoneme": "na\u026at|_"}, {"word": "in", "start_offset": 279, "end_offset": 283, "start": 22.32, "end": 22.64, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 283, "end_offset": 287, "start": 22.64, "end": 22.96, "phoneme": "ma\u026a|_"}, {"word": "dreams,", "start_offset": 287, "end_offset": 301, "start": 22.96, "end": 24.080000000000002, "phoneme": "dri\u02d0mz,"}, {"word": "I", "start_offset": 309, "end_offset": 313, "start": 24.72, "end": 25.04, "phoneme": "a\u026a|_"}, {"word": "see", "start_offset": 317, "end_offset": 321, "start": 25.36, "end": 25.68, "phoneme": "si\u02d0|_"}, {"word": "you,", "start_offset": 321, "end_offset": 325, "start": 25.68, "end": 26.0, "phoneme": "ju\u02d0,"}, {"word": "I", "start_offset": 340, "end_offset": 344, "start": 27.2, "end": 27.52, "phoneme": "a\u026a|_"}, {"word": "feel", "start_offset": 348, "end_offset": 352, "start": 27.84, "end": 28.16, "phoneme": "fi\u02d0l|_"}, {"word": "you.", "start_offset": 358, "end_offset": 362, "start": 28.64, "end": 28.96, "phoneme": "ju\u02d0."}, {"word": "That", "start_offset": 377, "end_offset": 381, "start": 30.16, "end": 30.48, "phoneme": "\u00f0\u00e6t|_"}, {"word": "is", "start_offset": 385, "end_offset": 389, "start": 30.8, "end": 31.12, "phoneme": "\u026az"}, {"word": "how", "start_offset": 393, "end_offset": 397, "start": 31.44, "end": 31.76, "phoneme": "ha\u028a|_"}, {"word": "I", "start_offset": 401, "end_offset": 405, "start": 32.08, "end": 32.4, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 405, "end_offset": 409, "start": 32.4, "end": 32.72, "phoneme": "no\u028a|_"}, {"word": "you", "start_offset": 413, "end_offset": 417, "start": 33.04, "end": 33.36, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 428, "end_offset": 431, "start": 34.24, "end": 34.480000000000004, "phoneme": "go\u028a|_"}, {"word": "far", "start_offset": 495, "end_offset": 503, "start": 39.6, "end": 40.24, "phoneme": "f\u0251\u02d0r"}, {"word": "across", "start_offset": 507, "end_offset": 517, "start": 40.56, "end": 41.36, "phoneme": "\u0259kr\u0254s|_"}, {"word": "the", "start_offset": 519, "end_offset": 523, "start": 41.52, "end": 41.84, "phoneme": "\u00f0\u0259|_"}, {"word": "distance", "start_offset": 527, "end_offset": 538, "start": 42.160000000000004, "end": 43.04, "phoneme": "d\u026ast\u0259ns|_"}, {"word": "and", "start_offset": 552, "end_offset": 556, "start": 44.160000000000004, "end": 44.480000000000004, "phoneme": "\u0259nd"}, {"word": "spaces", "start_offset": 556, "end_offset": 572, "start": 44.480000000000004, "end": 45.76, "phoneme": "spe\u026as\u0259z"}, {"word": "between", "start_offset": 583, "end_offset": 587, "start": 46.64, "end": 46.96, "phoneme": "b\u026atwi\u02d0n|_"}, {"word": "us.", "start_offset": 602, "end_offset": 606, "start": 48.160000000000004, "end": 48.480000000000004, "phoneme": "\u028cs."}, {"word": "You", "start_offset": 621, "end_offset": 625, "start": 49.68, "end": 50.0, "phoneme": "ju\u02d0|_"}, {"word": "have", "start_offset": 629, "end_offset": 633, "start": 50.32, "end": 50.64, "phoneme": "h\u00e6v"}, {"word": "come", "start_offset": 633, "end_offset": 637, "start": 50.64, "end": 50.96, "phoneme": "k\u028cm|_"}, {"word": "to", "start_offset": 641, "end_offset": 645, "start": 51.28, "end": 51.6, "phoneme": "tu\u02d0|_"}, {"word": "show", "start_offset": 649, "end_offset": 653, "start": 51.92, "end": 52.24, "phoneme": "\u0283o\u028a|_"}, {"word": "you", "start_offset": 655, "end_offset": 659, "start": 52.4, "end": 52.72, "phoneme": "ju\u02d0|_"}, {"word": "go", "start_offset": 673, "end_offset": 676, "start": 53.84, "end": 54.08, "phoneme": "go\u028a|_"}, {"word": "near,", "start_offset": 738, "end_offset": 745, "start": 59.04, "end": 59.6, "phoneme": "n\u026ar,"}, {"word": "far,", "start_offset": 768, "end_offset": 776, "start": 61.44, "end": 62.08, "phoneme": "f\u0251\u02d0r,"}, {"word": "wherever", "start_offset": 794, "end_offset": 806, "start": 63.52, "end": 64.48, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 822, "end_offset": 826, "start": 65.76, "end": 66.08, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 826, "end_offset": 830, "start": 66.08, "end": 66.4, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 849, "end_offset": 852, "start": 67.92, "end": 68.16, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 856, "end_offset": 868, "start": 68.48, "end": 69.44, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 875, "end_offset": 878, "start": 70.0, "end": 70.24, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 886, "end_offset": 890, "start": 70.88, "end": 71.2, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 890, "end_offset": 898, "start": 71.2, "end": 71.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 898, "end_offset": 901, "start": 71.84, "end": 72.08, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 916, "end_offset": 920, "start": 73.28, "end": 73.60000000000001, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 982, "end_offset": 985, "start": 78.56, "end": 78.8, "phoneme": "\u0251\u02d0n|_"}, {"word": "small.", "start_offset": 1009, "end_offset": 1017, "start": 80.72, "end": 81.36, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 1037, "end_offset": 1041, "start": 82.96000000000001, "end": 83.28, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 1045, "end_offset": 1049, "start": 83.60000000000001, "end": 83.92, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 1065, "end_offset": 1069, "start": 85.2, "end": 85.52, "phoneme": "\u00f0\u0259|_"}, {"word": "door,", "start_offset": 1069, "end_offset": 1076, "start": 85.52, "end": 86.08, "phoneme": "d\u0254r,"}, {"word": "and", "start_offset": 1090, "end_offset": 1094, "start": 87.2, "end": 87.52, "phoneme": "\u0259nd"}, {"word": "you'll", "start_offset": 1094, "end_offset": 1100, "start": 87.52, "end": 88.0, "phoneme": "j\u028c\u028al|_"}, {"word": "hear", "start_offset": 1103, "end_offset": 1108, "start": 88.24, "end": 88.64, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 1119, "end_offset": 1122, "start": 89.52, "end": 89.76, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1126, "end_offset": 1130, "start": 90.08, "end": 90.4, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 1130, "end_offset": 1138, "start": 90.4, "end": 91.04, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 1141, "end_offset": 1145, "start": 91.28, "end": 91.60000000000001, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 1157, "end_offset": 1161, "start": 92.56, "end": 92.88, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 1165, "end_offset": 1173, "start": 93.2, "end": 93.84, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 1173, "end_offset": 1177, "start": 93.84, "end": 94.16, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 1185, "end_offset": 1189, "start": 94.8, "end": 95.12, "phoneme": "go\u028a|_"}, {"word": "and", "start_offset": 1211, "end_offset": 1215, "start": 96.88, "end": 97.2, "phoneme": "\u0259nd"}, {"word": "dawn.", "start_offset": 1223, "end_offset": 1233, "start": 97.84, "end": 98.64, "phoneme": "d\u0254n."}, {"word": "Love", "start_offset": 1345, "end_offset": 1353, "start": 107.60000000000001, "end": 108.24000000000001, "phoneme": "l\u028cv"}, {"word": "can", "start_offset": 1356, "end_offset": 1360, "start": 108.48, "end": 108.8, "phoneme": "k\u00e6n|_"}, {"word": "touch", "start_offset": 1360, "end_offset": 1366, "start": 108.8, "end": 109.28, "phoneme": "t\u028ct\u0283|_"}, {"word": "us", "start_offset": 1369, "end_offset": 1373, "start": 109.52, "end": 109.84, "phoneme": "\u028cs|_"}, {"word": "one", "start_offset": 1376, "end_offset": 1380, "start": 110.08, "end": 110.4, "phoneme": "w\u028cn|_"}, {"word": "time", "start_offset": 1384, "end_offset": 1388, "start": 110.72, "end": 111.04, "phoneme": "ta\u026am|_"}, {"word": "and", "start_offset": 1399, "end_offset": 1402, "start": 111.92, "end": 112.16, "phoneme": "\u0259nd"}, {"word": "last", "start_offset": 1406, "end_offset": 1410, "start": 112.48, "end": 112.8, "phoneme": "l\u00e6st|_"}, {"word": "for", "start_offset": 1416, "end_offset": 1420, "start": 113.28, "end": 113.60000000000001, "phoneme": "f\u0254r"}, {"word": "a", "start_offset": 1431, "end_offset": 1435, "start": 114.48, "end": 114.8, "phoneme": "\u0259|_"}, {"word": "lifetime", "start_offset": 1435, "end_offset": 1458, "start": 114.8, "end": 116.64, "phoneme": "la\u026afta\u026am|_"}, {"word": "and", "start_offset": 1471, "end_offset": 1475, "start": 117.68, "end": 118.0, "phoneme": "\u0259nd"}, {"word": "never", "start_offset": 1479, "end_offset": 1483, "start": 118.32000000000001, "end": 118.64, "phoneme": "n\u025bv\u025d"}, {"word": "let", "start_offset": 1487, "end_offset": 1491, "start": 118.96000000000001, "end": 119.28, "phoneme": "l\u025bt|_"}, {"word": "go", "start_offset": 1495, "end_offset": 1499, "start": 119.60000000000001, "end": 119.92, "phoneme": "go\u028a|_"}, {"word": "till", "start_offset": 1503, "end_offset": 1511, "start": 120.24000000000001, "end": 120.88, "phoneme": "t\u026al|_"}, {"word": "we're", "start_offset": 1521, "end_offset": 1528, "start": 121.68, "end": 122.24000000000001, "phoneme": "w\u025d\u02d0|_"}, {"word": "gone.", "start_offset": 1528, "end_offset": 1536, "start": 122.24000000000001, "end": 122.88, "phoneme": "g\u0254n."}, {"word": "Love", "start_offset": 1587, "end_offset": 1596, "start": 126.96000000000001, "end": 127.68, "phoneme": "l\u028cv"}, {"word": "was", "start_offset": 1599, "end_offset": 1603, "start": 127.92, "end": 128.24, "phoneme": "w\u0251\u02d0z"}, {"word": "when", "start_offset": 1607, "end_offset": 1611, "start": 128.56, "end": 128.88, "phoneme": "w\u025bn|_"}, {"word": "I", "start_offset": 1611, "end_offset": 1615, "start": 128.88, "end": 129.2, "phoneme": "a\u026a|_"}, {"word": "loved", "start_offset": 1615, "end_offset": 1626, "start": 129.2, "end": 130.08, "phoneme": "l\u028cvd"}, {"word": "you", "start_offset": 1626, "end_offset": 1630, "start": 130.08, "end": 130.4, "phoneme": "ju\u02d0|_"}, {"word": "one", "start_offset": 1641, "end_offset": 1644, "start": 131.28, "end": 131.52, "phoneme": "w\u028cn|_"}, {"word": "true", "start_offset": 1648, "end_offset": 1656, "start": 131.84, "end": 132.48, "phoneme": "tru\u02d0|_"}, {"word": "time.", "start_offset": 1656, "end_offset": 1660, "start": 132.48, "end": 132.8, "phoneme": "ta\u026am."}, {"word": "I", "start_offset": 1672, "end_offset": 1675, "start": 133.76, "end": 134.0, "phoneme": "a\u026a|_"}, {"word": "hold", "start_offset": 1679, "end_offset": 1687, "start": 134.32, "end": 134.96, "phoneme": "ho\u028ald"}, {"word": "to", "start_offset": 1691, "end_offset": 1693, "start": 135.28, "end": 135.44, "phoneme": "tu\u02d0|_"}, {"word": "in", "start_offset": 1712, "end_offset": 1716, "start": 136.96, "end": 137.28, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 1720, "end_offset": 1724, "start": 137.6, "end": 137.92000000000002, "phoneme": "ma\u026a|_"}, {"word": "life", "start_offset": 1724, "end_offset": 1728, "start": 137.92000000000002, "end": 138.24, "phoneme": "la\u026af|_"}, {"word": "will", "start_offset": 1731, "end_offset": 1733, "start": 138.48, "end": 138.64000000000001, "phoneme": "w\u026al|_"}, {"word": "always", "start_offset": 1743, "end_offset": 1747, "start": 139.44, "end": 139.76, "phoneme": "\u0254lwe\u026az"}, {"word": "go", "start_offset": 1763, "end_offset": 1767, "start": 141.04, "end": 141.36, "phoneme": "go\u028a|_"}, {"word": "near", "start_offset": 1830, "end_offset": 1836, "start": 146.4, "end": 146.88, "phoneme": "n\u026ar"}, {"word": "far", "start_offset": 1859, "end_offset": 1867, "start": 148.72, "end": 149.36, "phoneme": "f\u0251\u02d0r"}, {"word": "wherever", "start_offset": 1884, "end_offset": 1896, "start": 150.72, "end": 151.68, "phoneme": "w\u025br\u025bv\u025d"}, {"word": "you", "start_offset": 1914, "end_offset": 1918, "start": 153.12, "end": 153.44, "phoneme": "ju\u02d0|_"}, {"word": "are.", "start_offset": 1918, "end_offset": 1922, "start": 153.44, "end": 153.76, "phoneme": "\u0251\u02d0r."}, {"word": "I", "start_offset": 1940, "end_offset": 1943, "start": 155.20000000000002, "end": 155.44, "phoneme": "a\u026a|_"}, {"word": "believe", "start_offset": 1947, "end_offset": 1959, "start": 155.76, "end": 156.72, "phoneme": "b\u026ali\u02d0v"}, {"word": "that", "start_offset": 1966, "end_offset": 1970, "start": 157.28, "end": 157.6, "phoneme": "\u00f0\u00e6t|_"}, {"word": "the", "start_offset": 1974, "end_offset": 1977, "start": 157.92000000000002, "end": 158.16, "phoneme": "\u00f0\u0259|_"}, {"word": "heart", "start_offset": 1981, "end_offset": 1986, "start": 158.48, "end": 158.88, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "does", "start_offset": 1990, "end_offset": 1993, "start": 159.20000000000002, "end": 159.44, "phoneme": "d\u028cz"}, {"word": "go", "start_offset": 2008, "end_offset": 2011, "start": 160.64000000000001, "end": 160.88, "phoneme": "go\u028a|_"}, {"word": "small.", "start_offset": 2099, "end_offset": 2111, "start": 167.92000000000002, "end": 168.88, "phoneme": "sm\u0254l."}, {"word": "You", "start_offset": 2127, "end_offset": 2131, "start": 170.16, "end": 170.48, "phoneme": "ju\u02d0|_"}, {"word": "open", "start_offset": 2136, "end_offset": 2140, "start": 170.88, "end": 171.20000000000002, "phoneme": "o\u028ap\u0259n|_"}, {"word": "the", "start_offset": 2156, "end_offset": 2160, "start": 172.48, "end": 172.8, "phoneme": "\u00f0\u0259|_"}, {"word": "door", "start_offset": 2160, "end_offset": 2167, "start": 172.8, "end": 173.36, "phoneme": "d\u0254r"}, {"word": "and", "start_offset": 2181, "end_offset": 2185, "start": 174.48, "end": 174.8, "phoneme": "\u0259nd"}, {"word": "you", "start_offset": 2185, "end_offset": 2187, "start": 174.8, "end": 174.96, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2195, "end_offset": 2203, "start": 175.6, "end": 176.24, "phoneme": "hi\u02d0r"}, {"word": "in", "start_offset": 2209, "end_offset": 2213, "start": 176.72, "end": 177.04, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2217, "end_offset": 2221, "start": 177.36, "end": 177.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2221, "end_offset": 2230, "start": 177.68, "end": 178.4, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2232, "end_offset": 2236, "start": 178.56, "end": 178.88, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2248, "end_offset": 2251, "start": 179.84, "end": 180.08, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2255, "end_offset": 2263, "start": 180.4, "end": 181.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2263, "end_offset": 2266, "start": 181.04, "end": 181.28, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2278, "end_offset": 2282, "start": 182.24, "end": 182.56, "phoneme": "go\u028a|_"}, {"word": "on.", "start_offset": 2286, "end_offset": 2289, "start": 182.88, "end": 183.12, "phoneme": "\u0251\u02d0n."}, {"word": "You", "start_offset": 2557, "end_offset": 2559, "start": 204.56, "end": 204.72, "phoneme": "ju\u02d0|_"}, {"word": "hear", "start_offset": 2587, "end_offset": 2594, "start": 206.96, "end": 207.52, "phoneme": "hi\u02d0r"}, {"word": "there's", "start_offset": 2610, "end_offset": 2620, "start": 208.8, "end": 209.6, "phoneme": "\u00f0\u025brz"}, {"word": "nothing", "start_offset": 2620, "end_offset": 2632, "start": 209.6, "end": 210.56, "phoneme": "n\u028c\u03b8\u026a\u014b|_"}, {"word": "I", "start_offset": 2640, "end_offset": 2644, "start": 211.20000000000002, "end": 211.52, "phoneme": "a\u026a|_"}, {"word": "fear,", "start_offset": 2644, "end_offset": 2651, "start": 211.52, "end": 212.08, "phoneme": "f\u026ar,"}, {"word": "and", "start_offset": 2666, "end_offset": 2669, "start": 213.28, "end": 213.52, "phoneme": "\u0259nd"}, {"word": "I", "start_offset": 2673, "end_offset": 2677, "start": 213.84, "end": 214.16, "phoneme": "a\u026a|_"}, {"word": "know", "start_offset": 2677, "end_offset": 2681, "start": 214.16, "end": 214.48000000000002, "phoneme": "no\u028a|_"}, {"word": "that", "start_offset": 2693, "end_offset": 2697, "start": 215.44, "end": 215.76, "phoneme": "\u00f0\u00e6t|_"}, {"word": "my", "start_offset": 2701, "end_offset": 2705, "start": 216.08, "end": 216.4, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2705, "end_offset": 2713, "start": 216.4, "end": 217.04, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2717, "end_offset": 2721, "start": 217.36, "end": 217.68, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 2733, "end_offset": 2736, "start": 218.64000000000001, "end": 218.88, "phoneme": "go\u028a|_"}, {"word": "forever", "start_offset": 2852, "end_offset": 2863, "start": 228.16, "end": 229.04, "phoneme": "f\u025d\u025bv\u025d"}, {"word": "this", "start_offset": 2881, "end_offset": 2883, "start": 230.48000000000002, "end": 230.64000000000001, "phoneme": "\u00f0\u026as|_"}, {"word": "way.", "start_offset": 2888, "end_offset": 2892, "start": 231.04, "end": 231.36, "phoneme": "we\u026a."}, {"word": "You", "start_offset": 2908, "end_offset": 2911, "start": 232.64000000000001, "end": 232.88, "phoneme": "ju\u02d0|_"}, {"word": "are", "start_offset": 2914, "end_offset": 2918, "start": 233.12, "end": 233.44, "phoneme": "\u0251\u02d0r"}, {"word": "safe", "start_offset": 2928, "end_offset": 2935, "start": 234.24, "end": 234.8, "phoneme": "se\u026af|_"}, {"word": "in", "start_offset": 2938, "end_offset": 2942, "start": 235.04, "end": 235.36, "phoneme": "\u026an|_"}, {"word": "my", "start_offset": 2942, "end_offset": 2946, "start": 235.36, "end": 235.68, "phoneme": "ma\u026a|_"}, {"word": "heart,", "start_offset": 2950, "end_offset": 2957, "start": 236.0, "end": 236.56, "phoneme": "h\u0251\u02d0rt,"}, {"word": "and", "start_offset": 2959, "end_offset": 2963, "start": 236.72, "end": 237.04, "phoneme": "\u0259nd"}, {"word": "my", "start_offset": 2975, "end_offset": 2978, "start": 238.0, "end": 238.24, "phoneme": "ma\u026a|_"}, {"word": "heart", "start_offset": 2982, "end_offset": 2990, "start": 238.56, "end": 239.20000000000002, "phoneme": "h\u0251\u02d0rt|_"}, {"word": "will", "start_offset": 2990, "end_offset": 2994, "start": 239.20000000000002, "end": 239.52, "phoneme": "w\u026al|_"}, {"word": "go", "start_offset": 3002, "end_offset": 3005, "start": 240.16, "end": 240.4, "phoneme": "go\u028a|_"}, {"word": "on", "start_offset": 3009, "end_offset": 3012, "start": 240.72, "end": 240.96, "phoneme": "\u0251\u02d0n|_"}, {"word": "there.", "start_offset": 3028, "end_offset": 3032, "start": 242.24, "end": 242.56, "phoneme": "\u00f0\u025br."}]
|
model.py
CHANGED
|
@@ -1,511 +1,180 @@
|
|
| 1 |
-
from transformers import T5EncoderModel,T5TokenizerFast
|
| 2 |
-
import torch
|
| 3 |
-
from diffusers import FluxTransformer2DModel
|
| 4 |
-
from torch import nn
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
from
|
| 9 |
-
import
|
| 10 |
-
import torch.nn.functional as F
|
| 11 |
import numpy as np
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
from
|
| 15 |
-
|
| 16 |
-
from
|
| 17 |
-
import
|
| 18 |
-
import
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
def
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
)
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 119 |
-
if not accept_sigmas:
|
| 120 |
-
raise ValueError(
|
| 121 |
-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 122 |
-
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 123 |
-
)
|
| 124 |
-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 125 |
-
timesteps = scheduler.timesteps
|
| 126 |
-
num_inference_steps = len(timesteps)
|
| 127 |
-
else:
|
| 128 |
-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 129 |
-
timesteps = scheduler.timesteps
|
| 130 |
-
return timesteps, num_inference_steps
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
class TangoFlux(nn.Module):
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
def __init__(self,config,initialize_reference_model=False):
|
| 142 |
-
|
| 143 |
-
super().__init__()
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
self.num_layers = config.get('num_layers', 6)
|
| 148 |
-
self.num_single_layers = config.get('num_single_layers', 18)
|
| 149 |
-
self.in_channels = config.get('in_channels', 64)
|
| 150 |
-
self.attention_head_dim = config.get('attention_head_dim', 128)
|
| 151 |
-
self.joint_attention_dim = config.get('joint_attention_dim', 1024)
|
| 152 |
-
self.num_attention_heads = config.get('num_attention_heads', 8)
|
| 153 |
-
self.audio_seq_len = config.get('audio_seq_len', 645)
|
| 154 |
-
self.max_duration = config.get('max_duration', 30)
|
| 155 |
-
self.uncondition = config.get('uncondition', False)
|
| 156 |
-
self.text_encoder_name = config.get('text_encoder_name', "google/flan-t5-large")
|
| 157 |
-
|
| 158 |
-
self.noise_scheduler = FlowMatchEulerDiscreteScheduler(num_train_timesteps=1000)
|
| 159 |
-
self.noise_scheduler_copy = copy.deepcopy(self.noise_scheduler)
|
| 160 |
-
self.max_text_seq_len = 64
|
| 161 |
-
self.text_encoder = T5EncoderModel.from_pretrained(self.text_encoder_name)
|
| 162 |
-
self.tokenizer = T5TokenizerFast.from_pretrained(self.text_encoder_name)
|
| 163 |
-
self.text_embedding_dim = self.text_encoder.config.d_model
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
self.fc = nn.Sequential(nn.Linear(self.text_embedding_dim,self.joint_attention_dim),nn.ReLU())
|
| 167 |
-
self.duration_emebdder = DurationEmbedder(self.text_embedding_dim,min_value=0,max_value=self.max_duration)
|
| 168 |
-
|
| 169 |
-
self.transformer = FluxTransformer2DModel(
|
| 170 |
-
in_channels=self.in_channels,
|
| 171 |
-
num_layers=self.num_layers,
|
| 172 |
-
num_single_layers=self.num_single_layers,
|
| 173 |
-
attention_head_dim=self.attention_head_dim,
|
| 174 |
-
num_attention_heads=self.num_attention_heads,
|
| 175 |
-
joint_attention_dim=self.joint_attention_dim,
|
| 176 |
-
pooled_projection_dim=self.text_embedding_dim,
|
| 177 |
-
guidance_embeds=False)
|
| 178 |
-
|
| 179 |
-
self.beta_dpo = 2000 ## this is used for dpo training
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def get_sigmas(self,timesteps, n_dim=3, dtype=torch.float32):
|
| 187 |
-
device = self.text_encoder.device
|
| 188 |
-
sigmas = self.noise_scheduler_copy.sigmas.to(device=device, dtype=dtype)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
schedule_timesteps = self.noise_scheduler_copy.timesteps.to(device)
|
| 192 |
-
timesteps = timesteps.to(device)
|
| 193 |
-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
| 194 |
-
|
| 195 |
-
sigma = sigmas[step_indices].flatten()
|
| 196 |
-
while len(sigma.shape) < n_dim:
|
| 197 |
-
sigma = sigma.unsqueeze(-1)
|
| 198 |
-
return sigma
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
def encode_text_classifier_free(self, prompt: List[str], num_samples_per_prompt=1):
|
| 203 |
-
device = self.text_encoder.device
|
| 204 |
-
batch = self.tokenizer(
|
| 205 |
-
prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt"
|
| 206 |
-
)
|
| 207 |
-
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
|
| 208 |
-
|
| 209 |
-
with torch.no_grad():
|
| 210 |
-
prompt_embeds = self.text_encoder(
|
| 211 |
-
input_ids=input_ids, attention_mask=attention_mask
|
| 212 |
-
)[0]
|
| 213 |
-
|
| 214 |
-
prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
|
| 215 |
-
attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0)
|
| 216 |
-
|
| 217 |
-
# get unconditional embeddings for classifier free guidance
|
| 218 |
-
uncond_tokens = [""]
|
| 219 |
-
|
| 220 |
-
max_length = prompt_embeds.shape[1]
|
| 221 |
-
uncond_batch = self.tokenizer(
|
| 222 |
-
uncond_tokens, max_length=max_length, padding='max_length', truncation=True, return_tensors="pt",
|
| 223 |
-
)
|
| 224 |
-
uncond_input_ids = uncond_batch.input_ids.to(device)
|
| 225 |
-
uncond_attention_mask = uncond_batch.attention_mask.to(device)
|
| 226 |
-
|
| 227 |
-
with torch.no_grad():
|
| 228 |
-
negative_prompt_embeds = self.text_encoder(
|
| 229 |
-
input_ids=uncond_input_ids, attention_mask=uncond_attention_mask
|
| 230 |
-
)[0]
|
| 231 |
-
|
| 232 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0)
|
| 233 |
-
uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0)
|
| 234 |
-
|
| 235 |
-
# For classifier free guidance, we need to do two forward passes.
|
| 236 |
-
# We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes
|
| 237 |
-
|
| 238 |
-
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 239 |
-
prompt_mask = torch.cat([uncond_attention_mask, attention_mask])
|
| 240 |
-
boolean_prompt_mask = (prompt_mask == 1).to(device)
|
| 241 |
-
|
| 242 |
-
return prompt_embeds, boolean_prompt_mask
|
| 243 |
-
|
| 244 |
-
@torch.no_grad()
|
| 245 |
-
def encode_text(self, prompt):
|
| 246 |
-
device = self.text_encoder.device
|
| 247 |
-
batch = self.tokenizer(
|
| 248 |
-
prompt, max_length=self.max_text_seq_len, padding=True, truncation=True, return_tensors="pt")
|
| 249 |
-
input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device)
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
encoder_hidden_states = self.text_encoder(
|
| 254 |
-
input_ids=input_ids, attention_mask=attention_mask)[0]
|
| 255 |
-
|
| 256 |
-
boolean_encoder_mask = (attention_mask == 1).to(device)
|
| 257 |
-
|
| 258 |
-
return encoder_hidden_states, boolean_encoder_mask
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
def encode_duration(self,duration):
|
| 262 |
-
return self.duration_emebdder(duration)
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
@torch.no_grad()
|
| 267 |
-
def inference_flow(self, prompt,
|
| 268 |
-
num_inference_steps=50,
|
| 269 |
-
timesteps=None,
|
| 270 |
-
guidance_scale=3,
|
| 271 |
-
duration=10,
|
| 272 |
-
disable_progress=False,
|
| 273 |
-
num_samples_per_prompt=1):
|
| 274 |
-
|
| 275 |
-
'''Only tested for single inference. Haven't test for batch inference'''
|
| 276 |
-
|
| 277 |
-
bsz = num_samples_per_prompt
|
| 278 |
-
device = self.transformer.device
|
| 279 |
-
scheduler = self.noise_scheduler
|
| 280 |
-
|
| 281 |
-
if not isinstance(prompt,list):
|
| 282 |
-
prompt = [prompt]
|
| 283 |
-
if not isinstance(duration,torch.Tensor):
|
| 284 |
-
duration = torch.tensor([duration],device=device)
|
| 285 |
-
classifier_free_guidance = guidance_scale > 1.0
|
| 286 |
-
duration_hidden_states = self.encode_duration(duration)
|
| 287 |
-
if classifier_free_guidance:
|
| 288 |
-
bsz = 2 * num_samples_per_prompt
|
| 289 |
-
|
| 290 |
-
encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(prompt, num_samples_per_prompt=num_samples_per_prompt)
|
| 291 |
-
duration_hidden_states = duration_hidden_states.repeat(bsz,1,1)
|
| 292 |
-
|
| 293 |
-
|
| 294 |
else:
|
|
|
|
|
|
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
pooled = torch.nanmean(masked_data, dim=1)
|
| 302 |
-
pooled_projection = self.fc(pooled)
|
| 303 |
-
|
| 304 |
-
encoder_hidden_states = torch.cat([encoder_hidden_states,duration_hidden_states],dim=1) ## (bs,seq_len,dim)
|
| 305 |
-
|
| 306 |
-
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
| 307 |
-
timesteps, num_inference_steps = retrieve_timesteps(
|
| 308 |
-
scheduler,
|
| 309 |
-
num_inference_steps,
|
| 310 |
-
device,
|
| 311 |
-
timesteps,
|
| 312 |
-
sigmas
|
| 313 |
-
)
|
| 314 |
-
|
| 315 |
-
latents = torch.randn(num_samples_per_prompt,self.audio_seq_len,64)
|
| 316 |
-
weight_dtype = latents.dtype
|
| 317 |
-
|
| 318 |
-
progress_bar = tqdm(range(num_inference_steps), disable=disable_progress)
|
| 319 |
-
|
| 320 |
-
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
|
| 321 |
-
audio_ids = torch.arange(self.audio_seq_len).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
|
| 322 |
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
-
|
| 325 |
-
latents = latents.to(device)
|
| 326 |
-
encoder_hidden_states = encoder_hidden_states.to(device)
|
| 327 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
latents_input = torch.cat([latents] * 2) if classifier_free_guidance else latents
|
| 332 |
-
|
| 333 |
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
img_ids=audio_ids,
|
| 344 |
-
return_dict=False,
|
| 345 |
-
)[0]
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 350 |
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
return latents
|
| 356 |
-
|
| 357 |
-
def forward(self,
|
| 358 |
-
latents,
|
| 359 |
-
prompt,
|
| 360 |
-
duration=torch.tensor([10]),
|
| 361 |
-
sft=True
|
| 362 |
-
):
|
| 363 |
-
|
| 364 |
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
duration_hidden_states = self.encode_duration(duration)
|
| 373 |
-
|
| 374 |
|
| 375 |
-
|
| 376 |
-
masked_data = torch.where(mask_expanded, encoder_hidden_states, torch.tensor(float('nan')))
|
| 377 |
-
pooled = torch.nanmean(masked_data, dim=1)
|
| 378 |
-
pooled_projection = self.fc(pooled)
|
| 379 |
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
txt_ids = torch.zeros(bsz,encoder_hidden_states.shape[1],3).to(device)
|
| 384 |
-
audio_ids = torch.arange(audio_seq_length).unsqueeze(0).unsqueeze(-1).repeat(bsz,1,3).to(device)
|
| 385 |
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
if self.uncondition:
|
| 389 |
-
mask_indices = [k for k in range(len(prompt)) if random.random() < 0.1]
|
| 390 |
-
if len(mask_indices) > 0:
|
| 391 |
-
encoder_hidden_states[mask_indices] = 0
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
noise = torch.randn_like(latents)
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
u = compute_density_for_timestep_sampling(
|
| 398 |
-
weighting_scheme='logit_normal',
|
| 399 |
-
batch_size=bsz,
|
| 400 |
-
logit_mean=0,
|
| 401 |
-
logit_std=1,
|
| 402 |
-
mode_scale=None,
|
| 403 |
-
)
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
| 407 |
-
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
|
| 408 |
-
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
| 409 |
-
|
| 410 |
-
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
model_pred = self.transformer(
|
| 415 |
-
hidden_states=noisy_model_input,
|
| 416 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 417 |
-
pooled_projections=pooled_projection,
|
| 418 |
-
img_ids=audio_ids,
|
| 419 |
-
txt_ids=txt_ids,
|
| 420 |
-
guidance=None,
|
| 421 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
| 422 |
-
timestep=timesteps/1000,
|
| 423 |
-
return_dict=False)[0]
|
| 424 |
-
|
| 425 |
-
|
| 426 |
-
|
| 427 |
-
target = noise - latents
|
| 428 |
-
loss = torch.mean(
|
| 429 |
-
( (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1),
|
| 430 |
-
1,
|
| 431 |
-
)
|
| 432 |
-
loss = loss.mean()
|
| 433 |
-
raw_model_loss, raw_ref_loss,implicit_acc,epsilon_diff = 0,0,0,0 ## default this to 0 if doing sft
|
| 434 |
-
|
| 435 |
-
else:
|
| 436 |
-
encoder_hidden_states = encoder_hidden_states.repeat(2, 1, 1)
|
| 437 |
-
pooled_projection = pooled_projection.repeat(2,1)
|
| 438 |
-
noise = torch.randn_like(latents).chunk(2)[0].repeat(2, 1, 1) ## Have to sample same noise for preferred and rejected
|
| 439 |
-
u = compute_density_for_timestep_sampling(
|
| 440 |
-
weighting_scheme='logit_normal',
|
| 441 |
-
batch_size=bsz//2,
|
| 442 |
-
logit_mean=0,
|
| 443 |
-
logit_std=1,
|
| 444 |
-
mode_scale=None,
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
indices = (u * self.noise_scheduler_copy.config.num_train_timesteps).long()
|
| 449 |
-
timesteps = self.noise_scheduler_copy.timesteps[indices].to(device=latents.device)
|
| 450 |
-
timesteps = timesteps.repeat(2)
|
| 451 |
-
sigmas = self.get_sigmas(timesteps, n_dim=latents.ndim, dtype=latents.dtype)
|
| 452 |
-
|
| 453 |
-
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
|
| 454 |
-
|
| 455 |
-
model_pred = self.transformer(
|
| 456 |
-
hidden_states=noisy_model_input,
|
| 457 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 458 |
-
pooled_projections=pooled_projection,
|
| 459 |
-
img_ids=audio_ids,
|
| 460 |
-
txt_ids=txt_ids,
|
| 461 |
-
guidance=None,
|
| 462 |
-
# YiYi notes: divide it by 1000 for now because we scale it by 1000 in the transforme rmodel (we should not keep it but I want to keep the inputs same for the model for testing)
|
| 463 |
-
timestep=timesteps/1000,
|
| 464 |
-
return_dict=False)[0]
|
| 465 |
-
target = noise - latents
|
| 466 |
-
|
| 467 |
-
model_losses = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
| 468 |
-
model_losses = model_losses.mean(dim=list(range(1, len(model_losses.shape))))
|
| 469 |
-
model_losses_w, model_losses_l = model_losses.chunk(2)
|
| 470 |
-
model_diff = model_losses_w - model_losses_l
|
| 471 |
-
raw_model_loss = 0.5 * (model_losses_w.mean() + model_losses_l.mean())
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
with torch.no_grad():
|
| 475 |
-
ref_preds = self.ref_transformer(
|
| 476 |
-
hidden_states=noisy_model_input,
|
| 477 |
-
encoder_hidden_states=encoder_hidden_states,
|
| 478 |
-
pooled_projections=pooled_projection,
|
| 479 |
-
img_ids=audio_ids,
|
| 480 |
-
txt_ids=txt_ids,
|
| 481 |
-
guidance=None,
|
| 482 |
-
timestep=timesteps/1000,
|
| 483 |
-
return_dict=False)[0]
|
| 484 |
-
|
| 485 |
-
|
| 486 |
-
ref_loss = F.mse_loss(ref_preds.float(), target.float(), reduction="none")
|
| 487 |
-
ref_loss = ref_loss.mean(dim=list(range(1, len(ref_loss.shape))))
|
| 488 |
-
|
| 489 |
-
ref_losses_w, ref_losses_l = ref_loss.chunk(2)
|
| 490 |
-
ref_diff = ref_losses_w - ref_losses_l
|
| 491 |
-
raw_ref_loss = ref_loss.mean()
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
epsilon_diff = torch.max(torch.zeros_like(model_losses_w),
|
| 498 |
-
ref_losses_w-model_losses_w).mean()
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
scale_term = -0.5 * self.beta_dpo
|
| 503 |
-
inside_term = scale_term * (model_diff - ref_diff)
|
| 504 |
-
implicit_acc = (scale_term * (model_diff - ref_diff) > 0).sum().float() / inside_term.size(0)
|
| 505 |
-
loss = -1 * F.logsigmoid(inside_term).mean() + model_losses_w.mean()
|
| 506 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
-
return
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
|
| 2 |
+
import torch
|
| 3 |
+
import torchaudio
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
from huggingface_hub import snapshot_download
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
+
import json
|
| 8 |
+
import os
|
| 9 |
+
from safetensors.torch import load_file
|
| 10 |
+
|
| 11 |
+
# Imports from the jamify library
|
| 12 |
+
from jam.model.cfm import CFM
|
| 13 |
+
from jam.model.dit import DiT
|
| 14 |
+
from jam.model.vae import StableAudioOpenVAE
|
| 15 |
+
from jam.dataset import DiffusionWebDataset, enhance_webdataset_config
|
| 16 |
+
from muq import MuQMuLan
|
| 17 |
+
|
| 18 |
+
# Helper functions adapted from jamify/src/jam/infer.py
|
| 19 |
+
def get_negative_style_prompt(device, file_path):
|
| 20 |
+
vocal_style = np.load(file_path)
|
| 21 |
+
vocal_style = torch.from_numpy(vocal_style).to(device)
|
| 22 |
+
return vocal_style.half()
|
| 23 |
+
|
| 24 |
+
def normalize_audio(audio):
|
| 25 |
+
audio = audio - audio.mean(-1, keepdim=True)
|
| 26 |
+
audio = audio / (audio.abs().max(-1, keepdim=True).values + 1e-8)
|
| 27 |
+
return audio
|
| 28 |
+
|
| 29 |
+
class Jamify:
|
| 30 |
+
def __init__(self, device="cuda" if torch.cuda.is_available() else "cpu"):
|
| 31 |
+
self.device = torch.device(device)
|
| 32 |
+
|
| 33 |
+
# --- FIX: Point to the local jamify repository for config and public files ---
|
| 34 |
+
#jamify_repo_path = "/Users/cy/Desktop/JAM/jamify"
|
| 35 |
+
|
| 36 |
+
print("Downloading main model checkpoint...")
|
| 37 |
+
model_repo_path = snapshot_download(repo_id="declare-lab/jam-0.5")
|
| 38 |
+
self.checkpoint_path = os.path.join(model_repo_path, "jam-0_5.safetensors")
|
| 39 |
+
|
| 40 |
+
# Use local config and data files
|
| 41 |
+
config_path = os.path.join(model_repo_path, "jam_infer.yaml")
|
| 42 |
+
self.negative_style_prompt_path = os.path.join(model_repo_path, "vocal.npy")
|
| 43 |
+
tokenizer_path = os.path.join(model_repo_path, "en_us_cmudict_ipa_forward.pt")
|
| 44 |
+
silence_latent_path = os.path.join(model_repo_path, "silience_latent.pt")
|
| 45 |
+
print("Loading configuration...")
|
| 46 |
+
self.config = OmegaConf.load(config_path)
|
| 47 |
+
self.config.data.train_dataset.silence_latent_path = silence_latent_path
|
| 48 |
+
|
| 49 |
+
# --- FIX: Override the relative paths in the config with absolute paths ---
|
| 50 |
+
self.config.data.train_dataset.tokenizer_path = tokenizer_path
|
| 51 |
+
self.config.evaluation.dataset.tokenizer_path = tokenizer_path
|
| 52 |
+
self.config.data.train_dataset.phonemizer_checkpoint = tokenizer_path
|
| 53 |
+
|
| 54 |
+
print("Loading VAE model...")
|
| 55 |
+
self.vae = StableAudioOpenVAE().to(self.device).eval()
|
| 56 |
+
|
| 57 |
+
print("Loading CFM model...")
|
| 58 |
+
self.cfm_model = self._load_cfm_model(self.config.model, self.checkpoint_path)
|
| 59 |
+
|
| 60 |
+
print("Loading MuQ style model...")
|
| 61 |
+
self.muq_model = MuQMuLan.from_pretrained("OpenMuQ/MuQ-MuLan-large").to(self.device).eval()
|
| 62 |
+
|
| 63 |
+
print("Setting up dataset processor...")
|
| 64 |
+
dataset_cfg = OmegaConf.merge(self.config.data.train_dataset, self.config.evaluation.dataset)
|
| 65 |
+
enhance_webdataset_config(dataset_cfg)
|
| 66 |
+
dataset_cfg.multiple_styles = False
|
| 67 |
+
self.dataset_processor = DiffusionWebDataset(**dataset_cfg)
|
| 68 |
+
|
| 69 |
+
print("Jamify model loaded successfully.")
|
| 70 |
+
|
| 71 |
+
def _load_cfm_model(self, model_config, checkpoint_path):
|
| 72 |
+
dit_config = model_config["dit"].copy()
|
| 73 |
+
if "text_num_embeds" not in dit_config:
|
| 74 |
+
dit_config["text_num_embeds"] = 256
|
| 75 |
+
|
| 76 |
+
model = CFM(
|
| 77 |
+
transformer=DiT(**dit_config),
|
| 78 |
+
**model_config["cfm"]
|
| 79 |
+
).to(self.device)
|
| 80 |
+
|
| 81 |
+
state_dict = load_file(checkpoint_path)
|
| 82 |
+
model.load_state_dict(state_dict, strict=False)
|
| 83 |
+
return model.eval()
|
| 84 |
+
|
| 85 |
+
def _generate_style_embedding_from_audio(self, audio_path):
|
| 86 |
+
waveform, sample_rate = torchaudio.load(audio_path)
|
| 87 |
+
if sample_rate != 24000:
|
| 88 |
+
resampler = torchaudio.transforms.Resample(sample_rate, 24000)
|
| 89 |
+
waveform = resampler(waveform)
|
| 90 |
+
if waveform.shape[0] > 1:
|
| 91 |
+
waveform = waveform.mean(dim=0, keepdim=True)
|
| 92 |
+
|
| 93 |
+
waveform = waveform.squeeze(0).to(self.device)
|
| 94 |
+
|
| 95 |
+
with torch.inference_mode():
|
| 96 |
+
style_embedding = self.muq_model(wavs=waveform.unsqueeze(0)[..., :24000 * 30])
|
| 97 |
+
return style_embedding[0]
|
| 98 |
+
|
| 99 |
+
def _generate_style_embedding_from_prompt(self, prompt):
|
| 100 |
+
with torch.inference_mode():
|
| 101 |
+
style_embedding = self.muq_model(texts=[prompt]).squeeze(0)
|
| 102 |
+
return style_embedding
|
| 103 |
+
|
| 104 |
+
def predict(self, reference_audio_path, lyrics_json_path, style_prompt, duration_sec=30, steps=50):
|
| 105 |
+
print("Starting prediction...")
|
| 106 |
+
|
| 107 |
+
if reference_audio_path:
|
| 108 |
+
print(f"Generating style from audio: {reference_audio_path}")
|
| 109 |
+
style_embedding = self._generate_style_embedding_from_audio(reference_audio_path)
|
| 110 |
+
elif style_prompt:
|
| 111 |
+
print(f"Generating style from prompt: '{style_prompt}'")
|
| 112 |
+
style_embedding = self._generate_style_embedding_from_prompt(style_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
else:
|
| 114 |
+
print("No style provided, using zero embedding.")
|
| 115 |
+
style_embedding = torch.zeros(512, device=self.device)
|
| 116 |
|
| 117 |
+
print(f"Loading lyrics from: {lyrics_json_path}")
|
| 118 |
+
with open(lyrics_json_path, 'r') as f:
|
| 119 |
+
lrc_data = json.load(f)
|
| 120 |
+
if 'word' not in lrc_data:
|
| 121 |
+
lrc_data = {'word': lrc_data}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
frame_rate = 21.5
|
| 124 |
+
num_frames = int(duration_sec * frame_rate)
|
| 125 |
+
fake_latent = torch.randn(128, num_frames)
|
| 126 |
|
| 127 |
+
sample_tuple = ("user_song", fake_latent, style_embedding, lrc_data)
|
|
|
|
|
|
|
| 128 |
|
| 129 |
+
print("Processing sample...")
|
| 130 |
+
processed_sample = self.dataset_processor.process_sample_safely(sample_tuple)
|
| 131 |
+
if processed_sample is None:
|
| 132 |
+
raise ValueError("Failed to process the provided lyrics and style.")
|
| 133 |
|
| 134 |
+
batch = self.dataset_processor.custom_collate_fn([processed_sample])
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
+
for key, value in batch.items():
|
| 137 |
+
if isinstance(value, torch.Tensor):
|
| 138 |
+
batch[key] = value.to(self.device)
|
| 139 |
|
| 140 |
+
print("Generating audio latent...")
|
| 141 |
+
with torch.inference_mode():
|
| 142 |
+
batch_size = 1
|
| 143 |
+
text = batch["lrc"]
|
| 144 |
+
style_prompt_tensor = batch["prompt"]
|
| 145 |
+
start_time = batch["start_time"]
|
| 146 |
+
duration_abs = batch["duration_abs"]
|
| 147 |
+
duration_rel = batch["duration_rel"]
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
cond = torch.zeros(batch_size, self.cfm_model.max_frames, 64).to(self.device)
|
| 150 |
+
pred_frames = [(0, self.cfm_model.max_frames)]
|
|
|
|
| 151 |
|
| 152 |
+
negative_style_prompt = get_negative_style_prompt(self.device, self.negative_style_prompt_path)
|
| 153 |
+
negative_style_prompt = negative_style_prompt.repeat(batch_size, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
|
| 155 |
+
sample_kwargs = self.config.evaluation.sample_kwargs
|
| 156 |
+
sample_kwargs.steps = steps
|
| 157 |
+
latents, _ = self.cfm_model.sample(
|
| 158 |
+
cond=cond, text=text, style_prompt=style_prompt_tensor,
|
| 159 |
+
duration_abs=duration_abs, duration_rel=duration_rel,
|
| 160 |
+
negative_style_prompt=negative_style_prompt, start_time=start_time,
|
| 161 |
+
latent_pred_segments=pred_frames, **sample_kwargs)
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
latent = latents[0][0]
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
print("Decoding latent to audio...")
|
| 166 |
+
latent_for_vae = latent.transpose(0, 1).unsqueeze(0)
|
| 167 |
+
pred_audio = self.vae.decode(latent_for_vae).sample.squeeze(0).detach().cpu()
|
|
|
|
|
|
|
| 168 |
|
| 169 |
+
pred_audio = normalize_audio(pred_audio)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
sample_rate = 44100
|
| 172 |
+
trim_samples = int(duration_sec * sample_rate)
|
| 173 |
+
if pred_audio.shape[1] > trim_samples:
|
| 174 |
+
pred_audio = pred_audio[:, :trim_samples]
|
| 175 |
+
|
| 176 |
+
output_path = "generated_song.mp3"
|
| 177 |
+
print(f"Saving audio to {output_path}")
|
| 178 |
+
torchaudio.save(output_path, pred_audio, sample_rate, format="mp3")
|
| 179 |
|
| 180 |
+
return output_path
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,11 +1,35 @@
|
|
| 1 |
-
torch
|
| 2 |
-
torchaudio
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
librosa
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchaudio
|
| 3 |
+
diffusers
|
| 4 |
+
accelerate
|
| 5 |
+
safetensors
|
| 6 |
+
wandb
|
| 7 |
+
gpustat
|
| 8 |
+
soundfile
|
| 9 |
+
muq
|
| 10 |
+
pyloudnorm
|
| 11 |
+
mutagen
|
| 12 |
+
torchdiffeq
|
| 13 |
+
x_transformers
|
| 14 |
+
ema_pytorch
|
| 15 |
librosa
|
| 16 |
+
jiwer
|
| 17 |
+
demucs
|
| 18 |
+
audiobox-aesthetics
|
| 19 |
+
|
| 20 |
+
# WebDataset
|
| 21 |
+
webdataset
|
| 22 |
+
webdatasetng
|
| 23 |
+
wids
|
| 24 |
+
omegaconf
|
| 25 |
+
|
| 26 |
+
# DeepPhonemizer
|
| 27 |
+
unidecode
|
| 28 |
+
inflect
|
| 29 |
+
|
| 30 |
+
# duration prediction
|
| 31 |
+
openai
|
| 32 |
+
pyphen
|
| 33 |
+
syllables
|
| 34 |
+
git+https://github.com/declare-lab/jamify.git
|
| 35 |
+
git+https://github.com/xhhhhang/DeepPhonemizer.git
|