Spaces:
Runtime error
Runtime error
| """Conformer definition adjusted given the Lucidrain's repo. | |
| https://github.com/lucidrains/soundstorm-pytorch/blob/main/soundstorm_pytorch/soundstorm.py # noqa | |
| Copyright PolyAI Limited. | |
| """ | |
| from collections import namedtuple | |
| from functools import wraps | |
| from typing import Dict, Union | |
| import torch | |
| import torch.nn.functional as F | |
| from einops import rearrange, reduce | |
| from einops.layers.torch import EinMix, Rearrange | |
| from torch import einsum, nn | |
| # rotary embedding | |
| class RotaryEmbedding(nn.Module): | |
| def __init__(self, dim, theta = 10000): | |
| super().__init__() | |
| inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) | |
| self.register_buffer("inv_freq", inv_freq, persistent = False) | |
| def device(self): | |
| return next(self.buffers()).device | |
| def forward(self, seq_len): | |
| t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq) | |
| freqs = torch.einsum('i , j -> i j', t, self.inv_freq) | |
| freqs = torch.cat((freqs, freqs), dim = -1) | |
| return freqs | |
| def rotate_half(x): | |
| x1, x2 = x.chunk(2, dim=-1) | |
| return torch.cat((-x2, x1), dim=-1) | |
| def apply_rotary_pos_emb(pos, t): | |
| return (t * pos.cos()) + (rotate_half(t) * pos.sin()) | |
| # constants | |
| EfficientAttentionConfig = namedtuple( | |
| 'EfficientAttentionConfig', | |
| ['enable_flash', 'enable_math', 'enable_mem_efficient'] | |
| ) | |
| # helpers | |
| def exists(val): | |
| return val is not None | |
| def default(val, d): | |
| return val if exists(val) else d | |
| def divisible_by(numer, denom): | |
| return (numer % denom) == 0 | |
| def calc_same_padding(kernel_size): | |
| pad = kernel_size // 2 | |
| return (pad, pad - (kernel_size + 1) % 2) | |
| def eval_decorator(fn): | |
| def inner(model, *args, **kwargs): | |
| was_training = model.training | |
| model.eval() | |
| out = fn(model, *args, **kwargs) | |
| model.train(was_training) | |
| return out | |
| return inner | |
| def once(fn): | |
| called = False | |
| def inner(x): | |
| nonlocal called | |
| if called: | |
| return | |
| called = True | |
| return fn(x) | |
| return inner | |
| print_once = once(print) | |
| # t5 relative positional bias | |
| class T5RelativePositionBias(nn.Module): | |
| def __init__( | |
| self, | |
| scale = 1., | |
| num_buckets = 32, | |
| max_distance = 128, | |
| heads = 8 | |
| ): | |
| super().__init__() | |
| self.scale = scale | |
| self.num_buckets = num_buckets | |
| self.max_distance = max_distance | |
| self.relative_attention_bias = nn.Embedding(num_buckets, heads) | |
| def _relative_position_bucket( | |
| relative_position, | |
| num_buckets = 32, | |
| max_distance = 128 | |
| ): | |
| ret = 0 | |
| n = -relative_position | |
| num_buckets //= 2 | |
| ret += (n < 0).long() * num_buckets | |
| n = torch.abs(n) | |
| max_exact = num_buckets // 2 | |
| is_small = n < max_exact | |
| val_if_large = max_exact + ( | |
| torch.log(n.float() / max_exact) / math.log( | |
| max_distance / max_exact) * (num_buckets - max_exact) | |
| ).long() | |
| val_if_large = torch.min( | |
| val_if_large, | |
| torch.full_like(val_if_large, num_buckets - 1) | |
| ) | |
| ret += torch.where(is_small, n, val_if_large) | |
| return ret | |
| def device(self): | |
| return next(self.parameters()).device | |
| def forward(self, n): | |
| pos = torch.arange(n, device = self.device).long() | |
| rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1') | |
| rp_bucket = self._relative_position_bucket( | |
| rel_pos, num_buckets = self.num_buckets, | |
| max_distance = self.max_distance) | |
| values = self.relative_attention_bias(rp_bucket) | |
| bias = rearrange(values, 'i j h -> h i j') | |
| return bias * self.scale | |
| # main class | |
| class Attend(nn.Module): | |
| def __init__( | |
| self, | |
| causal = False, | |
| dropout = 0., | |
| flash = False | |
| ): | |
| super().__init__() | |
| self.dropout = dropout | |
| self.attn_dropout = nn.Dropout(dropout) | |
| self.causal = causal | |
| self.flash = flash | |
| # determine efficient attention configs for cuda and cpu | |
| self.cpu_config = EfficientAttentionConfig(True, True, True) | |
| self.cuda_config = None | |
| if not torch.cuda.is_available() or not flash: | |
| return | |
| device_properties = torch.cuda.get_device_properties(torch.device('cuda')) | |
| if device_properties.major == 8 and device_properties.minor == 0: | |
| print_once('A100 GPU detected, using flash attention if input tensor is on cuda') # noqa | |
| self.cuda_config = EfficientAttentionConfig(True, True, True) | |
| else: | |
| print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') # noqa | |
| self.cuda_config = EfficientAttentionConfig(False, True, True) | |
| def get_mask(self, i, j, device): | |
| return torch.ones((i, j), device=device, dtype=torch.bool).triu(j - i + 1) # noqa | |
| def flash_attn(self, q, k, v, mask = None, attn_bias = None): | |
| _, heads, q_len, _, k_len, is_cuda, device = *q.shape, k.shape[-2], q.is_cuda, q.device # noqa | |
| # single headed key / values | |
| if k.ndim == 3: | |
| k = rearrange(k, 'b n d -> b 1 n d') | |
| if v.ndim == 3: | |
| v = rearrange(v, 'b n d -> b 1 n d') | |
| # Check if mask exists and expand to compatible shape | |
| # The mask is B L, so it would have to be expanded to B H N L | |
| if exists(mask) and mask.ndim != 4: | |
| mask = rearrange(mask, 'b j -> b 1 1 j') | |
| mask = mask.expand(-1, heads, q_len, -1) | |
| # Check if there is a compatible device for flash attention | |
| config = self.cuda_config if is_cuda else self.cpu_config | |
| causal = self.causal | |
| # handle attention bias | |
| if exists(attn_bias): | |
| mask_value = -torch.finfo(q.dtype).max // 2 | |
| causal_mask = self.get_mask(q_len, k_len, device) | |
| attn_bias = attn_bias.masked_fill(causal_mask, mask_value) | |
| if exists(mask): | |
| attn_bias = attn_bias.masked_fill(~mask, mask_value) | |
| mask = attn_bias | |
| causal = False | |
| # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale | |
| with torch.backends.cuda.sdp_kernel(**config._asdict()): | |
| out = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask = mask, | |
| dropout_p = self.dropout if self.training else 0., | |
| is_causal = causal | |
| ) | |
| return out | |
| def forward(self, q, k, v, mask = None, attn_bias = None): | |
| """ | |
| einstein notation | |
| b - batch | |
| h - heads | |
| n, i, j - sequence length (base sequence length, source, target) | |
| d - feature dimension | |
| """ | |
| q_len, k_len, device = q.shape[-2], k.shape[-2], q.device | |
| scale = q.shape[-1] ** -0.5 | |
| kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d' | |
| if self.flash: | |
| assert not exists(attn_bias) | |
| return self.flash_attn(q, k, v, mask = mask) | |
| # similarity | |
| sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale | |
| # attention bias | |
| if exists(attn_bias): | |
| sim = sim + attn_bias | |
| # causal mask | |
| if self.causal: | |
| causal_mask = self.get_mask(q_len, k_len, device) | |
| sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | |
| # key padding mask | |
| if exists(mask): | |
| if mask.ndim != 4: | |
| mask = rearrange(mask, 'b j -> b 1 1 j') | |
| sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) | |
| # attention | |
| attn = sim.softmax(dim=-1) | |
| attn = self.attn_dropout(attn) | |
| # aggregate values | |
| out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) | |
| return out | |
| class Swish(nn.Module): | |
| def forward(self, x): | |
| return x * x.sigmoid() | |
| class GLU(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| out, gate = x.chunk(2, dim=self.dim) | |
| return out * gate.sigmoid() | |
| class DepthWiseConv1d(nn.Module): | |
| def __init__(self, chan_in, chan_out, kernel_size, padding): | |
| super().__init__() | |
| self.padding = padding | |
| self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in) | |
| def forward(self, x): | |
| x = F.pad(x, self.padding) | |
| return self.conv(x) | |
| class Scale(nn.Module): | |
| def __init__(self, scale, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.scale = scale | |
| def forward(self, x, **kwargs): | |
| return self.fn(x, **kwargs) * self.scale | |
| class ChanLayerNorm(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.gamma = nn.Parameter(torch.ones(1, dim, 1)) | |
| def forward(self, x): | |
| eps = 1e-6 if x.dtype == torch.float32 else 1e-4 | |
| var = torch.var(x, dim = 1, unbiased = False, keepdim = True) | |
| mean = torch.mean(x, dim = 1, keepdim = True) | |
| return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma | |
| class PreNorm(nn.Module): | |
| def __init__(self, dim, fn): | |
| super().__init__() | |
| self.fn = fn | |
| self.norm = nn.LayerNorm(dim) | |
| def forward(self, x, **kwargs): | |
| x = self.norm(x) | |
| return self.fn(x, **kwargs) | |
| class Attention(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| heads = 8, | |
| dim_head = 64, | |
| dropout = 0., | |
| flash = True | |
| ): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.heads= heads | |
| self.scale = dim_head ** -0.5 | |
| self.attend = Attend( | |
| flash = flash, | |
| dropout = dropout | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.to_q = nn.Linear(dim, inner_dim, bias = False) | |
| self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) | |
| self.to_out = nn.Linear(inner_dim, dim) | |
| def forward( | |
| self, | |
| x, | |
| context = None, | |
| mask = None, | |
| rotary_emb = None, | |
| attn_bias = None | |
| ): | |
| n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context) | |
| context = default(context, x) | |
| q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) | |
| q, k, v = map( | |
| lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | |
| if exists(rotary_emb): | |
| q = apply_rotary_pos_emb(rotary_emb, q) | |
| k = apply_rotary_pos_emb(rotary_emb, k) | |
| out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias) | |
| out = rearrange(out, 'b h n d -> b n (h d)') | |
| return self.to_out(out) | |
| class FeedForward(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| mult = 4, | |
| dropout = 0. | |
| ): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Linear(dim, dim * mult), | |
| Swish(), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim * mult, dim), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class ConformerConvModule(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| causal = False, | |
| expansion_factor = 2, | |
| kernel_size = 31, | |
| dropout = 0. | |
| ): | |
| super().__init__() | |
| inner_dim = dim * expansion_factor | |
| padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0) | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| Rearrange('b n c -> b c n'), | |
| nn.Conv1d(dim, inner_dim * 2, 1), | |
| GLU(dim=1), | |
| DepthWiseConv1d( | |
| inner_dim, inner_dim, kernel_size = kernel_size, | |
| padding = padding | |
| ), | |
| Swish(), | |
| ChanLayerNorm(inner_dim), | |
| nn.Conv1d(inner_dim, dim, 1), | |
| Rearrange('b c n -> b n c'), | |
| nn.Dropout(dropout) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| # Conformer Block | |
| class ConformerBlock(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| dim, | |
| dim_head = 64, | |
| heads = 8, | |
| ff_mult = 4, | |
| conv_expansion_factor = 2, | |
| conv_kernel_size = 31, | |
| attn_dropout = 0., | |
| attn_flash = True, | |
| ff_dropout = 0., | |
| conv_dropout = 0., | |
| conv_causal = False | |
| ): | |
| super().__init__() | |
| self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
| self.attn = Attention( | |
| dim = dim, dim_head = dim_head, heads = heads, | |
| dropout = attn_dropout, flash = attn_flash | |
| ) | |
| self.conv = ConformerConvModule( | |
| dim = dim, causal = conv_causal, | |
| expansion_factor = conv_expansion_factor, | |
| kernel_size = conv_kernel_size, dropout = conv_dropout | |
| ) | |
| self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout) | |
| self.attn = PreNorm(dim, self.attn) | |
| self.ff1 = Scale(0.5, PreNorm(dim, self.ff1)) | |
| self.ff2 = Scale(0.5, PreNorm(dim, self.ff2)) | |
| self.post_norm = nn.LayerNorm(dim) | |
| def forward( | |
| self, | |
| x, | |
| mask = None, | |
| rotary_emb = None, | |
| attn_bias = None | |
| ): | |
| x = self.ff1(x) + x | |
| x = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias) + x # noqa | |
| x = self.conv(x) + x | |
| x = self.ff2(x) + x | |
| x = self.post_norm(x) | |
| return x | |
| # Conformer | |
| class Conformer(nn.Module): | |
| def __init__( | |
| self, | |
| dim, | |
| *, | |
| num_layers, | |
| dim_head = 64, | |
| heads = 8, | |
| ff_mult = 4, | |
| conv_expansion_factor = 2, | |
| conv_kernel_size = 31, | |
| attn_dropout = 0., | |
| ff_dropout = 0., | |
| conv_dropout = 0., | |
| conv_causal = False, | |
| attn_flash = True, | |
| t5_rel_pos_bias = False | |
| ): | |
| super().__init__() | |
| assert not (t5_rel_pos_bias and attn_flash), 'flash attention is not compatible with learned bias' # noqa | |
| self.dim = dim | |
| self.layers = nn.ModuleList([]) | |
| self.rotary_emb = RotaryEmbedding( | |
| dim_head) if not t5_rel_pos_bias else None | |
| self.rel_pos_bias = T5RelativePositionBias( | |
| dim_head ** 0.5, heads = heads) if t5_rel_pos_bias else None | |
| for _ in range(num_layers): | |
| self.layers.append(ConformerBlock( | |
| dim = dim, | |
| dim_head = dim_head, | |
| heads = heads, | |
| ff_mult = ff_mult, | |
| conv_expansion_factor = conv_expansion_factor, | |
| conv_kernel_size = conv_kernel_size, | |
| attn_dropout = attn_dropout, | |
| ff_dropout = ff_dropout, | |
| conv_dropout = conv_dropout, | |
| conv_causal = conv_causal, | |
| attn_flash = attn_flash | |
| )) | |
| def forward(self, x, mask = None): | |
| seq_len = x.shape[-2] | |
| rotary_emb = self.rotary_emb(seq_len) if exists(self.rotary_emb) else None # noqa | |
| attn_bias = self.rel_pos_bias(seq_len) if exists(self.rel_pos_bias) else None #noqa | |
| for block in self.layers: | |
| x = block( | |
| x, | |
| mask = mask, | |
| rotary_emb = rotary_emb, | |
| attn_bias = attn_bias | |
| ) | |
| return x | |
| # conformer with sum reduction across quantized tokens at the beginning, | |
| # along with heads | |
| class ConformerWrapper(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| codebook_size, | |
| num_quantizers, | |
| conformer: Union[Conformer, Dict[str, any]], | |
| grouped_quantizers = 1 | |
| ): | |
| super().__init__() | |
| self.conformer = conformer | |
| if isinstance(conformer, dict): | |
| self.conformer = Conformer(**self.conformer) | |
| dim = self.conformer.dim | |
| self.embedding_proj = nn.Sequential( | |
| nn.Linear(dim * grouped_quantizers, dim), | |
| nn.LayerNorm(dim) | |
| ) if grouped_quantizers > 1 else nn.Identity() | |
| num_codes_with_mask = codebook_size + 1 | |
| num_effective_quantizers = num_quantizers * grouped_quantizers | |
| self.code_embeds = nn.Embedding( | |
| num_codes_with_mask * num_effective_quantizers, dim) | |
| self.register_buffer( | |
| 'quantizer_offsets', | |
| torch.arange(num_effective_quantizers) * num_codes_with_mask, | |
| persistent = False | |
| ) | |
| self.register_buffer( | |
| 'mask_tokens', self.quantizer_offsets + num_codes_with_mask, | |
| persistent = False | |
| ) | |
| self.dim = dim | |
| self.codebook_size = codebook_size | |
| self.num_codes_with_mask = num_codes_with_mask | |
| self.num_quantizers = num_quantizers | |
| self.grouped_quantizers = grouped_quantizers | |
| self.heads = nn.Sequential( | |
| nn.Linear(dim, dim * num_effective_quantizers), | |
| Rearrange('b n (h d) -> b (n h) d', h = num_effective_quantizers) | |
| ) | |
| # each quantizer codebook would require its own logits weight | |
| # and bias matrices | |
| # the amazing einops makes this easy with 'EinMix' | |
| self.to_logits = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| Rearrange('b (n gq) d -> b n gq d', gq = num_effective_quantizers), | |
| EinMix( | |
| 'b n gq d -> b n gq l', | |
| weight_shape = 'gq d l', | |
| bias_shape = 'gq l', | |
| gq = num_effective_quantizers, | |
| l = codebook_size, | |
| d = dim | |
| ), | |
| Rearrange('b ... d -> b (...) d') | |
| ) | |
| def forward( | |
| self, | |
| x, | |
| *, | |
| mask = None, | |
| cond = None, | |
| sum_embeds = None, | |
| return_embeddings = False, | |
| return_logits_and_embeddings = False | |
| ): | |
| """ | |
| einops notation: | |
| b - batch | |
| n - sequence | |
| g - groups | |
| q - quantizers | |
| d - feature dimension | |
| """ | |
| n, q, g = x.shape[-1], self.num_quantizers, self.grouped_quantizers | |
| assert divisible_by(n, g * q), 'sequence must be divisible by number of quantizers' # noqa | |
| x = rearrange(x, 'b (n gq) -> b n gq', gq = g * q) | |
| x = x + self.quantizer_offsets | |
| x = self.code_embeds(x) | |
| x = reduce(x, 'b n (g q) d -> b n (g d)', 'sum', g = g) | |
| x = self.embedding_proj(x) | |
| if exists(sum_embeds): | |
| x = x + sum_embeds | |
| if exists(cond): | |
| if cond.ndim == 2: | |
| cond = rearrange(cond, 'b d -> b 1 d') | |
| x = x + cond | |
| x = self.conformer(x, mask = mask) | |
| embeds = self.heads(x) | |
| if return_embeddings or not exists(self.to_logits): | |
| return embeds | |
| logits = self.to_logits(embeds) | |
| if return_logits_and_embeddings: | |
| return logits, embeds | |
| return logits | |