Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import random | |
| import torch | |
| from torch.nn.utils.rnn import pad_sequence | |
| from utils.data_utils import * | |
| from models.base.base_dataset import ( | |
| BaseCollator, | |
| BaseDataset, | |
| BaseTestDataset, | |
| BaseTestCollator, | |
| ) | |
| from text import text_to_sequence | |
| class FS2Dataset(BaseDataset): | |
| def __init__(self, cfg, dataset, is_valid=False): | |
| BaseDataset.__init__(self, cfg, dataset, is_valid=is_valid) | |
| self.batch_size = cfg.train.batch_size | |
| cfg = cfg.preprocess | |
| # utt2duration | |
| self.utt2duration_path = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2duration_path[utt] = os.path.join( | |
| cfg.processed_dir, | |
| dataset, | |
| cfg.duration_dir, | |
| uid + ".npy", | |
| ) | |
| self.utt2dur = self.read_duration() | |
| if cfg.use_frame_energy: | |
| self.frame_utt2energy, self.energy_statistic = load_energy( | |
| self.metadata, | |
| cfg.processed_dir, | |
| cfg.energy_dir, | |
| use_log_scale=cfg.use_log_scale_energy, | |
| utt2spk=self.preprocess.utt2spk if cfg.use_spkid else None, | |
| return_norm=True, | |
| ) | |
| elif cfg.use_phone_energy: | |
| self.phone_utt2energy, self.energy_statistic = load_energy( | |
| self.metadata, | |
| cfg.processed_dir, | |
| cfg.phone_energy_dir, | |
| use_log_scale=cfg.use_log_scale_energy, | |
| utt2spk=self.utt2spk if cfg.use_spkid else None, | |
| return_norm=True, | |
| ) | |
| if cfg.use_frame_pitch: | |
| self.frame_utt2pitch, self.pitch_statistic = load_energy( | |
| self.metadata, | |
| cfg.processed_dir, | |
| cfg.pitch_dir, | |
| use_log_scale=cfg.energy_extract_mode, | |
| utt2spk=self.utt2spk if cfg.use_spkid else None, | |
| return_norm=True, | |
| ) | |
| elif cfg.use_phone_pitch: | |
| self.phone_utt2pitch, self.pitch_statistic = load_energy( | |
| self.metadata, | |
| cfg.processed_dir, | |
| cfg.phone_pitch_dir, | |
| use_log_scale=cfg.use_log_scale_pitch, | |
| utt2spk=self.utt2spk if cfg.use_spkid else None, | |
| return_norm=True, | |
| ) | |
| # utt2lab | |
| self.utt2lab_path = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2lab_path[utt] = os.path.join( | |
| cfg.processed_dir, | |
| dataset, | |
| cfg.lab_dir, | |
| uid + ".txt", | |
| ) | |
| self.speaker_map = {} | |
| if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): | |
| with open( | |
| os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) | |
| ) as f: | |
| self.speaker_map = json.load(f) | |
| self.metadata = self.check_metadata() | |
| def __getitem__(self, index): | |
| single_feature = BaseDataset.__getitem__(self, index) | |
| utt_info = self.metadata[index] | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| duration = self.utt2dur[utt] | |
| # text | |
| f = open(self.utt2lab_path[utt], "r") | |
| phones = f.readlines()[0].strip() | |
| f.close() | |
| # todo: add cleaner(chenxi) | |
| phones_ids = np.array(text_to_sequence(phones, ["english_cleaners"])) | |
| text_len = len(phones_ids) | |
| if self.cfg.preprocess.use_frame_pitch: | |
| pitch = self.frame_utt2pitch[utt] | |
| elif self.cfg.preprocess.use_phone_pitch: | |
| pitch = self.phone_utt2pitch[utt] | |
| if self.cfg.preprocess.use_frame_energy: | |
| energy = self.frame_utt2energy[utt] | |
| elif self.cfg.preprocess.use_phone_energy: | |
| energy = self.phone_utt2energy[utt] | |
| # speaker | |
| if len(self.speaker_map) > 0: | |
| speaker_id = self.speaker_map[utt_info["Singer"]] | |
| else: | |
| speaker_id = 0 | |
| single_feature.update( | |
| { | |
| "durations": duration, | |
| "texts": phones_ids, | |
| "spk_id": speaker_id, | |
| "text_len": text_len, | |
| "pitch": pitch, | |
| "energy": energy, | |
| "uid": uid, | |
| } | |
| ) | |
| return self.clip_if_too_long(single_feature) | |
| def read_duration(self): | |
| # read duration | |
| utt2dur = {} | |
| for index in range(len(self.metadata)): | |
| utt_info = self.metadata[index] | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| if not os.path.exists(self.utt2mel_path[utt]) or not os.path.exists( | |
| self.utt2duration_path[utt] | |
| ): | |
| continue | |
| mel = np.load(self.utt2mel_path[utt]).transpose(1, 0) | |
| duration = np.load(self.utt2duration_path[utt]) | |
| assert mel.shape[0] == sum( | |
| duration | |
| ), f"{utt}: mismatch length between mel {mel.shape[0]} and sum(duration) {sum(duration)}" | |
| utt2dur[utt] = duration | |
| return utt2dur | |
| def __len__(self): | |
| return len(self.metadata) | |
| def random_select(self, feature_seq_len, max_seq_len, ending_ts=2812): | |
| """ | |
| ending_ts: to avoid invalid whisper features for over 30s audios | |
| 2812 = 30 * 24000 // 256 | |
| """ | |
| ts = max(feature_seq_len - max_seq_len, 0) | |
| ts = min(ts, ending_ts - max_seq_len) | |
| start = random.randint(0, ts) | |
| end = start + max_seq_len | |
| return start, end | |
| def clip_if_too_long(self, sample, max_seq_len=1000): | |
| """ | |
| sample : | |
| { | |
| 'spk_id': (1,), | |
| 'target_len': int | |
| 'mel': (seq_len, dim), | |
| 'frame_pitch': (seq_len,) | |
| 'frame_energy': (seq_len,) | |
| 'content_vector_feat': (seq_len, dim) | |
| } | |
| """ | |
| if sample["target_len"] <= max_seq_len: | |
| return sample | |
| start, end = self.random_select(sample["target_len"], max_seq_len) | |
| sample["target_len"] = end - start | |
| for k in sample.keys(): | |
| if k not in ["spk_id", "target_len"]: | |
| sample[k] = sample[k][start:end] | |
| return sample | |
| def check_metadata(self): | |
| new_metadata = [] | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| if not os.path.exists(self.utt2duration_path[utt]) or not os.path.exists( | |
| self.utt2mel_path[utt] | |
| ): | |
| continue | |
| else: | |
| new_metadata.append(utt_info) | |
| return new_metadata | |
| class FS2Collator(BaseCollator): | |
| """Zero-pads model inputs and targets based on number of frames per step""" | |
| def __init__(self, cfg): | |
| BaseCollator.__init__(self, cfg) | |
| self.sort = cfg.train.sort_sample | |
| self.batch_size = cfg.train.batch_size | |
| self.drop_last = cfg.train.drop_last | |
| def __call__(self, batch): | |
| # mel: [b, T, n_mels] | |
| # frame_pitch, frame_energy: [1, T] | |
| # target_len: [1] | |
| # spk_id: [b, 1] | |
| # mask: [b, T, 1] | |
| packed_batch_features = dict() | |
| for key in batch[0].keys(): | |
| if key == "target_len": | |
| packed_batch_features["target_len"] = torch.LongTensor( | |
| [b["target_len"] for b in batch] | |
| ) | |
| masks = [ | |
| torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch | |
| ] | |
| packed_batch_features["mask"] = pad_sequence( | |
| masks, batch_first=True, padding_value=0 | |
| ) | |
| elif key == "text_len": | |
| packed_batch_features["text_len"] = torch.LongTensor( | |
| [b["text_len"] for b in batch] | |
| ) | |
| masks = [ | |
| torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch | |
| ] | |
| packed_batch_features["text_mask"] = pad_sequence( | |
| masks, batch_first=True, padding_value=0 | |
| ) | |
| elif key == "spk_id": | |
| packed_batch_features["spk_id"] = torch.LongTensor( | |
| [b["spk_id"] for b in batch] | |
| ) | |
| elif key == "uid": | |
| packed_batch_features[key] = [b["uid"] for b in batch] | |
| else: | |
| values = [torch.from_numpy(b[key]) for b in batch] | |
| packed_batch_features[key] = pad_sequence( | |
| values, batch_first=True, padding_value=0 | |
| ) | |
| return packed_batch_features | |
| class FS2TestDataset(BaseTestDataset): | |
| def __init__(self, args, cfg, infer_type=None): | |
| datasets = cfg.dataset | |
| cfg = cfg.preprocess | |
| is_bigdata = False | |
| assert len(datasets) >= 1 | |
| if len(datasets) > 1: | |
| datasets.sort() | |
| bigdata_version = "_".join(datasets) | |
| processed_data_dir = os.path.join(cfg.processed_dir, bigdata_version) | |
| is_bigdata = True | |
| else: | |
| processed_data_dir = os.path.join(cfg.processed_dir, args.dataset) | |
| if args.test_list_file: | |
| self.metafile_path = args.test_list_file | |
| self.metadata = self.get_metadata() | |
| else: | |
| assert args.testing_set | |
| source_metafile_path = os.path.join( | |
| cfg.processed_dir, | |
| args.dataset, | |
| "{}.json".format(args.testing_set), | |
| ) | |
| with open(source_metafile_path, "r") as f: | |
| self.metadata = json.load(f) | |
| self.cfg = cfg | |
| self.datasets = datasets | |
| self.data_root = processed_data_dir | |
| self.is_bigdata = is_bigdata | |
| self.source_dataset = args.dataset | |
| ######### Load source acoustic features ######### | |
| if cfg.use_spkid: | |
| spk2id_path = os.path.join(self.data_root, cfg.spk2id) | |
| utt2sp_path = os.path.join(self.data_root, cfg.utt2spk) | |
| self.spk2id, self.utt2spk = get_spk_map(spk2id_path, utt2sp_path, datasets) | |
| # utt2lab | |
| self.utt2lab_path = {} | |
| for utt_info in self.metadata: | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| self.utt2lab_path[utt] = os.path.join( | |
| cfg.processed_dir, | |
| dataset, | |
| cfg.lab_dir, | |
| uid + ".txt", | |
| ) | |
| self.speaker_map = {} | |
| if os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")): | |
| with open( | |
| os.path.exists(os.path.join(cfg.processed_dir, "spk2id.json")) | |
| ) as f: | |
| self.speaker_map = json.load(f) | |
| def __getitem__(self, index): | |
| single_feature = {} | |
| utt_info = self.metadata[index] | |
| dataset = utt_info["Dataset"] | |
| uid = utt_info["Uid"] | |
| utt = "{}_{}".format(dataset, uid) | |
| # text | |
| f = open(self.utt2lab_path[utt], "r") | |
| phones = f.readlines()[0].strip() | |
| f.close() | |
| phones_ids = np.array(text_to_sequence(phones, self.cfg.text_cleaners)) | |
| text_len = len(phones_ids) | |
| # speaker | |
| if len(self.speaker_map) > 0: | |
| speaker_id = self.speaker_map[utt_info["Singer"]] | |
| else: | |
| speaker_id = 0 | |
| single_feature.update( | |
| { | |
| "texts": phones_ids, | |
| "spk_id": speaker_id, | |
| "text_len": text_len, | |
| } | |
| ) | |
| return single_feature | |
| def __len__(self): | |
| return len(self.metadata) | |
| def get_metadata(self): | |
| with open(self.metafile_path, "r", encoding="utf-8") as f: | |
| metadata = json.load(f) | |
| return metadata | |
| class FS2TestCollator(BaseTestCollator): | |
| """Zero-pads model inputs and targets based on number of frames per step""" | |
| def __init__(self, cfg): | |
| self.cfg = cfg | |
| def __call__(self, batch): | |
| packed_batch_features = dict() | |
| # mel: [b, T, n_mels] | |
| # frame_pitch, frame_energy: [1, T] | |
| # target_len: [1] | |
| # spk_id: [b, 1] | |
| # mask: [b, T, 1] | |
| for key in batch[0].keys(): | |
| if key == "target_len": | |
| packed_batch_features["target_len"] = torch.LongTensor( | |
| [b["target_len"] for b in batch] | |
| ) | |
| masks = [ | |
| torch.ones((b["target_len"], 1), dtype=torch.long) for b in batch | |
| ] | |
| packed_batch_features["mask"] = pad_sequence( | |
| masks, batch_first=True, padding_value=0 | |
| ) | |
| elif key == "text_len": | |
| packed_batch_features["text_len"] = torch.LongTensor( | |
| [b["text_len"] for b in batch] | |
| ) | |
| masks = [ | |
| torch.ones((b["text_len"], 1), dtype=torch.long) for b in batch | |
| ] | |
| packed_batch_features["text_mask"] = pad_sequence( | |
| masks, batch_first=True, padding_value=0 | |
| ) | |
| elif key == "spk_id": | |
| packed_batch_features["spk_id"] = torch.LongTensor( | |
| [b["spk_id"] for b in batch] | |
| ) | |
| else: | |
| values = [torch.from_numpy(b[key]) for b in batch] | |
| packed_batch_features[key] = pad_sequence( | |
| values, batch_first=True, padding_value=0 | |
| ) | |
| return packed_batch_features | |