Spaces:
Running
Running
| # Copyright 2024 The YourMT3 Authors. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Please see the details in the LICENSE file. | |
| """t5mod_helper.py""" | |
| import torch | |
| from torch import nn | |
| from model.t5mod import T5DecoderYMT3, MultiChannelT5Decoder | |
| from typing import Optional, Callable, Union, Literal | |
| def task_cond_dec_generate(decoder: Union[T5DecoderYMT3, MultiChannelT5Decoder], | |
| decoder_type: Literal["t5", "multi-t5"], | |
| embed_tokens: nn.Embedding, | |
| lm_head: nn.Module, | |
| encoder_hidden_states: torch.FloatTensor, | |
| shift_right_fn: Callable, | |
| prefix_ids: Optional[torch.LongTensor] = None, | |
| max_length: int = 1024, | |
| stop_at_eos: bool = True, | |
| eos_id: Optional[int] = 1, | |
| pad_id: Optional[int] = 0, | |
| decoder_start_token_id: Optional[int] = 0, | |
| debug: bool = False) -> torch.LongTensor: | |
| """ | |
| Generate sequence by task conditioning on the decoder side | |
| :An extension of transofrmers.generate() function for the model with | |
| conditioning only on the decoder side. | |
| Args: | |
| decoder: T5DecoderYMT3 or MultiChannelT5Decoder, any decoder model with T5Stack architecture | |
| decoder_type: Literal["t5", "multi-t5"], type of decoder | |
| embed_tokens: nn.Embedding, embedding layer for the decoder | |
| lm_head: nn.Module, language model head | |
| encoder_hidden_states: torch.FloatTensor, (B, T, D) or (B, K, T, D) last hidden states | |
| shift_right_fn: Callable, shift_right function of the decoder | |
| prefix_ids: torch.LongTensor, (B, prefix_len) prefix ids typically used as task conditioning to decoder. | |
| max_length: int, max token length to generate (default is 1024) | |
| stop_at_eos: bool, whether to early-stop when all predictions in the batch are the <eos> token. | |
| eos_id: int, the id of the <eos> token (default is 1) | |
| pad_id: int, the id of the <pad> token (default is 0) | |
| decoder_start_token_id: int, the id of the <bos> token (default is 0) | |
| debug: bool, whether to print debug information | |
| Returns: | |
| pred_ids: torch.LongTensor, (B, task_len + N) or (B, C, task_len + N) predicted token ids | |
| """ | |
| bsz = encoder_hidden_states.shape[0] | |
| device = encoder_hidden_states.device | |
| # Prepare dec_input_shape: (B, 1) or (B, C, 1) | |
| if decoder_type == "t5": | |
| dec_input_shape = (bsz, 1) | |
| elif decoder_type == "multi-t5": | |
| dec_input_shape = (bsz, decoder.num_channels, 1) | |
| else: | |
| raise ValueError(f"decoder_type {decoder_type} is not supported.") | |
| # Prepare dec_input_ids: <bos> + task_prefix_token (B, prefix_len + 1) or (B, C, prefix_len + 1) | |
| if prefix_ids is not None and prefix_ids.numel() > 0: | |
| dec_input_ids = shift_right_fn(prefix_ids) | |
| prefix_length = prefix_ids.shape[-1] | |
| else: | |
| # if prefix_ids is None, use <bos> as initial inSput | |
| dec_input_ids = torch.tile(torch.LongTensor([decoder_start_token_id]).to(device), dec_input_shape) | |
| prefix_length = 0 | |
| dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, L, D) or (B, C, L, D) | |
| # Generate decoder hidden state and past_key_values using prefix: | |
| """ | |
| - initial inputs_embeds can be a sequence, without using past_key_values | |
| - dec_hs: (B, 1, D) | |
| - past_key_values: Tuple of length M for M layers of decoder | |
| - pred_ids: (B, prefix_len) where N is the length of prefix_ids | |
| """ | |
| dec_hs, past_key_values = decoder(inputs_embeds=dec_inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| return_dict=False) | |
| logits = lm_head(dec_hs) # (b, T=1, vocab_size) or (b, C, T=1, vocab_size) | |
| pred_ids = logits.argmax(-1) # (B, prefix_len + 1) or (B, C, prefix_len + 1) | |
| # keep track of which sequences are already finished | |
| unfinished_sequences = torch.ones(dec_input_shape, dtype=torch.long, device=device) | |
| # Fast generation with past_key_values for the rest of the sequence | |
| if decoder_type == "t5": | |
| dec_input_ids = pred_ids[:, -1].unsqueeze(-1) # (B, 1) | |
| elif decoder_type == "multi-t5": | |
| dec_input_ids = pred_ids[:, :, -1].unsqueeze(-1) # (B, C, 1) | |
| for i in range(max_length - prefix_length - 1): # -1 for <eos> token | |
| if debug: | |
| past_key_values_length = past_key_values[0][0].shape[ | |
| 2] # past_key_values_length determines the positional embedding | |
| print(f'i = {i}, past_key_values_length = {past_key_values_length}, pred_ids.shape = {pred_ids.shape}') | |
| # when past_key_values is provided, we use only the last token as input_ids | |
| dec_inputs_embeds = embed_tokens(dec_input_ids) # (B, 1, D) or (B, C, 1, D) | |
| dec_hs, _past_key_values = decoder(inputs_embeds=dec_inputs_embeds, | |
| encoder_hidden_states=encoder_hidden_states, | |
| past_key_values=past_key_values, | |
| return_dict=False) | |
| logits = lm_head(dec_hs) # (b, 1, vocab_size) or (b, K, 1, vocab_size) | |
| _pred_ids = logits.argmax(-1) # (B, 1) or (B, K, 1) | |
| # update input_ids and past_key_values for next iteration | |
| dec_input_ids = _pred_ids.clone( | |
| ) # (B, 1) or (B, C, 1), deepcopy of _pred_ids because _pred_ids will be modified for finished sentences | |
| past_key_values = _past_key_values | |
| # finished sentences should have their next token be a padding token | |
| if eos_id is not None: | |
| if pad_id is None: | |
| raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") | |
| _pred_ids = _pred_ids * unfinished_sequences + pad_id * (1 - unfinished_sequences) | |
| # update pred_ids | |
| pred_ids = torch.cat((pred_ids, _pred_ids), dim=-1) # (B, T') or (B, C, T') with increasing T' | |
| # update state of unfinished_sequences | |
| if eos_id is not None: | |
| unfinished_sequences = unfinished_sequences * _pred_ids.ne(eos_id).long() | |
| # early-stop when each sentence is finished | |
| if stop_at_eos is True and unfinished_sequences.max() == 0: | |
| break | |
| return pred_ids # (B, L) or (B, C, L) | |