Spaces:
Sleeping
Sleeping
| """Projector that maps hidden states from the LLM component to multimodal logits.""" | |
| import torch | |
| from torch import nn | |
| from dataclasses import dataclass | |
| from typing import Optional, Tuple | |
| from .common import HiggsAudioPreTrainedModel | |
| from .configuration_higgs_audio import HiggsAudioConfig | |
| class HiggsAudioDecoderLayerOutput: | |
| logits: torch.FloatTensor | |
| audio_logits: torch.FloatTensor | |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None | |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None | |
| class HiggsAudioDecoderProjector(HiggsAudioPreTrainedModel): | |
| """Projection layers that map hidden states from the LLM component to audio / text logits. | |
| We support two type of audio head: | |
| - Basic Audio Head: | |
| Directly map the hidden states to audio logits for all the codebooks. | |
| """ | |
| def __init__(self, config: HiggsAudioConfig, layer_idx: Optional[int] = None): | |
| super().__init__(config) | |
| self.text_lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) | |
| self.audio_lm_head = nn.Linear( | |
| config.text_config.hidden_size, | |
| config.audio_num_codebooks * (config.audio_codebook_size + 2), | |
| bias=False, | |
| ) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def forward( | |
| self, | |
| hidden_states, | |
| audio_out_mask, | |
| label_audio_ids=None, | |
| attention_mask=None, | |
| position_ids=None, | |
| past_key_values=None, | |
| use_cache=None, | |
| output_attentions=None, | |
| output_hidden_states=None, | |
| output_audio_hidden_states=False, | |
| cache_position=None, | |
| ): | |
| """ | |
| Args: | |
| hidden_states (`torch.Tensor` of shape `(batch_size, seq_len, hidden_size)`): | |
| Hidden states from the LLM component | |
| audio_out_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
| Mask for identifying the audio out tokens. | |
| label_audio_ids (`torch.Tensor` of shape `(num_codebooks, num_audio_out_tokens)`): | |
| Label tokens for the audio-out part. This is used for calculating the logits if RQ-Transformer is used. | |
| attention_mask (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
| Mask to avoid performing attention on padding token indices | |
| position_ids (`torch.Tensor` of shape `(batch_size, seq_len)`): | |
| Position ids for the input tokens | |
| Returns: | |
| logits (`torch.Tensor` of shape `(batch_size, seq_len, vocab_size)`): | |
| Logits for text tokens | |
| audio_logits (`torch.Tensor` of shape `(num_audio_out_tokens, audio_num_codebooks * audio_codebook_size)`): | |
| Logits for audio tokens. We ensure `num_text_tokens + num_audio_tokens == batch_size * seq_len` | |
| """ | |
| logits = self.text_lm_head(hidden_states) | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| next_decoder_cache = None | |
| # TODO(sxjscience) Need to check if DeepSpeed Zero3 supports zero-shape input. | |
| if self.config.audio_decoder_proj_num_layers > 0: | |
| # create position embeddings to be shared across the decoder layers | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| for decoder_layer in self.transformer_layers: | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| if self.gradient_checkpointing and self.training: | |
| layer_outputs = self._gradient_checkpointing_func( | |
| decoder_layer.__call__, | |
| hidden_states, | |
| attention_mask, | |
| position_ids, | |
| past_key_values, | |
| output_attentions, | |
| use_cache, | |
| cache_position, | |
| position_embeddings, | |
| ) | |
| else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| hidden_states = self.norm(hidden_states) | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| if use_cache: | |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
| next_cache = next_decoder_cache if use_cache else None | |
| audio_logits = self.audio_lm_head(hidden_states[audio_out_mask]) | |
| if output_audio_hidden_states: | |
| audio_hidden_states = hidden_states[audio_out_mask] | |
| else: | |
| audio_hidden_states = None | |
| return ( | |
| logits, | |
| audio_logits, | |
| all_self_attns, | |
| all_hidden_states, | |
| audio_hidden_states, | |
| next_cache, | |
| ) | |