Spaces:
Runtime error
Runtime error
| # Copyright (c) 2021 Mobvoi Inc (Chao Yang) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import sys | |
| import random | |
| import math | |
| import torchaudio | |
| import torch | |
| def db2amp(db): | |
| return pow(10, db / 20) | |
| def amp2db(amp): | |
| return 20 * math.log10(amp) | |
| def make_poly_distortion(conf): | |
| """Generate a db-domain ploynomial distortion function | |
| f(x) = a * x^m * (1-x)^n + x | |
| Args: | |
| conf: a dict {'a': #int, 'm': #int, 'n': #int} | |
| Returns: | |
| The ploynomial function, which could be applied on | |
| a float amplitude value | |
| """ | |
| a = conf['a'] | |
| m = conf['m'] | |
| n = conf['n'] | |
| def poly_distortion(x): | |
| abs_x = abs(x) | |
| if abs_x < 0.000001: | |
| x = x | |
| else: | |
| db_norm = amp2db(abs_x) / 100 + 1 | |
| if db_norm < 0: | |
| db_norm = 0 | |
| db_norm = a * pow(db_norm, m) * pow((1 - db_norm), n) + db_norm | |
| if db_norm > 1: | |
| db_norm = 1 | |
| db = (db_norm - 1) * 100 | |
| amp = db2amp(db) | |
| if amp >= 0.9997: | |
| amp = 0.9997 | |
| if x > 0: | |
| x = amp | |
| else: | |
| x = -amp | |
| return x | |
| return poly_distortion | |
| def make_quad_distortion(): | |
| return make_poly_distortion({'a': 1, 'm': 1, 'n': 1}) | |
| # the amplitude are set to max for all non-zero point | |
| def make_max_distortion(conf): | |
| """Generate a max distortion function | |
| Args: | |
| conf: a dict {'max_db': float } | |
| 'max_db': the maxium value. | |
| Returns: | |
| The max function, which could be applied on | |
| a float amplitude value | |
| """ | |
| max_db = conf['max_db'] | |
| if max_db: | |
| max_amp = db2amp(max_db) # < 0.997 | |
| else: | |
| max_amp = 0.997 | |
| def max_distortion(x): | |
| if x > 0: | |
| x = max_amp | |
| elif x < 0: | |
| x = -max_amp | |
| else: | |
| x = 0.0 | |
| return x | |
| return max_distortion | |
| def make_amp_mask(db_mask=None): | |
| """Get a amplitude domain mask from db domain mask | |
| Args: | |
| db_mask: Optional. A list of tuple. if None, using default value. | |
| Returns: | |
| A list of tuple. The amplitude domain mask | |
| """ | |
| if db_mask is None: | |
| db_mask = [(-110, -95), (-90, -80), (-65, -60), (-50, -30), (-15, 0)] | |
| amp_mask = [(db2amp(db[0]), db2amp(db[1])) for db in db_mask] | |
| return amp_mask | |
| default_mask = make_amp_mask() | |
| def generate_amp_mask(mask_num): | |
| """Generate amplitude domain mask randomly in [-100db, 0db] | |
| Args: | |
| mask_num: the slot number of the mask | |
| Returns: | |
| A list of tuple. each tuple defines a slot. | |
| e.g. [(-100, -80), (-65, -60), (-50, -30), (-15, 0)] | |
| for #mask_num = 4 | |
| """ | |
| a = [0] * 2 * mask_num | |
| a[0] = 0 | |
| m = [] | |
| for i in range(1, 2 * mask_num): | |
| a[i] = a[i - 1] + random.uniform(0.5, 1) | |
| max_val = a[2 * mask_num - 1] | |
| for i in range(0, mask_num): | |
| l = ((a[2 * i] - max_val) / max_val) * 100 | |
| r = ((a[2 * i + 1] - max_val) / max_val) * 100 | |
| m.append((l, r)) | |
| return make_amp_mask(m) | |
| def make_fence_distortion(conf): | |
| """Generate a fence distortion function | |
| In this fence-like shape function, the values in mask slots are | |
| set to maxium, while the values not in mask slots are set to 0. | |
| Use seperated masks for Positive and negetive amplitude. | |
| Args: | |
| conf: a dict {'mask_number': int,'max_db': float } | |
| 'mask_number': the slot number in mask. | |
| 'max_db': the maxium value. | |
| Returns: | |
| The fence function, which could be applied on | |
| a float amplitude value | |
| """ | |
| mask_number = conf['mask_number'] | |
| max_db = conf['max_db'] | |
| max_amp = db2amp(max_db) # 0.997 | |
| if mask_number <= 0: | |
| positive_mask = default_mask | |
| negative_mask = make_amp_mask([(-50, 0)]) | |
| else: | |
| positive_mask = generate_amp_mask(mask_number) | |
| negative_mask = generate_amp_mask(mask_number) | |
| def fence_distortion(x): | |
| is_in_mask = False | |
| if x > 0: | |
| for mask in positive_mask: | |
| if x >= mask[0] and x <= mask[1]: | |
| is_in_mask = True | |
| return max_amp | |
| if not is_in_mask: | |
| return 0.0 | |
| elif x < 0: | |
| abs_x = abs(x) | |
| for mask in negative_mask: | |
| if abs_x >= mask[0] and abs_x <= mask[1]: | |
| is_in_mask = True | |
| return max_amp | |
| if not is_in_mask: | |
| return 0.0 | |
| return x | |
| return fence_distortion | |
| # | |
| def make_jag_distortion(conf): | |
| """Generate a jag distortion function | |
| In this jag-like shape function, the values in mask slots are | |
| not changed, while the values not in mask slots are set to 0. | |
| Use seperated masks for Positive and negetive amplitude. | |
| Args: | |
| conf: a dict {'mask_number': #int} | |
| 'mask_number': the slot number in mask. | |
| Returns: | |
| The jag function,which could be applied on | |
| a float amplitude value | |
| """ | |
| mask_number = conf['mask_number'] | |
| if mask_number <= 0: | |
| positive_mask = default_mask | |
| negative_mask = make_amp_mask([(-50, 0)]) | |
| else: | |
| positive_mask = generate_amp_mask(mask_number) | |
| negative_mask = generate_amp_mask(mask_number) | |
| def jag_distortion(x): | |
| is_in_mask = False | |
| if x > 0: | |
| for mask in positive_mask: | |
| if x >= mask[0] and x <= mask[1]: | |
| is_in_mask = True | |
| return x | |
| if not is_in_mask: | |
| return 0.0 | |
| elif x < 0: | |
| abs_x = abs(x) | |
| for mask in negative_mask: | |
| if abs_x >= mask[0] and abs_x <= mask[1]: | |
| is_in_mask = True | |
| return x | |
| if not is_in_mask: | |
| return 0.0 | |
| return x | |
| return jag_distortion | |
| # gaining 20db means amp = amp * 10 | |
| # gaining -20db means amp = amp / 10 | |
| def make_gain_db(conf): | |
| """Generate a db domain gain function | |
| Args: | |
| conf: a dict {'db': #float} | |
| 'db': the gaining value | |
| Returns: | |
| The db gain function, which could be applied on | |
| a float amplitude value | |
| """ | |
| db = conf['db'] | |
| def gain_db(x): | |
| return min(0.997, x * pow(10, db / 20)) | |
| return gain_db | |
| def distort(x, func, rate=0.8): | |
| """Distort a waveform in sample point level | |
| Args: | |
| x: the origin wavefrom | |
| func: the distort function | |
| rate: sample point-level distort probability | |
| Returns: | |
| the distorted waveform | |
| """ | |
| for i in range(0, x.shape[1]): | |
| a = random.uniform(0, 1) | |
| if a < rate: | |
| x[0][i] = func(float(x[0][i])) | |
| return x | |
| def distort_chain(x, funcs, rate=0.8): | |
| for i in range(0, x.shape[1]): | |
| a = random.uniform(0, 1) | |
| if a < rate: | |
| for func in funcs: | |
| x[0][i] = func(float(x[0][i])) | |
| return x | |
| # x is numpy | |
| def distort_wav_conf(x, distort_type, distort_conf, rate=0.1): | |
| if distort_type == 'gain_db': | |
| gain_db = make_gain_db(distort_conf) | |
| x = distort(x, gain_db) | |
| elif distort_type == 'max_distortion': | |
| max_distortion = make_max_distortion(distort_conf) | |
| x = distort(x, max_distortion, rate=rate) | |
| elif distort_type == 'fence_distortion': | |
| fence_distortion = make_fence_distortion(distort_conf) | |
| x = distort(x, fence_distortion, rate=rate) | |
| elif distort_type == 'jag_distortion': | |
| jag_distortion = make_jag_distortion(distort_conf) | |
| x = distort(x, jag_distortion, rate=rate) | |
| elif distort_type == 'poly_distortion': | |
| poly_distortion = make_poly_distortion(distort_conf) | |
| x = distort(x, poly_distortion, rate=rate) | |
| elif distort_type == 'quad_distortion': | |
| quad_distortion = make_quad_distortion() | |
| x = distort(x, quad_distortion, rate=rate) | |
| elif distort_type == 'none_distortion': | |
| pass | |
| else: | |
| print('unsupport type') | |
| return x | |
| def distort_wav_conf_and_save(distort_type, distort_conf, rate, wav_in, | |
| wav_out): | |
| x, sr = torchaudio.load(wav_in) | |
| x = x.detach().numpy() | |
| out = distort_wav_conf(x, distort_type, distort_conf, rate) | |
| torchaudio.save(wav_out, torch.from_numpy(out), sr) | |
| if __name__ == "__main__": | |
| distort_type = sys.argv[1] | |
| wav_in = sys.argv[2] | |
| wav_out = sys.argv[3] | |
| conf = None | |
| rate = 0.1 | |
| if distort_type == 'new_jag_distortion': | |
| conf = {'mask_number': 4} | |
| elif distort_type == 'new_fence_distortion': | |
| conf = {'mask_number': 1, 'max_db': -30} | |
| elif distort_type == 'poly_distortion': | |
| conf = {'a': 4, 'm': 2, "n": 2} | |
| distort_wav_conf_and_save(distort_type, conf, rate, wav_in, wav_out) | |