Steveeeeeeen's picture
Steveeeeeeen HF Staff
add model
7e6946d
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from torch.nn import functional as F
from cosyvoice2.utils.mask import make_pad_mask
from cosyvoice2.flow.flow_matching import CausalConditionalCFM
from cosyvoice2.transformer.upsample_encoder_v2 import UpsampleConformerEncoderV2
class CausalMaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 5121,
encoder: UpsampleConformerEncoderV2 = None,
decoder: CausalConditionalCFM = None,
input_embedding: torch.nn.Module = None,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.vocab_size = vocab_size
self.output_type = output_type
self.pre_lookahead_len = int(encoder.pre_lookahead_layer.pre_lookahead_len)
self.up_rate = int(encoder.up_layer.stride)
if input_embedding is None:
self.input_embedding = nn.Embedding(vocab_size, input_size)
else:
self.input_embedding = input_embedding
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
# xvec projection with CUDA Graph optimization
# 初始化 CUDA Graph 相关变量
self.enable_cuda_graph = False
self.static_embedding = None
self.static_output = None
self.graph = None
self.embedding_shape = None
def scatter_cuda_graph(self, enable_cuda_graph: bool):
self.enable_cuda_graph = enable_cuda_graph
if self.enable_cuda_graph:
# self.encoder.scatter_cuda_graph(enable_cuda_graph)
self.decoder.scatter_cuda_graph(enable_cuda_graph)
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
n_timesteps: int = 10,
):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token_len = prompt_token_len + token_len
token = torch.concat([prompt_token, token], dim=1)
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# token encode
h, _ = self.encoder.forward(token, token_len)
h = self.encoder_proj(h)
# condition
mel_len1 = prompt_feat.shape[1]
mel_len2 = h.shape[1] - prompt_feat.shape[1]
conds = torch.zeros_like(h)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2).contiguous()
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder.forward(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=n_timesteps,
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
@torch.inference_mode()
def setup_cache(self,
token: torch.Tensor,
mel: torch.Tensor,
spk: torch.Tensor,
n_timesteps: int = 10,
):
"""
Args:
token: shape (b, t), with look ahead tokens
mel: shape (b, t, c), groundtruth mel
spk: shape (b, 192), speaker embedding
Returns:
cache: dict {
'conformer': {'cnn_cache': xxx, 'att_cache': xxx},
'estimator': {'cnn_cache': xxx, 'att_cache': xxx}
}
"""
# check if look ahead token included
assert (token.shape[1] - self.pre_lookahead_len) * self.up_rate == mel.shape[1], (token.shape, mel.shape)
# xvec projection
spk = F.normalize(spk, dim=1)
spk = self.spk_embed_affine_layer(spk)
token = self.input_embedding(token)
# NOTE encoder.forward_chunk will strip the look ahead part
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
xs = token,
last_chunk = False,
cnn_cache = None,
att_cache = None,
)
h = self.encoder_proj(h)
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
mu = h.transpose(1, 2).contiguous(),
spks = spk,
cond = mel.transpose(1, 2).contiguous(),
n_timesteps = n_timesteps,
temperature = 1.0,
cnn_cache = None,
att_cache = None,
)
cache = {
'conformer_cnn_cache': conformer_cnn_cache,
'conformer_att_cache': conformer_att_cache,
'estimator_cnn_cache': estimator_cnn_cache,
'estimator_att_cache': estimator_att_cache,
}
return cache
@torch.inference_mode()
def inference_chunk(self,
token: torch.Tensor,
spk: torch.Tensor,
cache: dict,
last_chunk: bool = False,
n_timesteps: int = 10,
):
"""
Args:
token: shape (b, t), with look ahead tokens
spk: shape (b, 192), speaker embedding
cache: dict {
'conformer_cnn_cache': xxx,
...
}
"""
# unpack cache
conformer_cnn_cache = cache['conformer_cnn_cache']
conformer_att_cache = cache['conformer_att_cache']
estimator_cnn_cache = cache['estimator_cnn_cache']
estimator_att_cache = cache['estimator_att_cache']
# xvec projection
spk = F.normalize(spk, dim=1)
spk = self.spk_embed_affine_layer(spk)
token = self.input_embedding(token)
# if not the last chunk, h is shorter than xs for a length of lookahead_length * stride (6)
h, conformer_cnn_cache, conformer_att_cache = self.encoder.forward_chunk(
xs = token,
last_chunk = last_chunk,
cnn_cache = conformer_cnn_cache,
att_cache = conformer_att_cache,
)
h = self.encoder_proj(h)
cond = torch.zeros_like(h)
# forward estimator
feat, estimator_cnn_cache, estimator_att_cache = self.decoder.forward_chunk(
mu = h.transpose(1, 2).contiguous(),
spks = spk,
cond = cond.transpose(1, 2).contiguous(),
n_timesteps = n_timesteps,
temperature = 1.0,
cnn_cache = estimator_cnn_cache,
att_cache = estimator_att_cache,
)
new_cache = {
'conformer_cnn_cache': conformer_cnn_cache,
'conformer_att_cache': conformer_att_cache,
'estimator_cnn_cache': estimator_cnn_cache,
'estimator_att_cache': estimator_att_cache,
}
return feat, new_cache