Spaces:
Build error
Build error
| import torch | |
| from deepafx_st.processors.proxy.proxy_system import ProxySystem | |
| from deepafx_st.utils import DSPMode | |
| class ProxyChannel(torch.nn.Module): | |
| def __init__( | |
| self, | |
| proxy_system_ckpts: list, | |
| freeze_proxies: bool = True, | |
| dsp_mode: DSPMode = DSPMode.NONE, | |
| num_tcns: int = 2, | |
| tcn_nblocks: int = 4, | |
| tcn_dilation_growth: int = 8, | |
| tcn_channel_width: int = 64, | |
| tcn_kernel_size: int = 13, | |
| sample_rate: int = 24000, | |
| ): | |
| super().__init__() | |
| self.freeze_proxies = freeze_proxies | |
| self.dsp_mode = dsp_mode | |
| self.num_tcns = num_tcns | |
| # load the proxies | |
| self.proxies = torch.nn.ModuleList() | |
| self.num_control_params = 0 | |
| self.ports = [] | |
| for proxy_system_ckpt in proxy_system_ckpts: | |
| proxy = ProxySystem.load_from_checkpoint(proxy_system_ckpt) | |
| # freeze model parameters | |
| if freeze_proxies: | |
| for param in proxy.parameters(): | |
| param.requires_grad = False | |
| self.proxies.append(proxy) | |
| if proxy.hparams.processor == "channel": | |
| self.ports = proxy.processor.ports | |
| else: | |
| self.ports.append(proxy.processor.ports) | |
| self.num_control_params += proxy.processor.num_control_params | |
| if len(proxy_system_ckpts) == 0: | |
| if self.num_tcns == 2: | |
| peq_proxy = ProxySystem( | |
| processor="peq", | |
| output_gain=False, | |
| nblocks=tcn_nblocks, | |
| dilation_growth=tcn_dilation_growth, | |
| kernel_size=tcn_kernel_size, | |
| channel_width=tcn_channel_width, | |
| sample_rate=sample_rate, | |
| ) | |
| self.proxies.append(peq_proxy) | |
| self.ports.append(peq_proxy.processor.ports) | |
| self.num_control_params += peq_proxy.processor.num_control_params | |
| comp_proxy = ProxySystem( | |
| processor="comp", | |
| output_gain=True, | |
| nblocks=tcn_nblocks, | |
| dilation_growth=tcn_dilation_growth, | |
| kernel_size=tcn_kernel_size, | |
| channel_width=tcn_channel_width, | |
| sample_rate=sample_rate, | |
| ) | |
| self.proxies.append(comp_proxy) | |
| self.ports.append(comp_proxy.processor.ports) | |
| self.num_control_params += comp_proxy.processor.num_control_params | |
| elif self.num_tcns == 1: | |
| channel_proxy = ProxySystem( | |
| processor="channel", | |
| output_gain=True, | |
| nblocks=tcn_nblocks, | |
| dilation_growth=tcn_dilation_growth, | |
| kernel_size=tcn_kernel_size, | |
| channel_width=tcn_channel_width, | |
| sample_rate=sample_rate, | |
| ) | |
| self.proxies.append(channel_proxy) | |
| for port_list in channel_proxy.processor.ports: | |
| self.ports.append(port_list) | |
| self.num_control_params += channel_proxy.processor.num_control_params | |
| else: | |
| raise ValueError(f"num_tcns must be <= 2. Asked for {self.num_tcns}.") | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| p: torch.Tensor, | |
| dsp_mode: DSPMode = DSPMode.NONE, | |
| sample_rate: int = 24000, | |
| **kwargs, | |
| ): | |
| # loop over the proxies and pass parameters | |
| stop_idx = 0 | |
| for proxy in self.proxies: | |
| start_idx = stop_idx | |
| stop_idx += proxy.processor.num_control_params | |
| p_subset = p[:, start_idx:stop_idx] | |
| if dsp_mode.name == DSPMode.NONE.name: | |
| x = proxy( | |
| x, | |
| p_subset, | |
| use_dsp=False, | |
| ) | |
| elif dsp_mode.name == DSPMode.INFER.name: | |
| x = proxy( | |
| x, | |
| p_subset, | |
| use_dsp=True, | |
| sample_rate=sample_rate, | |
| ) | |
| elif dsp_mode.name == DSPMode.TRAIN_INFER.name: | |
| # Mimic gumbel softmax implementation to replace grads similar to | |
| # https://gist.github.com/yzh119/fd2146d2aeb329d067568a493b20172f | |
| x_hard = proxy( | |
| x, | |
| p_subset, | |
| use_dsp=True, | |
| sample_rate=sample_rate, | |
| ) | |
| x = proxy( | |
| x, | |
| p_subset, | |
| use_dsp=False, | |
| sample_rate=sample_rate, | |
| ) | |
| x = (x_hard - x).detach() + x | |
| else: | |
| assert 0, "invalid dsp model for proxy" | |
| return x | |