Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| # author: adiyoss | |
| import argparse | |
| from concurrent.futures import ProcessPoolExecutor | |
| import json | |
| import logging | |
| import sys | |
| from pesq import pesq | |
| from pystoi import stoi | |
| import torch | |
| from .data import NoisyCleanSet | |
| from .enhance import add_flags, get_estimate | |
| from . import distrib, pretrained | |
| from .utils import bold, LogProgress | |
| logger = logging.getLogger(__name__) | |
| parser = argparse.ArgumentParser( | |
| 'denoiser.evaluate', | |
| description='Speech enhancement using Demucs - Evaluate model performance') | |
| add_flags(parser) | |
| parser.add_argument('--data_dir', help='directory including noisy.json and clean.json files') | |
| parser.add_argument('--matching', default="sort", help='set this to dns for the dns dataset.') | |
| parser.add_argument('--no_pesq', action="store_false", dest="pesq", default=True, | |
| help="Don't compute PESQ.") | |
| parser.add_argument('-v', '--verbose', action='store_const', const=logging.DEBUG, | |
| default=logging.INFO, help="More loggging") | |
| def evaluate(args, model=None, data_loader=None): | |
| total_pesq = 0 | |
| total_stoi = 0 | |
| total_cnt = 0 | |
| updates = 5 | |
| # Load model | |
| if not model: | |
| model = pretrained.get_model(args).to(args.device) | |
| model.eval() | |
| # Load data | |
| if data_loader is None: | |
| dataset = NoisyCleanSet(args.data_dir, matching=args.matching, sample_rate=args.sample_rate) | |
| data_loader = distrib.loader(dataset, batch_size=1, num_workers=2) | |
| pendings = [] | |
| with ProcessPoolExecutor(args.num_workers) as pool: | |
| with torch.no_grad(): | |
| iterator = LogProgress(logger, data_loader, name="Eval estimates") | |
| for i, data in enumerate(iterator): | |
| # Get batch data | |
| noisy, clean = [x.to(args.device) for x in data] | |
| # If device is CPU, we do parallel evaluation in each CPU worker. | |
| if args.device == 'cpu': | |
| pendings.append( | |
| pool.submit(_estimate_and_run_metrics, clean, model, noisy, args)) | |
| else: | |
| estimate = get_estimate(model, noisy, args) | |
| estimate = estimate.cpu() | |
| clean = clean.cpu() | |
| pendings.append( | |
| pool.submit(_run_metrics, clean, estimate, args)) | |
| total_cnt += clean.shape[0] | |
| for pending in LogProgress(logger, pendings, updates, name="Eval metrics"): | |
| pesq_i, stoi_i = pending.result() | |
| total_pesq += pesq_i | |
| total_stoi += stoi_i | |
| metrics = [total_pesq, total_stoi] | |
| pesq, stoi = distrib.average([m/total_cnt for m in metrics], total_cnt) | |
| logger.info(bold(f'Test set performance:PESQ={pesq}, STOI={stoi}.')) | |
| return pesq, stoi | |
| def _estimate_and_run_metrics(clean, model, noisy, args): | |
| estimate = get_estimate(model, noisy, args) | |
| return _run_metrics(clean, estimate, args) | |
| def _run_metrics(clean, estimate, args): | |
| estimate = estimate.numpy()[:, 0] | |
| clean = clean.numpy()[:, 0] | |
| if args.pesq: | |
| pesq_i = get_pesq(clean, estimate, sr=args.sample_rate) | |
| else: | |
| pesq_i = 0 | |
| stoi_i = get_stoi(clean, estimate, sr=args.sample_rate) | |
| return pesq_i, stoi_i | |
| def get_pesq(ref_sig, out_sig, sr): | |
| """Calculate PESQ. | |
| Args: | |
| ref_sig: numpy.ndarray, [B, T] | |
| out_sig: numpy.ndarray, [B, T] | |
| Returns: | |
| PESQ | |
| """ | |
| pesq_val = 0 | |
| for i in range(len(ref_sig)): | |
| pesq_val += pesq(sr, ref_sig[i], out_sig[i], 'wb') | |
| return pesq_val | |
| def get_stoi(ref_sig, out_sig, sr): | |
| """Calculate STOI. | |
| Args: | |
| ref_sig: numpy.ndarray, [B, T] | |
| out_sig: numpy.ndarray, [B, T] | |
| Returns: | |
| STOI | |
| """ | |
| stoi_val = 0 | |
| for i in range(len(ref_sig)): | |
| stoi_val += stoi(ref_sig[i], out_sig[i], sr, extended=False) | |
| return stoi_val | |
| def main(): | |
| args = parser.parse_args() | |
| logging.basicConfig(stream=sys.stderr, level=args.verbose) | |
| logger.debug(args) | |
| pesq, stoi = evaluate(args) | |
| json.dump({'pesq': pesq, 'stoi': stoi}, sys.stdout) | |
| sys.stdout.write('\n') | |
| if __name__ == '__main__': | |
| main() | |