Spaces:
Build error
Build error
| """Collators for T2S and S2A. | |
| Copyright PolyAI Limited. | |
| """ | |
| from pathlib import Path | |
| from typing import List, Tuple, Union | |
| import numpy as np | |
| import torch | |
| from utils.symbol_table import SymbolTable | |
| class GlobalCollater: | |
| def __init__(self, n_codes, n_semantic_codes): | |
| self.n_codes = n_codes | |
| self.sem_mask_id = n_semantic_codes | |
| def collate(self, batch): | |
| output = { | |
| 'speaker': [], | |
| 'tts_quantize_input': [], | |
| 'tts_quantize_output': [], | |
| 'quantize_mask': [], | |
| 'f_names': [], | |
| 'semantic_tokens': [], | |
| 'quantization_lengths': [], | |
| } | |
| # Get the max length of everything | |
| max_len_q = 0 | |
| for _, q_s, q_e, _, _ in batch: | |
| if len(q_s) > max_len_q: | |
| max_len_q = len(q_s) | |
| output['quantization_lengths'].append(len(q_s)) | |
| # Pad each element, create mask | |
| for spkr, qs, qe, itm_name, s_tokens in batch: | |
| # Deal with quantizations | |
| q_mask = np.array( | |
| [False] * len(qs) + [True] * (max_len_q - len(qs))) | |
| qs = np.pad( | |
| qs, | |
| [[0, max_len_q-len(qs)], [0, 0]], | |
| constant_values=self.n_codes | |
| ) | |
| qe = np.pad( | |
| qe, | |
| [[0, max_len_q-len(qe)], [0, 0]], | |
| constant_values=self.n_codes | |
| ) | |
| # Deal with semantics | |
| s_tokens = s_tokens.flatten() | |
| s_tokens = np.pad( | |
| s_tokens, | |
| (0, max_len_q-len(s_tokens)), | |
| constant_values=self.sem_mask_id | |
| ) | |
| # Speaker padding | |
| spkr = np.concatenate( | |
| (spkr, np.zeros((max_len_q - len(spkr), 512)))) | |
| # Aggregate | |
| output['speaker'].append(spkr) | |
| output['tts_quantize_input'].append(qs) | |
| output['tts_quantize_output'].append(qe) | |
| output['quantize_mask'].append(q_mask) | |
| output['f_names'].append(itm_name) | |
| output["semantic_tokens"].append(s_tokens) | |
| for k in output.keys(): | |
| if k == 'f_names': | |
| continue | |
| output[k] = np.array(output[k]) | |
| if 'mask' in k: | |
| output[k] = torch.BoolTensor(output[k]) | |
| elif k in [ | |
| 'tts_quantize_input', 'tts_quantize_output', | |
| 'semantic_tokens', 'quantization_lengths' | |
| ]: | |
| output[k] = torch.LongTensor(output[k]) | |
| else: | |
| output[k] = torch.FloatTensor(output[k]) | |
| return output | |
| class TextTokenCollater: | |
| def __init__( | |
| self, | |
| text_tokens: List[str], | |
| add_eos: bool = True, | |
| add_bos: bool = True, | |
| pad_symbol: str = "<pad>", | |
| bos_symbol: str = "<bos>", | |
| eos_symbol: str = "<eos>", | |
| spkr_1_symbol: str = "spkr_1", | |
| spkr_2_symbol: str = "spkr_2", | |
| ): | |
| self.pad_symbol = pad_symbol | |
| self.add_eos = add_eos | |
| self.add_bos = add_bos | |
| self.bos_symbol = bos_symbol | |
| self.eos_symbol = eos_symbol | |
| self.spkr_1_symbol = spkr_1_symbol | |
| self.spkr_2_symbol = spkr_2_symbol | |
| unique_tokens = ( | |
| [pad_symbol] | |
| + ([bos_symbol] if add_bos else []) | |
| + ([eos_symbol] if add_eos else []) | |
| + ([spkr_1_symbol]) | |
| + ([spkr_2_symbol]) | |
| + sorted(text_tokens) | |
| ) | |
| self.token2idx = {token: idx for idx, token in enumerate(unique_tokens)} | |
| self.idx2token = [token for token in unique_tokens] | |
| def __call__( | |
| self, texts: List[str], texts_2: Union[None, List[str]] = None | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| tokens_seqs = [[p for p in text] for text in texts] | |
| if texts_2 is None: | |
| seqs = [ | |
| ([self.bos_symbol] if self.add_bos else []) | |
| + [self.spkr_1_symbol] | |
| + list(seq) | |
| + ([self.eos_symbol] if self.add_eos else []) | |
| for seq in tokens_seqs | |
| ] | |
| else: | |
| tokens_seqs_2 = [[p for p in text] for text in texts_2] | |
| seqs = [ | |
| ([self.bos_symbol] if self.add_bos else []) | |
| + [self.spkr_1_symbol] | |
| + list(seq) | |
| + ([self.spkr_2_symbol]) | |
| + list(seq_2) | |
| + ([self.eos_symbol] if self.add_eos else []) | |
| for seq, seq_2 in zip(tokens_seqs, tokens_seqs_2) | |
| ] | |
| tokens_batch = torch.from_numpy( | |
| np.array( | |
| [[self.token2idx[token] for token in seq] for seq in seqs], | |
| dtype=np.int64, | |
| ) | |
| ) | |
| return tokens_batch | |
| def get_text_token_collater(text_tokens_file: str) -> TextTokenCollater: | |
| text_tokens_path = Path(text_tokens_file) | |
| unique_tokens = SymbolTable.from_file(text_tokens_path) | |
| collater = TextTokenCollater( | |
| unique_tokens.symbols, add_bos=True, add_eos=True | |
| ) | |
| return collater | |
| def get_text_semantic_token_collater( | |
| text_tokens_file: str, n_semantic_tokens=1024) -> TextTokenCollater: | |
| text_tokens_path = Path(text_tokens_file) | |
| unique_tokens = SymbolTable.from_file(text_tokens_path) | |
| for semantic_idx in range(n_semantic_tokens): | |
| unique_tokens.add(str(semantic_idx)) | |
| collater = TextTokenCollater( | |
| unique_tokens.symbols, add_bos=True, add_eos=True | |
| ) | |
| return collater | |
| if __name__ == '__main__': | |
| text_tokens_file = 'ckpt/unique_text_tokens.k2symbols' | |
| collater = get_text_semantic_token_collater(text_tokens_file) | |