Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # author: adefossez | |
| import math | |
| import time | |
| import torch | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from .resample import downsample2, upsample2 | |
| from .utils import capture_init | |
| # class BLSTM(nn.Module): | |
| # def __init__(self, dim, layers=2, bi=True): | |
| # super().__init__() | |
| # klass = nn.LSTM | |
| # self.lstm = klass(bidirectional=bi, num_layers=layers, hidden_size=dim, input_size=dim) | |
| # self.linear = None | |
| # if bi: | |
| # self.linear = nn.Linear(2 * dim, dim) | |
| # def forward(self, x, hidden=None): | |
| # x, hidden = self.lstm(x, hidden) | |
| # if self.linear: | |
| # x = self.linear(x) | |
| # return x, hidden | |
| EPS = 1e-8 | |
| class Chomp1d(nn.Module): | |
| """To ensure the output length is the same as the input. | |
| """ | |
| def __init__(self, chomp_size): | |
| super(Chomp1d, self).__init__() | |
| self.chomp_size = chomp_size | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [M, H, Kpad] | |
| Returns: | |
| [M, H, K] | |
| """ | |
| return x[:, :, :-self.chomp_size].contiguous() | |
| def chose_norm(norm_type, channel_size): | |
| """The input of normlization will be (M, C, K), where M is batch size, | |
| C is channel size and K is sequence length. | |
| """ | |
| if norm_type == "gLN": | |
| return GlobalLayerNorm(channel_size) | |
| elif norm_type == "cLN": | |
| return ChannelwiseLayerNorm(channel_size) | |
| else: # norm_type == "BN": | |
| # Given input (M, C, K), nn.BatchNorm1d(C) will accumulate statics | |
| # along M and K, so this BN usage is right. | |
| return nn.BatchNorm1d(channel_size) | |
| class ChannelwiseLayerNorm(nn.Module): | |
| """Channel-wise Layer Normalization (cLN)""" | |
| def __init__(self, channel_size): | |
| super(ChannelwiseLayerNorm, self).__init__() | |
| self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] | |
| self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| self.gamma.data.fill_(1) | |
| self.beta.data.zero_() | |
| def forward(self, y): | |
| """ | |
| Args: | |
| y: [M, N, K], M is batch size, N is channel size, K is length | |
| Returns: | |
| cLN_y: [M, N, K] | |
| """ | |
| mean = torch.mean(y, dim=1, keepdim=True) # [M, 1, K] | |
| var = torch.var(y, dim=1, keepdim=True, unbiased=False) # [M, 1, K] | |
| cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta | |
| return cLN_y | |
| class DepthwiseSeparableConv(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, | |
| stride, padding, dilation, norm_type="gLN", causal=False): | |
| super(DepthwiseSeparableConv, self).__init__() | |
| # Use `groups` option to implement depthwise convolution | |
| # [M, H, K] -> [M, H, K] | |
| depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, | |
| stride=stride, padding=padding, | |
| dilation=dilation, groups=in_channels, | |
| bias=False) | |
| if causal: | |
| chomp = Chomp1d(padding) | |
| prelu = nn.PReLU() | |
| norm = chose_norm(norm_type, in_channels) | |
| # [M, H, K] -> [M, B, K] | |
| pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
| # Put together | |
| if causal: | |
| self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, | |
| pointwise_conv) | |
| else: | |
| self.net = nn.Sequential(depthwise_conv, prelu, norm, | |
| pointwise_conv) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [M, H, K] | |
| Returns: | |
| result: [M, B, K] | |
| """ | |
| return self.net(x) | |
| class TemporalBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, kernel_size, | |
| stride, padding, dilation, norm_type="gLN", causal=False): | |
| super(TemporalBlock, self).__init__() | |
| # [M, B, K] -> [M, H, K] | |
| conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) | |
| prelu = nn.PReLU() | |
| norm = chose_norm(norm_type, out_channels) | |
| # [M, H, K] -> [M, B, K] | |
| dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, | |
| stride, padding, dilation, norm_type, | |
| causal) | |
| # Put together | |
| self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) | |
| def forward(self, x): | |
| """ | |
| Args: | |
| x: [M, B, K] | |
| Returns: | |
| [M, B, K] | |
| """ | |
| residual = x | |
| out = self.net(x) | |
| # TODO: when P = 3 here works fine, but when P = 2 maybe need to pad? | |
| return out + residual # look like w/o F.relu is better than w/ F.relu | |
| # return F.relu(out + residual) | |
| class GlobalLayerNorm(nn.Module): | |
| """Global Layer Normalization (gLN)""" | |
| def __init__(self, channel_size): | |
| super(GlobalLayerNorm, self).__init__() | |
| self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) # [1, N, 1] | |
| self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) # [1, N, 1] | |
| self.reset_parameters() | |
| def reset_parameters(self): | |
| self.gamma.data.fill_(1) | |
| self.beta.data.zero_() | |
| def forward(self, y): | |
| """ | |
| Args: | |
| y: [M, N, K], M is batch size, N is channel size, K is length | |
| Returns: | |
| gLN_y: [M, N, K] | |
| """ | |
| # TODO: in torch 1.0, torch.mean() support dim list | |
| mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) #[M, 1, 1] | |
| var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) | |
| gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta | |
| return gLN_y | |
| class TemporalConvNet(nn.Module): | |
| def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1, | |
| mask_nonlinear='relu'): | |
| """ | |
| Args: | |
| N: Number of filters in autoencoder | |
| B: Number of channels in bottleneck 1 × 1-conv block | |
| H: Number of channels in convolutional blocks | |
| P: Kernel size in convolutional blocks | |
| X: Number of convolutional blocks in each repeat | |
| R: Number of repeats | |
| C: Number of speakers | |
| norm_type: BN, gLN, cLN | |
| causal: causal or non-causal | |
| mask_nonlinear: use which non-linear function to generate mask | |
| """ | |
| super(TemporalConvNet, self).__init__() | |
| # Hyper-parameter | |
| self.C = C | |
| self.mask_nonlinear = mask_nonlinear | |
| # Components | |
| # [M, N, K] -> [M, N, K] | |
| layer_norm = ChannelwiseLayerNorm(N) | |
| # [M, N, K] -> [M, B, K] | |
| bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) | |
| # [M, B, K] -> [M, B, K] | |
| repeats = [] | |
| for r in range(R): | |
| blocks = [] | |
| for x in range(X): | |
| dilation = 2**x | |
| padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 | |
| blocks += [TemporalBlock(B, H, P, stride=1, | |
| padding=padding, | |
| dilation=dilation, | |
| norm_type=norm_type, | |
| causal=causal)] | |
| repeats += [nn.Sequential(*blocks)] | |
| temporal_conv_net = nn.Sequential(*repeats) | |
| # [M, B, K] -> [M, C*N, K] | |
| mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False) | |
| # Put together | |
| self.network = nn.Sequential(layer_norm, | |
| bottleneck_conv1x1, | |
| temporal_conv_net, | |
| mask_conv1x1) | |
| def forward(self, mixture_w): | |
| """ | |
| Keep this API same with TasNet | |
| Args: | |
| mixture_w: [M, N, K], M is batch size | |
| returns: | |
| est_mask: [M, C, N, K] | |
| """ | |
| M, N, K = mixture_w.size() | |
| score = self.network(mixture_w) # [M, N, K] -> [M, C*N, K] | |
| score = score.view(M, self.C, N, K) # [M, C*N, K] -> [M, C, N, K] | |
| if self.mask_nonlinear == 'softmax': | |
| est_mask = F.softmax(score, dim=1) | |
| est_mask = est_mask.squeeze(1) | |
| elif self.mask_nonlinear == 'relu': | |
| est_mask = F.relu(score) | |
| est_mask = est_mask.squeeze(1) | |
| else: | |
| raise ValueError("Unsupported mask non-linear function") | |
| return est_mask | |
| def rescale_conv(conv, reference): | |
| std = conv.weight.std().detach() | |
| scale = (std / reference)**0.5 | |
| conv.weight.data /= scale | |
| if conv.bias is not None: | |
| conv.bias.data /= scale | |
| def rescale_module(module, reference): | |
| for sub in module.modules(): | |
| if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): | |
| rescale_conv(sub, reference) | |
| class Demucs(nn.Module): | |
| """ | |
| Demucs speech enhancement model. | |
| Args: | |
| - chin (int): number of input channels. | |
| - chout (int): number of output channels. | |
| - hidden (int): number of initial hidden channels. | |
| - depth (int): number of layers. | |
| - kernel_size (int): kernel size for each layer. | |
| - stride (int): stride for each layer. | |
| - causal (bool): if false, uses BiLSTM instead of LSTM. | |
| - resample (int): amount of resampling to apply to the input/output. | |
| Can be one of 1, 2 or 4. | |
| - growth (float): number of channels is multiplied by this for every layer. | |
| - max_hidden (int): maximum number of channels. Can be useful to | |
| control the size/speed of the model. | |
| - normalize (bool): if true, normalize the input. | |
| - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. | |
| - rescale (float): controls custom weight initialization. | |
| See https://arxiv.org/abs/1911.13254. | |
| - floor (float): stability flooring when normalizing. | |
| """ | |
| def __init__(self, | |
| chin=1, | |
| chout=1, | |
| hidden=48, | |
| depth=5, | |
| kernel_size=8, | |
| stride=4, | |
| causal=True, | |
| resample=4, | |
| growth=2, | |
| max_hidden=10_000, | |
| normalize=True, | |
| glu=True, | |
| rescale=0.1, | |
| floor=1e-3): | |
| super().__init__() | |
| if resample not in [1, 2, 4]: | |
| raise ValueError("Resample should be 1, 2 or 4.") | |
| self.chin = chin | |
| self.chout = chout | |
| self.hidden = hidden | |
| self.depth = depth | |
| self.kernel_size = kernel_size | |
| self.stride = stride | |
| self.causal = causal | |
| self.floor = floor | |
| self.resample = resample | |
| self.normalize = normalize | |
| self.encoder = nn.ModuleList() | |
| self.decoder = nn.ModuleList() | |
| activation = nn.GLU(1) if glu else nn.ReLU() | |
| ch_scale = 2 if glu else 1 | |
| for index in range(depth): | |
| encode = [] | |
| encode += [ | |
| nn.Conv1d(chin, hidden, kernel_size, stride), | |
| nn.ReLU(), | |
| nn.Conv1d(hidden, hidden * ch_scale, 1), activation, | |
| ] | |
| self.encoder.append(nn.Sequential(*encode)) | |
| decode = [] | |
| decode += [ | |
| nn.Conv1d(hidden, ch_scale * hidden, 1), activation, | |
| nn.ConvTranspose1d(hidden, chout, kernel_size, stride), | |
| ] | |
| if index > 0: | |
| decode.append(nn.ReLU()) | |
| self.decoder.insert(0, nn.Sequential(*decode)) | |
| chout = hidden | |
| chin = hidden | |
| hidden = min(int(growth * hidden), max_hidden) | |
| # import pdb; pdb.set_trace() | |
| self.separator = TemporalConvNet(N=chout) | |
| # self.lstm = BLSTM(chin, bi=not causal) | |
| if rescale: | |
| rescale_module(self, reference=rescale) | |
| def valid_length(self, length): | |
| """ | |
| Return the nearest valid length to use with the model so that | |
| there is no time steps left over in a convolutions, e.g. for all | |
| layers, size of the input - kernel_size % stride = 0. | |
| If the mixture has a valid length, the estimated sources | |
| will have exactly the same length. | |
| """ | |
| length = math.ceil(length * self.resample) | |
| for idx in range(self.depth): | |
| length = math.ceil((length - self.kernel_size) / self.stride) + 1 | |
| length = max(length, 1) | |
| for idx in range(self.depth): | |
| length = (length - 1) * self.stride + self.kernel_size | |
| length = int(math.ceil(length / self.resample)) | |
| return int(length) | |
| def total_stride(self): | |
| return self.stride ** self.depth // self.resample | |
| def forward(self, mix): | |
| if mix.dim() == 2: | |
| mix = mix.unsqueeze(1) | |
| if self.normalize: | |
| mono = mix.mean(dim=1, keepdim=True) | |
| std = mono.std(dim=-1, keepdim=True) | |
| mix = mix / (self.floor + std) | |
| else: | |
| std = 1 | |
| length = mix.shape[-1] | |
| x = mix | |
| x = F.pad(x, (0, self.valid_length(length) - length)) | |
| if self.resample == 2: | |
| x = upsample2(x) | |
| elif self.resample == 4: | |
| x = upsample2(x) | |
| x = upsample2(x) | |
| skips = [] | |
| for encode in self.encoder: | |
| x = encode(x) | |
| skips.append(x) | |
| x = self.separator(x) | |
| # x = x.permute(2, 0, 1) | |
| # x, _ = self.lstm(x) | |
| # x = x.permute(1, 2, 0) | |
| # import pdb; pdb.set_trace() | |
| for decode in self.decoder: | |
| skip = skips.pop(-1) | |
| x = x + skip[..., :x.shape[-1]] | |
| x = decode(x) | |
| if self.resample == 2: | |
| x = downsample2(x) | |
| elif self.resample == 4: | |
| x = downsample2(x) | |
| x = downsample2(x) | |
| x = x[..., :length] | |
| return std * x | |
| def fast_conv(conv, x): | |
| """ | |
| Faster convolution evaluation if either kernel size is 1 | |
| or length of sequence is 1. | |
| """ | |
| batch, chin, length = x.shape | |
| chout, chin, kernel = conv.weight.shape | |
| assert batch == 1 | |
| if kernel == 1: | |
| x = x.view(chin, length) | |
| out = th.addmm(conv.bias.view(-1, 1), | |
| conv.weight.view(chout, chin), x) | |
| elif length == kernel: | |
| x = x.view(chin * kernel, 1) | |
| out = th.addmm(conv.bias.view(-1, 1), | |
| conv.weight.view(chout, chin * kernel), x) | |
| else: | |
| out = conv(x) | |
| return out.view(batch, chout, -1) | |
| class DemucsStreamer: | |
| """ | |
| Streaming implementation for Demucs. It supports being fed with any amount | |
| of audio at a time. You will get back as much audio as possible at that | |
| point. | |
| Args: | |
| - demucs (Demucs): Demucs model. | |
| - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum | |
| noise removal, 1 just returns the input signal. Small values > 0 | |
| allows to limit distortions. | |
| - num_frames (int): number of frames to process at once. Higher values | |
| will increase overall latency but improve the real time factor. | |
| - resample_lookahead (int): extra lookahead used for the resampling. | |
| - resample_buffer (int): size of the buffer of previous inputs/outputs | |
| kept for resampling. | |
| """ | |
| def __init__(self, demucs, | |
| dry=0, | |
| num_frames=1, | |
| resample_lookahead=64, | |
| resample_buffer=256): | |
| device = next(iter(demucs.parameters())).device | |
| self.demucs = demucs | |
| self.lstm_state = None | |
| self.conv_state = None | |
| self.dry = dry | |
| self.resample_lookahead = resample_lookahead | |
| self.resample_buffer = resample_buffer | |
| self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1) | |
| self.total_length = self.frame_length + self.resample_lookahead | |
| self.stride = demucs.total_stride * num_frames | |
| self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device) | |
| self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device) | |
| self.frames = 0 | |
| self.total_time = 0 | |
| self.variance = 0 | |
| self.pending = torch.zeros(demucs.chin, 0, device=device) | |
| bias = demucs.decoder[0][2].bias | |
| weight = demucs.decoder[0][2].weight | |
| chin, chout, kernel = weight.shape | |
| self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) | |
| self._weight = weight.permute(1, 2, 0).contiguous() | |
| def reset_time_per_frame(self): | |
| self.total_time = 0 | |
| self.frames = 0 | |
| def time_per_frame(self): | |
| return self.total_time / self.frames | |
| def flush(self): | |
| """ | |
| Flush remaining audio by padding it with zero. Call this | |
| when you have no more input and want to get back the last chunk of audio. | |
| """ | |
| pending_length = self.pending.shape[1] | |
| padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device) | |
| out = self.feed(padding) | |
| return out[:, :pending_length] | |
| def feed(self, wav): | |
| """ | |
| Apply the model to mix using true real time evaluation. | |
| Normalization is done online as is the resampling. | |
| """ | |
| begin = time.time() | |
| demucs = self.demucs | |
| resample_buffer = self.resample_buffer | |
| stride = self.stride | |
| resample = demucs.resample | |
| if wav.dim() != 2: | |
| raise ValueError("input wav should be two dimensional.") | |
| chin, _ = wav.shape | |
| if chin != demucs.chin: | |
| raise ValueError(f"Expected {demucs.chin} channels, got {chin}") | |
| self.pending = torch.cat([self.pending, wav], dim=1) | |
| outs = [] | |
| while self.pending.shape[1] >= self.total_length: | |
| self.frames += 1 | |
| frame = self.pending[:, :self.total_length] | |
| dry_signal = frame[:, :stride] | |
| if demucs.normalize: | |
| mono = frame.mean(0) | |
| variance = (mono**2).mean() | |
| self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance | |
| frame = frame / (demucs.floor + math.sqrt(self.variance)) | |
| frame = torch.cat([self.resample_in, frame], dim=-1) | |
| self.resample_in[:] = frame[:, stride - resample_buffer:stride] | |
| if resample == 4: | |
| frame = upsample2(upsample2(frame)) | |
| elif resample == 2: | |
| frame = upsample2(frame) | |
| frame = frame[:, resample * resample_buffer:] # remove pre sampling buffer | |
| frame = frame[:, :resample * self.frame_length] # remove extra samples after window | |
| out, extra = self._separate_frame(frame) | |
| padded_out = torch.cat([self.resample_out, out, extra], 1) | |
| self.resample_out[:] = out[:, -resample_buffer:] | |
| if resample == 4: | |
| out = downsample2(downsample2(padded_out)) | |
| elif resample == 2: | |
| out = downsample2(padded_out) | |
| else: | |
| out = padded_out | |
| out = out[:, resample_buffer // resample:] | |
| out = out[:, :stride] | |
| if demucs.normalize: | |
| out *= math.sqrt(self.variance) | |
| out = self.dry * dry_signal + (1 - self.dry) * out | |
| outs.append(out) | |
| self.pending = self.pending[:, stride:] | |
| self.total_time += time.time() - begin | |
| if outs: | |
| out = torch.cat(outs, 1) | |
| else: | |
| out = torch.zeros(chin, 0, device=wav.device) | |
| return out | |
| def _separate_frame(self, frame): | |
| demucs = self.demucs | |
| skips = [] | |
| next_state = [] | |
| first = self.conv_state is None | |
| stride = self.stride * demucs.resample | |
| x = frame[None] | |
| for idx, encode in enumerate(demucs.encoder): | |
| stride //= demucs.stride | |
| length = x.shape[2] | |
| if idx == demucs.depth - 1: | |
| # This is sligthly faster for the last conv | |
| x = fast_conv(encode[0], x) | |
| x = encode[1](x) | |
| x = fast_conv(encode[2], x) | |
| x = encode[3](x) | |
| else: | |
| if not first: | |
| prev = self.conv_state.pop(0) | |
| prev = prev[..., stride:] | |
| tgt = (length - demucs.kernel_size) // demucs.stride + 1 | |
| missing = tgt - prev.shape[-1] | |
| offset = length - demucs.kernel_size - demucs.stride * (missing - 1) | |
| x = x[..., offset:] | |
| x = encode[1](encode[0](x)) | |
| x = fast_conv(encode[2], x) | |
| x = encode[3](x) | |
| if not first: | |
| x = torch.cat([prev, x], -1) | |
| next_state.append(x) | |
| skips.append(x) | |
| x = x.permute(2, 0, 1) | |
| x, self.lstm_state = demucs.lstm(x, self.lstm_state) | |
| x = x.permute(1, 2, 0) | |
| # In the following, x contains only correct samples, i.e. the one | |
| # for which each time position is covered by two window of the upper layer. | |
| # extra contains extra samples to the right, and is used only as a | |
| # better padding for the online resampling. | |
| extra = None | |
| for idx, decode in enumerate(demucs.decoder): | |
| skip = skips.pop(-1) | |
| x += skip[..., :x.shape[-1]] | |
| x = fast_conv(decode[0], x) | |
| x = decode[1](x) | |
| if extra is not None: | |
| skip = skip[..., x.shape[-1]:] | |
| extra += skip[..., :extra.shape[-1]] | |
| extra = decode[2](decode[1](decode[0](extra))) | |
| x = decode[2](x) | |
| next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)) | |
| if extra is None: | |
| extra = x[..., -demucs.stride:] | |
| else: | |
| extra[..., :demucs.stride] += next_state[-1] | |
| x = x[..., :-demucs.stride] | |
| if not first: | |
| prev = self.conv_state.pop(0) | |
| x[..., :demucs.stride] += prev | |
| if idx != demucs.depth - 1: | |
| x = decode[3](x) | |
| extra = decode[3](extra) | |
| self.conv_state = next_state | |
| return x[0], extra[0] | |
| def test(): | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| "denoiser.demucs", | |
| description="Benchmark the streaming Demucs implementation, " | |
| "as well as checking the delta with the offline implementation.") | |
| parser.add_argument("--resample", default=4, type=int) | |
| parser.add_argument("--hidden", default=48, type=int) | |
| parser.add_argument("--device", default="cpu") | |
| parser.add_argument("-t", "--num_threads", type=int) | |
| parser.add_argument("-f", "--num_frames", type=int, default=1) | |
| args = parser.parse_args() | |
| if args.num_threads: | |
| torch.set_num_threads(args.num_threads) | |
| sr = 16_000 | |
| sr_ms = sr / 1000 | |
| demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device) | |
| x = torch.randn(1, sr * 4).to(args.device) | |
| out = demucs(x[None])[0] | |
| streamer = DemucsStreamer(demucs, num_frames=args.num_frames) | |
| out_rt = [] | |
| frame_size = streamer.total_length | |
| with torch.no_grad(): | |
| while x.shape[1] > 0: | |
| out_rt.append(streamer.feed(x[:, :frame_size])) | |
| x = x[:, frame_size:] | |
| frame_size = streamer.demucs.total_stride | |
| out_rt.append(streamer.flush()) | |
| out_rt = torch.cat(out_rt, 1) | |
| print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='') | |
| print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='') | |
| print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='') | |
| print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='') | |
| print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}") | |
| if __name__ == "__main__": | |
| test() | |