Spaces:
Configuration error
Configuration error
| from collections import OrderedDict | |
| import torch | |
| import utils | |
| from models import SynthesizerTrn | |
| def copyStateDict(state_dict): | |
| if list(state_dict.keys())[0].startswith('module'): | |
| start_idx = 1 | |
| else: | |
| start_idx = 0 | |
| new_state_dict = OrderedDict() | |
| for k, v in state_dict.items(): | |
| name = ','.join(k.split('.')[start_idx:]) | |
| new_state_dict[name] = v | |
| return new_state_dict | |
| def removeOptimizer(config: str, input_model: str, ishalf: bool, output_model: str): | |
| hps = utils.get_hparams_from_file(config) | |
| net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps.model) | |
| optim_g = torch.optim.AdamW(net_g.parameters(), | |
| hps.train.learning_rate, | |
| betas=hps.train.betas, | |
| eps=hps.train.eps) | |
| state_dict_g = torch.load(input_model, map_location="cpu") | |
| new_dict_g = copyStateDict(state_dict_g) | |
| keys = [] | |
| for k, v in new_dict_g['model'].items(): | |
| if "enc_q" in k: continue # noqa: E701 | |
| keys.append(k) | |
| new_dict_g = {k: new_dict_g['model'][k].half() for k in keys} if ishalf else {k: new_dict_g['model'][k] for k in keys} | |
| torch.save( | |
| { | |
| 'model': new_dict_g, | |
| 'iteration': 0, | |
| 'optimizer': optim_g.state_dict(), | |
| 'learning_rate': 0.0001 | |
| }, output_model) | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("-c", | |
| "--config", | |
| type=str, | |
| default='configs/config.json') | |
| parser.add_argument("-i", "--input", type=str) | |
| parser.add_argument("-o", "--output", type=str, default=None) | |
| parser.add_argument('-hf', '--half', action='store_true', default=False, help='Save as FP16') | |
| args = parser.parse_args() | |
| output = args.output | |
| if output is None: | |
| import os.path | |
| filename, ext = os.path.splitext(args.input) | |
| half = "_half" if args.half else "" | |
| output = filename + "_release" + half + ext | |
| removeOptimizer(args.config, args.input, args.half, output) |