| import argparse | |
| import glob | |
| import os | |
| import librosa | |
| import numpy as np | |
| import onnx | |
| import onnxruntime | |
| import soundfile as sf | |
| import torch | |
| import tqdm | |
| from config import CONFIG | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--onnx_path', default=None, | |
| help='path to onnx') | |
| args = parser.parse_args() | |
| if __name__ == '__main__': | |
| path = args.onnx_path | |
| window = CONFIG.DATA.window_size | |
| stride = CONFIG.DATA.stride | |
| onnx_model = onnx.load(path) | |
| options = onnxruntime.SessionOptions() | |
| options.intra_op_num_threads = 8 | |
| options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| session = onnxruntime.InferenceSession(path, options) | |
| input_names = [x.name for x in session.get_inputs()] | |
| output_names = [x.name for x in session.get_outputs()] | |
| print(input_names) | |
| print(output_names) | |
| audio_files = glob.glob(os.path.join(CONFIG.TEST.in_dir, '*.wav')) | |
| hann = torch.sqrt(torch.hann_window(window)) | |
| os.makedirs(CONFIG.TEST.out_dir, exist_ok=True) | |
| for file in tqdm.tqdm(audio_files, total=len(audio_files)): | |
| sig, _ = librosa.load(file, sr=48000) | |
| sig = torch.tensor(sig) | |
| re_im = torch.stft(sig, window, stride, window=hann, return_complex=False).permute(1, 0, 2).unsqueeze( | |
| 1).numpy().astype(np.float32) | |
| inputs = {input_names[i]: np.zeros([d.dim_value for d in _input.type.tensor_type.shape.dim], | |
| dtype=np.float32) | |
| for i, _input in enumerate(onnx_model.graph.input) | |
| } | |
| output_audio = [] | |
| for t in range(re_im.shape[0]): | |
| inputs[input_names[0]] = re_im[t] | |
| out, prev_mag, predictor_state, mlp_state = session.run(output_names, inputs) | |
| inputs[input_names[1]] = prev_mag | |
| inputs[input_names[2]] = predictor_state | |
| inputs[input_names[3]] = mlp_state | |
| output_audio.append(out) | |
| output_audio = torch.tensor(np.concatenate(output_audio, 0)) | |
| output_audio = output_audio.permute(1, 0, 2).contiguous() | |
| output_audio = torch.view_as_complex(output_audio) | |
| output_audio = torch.istft(output_audio, window, stride, window=hann) | |
| sf.write(os.path.join(CONFIG.TEST.out_dir, os.path.basename(file)), output_audio, samplerate=48000, | |
| subtype='PCM_16') | |