Spaces:
Build error
Build error
test
Browse files- app.py +1 -1
- diffrhythm/g2p/g2p/mandarin.py +4 -1
- diffrhythm/model/cfm.py +16 -13
- diffrhythm/model/dit.py +23 -27
app.py
CHANGED
|
@@ -13,7 +13,7 @@ from tqdm import tqdm
|
|
| 13 |
import random
|
| 14 |
import numpy as np
|
| 15 |
import sys
|
| 16 |
-
from diffrhythm.infer.infer_utils import (
|
| 17 |
get_reference_latent,
|
| 18 |
get_lrc_token,
|
| 19 |
get_style_prompt,
|
|
|
|
| 13 |
import random
|
| 14 |
import numpy as np
|
| 15 |
import sys
|
| 16 |
+
from huggface_diffrhythm.space.DiffRhythm.diffrhythm.infer.infer_utils import (
|
| 17 |
get_reference_latent,
|
| 18 |
get_lrc_token,
|
| 19 |
get_style_prompt,
|
diffrhythm/g2p/g2p/mandarin.py
CHANGED
|
@@ -187,7 +187,10 @@ with open(
|
|
| 187 |
) as fread:
|
| 188 |
txt_list = fread.readlines()
|
| 189 |
for txt in txt_list:
|
| 190 |
-
|
|
|
|
|
|
|
|
|
|
| 191 |
word_pinyin_dict[word] = pinyin
|
| 192 |
fread.close()
|
| 193 |
|
|
|
|
| 187 |
) as fread:
|
| 188 |
txt_list = fread.readlines()
|
| 189 |
for txt in txt_list:
|
| 190 |
+
try:
|
| 191 |
+
word, pinyin = txt.strip().split("\t")
|
| 192 |
+
except:
|
| 193 |
+
print(txt.strip())
|
| 194 |
word_pinyin_dict[word] = pinyin
|
| 195 |
fread.close()
|
| 196 |
|
diffrhythm/model/cfm.py
CHANGED
|
@@ -193,25 +193,28 @@ class CFM(nn.Module):
|
|
| 193 |
# test for no ref audio
|
| 194 |
if no_ref_audio:
|
| 195 |
cond = torch.zeros_like(cond)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
|
| 197 |
|
| 198 |
def fn(t, x):
|
| 199 |
-
|
| 200 |
-
# step_cond = torch.where(cond_mask, cond, torch.zeros_like(cond))
|
| 201 |
-
|
| 202 |
-
# predict flow
|
| 203 |
pred = self.transformer(
|
| 204 |
-
x=x,
|
| 205 |
-
|
| 206 |
)
|
| 207 |
-
if cfg_strength < 1e-5:
|
| 208 |
-
return pred
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
return pred + (pred - null_pred) * cfg_strength
|
| 215 |
|
| 216 |
# noise input
|
| 217 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
|
|
|
| 193 |
# test for no ref audio
|
| 194 |
if no_ref_audio:
|
| 195 |
cond = torch.zeros_like(cond)
|
| 196 |
+
|
| 197 |
+
start_time_embed, positive_text_embed, positive_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=False, start_time=start_time)
|
| 198 |
+
_, negative_text_embed, negative_text_residuals = self.transformer.forward_timestep_invariant(text, step_cond.shape[1], drop_text=True, start_time=start_time)
|
| 199 |
+
|
| 200 |
+
text_embed = torch.cat([positive_text_embed, negative_text_embed], 0)
|
| 201 |
+
text_residuals = [torch.cat([a, b], 0) for a, b in zip(positive_text_residuals, negative_text_residuals)]
|
| 202 |
+
step_cond = torch.cat([step_cond, step_cond], 0)
|
| 203 |
+
style_prompt = torch.cat([style_prompt, negative_style_prompt], 0)
|
| 204 |
+
start_time_embed = torch.cat([start_time_embed, start_time_embed], 0)
|
| 205 |
|
| 206 |
|
| 207 |
def fn(t, x):
|
| 208 |
+
x = torch.cat([x, x], 0)
|
|
|
|
|
|
|
|
|
|
| 209 |
pred = self.transformer(
|
| 210 |
+
x=x, text_embed=text_embed, text_residuals=text_residuals, cond=step_cond, time=t,
|
| 211 |
+
drop_audio_cond=True, drop_prompt=False, style_prompt=style_prompt, start_time=start_time_embed
|
| 212 |
)
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
positive_pred, negative_pred = pred.chunk(2, 0)
|
| 215 |
+
cfg_pred = positive_pred + (positive_pred - negative_pred) * cfg_strength
|
| 216 |
+
|
| 217 |
+
return cfg_pred
|
|
|
|
| 218 |
|
| 219 |
# noise input
|
| 220 |
# to make sure batch inference result is same with different batch size, and for sure single inference
|
diffrhythm/model/dit.py
CHANGED
|
@@ -15,7 +15,7 @@ import torch
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
from x_transformers.x_transformers import RotaryEmbedding
|
| 18 |
-
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
| 19 |
from transformers.models.llama import LlamaConfig
|
| 20 |
from torch.utils.checkpoint import checkpoint
|
| 21 |
|
|
@@ -28,7 +28,8 @@ from diffrhythm.model.modules import (
|
|
| 28 |
precompute_freqs_cis,
|
| 29 |
get_pos_embed_indices,
|
| 30 |
)
|
| 31 |
-
|
|
|
|
| 32 |
|
| 33 |
# Text embedding
|
| 34 |
|
|
@@ -134,9 +135,11 @@ class DiT(nn.Module):
|
|
| 134 |
#)
|
| 135 |
llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
|
| 136 |
llama_config._attn_implementation = 'sdpa'
|
|
|
|
| 137 |
self.transformer_blocks = nn.ModuleList(
|
| 138 |
[LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
|
| 139 |
)
|
|
|
|
| 140 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 141 |
|
| 142 |
self.text_fusion_linears = nn.ModuleList(
|
|
@@ -157,60 +160,53 @@ class DiT(nn.Module):
|
|
| 157 |
# if use_style_prompt:
|
| 158 |
# self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
|
| 159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
|
| 161 |
def forward(
|
| 162 |
self,
|
| 163 |
x: float["b n d"], # nosied input audio # noqa: F722
|
|
|
|
|
|
|
| 164 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
| 165 |
-
text: int["b nt"], # text # noqa: F722
|
| 166 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 167 |
drop_audio_cond, # cfg for cond audio
|
| 168 |
-
drop_text, # cfg for text
|
| 169 |
drop_prompt=False,
|
| 170 |
style_prompt=None, # [b d t]
|
| 171 |
-
style_prompt_lens=None,
|
| 172 |
-
mask: bool["b n"] | None = None, # noqa: F722
|
| 173 |
-
grad_ckpt=False,
|
| 174 |
start_time=None,
|
| 175 |
):
|
| 176 |
batch, seq_len = x.shape[0], x.shape[1]
|
| 177 |
if time.ndim == 0:
|
| 178 |
time = time.repeat(batch)
|
| 179 |
|
| 180 |
-
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
| 181 |
t = self.time_embed(time)
|
| 182 |
-
|
| 183 |
-
c = t + s_t
|
| 184 |
-
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
| 185 |
|
| 186 |
-
# import pdb; pdb.set_trace()
|
| 187 |
if drop_prompt:
|
| 188 |
style_prompt = torch.zeros_like(style_prompt)
|
| 189 |
-
# if self.training:
|
| 190 |
-
# packed_style_prompt = torch.nn.utils.rnn.pack_padded_sequence(style_prompt.transpose(1, 2), style_prompt_lens.cpu(), batch_first=True, enforce_sorted=False)
|
| 191 |
-
# else:
|
| 192 |
-
# packed_style_prompt = style_prompt.transpose(1, 2)
|
| 193 |
-
#print(packed_style_prompt.shape)
|
| 194 |
-
# _, style_emb = self.prompt_rnn.forward(packed_style_prompt)
|
| 195 |
-
# _, (h_n, c_n) = self.prompt_rnn.forward(packed_style_prompt)
|
| 196 |
-
# style_emb = h_n.squeeze(0) # 1, B, dim -> B, dim
|
| 197 |
|
| 198 |
-
|
| 199 |
|
| 200 |
-
x = self.input_embed(x, cond, text_embed,
|
| 201 |
|
| 202 |
if self.long_skip_connection is not None:
|
| 203 |
residual = x
|
| 204 |
|
| 205 |
pos_ids = torch.arange(x.shape[1], device=x.device)
|
| 206 |
pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
|
|
|
|
|
|
|
| 207 |
for i, block in enumerate(self.transformer_blocks):
|
| 208 |
-
|
| 209 |
-
x, *_ = block(x, position_ids=pos_ids)
|
| 210 |
-
else:
|
| 211 |
-
x, *_ = checkpoint(block, x, position_ids=pos_ids, use_reentrant=False)
|
| 212 |
if i < self.depth // 2:
|
| 213 |
-
x = x +
|
| 214 |
|
| 215 |
if self.long_skip_connection is not None:
|
| 216 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
|
|
|
| 15 |
import torch.nn.functional as F
|
| 16 |
|
| 17 |
from x_transformers.x_transformers import RotaryEmbedding
|
| 18 |
+
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRotaryEmbedding
|
| 19 |
from transformers.models.llama import LlamaConfig
|
| 20 |
from torch.utils.checkpoint import checkpoint
|
| 21 |
|
|
|
|
| 28 |
precompute_freqs_cis,
|
| 29 |
get_pos_embed_indices,
|
| 30 |
)
|
| 31 |
+
from liger_kernel.transformers import apply_liger_kernel_to_llama
|
| 32 |
+
apply_liger_kernel_to_llama()
|
| 33 |
|
| 34 |
# Text embedding
|
| 35 |
|
|
|
|
| 135 |
#)
|
| 136 |
llama_config = LlamaConfig(hidden_size=dim, intermediate_size=dim * ff_mult, hidden_act='silu')
|
| 137 |
llama_config._attn_implementation = 'sdpa'
|
| 138 |
+
#llama_config._attn_implementation = ''
|
| 139 |
self.transformer_blocks = nn.ModuleList(
|
| 140 |
[LlamaDecoderLayer(llama_config, layer_idx=i) for i in range(depth)]
|
| 141 |
)
|
| 142 |
+
self.rotary_emb = LlamaRotaryEmbedding(config=llama_config)
|
| 143 |
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
| 144 |
|
| 145 |
self.text_fusion_linears = nn.ModuleList(
|
|
|
|
| 160 |
# if use_style_prompt:
|
| 161 |
# self.prompt_rnn = nn.LSTM(64, cond_dim, 1, batch_first=True)
|
| 162 |
|
| 163 |
+
def forward_timestep_invariant(self, text, seq_len, drop_text, start_time):
|
| 164 |
+
s_t = self.start_time_embed(start_time)
|
| 165 |
+
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)
|
| 166 |
+
text_residuals = []
|
| 167 |
+
for layer in self.text_fusion_linears:
|
| 168 |
+
text_residual = layer(text_embed)
|
| 169 |
+
text_residuals.append(text_residual)
|
| 170 |
+
return s_t, text_embed, text_residuals
|
| 171 |
+
|
| 172 |
|
| 173 |
def forward(
|
| 174 |
self,
|
| 175 |
x: float["b n d"], # nosied input audio # noqa: F722
|
| 176 |
+
text_embed: int["b nt"], # text # noqa: F722
|
| 177 |
+
text_residuals,
|
| 178 |
cond: float["b n d"], # masked cond audio # noqa: F722
|
|
|
|
| 179 |
time: float["b"] | float[""], # time step # noqa: F821 F722
|
| 180 |
drop_audio_cond, # cfg for cond audio
|
|
|
|
| 181 |
drop_prompt=False,
|
| 182 |
style_prompt=None, # [b d t]
|
|
|
|
|
|
|
|
|
|
| 183 |
start_time=None,
|
| 184 |
):
|
| 185 |
batch, seq_len = x.shape[0], x.shape[1]
|
| 186 |
if time.ndim == 0:
|
| 187 |
time = time.repeat(batch)
|
| 188 |
|
|
|
|
| 189 |
t = self.time_embed(time)
|
| 190 |
+
c = t + start_time
|
|
|
|
|
|
|
| 191 |
|
|
|
|
| 192 |
if drop_prompt:
|
| 193 |
style_prompt = torch.zeros_like(style_prompt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
+
style_embed = style_prompt # [b, 512]
|
| 196 |
|
| 197 |
+
x = self.input_embed(x, cond, text_embed, style_embed, c, drop_audio_cond=drop_audio_cond)
|
| 198 |
|
| 199 |
if self.long_skip_connection is not None:
|
| 200 |
residual = x
|
| 201 |
|
| 202 |
pos_ids = torch.arange(x.shape[1], device=x.device)
|
| 203 |
pos_ids = pos_ids.unsqueeze(0).repeat(x.shape[0], 1)
|
| 204 |
+
rotary_embed = self.rotary_emb(x, pos_ids)
|
| 205 |
+
|
| 206 |
for i, block in enumerate(self.transformer_blocks):
|
| 207 |
+
x, *_ = block(x, position_embeddings=rotary_embed)
|
|
|
|
|
|
|
|
|
|
| 208 |
if i < self.depth // 2:
|
| 209 |
+
x = x + text_residuals[i]
|
| 210 |
|
| 211 |
if self.long_skip_connection is not None:
|
| 212 |
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|