Spaces:
Running
Running
| import sys,os | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| import numpy as np | |
| import argparse | |
| from tqdm import tqdm | |
| from functools import partial | |
| from argparse import RawTextHelpFormatter | |
| from multiprocessing.pool import ThreadPool | |
| from speaker.models.lstm import LSTMSpeakerEncoder | |
| from speaker.config import SpeakerEncoderConfig | |
| from speaker.utils.audio import AudioProcessor | |
| from speaker.infer import read_json | |
| def get_spk_wavs(dataset_path, output_path): | |
| wav_files = [] | |
| os.makedirs(f"./{output_path}", exist_ok=True) | |
| for spks in os.listdir(dataset_path): | |
| if os.path.isdir(f"./{dataset_path}/{spks}"): | |
| os.makedirs(f"./{output_path}/{spks}", exist_ok=True) | |
| for file in os.listdir(f"./{dataset_path}/{spks}"): | |
| if file.endswith(".wav"): | |
| wav_files.append(f"./{dataset_path}/{spks}/{file}") | |
| elif spks.endswith(".wav"): | |
| wav_files.append(f"./{dataset_path}/{spks}") | |
| return wav_files | |
| def process_wav(wav_file, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder): | |
| waveform = speaker_encoder_ap.load_wav( | |
| wav_file, sr=speaker_encoder_ap.sample_rate | |
| ) | |
| spec = speaker_encoder_ap.melspectrogram(waveform) | |
| spec = torch.from_numpy(spec.T) | |
| if args.use_cuda: | |
| spec = spec.cuda() | |
| spec = spec.unsqueeze(0) | |
| embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() | |
| embed = embed.squeeze() | |
| embed_path = wav_file.replace(dataset_path, output_path) | |
| embed_path = embed_path.replace(".wav", ".spk") | |
| np.save(embed_path, embed, allow_pickle=False) | |
| def extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, concurrency): | |
| bound_process_wav = partial(process_wav, dataset_path=dataset_path, output_path=output_path, args=args, speaker_encoder_ap=speaker_encoder_ap, speaker_encoder=speaker_encoder) | |
| with ThreadPool(concurrency) as pool: | |
| list(tqdm(pool.imap(bound_process_wav, wav_files), total=len(wav_files))) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser( | |
| description="""Compute embedding vectors for each wav file in a dataset.""", | |
| formatter_class=RawTextHelpFormatter, | |
| ) | |
| parser.add_argument("dataset_path", type=str, help="Path to dataset waves.") | |
| parser.add_argument( | |
| "output_path", type=str, help="path for output speaker/speaker_wavs.npy." | |
| ) | |
| parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) | |
| parser.add_argument("-t", "--thread_count", help="thread count to process, set 0 to use all cpu cores", dest="thread_count", type=int, default=1) | |
| args = parser.parse_args() | |
| dataset_path = args.dataset_path | |
| output_path = args.output_path | |
| thread_count = args.thread_count | |
| # model | |
| args.model_path = os.path.join("speaker_pretrain", "best_model.pth.tar") | |
| args.config_path = os.path.join("speaker_pretrain", "config.json") | |
| # config | |
| config_dict = read_json(args.config_path) | |
| # model | |
| config = SpeakerEncoderConfig(config_dict) | |
| config.from_dict(config_dict) | |
| speaker_encoder = LSTMSpeakerEncoder( | |
| config.model_params["input_dim"], | |
| config.model_params["proj_dim"], | |
| config.model_params["lstm_dim"], | |
| config.model_params["num_lstm_layers"], | |
| ) | |
| speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) | |
| # preprocess | |
| speaker_encoder_ap = AudioProcessor(**config.audio) | |
| # normalize the input audio level and trim silences | |
| speaker_encoder_ap.do_sound_norm = True | |
| speaker_encoder_ap.do_trim_silence = True | |
| wav_files = get_spk_wavs(dataset_path, output_path) | |
| if thread_count == 0: | |
| process_num = os.cpu_count() | |
| else: | |
| process_num = thread_count | |
| extract_speaker_embeddings(wav_files, dataset_path, output_path, args, speaker_encoder_ap, speaker_encoder, process_num) |