Spaces:
Runtime error
Runtime error
| import abc | |
| import functools | |
| import io | |
| import json | |
| import logging | |
| import os | |
| import tarfile | |
| import typing | |
| import torch.utils.data | |
| import torchaudio | |
| import transformers | |
| import vocos | |
| from torchvision.datasets.utils import download_url | |
| from modules.ChatTTS.ChatTTS.utils.infer_utils import ( | |
| apply_character_map, | |
| count_invalid_characters, | |
| ) | |
| class LazyDataType(typing.TypedDict): | |
| filepath: str | |
| speaker: str | |
| lang: str | |
| text: str | |
| class DataType(LazyDataType): | |
| text_input_ids: torch.Tensor # (batch_size, text_len) | |
| text_attention_mask: torch.Tensor # (batch_size, text_len) | |
| audio_mel_specs: torch.Tensor # (batch_size, audio_len*2, 100) | |
| audio_attention_mask: torch.Tensor # (batch_size, audio_len) | |
| class XzListTarKwargsType(typing.TypedDict): | |
| tokenizer: typing.Union[transformers.PreTrainedTokenizer, None] | |
| vocos_model: typing.Union[vocos.Vocos, None] | |
| device: typing.Union[str, torch.device, None] | |
| speakers: typing.Union[typing.Iterable[str], None] | |
| sample_rate: typing.Union[int] | |
| default_speaker: typing.Union[str, None] | |
| default_lang: typing.Union[str, None] | |
| tar_in_memory: typing.Union[bool, None] | |
| process_ahead: typing.Union[bool, None] | |
| class AudioFolder(torch.utils.data.Dataset, abc.ABC): | |
| def __init__( | |
| self, | |
| root: str | io.BytesIO, | |
| tokenizer: transformers.PreTrainedTokenizer | None = None, | |
| vocos_model: vocos.Vocos | None = None, | |
| device: str | torch.device | None = None, | |
| speakers: typing.Iterable[str] | None = None, | |
| sample_rate: int = 24_000, | |
| default_speaker: str | None = None, | |
| default_lang: str | None = None, | |
| tar_path: str | None = None, | |
| tar_in_memory: bool = False, | |
| process_ahead: bool = False, | |
| ) -> None: | |
| self.root = root | |
| self.sample_rate = sample_rate | |
| self.default_speaker = default_speaker | |
| self.default_lang = default_lang | |
| self.logger = logging.getLogger(__name__) | |
| self.normalizer = {} | |
| self.tokenizer = tokenizer | |
| self.vocos = vocos_model | |
| self.vocos_device = ( | |
| None if self.vocos is None else next(self.vocos.parameters()).device | |
| ) | |
| self.device = device or self.vocos_device | |
| # tar -cvf ../Xz.tar * | |
| # tar -xf Xz.tar -C ./Xz | |
| self.tar_path = tar_path | |
| self.tar_file = None | |
| self.tar_io = None | |
| if tar_path is not None: | |
| if tar_in_memory: | |
| with open(tar_path, "rb") as f: | |
| self.tar_io = io.BytesIO(f.read()) | |
| self.tar_file = tarfile.open(fileobj=self.tar_io) | |
| else: | |
| self.tar_file = tarfile.open(tar_path) | |
| self.lazy_data, self.speakers = self.get_lazy_data(root, speakers) | |
| self.text_input_ids: dict[int, torch.Tensor] = {} | |
| self.audio_mel_specs: dict[int, torch.Tensor] = {} | |
| if process_ahead: | |
| for n, item in enumerate(self.lazy_data): | |
| self.audio_mel_specs[n] = self.preprocess_audio(item["filepath"]) | |
| self.text_input_ids[n] = self.preprocess_text( | |
| item["text"], item["lang"] | |
| ) | |
| if self.tar_file is not None: | |
| self.tar_file.close() | |
| if self.tar_io is not None: | |
| self.tar_io.close() | |
| def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: ... | |
| def save_config( | |
| save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
| ) -> None: ... | |
| def __len__(self): | |
| return len(self.lazy_data) | |
| def __getitem__(self, n: int) -> DataType: | |
| lazy_data = self.lazy_data[n] | |
| if n in self.audio_mel_specs: | |
| audio_mel_specs = self.audio_mel_specs[n] | |
| text_input_ids = self.text_input_ids[n] | |
| else: | |
| audio_mel_specs = self.preprocess_audio(lazy_data["filepath"]) | |
| text_input_ids = self.preprocess_text(lazy_data["text"], lazy_data["lang"]) | |
| self.audio_mel_specs[n] = audio_mel_specs | |
| self.text_input_ids[n] = text_input_ids | |
| if len(self.audio_mel_specs) == len(self.lazy_data): | |
| if self.tar_file is not None: | |
| self.tar_file.close() | |
| if self.tar_io is not None: | |
| self.tar_io.close() | |
| text_attention_mask = torch.ones( | |
| len(text_input_ids), device=text_input_ids.device | |
| ) | |
| audio_attention_mask = torch.ones( | |
| (len(audio_mel_specs) + 1) // 2, | |
| device=audio_mel_specs.device, | |
| ) | |
| return { | |
| "filepath": lazy_data["filepath"], | |
| "speaker": lazy_data["speaker"], | |
| "lang": lazy_data["lang"], | |
| "text": lazy_data["text"], | |
| "text_input_ids": text_input_ids, | |
| "text_attention_mask": text_attention_mask, | |
| "audio_mel_specs": audio_mel_specs, | |
| "audio_attention_mask": audio_attention_mask, | |
| } | |
| def get_lazy_data( | |
| self, | |
| root: str | io.BytesIO, | |
| speakers: typing.Iterable[str] | None = None, | |
| ) -> tuple[list[LazyDataType], set[str]]: | |
| if speakers is not None: | |
| new_speakers = set(speakers) | |
| else: | |
| new_speakers = set() | |
| lazy_data = [] | |
| raw_data = self.get_raw_data(root) | |
| folder_path = os.path.dirname(root) if isinstance(root, str) else "" | |
| for item in raw_data: | |
| if "speaker" not in item: | |
| item["speaker"] = self.default_speaker | |
| if "lang" not in item: | |
| item["lang"] = self.default_lang | |
| if speakers is not None and item["speaker"] not in speakers: | |
| continue | |
| if speakers is None and item["speaker"] not in new_speakers: | |
| new_speakers.add(item["speaker"]) | |
| if self.tar_file is None and isinstance(root, str): | |
| filepath = os.path.join(folder_path, item["filepath"]) | |
| else: | |
| filepath = item["filepath"] | |
| lazy_data.append( | |
| { | |
| "filepath": filepath, | |
| "speaker": item["speaker"], | |
| "lang": item["lang"].lower(), | |
| "text": item["text"], | |
| } | |
| ) | |
| return lazy_data, new_speakers | |
| def preprocess_text( | |
| self, | |
| text: str, | |
| lang: str, | |
| ) -> torch.Tensor: | |
| invalid_characters = count_invalid_characters(text) | |
| if len(invalid_characters): | |
| # self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') | |
| text = apply_character_map(text) | |
| # if not skip_refine_text: | |
| # text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] | |
| # text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] | |
| # text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) | |
| # if refine_text_only: | |
| # return text | |
| text = f"[Stts][spk_emb]{text}[Ptts]" | |
| # text = f'[Stts][empty_spk]{text}[Ptts]' | |
| text_token = self.tokenizer( | |
| text, return_tensors="pt", add_special_tokens=False | |
| ).to(device=self.device) | |
| return text_token["input_ids"].squeeze(0) | |
| def preprocess_audio(self, filepath: str) -> torch.Tensor: | |
| if self.tar_file is not None: | |
| file = self.tar_file.extractfile(filepath) | |
| waveform, sample_rate = torchaudio.load(file) | |
| else: | |
| waveform, sample_rate = torchaudio.load(filepath) | |
| waveform = waveform.to(device=self.vocos_device) | |
| if sample_rate != self.sample_rate: | |
| waveform = torchaudio.functional.resample( | |
| waveform, | |
| orig_freq=sample_rate, | |
| new_freq=self.sample_rate, | |
| ) | |
| mel_spec: torch.Tensor = self.vocos.feature_extractor(waveform) | |
| return ( | |
| mel_spec.to(device=self.device).squeeze(0).transpose(0, 1) | |
| ) # (audio_len*2, 100) | |
| class JsonFolder(AudioFolder): | |
| """ | |
| In json file, each item is formatted as following example: | |
| `{"filepath": "path/to/file.wav", "speaker": "John", "lang": "ZH", "text": "Hello"}`. | |
| filepath is relative to the dirname of root json file. | |
| """ | |
| def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
| with open(root, "r", encoding="utf-8") as f: | |
| raw_data = json.load(f) | |
| return raw_data | |
| def save_config( | |
| save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
| ) -> None: | |
| save_data = [item.copy() for item in lazy_data] | |
| for item in save_data: | |
| item["filepath"] = os.path.relpath(item["filepath"], rel_path) | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| json.dump(save_data, f, ensure_ascii=False, indent=4) | |
| class ListFolder(AudioFolder): | |
| """ | |
| In list file, each row is formatted as `filepath|speaker|lang|text` with `|` as separator. | |
| `path/to/file.wav|John|ZH|Hello`. | |
| filepath is relative to the dirname of root list file. | |
| """ | |
| def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
| raw_data = [] | |
| with open(root, "r", encoding="utf-8") as f: | |
| for line in f.readlines(): | |
| line = line.strip().removesuffix("\n") | |
| if len(line) == 0: | |
| continue | |
| filepath, speaker, lang, text = line.split(sep="|", maxsplit=3) | |
| raw_data.append( | |
| { | |
| "text": text, | |
| "filepath": filepath, | |
| "speaker": speaker, | |
| "lang": lang, | |
| } | |
| ) | |
| return raw_data | |
| def save_config( | |
| save_path: str, lazy_data: list[LazyDataType], rel_path: str = "./" | |
| ) -> None: | |
| save_data = [item.copy() for item in lazy_data] | |
| for item in save_data: | |
| item["filepath"] = os.path.relpath(item["filepath"], rel_path) | |
| with open(save_path, "w", encoding="utf-8") as f: | |
| for item in save_data: | |
| f.write( | |
| f"{item['filepath']}|{item['speaker']}|{item['lang']}|{item['text']}\n" | |
| ) | |
| class XzListTar(ListFolder): | |
| def __init__( | |
| self, | |
| *args, | |
| root: str | io.BytesIO, | |
| tar_path: str | None = None, | |
| **kwargs, | |
| ): | |
| if isinstance(root, io.BytesIO): | |
| assert tar_path is not None | |
| else: | |
| # make sure root is a list file | |
| if not root.endswith(".list"): # folder case | |
| if os.path.isfile(root): | |
| raise FileExistsError(f"{root} is a file!") | |
| elif not os.path.exists(root): | |
| os.makedirs(root) | |
| root = os.path.join(root, "all.list") | |
| if isinstance(root, str) and not os.path.isfile(root): | |
| # prepare all.list | |
| self.concat_dataset( | |
| save_folder=os.path.dirname(root), | |
| langs=kwargs.get("langs", ["zh", "en"]), | |
| ) | |
| super().__init__(root, *args, tar_path=tar_path, **kwargs) | |
| def concat_dataset( | |
| self, save_folder: str | None = None, langs: list[str] = ["zh", "en"] | |
| ) -> None: | |
| if save_folder is None: | |
| save_folder = os.path.dirname(self.root) | |
| if os.path.isfile(save_folder): | |
| raise FileExistsError(f"{save_folder} already exists as a file!") | |
| elif not os.path.exists(save_folder): | |
| os.makedirs(save_folder) | |
| lazy_data = [] | |
| for member in self.tar_file.getmembers(): | |
| if not member.isfile(): | |
| continue | |
| if member.name.endswith(".list"): | |
| print(member.name) | |
| root_io = self.tar_file.extractfile(member) | |
| lazy_data += ListFolder(root_io).lazy_data | |
| if member.name.endswith(".json"): | |
| print(member.name) | |
| root_io = self.tar_file.extractfile(member) | |
| lazy_data += JsonFolder(root_io).lazy_data | |
| if langs is not None: | |
| lazy_data = [item for item in lazy_data if item["lang"] in langs] | |
| ListFolder.save_config(os.path.join(save_folder, "all.list"), lazy_data) | |
| JsonFolder.save_config(os.path.join(save_folder, "all.json"), lazy_data) | |
| print(f"all.list and all.json are saved to {save_folder}") | |
| class XzListFolder(ListFolder): | |
| """ | |
| [XzδΉεΈ](https://space.bilibili.com/5859321) | |
| Only look at the basename of filepath in list file. Previous folder paths are ignored. | |
| Files are organized as `[list basename]/[file basename]` | |
| Example tree structure: | |
| [folder] | |
| βββ speaker_A | |
| β βββ 1.wav | |
| β βββ 2.wav | |
| βββ speaker_A.list | |
| βββ speaker_B | |
| β βββ 1.wav | |
| β βββ 2.wav | |
| βββ speaker_B.list | |
| """ | |
| def get_raw_data(self, root: str | io.BytesIO) -> list[dict[str, str]]: | |
| raw_data = super().get_raw_data(root) | |
| for item in raw_data: | |
| item["filepath"] = os.path.join( | |
| os.path.basename(root).removesuffix(".list"), | |
| os.path.basename(item["filepath"]), | |
| ) | |
| return raw_data | |
| class AudioCollator: | |
| def __init__(self, text_pad: int = 0, audio_pad: int = 0): | |
| self.text_pad = text_pad | |
| self.audio_pad = audio_pad | |
| def __call__(self, batch: list[DataType]): | |
| batch = [x for x in batch if x is not None] | |
| audio_maxlen = max(len(item["audio_attention_mask"]) for item in batch) | |
| text_maxlen = max(len(item["text_attention_mask"]) for item in batch) | |
| filepath = [] | |
| speaker = [] | |
| lang = [] | |
| text = [] | |
| text_input_ids = [] | |
| text_attention_mask = [] | |
| audio_mel_specs = [] | |
| audio_attention_mask = [] | |
| for x in batch: | |
| filepath.append(x["filepath"]) | |
| speaker.append(x["speaker"]) | |
| lang.append(x["lang"]) | |
| text.append(x["text"]) | |
| text_input_ids.append( | |
| torch.nn.functional.pad( | |
| x["text_input_ids"], | |
| (text_maxlen - len(x["text_input_ids"]), 0), | |
| value=self.text_pad, | |
| ) | |
| ) | |
| text_attention_mask.append( | |
| torch.nn.functional.pad( | |
| x["text_attention_mask"], | |
| (text_maxlen - len(x["text_attention_mask"]), 0), | |
| value=0, | |
| ) | |
| ) | |
| audio_mel_specs.append( | |
| torch.nn.functional.pad( | |
| x["audio_mel_specs"], | |
| (0, 0, 0, audio_maxlen * 2 - len(x["audio_mel_specs"])), | |
| value=self.audio_pad, | |
| ) | |
| ) | |
| audio_attention_mask.append( | |
| torch.nn.functional.pad( | |
| x["audio_attention_mask"], | |
| (0, audio_maxlen - len(x["audio_attention_mask"])), | |
| value=0, | |
| ) | |
| ) | |
| return { | |
| "filepath": filepath, | |
| "speaker": speaker, | |
| "lang": lang, | |
| "text": text, | |
| "text_input_ids": torch.stack(text_input_ids), | |
| "text_attention_mask": torch.stack(text_attention_mask), | |
| "audio_mel_specs": torch.stack(audio_mel_specs), | |
| "audio_attention_mask": torch.stack(audio_attention_mask), | |
| } | |
| def formalize_xz_list(src_folder: str): | |
| for root, _, files in os.walk(src_folder): | |
| for file in files: | |
| if file.endswith(".list"): | |
| filepath = os.path.join(root, file) | |
| print(filepath) | |
| lazy_data = XzListFolder(filepath).lazy_data | |
| XzListFolder.save_config(filepath, lazy_data, rel_path=src_folder) | |
| def concat_dataset( | |
| src_folder: str, save_folder: str | None = None, langs: list[str] = ["zh", "en"] | |
| ) -> None: | |
| if save_folder is None: | |
| save_folder = src_folder | |
| if os.path.isfile(save_folder): | |
| raise FileExistsError(f"{save_folder} already exists as a file!") | |
| elif not os.path.exists(save_folder): | |
| os.makedirs(save_folder) | |
| lazy_data = [] | |
| same_folder = os.path.samefile(src_folder, save_folder) | |
| for root, _, files in os.walk(src_folder): | |
| for file in files: | |
| filepath = os.path.join(root, file) | |
| if same_folder and file in ("all.list", "all.json"): | |
| continue | |
| if file.endswith(".list"): | |
| print(filepath) | |
| lazy_data += ListFolder(filepath).lazy_data | |
| if file.endswith(".json"): | |
| print(filepath) | |
| lazy_data += JsonFolder(filepath).lazy_data | |
| if langs is not None: | |
| lazy_data = [item for item in lazy_data if item["lang"] in langs] | |
| ListFolder.save_config( | |
| os.path.join(save_folder, "all.list"), lazy_data, rel_path=save_folder | |
| ) | |
| JsonFolder.save_config( | |
| os.path.join(save_folder, "all.json"), lazy_data, rel_path=save_folder | |
| ) | |
| print(f"all.list and all.json are saved to {save_folder}") | |