Spaces:
Running
Running
| import re | |
| import json | |
| import fsspec | |
| import torch | |
| import numpy as np | |
| import argparse | |
| from argparse import RawTextHelpFormatter | |
| from .models.lstm import LSTMSpeakerEncoder | |
| from .config import SpeakerEncoderConfig | |
| from .utils.audio import AudioProcessor | |
| def read_json(json_path): | |
| config_dict = {} | |
| try: | |
| with fsspec.open(json_path, "r", encoding="utf-8") as f: | |
| data = json.load(f) | |
| except json.decoder.JSONDecodeError: | |
| # backwards compat. | |
| data = read_json_with_comments(json_path) | |
| config_dict.update(data) | |
| return config_dict | |
| def read_json_with_comments(json_path): | |
| """for backward compat.""" | |
| # fallback to json | |
| with fsspec.open(json_path, "r", encoding="utf-8") as f: | |
| input_str = f.read() | |
| # handle comments | |
| input_str = re.sub(r"\\\n", "", input_str) | |
| input_str = re.sub(r"//.*\n", "\n", input_str) | |
| data = json.loads(input_str) | |
| return data | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="""Compute embedding vectors for each wav file in a dataset.""", | |
| formatter_class=RawTextHelpFormatter, | |
| ) | |
| parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") | |
| parser.add_argument( | |
| "config_path", | |
| type=str, | |
| help="Path to model config file.", | |
| ) | |
| parser.add_argument("-s", "--source", help="input wave", dest="source") | |
| parser.add_argument( | |
| "-t", "--target", help="output 256d speaker embeddimg", dest="target" | |
| ) | |
| parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) | |
| parser.add_argument("--eval", type=bool, help="compute eval.", default=True) | |
| args = parser.parse_args() | |
| source_file = args.source | |
| target_file = args.target | |
| # config | |
| config_dict = read_json(args.config_path) | |
| # print(config_dict) | |
| # model | |
| config = SpeakerEncoderConfig(config_dict) | |
| config.from_dict(config_dict) | |
| speaker_encoder = LSTMSpeakerEncoder( | |
| config.model_params["input_dim"], | |
| config.model_params["proj_dim"], | |
| config.model_params["lstm_dim"], | |
| config.model_params["num_lstm_layers"], | |
| ) | |
| speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) | |
| # preprocess | |
| speaker_encoder_ap = AudioProcessor(**config.audio) | |
| # normalize the input audio level and trim silences | |
| speaker_encoder_ap.do_sound_norm = True | |
| speaker_encoder_ap.do_trim_silence = True | |
| # compute speaker embeddings | |
| # extract the embedding | |
| waveform = speaker_encoder_ap.load_wav( | |
| source_file, sr=speaker_encoder_ap.sample_rate | |
| ) | |
| spec = speaker_encoder_ap.melspectrogram(waveform) | |
| spec = torch.from_numpy(spec.T) | |
| if args.use_cuda: | |
| spec = spec.cuda() | |
| spec = spec.unsqueeze(0) | |
| embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() | |
| embed = embed.squeeze() | |
| # print(embed) | |
| # print(embed.size) | |
| np.save(target_file, embed, allow_pickle=False) | |
| if hasattr(speaker_encoder, 'module'): | |
| state_dict = speaker_encoder.module.state_dict() | |
| else: | |
| state_dict = speaker_encoder.state_dict() | |
| torch.save({'model': state_dict}, "model_small.pth") | |