Spaces:
Runtime error
Runtime error
| """T2S model definition. | |
| Copyright PolyAI Limited. | |
| """ | |
| import os | |
| import numpy as np | |
| from torch import nn | |
| from transformers import EvalPrediction, T5Config, T5ForConditionalGeneration | |
| from data.collation import get_text_semantic_token_collater | |
| def compute_custom_metrics(eval_prediction: EvalPrediction): | |
| # eval_prediction: tuple | |
| # eval_prediction[0]: tensor of decoder outputs(logits) (n_batch, n_semantic, n_tokens) # noqa | |
| # eval_prediction[1]: tensor of encoder outputs (n_batch, n_text/n_phone, n_hidden) # noqa | |
| logits = eval_prediction.predictions[0] | |
| labels = eval_prediction.label_ids | |
| n_vocab = logits.shape[-1] | |
| mask = labels == -100 | |
| top_1 = np.argmax(logits, axis=-1) == labels | |
| top_1[mask] = False | |
| top_5 = np.argsort(logits, axis=-1)[:, :, -5:] | |
| top_5 = np.any(top_5 == np.expand_dims(labels, axis=-1), axis=-1) | |
| top_5[mask] = False | |
| top_10 = np.argsort(logits, axis=-1)[:, :, -10:] | |
| top_10 = np.any(top_10 == np.expand_dims(labels, axis=-1), axis=-1) | |
| top_10[mask] = False | |
| top_1_accuracy = np.sum(top_1) / np.sum(~mask) | |
| top_5_accuracy = np.sum(top_5) / np.sum(~mask) | |
| top_10_accuracy = np.sum(top_10) / np.sum(~mask) | |
| return { | |
| "top_1_accuracy": top_1_accuracy, | |
| "top_5_accuracy": top_5_accuracy, | |
| "top_10_accuracy": top_10_accuracy, | |
| } | |
| class T2S(nn.Module): | |
| def __init__(self, hp): | |
| super().__init__() | |
| self.text_tokens_file = "ckpt/unique_text_tokens.k2symbols" | |
| self.collater = get_text_semantic_token_collater(self.text_tokens_file) | |
| self.model_size = hp.model_size | |
| self.vocab_size = len(self.collater.idx2token) | |
| self.config = self._define_model_config(self.model_size) | |
| print(f"{self.config = }") | |
| self.t2s = T5ForConditionalGeneration(self.config) | |
| def _define_model_config(self, model_size): | |
| if model_size == "test": | |
| # n_params = 16M | |
| d_ff = 16 | |
| d_model = 8 | |
| d_kv = 32 | |
| num_heads = 1 | |
| num_decoder_layers = 1 | |
| num_layers = 1 | |
| elif model_size == "tiny": | |
| # n_params = 16M | |
| d_ff = 1024 | |
| d_model = 256 | |
| d_kv = 32 | |
| num_heads = 4 | |
| num_decoder_layers = 4 | |
| num_layers = 4 | |
| elif model_size == "t5small": | |
| # n_params = 60M | |
| d_ff = 2048 | |
| d_model = 512 | |
| d_kv = 64 | |
| num_heads = 8 | |
| num_decoder_layers = 6 | |
| num_layers = 6 | |
| elif model_size == "large": | |
| # n_params = 100M | |
| d_ff = 2048 | |
| d_model = 512 | |
| d_kv = 64 | |
| num_heads = 8 | |
| num_decoder_layers = 14 | |
| num_layers = 14 | |
| elif model_size == "Large": | |
| # n_params = 114M | |
| d_ff = 4096 | |
| d_model = 512 | |
| d_kv = 64 | |
| num_heads = 8 | |
| num_decoder_layers = 6 | |
| num_layers = 10 | |
| else: | |
| raise ValueError(f"unknown {model_size}") | |
| config = T5Config( | |
| d_ff=d_ff, | |
| d_model=d_model, | |
| d_kv=d_kv, | |
| num_heads=num_heads, | |
| num_decoder_layers=num_decoder_layers, | |
| num_layers=num_layers, | |
| decoder_start_token_id=0, | |
| eos_token_id=2, | |
| vocab_size=self.vocab_size, | |
| ) | |
| return config | |