Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| This code contains the spectrogram and Hybrid version of Demucs. | |
| """ | |
| from copy import deepcopy | |
| import math | |
| import typing as tp | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .filtering import wiener | |
| from .demucs import DConv, rescale_module | |
| from .states import capture_init | |
| from .spec import spectro, ispectro | |
| def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.): | |
| """Tiny wrapper around F.pad, just to allow for reflect padding on small input. | |
| If this is the case, we insert extra 0 padding to the right before the reflection happen.""" | |
| x0 = x | |
| length = x.shape[-1] | |
| padding_left, padding_right = paddings | |
| if mode == 'reflect': | |
| max_pad = max(padding_left, padding_right) | |
| if length <= max_pad: | |
| extra_pad = max_pad - length + 1 | |
| extra_pad_right = min(padding_right, extra_pad) | |
| extra_pad_left = extra_pad - extra_pad_right | |
| paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right) | |
| x = F.pad(x, (extra_pad_left, extra_pad_right)) | |
| out = F.pad(x, paddings, mode, value) | |
| assert out.shape[-1] == length + padding_left + padding_right | |
| assert (out[..., padding_left: padding_left + length] == x0).all() | |
| return out | |
| class ScaledEmbedding(nn.Module): | |
| """ | |
| Boost learning rate for embeddings (with `scale`). | |
| Also, can make embeddings continuous with `smooth`. | |
| """ | |
| def __init__(self, num_embeddings: int, embedding_dim: int, | |
| scale: float = 10., smooth=False): | |
| super().__init__() | |
| self.embedding = nn.Embedding(num_embeddings, embedding_dim) | |
| if smooth: | |
| weight = torch.cumsum(self.embedding.weight.data, dim=0) | |
| # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that. | |
| weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None] | |
| self.embedding.weight.data[:] = weight | |
| self.embedding.weight.data /= scale | |
| self.scale = scale | |
| def weight(self): | |
| return self.embedding.weight * self.scale | |
| def forward(self, x): | |
| out = self.embedding(x) * self.scale | |
| return out | |
| class HEncLayer(nn.Module): | |
| def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, | |
| freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, | |
| rewrite=True): | |
| """Encoder layer. This used both by the time and the frequency branch. | |
| Args: | |
| chin: number of input channels. | |
| chout: number of output channels. | |
| norm_groups: number of groups for group norm. | |
| empty: used to make a layer with just the first conv. this is used | |
| before merging the time and freq. branches. | |
| freq: this is acting on frequencies. | |
| dconv: insert DConv residual branches. | |
| norm: use GroupNorm. | |
| context: context size for the 1x1 conv. | |
| dconv_kw: list of kwargs for the DConv class. | |
| pad: pad the input. Padding is done so that the output size is | |
| always the input size / stride. | |
| rewrite: add 1x1 conv at the end of the layer. | |
| """ | |
| super().__init__() | |
| norm_fn = lambda d: nn.Identity() # noqa | |
| if norm: | |
| norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa | |
| if pad: | |
| pad = kernel_size // 4 | |
| else: | |
| pad = 0 | |
| klass = nn.Conv1d | |
| self.freq = freq | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.empty = empty | |
| self.norm = norm | |
| self.pad = pad | |
| if freq: | |
| kernel_size = [kernel_size, 1] | |
| stride = [stride, 1] | |
| pad = [pad, 0] | |
| klass = nn.Conv2d | |
| self.conv = klass(chin, chout, kernel_size, stride, pad) | |
| if self.empty: | |
| return | |
| self.norm1 = norm_fn(chout) | |
| self.rewrite = None | |
| if rewrite: | |
| self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context) | |
| self.norm2 = norm_fn(2 * chout) | |
| self.dconv = None | |
| if dconv: | |
| self.dconv = DConv(chout, **dconv_kw) | |
| def forward(self, x, inject=None): | |
| """ | |
| `inject` is used to inject the result from the time branch into the frequency branch, | |
| when both have the same stride. | |
| """ | |
| if not self.freq and x.dim() == 4: | |
| B, C, Fr, T = x.shape | |
| x = x.view(B, -1, T) | |
| if not self.freq: | |
| le = x.shape[-1] | |
| if not le % self.stride == 0: | |
| x = F.pad(x, (0, self.stride - (le % self.stride))) | |
| y = self.conv(x) | |
| if self.empty: | |
| return y | |
| if inject is not None: | |
| assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape) | |
| if inject.dim() == 3 and y.dim() == 4: | |
| inject = inject[:, :, None] | |
| y = y + inject | |
| y = F.gelu(self.norm1(y)) | |
| if self.dconv: | |
| if self.freq: | |
| B, C, Fr, T = y.shape | |
| y = y.permute(0, 2, 1, 3).reshape(-1, C, T) | |
| y = self.dconv(y) | |
| if self.freq: | |
| y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) | |
| if self.rewrite: | |
| z = self.norm2(self.rewrite(y)) | |
| z = F.glu(z, dim=1) | |
| else: | |
| z = y | |
| return z | |
| class MultiWrap(nn.Module): | |
| """ | |
| Takes one layer and replicate it N times. each replica will act | |
| on a frequency band. All is done so that if the N replica have the same weights, | |
| then this is exactly equivalent to applying the original module on all frequencies. | |
| This is a bit over-engineered to avoid edge artifacts when splitting | |
| the frequency bands, but it is possible the naive implementation would work as well... | |
| """ | |
| def __init__(self, layer, split_ratios): | |
| """ | |
| Args: | |
| layer: module to clone, must be either HEncLayer or HDecLayer. | |
| split_ratios: list of float indicating which ratio to keep for each band. | |
| """ | |
| super().__init__() | |
| self.split_ratios = split_ratios | |
| self.layers = nn.ModuleList() | |
| self.conv = isinstance(layer, HEncLayer) | |
| assert not layer.norm | |
| assert layer.freq | |
| assert layer.pad | |
| if not self.conv: | |
| assert not layer.context_freq | |
| for k in range(len(split_ratios) + 1): | |
| lay = deepcopy(layer) | |
| if self.conv: | |
| lay.conv.padding = (0, 0) | |
| else: | |
| lay.pad = False | |
| for m in lay.modules(): | |
| if hasattr(m, 'reset_parameters'): | |
| m.reset_parameters() | |
| self.layers.append(lay) | |
| def forward(self, x, skip=None, length=None): | |
| B, C, Fr, T = x.shape | |
| ratios = list(self.split_ratios) + [1] | |
| start = 0 | |
| outs = [] | |
| for ratio, layer in zip(ratios, self.layers): | |
| if self.conv: | |
| pad = layer.kernel_size // 4 | |
| if ratio == 1: | |
| limit = Fr | |
| frames = -1 | |
| else: | |
| limit = int(round(Fr * ratio)) | |
| le = limit - start | |
| if start == 0: | |
| le += pad | |
| frames = round((le - layer.kernel_size) / layer.stride + 1) | |
| limit = start + (frames - 1) * layer.stride + layer.kernel_size | |
| if start == 0: | |
| limit -= pad | |
| assert limit - start > 0, (limit, start) | |
| assert limit <= Fr, (limit, Fr) | |
| y = x[:, :, start:limit, :] | |
| if start == 0: | |
| y = F.pad(y, (0, 0, pad, 0)) | |
| if ratio == 1: | |
| y = F.pad(y, (0, 0, 0, pad)) | |
| outs.append(layer(y)) | |
| start = limit - layer.kernel_size + layer.stride | |
| else: | |
| if ratio == 1: | |
| limit = Fr | |
| else: | |
| limit = int(round(Fr * ratio)) | |
| last = layer.last | |
| layer.last = True | |
| y = x[:, :, start:limit] | |
| s = skip[:, :, start:limit] | |
| out, _ = layer(y, s, None) | |
| if outs: | |
| outs[-1][:, :, -layer.stride:] += ( | |
| out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)) | |
| out = out[:, :, layer.stride:] | |
| if ratio == 1: | |
| out = out[:, :, :-layer.stride // 2, :] | |
| if start == 0: | |
| out = out[:, :, layer.stride // 2:, :] | |
| outs.append(out) | |
| layer.last = last | |
| start = limit | |
| out = torch.cat(outs, dim=2) | |
| if not self.conv and not last: | |
| out = F.gelu(out) | |
| if self.conv: | |
| return out | |
| else: | |
| return out, None | |
| class HDecLayer(nn.Module): | |
| def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, | |
| freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, | |
| context_freq=True, rewrite=True): | |
| """ | |
| Same as HEncLayer but for decoder. See `HEncLayer` for documentation. | |
| """ | |
| super().__init__() | |
| norm_fn = lambda d: nn.Identity() # noqa | |
| if norm: | |
| norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa | |
| if pad: | |
| pad = kernel_size // 4 | |
| else: | |
| pad = 0 | |
| self.pad = pad | |
| self.last = last | |
| self.freq = freq | |
| self.chin = chin | |
| self.empty = empty | |
| self.stride = stride | |
| self.kernel_size = kernel_size | |
| self.norm = norm | |
| self.context_freq = context_freq | |
| klass = nn.Conv1d | |
| klass_tr = nn.ConvTranspose1d | |
| if freq: | |
| kernel_size = [kernel_size, 1] | |
| stride = [stride, 1] | |
| klass = nn.Conv2d | |
| klass_tr = nn.ConvTranspose2d | |
| self.conv_tr = klass_tr(chin, chout, kernel_size, stride) | |
| self.norm2 = norm_fn(chout) | |
| if self.empty: | |
| return | |
| self.rewrite = None | |
| if rewrite: | |
| if context_freq: | |
| self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context) | |
| else: | |
| self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, | |
| [0, context]) | |
| self.norm1 = norm_fn(2 * chin) | |
| self.dconv = None | |
| if dconv: | |
| self.dconv = DConv(chin, **dconv_kw) | |
| def forward(self, x, skip, length): | |
| if self.freq and x.dim() == 3: | |
| B, C, T = x.shape | |
| x = x.view(B, self.chin, -1, T) | |
| if not self.empty: | |
| x = x + skip | |
| if self.rewrite: | |
| y = F.glu(self.norm1(self.rewrite(x)), dim=1) | |
| else: | |
| y = x | |
| if self.dconv: | |
| if self.freq: | |
| B, C, Fr, T = y.shape | |
| y = y.permute(0, 2, 1, 3).reshape(-1, C, T) | |
| y = self.dconv(y) | |
| if self.freq: | |
| y = y.view(B, Fr, C, T).permute(0, 2, 1, 3) | |
| else: | |
| y = x | |
| assert skip is None | |
| z = self.norm2(self.conv_tr(y)) | |
| if self.freq: | |
| if self.pad: | |
| z = z[..., self.pad:-self.pad, :] | |
| else: | |
| z = z[..., self.pad:self.pad + length] | |
| assert z.shape[-1] == length, (z.shape[-1], length) | |
| if not self.last: | |
| z = F.gelu(z) | |
| return z, y | |
| class HDemucs(nn.Module): | |
| """ | |
| Spectrogram and hybrid Demucs model. | |
| The spectrogram model has the same structure as Demucs, except the first few layers are over the | |
| frequency axis, until there is only 1 frequency, and then it moves to time convolutions. | |
| Frequency layers can still access information across time steps thanks to the DConv residual. | |
| Hybrid model have a parallel time branch. At some layer, the time branch has the same stride | |
| as the frequency branch and then the two are combined. The opposite happens in the decoder. | |
| Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]), | |
| or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on | |
| Open Unmix implementation [Stoter et al. 2019]. | |
| The loss is always on the temporal domain, by backpropagating through the above | |
| output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks | |
| a bit Wiener filtering, as doing more iteration at test time will change the spectrogram | |
| contribution, without changing the one from the waveform, which will lead to worse performance. | |
| I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve. | |
| CaC on the other hand provides similar performance for hybrid, and works naturally with | |
| hybrid models. | |
| This model also uses frequency embeddings are used to improve efficiency on convolutions | |
| over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf). | |
| Unlike classic Demucs, there is no resampling here, and normalization is always applied. | |
| """ | |
| def __init__(self, | |
| sources, | |
| # Channels | |
| audio_channels=2, | |
| channels=48, | |
| channels_time=None, | |
| growth=2, | |
| # STFT | |
| nfft=4096, | |
| wiener_iters=0, | |
| end_iters=0, | |
| wiener_residual=False, | |
| cac=True, | |
| # Main structure | |
| depth=6, | |
| rewrite=True, | |
| hybrid=True, | |
| hybrid_old=False, | |
| # Frequency branch | |
| multi_freqs=None, | |
| multi_freqs_depth=2, | |
| freq_emb=0.2, | |
| emb_scale=10, | |
| emb_smooth=True, | |
| # Convolutions | |
| kernel_size=8, | |
| time_stride=2, | |
| stride=4, | |
| context=1, | |
| context_enc=0, | |
| # Normalization | |
| norm_starts=4, | |
| norm_groups=4, | |
| # DConv residual branch | |
| dconv_mode=1, | |
| dconv_depth=2, | |
| dconv_comp=4, | |
| dconv_attn=4, | |
| dconv_lstm=4, | |
| dconv_init=1e-4, | |
| # Weight init | |
| rescale=0.1, | |
| # Metadata | |
| samplerate=44100, | |
| segment=4 * 10): | |
| """ | |
| Args: | |
| sources (list[str]): list of source names. | |
| audio_channels (int): input/output audio channels. | |
| channels (int): initial number of hidden channels. | |
| channels_time: if not None, use a different `channels` value for the time branch. | |
| growth: increase the number of hidden channels by this factor at each layer. | |
| nfft: number of fft bins. Note that changing this require careful computation of | |
| various shape parameters and will not work out of the box for hybrid models. | |
| wiener_iters: when using Wiener filtering, number of iterations at test time. | |
| end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`. | |
| wiener_residual: add residual source before wiener filtering. | |
| cac: uses complex as channels, i.e. complex numbers are 2 channels each | |
| in input and output. no further processing is done before ISTFT. | |
| depth (int): number of layers in the encoder and in the decoder. | |
| rewrite (bool): add 1x1 convolution to each layer. | |
| hybrid (bool): make a hybrid time/frequency domain, otherwise frequency only. | |
| hybrid_old: some models trained for MDX had a padding bug. This replicates | |
| this bug to avoid retraining them. | |
| multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`. | |
| multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost | |
| layers will be wrapped. | |
| freq_emb: add frequency embedding after the first frequency layer if > 0, | |
| the actual value controls the weight of the embedding. | |
| emb_scale: equivalent to scaling the embedding learning rate | |
| emb_smooth: initialize the embedding with a smooth one (with respect to frequencies). | |
| kernel_size: kernel_size for encoder and decoder layers. | |
| stride: stride for encoder and decoder layers. | |
| time_stride: stride for the final time layer, after the merge. | |
| context: context for 1x1 conv in the decoder. | |
| context_enc: context for 1x1 conv in the encoder. | |
| norm_starts: layer at which group norm starts being used. | |
| decoder layers are numbered in reverse order. | |
| norm_groups: number of groups for group norm. | |
| dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both. | |
| dconv_depth: depth of residual DConv branch. | |
| dconv_comp: compression of DConv branch. | |
| dconv_attn: adds attention layers in DConv branch starting at this layer. | |
| dconv_lstm: adds a LSTM layer in DConv branch starting at this layer. | |
| dconv_init: initial scale for the DConv branch LayerScale. | |
| rescale: weight recaling trick | |
| """ | |
| super().__init__() | |
| self.cac = cac | |
| self.wiener_residual = wiener_residual | |
| self.audio_channels = audio_channels | |
| self.sources = sources | |
| self.kernel_size = kernel_size | |
| self.context = context | |
| self.stride = stride | |
| self.depth = depth | |
| self.channels = channels | |
| self.samplerate = samplerate | |
| self.segment = segment | |
| self.nfft = nfft | |
| self.hop_length = nfft // 4 | |
| self.wiener_iters = wiener_iters | |
| self.end_iters = end_iters | |
| self.freq_emb = None | |
| self.hybrid = hybrid | |
| self.hybrid_old = hybrid_old | |
| if hybrid_old: | |
| assert hybrid, "hybrid_old must come with hybrid=True" | |
| if hybrid: | |
| assert wiener_iters == end_iters | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| if hybrid: | |
| self.tencoder = nn.ModuleList() | |
| self.tdecoder = nn.ModuleList() | |
| chin = audio_channels | |
| chin_z = chin # number of channels for the freq branch | |
| if self.cac: | |
| chin_z *= 2 | |
| chout = channels_time or channels | |
| chout_z = channels | |
| freqs = nfft // 2 | |
| for index in range(depth): | |
| lstm = index >= dconv_lstm | |
| attn = index >= dconv_attn | |
| norm = index >= norm_starts | |
| freq = freqs > 1 | |
| stri = stride | |
| ker = kernel_size | |
| if not freq: | |
| assert freqs == 1 | |
| ker = time_stride * 2 | |
| stri = time_stride | |
| pad = True | |
| last_freq = False | |
| if freq and freqs <= kernel_size: | |
| ker = freqs | |
| pad = False | |
| last_freq = True | |
| kw = { | |
| 'kernel_size': ker, | |
| 'stride': stri, | |
| 'freq': freq, | |
| 'pad': pad, | |
| 'norm': norm, | |
| 'rewrite': rewrite, | |
| 'norm_groups': norm_groups, | |
| 'dconv_kw': { | |
| 'lstm': lstm, | |
| 'attn': attn, | |
| 'depth': dconv_depth, | |
| 'compress': dconv_comp, | |
| 'init': dconv_init, | |
| 'gelu': True, | |
| } | |
| } | |
| kwt = dict(kw) | |
| kwt['freq'] = 0 | |
| kwt['kernel_size'] = kernel_size | |
| kwt['stride'] = stride | |
| kwt['pad'] = True | |
| kw_dec = dict(kw) | |
| multi = False | |
| if multi_freqs and index < multi_freqs_depth: | |
| multi = True | |
| kw_dec['context_freq'] = False | |
| if last_freq: | |
| chout_z = max(chout, chout_z) | |
| chout = chout_z | |
| enc = HEncLayer(chin_z, chout_z, | |
| dconv=dconv_mode & 1, context=context_enc, **kw) | |
| if hybrid and freq: | |
| tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, | |
| empty=last_freq, **kwt) | |
| self.tencoder.append(tenc) | |
| if multi: | |
| enc = MultiWrap(enc, multi_freqs) | |
| self.encoder.append(enc) | |
| if index == 0: | |
| chin = self.audio_channels * len(self.sources) | |
| chin_z = chin | |
| if self.cac: | |
| chin_z *= 2 | |
| dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, | |
| last=index == 0, context=context, **kw_dec) | |
| if multi: | |
| dec = MultiWrap(dec, multi_freqs) | |
| if hybrid and freq: | |
| tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, | |
| last=index == 0, context=context, **kwt) | |
| self.tdecoder.insert(0, tdec) | |
| self.decoder.insert(0, dec) | |
| chin = chout | |
| chin_z = chout_z | |
| chout = int(growth * chout) | |
| chout_z = int(growth * chout_z) | |
| if freq: | |
| if freqs <= kernel_size: | |
| freqs = 1 | |
| else: | |
| freqs //= stride | |
| if index == 0 and freq_emb: | |
| self.freq_emb = ScaledEmbedding( | |
| freqs, chin_z, smooth=emb_smooth, scale=emb_scale) | |
| self.freq_emb_scale = freq_emb | |
| if rescale: | |
| rescale_module(self, reference=rescale) | |
| def _spec(self, x): | |
| hl = self.hop_length | |
| nfft = self.nfft | |
| x0 = x # noqa | |
| if self.hybrid: | |
| # We re-pad the signal in order to keep the property | |
| # that the size of the output is exactly the size of the input | |
| # divided by the stride (here hop_length), when divisible. | |
| # This is achieved by padding by 1/4th of the kernel size (here nfft). | |
| # which is not supported by torch.stft. | |
| # Having all convolution operations follow this convention allow to easily | |
| # align the time and frequency branches later on. | |
| assert hl == nfft // 4 | |
| le = int(math.ceil(x.shape[-1] / hl)) | |
| pad = hl // 2 * 3 | |
| if not self.hybrid_old: | |
| x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode='reflect') | |
| else: | |
| x = pad1d(x, (pad, pad + le * hl - x.shape[-1])) | |
| z = spectro(x, nfft, hl)[..., :-1, :] | |
| if self.hybrid: | |
| assert z.shape[-1] == le + 4, (z.shape, x.shape, le) | |
| z = z[..., 2:2+le] | |
| return z | |
| def _ispec(self, z, length=None, scale=0): | |
| hl = self.hop_length // (4 ** scale) | |
| z = F.pad(z, (0, 0, 0, 1)) | |
| if self.hybrid: | |
| z = F.pad(z, (2, 2)) | |
| pad = hl // 2 * 3 | |
| if not self.hybrid_old: | |
| le = hl * int(math.ceil(length / hl)) + 2 * pad | |
| else: | |
| le = hl * int(math.ceil(length / hl)) | |
| x = ispectro(z, hl, length=le) | |
| if not self.hybrid_old: | |
| x = x[..., pad:pad + length] | |
| else: | |
| x = x[..., :length] | |
| else: | |
| x = ispectro(z, hl, length) | |
| return x | |
| def _magnitude(self, z): | |
| # return the magnitude of the spectrogram, except when cac is True, | |
| # in which case we just move the complex dimension to the channel one. | |
| if self.cac: | |
| B, C, Fr, T = z.shape | |
| m = torch.view_as_real(z).permute(0, 1, 4, 2, 3) | |
| m = m.reshape(B, C * 2, Fr, T) | |
| else: | |
| m = z.abs() | |
| return m | |
| def _mask(self, z, m): | |
| # Apply masking given the mixture spectrogram `z` and the estimated mask `m`. | |
| # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored. | |
| niters = self.wiener_iters | |
| if self.cac: | |
| B, S, C, Fr, T = m.shape | |
| out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3) | |
| out = torch.view_as_complex(out.contiguous()) | |
| return out | |
| if self.training: | |
| niters = self.end_iters | |
| if niters < 0: | |
| z = z[:, None] | |
| return z / (1e-8 + z.abs()) * m | |
| else: | |
| return self._wiener(m, z, niters) | |
| def _wiener(self, mag_out, mix_stft, niters): | |
| # apply wiener filtering from OpenUnmix. | |
| init = mix_stft.dtype | |
| wiener_win_len = 300 | |
| residual = self.wiener_residual | |
| B, S, C, Fq, T = mag_out.shape | |
| mag_out = mag_out.permute(0, 4, 3, 2, 1) | |
| mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1)) | |
| outs = [] | |
| for sample in range(B): | |
| pos = 0 | |
| out = [] | |
| for pos in range(0, T, wiener_win_len): | |
| frame = slice(pos, pos + wiener_win_len) | |
| z_out = wiener( | |
| mag_out[sample, frame], mix_stft[sample, frame], niters, | |
| residual=residual) | |
| out.append(z_out.transpose(-1, -2)) | |
| outs.append(torch.cat(out, dim=0)) | |
| out = torch.view_as_complex(torch.stack(outs, 0)) | |
| out = out.permute(0, 4, 3, 2, 1).contiguous() | |
| if residual: | |
| out = out[:, :-1] | |
| assert list(out.shape) == [B, S, C, Fq, T] | |
| return out.to(init) | |
| def forward(self, mix): | |
| x = mix | |
| length = x.shape[-1] | |
| z = self._spec(mix) | |
| mag = self._magnitude(z) | |
| x = mag | |
| B, C, Fq, T = x.shape | |
| # unlike previous Demucs, we always normalize because it is easier. | |
| mean = x.mean(dim=(1, 2, 3), keepdim=True) | |
| std = x.std(dim=(1, 2, 3), keepdim=True) | |
| x = (x - mean) / (1e-5 + std) | |
| # x will be the freq. branch input. | |
| if self.hybrid: | |
| # Prepare the time branch input. | |
| xt = mix | |
| meant = xt.mean(dim=(1, 2), keepdim=True) | |
| stdt = xt.std(dim=(1, 2), keepdim=True) | |
| xt = (xt - meant) / (1e-5 + stdt) | |
| # okay, this is a giant mess I know... | |
| saved = [] # skip connections, freq. | |
| saved_t = [] # skip connections, time. | |
| lengths = [] # saved lengths to properly remove padding, freq branch. | |
| lengths_t = [] # saved lengths for time branch. | |
| for idx, encode in enumerate(self.encoder): | |
| lengths.append(x.shape[-1]) | |
| inject = None | |
| if self.hybrid and idx < len(self.tencoder): | |
| # we have not yet merged branches. | |
| lengths_t.append(xt.shape[-1]) | |
| tenc = self.tencoder[idx] | |
| xt = tenc(xt) | |
| if not tenc.empty: | |
| # save for skip connection | |
| saved_t.append(xt) | |
| else: | |
| # tenc contains just the first conv., so that now time and freq. | |
| # branches have the same shape and can be merged. | |
| inject = xt | |
| x = encode(x, inject) | |
| if idx == 0 and self.freq_emb is not None: | |
| # add frequency embedding to allow for non equivariant convolutions | |
| # over the frequency axis. | |
| frs = torch.arange(x.shape[-2], device=x.device) | |
| emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x) | |
| x = x + self.freq_emb_scale * emb | |
| saved.append(x) | |
| x = torch.zeros_like(x) | |
| if self.hybrid: | |
| xt = torch.zeros_like(x) | |
| # initialize everything to zero (signal will go through u-net skips). | |
| for idx, decode in enumerate(self.decoder): | |
| skip = saved.pop(-1) | |
| x, pre = decode(x, skip, lengths.pop(-1)) | |
| # `pre` contains the output just before final transposed convolution, | |
| # which is used when the freq. and time branch separate. | |
| if self.hybrid: | |
| offset = self.depth - len(self.tdecoder) | |
| if self.hybrid and idx >= offset: | |
| tdec = self.tdecoder[idx - offset] | |
| length_t = lengths_t.pop(-1) | |
| if tdec.empty: | |
| assert pre.shape[2] == 1, pre.shape | |
| pre = pre[:, :, 0] | |
| xt, _ = tdec(pre, None, length_t) | |
| else: | |
| skip = saved_t.pop(-1) | |
| xt, _ = tdec(xt, skip, length_t) | |
| # Let's make sure we used all stored skip connections. | |
| assert len(saved) == 0 | |
| assert len(lengths_t) == 0 | |
| assert len(saved_t) == 0 | |
| S = len(self.sources) | |
| x = x.view(B, S, -1, Fq, T) | |
| x = x * std[:, None] + mean[:, None] | |
| zout = self._mask(z, x) | |
| x = self._ispec(zout, length) | |
| if self.hybrid: | |
| xt = xt.view(B, S, -1, length) | |
| xt = xt * stdt[:, None] + meant[:, None] | |
| x = xt + x | |
| return x | |