Spaces:
Build error
Build error
| """Semantic tokens loading logic. | |
| Copyright PolyAI Limited. | |
| """ | |
| import json | |
| import logging | |
| import random | |
| import re | |
| from logging import getLogger | |
| from pathlib import Path | |
| from typing import List, Pattern, Union | |
| import numpy as np | |
| import torch | |
| from phonemizer.backend import EspeakBackend | |
| from phonemizer.backend.espeak.language_switch import LanguageSwitch | |
| from phonemizer.backend.espeak.words_mismatch import WordMismatch | |
| from phonemizer.punctuation import Punctuation | |
| from phonemizer.separator import Separator | |
| from torch.utils.data import DataLoader, Dataset | |
| from tqdm import tqdm | |
| from data.collation import get_text_semantic_token_collater | |
| class TextTokenizer: | |
| """Phonemize Text.""" | |
| def __init__( | |
| self, | |
| language="en-us", | |
| backend="espeak", | |
| separator=Separator(word="_", syllable="-", phone="|"), | |
| preserve_punctuation=True, | |
| punctuation_marks: Union[str, Pattern] = Punctuation.default_marks(), | |
| with_stress: bool = False, | |
| tie: Union[bool, str] = False, | |
| language_switch: LanguageSwitch = "keep-flags", | |
| words_mismatch: WordMismatch = "ignore", | |
| ) -> None: | |
| logger = getLogger("phonemizer") | |
| logger.setLevel(logging.ERROR) | |
| if backend == "espeak": | |
| phonemizer = EspeakBackend( | |
| language, | |
| punctuation_marks=punctuation_marks, | |
| preserve_punctuation=preserve_punctuation, | |
| with_stress=with_stress, | |
| tie=tie, | |
| language_switch=language_switch, | |
| words_mismatch=words_mismatch, | |
| logger=logger, | |
| ) | |
| else: | |
| raise NotImplementedError(f"{backend}") | |
| self.backend = phonemizer | |
| self.separator = separator | |
| def to_list(self, phonemized: str) -> List[str]: | |
| fields = [] | |
| for word in phonemized.split(self.separator.word): | |
| # "ɐ m|iː|n?" ɹ|ɪ|z|ɜː|v; h|ɪ|z. | |
| pp = re.findall(r"\w+|[^\w\s]", word, re.UNICODE) | |
| fields.extend( | |
| [p for p in pp if p != self.separator.phone] + [self.separator.word] | |
| ) | |
| assert len("".join(fields[:-1])) == len(phonemized) - phonemized.count( | |
| self.separator.phone | |
| ) | |
| return fields[:-1] | |
| def __call__(self, text, strip=True) -> List[List[str]]: | |
| if isinstance(text, str): | |
| text = [text] | |
| phonemized = self.backend.phonemize( | |
| text, separator=self.separator, strip=strip, njobs=1 | |
| ) | |
| return [self.to_list(p) for p in phonemized] | |
| class Collator: | |
| def collate(self, batch): | |
| input_ids = [item["input_ids"] for item in batch] | |
| output_sequences = [item["labels"] for item in batch] | |
| # Pad sequences to the maximum length in the batch | |
| input_ids = torch.nn.utils.rnn.pad_sequence( | |
| input_ids, batch_first=True, padding_value=0 | |
| ) | |
| output_sequences = torch.nn.utils.rnn.pad_sequence( | |
| output_sequences, batch_first=True, padding_value=-100 | |
| ) | |
| # 1 - token is unmasked, 0 - token is masked. | |
| attention_mask = input_ids != 0 | |
| return { | |
| "input_ids": input_ids, | |
| "attention_mask": attention_mask, | |
| "labels": output_sequences, | |
| } | |
| class ConcatenateSemanticDataset(Dataset): | |
| def __init__( | |
| self, manifest_path: str, symbol_table_path: str, | |
| n_samples: int = 0, max_duration=15): | |
| self.data = [] | |
| self.phonemizer = TextTokenizer() | |
| self.text_collater = get_text_semantic_token_collater( | |
| symbol_table_path) | |
| self.manifest_path = manifest_path | |
| self.n_samples = n_samples | |
| self.max_duration = max_duration | |
| if manifest_path is not None: | |
| self._build() | |
| def __len__(self): | |
| if self.n_samples: | |
| return min(self.n_samples, len(self.data)) | |
| return len(self.data) | |
| def remove_unknown_symbols(self, text: List[str]): | |
| res = [] | |
| for sym in text: | |
| if sym not in self.text_collater.token2idx: | |
| # print(f'{sym} is unk') | |
| continue | |
| res.append(sym) | |
| return res | |
| def __getitem__(self, idx): | |
| item = self.data[idx] | |
| input_ids = item["phoneme"].split("|") | |
| input_ids = self.remove_unknown_symbols(input_ids) | |
| input_ids_2 = None | |
| if item.get("phoneme_2"): | |
| input_ids_2 = item["phoneme_2"].split("|") | |
| input_ids_2 = [self.remove_unknown_symbols(input_ids_2)] | |
| input_ids = self.text_collater( | |
| [input_ids], input_ids_2).to(dtype=torch.long) | |
| input_ids = input_ids.to(dtype=torch.long) | |
| labels = np.load(item["semantic_path"]) | |
| labels = [str(lbl) for lbl in labels] | |
| labels_2 = None | |
| if item.get("semantic_path_2"): | |
| labels_2 = np.load(item["semantic_path_2"]) | |
| labels_2 = [[str(lbl) for lbl in labels_2]] | |
| labels = self.text_collater([labels], labels_2).to(dtype=torch.long) | |
| return {"input_ids": input_ids.squeeze(0), "labels": labels.squeeze(0)} | |
| # TODO - remove this to not load to the memory | |
| def _build(self): | |
| for manifest_path in self.manifest_path: | |
| dataset_path = Path(manifest_path).parent | |
| with open(manifest_path, "r") as manifest_file: | |
| manifest_data = json.load(manifest_file) | |
| for key, value in tqdm(manifest_data.items()): | |
| if float(value["duration"]) > self.max_duration: | |
| continue | |
| text = value["text"] | |
| phoneme = value["phoneme"] | |
| npy_path = f"{dataset_path}/audios-speech-tokenizer/semantic/{key.split('.wav')[0]}.npy" # noqa | |
| datapoint = { | |
| "text": text, | |
| "semantic_path": npy_path, | |
| "phoneme": phoneme | |
| } | |
| self.data.append(datapoint) | |
| print(f"Total length of the dataset {manifest_path}: {len(self.data)}") | |
| random.shuffle(self.data) | |
| if __name__ == "__main__": | |
| # Create an instance of the dataset | |
| manifest_path = "datasets/ljspeech-training-data/dev.json" | |
| text_tokens_file = "ckpt/unique_text_tokens.k2symbols" | |
| seq2seq_dataset = ConcatenateSemanticDataset( | |
| [manifest_path, manifest_path], text_tokens_file) | |
| # seq2seq_dataset.phonemize_and_rewrite_manifest() | |
| batch_size = 1 # Adjust to your desired batch size | |
| dataloader = DataLoader( | |
| seq2seq_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| collate_fn=Collator().collate, | |
| ) | |
| for batch in dataloader: | |
| print(batch["input_ids"]) | |
| print(batch["labels"]) | |
| print(batch["input_ids"][0].unique().max()) | |
| print(batch["input_ids"][0].unique().min()) | |
| print(batch["input_ids"].shape) | |
| print(batch["labels"].shape) | |
| break # Stop after the first batch if needed | |