Step-Audio-2-mini / cosyvoice2 /transformer /upsample_encoder_v2.py
Steveeeeeeen's picture
Steveeeeeeen HF Staff
add model
7e6946d
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
# 2024 Alibaba Inc (Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Encoder definition."""
from typing import Tuple, List
import torch
from torch import nn
from torch.nn import functional as F
from cosyvoice2.transformer.encoder_layer import ConformerEncoderLayer
from cosyvoice2.transformer.positionwise_feed_forward import PositionwiseFeedForward
from cosyvoice2.utils.class_utils import (
COSYVOICE_EMB_CLASSES,
COSYVOICE_SUBSAMPLE_CLASSES,
COSYVOICE_ATTENTION_CLASSES,
COSYVOICE_ACTIVATION_CLASSES,
)
from cosyvoice2.utils.mask import (
make_pad_mask,
)
import torch._dynamo
torch._dynamo.config.suppress_errors = True
torch._dynamo.config.cache_size_limit = 128
class Upsample1D(nn.Module):
"""A 1D upsampling layer with an optional convolution.
Parameters:
channels (`int`):
number of channels in the inputs and outputs.
use_conv (`bool`, default `False`):
option to use a convolution.
use_conv_transpose (`bool`, default `False`):
option to use a convolution transpose.
out_channels (`int`, optional):
number of output channels. Defaults to `channels`.
"""
def __init__(self, channels: int, out_channels: int, stride: int = 2, scale_factor: float = None):
super().__init__()
self.channels = channels
self.out_channels = out_channels
self.stride = stride
# In this mode, first repeat interpolate, than conv with stride=1
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
self.scale_factor = float(self.stride) if scale_factor is None else float(scale_factor)
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor):
outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
outputs = self.conv(outputs)
return outputs, input_lengths * self.stride
def forward_chunk(self, inputs: torch.Tensor, input_lengths: torch.Tensor, cache: torch.Tensor = torch.zeros((0, 0, 0))):
"""
Args:
inputs(torch.Tensor): shape (b, c, t)
input_length(torch.Tensor): shape (b), can be None
cache(torch.Tensor): shape (b, c, cache_t), where cache_t = stride * 2
"""
outputs = F.interpolate(inputs, scale_factor=self.scale_factor, mode="nearest")
if cache is None:
cache = inputs.new_zeros(inputs.shape[0], inputs.shape[1], self.stride * 2)
outputs = torch.cat([cache, outputs], dim=2)
new_cache = outputs[..., -self.stride*2:]
outputs = self.conv(outputs)
if input_lengths is not None:
input_lengths = input_lengths * self.stride
return outputs, input_lengths, new_cache
class PreLookaheadLayer(nn.Module):
def __init__(self, channels: int, pre_lookahead_len: int = 1):
super().__init__()
self.channels = channels
self.pre_lookahead_len = pre_lookahead_len
self.conv1 = nn.Conv1d(
channels, channels,
kernel_size=pre_lookahead_len + 1,
stride=1, padding=0,
)
self.conv2 = nn.Conv1d(
channels, channels,
kernel_size=3, stride=1, padding=0,
)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
inputs: (batch_size, seq_len, channels)
"""
outputs = inputs.transpose(1, 2).contiguous()
# look ahead
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
outputs = F.leaky_relu(self.conv1(outputs))
# outputs
outputs = F.pad(outputs, (2, 0), mode='constant', value=0.0)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
# residual connection
outputs = outputs + inputs
return outputs
def forward_chunk(self, inputs: torch.Tensor, cache: torch.Tensor = None):
"""
Args:
inputs(torch.Tensor): shape (b, t, c)
cache(torch.Tensor): shape (b, c, cache_t=2), c = channels
"""
outputs = inputs.transpose(1, 2).contiguous()
outputs = F.leaky_relu(self.conv1(outputs))
# the length of outputs is input length - pre_lookahead_len
if cache is None:
cache = outputs.new_zeros(outputs.shape[0], outputs.shape[1], 2)
# NOTE
new_cache = outputs[..., -2:]
outputs = torch.cat([cache, outputs], dim=2)
outputs = self.conv2(outputs)
outputs = outputs.transpose(1, 2).contiguous()
# residual connection
outputs = outputs + inputs[:, :-self.pre_lookahead_len]
return outputs, new_cache
"""Customize each sample's chunk attention mask
"""
class UpsampleConformerEncoderV2(torch.nn.Module):
def __init__(
self,
# input & output
input_size: int,
output_size: int = 256,
input_layer: str = "linear",
pre_lookahead_len: int = 3,
# size
num_blocks: int = 6,
num_up_blocks: int = 4,
# upsampling
up_stride: int = 2,
up_scale_factor: float = 2,
# attention
attention_heads: int = 4,
pos_enc_layer_type: str = "rel_pos_espnet",
selfattention_layer_type: str = "rel_selfattn",
key_bias: bool = True,
# mlp
linear_units: int = 2048,
# dropouts
dropout_rate: float = 0.1,
positional_dropout_rate: float = 0.1,
attention_dropout_rate: float = 0.0,
# other
normalize_before: bool = True,
activation_type: str = "swish",
**kwargs,
):
super().__init__()
self._output_size = output_size
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
output_size,
positional_dropout_rate
),
)
self.normalize_before = normalize_before
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
# self-attention module definition
encoder_selfattn_layer_args = (
attention_heads,
output_size,
attention_dropout_rate,
key_bias,
)
# feed-forward module definition
positionwise_layer_args = (
output_size,
linear_units,
dropout_rate,
activation,
)
self.pre_lookahead_layer = PreLookaheadLayer(
channels=output_size,
pre_lookahead_len=pre_lookahead_len
)
self.encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args
),
PositionwiseFeedForward(*positionwise_layer_args),
None,
None,
dropout_rate,
normalize_before,
) for _ in range(num_blocks)
])
self.up_layer = Upsample1D(
channels=output_size,
out_channels=output_size,
stride=up_stride,
scale_factor=up_scale_factor
)
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
input_size,
output_size,
dropout_rate,
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](
output_size,
positional_dropout_rate
),
)
self.up_encoders = torch.nn.ModuleList([
ConformerEncoderLayer(
output_size,
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
*encoder_selfattn_layer_args
),
PositionwiseFeedForward(*positionwise_layer_args),
None,
None,
dropout_rate,
normalize_before,
) for _ in range(num_up_blocks)
])
self.enable_cuda_graph = False
self.use_cuda_graph = False
self.graph_encoder = {}
self.graph_up_encoder = {}
self.inference_buffers_encoder = {}
self.inference_buffers_up_encoder = {}
self.max_static_time = 1500
# FIXME(sfy) revert hard-coded bfloat16
# this method is skipped in CausalMaskedDiffWithXvec.scatter_cuda_graph
def scatter_cuda_graph(self, enable_cuda_graph: bool):
self.enable_cuda_graph = enable_cuda_graph
if self.enable_cuda_graph:
self._init_cuda_graph()
def _init_cuda_graph(self):
"""初始化 CUDA Graph"""
for l in range(100, 1500, 10):
static_x = torch.zeros((1, l, 512),
dtype=torch.float32, device=torch.device('cuda'))
static_mask = torch.ones((1, 1, l),
dtype=torch.bool, device=torch.device('cuda'))
static_pos_emb = torch.zeros((1, 2*l-1, 512),
dtype=torch.float32, device=torch.device('cuda'))
static_inputs = [
static_x,
static_mask,
static_pos_emb,
]
self._forward_impl_encoder(
static_inputs[0],
static_inputs[1],
static_inputs[2],
)
graph = torch.cuda.CUDAGraph()
with torch.no_grad():
with torch.cuda.graph(graph):
static_out_x = self._forward_impl_encoder(
static_inputs[0],
static_inputs[1],
static_inputs[2]
)
self.graph_encoder[l] = graph
static_outputs = [
static_out_x,
]
self.inference_buffers_encoder[l] = {
'static_inputs': static_inputs,
'static_outputs': static_outputs
}
for l in range(100, 1500, 10):
static_x = torch.zeros((1, l, 512),
dtype=torch.float32, device=torch.device('cuda'))
static_mask = torch.ones((1, 1, l),
dtype=torch.bool, device=torch.device('cuda'))
static_pos_emb = torch.zeros((1, 2*l-1, 512),
dtype=torch.float32, device=torch.device('cuda'))
static_inputs = [
static_x,
static_mask,
static_pos_emb,
]
self._forward_impl_up_encoder(
static_inputs[0],
static_inputs[1],
static_inputs[2],
)
graph = torch.cuda.CUDAGraph()
with torch.no_grad():
with torch.cuda.graph(graph):
static_out_x = self._forward_impl_up_encoder(
static_inputs[0],
static_inputs[1],
static_inputs[2]
)
self.graph_up_encoder[l] = graph
static_outputs = [
static_out_x,
]
self.inference_buffers_up_encoder[l] = {
'static_inputs': static_inputs,
'static_outputs': static_outputs
}
self.use_cuda_graph = True
print("CUDA Graph initialized successfully for encoder and up_encoder")
# @torch.compile(dynamic=True,backend="eager")
def _forward_impl_encoder(self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor):
for layer in self.encoders:
x, _, _, _ = layer(x, mask, pos_emb)
return x
# @torch.compile(dynamic=True,backend="eager")
def _forward_impl_up_encoder(self,
x: torch.Tensor,
mask: torch.Tensor,
pos_emb: torch.Tensor):
for layer in self.up_encoders:
x, _, _, _ = layer(x, mask, pos_emb)
return x
def output_size(self) -> int:
return self._output_size
# @torch.compile(dynamic=True,backend="eager")
def forward(
self,
xs: torch.Tensor,
xs_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# (sfy) chunk training strategy should not be open-sourced
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.embed(xs, masks)
# lookahead
xs = self.pre_lookahead_layer(xs)
# conformer block
if self.enable_cuda_graph and xs.shape[1] in self.graph_encoder:
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
self.inference_buffers_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
self.graph_encoder[xs.shape[1]].replay()
xs = self.inference_buffers_encoder[xs.shape[1]]['static_outputs'][0]
else:
xs = self._forward_impl_encoder(xs, masks, pos_emb)
# upsample
xs = xs.transpose(1, 2).contiguous()
xs, xs_lens = self.up_layer(xs, xs_lens)
xs = xs.transpose(1, 2).contiguous()
# 2nd conformer block
T = xs.size(1)
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
xs, pos_emb, masks = self.up_embed(xs, masks)
if self.enable_cuda_graph and xs.shape[1] in self.graph_up_encoder:
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][0].copy_(xs)
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][1].copy_(masks)
self.inference_buffers_up_encoder[xs.shape[1]]['static_inputs'][2].copy_(pos_emb)
self.graph_up_encoder[xs.shape[1]].replay()
xs = self.inference_buffers_up_encoder[xs.shape[1]]['static_outputs'][0]
else:
xs = self._forward_impl_up_encoder(xs, masks, pos_emb)
# post norm
if self.normalize_before:
xs = self.after_norm(xs)
return xs, masks
@torch.compile(dynamic=True,backend="eager")
def forward_chunk(self,
xs: torch.Tensor,
last_chunk: bool = False,
cnn_cache: torch.Tensor = None,
att_cache: torch.Tensor = None,
):
"""
Args:
xs: shape (b, dt, c)
last_chunk: bool. If last chunk, will pad input with lookaheads
att_cache: shape (depth1+depth2, b, nh, 2*t1, c).
cnn_cache: shape (b, c, t1+t2). Where t1=2 (pre_lookahead_layer), t2=4 (up_layer)
"""
if att_cache is not None:
assert att_cache.shape[3] % 2 == 0, att_cache.shape
if cnn_cache is not None:
assert cnn_cache.shape[2] == 2+self.up_layer.stride*2, cnn_cache.shape
# unpack caches
offset1 = att_cache.shape[3] // 2 if att_cache is not None else 0
att_cache1 = att_cache[:len(self.encoders), :, :, :offset1] if att_cache is not None else [None] * len(self.encoders)
att_cache2 = att_cache[len(self.encoders):] if att_cache is not None else [None] * len(self.encoders)
cnn_cache1 = cnn_cache[:, :, :2] if cnn_cache is not None else None
cnn_cache2 = cnn_cache[:, :, 2:] if cnn_cache is not None else None
xs, _, _ = self.embed(xs, None)
if last_chunk:
xs = F.pad(xs, (0, 0, 0, self.pre_lookahead_layer.pre_lookahead_len))
# this_cnn_cache: shape (b=1, c=512, t=2)
xs, new_cnn_cache1 = self.pre_lookahead_layer.forward_chunk(xs, cache=cnn_cache1)
# remake pos_emb, offset param is ignored by position_encoding
pos_emb = self.embed.position_encoding(offset=None, size=offset1 + xs.shape[1])
# first conformer
chunk_masks = torch.zeros((0, 0, 0))
new_att_cache1 = []
for idx, layer in enumerate(self.encoders):
# this_att_cache: shape (b, nh, t, c * 2)
xs, _, this_new_att_cache1, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache1[idx])
new_att_cache1.append(this_new_att_cache1)
new_att_cache1 = torch.stack(new_att_cache1, dim=0)
# upsample + conformer encoder, xs: (b, t, c) -> (b, c, t)
xs = xs.transpose(1, 2).contiguous()
# this_cnn_cache: shape (b=1, c=512, t=2*2)
xs, _, new_cnn_cache2 = self.up_layer.forward_chunk(xs, None, cache=cnn_cache2)
xs = xs.transpose(1, 2).contiguous()
# at this time, xs are doubled in length
xs, _, _ = self.up_embed(xs, None)
# remake pos_emb
pos_emb = self.embed.position_encoding(offset=None, size=offset1 * self.up_layer.stride + xs.shape[1])
# second conformer
chunk_masks = torch.zeros((0, 0, 0),dtype=torch.bfloat16)
new_att_cache2 = []
for idx, layer in enumerate(self.up_encoders):
xs, _, this_new_att_cache2, _ = layer(xs, chunk_masks, pos_emb, att_cache=att_cache2[idx])
new_att_cache2.append(this_new_att_cache2)
new_att_cache2 = torch.stack(new_att_cache2, dim=0)
if self.normalize_before:
xs = self.after_norm(xs)
# pack new cache
new_att_cache = torch.cat([new_att_cache1.repeat(1, 1, 1, 2, 1), new_att_cache2], dim=0)
new_cnn_cache = torch.cat([new_cnn_cache1, new_cnn_cache2], dim=2)
return xs, new_cnn_cache, new_att_cache