ammarnasr commited on
Commit
348089f
·
verified ·
1 Parent(s): 97d679c

Upload T5MIMOconvForConditionalGeneration

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "T5MIMOconvForConditionalGeneration"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_t5mimoconv.T5MIMOconvConfig",
7
+ "AutoModelForSeq2SeqLM": "modeling_t5mimoconv.T5MIMOconvForConditionalGeneration"
8
+ },
9
+ "classifier_dropout": 0.0,
10
+ "d_ff": 1024,
11
+ "d_kv": 64,
12
+ "d_model": 256,
13
+ "decoder_start_token_id": 0,
14
+ "dense_act_fn": "relu",
15
+ "dropout_rate": 0.1,
16
+ "eos_token_id": 1,
17
+ "feed_forward_proj": "relu",
18
+ "initializer_factor": 0.05,
19
+ "is_encoder_decoder": true,
20
+ "is_gated_act": false,
21
+ "layer_norm_epsilon": 1e-06,
22
+ "model_type": "t5mimoconv",
23
+ "num_decoder_layers": 4,
24
+ "num_filters": 64,
25
+ "num_heads": 4,
26
+ "num_layers": 4,
27
+ "num_seqs": 3,
28
+ "pad_token_id": 0,
29
+ "relative_attention_max_distance": 128,
30
+ "relative_attention_num_buckets": 32,
31
+ "torch_dtype": "float32",
32
+ "transformers_version": "4.41.1",
33
+ "use_cache": true,
34
+ "vocab_size": 4096
35
+ }
configuration_t5mimoconv.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Mapping
2
+ from transformers.configuration_utils import PretrainedConfig
3
+ from transformers.onnx import OnnxSeq2SeqConfigWithPast
4
+ from transformers.utils import logging
5
+
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class T5MIMOconvConfig(PretrainedConfig):
11
+ r"""
12
+ This is the configuration class to store the configuration of a [`T5Model`] or a [`TFT5Model`]. It is used to
13
+ instantiate a T5 model according to the specified arguments, defining the model architecture. Instantiating a
14
+ configuration with the defaults will yield a similar configuration to that of the T5
15
+ [google-t5/t5-small](https://huggingface.co/google-t5/t5-small) architecture.
16
+
17
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
18
+ documentation from [`PretrainedConfig`] for more information.
19
+
20
+ Arguments:
21
+ vocab_size (`int`, *optional*, defaults to 32128):
22
+ Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
23
+ `inputs_ids` passed when calling [`T5Model`] or [`TFT5Model`].
24
+ d_model (`int`, *optional*, defaults to 512):
25
+ Size of the encoder layers and the pooler layer.
26
+ d_kv (`int`, *optional*, defaults to 64):
27
+ Size of the key, query, value projections per attention head. The `inner_dim` of the projection layer will
28
+ be defined as `num_heads * d_kv`.
29
+ d_ff (`int`, *optional*, defaults to 2048):
30
+ Size of the intermediate feed forward layer in each `T5Block`.
31
+ num_layers (`int`, *optional*, defaults to 6):
32
+ Number of hidden layers in the Transformer encoder.
33
+ num_decoder_layers (`int`, *optional*):
34
+ Number of hidden layers in the Transformer decoder. Will use the same value as `num_layers` if not set.
35
+ num_heads (`int`, *optional*, defaults to 8):
36
+ Number of attention heads for each attention layer in the Transformer encoder.
37
+ relative_attention_num_buckets (`int`, *optional*, defaults to 32):
38
+ The number of buckets to use for each attention layer.
39
+ relative_attention_max_distance (`int`, *optional*, defaults to 128):
40
+ The maximum distance of the longer sequences for the bucket separation.
41
+ dropout_rate (`float`, *optional*, defaults to 0.1):
42
+ The ratio for all dropout layers.
43
+ classifier_dropout (`float`, *optional*, defaults to 0.0):
44
+ The dropout ratio for classifier.
45
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
46
+ The epsilon used by the layer normalization layers.
47
+ initializer_factor (`float`, *optional*, defaults to 1):
48
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
49
+ testing).
50
+ feed_forward_proj (`string`, *optional*, defaults to `"relu"`):
51
+ Type of feed forward layer to be used. Should be one of `"relu"` or `"gated-gelu"`. T5v1.1 uses the
52
+ `"gated-gelu"` feed forward projection. Original T5 uses `"relu"`.
53
+ use_cache (`bool`, *optional*, defaults to `True`):
54
+ Whether or not the model should return the last key/values attentions (not used by all models).
55
+ """
56
+
57
+ model_type = "t5mimoconv"
58
+ keys_to_ignore_at_inference = ["past_key_values"]
59
+ attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"}
60
+
61
+ def __init__(
62
+ self,
63
+ vocab_size=32128,
64
+ d_model=512,
65
+ d_kv=64,
66
+ d_ff=2048,
67
+ num_layers=6,
68
+ num_decoder_layers=None,
69
+ num_heads=8,
70
+ relative_attention_num_buckets=32,
71
+ relative_attention_max_distance=128,
72
+ dropout_rate=0.1,
73
+ layer_norm_epsilon=1e-6,
74
+ initializer_factor=1.0,
75
+ feed_forward_proj="relu",
76
+ is_encoder_decoder=True,
77
+ use_cache=True,
78
+ pad_token_id=0,
79
+ eos_token_id=1,
80
+ decoder_start_token_id = 0,
81
+ classifier_dropout=0.0,
82
+ num_seqs=3,
83
+ num_filters=64,
84
+ **kwargs,
85
+ ):
86
+ self.vocab_size = vocab_size
87
+ self.d_model = d_model
88
+ self.d_kv = d_kv
89
+ self.d_ff = d_ff
90
+ self.num_layers = num_layers
91
+ self.num_decoder_layers = (
92
+ num_decoder_layers if num_decoder_layers is not None else self.num_layers
93
+ ) # default = symmetry
94
+ self.num_heads = num_heads
95
+ self.relative_attention_num_buckets = relative_attention_num_buckets
96
+ self.relative_attention_max_distance = relative_attention_max_distance
97
+ self.dropout_rate = dropout_rate
98
+ self.classifier_dropout = classifier_dropout
99
+ self.layer_norm_epsilon = layer_norm_epsilon
100
+ self.initializer_factor = initializer_factor
101
+ self.feed_forward_proj = feed_forward_proj
102
+ self.use_cache = use_cache
103
+ self.num_seqs = num_seqs
104
+ self.num_filters = num_filters
105
+
106
+ act_info = self.feed_forward_proj.split("-")
107
+ self.dense_act_fn = act_info[-1]
108
+ self.is_gated_act = act_info[0] == "gated"
109
+
110
+ if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
111
+ raise ValueError(
112
+ f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer. "
113
+ "Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
114
+ "'gated-gelu' or 'relu'"
115
+ )
116
+
117
+ # for backwards compatibility
118
+ if feed_forward_proj == "gated-gelu":
119
+ self.dense_act_fn = "gelu_new"
120
+
121
+ super().__init__(
122
+ pad_token_id=pad_token_id,
123
+ eos_token_id=eos_token_id,
124
+ decoder_start_token_id=decoder_start_token_id,
125
+ is_encoder_decoder=is_encoder_decoder,
126
+ **kwargs,
127
+ )
128
+
129
+
130
+ class T5MIMOOnnxConfig(OnnxSeq2SeqConfigWithPast):
131
+ @property
132
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
133
+ common_inputs = {
134
+ "input_ids": {0: "batch", 1: "encoder_sequence"},
135
+ "attention_mask": {0: "batch", 1: "encoder_sequence"},
136
+ }
137
+ if self.use_past:
138
+ common_inputs["attention_mask"][1] = "past_encoder_sequence + sequence"
139
+ common_inputs["decoder_input_ids"] = {0: "batch"}
140
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
141
+ else:
142
+ common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
143
+ common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
144
+
145
+ if self.use_past:
146
+ self.fill_with_past_key_values_(common_inputs, direction="inputs")
147
+
148
+ return common_inputs
149
+
150
+ @property
151
+ def default_onnx_opset(self) -> int:
152
+ return 13
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 0,
4
+ "eos_token_id": 1,
5
+ "pad_token_id": 0,
6
+ "transformers_version": "4.41.1"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca0b738f4c07afb251d5167ef3c7cc80e0257fa069b463836231b434b434f825
3
+ size 33649068
modeling_t5mimoconv.py ADDED
@@ -0,0 +1,1752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import warnings
4
+ from typing import Optional, Tuple, Union
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ from transformers.activations import ACT2FN
9
+ from transformers.modeling_outputs import (
10
+ BaseModelOutput,
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ Seq2SeqLMOutput,
13
+ Seq2SeqModelOutput,
14
+ )
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
17
+ from transformers.utils import (
18
+ DUMMY_INPUTS,
19
+ DUMMY_MASK,
20
+ is_torch_fx_proxy,
21
+ logging,
22
+ )
23
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
24
+ from .configuration_t5mimoconv import T5MIMOconvConfig
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+
30
+
31
+ class T5LayerNorm(nn.Module):
32
+ def __init__(self, hidden_size, eps=1e-6):
33
+ """
34
+ Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
35
+ """
36
+ super().__init__()
37
+ self.weight = nn.Parameter(torch.ones(hidden_size))
38
+ self.variance_epsilon = eps
39
+
40
+ def forward(self, hidden_states):
41
+ # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
42
+ # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
43
+ # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
44
+ # half-precision inputs is done in fp32
45
+
46
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
47
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
48
+
49
+ # convert into half-precision if necessary
50
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
51
+ hidden_states = hidden_states.to(self.weight.dtype)
52
+
53
+ return self.weight * hidden_states
54
+
55
+
56
+ ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
57
+
58
+
59
+ class T5DenseActDense(nn.Module):
60
+ def __init__(self, config: T5MIMOconvConfig):
61
+ super().__init__()
62
+ self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
63
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
64
+ self.dropout = nn.Dropout(config.dropout_rate)
65
+ self.act = ACT2FN[config.dense_act_fn]
66
+
67
+ def forward(self, hidden_states):
68
+ hidden_states = self.wi(hidden_states)
69
+ hidden_states = self.act(hidden_states)
70
+ hidden_states = self.dropout(hidden_states)
71
+ if (
72
+ isinstance(self.wo.weight, torch.Tensor)
73
+ and hidden_states.dtype != self.wo.weight.dtype
74
+ and self.wo.weight.dtype != torch.int8
75
+ ):
76
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
77
+ hidden_states = self.wo(hidden_states)
78
+ return hidden_states
79
+
80
+
81
+ class T5DenseGatedActDense(nn.Module):
82
+ def __init__(self, config: T5MIMOconvConfig):
83
+ super().__init__()
84
+ self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
85
+ self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
86
+ self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
87
+ self.dropout = nn.Dropout(config.dropout_rate)
88
+ self.act = ACT2FN[config.dense_act_fn]
89
+
90
+ def forward(self, hidden_states):
91
+ hidden_gelu = self.act(self.wi_0(hidden_states))
92
+ hidden_linear = self.wi_1(hidden_states)
93
+ hidden_states = hidden_gelu * hidden_linear
94
+ hidden_states = self.dropout(hidden_states)
95
+
96
+ # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
97
+ # See https://github.com/huggingface/transformers/issues/20287
98
+ # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
99
+ if (
100
+ isinstance(self.wo.weight, torch.Tensor)
101
+ and hidden_states.dtype != self.wo.weight.dtype
102
+ and self.wo.weight.dtype != torch.int8
103
+ ):
104
+ hidden_states = hidden_states.to(self.wo.weight.dtype)
105
+
106
+ hidden_states = self.wo(hidden_states)
107
+ return hidden_states
108
+
109
+
110
+ class T5LayerFF(nn.Module):
111
+ def __init__(self, config: T5MIMOconvConfig):
112
+ super().__init__()
113
+ if config.is_gated_act:
114
+ self.DenseReluDense = T5DenseGatedActDense(config)
115
+ else:
116
+ self.DenseReluDense = T5DenseActDense(config)
117
+
118
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
119
+ self.dropout = nn.Dropout(config.dropout_rate)
120
+
121
+ def forward(self, hidden_states):
122
+ forwarded_states = self.layer_norm(hidden_states)
123
+ forwarded_states = self.DenseReluDense(forwarded_states)
124
+ hidden_states = hidden_states + self.dropout(forwarded_states)
125
+ return hidden_states
126
+
127
+
128
+
129
+ class MultivariateConvBlock(nn.Module):
130
+ def __init__(self, config: T5MIMOconvConfig, kernel_size=3, stride=1, padding=1):
131
+ super().__init__()
132
+ # 2D Convolution across sequences and time
133
+ self.conv1 = nn.Conv2d(
134
+ in_channels=config.num_seqs,
135
+ out_channels=config.num_filters,
136
+ kernel_size=kernel_size, # Kernel spans across time and all features
137
+ stride=1, # Stride across time, no stride across features
138
+ padding=1 # Padding to preserve sequence length, no padding across features
139
+ )
140
+
141
+ # Batch normalization for stabilization and faster convergence
142
+ self.bn1 = nn.BatchNorm2d(config.num_filters)
143
+
144
+ # Second convolution layer to further model interactions and temporal patterns
145
+ self.conv2 = nn.Conv2d(
146
+ in_channels=config.num_filters,
147
+ out_channels=config.num_filters,
148
+ kernel_size=(kernel_size, 1), # Focus only on temporal patterns
149
+ stride=(stride, 1),
150
+ padding=(padding, 0)
151
+ )
152
+
153
+ # Batch normalization after second convolution
154
+ self.bn2 = nn.BatchNorm2d(config.num_filters)
155
+
156
+ # 1x1 Convolution to reduce the channel dimension back to num_seqs
157
+ self.conv3 = nn.Conv2d(
158
+ in_channels=config.num_filters,
159
+ out_channels=config.num_seqs, # Back to the original number of sequences (channels)
160
+ kernel_size=(1, 1)
161
+ )
162
+
163
+ def forward(self, x):
164
+ """
165
+ Forward pass of the multivariate convolutional block.
166
+
167
+ Args:
168
+ x (torch.Tensor): Input tensor of shape [batch_size, num_seqs, seq_len, model_dim].
169
+
170
+ Returns:
171
+ torch.Tensor: Output tensor of shape [batch_size, num_seqs, seq_len, model_dim].
172
+ """
173
+ # Permute to [batch_size, num_seqs, seq_len, model_dim] -> [batch_size, num_seqs, model_dim, seq_len]
174
+ x = x.permute(0, 1, 3, 2)
175
+
176
+ # Apply first convolution and activation
177
+ x = nn.functional.relu(self.bn1(self.conv1(x)))
178
+ # Apply second convolution and activation
179
+ x = nn.functional.relu(self.bn2(self.conv2(x)))
180
+
181
+ # Reduce channel dimension back to num_seqs
182
+ x = self.conv3(x)
183
+
184
+ # Permute back to original shape [batch_size, num_seqs, seq_len, model_dim]
185
+ x = x.permute(0, 1, 3, 2)
186
+
187
+ return x
188
+
189
+
190
+
191
+ class T5Attention(nn.Module):
192
+ def __init__(self, config: T5MIMOconvConfig, has_relative_attention_bias=False):
193
+ super().__init__()
194
+ self.is_decoder = config.is_decoder
195
+ self.has_relative_attention_bias = has_relative_attention_bias
196
+ self.relative_attention_num_buckets = config.relative_attention_num_buckets
197
+ self.relative_attention_max_distance = config.relative_attention_max_distance
198
+ self.d_model = config.d_model
199
+ self.key_value_proj_dim = config.d_kv
200
+ self.n_heads = config.num_heads
201
+ self.dropout = config.dropout_rate
202
+ self.inner_dim = self.n_heads * self.key_value_proj_dim
203
+
204
+ # Mesh TensorFlow initialization to avoid scaling before softmax
205
+ self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
206
+ self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
207
+ self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
208
+ self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
209
+
210
+ if self.has_relative_attention_bias:
211
+ self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
212
+ self.pruned_heads = set()
213
+ self.gradient_checkpointing = False
214
+
215
+ def prune_heads(self, heads):
216
+ if len(heads) == 0:
217
+ return
218
+ heads, index = find_pruneable_heads_and_indices(
219
+ heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
220
+ )
221
+ # Prune linear layers
222
+ self.q = prune_linear_layer(self.q, index)
223
+ self.k = prune_linear_layer(self.k, index)
224
+ self.v = prune_linear_layer(self.v, index)
225
+ self.o = prune_linear_layer(self.o, index, dim=1)
226
+ # Update hyper params
227
+ self.n_heads = self.n_heads - len(heads)
228
+ self.inner_dim = self.key_value_proj_dim * self.n_heads
229
+ self.pruned_heads = self.pruned_heads.union(heads)
230
+
231
+ @staticmethod
232
+ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
233
+ """
234
+ Adapted from Mesh Tensorflow:
235
+ https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
236
+
237
+ Translate relative position to a bucket number for relative attention. The relative position is defined as
238
+ memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
239
+ position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
240
+ small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
241
+ positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
242
+ This should allow for more graceful generalization to longer sequences than the model has been trained on
243
+
244
+ Args:
245
+ relative_position: an int32 Tensor
246
+ bidirectional: a boolean - whether the attention is bidirectional
247
+ num_buckets: an integer
248
+ max_distance: an integer
249
+
250
+ Returns:
251
+ a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
252
+ """
253
+ relative_buckets = 0
254
+ if bidirectional:
255
+ num_buckets //= 2
256
+ relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
257
+ relative_position = torch.abs(relative_position)
258
+ else:
259
+ relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
260
+ # now relative_position is in the range [0, inf)
261
+
262
+ # half of the buckets are for exact increments in positions
263
+ max_exact = num_buckets // 2
264
+ is_small = relative_position < max_exact
265
+
266
+ # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
267
+ relative_position_if_large = max_exact + (
268
+ torch.log(relative_position.float() / max_exact)
269
+ / math.log(max_distance / max_exact)
270
+ * (num_buckets - max_exact)
271
+ ).to(torch.long)
272
+ relative_position_if_large = torch.min(
273
+ relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
274
+ )
275
+
276
+ relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
277
+ return relative_buckets
278
+
279
+ def compute_bias(self, query_length, key_length,multivar_dim=-1, device=None):
280
+ """Compute binned relative position bias"""
281
+ if device is None:
282
+ device = self.relative_attention_bias.weight.device
283
+ context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
284
+ memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
285
+ relative_position = memory_position - context_position # shape (query_length, key_length)
286
+ relative_position_bucket = self._relative_position_bucket(
287
+ relative_position, # shape (query_length, key_length)
288
+ bidirectional=(not self.is_decoder),
289
+ num_buckets=self.relative_attention_num_buckets,
290
+ max_distance=self.relative_attention_max_distance,
291
+ )
292
+ values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
293
+ values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
294
+ if multivar_dim !=-1: # shape (1, multivar_dim, num_heads, query_length, key_length) (copy across)
295
+ values = values.expand(1, multivar_dim, -1, -1, -1)
296
+
297
+ return values
298
+
299
+ def forward(
300
+ self,
301
+ hidden_states,
302
+ mask=None,
303
+ key_value_states=None,
304
+ position_bias=None,
305
+ past_key_value=None,
306
+ layer_head_mask=None,
307
+ query_length=None,
308
+ use_cache=False,
309
+ output_attentions=False,
310
+ ):
311
+ """
312
+ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
313
+ """
314
+ # Input is (batch_size, seq_length, dim)
315
+ # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
316
+ # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
317
+ if len(hidden_states.shape) == 3:
318
+ batch_size, seq_length = hidden_states.shape[:2]
319
+ else:
320
+ batch_size, seq_length = hidden_states.shape[0],hidden_states.shape[2]
321
+ multivar_dim = hidden_states.shape[1]
322
+ real_seq_length = seq_length
323
+
324
+ if past_key_value is not None:
325
+ if len(past_key_value) != 2:
326
+ raise ValueError(
327
+ f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
328
+ )
329
+ real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
330
+
331
+ if len(hidden_states.shape) == 3:
332
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
333
+ else:
334
+ key_length = real_seq_length if key_value_states is None else key_value_states.shape[2]
335
+
336
+
337
+ def shape(states):
338
+ """projection"""
339
+ # states: torch.Size([3, 16, 512]) -> query_states: torch.Size([3, 8, 16, 64])
340
+ # states: torch.Size([3, 6, 16, 512]) -> query_states: torch.Size([3, 6, 8 , 16, 64])
341
+ if len(states.shape) == 3:
342
+ return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
343
+ else:
344
+ return states.view(batch_size, multivar_dim, -1, self.n_heads, self.key_value_proj_dim).transpose(2, 3)
345
+
346
+
347
+ def unshape(states):
348
+ """reshape"""
349
+ if len(states.shape) == 4:
350
+ return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
351
+ else:
352
+ return states.transpose(2, 3).contiguous().view(batch_size, multivar_dim, -1, self.inner_dim)
353
+
354
+ def project(hidden_states, proj_layer, key_value_states, past_key_value):
355
+ """projects hidden states correctly to key/query states"""
356
+ if key_value_states is None:
357
+ # self-attn
358
+ # (batch_size, n_heads, seq_length, dim_per_head)
359
+ hidden_states = shape(proj_layer(hidden_states))
360
+ elif past_key_value is None:
361
+ # cross-attn
362
+ # (batch_size, n_heads, seq_length, dim_per_head)
363
+ hidden_states = shape(proj_layer(key_value_states))
364
+
365
+ if past_key_value is not None:
366
+ if key_value_states is None:
367
+ # self-attn
368
+ # (batch_size, n_heads, key_length, dim_per_head)
369
+ hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
370
+ elif past_key_value.shape[2] != key_value_states.shape[1]:
371
+ # checking that the `sequence_length` of the `past_key_value` is the same as
372
+ # the provided `key_value_states` to support prefix tuning
373
+ # cross-attn
374
+ # (batch_size, n_heads, seq_length, dim_per_head)
375
+ hidden_states = shape(proj_layer(key_value_states))
376
+ else:
377
+ # cross-attn
378
+ hidden_states = past_key_value
379
+ return hidden_states
380
+
381
+ # get query states
382
+ query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
383
+
384
+
385
+ # get key/value states
386
+ key_states = project(
387
+ hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
388
+ )
389
+ value_states = project(
390
+ hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
391
+ )
392
+
393
+
394
+
395
+ # compute scores
396
+ if len(hidden_states.shape) == 3:
397
+ scores = torch.matmul(
398
+ query_states, key_states.transpose(3, 2)
399
+ ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
400
+ else:
401
+ scores = torch.matmul(
402
+ query_states, key_states.transpose(4, 3)
403
+ )
404
+
405
+
406
+
407
+
408
+
409
+ if position_bias is None:
410
+ if not self.has_relative_attention_bias:
411
+
412
+ if len(hidden_states.shape) == 3:
413
+ position_bias = torch.zeros(
414
+ (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
415
+ )
416
+ else:
417
+ position_bias = torch.zeros(
418
+ (1,multivar_dim, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
419
+ )
420
+ if self.gradient_checkpointing and self.training:
421
+ position_bias.requires_grad = True
422
+ else:
423
+
424
+ if len(hidden_states.shape) == 3:
425
+ position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
426
+ else:
427
+ position_bias = self.compute_bias(real_seq_length, key_length,multivar_dim=multivar_dim, device=scores.device)
428
+
429
+ # if key and values are already calculated
430
+ # we want only the last query position bias
431
+ if past_key_value is not None:
432
+ position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
433
+
434
+ if mask is not None:
435
+ position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
436
+
437
+
438
+
439
+ if self.pruned_heads:
440
+ mask = torch.ones(position_bias.shape[1])
441
+ mask[list(self.pruned_heads)] = 0
442
+ position_bias_masked = position_bias[:, mask.bool()]
443
+ else:
444
+ position_bias_masked = position_bias
445
+
446
+
447
+ scores += position_bias_masked
448
+ attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
449
+ scores
450
+ ) # (batch_size, n_heads, seq_length, key_length)
451
+ attn_weights = nn.functional.dropout(
452
+ attn_weights, p=self.dropout, training=self.training
453
+ ) # (batch_size, n_heads, seq_length, key_length)
454
+
455
+ # Mask heads if we want to
456
+ if layer_head_mask is not None:
457
+ attn_weights = attn_weights * layer_head_mask
458
+
459
+
460
+ if len(hidden_states.shape) == 3:
461
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim)
462
+ else:
463
+ attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, multivar_dim, seq_length, dim)
464
+ attn_output = self.o(attn_output)
465
+
466
+
467
+ present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
468
+ outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
469
+
470
+
471
+ if output_attentions:
472
+ outputs = outputs + (attn_weights,)
473
+
474
+ return outputs
475
+
476
+
477
+ class T5LayerSelfAttention(nn.Module):
478
+ def __init__(self, config, has_relative_attention_bias=False):
479
+ super().__init__()
480
+ self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
481
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
482
+ self.dropout = nn.Dropout(config.dropout_rate)
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states,
487
+ attention_mask=None,
488
+ position_bias=None,
489
+ layer_head_mask=None,
490
+ past_key_value=None,
491
+ use_cache=False,
492
+ output_attentions=False,
493
+ ):
494
+ normed_hidden_states = self.layer_norm(hidden_states)
495
+ attention_output = self.SelfAttention(
496
+ normed_hidden_states,
497
+ mask=attention_mask,
498
+ position_bias=position_bias,
499
+ layer_head_mask=layer_head_mask,
500
+ past_key_value=past_key_value,
501
+ use_cache=use_cache,
502
+ output_attentions=output_attentions,
503
+ )
504
+
505
+ hidden_states = hidden_states + self.dropout(attention_output[0])
506
+ outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
507
+ return outputs
508
+
509
+
510
+ class T5LayerCrossAttention(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
514
+ self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
515
+ self.dropout = nn.Dropout(config.dropout_rate)
516
+
517
+ def forward(
518
+ self,
519
+ hidden_states,
520
+ key_value_states,
521
+ attention_mask=None,
522
+ position_bias=None,
523
+ layer_head_mask=None,
524
+ past_key_value=None,
525
+ use_cache=False,
526
+ query_length=None,
527
+ output_attentions=False,
528
+ ):
529
+
530
+ normed_hidden_states = self.layer_norm(hidden_states)
531
+ attention_output = self.EncDecAttention(
532
+ normed_hidden_states,
533
+ mask=attention_mask,
534
+ key_value_states=key_value_states,
535
+ position_bias=position_bias,
536
+ layer_head_mask=layer_head_mask,
537
+ past_key_value=past_key_value,
538
+ use_cache=use_cache,
539
+ query_length=query_length,
540
+ output_attentions=output_attentions,
541
+ )
542
+ layer_output = hidden_states + self.dropout(attention_output[0])
543
+ outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
544
+ return outputs
545
+
546
+
547
+ class T5Block(nn.Module):
548
+ def __init__(self, config, has_relative_attention_bias=False):
549
+ super().__init__()
550
+ self.is_decoder = config.is_decoder
551
+ self.layer = nn.ModuleList()
552
+ self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
553
+ if self.is_decoder:
554
+ self.layer.append(T5LayerCrossAttention(config))
555
+
556
+ self.layer.append(T5LayerFF(config))
557
+
558
+ def forward(
559
+ self,
560
+ hidden_states,
561
+ attention_mask=None,
562
+ position_bias=None,
563
+ encoder_hidden_states=None,
564
+ encoder_attention_mask=None,
565
+ encoder_decoder_position_bias=None,
566
+ layer_head_mask=None,
567
+ cross_attn_layer_head_mask=None,
568
+ past_key_value=None,
569
+ use_cache=False,
570
+ output_attentions=False,
571
+ return_dict=True,
572
+ ):
573
+ if past_key_value is not None:
574
+ if not self.is_decoder:
575
+ logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
576
+ expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
577
+
578
+ if len(past_key_value) != expected_num_past_key_values:
579
+ raise ValueError(
580
+ f"There should be {expected_num_past_key_values} past states. "
581
+ f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
582
+ f"Got {len(past_key_value)} past key / value states"
583
+ )
584
+
585
+ self_attn_past_key_value = past_key_value[:2]
586
+ cross_attn_past_key_value = past_key_value[2:]
587
+ else:
588
+ self_attn_past_key_value, cross_attn_past_key_value = None, None
589
+
590
+ self_attention_outputs = self.layer[0](
591
+ hidden_states,
592
+ attention_mask=attention_mask,
593
+ position_bias=position_bias,
594
+ layer_head_mask=layer_head_mask,
595
+ past_key_value=self_attn_past_key_value,
596
+ use_cache=use_cache,
597
+ output_attentions=output_attentions,
598
+ )
599
+ hidden_states, present_key_value_state = self_attention_outputs[:2]
600
+ attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
601
+
602
+ # clamp inf values to enable fp16 training
603
+ if hidden_states.dtype == torch.float16:
604
+ clamp_value = torch.where(
605
+ torch.isinf(hidden_states).any(),
606
+ torch.finfo(hidden_states.dtype).max - 1000,
607
+ torch.finfo(hidden_states.dtype).max,
608
+ )
609
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
610
+
611
+ do_cross_attention = self.is_decoder and encoder_hidden_states is not None
612
+ if do_cross_attention:
613
+ # the actual query length is unknown for cross attention
614
+ # if using past key value states. Need to inject it here
615
+ if present_key_value_state is not None:
616
+ query_length = present_key_value_state[0].shape[2]
617
+ else:
618
+ query_length = None
619
+
620
+ cross_attention_outputs = self.layer[1](
621
+ hidden_states,
622
+ key_value_states=encoder_hidden_states,
623
+ attention_mask=encoder_attention_mask,
624
+ position_bias=encoder_decoder_position_bias,
625
+ layer_head_mask=cross_attn_layer_head_mask,
626
+ past_key_value=cross_attn_past_key_value,
627
+ query_length=query_length,
628
+ use_cache=use_cache,
629
+ output_attentions=output_attentions,
630
+ )
631
+ hidden_states = cross_attention_outputs[0]
632
+
633
+ # clamp inf values to enable fp16 training
634
+ if hidden_states.dtype == torch.float16:
635
+ clamp_value = torch.where(
636
+ torch.isinf(hidden_states).any(),
637
+ torch.finfo(hidden_states.dtype).max - 1000,
638
+ torch.finfo(hidden_states.dtype).max,
639
+ )
640
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
641
+
642
+ # Combine self attn and cross attn key value states
643
+ if present_key_value_state is not None:
644
+ present_key_value_state = present_key_value_state + cross_attention_outputs[1]
645
+
646
+ # Keep cross-attention outputs and relative position weights
647
+ attention_outputs = attention_outputs + cross_attention_outputs[2:]
648
+
649
+ # Apply Feed Forward layer
650
+ hidden_states = self.layer[-1](hidden_states)
651
+
652
+ # clamp inf values to enable fp16 training
653
+ if hidden_states.dtype == torch.float16:
654
+ clamp_value = torch.where(
655
+ torch.isinf(hidden_states).any(),
656
+ torch.finfo(hidden_states.dtype).max - 1000,
657
+ torch.finfo(hidden_states.dtype).max,
658
+ )
659
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
660
+
661
+ outputs = (hidden_states,)
662
+
663
+ if use_cache:
664
+ outputs = outputs + (present_key_value_state,) + attention_outputs
665
+ else:
666
+ outputs = outputs + attention_outputs
667
+
668
+ return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
669
+
670
+
671
+ class T5ClassificationHead(nn.Module):
672
+ """Head for sentence-level classification tasks."""
673
+
674
+ def __init__(self, config: T5MIMOconvConfig):
675
+ super().__init__()
676
+ self.dense = nn.Linear(config.d_model, config.d_model)
677
+ self.dropout = nn.Dropout(p=config.classifier_dropout)
678
+ self.out_proj = nn.Linear(config.d_model, config.num_labels)
679
+
680
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
681
+ hidden_states = self.dropout(hidden_states)
682
+ hidden_states = self.dense(hidden_states)
683
+ hidden_states = torch.tanh(hidden_states)
684
+ hidden_states = self.dropout(hidden_states)
685
+ hidden_states = self.out_proj(hidden_states)
686
+ return hidden_states
687
+
688
+
689
+ class T5PreTrainedModel(PreTrainedModel):
690
+ """
691
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
692
+ models.
693
+ """
694
+
695
+ config_class = T5MIMOconvConfig
696
+ base_model_prefix = "transformer"
697
+ is_parallelizable = True
698
+ supports_gradient_checkpointing = True
699
+ _no_split_modules = ["T5Block"]
700
+ _keep_in_fp32_modules = ["wo"]
701
+
702
+ @property
703
+ def dummy_inputs(self):
704
+ input_ids = torch.tensor(DUMMY_INPUTS)
705
+ input_mask = torch.tensor(DUMMY_MASK)
706
+ dummy_inputs = {
707
+ "decoder_input_ids": input_ids,
708
+ "input_ids": input_ids,
709
+ "decoder_attention_mask": input_mask,
710
+ }
711
+ return dummy_inputs
712
+
713
+ def _init_weights(self, module):
714
+ """Initialize the weights"""
715
+ factor = self.config.initializer_factor # Used for testing weights initialization
716
+ if isinstance(module, T5LayerNorm):
717
+ module.weight.data.fill_(factor * 1.0)
718
+ elif isinstance(
719
+ module,
720
+ (T5MIMOconvModel, T5MIMOconvForConditionalGeneration, T5MIMOEncoderModel),
721
+ ):
722
+ # Mesh TensorFlow embeddings initialization
723
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
724
+ module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
725
+ if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
726
+ module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
727
+ if hasattr(module, "qa_outputs"):
728
+ module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
729
+ module.qa_outputs.bias.data.zero_()
730
+ elif isinstance(module, T5ClassificationHead):
731
+ module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
732
+ if hasattr(module.dense, "bias") and module.dense.bias is not None:
733
+ module.dense.bias.data.zero_()
734
+ module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
735
+ if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
736
+ module.out_proj.bias.data.zero_()
737
+ elif isinstance(module, T5DenseActDense):
738
+ # Mesh TensorFlow FF initialization
739
+ # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
740
+ # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
741
+ module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
742
+ if hasattr(module.wi, "bias") and module.wi.bias is not None:
743
+ module.wi.bias.data.zero_()
744
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
745
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
746
+ module.wo.bias.data.zero_()
747
+ elif isinstance(module, T5DenseGatedActDense):
748
+ module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
749
+ if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
750
+ module.wi_0.bias.data.zero_()
751
+ module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
752
+ if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
753
+ module.wi_1.bias.data.zero_()
754
+ module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
755
+ if hasattr(module.wo, "bias") and module.wo.bias is not None:
756
+ module.wo.bias.data.zero_()
757
+ elif isinstance(module, T5Attention):
758
+ # Mesh TensorFlow attention initialization to avoid scaling before softmax
759
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
760
+ d_model = self.config.d_model
761
+ key_value_proj_dim = self.config.d_kv
762
+ n_heads = self.config.num_heads
763
+ module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
764
+ module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
765
+ module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
766
+ module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
767
+ if module.has_relative_attention_bias:
768
+ module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
769
+
770
+ def _shift_right(self, input_ids):
771
+ decoder_start_token_id = self.config.decoder_start_token_id
772
+ pad_token_id = self.config.pad_token_id
773
+
774
+ if decoder_start_token_id is None:
775
+ raise ValueError(
776
+ "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
777
+ "See T5 docs for more information."
778
+ )
779
+
780
+ # shift inputs to the right
781
+ if is_torch_fx_proxy(input_ids):
782
+ # Item assignment is not supported natively for proxies.
783
+ shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
784
+ shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
785
+ else:
786
+ shifted_input_ids = input_ids.new_zeros(input_ids.shape)
787
+ shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
788
+ shifted_input_ids[..., 0] = decoder_start_token_id
789
+
790
+ if pad_token_id is None:
791
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
792
+ # replace possible -100 values in labels by `pad_token_id`
793
+ shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
794
+
795
+ return shifted_input_ids
796
+
797
+
798
+ class T5Stack(T5PreTrainedModel):
799
+ def __init__(self, config, embed_tokens=None):
800
+ super().__init__(config)
801
+
802
+ self.embed_tokens = embed_tokens
803
+ self.is_decoder = config.is_decoder
804
+
805
+ self.block = nn.ModuleList(
806
+ [T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)]
807
+ )
808
+ self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
809
+ self.dropout = nn.Dropout(config.dropout_rate)
810
+
811
+ # Initialize weights and apply final processing
812
+ self.post_init()
813
+ # Model parallel
814
+ self.model_parallel = False
815
+ self.device_map = None
816
+ self.gradient_checkpointing = False
817
+
818
+ def parallelize(self, device_map=None):
819
+ warnings.warn(
820
+ "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
821
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
822
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
823
+ " 'block.1': 1, ...}",
824
+ FutureWarning,
825
+ )
826
+ # Check validity of device_map
827
+ self.device_map = (
828
+ get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
829
+ )
830
+ assert_device_map(self.device_map, len(self.block))
831
+ self.model_parallel = True
832
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
833
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
834
+ # Load onto devices
835
+ for k, v in self.device_map.items():
836
+ for layer in v:
837
+ cuda_device = "cuda:" + str(k)
838
+ self.block[layer] = self.block[layer].to(cuda_device)
839
+
840
+ # Set embed_tokens to first layer
841
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
842
+ # Set final layer norm to last device
843
+ self.final_layer_norm = self.final_layer_norm.to(self.last_device)
844
+
845
+
846
+ def deparallelize(self):
847
+ warnings.warn(
848
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
849
+ FutureWarning,
850
+ )
851
+ self.model_parallel = False
852
+ self.device_map = None
853
+ self.first_device = "cpu"
854
+ self.last_device = "cpu"
855
+ for i in range(len(self.block)):
856
+ self.block[i] = self.block[i].to("cpu")
857
+ self.embed_tokens = self.embed_tokens.to("cpu")
858
+ self.final_layer_norm = self.final_layer_norm.to("cpu")
859
+ torch.cuda.empty_cache()
860
+
861
+ def get_input_embeddings(self):
862
+ return self.embed_tokens
863
+
864
+ def set_input_embeddings(self, new_embeddings):
865
+ self.embed_tokens = new_embeddings
866
+
867
+ def forward(
868
+ self,
869
+ input_ids=None,
870
+ attention_mask=None,
871
+ encoder_hidden_states=None,
872
+ encoder_attention_mask=None,
873
+ inputs_embeds=None,
874
+ head_mask=None,
875
+ cross_attn_head_mask=None,
876
+ past_key_values=None,
877
+ use_cache=None,
878
+ output_attentions=None,
879
+ output_hidden_states=None,
880
+ return_dict=None,
881
+ ):
882
+ # Model parallel
883
+ if self.model_parallel:
884
+ torch.cuda.set_device(self.first_device)
885
+ self.embed_tokens = self.embed_tokens.to(self.first_device)
886
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
887
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
888
+ output_hidden_states = (
889
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
890
+ )
891
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
892
+
893
+ if input_ids is not None and inputs_embeds is not None:
894
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
895
+ raise ValueError(
896
+ f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
897
+ )
898
+ elif input_ids is not None:
899
+ input_shape = input_ids.size()
900
+ # input_ids = input_ids.view(-1, input_shape[-1])
901
+ elif inputs_embeds is not None:
902
+ input_shape = inputs_embeds.size()[:-1]
903
+ else:
904
+ err_msg_prefix = "decoder_" if self.is_decoder else ""
905
+ raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
906
+
907
+ if inputs_embeds is None:
908
+ if self.embed_tokens is None:
909
+ raise ValueError("You have to initialize the model with valid token embeddings")
910
+ inputs_embeds = self.embed_tokens(input_ids)
911
+
912
+ if len(input_shape) == 3:
913
+ batch_size, multivar_seqs ,seq_length = input_shape
914
+ else:
915
+ batch_size, seq_length = input_shape
916
+
917
+ # required mask seq length can be calculated via length of past
918
+ mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
919
+
920
+ if use_cache is True:
921
+ if not self.is_decoder:
922
+ raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
923
+
924
+ # initialize past_key_values with `None` if past does not exist
925
+ if past_key_values is None:
926
+ past_key_values = [None] * len(self.block)
927
+
928
+ if attention_mask is None:
929
+ if len(input_shape) == 2:
930
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
931
+ else:
932
+ attention_mask = torch.ones(batch_size, multivar_seqs, mask_seq_length, device=inputs_embeds.device)
933
+
934
+
935
+
936
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
937
+ # ourselves in which case we just need to make it broadcastable to all heads.
938
+ if len(input_shape) == 2:
939
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
940
+ else:
941
+ extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
942
+ # permute from [batch_size, 1, multivar_seqs, seq_length] to [batch_size, multivar_seqs, 1, seq_length]
943
+ extended_attention_mask = extended_attention_mask.permute(0, 2, 1, 3)
944
+ # Now make it [batch_size, multivar_seqs, 1, 1, seq_length]
945
+ extended_attention_mask = extended_attention_mask.unsqueeze(3)
946
+
947
+ # If a 2D or 3D attention mask is provided for the cross-attention
948
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
949
+ if self.is_decoder and encoder_hidden_states is not None:
950
+
951
+ if len(encoder_hidden_states.size()) == 3 :
952
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
953
+ else:
954
+ encoder_batch_size, multivar_dem, encoder_sequence_length, _ = encoder_hidden_states.size()
955
+
956
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
957
+ if encoder_attention_mask is None:
958
+ encoder_attention_mask = torch.ones(
959
+ encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
960
+ )
961
+ if len(input_shape) == 2:
962
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
963
+ else:
964
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
965
+ multivar_dim = extended_attention_mask.shape[1]
966
+ encoder_extended_attention_mask = encoder_extended_attention_mask.unsqueeze(1)
967
+ encoder_extended_attention_mask = encoder_extended_attention_mask.permute(0, 3, 1, 2, 4)
968
+
969
+ else:
970
+ encoder_extended_attention_mask = None
971
+
972
+
973
+
974
+ if self.gradient_checkpointing and self.training:
975
+ if use_cache:
976
+ logger.warning_once(
977
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
978
+ )
979
+ use_cache = False
980
+
981
+ # Prepare head mask if needed
982
+ head_mask = self.get_head_mask(head_mask, self.config.num_layers)
983
+ cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
984
+ present_key_value_states = () if use_cache else None
985
+ all_hidden_states = () if output_hidden_states else None
986
+ all_attentions = () if output_attentions else None
987
+ all_cross_attentions = () if (output_attentions and self.is_decoder) else None
988
+ position_bias = None
989
+ encoder_decoder_position_bias = None
990
+
991
+ hidden_states = self.dropout(inputs_embeds)
992
+
993
+ for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
994
+ layer_head_mask = head_mask[i]
995
+ cross_attn_layer_head_mask = cross_attn_head_mask[i]
996
+ # Model parallel
997
+ if self.model_parallel:
998
+ torch.cuda.set_device(hidden_states.device)
999
+ # Ensure that attention_mask is always on the same device as hidden_states
1000
+ if attention_mask is not None:
1001
+ attention_mask = attention_mask.to(hidden_states.device)
1002
+ if position_bias is not None:
1003
+ position_bias = position_bias.to(hidden_states.device)
1004
+ if encoder_hidden_states is not None:
1005
+ encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
1006
+ if encoder_extended_attention_mask is not None:
1007
+ encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
1008
+ if encoder_decoder_position_bias is not None:
1009
+ encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
1010
+ if layer_head_mask is not None:
1011
+ layer_head_mask = layer_head_mask.to(hidden_states.device)
1012
+ if cross_attn_layer_head_mask is not None:
1013
+ cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
1014
+ if output_hidden_states:
1015
+ all_hidden_states = all_hidden_states + (hidden_states,)
1016
+
1017
+ if self.gradient_checkpointing and self.training:
1018
+ layer_outputs = self._gradient_checkpointing_func(
1019
+ layer_module.forward,
1020
+ hidden_states,
1021
+ extended_attention_mask,
1022
+ position_bias,
1023
+ encoder_hidden_states,
1024
+ encoder_extended_attention_mask,
1025
+ encoder_decoder_position_bias,
1026
+ layer_head_mask,
1027
+ cross_attn_layer_head_mask,
1028
+ None, # past_key_value is always None with gradient checkpointing
1029
+ use_cache,
1030
+ output_attentions,
1031
+ )
1032
+ else:
1033
+ layer_outputs = layer_module(
1034
+ hidden_states,
1035
+ attention_mask=extended_attention_mask,
1036
+ position_bias=position_bias,
1037
+ encoder_hidden_states=encoder_hidden_states,
1038
+ encoder_attention_mask=encoder_extended_attention_mask,
1039
+ encoder_decoder_position_bias=encoder_decoder_position_bias,
1040
+ layer_head_mask=layer_head_mask,
1041
+ cross_attn_layer_head_mask=cross_attn_layer_head_mask,
1042
+ past_key_value=past_key_value,
1043
+ use_cache=use_cache,
1044
+ output_attentions=output_attentions,
1045
+ )
1046
+
1047
+ # layer_outputs is a tuple with:
1048
+ # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
1049
+ if use_cache is False:
1050
+ layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
1051
+
1052
+ hidden_states, present_key_value_state = layer_outputs[:2]
1053
+
1054
+ # We share the position biases between the layers - the first layer store them
1055
+ # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
1056
+ # (cross-attention position bias), (cross-attention weights)
1057
+ position_bias = layer_outputs[2]
1058
+ if self.is_decoder and encoder_hidden_states is not None:
1059
+ encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
1060
+ # append next layer key value states
1061
+ if use_cache:
1062
+ present_key_value_states = present_key_value_states + (present_key_value_state,)
1063
+
1064
+ if output_attentions:
1065
+ all_attentions = all_attentions + (layer_outputs[3],)
1066
+ if self.is_decoder:
1067
+ all_cross_attentions = all_cross_attentions + (layer_outputs[5],)
1068
+
1069
+ # Model Parallel: If it's the last layer for that device, put things on the next device
1070
+ if self.model_parallel:
1071
+ for k, v in self.device_map.items():
1072
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
1073
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
1074
+
1075
+ hidden_states = self.final_layer_norm(hidden_states)
1076
+ hidden_states = self.dropout(hidden_states)
1077
+
1078
+ # Add last layer
1079
+ if output_hidden_states:
1080
+ all_hidden_states = all_hidden_states + (hidden_states,)
1081
+
1082
+ if not return_dict:
1083
+ return tuple(
1084
+ v
1085
+ for v in [
1086
+ hidden_states,
1087
+ present_key_value_states,
1088
+ all_hidden_states,
1089
+ all_attentions,
1090
+ all_cross_attentions,
1091
+ ]
1092
+ if v is not None
1093
+ )
1094
+ return BaseModelOutputWithPastAndCrossAttentions(
1095
+ last_hidden_state=hidden_states,
1096
+ past_key_values=present_key_value_states,
1097
+ hidden_states=all_hidden_states,
1098
+ attentions=all_attentions,
1099
+ cross_attentions=all_cross_attentions,
1100
+ )
1101
+
1102
+
1103
+
1104
+ class T5MIMOconvModel(T5PreTrainedModel):
1105
+ config_class = T5MIMOconvConfig
1106
+
1107
+ _keys_to_ignore_on_load_unexpected = [
1108
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1109
+ ]
1110
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
1111
+
1112
+ def __init__(self, config: T5MIMOconvConfig):
1113
+ super().__init__(config)
1114
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1115
+
1116
+ encoder_config = copy.deepcopy(config)
1117
+ encoder_config.is_decoder = False
1118
+ encoder_config.use_cache = False
1119
+ encoder_config.is_encoder_decoder = False
1120
+ self.encoder = T5Stack(encoder_config, self.shared)
1121
+
1122
+ decoder_config = copy.deepcopy(config)
1123
+ decoder_config.is_decoder = True
1124
+ decoder_config.is_encoder_decoder = False
1125
+ decoder_config.num_layers = config.num_decoder_layers
1126
+ self.decoder = T5Stack(decoder_config, self.shared)
1127
+
1128
+ self.conv_block = MultivariateConvBlock(config)
1129
+
1130
+ # Initialize weights and apply final processing
1131
+ self.post_init()
1132
+
1133
+ # Model parallel
1134
+ self.model_parallel = False
1135
+ self.device_map = None
1136
+
1137
+
1138
+ def parallelize(self, device_map=None):
1139
+ warnings.warn(
1140
+ "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
1141
+ " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1142
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
1143
+ " 0, 'encoder.block.1': 1, ...}",
1144
+ FutureWarning,
1145
+ )
1146
+ self.device_map = (
1147
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1148
+ if device_map is None
1149
+ else device_map
1150
+ )
1151
+ assert_device_map(self.device_map, len(self.encoder.block))
1152
+ self.encoder.parallelize(self.device_map)
1153
+ self.decoder.parallelize(self.device_map)
1154
+ self.model_parallel = True
1155
+
1156
+
1157
+ def deparallelize(self):
1158
+ warnings.warn(
1159
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1160
+ FutureWarning,
1161
+ )
1162
+ self.encoder.deparallelize()
1163
+ self.decoder.deparallelize()
1164
+ self.encoder = self.encoder.to("cpu")
1165
+ self.decoder = self.decoder.to("cpu")
1166
+ self.model_parallel = False
1167
+ self.device_map = None
1168
+ torch.cuda.empty_cache()
1169
+
1170
+ def get_input_embeddings(self):
1171
+ return self.shared
1172
+
1173
+ def set_input_embeddings(self, new_embeddings):
1174
+ self.shared = new_embeddings
1175
+ self.encoder.set_input_embeddings(new_embeddings)
1176
+ self.decoder.set_input_embeddings(new_embeddings)
1177
+
1178
+ def _tie_weights(self):
1179
+ if self.config.tie_word_embeddings:
1180
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1181
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1182
+
1183
+ def get_encoder(self):
1184
+ return self.encoder
1185
+
1186
+ def get_decoder(self):
1187
+ return self.decoder
1188
+
1189
+ def _prune_heads(self, heads_to_prune):
1190
+ """
1191
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1192
+ class PreTrainedModel
1193
+ """
1194
+ for layer, heads in heads_to_prune.items():
1195
+ self.encoder.layer[layer].attention.prune_heads(heads)
1196
+
1197
+ def forward(
1198
+ self,
1199
+ input_ids: Optional[torch.LongTensor] = None,
1200
+ attention_mask: Optional[torch.FloatTensor] = None,
1201
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1202
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1203
+ head_mask: Optional[torch.FloatTensor] = None,
1204
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1205
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1206
+ encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1207
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
1208
+ inputs_embeds: Optional[torch.Tensor] = None,
1209
+ decoder_inputs_embeds: Optional[torch.Tensor] = None,
1210
+ use_cache: Optional[bool] = None,
1211
+ output_attentions: Optional[bool] = None,
1212
+ output_hidden_states: Optional[bool] = None,
1213
+ return_dict: Optional[bool] = None,
1214
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
1215
+ r"""
1216
+ Returns:
1217
+
1218
+ Example:
1219
+
1220
+ ```python
1221
+ >>> from transformers import AutoTokenizer, T5Model
1222
+
1223
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1224
+ >>> model = T5Model.from_pretrained("google-t5/t5-small")
1225
+
1226
+ >>> input_ids = tokenizer(
1227
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1228
+ ... ).input_ids # Batch size 1
1229
+ >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
1230
+
1231
+ >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
1232
+ >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
1233
+ >>> decoder_input_ids = model._shift_right(decoder_input_ids)
1234
+
1235
+ >>> # forward pass
1236
+ >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
1237
+ >>> last_hidden_states = outputs.last_hidden_state
1238
+ ```"""
1239
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1240
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1241
+
1242
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1243
+ if head_mask is not None and decoder_head_mask is None:
1244
+ if self.config.num_layers == self.config.num_decoder_layers:
1245
+ decoder_head_mask = head_mask
1246
+
1247
+ # Encode if needed (training, first prediction pass)
1248
+ if encoder_outputs is None:
1249
+ encoder_outputs = self.encoder(
1250
+ input_ids=input_ids,
1251
+ attention_mask=attention_mask,
1252
+ inputs_embeds=inputs_embeds,
1253
+ head_mask=head_mask,
1254
+ output_attentions=output_attentions,
1255
+ output_hidden_states=output_hidden_states,
1256
+ return_dict=return_dict,
1257
+ )
1258
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1259
+ encoder_outputs = BaseModelOutput(
1260
+ last_hidden_state=encoder_outputs[0],
1261
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1262
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1263
+ )
1264
+
1265
+ hidden_states = encoder_outputs[0]
1266
+
1267
+ # Set device for model parallelism
1268
+ if self.model_parallel:
1269
+ torch.cuda.set_device(self.decoder.first_device)
1270
+ hidden_states = hidden_states.to(self.decoder.first_device)
1271
+ if decoder_input_ids is not None:
1272
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1273
+ if attention_mask is not None:
1274
+ attention_mask = attention_mask.to(self.decoder.first_device)
1275
+ if decoder_attention_mask is not None:
1276
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1277
+
1278
+ # Decode
1279
+ decoder_outputs = self.decoder(
1280
+ input_ids=decoder_input_ids,
1281
+ attention_mask=decoder_attention_mask,
1282
+ inputs_embeds=decoder_inputs_embeds,
1283
+ past_key_values=past_key_values,
1284
+ encoder_hidden_states=hidden_states,
1285
+ encoder_attention_mask=attention_mask,
1286
+ head_mask=decoder_head_mask,
1287
+ cross_attn_head_mask=cross_attn_head_mask,
1288
+ use_cache=use_cache,
1289
+ output_attentions=output_attentions,
1290
+ output_hidden_states=output_hidden_states,
1291
+ return_dict=return_dict,
1292
+ )
1293
+
1294
+ decoder_outputs = self.conv_block(decoder_outputs)
1295
+
1296
+
1297
+ if not return_dict:
1298
+ return decoder_outputs + encoder_outputs
1299
+
1300
+ return Seq2SeqModelOutput(
1301
+ last_hidden_state=decoder_outputs.last_hidden_state,
1302
+ past_key_values=decoder_outputs.past_key_values,
1303
+ decoder_hidden_states=decoder_outputs.hidden_states,
1304
+ decoder_attentions=decoder_outputs.attentions,
1305
+ cross_attentions=decoder_outputs.cross_attentions,
1306
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1307
+ encoder_hidden_states=encoder_outputs.hidden_states,
1308
+ encoder_attentions=encoder_outputs.attentions,
1309
+ )
1310
+
1311
+
1312
+
1313
+ class T5MIMOconvForConditionalGeneration(T5PreTrainedModel):
1314
+ config_class = T5MIMOconvConfig
1315
+
1316
+ _keys_to_ignore_on_load_unexpected = [
1317
+ "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
1318
+ ]
1319
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
1320
+
1321
+ def __init__(self, config: T5MIMOconvConfig):
1322
+ super().__init__(config)
1323
+ self.model_dim = config.d_model
1324
+
1325
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1326
+
1327
+ encoder_config = copy.deepcopy(config)
1328
+ encoder_config.is_decoder = False
1329
+ encoder_config.use_cache = False
1330
+ encoder_config.is_encoder_decoder = False
1331
+ self.encoder = T5Stack(encoder_config, self.shared)
1332
+
1333
+ decoder_config = copy.deepcopy(config)
1334
+ decoder_config.is_decoder = True
1335
+ decoder_config.is_encoder_decoder = False
1336
+ decoder_config.num_layers = config.num_decoder_layers
1337
+ self.decoder = T5Stack(decoder_config, self.shared)
1338
+
1339
+ self.conv_block = MultivariateConvBlock(config)
1340
+
1341
+ self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
1342
+
1343
+ # Initialize weights and apply final processing
1344
+ self.post_init()
1345
+
1346
+ # Model parallel
1347
+ self.model_parallel = False
1348
+ self.device_map = None
1349
+
1350
+
1351
+ def parallelize(self, device_map=None):
1352
+ warnings.warn(
1353
+ "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
1354
+ " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
1355
+ " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
1356
+ " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
1357
+ FutureWarning,
1358
+ )
1359
+ self.device_map = (
1360
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1361
+ if device_map is None
1362
+ else device_map
1363
+ )
1364
+ assert_device_map(self.device_map, len(self.encoder.block))
1365
+ self.encoder.parallelize(self.device_map)
1366
+ self.decoder.parallelize(self.device_map)
1367
+ self.lm_head = self.lm_head.to(self.decoder.first_device)
1368
+ self.model_parallel = True
1369
+
1370
+
1371
+ def deparallelize(self):
1372
+ warnings.warn(
1373
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1374
+ FutureWarning,
1375
+ )
1376
+ self.encoder.deparallelize()
1377
+ self.decoder.deparallelize()
1378
+ self.encoder = self.encoder.to("cpu")
1379
+ self.decoder = self.decoder.to("cpu")
1380
+ self.lm_head = self.lm_head.to("cpu")
1381
+ self.model_parallel = False
1382
+ self.device_map = None
1383
+ torch.cuda.empty_cache()
1384
+
1385
+ def get_input_embeddings(self):
1386
+ return self.shared
1387
+
1388
+ def set_input_embeddings(self, new_embeddings):
1389
+ self.shared = new_embeddings
1390
+ self.encoder.set_input_embeddings(new_embeddings)
1391
+ self.decoder.set_input_embeddings(new_embeddings)
1392
+
1393
+ def _tie_weights(self):
1394
+ if self.config.tie_word_embeddings:
1395
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1396
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
1397
+
1398
+ def set_output_embeddings(self, new_embeddings):
1399
+ self.lm_head = new_embeddings
1400
+
1401
+ def get_output_embeddings(self):
1402
+ return self.lm_head
1403
+
1404
+ def get_encoder(self):
1405
+ return self.encoder
1406
+
1407
+ def get_decoder(self):
1408
+ return self.decoder
1409
+
1410
+ def forward(
1411
+ self,
1412
+ input_ids: Optional[torch.LongTensor] = None,
1413
+ attention_mask: Optional[torch.FloatTensor] = None,
1414
+ decoder_input_ids: Optional[torch.LongTensor] = None,
1415
+ decoder_attention_mask: Optional[torch.BoolTensor] = None,
1416
+ head_mask: Optional[torch.FloatTensor] = None,
1417
+ decoder_head_mask: Optional[torch.FloatTensor] = None,
1418
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
1419
+ encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1420
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1421
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1422
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
1423
+ labels: Optional[torch.LongTensor] = None,
1424
+ use_cache: Optional[bool] = None,
1425
+ output_attentions: Optional[bool] = None,
1426
+ output_hidden_states: Optional[bool] = None,
1427
+ return_dict: Optional[bool] = None,
1428
+ ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
1429
+ r"""
1430
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1431
+ Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
1432
+ config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
1433
+ labels in `[0, ..., config.vocab_size]`
1434
+
1435
+ Returns:
1436
+
1437
+ Examples:
1438
+
1439
+ ```python
1440
+ >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
1441
+
1442
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1443
+ >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
1444
+
1445
+ >>> # training
1446
+ >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
1447
+ >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
1448
+ >>> outputs = model(input_ids=input_ids, labels=labels)
1449
+ >>> loss = outputs.loss
1450
+ >>> logits = outputs.logits
1451
+
1452
+ >>> # inference
1453
+ >>> input_ids = tokenizer(
1454
+ ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
1455
+ ... ).input_ids # Batch size 1
1456
+ >>> outputs = model.generate(input_ids)
1457
+ >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
1458
+ >>> # studies have shown that owning a dog is good for you.
1459
+ ```"""
1460
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1461
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1462
+
1463
+ # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
1464
+ if head_mask is not None and decoder_head_mask is None:
1465
+ if self.config.num_layers == self.config.num_decoder_layers:
1466
+ decoder_head_mask = head_mask
1467
+
1468
+ # Encode if needed (training, first prediction pass)
1469
+ if encoder_outputs is None:
1470
+ # Convert encoder inputs in embeddings if needed
1471
+ encoder_outputs = self.encoder(
1472
+ input_ids=input_ids,
1473
+ attention_mask=attention_mask,
1474
+ inputs_embeds=inputs_embeds,
1475
+ head_mask=head_mask,
1476
+ output_attentions=output_attentions,
1477
+ output_hidden_states=output_hidden_states,
1478
+ return_dict=return_dict,
1479
+ )
1480
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
1481
+ encoder_outputs = BaseModelOutput(
1482
+ last_hidden_state=encoder_outputs[0],
1483
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1484
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1485
+ )
1486
+
1487
+ hidden_states = encoder_outputs[0]
1488
+
1489
+ if self.model_parallel:
1490
+ torch.cuda.set_device(self.decoder.first_device)
1491
+
1492
+ if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
1493
+ # get decoder inputs from shifting lm labels to the right
1494
+ decoder_input_ids = self._shift_right(labels)
1495
+
1496
+ # Set device for model parallelism
1497
+ if self.model_parallel:
1498
+ torch.cuda.set_device(self.decoder.first_device)
1499
+ hidden_states = hidden_states.to(self.decoder.first_device)
1500
+ if decoder_input_ids is not None:
1501
+ decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
1502
+ if attention_mask is not None:
1503
+ attention_mask = attention_mask.to(self.decoder.first_device)
1504
+ if decoder_attention_mask is not None:
1505
+ decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
1506
+
1507
+ # Decode
1508
+ decoder_outputs = self.decoder(
1509
+ input_ids=decoder_input_ids,
1510
+ attention_mask=decoder_attention_mask,
1511
+ inputs_embeds=decoder_inputs_embeds,
1512
+ past_key_values=past_key_values,
1513
+ encoder_hidden_states=hidden_states,
1514
+ encoder_attention_mask=attention_mask,
1515
+ head_mask=decoder_head_mask,
1516
+ cross_attn_head_mask=cross_attn_head_mask,
1517
+ use_cache=use_cache,
1518
+ output_attentions=output_attentions,
1519
+ output_hidden_states=output_hidden_states,
1520
+ return_dict=return_dict,
1521
+ )
1522
+
1523
+ sequence_output = decoder_outputs[0]
1524
+
1525
+ sequence_output = self.conv_block(sequence_output)
1526
+
1527
+
1528
+ # Set device for model parallelism
1529
+ if self.model_parallel:
1530
+ torch.cuda.set_device(self.encoder.first_device)
1531
+ self.lm_head = self.lm_head.to(self.encoder.first_device)
1532
+ sequence_output = sequence_output.to(self.lm_head.weight.device)
1533
+
1534
+ if self.config.tie_word_embeddings:
1535
+ # Rescale output before projecting on vocab
1536
+ # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
1537
+ sequence_output = sequence_output * (self.model_dim**-0.5)
1538
+
1539
+ lm_logits = self.lm_head(sequence_output)
1540
+
1541
+ loss = None
1542
+ if labels is not None:
1543
+ loss_fct = CrossEntropyLoss(ignore_index=-100)
1544
+ # move labels to correct device to enable PP
1545
+ labels = labels.to(lm_logits.device)
1546
+ if len(labels.shape) == 2:
1547
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
1548
+ else:
1549
+ loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.reshape(-1))
1550
+ # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
1551
+
1552
+ if not return_dict:
1553
+ output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
1554
+ return ((loss,) + output) if loss is not None else output
1555
+
1556
+ return Seq2SeqLMOutput(
1557
+ loss=loss,
1558
+ logits=lm_logits,
1559
+ past_key_values=decoder_outputs.past_key_values,
1560
+ decoder_hidden_states=decoder_outputs.hidden_states,
1561
+ decoder_attentions=decoder_outputs.attentions,
1562
+ cross_attentions=decoder_outputs.cross_attentions,
1563
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1564
+ encoder_hidden_states=encoder_outputs.hidden_states,
1565
+ encoder_attentions=encoder_outputs.attentions,
1566
+ )
1567
+
1568
+ def prepare_inputs_for_generation(
1569
+ self,
1570
+ input_ids,
1571
+ past_key_values=None,
1572
+ attention_mask=None,
1573
+ head_mask=None,
1574
+ decoder_head_mask=None,
1575
+ decoder_attention_mask=None,
1576
+ cross_attn_head_mask=None,
1577
+ use_cache=None,
1578
+ encoder_outputs=None,
1579
+ **kwargs,
1580
+ ):
1581
+ # cut decoder_input_ids if past_key_values is used
1582
+ if past_key_values is not None:
1583
+ past_length = past_key_values[0][0].shape[2]
1584
+
1585
+ # Some generation methods already pass only the last input ID
1586
+ if input_ids.shape[1] > past_length:
1587
+ remove_prefix_length = past_length
1588
+ else:
1589
+ # Default to old behavior: keep only final ID
1590
+ remove_prefix_length = input_ids.shape[1] - 1
1591
+
1592
+ input_ids = input_ids[:, remove_prefix_length:]
1593
+
1594
+ return {
1595
+ "decoder_input_ids": input_ids,
1596
+ "past_key_values": past_key_values,
1597
+ "encoder_outputs": encoder_outputs,
1598
+ "attention_mask": attention_mask,
1599
+ "head_mask": head_mask,
1600
+ "decoder_head_mask": decoder_head_mask,
1601
+ "decoder_attention_mask": decoder_attention_mask,
1602
+ "cross_attn_head_mask": cross_attn_head_mask,
1603
+ "use_cache": use_cache,
1604
+ }
1605
+
1606
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
1607
+ return self._shift_right(labels)
1608
+
1609
+ def _reorder_cache(self, past_key_values, beam_idx):
1610
+ # if decoder past is not included in output
1611
+ # speedy decoding is disabled and no need to reorder
1612
+ if past_key_values is None:
1613
+ logger.warning("You might want to consider setting `use_cache=True` to speed up decoding")
1614
+ return past_key_values
1615
+
1616
+ reordered_decoder_past = ()
1617
+ for layer_past_states in past_key_values:
1618
+ # get the correct batch idx from layer past batch dim
1619
+ # batch dim of `past` is at 2nd position
1620
+ reordered_layer_past_states = ()
1621
+ for layer_past_state in layer_past_states:
1622
+ # need to set correct `past` for each of the four key / value states
1623
+ reordered_layer_past_states = reordered_layer_past_states + (
1624
+ layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)),
1625
+ )
1626
+
1627
+ if reordered_layer_past_states[0].shape != layer_past_states[0].shape:
1628
+ raise ValueError(
1629
+ f"reordered_layer_past_states[0] shape {reordered_layer_past_states[0].shape} and layer_past_states[0] shape {layer_past_states[0].shape} mismatched"
1630
+ )
1631
+ if len(reordered_layer_past_states) != len(layer_past_states):
1632
+ raise ValueError(
1633
+ f"length of reordered_layer_past_states {len(reordered_layer_past_states)} and length of layer_past_states {len(layer_past_states)} mismatched"
1634
+ )
1635
+
1636
+ reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,)
1637
+ return reordered_decoder_past
1638
+
1639
+
1640
+
1641
+ class T5MIMOEncoderModel(T5PreTrainedModel):
1642
+ _tied_weights_keys = ["encoder.embed_tokens.weight"]
1643
+ _keys_to_ignore_on_load_unexpected = [r"decoder"]
1644
+
1645
+ def __init__(self, config: T5MIMOconvConfig):
1646
+ super().__init__(config)
1647
+ self.shared = nn.Embedding(config.vocab_size, config.d_model)
1648
+
1649
+ encoder_config = copy.deepcopy(config)
1650
+ encoder_config.use_cache = False
1651
+ encoder_config.is_encoder_decoder = False
1652
+ self.encoder = T5Stack(encoder_config, self.shared)
1653
+
1654
+ # Initialize weights and apply final processing
1655
+ self.post_init()
1656
+
1657
+ # Model parallel
1658
+ self.model_parallel = False
1659
+ self.device_map = None
1660
+
1661
+ def parallelize(self, device_map=None):
1662
+ warnings.warn(
1663
+ "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
1664
+ " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
1665
+ " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
1666
+ " 'block.1': 1, ...}",
1667
+ FutureWarning,
1668
+ )
1669
+ self.device_map = (
1670
+ get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
1671
+ if device_map is None
1672
+ else device_map
1673
+ )
1674
+ assert_device_map(self.device_map, len(self.encoder.block))
1675
+ self.encoder.parallelize(self.device_map)
1676
+ self.model_parallel = True
1677
+
1678
+ def deparallelize(self):
1679
+ warnings.warn(
1680
+ "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
1681
+ FutureWarning,
1682
+ )
1683
+ self.encoder.deparallelize()
1684
+ self.encoder = self.encoder.to("cpu")
1685
+ self.model_parallel = False
1686
+ self.device_map = None
1687
+ torch.cuda.empty_cache()
1688
+
1689
+ def get_input_embeddings(self):
1690
+ return self.shared
1691
+
1692
+ def set_input_embeddings(self, new_embeddings):
1693
+ self.shared = new_embeddings
1694
+ self.encoder.set_input_embeddings(new_embeddings)
1695
+
1696
+ def _tie_weights(self):
1697
+ if self.config.tie_word_embeddings:
1698
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
1699
+
1700
+ def get_encoder(self):
1701
+ return self.encoder
1702
+
1703
+ def _prune_heads(self, heads_to_prune):
1704
+ """
1705
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
1706
+ class PreTrainedModel
1707
+ """
1708
+ for layer, heads in heads_to_prune.items():
1709
+ self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
1710
+
1711
+ def forward(
1712
+ self,
1713
+ input_ids: Optional[torch.LongTensor] = None,
1714
+ attention_mask: Optional[torch.FloatTensor] = None,
1715
+ head_mask: Optional[torch.FloatTensor] = None,
1716
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1717
+ output_attentions: Optional[bool] = None,
1718
+ output_hidden_states: Optional[bool] = None,
1719
+ return_dict: Optional[bool] = None,
1720
+ ) -> Union[Tuple[torch.FloatTensor], BaseModelOutput]:
1721
+ r"""
1722
+ Returns:
1723
+
1724
+ Example:
1725
+
1726
+ ```python
1727
+ >>> from transformers import AutoTokenizer, T5EncoderModel
1728
+
1729
+ >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
1730
+ >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
1731
+ >>> input_ids = tokenizer(
1732
+ ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
1733
+ ... ).input_ids # Batch size 1
1734
+ >>> outputs = model(input_ids=input_ids)
1735
+ >>> last_hidden_states = outputs.last_hidden_state
1736
+ ```"""
1737
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1738
+
1739
+ encoder_outputs = self.encoder(
1740
+ input_ids=input_ids,
1741
+ attention_mask=attention_mask,
1742
+ inputs_embeds=inputs_embeds,
1743
+ head_mask=head_mask,
1744
+ output_attentions=output_attentions,
1745
+ output_hidden_states=output_hidden_states,
1746
+ return_dict=return_dict,
1747
+ )
1748
+
1749
+ return encoder_outputs
1750
+
1751
+
1752
+