Katsumata420 commited on
Commit
170ce12
1 Parent(s): 5118081

Upload RetrievaBertForMaskedLM

Browse files
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "/data/katsumata/sandbox/megatron/checkpoints/tsubame/hf-release/RetrievaBERT-seq2048-iter454k",
3
+ "architectures": [
4
+ "RetrievaBertForMaskedLM"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_retrieva_bert.RetrievaBertConfig",
9
+ "AutoModelForMaskedLM": "modeling_retrieva_bert.RetrievaBertForMaskedLM"
10
+ },
11
+ "bos_token_id": 1,
12
+ "eos_token_id": 2,
13
+ "hidden_act": "silu",
14
+ "hidden_dropout_prob": 0.1,
15
+ "hidden_size": 1536,
16
+ "initializer_range": 0.02,
17
+ "intermediate_size": 4096,
18
+ "layer_norm_eps": 1e-12,
19
+ "lm_head_hidden_act": "gelu",
20
+ "max_position_embeddings": 2048,
21
+ "max_sequence_length": 2048,
22
+ "mlp_bias": true,
23
+ "model_type": "retrieva-bert",
24
+ "num_attention_heads": 24,
25
+ "num_hidden_layers": 48,
26
+ "num_key_value_heads": 1,
27
+ "pad_token_id": 4,
28
+ "position_embedding_type": "absolute",
29
+ "rope_theta": 10000.0,
30
+ "rotary_percent": 1.0,
31
+ "torch_dtype": "bfloat16",
32
+ "transformers_version": "4.41.2",
33
+ "type_vocab_size": 0,
34
+ "use_cache": true,
35
+ "vocab_size": 99584
36
+ }
configuration_retrieva_bert.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021- NVIDIA Corporation and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """RetrievaBERT model configuration"""
16
+
17
+ from transformers.configuration_utils import PretrainedConfig
18
+ from transformers.utils import logging
19
+
20
+
21
+ logger = logging.get_logger(__name__)
22
+
23
+
24
+ class RetrievaBertConfig(PretrainedConfig):
25
+ r"""
26
+ This is the configuration class to store the configuration of a [`RetrievaBertModel`]. It is used to instantiate a
27
+ RETRIEVA_BERT model according to the specified arguments, defining the model architecture. Instantiating a
28
+ configuration with the defaults will yield a similar configuration to that of the RETRIEVA_BERT
29
+ [nvidia/megatron-bert-uncased-345m](https://huggingface.co/nvidia/megatron-bert-uncased-345m) architecture.
30
+
31
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
32
+ documentation from [`PretrainedConfig`] for more information.
33
+
34
+
35
+ Args:
36
+ vocab_size (`int`, *optional*, defaults to 29056):
37
+ Vocabulary size of the RETRIEVA_BERT model. Defines the number of different tokens that can be represented
38
+ by the `inputs_ids` passed when calling [`RetrievaBertModel`].
39
+ hidden_size (`int`, *optional*, defaults to 1024):
40
+ Dimensionality of the encoder layers and the pooler layer.
41
+ num_hidden_layers (`int`, *optional*, defaults to 24):
42
+ Number of hidden layers in the Transformer encoder.
43
+ num_attention_heads (`int`, *optional*, defaults to 16):
44
+ Number of attention heads for each attention layer in the Transformer encoder.
45
+ intermediate_size (`int`, *optional*, defaults to 4096):
46
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
47
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
48
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
49
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
50
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
51
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
52
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
53
+ The dropout ratio for the attention probabilities.
54
+ max_position_embeddings (`int`, *optional*, defaults to 512):
55
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
56
+ just in case (e.g., 512 or 1024 or 2048).
57
+ type_vocab_size (`int`, *optional*, defaults to 2):
58
+ The vocabulary size of the `token_type_ids` passed when calling [`RetrievaBertModel`].
59
+ If set 0, `token_type_ids` is not used.
60
+ initializer_range (`float`, *optional*, defaults to 0.02):
61
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
62
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
63
+ The epsilon used by the layer normalization layers.
64
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
65
+ Type of position embedding. Choose one of `"absolute"`, `"rope"`. For
66
+ positional embeddings use `"absolute"`.
67
+ is_decoder (`bool`, *optional*, defaults to `False`):
68
+ Whether the model is used as a decoder or not. If `False`, the model is used as an encoder.
69
+ use_cache (`bool`, *optional*, defaults to `True`):
70
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
71
+ relevant if `config.is_decoder=True`.
72
+
73
+ Examples:
74
+
75
+ ```python
76
+ >>> from models import RetrievaBertConfig, RetrievaBertModel
77
+
78
+ >>> # Initializing a RETRIEVA_BERT google-bert/bert-base-uncased style configuration
79
+ >>> configuration = RetrievaBertConfig()
80
+
81
+ >>> # Initializing a model (with random weights) from the google-bert/bert-base-uncased style configuration
82
+ >>> model = RetrievaBertModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "retrieva-bert"
89
+
90
+ def __init__(
91
+ self,
92
+ vocab_size=29056,
93
+ hidden_size=1024,
94
+ num_hidden_layers=24,
95
+ num_attention_heads=16,
96
+ intermediate_size=4096,
97
+ hidden_act="silu",
98
+ hidden_dropout_prob=0.1,
99
+ attention_probs_dropout_prob=0.1,
100
+ max_position_embeddings=512,
101
+ type_vocab_size=0,
102
+ initializer_range=0.02,
103
+ layer_norm_eps=1e-12,
104
+ pad_token_id=0,
105
+ position_embedding_type="absolute",
106
+ use_cache=True,
107
+ rope_theta=10000.0,
108
+ rotary_percent=1.0,
109
+ mlp_bias=False,
110
+ num_key_value_heads=None,
111
+ lm_head_hidden_act="gelu",
112
+ **kwargs,
113
+ ):
114
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
115
+
116
+ self.vocab_size = vocab_size
117
+ self.hidden_size = hidden_size
118
+ self.num_hidden_layers = num_hidden_layers
119
+ self.num_attention_heads = num_attention_heads
120
+ self.hidden_act = hidden_act
121
+ self.intermediate_size = intermediate_size
122
+ self.hidden_dropout_prob = hidden_dropout_prob
123
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
124
+ self.max_position_embeddings = max_position_embeddings
125
+ self.type_vocab_size = type_vocab_size
126
+ self.initializer_range = initializer_range
127
+ self.layer_norm_eps = layer_norm_eps
128
+ self.position_embedding_type = position_embedding_type
129
+ self.use_cache = use_cache
130
+ self.rope_theta = rope_theta
131
+ self.rotary_percent = rotary_percent
132
+ self.mlp_bias = mlp_bias
133
+
134
+ if num_key_value_heads is None:
135
+ num_key_value_heads = num_attention_heads
136
+ self.num_key_value_heads = num_key_value_heads
137
+ self.lm_head_hidden_act = lm_head_hidden_act
generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 1,
4
+ "eos_token_id": 2,
5
+ "pad_token_id": 4,
6
+ "transformers_version": "4.41.2"
7
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c42e82e5fd0d4fd37e4b158b8669abfc465c5d16483e3e63ffa2fd7616592ad7
3
+ size 2602880000
modeling_retrieva_bert.py ADDED
@@ -0,0 +1,1950 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018-2021, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch RetrievaBERT model.
17
+
18
+ The follwoing are the differences of the original huffingface/MegatronBERT model.
19
+ - Use RoPE instead of absolute position embeddings.
20
+ - Use Grouped Query Attention (GQA) instead of the standard self-attention.
21
+ - Use Swiglu activation function instead of GELU.
22
+
23
+ RoPE implementation is based on the huggingface's Llama and RoFormer model.
24
+ GQA/Swiglu implementation is based on the Llama model.
25
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
26
+ https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py
27
+ """
28
+
29
+ import math
30
+ import os
31
+ import warnings
32
+ from dataclasses import dataclass
33
+ from typing import Optional, Tuple, Union
34
+
35
+ import torch
36
+ import torch.utils.checkpoint
37
+ from torch import nn
38
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
39
+
40
+ from transformers.activations import ACT2FN
41
+ from transformers.modeling_outputs import (
42
+ BaseModelOutputWithPastAndCrossAttentions,
43
+ BaseModelOutputWithPoolingAndCrossAttentions,
44
+ CausalLMOutputWithCrossAttentions,
45
+ MaskedLMOutput,
46
+ MultipleChoiceModelOutput,
47
+ NextSentencePredictorOutput,
48
+ QuestionAnsweringModelOutput,
49
+ SequenceClassifierOutput,
50
+ TokenClassifierOutput,
51
+ )
52
+ from transformers.modeling_utils import PreTrainedModel
53
+ from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
54
+ from transformers.utils import (
55
+ ModelOutput,
56
+ add_code_sample_docstrings,
57
+ add_start_docstrings,
58
+ add_start_docstrings_to_model_forward,
59
+ logging,
60
+ replace_return_docstrings,
61
+ )
62
+ from .configuration_retrieva_bert import RetrievaBertConfig
63
+
64
+
65
+ logger = logging.get_logger(__name__)
66
+
67
+ _CONFIG_FOR_DOC = "RetrievaBertConfig"
68
+ _CHECKPOINT_FOR_DOC = "nvidia/megatron-bert-cased-345m"
69
+
70
+
71
+ def load_tf_weights_in_megatron_bert(model, config, tf_checkpoint_path):
72
+ """Load tf checkpoints in a pytorch model."""
73
+ try:
74
+ import re
75
+
76
+ import numpy as np
77
+ import tensorflow as tf
78
+ except ImportError:
79
+ logger.error(
80
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
81
+ "https://www.tensorflow.org/install/ for installation instructions."
82
+ )
83
+ raise
84
+ tf_path = os.path.abspath(tf_checkpoint_path)
85
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
86
+ # Load weights from TF model
87
+ init_vars = tf.train.list_variables(tf_path)
88
+ names = []
89
+ arrays = []
90
+ for name, shape in init_vars:
91
+ logger.info(f"Loading TF weight {name} with shape {shape}")
92
+ array = tf.train.load_variable(tf_path, name)
93
+ names.append(name)
94
+ arrays.append(array)
95
+
96
+ for name, array in zip(names, arrays):
97
+ name = name.split("/")
98
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
99
+ # which are not required for using pretrained model
100
+ if any(
101
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
102
+ for n in name
103
+ ):
104
+ logger.info(f"Skipping {'/'.join(name)}")
105
+ continue
106
+ pointer = model
107
+ for m_name in name:
108
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
109
+ scope_names = re.split(r"_(\d+)", m_name)
110
+ else:
111
+ scope_names = [m_name]
112
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
113
+ pointer = getattr(pointer, "weight")
114
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
115
+ pointer = getattr(pointer, "bias")
116
+ elif scope_names[0] == "output_weights":
117
+ pointer = getattr(pointer, "weight")
118
+ elif scope_names[0] == "squad":
119
+ pointer = getattr(pointer, "classifier")
120
+ else:
121
+ try:
122
+ pointer = getattr(pointer, scope_names[0])
123
+ except AttributeError:
124
+ logger.info(f"Skipping {'/'.join(name)}")
125
+ continue
126
+ if len(scope_names) >= 2:
127
+ num = int(scope_names[1])
128
+ pointer = pointer[num]
129
+ if m_name[-11:] == "_embeddings":
130
+ pointer = getattr(pointer, "weight")
131
+ elif m_name == "kernel":
132
+ array = np.transpose(array)
133
+ if pointer.shape != array.shape:
134
+ raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
135
+ logger.info("Initialize PyTorch weight {}".format(name))
136
+ pointer.data = torch.from_numpy(array)
137
+ return model
138
+
139
+
140
+ class RotaryEmbedding(nn.Module):
141
+ """Rotary Embedding for positional encoding."""
142
+
143
+ def __init__(self, hidden_size, max_position_embeddings, theta, rotary_percent=1.0, device=None):
144
+ super().__init__()
145
+ if rotary_percent < 1.0:
146
+ hidden_size = int(hidden_size * rotary_percent)
147
+ self.hidden_size = hidden_size
148
+ self.max_position_embeddings = max_position_embeddings
149
+ self.theta = theta
150
+ inv_freq = 1.0 / (self.theta ** (torch.arange(0, hidden_size, 2, dtype=torch.int64).float().to(device) / hidden_size))
151
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
152
+
153
+ def forward(self, x, position_ids):
154
+ # x: [batch_size, num_attention_heads, seq_len, hidden_size]
155
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
156
+ position_ids_expanded = position_ids[:, None, :].float()
157
+ device_type = x.device.type
158
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
159
+ with torch.autocast(device_type=device_type, enabled=False):
160
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
161
+ emb = torch.cat((freqs, freqs), dim=-1)
162
+ cos = emb.cos()
163
+ sin = emb.sin()
164
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
165
+
166
+
167
+ def rotate_half(x):
168
+ """Rotates half the hidden dims of the input."""
169
+ x1 = x[..., : x.shape[-1] // 2]
170
+ x2 = x[..., x.shape[-1] // 2 :]
171
+ return torch.cat((-x2, x1), dim=-1)
172
+
173
+
174
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
175
+ """Applies Rotary Position Embedding to the query and key tensors.
176
+
177
+ Args:
178
+ q (`torch.Tensor`): The query tensor.
179
+ k (`torch.Tensor`): The key tensor.
180
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
181
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
182
+ position_ids (`torch.Tensor`, *optional*):
183
+ Deprecated and unused.
184
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
185
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
186
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
187
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
188
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
189
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
190
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
191
+ Returns:
192
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
193
+ """
194
+ cos = cos.unsqueeze(unsqueeze_dim)
195
+ sin = sin.unsqueeze(unsqueeze_dim)
196
+ q_embed = (q * cos) + (rotate_half(q) * sin)
197
+ k_embed = (k * cos) + (rotate_half(k) * sin)
198
+ return q_embed, k_embed
199
+
200
+
201
+ class RetrievaBertEmbeddings(nn.Module):
202
+ """Construct the embeddings from word, position and token_type embeddings."""
203
+
204
+ def __init__(self, config):
205
+ super().__init__()
206
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
207
+ if config.position_embedding_type == "absolute":
208
+ self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
209
+ else:
210
+ self.position_embeddings = None
211
+ if config.type_vocab_size > 0:
212
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
213
+ else:
214
+ self.token_type_embeddings = None
215
+
216
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
217
+ # any TensorFlow checkpoint file
218
+
219
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
220
+
221
+ self.register_buffer(
222
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
223
+ )
224
+ self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
225
+
226
+ def forward(
227
+ self,
228
+ input_ids: Optional[torch.LongTensor] = None,
229
+ token_type_ids: Optional[torch.LongTensor] = None,
230
+ position_ids: Optional[torch.LongTensor] = None,
231
+ inputs_embeds: Optional[torch.LongTensor] = None,
232
+ past_key_values_length: int = 0,
233
+ ) -> torch.Tensor:
234
+ if input_ids is not None:
235
+ input_shape = input_ids.size()
236
+ else:
237
+ input_shape = inputs_embeds.size()[:-1]
238
+
239
+ if inputs_embeds is None:
240
+ inputs_embeds = self.word_embeddings(input_ids)
241
+
242
+ if self.position_embeddings is not None:
243
+ if position_ids is None:
244
+ position_ids = self.position_ids[:, past_key_values_length : past_key_values_length + input_shape[1]]
245
+ position_embeddings = self.position_embeddings(position_ids)
246
+ else:
247
+ position_embeddings = None
248
+
249
+ if self.token_type_embeddings is not None:
250
+ if token_type_ids is None:
251
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
252
+
253
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
254
+ else:
255
+ token_type_embeddings = None
256
+
257
+ if position_embeddings is not None and token_type_embeddings is not None:
258
+ embeddings = inputs_embeds + position_embeddings + token_type_embeddings
259
+ elif position_embeddings is not None:
260
+ embeddings = inputs_embeds + position_embeddings
261
+ elif token_type_embeddings is not None:
262
+ embeddings = inputs_embeds + token_type_embeddings
263
+ else:
264
+ embeddings = inputs_embeds
265
+
266
+ embeddings = self.dropout(embeddings)
267
+ return embeddings
268
+
269
+
270
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
271
+ """Repeat key/value weigts for GQA.
272
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
273
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
274
+ """
275
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
276
+ if n_rep == 1:
277
+ return hidden_states
278
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
279
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
280
+
281
+
282
+ # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->MegatronBert->RetrievaBert
283
+ class RetrievaBertSelfAttention(nn.Module):
284
+ def __init__(self, config, position_embedding_type=None):
285
+ super().__init__()
286
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
287
+ raise ValueError(
288
+ f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
289
+ f"heads ({config.num_attention_heads})"
290
+ )
291
+
292
+ self.num_attention_heads = config.num_attention_heads
293
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
294
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
295
+
296
+ self.num_key_value_heads = config.num_key_value_heads
297
+ self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
298
+
299
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
300
+ self.key = nn.Linear(config.hidden_size, self.num_key_value_heads * self.attention_head_size)
301
+ self.value = nn.Linear(config.hidden_size, self.num_key_value_heads * self.attention_head_size)
302
+
303
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
304
+
305
+ if config.position_embedding_type == "rope":
306
+ self.rope_theta = config.rope_theta
307
+ self.rope_emb = RotaryEmbedding(self.attention_head_size, config.max_position_embeddings, self.rope_theta, config.rotary_percent)
308
+ else:
309
+ self.rope_theta = None
310
+ self.rope_emb = None
311
+
312
+ self.is_decoder = config.is_decoder
313
+
314
+ def transpose_for_scores(self, x: torch.Tensor, is_query: bool) -> torch.Tensor:
315
+ if is_query:
316
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
317
+ else:
318
+ new_x_shape = x.size()[:-1] + (self.num_key_value_heads, self.attention_head_size)
319
+ x = x.view(new_x_shape)
320
+ return x.permute(0, 2, 1, 3)
321
+
322
+ def forward(
323
+ self,
324
+ hidden_states: torch.Tensor,
325
+ attention_mask: Optional[torch.FloatTensor] = None,
326
+ position_ids: Optional[torch.LongTensor] = None,
327
+ head_mask: Optional[torch.FloatTensor] = None,
328
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
329
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
330
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
331
+ output_attentions: Optional[bool] = False,
332
+ ) -> Tuple[torch.Tensor]:
333
+ mixed_query_layer = self.query(hidden_states)
334
+ query_layer = self.transpose_for_scores(mixed_query_layer, is_query=True)
335
+
336
+ # If this is instantiated as a cross-attention module, the keys
337
+ # and values come from an encoder; the attention mask needs to be
338
+ # such that the encoder's padding tokens are not attended to.
339
+ is_cross_attention = encoder_hidden_states is not None
340
+
341
+ if is_cross_attention and past_key_value is not None:
342
+ # reuse k,v, cross_attentions
343
+ key_layer = past_key_value[0]
344
+ value_layer = past_key_value[1]
345
+ attention_mask = encoder_attention_mask
346
+ elif is_cross_attention:
347
+ key_layer = self.transpose_for_scores(self.key(encoder_hidden_states), is_query=False)
348
+ value_layer = self.transpose_for_scores(self.value(encoder_hidden_states), is_query=False)
349
+ attention_mask = encoder_attention_mask
350
+ else:
351
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/roformer/modeling_roformer.py#L254-L265
352
+ key_layer = self.transpose_for_scores(self.key(hidden_states), is_query=False)
353
+ value_layer = self.transpose_for_scores(self.value(hidden_states), is_query=False)
354
+
355
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L335-L336
356
+ if self.rope_emb is not None:
357
+ cos, sin = self.rope_emb(hidden_states, position_ids)
358
+ query_layer, key_layer = apply_rotary_pos_emb(query_layer, key_layer, cos, sin)
359
+
360
+ if past_key_value is not None:
361
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
362
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
363
+
364
+ # For GQA, we repeat the key/value weights.
365
+ key_layer = repeat_kv(key_layer, self.num_key_value_groups)
366
+ value_layer = repeat_kv(value_layer, self.num_key_value_groups)
367
+
368
+ if self.is_decoder:
369
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
370
+ # Further calls to cross_attention layer can then reuse all cross-attention
371
+ # key/value_states (first "if" case)
372
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
373
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
374
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
375
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
376
+ past_key_value = (key_layer, value_layer)
377
+
378
+ # Take the dot product between "query" and "key" to get the raw attention scores.
379
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
380
+
381
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
382
+ if attention_mask is not None:
383
+ # Apply the attention mask is (precomputed for all layers in RetrievaBertModel forward() function)
384
+ attention_scores = attention_scores + attention_mask
385
+
386
+ # Normalize the attention scores to probabilities.
387
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
388
+
389
+ # This is actually dropping out entire tokens to attend to, which might
390
+ # seem a bit unusual, but is taken from the original Transformer paper.
391
+ attention_probs = self.dropout(attention_probs)
392
+
393
+ # Mask heads if we want to
394
+ if head_mask is not None:
395
+ attention_probs = attention_probs * head_mask
396
+
397
+ context_layer = torch.matmul(attention_probs, value_layer)
398
+
399
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
400
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
401
+ context_layer = context_layer.view(new_context_layer_shape)
402
+
403
+ outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
404
+
405
+ if self.is_decoder:
406
+ outputs = outputs + (past_key_value,)
407
+ return outputs
408
+
409
+
410
+ # Based transformers.models.bert.modeling_bert.BertSelfOutput. Moved LayerNorm to RetrievaBertAttention below.
411
+ class RetrievaBertSelfOutput(nn.Module):
412
+ def __init__(self, config):
413
+ super().__init__()
414
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
415
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
416
+
417
+ def forward(self, hidden_states: torch.Tensor, residual: torch.Tensor) -> torch.Tensor:
418
+ hidden_states = self.dense(hidden_states)
419
+ hidden_states = self.dropout(hidden_states)
420
+ return residual + hidden_states
421
+
422
+
423
+ # Based transformers.models.bert.modeling_bert.BertAttention. Added LayerNorm.
424
+ class RetrievaBertAttention(nn.Module):
425
+ def __init__(self, config):
426
+ super().__init__()
427
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
428
+ self.self = RetrievaBertSelfAttention(config)
429
+ self.output = RetrievaBertSelfOutput(config)
430
+ self.pruned_heads = set()
431
+
432
+ def prune_heads(self, heads):
433
+ if len(heads) == 0:
434
+ return
435
+ heads, index = find_pruneable_heads_and_indices(
436
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
437
+ )
438
+
439
+ # Prune linear layers
440
+ self.self.query = prune_linear_layer(self.self.query, index)
441
+ self.self.key = prune_linear_layer(self.self.key, index)
442
+ self.self.value = prune_linear_layer(self.self.value, index)
443
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
444
+
445
+ # Update hyper params and store pruned heads
446
+ self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
447
+ self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
448
+ self.pruned_heads = self.pruned_heads.union(heads)
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.Tensor,
453
+ attention_mask: Optional[torch.FloatTensor] = None,
454
+ position_ids: Optional[torch.LongTensor] = None,
455
+ head_mask: Optional[torch.FloatTensor] = None,
456
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
457
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
458
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
459
+ output_attentions: Optional[bool] = False,
460
+ ) -> Tuple[torch.Tensor]:
461
+ ln_outputs = self.ln(hidden_states)
462
+ self_outputs = self.self(
463
+ ln_outputs,
464
+ attention_mask,
465
+ position_ids,
466
+ head_mask,
467
+ encoder_hidden_states,
468
+ encoder_attention_mask,
469
+ past_key_value,
470
+ output_attentions,
471
+ )
472
+ attention_output = self.output(self_outputs[0], hidden_states)
473
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
474
+ return outputs
475
+
476
+
477
+ # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->MegatronBert->RetrievaBert
478
+ class RetrievaBertIntermediate(nn.Module):
479
+ def __init__(self, config):
480
+ super().__init__()
481
+ self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
482
+ self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=config.mlp_bias)
483
+ if isinstance(config.hidden_act, str):
484
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
485
+ else:
486
+ self.intermediate_act_fn = config.hidden_act
487
+
488
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
489
+ gate_hidden_states = self.gate_proj(hidden_states)
490
+ gate_hidden_states = self.intermediate_act_fn(gate_hidden_states)
491
+ up_hidden_state = self.up_proj(hidden_states)
492
+ hidden_states = gate_hidden_states * up_hidden_state
493
+ return hidden_states
494
+
495
+
496
+ # Based on transformers.models.bert.modeling_bert.BertOutput. Moved LayerNorm to RetrievaBertLayer below.
497
+ class RetrievaBertOutput(nn.Module):
498
+ def __init__(self, config):
499
+ super().__init__()
500
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size, bias=config.mlp_bias) # down_proj
501
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
502
+
503
+ def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
504
+ hidden_states = self.dense(hidden_states)
505
+ hidden_states = self.dropout(hidden_states)
506
+ return input_tensor + hidden_states
507
+
508
+
509
+ # Based on transformers.models.bert.modeling_bert.BertLayer. Added LayerNorm.
510
+ class RetrievaBertLayer(nn.Module):
511
+ def __init__(self, config):
512
+ super().__init__()
513
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
514
+ self.seq_len_dim = 1
515
+ self.attention = RetrievaBertAttention(config)
516
+ self.is_decoder = config.is_decoder
517
+ self.add_cross_attention = config.add_cross_attention
518
+ if self.add_cross_attention:
519
+ if not self.is_decoder:
520
+ raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
521
+ self.crossattention = RetrievaBertAttention(config)
522
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
523
+ self.intermediate = RetrievaBertIntermediate(config)
524
+ self.output = RetrievaBertOutput(config)
525
+
526
+ def forward(
527
+ self,
528
+ hidden_states: torch.Tensor,
529
+ attention_mask: Optional[torch.FloatTensor] = None,
530
+ position_ids: Optional[torch.LongTensor] = None,
531
+ head_mask: Optional[torch.FloatTensor] = None,
532
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
533
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
534
+ past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
535
+ output_attentions: Optional[bool] = False,
536
+ ) -> Tuple[torch.Tensor]:
537
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
538
+ self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
539
+ self_attention_outputs = self.attention(
540
+ hidden_states,
541
+ attention_mask,
542
+ position_ids,
543
+ head_mask,
544
+ output_attentions=output_attentions,
545
+ past_key_value=self_attn_past_key_value,
546
+ )
547
+ attention_output = self_attention_outputs[0]
548
+
549
+ # if decoder, the last output is tuple of self-attn cache
550
+ if self.is_decoder:
551
+ outputs = self_attention_outputs[1:-1]
552
+ present_key_value = self_attention_outputs[-1]
553
+ else:
554
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
555
+
556
+ cross_attn_present_key_value = None
557
+ if self.is_decoder and encoder_hidden_states is not None:
558
+ if not hasattr(self, "crossattention"):
559
+ raise AttributeError(
560
+ f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
561
+ " by setting `config.add_cross_attention=True`"
562
+ )
563
+
564
+ # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple
565
+ cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
566
+ cross_attention_outputs = self.crossattention(
567
+ attention_output,
568
+ attention_mask,
569
+ position_ids,
570
+ head_mask,
571
+ encoder_hidden_states,
572
+ encoder_attention_mask,
573
+ cross_attn_past_key_value,
574
+ output_attentions,
575
+ )
576
+ attention_output = cross_attention_outputs[0]
577
+ outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights
578
+
579
+ # add cross-attn cache to positions 3,4 of present_key_value tuple
580
+ cross_attn_present_key_value = cross_attention_outputs[-1]
581
+ present_key_value = present_key_value + cross_attn_present_key_value
582
+
583
+ layer_output = apply_chunking_to_forward(
584
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
585
+ )
586
+ outputs = (layer_output,) + outputs
587
+
588
+ # if decoder, return the attn key/values as the last output
589
+ if self.is_decoder:
590
+ outputs = outputs + (present_key_value,)
591
+
592
+ return outputs
593
+
594
+ def feed_forward_chunk(self, attention_output):
595
+ ln_output = self.ln(attention_output)
596
+ intermediate_output = self.intermediate(ln_output)
597
+ layer_output = self.output(intermediate_output, attention_output)
598
+ return layer_output
599
+
600
+
601
+ class RetrievaBertEncoder(nn.Module):
602
+ def __init__(self, config):
603
+ super().__init__()
604
+ self.config = config
605
+ self.layer = nn.ModuleList([RetrievaBertLayer(config) for _ in range(config.num_hidden_layers)])
606
+
607
+ # The final layer norm. We removed the 1st LN, moved LN to each hidden layer and this one
608
+ # is simply the final LN (Transformer's BERT has it attached to each hidden layer).
609
+ self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) # final_layernorm
610
+ self.gradient_checkpointing = False
611
+
612
+ def forward(
613
+ self,
614
+ hidden_states: torch.Tensor,
615
+ attention_mask: Optional[torch.FloatTensor] = None,
616
+ position_ids: Optional[torch.LongTensor] = None,
617
+ head_mask: Optional[torch.FloatTensor] = None,
618
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
619
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
620
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
621
+ use_cache: Optional[bool] = None,
622
+ output_attentions: Optional[bool] = False,
623
+ output_hidden_states: Optional[bool] = False,
624
+ return_dict: Optional[bool] = True,
625
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
626
+ if self.gradient_checkpointing and self.training:
627
+ if use_cache:
628
+ logger.warning_once(
629
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
630
+ )
631
+ use_cache = False
632
+ all_hidden_states = () if output_hidden_states else None
633
+ all_self_attentions = () if output_attentions else None
634
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
635
+
636
+ next_decoder_cache = () if use_cache else None
637
+ for i, layer_module in enumerate(self.layer):
638
+ if output_hidden_states:
639
+ all_hidden_states = all_hidden_states + (hidden_states,)
640
+
641
+ layer_head_mask = head_mask[i] if head_mask is not None else None
642
+ past_key_value = past_key_values[i] if past_key_values is not None else None
643
+
644
+ if self.gradient_checkpointing and self.training:
645
+ layer_outputs = self._gradient_checkpointing_func(
646
+ layer_module.__call__,
647
+ hidden_states,
648
+ attention_mask,
649
+ position_ids,
650
+ layer_head_mask,
651
+ encoder_hidden_states,
652
+ encoder_attention_mask,
653
+ past_key_value,
654
+ output_attentions,
655
+ )
656
+ else:
657
+ layer_outputs = layer_module(
658
+ hidden_states,
659
+ attention_mask,
660
+ position_ids,
661
+ layer_head_mask,
662
+ encoder_hidden_states,
663
+ encoder_attention_mask,
664
+ past_key_value,
665
+ output_attentions,
666
+ )
667
+
668
+ # Because we moved the layer-norm at the end of the hidden layer, we have non-normali-
669
+ # zed data here. If that's really needed, we must apply LN to match Transformer's BERT.
670
+
671
+ hidden_states = layer_outputs[0]
672
+ if use_cache:
673
+ next_decoder_cache += (layer_outputs[-1],)
674
+ if output_attentions:
675
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
676
+ if self.config.add_cross_attention:
677
+ all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
678
+
679
+ # Finalize the hidden states.
680
+ hidden_states = self.ln(hidden_states)
681
+
682
+ if output_hidden_states:
683
+ all_hidden_states = all_hidden_states + (hidden_states,)
684
+
685
+ if not return_dict:
686
+ return tuple(
687
+ v
688
+ for v in [
689
+ hidden_states,
690
+ next_decoder_cache,
691
+ all_hidden_states,
692
+ all_self_attentions,
693
+ all_cross_attentions,
694
+ ]
695
+ if v is not None
696
+ )
697
+ return BaseModelOutputWithPastAndCrossAttentions(
698
+ last_hidden_state=hidden_states,
699
+ past_key_values=next_decoder_cache,
700
+ hidden_states=all_hidden_states,
701
+ attentions=all_self_attentions,
702
+ cross_attentions=all_cross_attentions,
703
+ )
704
+
705
+
706
+ # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->MegatronBert->RetrievaBert
707
+ class RetrievaBertPooler(nn.Module):
708
+ def __init__(self, config):
709
+ super().__init__()
710
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
711
+ self.activation = nn.Tanh()
712
+
713
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
714
+ # We "pool" the model by simply taking the hidden state corresponding
715
+ # to the first token.
716
+ first_token_tensor = hidden_states[:, 0]
717
+ pooled_output = self.dense(first_token_tensor)
718
+ pooled_output = self.activation(pooled_output)
719
+ return pooled_output
720
+
721
+
722
+ # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->MegatronBert->RetrievaBert
723
+ class RetrievaBertPredictionHeadTransform(nn.Module):
724
+ def __init__(self, config):
725
+ super().__init__()
726
+ # bertlmhead
727
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
728
+ if isinstance(config.lm_head_hidden_act, str):
729
+ self.transform_act_fn = ACT2FN[config.lm_head_hidden_act]
730
+ else:
731
+ self.transform_act_fn = config.lm_head_hidden_act
732
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
733
+
734
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
735
+ hidden_states = self.dense(hidden_states)
736
+ hidden_states = self.transform_act_fn(hidden_states)
737
+ hidden_states = self.LayerNorm(hidden_states)
738
+ return hidden_states
739
+
740
+
741
+ # Copied from transformers.models.bert.modeling_bert.BertLMPredictionHead with Bert->MegatronBert->RetrievaBert
742
+ class RetrievaBertLMPredictionHead(nn.Module):
743
+ def __init__(self, config):
744
+ super().__init__()
745
+ self.transform = RetrievaBertPredictionHeadTransform(config)
746
+
747
+ # The output weights are the same as the input embeddings, but there is
748
+ # an output-only bias for each token.
749
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
750
+
751
+ # output_layer
752
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
753
+
754
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
755
+ self.decoder.bias = self.bias
756
+
757
+ def _tie_weights(self):
758
+ self.decoder.bias = self.bias
759
+
760
+ def forward(self, hidden_states):
761
+ hidden_states = self.transform(hidden_states)
762
+ hidden_states = self.decoder(hidden_states)
763
+ return hidden_states
764
+
765
+
766
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->MegatronBert->RetrievaBert
767
+ class RetrievaBertOnlyMLMHead(nn.Module):
768
+ def __init__(self, config):
769
+ super().__init__()
770
+ self.predictions = RetrievaBertLMPredictionHead(config)
771
+
772
+ def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
773
+ prediction_scores = self.predictions(sequence_output)
774
+ return prediction_scores
775
+
776
+
777
+ # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->MegatronBert->RetrievaBert
778
+ class RetrievaBertOnlyNSPHead(nn.Module):
779
+ def __init__(self, config):
780
+ super().__init__()
781
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
782
+
783
+ def forward(self, pooled_output):
784
+ seq_relationship_score = self.seq_relationship(pooled_output)
785
+ return seq_relationship_score
786
+
787
+
788
+ # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->MegatronBert->RetrievaBert
789
+ class RetrievaBertPreTrainingHeads(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.predictions = RetrievaBertLMPredictionHead(config)
793
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
794
+
795
+ def forward(self, sequence_output, pooled_output):
796
+ prediction_scores = self.predictions(sequence_output)
797
+ seq_relationship_score = self.seq_relationship(pooled_output)
798
+ return prediction_scores, seq_relationship_score
799
+
800
+
801
+ class RetrievaBertPreTrainedModel(PreTrainedModel):
802
+ """
803
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
804
+ models.
805
+ """
806
+
807
+ config_class = RetrievaBertConfig
808
+ load_tf_weights = load_tf_weights_in_megatron_bert
809
+ base_model_prefix = "bert"
810
+ supports_gradient_checkpointing = True
811
+
812
+ def _init_weights(self, module):
813
+ """Initialize the weights"""
814
+ if isinstance(module, (nn.Linear, nn.Embedding)):
815
+ # Slightly different from the TF version which uses truncated_normal for initialization
816
+ # cf https://github.com/pytorch/pytorch/pull/5617
817
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
818
+ elif isinstance(module, nn.LayerNorm):
819
+ module.bias.data.zero_()
820
+ module.weight.data.fill_(1.0)
821
+ if isinstance(module, nn.Linear) and module.bias is not None:
822
+ module.bias.data.zero_()
823
+
824
+
825
+ @dataclass
826
+ # Copied from transformers.models.bert.modeling_bert.BertForPreTrainingOutput with Bert->MegatronBert->RetrievaBert
827
+ class RetrievaBertForPreTrainingOutput(ModelOutput):
828
+ """
829
+ Output type of [`RetrievaBertForPreTraining`].
830
+
831
+ Args:
832
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
833
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
834
+ (classification) loss.
835
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
836
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
837
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
838
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
839
+ before SoftMax).
840
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
841
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
842
+ shape `(batch_size, sequence_length, hidden_size)`.
843
+
844
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
845
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
846
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
847
+ sequence_length)`.
848
+
849
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
850
+ heads.
851
+ """
852
+
853
+ loss: Optional[torch.FloatTensor] = None
854
+ prediction_logits: torch.FloatTensor = None
855
+ seq_relationship_logits: torch.FloatTensor = None
856
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
857
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
858
+
859
+
860
+ RETRIEVA_BERT_START_DOCSTRING = r"""
861
+
862
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
863
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
864
+ etc.)
865
+
866
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
867
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
868
+ and behavior.
869
+
870
+ Parameters:
871
+ config ([`RetrievaBertConfig`]): Model configuration class with all the parameters of the model.
872
+ Initializing with a config file does not load the weights associated with the model, only the
873
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
874
+ """
875
+
876
+ RETRIEVA_BERT_INPUTS_DOCSTRING = r"""
877
+ Args:
878
+ input_ids (`torch.LongTensor` of shape `({0})`):
879
+ Indices of input sequence tokens in the vocabulary.
880
+
881
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
882
+ [`PreTrainedTokenizer.__call__`] for details.
883
+
884
+ [What are input IDs?](../glossary#input-ids)
885
+ attention_mask (`torch.FloatTensor` of shape `({0})`, *optional*):
886
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
887
+
888
+ - 1 for tokens that are **not masked**,
889
+ - 0 for tokens that are **masked**.
890
+
891
+ [What are attention masks?](../glossary#attention-mask)
892
+ token_type_ids (`torch.LongTensor` of shape `({0})`, *optional*):
893
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
894
+ 1]`:
895
+
896
+ - 0 corresponds to a *sentence A* token,
897
+ - 1 corresponds to a *sentence B* token.
898
+
899
+ [What are token type IDs?](../glossary#token-type-ids)
900
+ position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
901
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
902
+ config.max_position_embeddings - 1]`.
903
+
904
+ [What are position IDs?](../glossary#position-ids)
905
+ head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
906
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
907
+
908
+ - 1 indicates the head is **not masked**,
909
+ - 0 indicates the head is **masked**.
910
+
911
+ inputs_embeds (`torch.FloatTensor` of shape `({0}, hidden_size)`, *optional*):
912
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
913
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
914
+ model's internal embedding lookup matrix.
915
+ output_attentions (`bool`, *optional*):
916
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
917
+ tensors for more detail.
918
+ output_hidden_states (`bool`, *optional*):
919
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
920
+ more detail.
921
+ return_dict (`bool`, *optional*):
922
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
923
+ """
924
+
925
+
926
+ @add_start_docstrings(
927
+ "The bare RetrievaBert Model transformer outputting raw hidden-states without any specific head on top.",
928
+ RETRIEVA_BERT_START_DOCSTRING,
929
+ )
930
+ class RetrievaBertModel(RetrievaBertPreTrainedModel):
931
+ """
932
+
933
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
934
+ cross-attention is added between the self-attention layers, following the architecture described in [Attention is
935
+ all you need](https://arxiv.org/abs/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
936
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
937
+
938
+ To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
939
+ to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
940
+ `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
941
+ """
942
+
943
+ def __init__(self, config, add_pooling_layer=True):
944
+ super().__init__(config)
945
+ self.config = config
946
+
947
+ self.embeddings = RetrievaBertEmbeddings(config)
948
+ self.encoder = RetrievaBertEncoder(config)
949
+
950
+ self.pooler = RetrievaBertPooler(config) if add_pooling_layer else None
951
+
952
+ self.register_buffer(
953
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
954
+ )
955
+
956
+ # Initialize weights and apply final processing
957
+ self.post_init()
958
+
959
+ def get_input_embeddings(self):
960
+ return self.embeddings.word_embeddings
961
+
962
+ def set_input_embeddings(self, value):
963
+ self.embeddings.word_embeddings = value
964
+
965
+ def _prune_heads(self, heads_to_prune):
966
+ """
967
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
968
+ class PreTrainedModel
969
+ """
970
+ for layer, heads in heads_to_prune.items():
971
+ self.encoder.layer[layer].attention.prune_heads(heads)
972
+
973
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
974
+ @add_code_sample_docstrings(
975
+ checkpoint=_CHECKPOINT_FOR_DOC,
976
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
977
+ config_class=_CONFIG_FOR_DOC,
978
+ )
979
+ def forward(
980
+ self,
981
+ input_ids: Optional[torch.LongTensor] = None,
982
+ attention_mask: Optional[torch.FloatTensor] = None,
983
+ token_type_ids: Optional[torch.LongTensor] = None,
984
+ position_ids: Optional[torch.LongTensor] = None,
985
+ head_mask: Optional[torch.FloatTensor] = None,
986
+ inputs_embeds: Optional[torch.FloatTensor] = None,
987
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
988
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
989
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
990
+ use_cache: Optional[bool] = None,
991
+ output_attentions: Optional[bool] = None,
992
+ output_hidden_states: Optional[bool] = None,
993
+ return_dict: Optional[bool] = None,
994
+ ) -> Union[Tuple, BaseModelOutputWithPoolingAndCrossAttentions]:
995
+ r"""
996
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
997
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
998
+ the model is configured as a decoder.
999
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1000
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1001
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1002
+
1003
+ - 1 for tokens that are **not masked**,
1004
+ - 0 for tokens that are **masked**.
1005
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1006
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1007
+
1008
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1009
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1010
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1011
+ use_cache (`bool`, *optional*):
1012
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1013
+ `past_key_values`).
1014
+ """
1015
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1016
+ output_hidden_states = (
1017
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1018
+ )
1019
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1020
+
1021
+ if self.config.is_decoder:
1022
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1023
+ else:
1024
+ use_cache = False
1025
+
1026
+ if input_ids is not None and inputs_embeds is not None:
1027
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
1028
+ elif input_ids is not None:
1029
+ self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
1030
+ input_shape = input_ids.size()
1031
+ elif inputs_embeds is not None:
1032
+ input_shape = inputs_embeds.size()[:-1]
1033
+ else:
1034
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
1035
+
1036
+ batch_size, seq_length = input_shape
1037
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
1038
+
1039
+ # past_key_values_length
1040
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1041
+
1042
+ if attention_mask is None:
1043
+ attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
1044
+ if token_type_ids is None:
1045
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
1046
+ if position_ids is None:
1047
+ position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
1048
+
1049
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1050
+ # ourselves in which case we just need to make it broadcastable to all heads.
1051
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
1052
+
1053
+ # If a 2D or 3D attention mask is provided for the cross-attention
1054
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1055
+ if self.config.is_decoder and encoder_hidden_states is not None:
1056
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1057
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
1058
+ if encoder_attention_mask is None:
1059
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
1060
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
1061
+ else:
1062
+ encoder_extended_attention_mask = None
1063
+
1064
+ # Prepare head mask if needed
1065
+ # 1.0 in head_mask indicate we keep the head
1066
+ # attention_probs has shape bsz x n_heads x N x N
1067
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1068
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1069
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1070
+
1071
+ embedding_output = self.embeddings(
1072
+ input_ids=input_ids,
1073
+ position_ids=position_ids,
1074
+ token_type_ids=token_type_ids,
1075
+ inputs_embeds=inputs_embeds,
1076
+ past_key_values_length=past_key_values_length,
1077
+ )
1078
+ encoder_outputs = self.encoder(
1079
+ embedding_output,
1080
+ attention_mask=extended_attention_mask,
1081
+ position_ids=position_ids,
1082
+ head_mask=head_mask,
1083
+ encoder_hidden_states=encoder_hidden_states,
1084
+ encoder_attention_mask=encoder_extended_attention_mask,
1085
+ past_key_values=past_key_values,
1086
+ use_cache=use_cache,
1087
+ output_attentions=output_attentions,
1088
+ output_hidden_states=output_hidden_states,
1089
+ return_dict=return_dict,
1090
+ )
1091
+ sequence_output = encoder_outputs[0]
1092
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1093
+
1094
+ if not return_dict:
1095
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1096
+
1097
+ return BaseModelOutputWithPoolingAndCrossAttentions(
1098
+ last_hidden_state=sequence_output,
1099
+ pooler_output=pooled_output,
1100
+ past_key_values=encoder_outputs.past_key_values,
1101
+ hidden_states=encoder_outputs.hidden_states,
1102
+ attentions=encoder_outputs.attentions,
1103
+ cross_attentions=encoder_outputs.cross_attentions,
1104
+ )
1105
+
1106
+
1107
+ @add_start_docstrings(
1108
+ """
1109
+ MegatronBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
1110
+ `next sentence prediction (classification)` head.
1111
+ RetrievaBert uses a `masked language modeling` only.
1112
+ """,
1113
+ RETRIEVA_BERT_START_DOCSTRING,
1114
+ )
1115
+ class RetrievaBertForPreTraining(RetrievaBertPreTrainedModel):
1116
+ _tied_weights_keys = ["cls.predictions.decoder"]
1117
+
1118
+ def __init__(self, config, add_binary_head=True):
1119
+ super().__init__(config)
1120
+
1121
+ self.bert = RetrievaBertModel(config)
1122
+ self.cls = RetrievaBertPreTrainingHeads(config)
1123
+
1124
+ # Initialize weights and apply final processing
1125
+ self.post_init()
1126
+
1127
+ def get_output_embeddings(self):
1128
+ return self.cls.predictions.decoder
1129
+
1130
+ def set_output_embeddings(self, new_embeddings):
1131
+ self.cls.predictions.decoder = new_embeddings
1132
+ self.cls.predictions.bias = new_embeddings.bias
1133
+
1134
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1135
+ @replace_return_docstrings(output_type=RetrievaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1136
+ def forward(
1137
+ self,
1138
+ input_ids: Optional[torch.LongTensor] = None,
1139
+ attention_mask: Optional[torch.FloatTensor] = None,
1140
+ token_type_ids: Optional[torch.LongTensor] = None,
1141
+ position_ids: Optional[torch.LongTensor] = None,
1142
+ head_mask: Optional[torch.FloatTensor] = None,
1143
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1144
+ labels: Optional[torch.LongTensor] = None,
1145
+ next_sentence_label: Optional[torch.LongTensor] = None,
1146
+ output_attentions: Optional[bool] = None,
1147
+ output_hidden_states: Optional[bool] = None,
1148
+ return_dict: Optional[bool] = None,
1149
+ ) -> Union[Tuple, RetrievaBertForPreTrainingOutput]:
1150
+ r"""
1151
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1152
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1153
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1154
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1155
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1156
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1157
+ (see `input_ids` docstring) Indices should be in `[0, 1]`:
1158
+
1159
+ - 0 indicates sequence B is a continuation of sequence A,
1160
+ - 1 indicates sequence B is a random sequence.
1161
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
1162
+ Used to hide legacy arguments that have been deprecated.
1163
+
1164
+ Returns:
1165
+
1166
+ Example:
1167
+
1168
+ ```python
1169
+ >>> from transformers import AutoTokenizer
1170
+ >>> from models import RetrievaBertForPreTraining
1171
+ >>> import torch
1172
+
1173
+ >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1174
+ >>> model = RetrievaBertForPreTraining.from_pretrained("nvidia/megatron-bert-cased-345m")
1175
+
1176
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1177
+ >>> outputs = model(**inputs)
1178
+
1179
+ >>> prediction_logits = outputs.prediction_logits
1180
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1181
+ ```"""
1182
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1183
+
1184
+ outputs = self.bert(
1185
+ input_ids,
1186
+ attention_mask=attention_mask,
1187
+ token_type_ids=token_type_ids,
1188
+ position_ids=position_ids,
1189
+ head_mask=head_mask,
1190
+ inputs_embeds=inputs_embeds,
1191
+ output_attentions=output_attentions,
1192
+ output_hidden_states=output_hidden_states,
1193
+ return_dict=return_dict,
1194
+ )
1195
+
1196
+ sequence_output, pooled_output = outputs[:2]
1197
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
1198
+
1199
+ total_loss = None
1200
+ if labels is not None and next_sentence_label is not None:
1201
+ loss_fct = CrossEntropyLoss()
1202
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1203
+ next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1204
+ total_loss = masked_lm_loss + next_sentence_loss
1205
+
1206
+ if not return_dict:
1207
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1208
+ return ((total_loss,) + output) if total_loss is not None else output
1209
+
1210
+ return RetrievaBertForPreTrainingOutput(
1211
+ loss=total_loss,
1212
+ prediction_logits=prediction_scores,
1213
+ seq_relationship_logits=seq_relationship_score,
1214
+ hidden_states=outputs.hidden_states,
1215
+ attentions=outputs.attentions,
1216
+ )
1217
+
1218
+
1219
+ @add_start_docstrings(
1220
+ """RetrievaBert Model with a `language modeling` head on top for CLM fine-tuning.""",
1221
+ RETRIEVA_BERT_START_DOCSTRING,
1222
+ )
1223
+ class RetrievaBertForCausalLM(RetrievaBertPreTrainedModel):
1224
+ _tied_weights_keys = ["cls.predictions.decoder"]
1225
+
1226
+ def __init__(self, config):
1227
+ super().__init__(config)
1228
+
1229
+ if not config.is_decoder:
1230
+ logger.warning("If you want to use `RetrievaBertForCausalLM` as a standalone, add `is_decoder=True.`")
1231
+
1232
+ self.bert = RetrievaBertModel(config, add_pooling_layer=False)
1233
+ self.cls = RetrievaBertOnlyMLMHead(config)
1234
+
1235
+ # Initialize weights and apply final processing
1236
+ self.post_init()
1237
+
1238
+ def get_output_embeddings(self):
1239
+ return self.cls.predictions.decoder
1240
+
1241
+ def set_output_embeddings(self, new_embeddings):
1242
+ self.cls.predictions.decoder = new_embeddings
1243
+ self.cls.predictions.bias = new_embeddings.bias
1244
+
1245
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1246
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1247
+ def forward(
1248
+ self,
1249
+ input_ids: Optional[torch.LongTensor] = None,
1250
+ attention_mask: Optional[torch.FloatTensor] = None,
1251
+ token_type_ids: Optional[torch.LongTensor] = None,
1252
+ position_ids: Optional[torch.LongTensor] = None,
1253
+ head_mask: Optional[torch.FloatTensor] = None,
1254
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1255
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1256
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1257
+ labels: Optional[torch.LongTensor] = None,
1258
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
1259
+ use_cache: Optional[bool] = None,
1260
+ output_attentions: Optional[bool] = None,
1261
+ output_hidden_states: Optional[bool] = None,
1262
+ return_dict: Optional[bool] = None,
1263
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
1264
+ r"""
1265
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1266
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1267
+ the model is configured as a decoder.
1268
+ encoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
1269
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1270
+ the cross-attention if the model is configured as a decoder. Mask values selected in `[0, 1]`:
1271
+
1272
+ - 1 for tokens that are **not masked**,
1273
+ - 0 for tokens that are **masked**.
1274
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1275
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1276
+ `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
1277
+ ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
1278
+ past_key_values (`tuple(tuple(torch.FloatTensor))` of length `config.n_layers` with each tuple having 4 tensors of shape `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1279
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1280
+
1281
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
1282
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
1283
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
1284
+ use_cache (`bool`, *optional*):
1285
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
1286
+ `past_key_values`).
1287
+
1288
+ Returns:
1289
+
1290
+ Example:
1291
+
1292
+ ```python
1293
+ >>> from transformers import AutoTokenizer
1294
+ >>> from models import RetrievaBertForCausalLM, RetrievaBertConfig
1295
+ >>> import torch
1296
+
1297
+ >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1298
+ >>> model = RetrievaBertForCausalLM.from_pretrained("nvidia/megatron-bert-cased-345m", is_decoder=True)
1299
+
1300
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1301
+ >>> outputs = model(**inputs)
1302
+
1303
+ >>> prediction_logits = outputs.logits
1304
+ ```"""
1305
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1306
+ if labels is not None:
1307
+ use_cache = False
1308
+
1309
+ outputs = self.bert(
1310
+ input_ids,
1311
+ attention_mask=attention_mask,
1312
+ token_type_ids=token_type_ids,
1313
+ position_ids=position_ids,
1314
+ head_mask=head_mask,
1315
+ inputs_embeds=inputs_embeds,
1316
+ encoder_hidden_states=encoder_hidden_states,
1317
+ encoder_attention_mask=encoder_attention_mask,
1318
+ past_key_values=past_key_values,
1319
+ use_cache=use_cache,
1320
+ output_attentions=output_attentions,
1321
+ output_hidden_states=output_hidden_states,
1322
+ return_dict=return_dict,
1323
+ )
1324
+
1325
+ sequence_output = outputs[0]
1326
+ prediction_scores = self.cls(sequence_output)
1327
+
1328
+ lm_loss = None
1329
+ if labels is not None:
1330
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1331
+ shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
1332
+ labels = labels[:, 1:].contiguous()
1333
+ loss_fct = CrossEntropyLoss()
1334
+ lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1335
+
1336
+ if not return_dict:
1337
+ output = (prediction_scores,) + outputs[2:]
1338
+ return ((lm_loss,) + output) if lm_loss is not None else output
1339
+
1340
+ return CausalLMOutputWithCrossAttentions(
1341
+ loss=lm_loss,
1342
+ logits=prediction_scores,
1343
+ past_key_values=outputs.past_key_values,
1344
+ hidden_states=outputs.hidden_states,
1345
+ attentions=outputs.attentions,
1346
+ cross_attentions=outputs.cross_attentions,
1347
+ )
1348
+
1349
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
1350
+ input_shape = input_ids.shape
1351
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1352
+ if attention_mask is None:
1353
+ attention_mask = input_ids.new_ones(input_shape)
1354
+
1355
+ # cut decoder_input_ids if past_key_values is used
1356
+ if past_key_values is not None:
1357
+ past_length = past_key_values[0][0].shape[2]
1358
+
1359
+ # Some generation methods already pass only the last input ID
1360
+ if input_ids.shape[1] > past_length:
1361
+ remove_prefix_length = past_length
1362
+ else:
1363
+ # Default to old behavior: keep only final ID
1364
+ remove_prefix_length = input_ids.shape[1] - 1
1365
+
1366
+ input_ids = input_ids[:, remove_prefix_length:]
1367
+
1368
+ return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values}
1369
+
1370
+ def _reorder_cache(self, past_key_values, beam_idx):
1371
+ reordered_past = ()
1372
+ for layer_past in past_key_values:
1373
+ reordered_past += (
1374
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
1375
+ )
1376
+ return reordered_past
1377
+
1378
+
1379
+ @add_start_docstrings("""RetrievaBert Model with a `language modeling` head on top.""", RETRIEVA_BERT_START_DOCSTRING)
1380
+ class RetrievaBertForMaskedLM(RetrievaBertPreTrainedModel):
1381
+ _tied_weights_keys = ["cls.predictions.decoder"]
1382
+
1383
+ def __init__(self, config):
1384
+ super().__init__(config)
1385
+
1386
+ if config.is_decoder:
1387
+ logger.warning(
1388
+ "If you want to use `RetrievaBertForMaskedLM` make sure `config.is_decoder=False` for "
1389
+ "bi-directional self-attention."
1390
+ )
1391
+
1392
+ self.bert = RetrievaBertModel(config, add_pooling_layer=False)
1393
+ self.cls = RetrievaBertOnlyMLMHead(config)
1394
+
1395
+ # Initialize weights and apply final processing
1396
+ self.post_init()
1397
+
1398
+ def get_output_embeddings(self):
1399
+ return self.cls.predictions.decoder
1400
+
1401
+ def set_output_embeddings(self, new_embeddings):
1402
+ self.cls.predictions.decoder = new_embeddings
1403
+ self.cls.predictions.bias = new_embeddings.bias
1404
+
1405
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1406
+ @add_code_sample_docstrings(
1407
+ checkpoint=_CHECKPOINT_FOR_DOC,
1408
+ output_type=MaskedLMOutput,
1409
+ config_class=_CONFIG_FOR_DOC,
1410
+ )
1411
+ def forward(
1412
+ self,
1413
+ input_ids: Optional[torch.LongTensor] = None,
1414
+ attention_mask: Optional[torch.FloatTensor] = None,
1415
+ token_type_ids: Optional[torch.LongTensor] = None,
1416
+ position_ids: Optional[torch.LongTensor] = None,
1417
+ head_mask: Optional[torch.FloatTensor] = None,
1418
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1419
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
1420
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
1421
+ labels: Optional[torch.LongTensor] = None,
1422
+ output_attentions: Optional[bool] = None,
1423
+ output_hidden_states: Optional[bool] = None,
1424
+ return_dict: Optional[bool] = None,
1425
+ ) -> Union[Tuple, MaskedLMOutput]:
1426
+ r"""
1427
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1428
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
1429
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
1430
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
1431
+ """
1432
+
1433
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1434
+
1435
+ outputs = self.bert(
1436
+ input_ids,
1437
+ attention_mask=attention_mask,
1438
+ token_type_ids=token_type_ids,
1439
+ position_ids=position_ids,
1440
+ head_mask=head_mask,
1441
+ inputs_embeds=inputs_embeds,
1442
+ encoder_hidden_states=encoder_hidden_states,
1443
+ encoder_attention_mask=encoder_attention_mask,
1444
+ output_attentions=output_attentions,
1445
+ output_hidden_states=output_hidden_states,
1446
+ return_dict=return_dict,
1447
+ )
1448
+
1449
+ sequence_output = outputs[0]
1450
+ prediction_scores = self.cls(sequence_output)
1451
+
1452
+ masked_lm_loss = None
1453
+ if labels is not None:
1454
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1455
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1456
+
1457
+ if not return_dict:
1458
+ output = (prediction_scores,) + outputs[2:]
1459
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1460
+
1461
+ return MaskedLMOutput(
1462
+ loss=masked_lm_loss,
1463
+ logits=prediction_scores,
1464
+ hidden_states=outputs.hidden_states,
1465
+ attentions=outputs.attentions,
1466
+ )
1467
+
1468
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1469
+ input_shape = input_ids.shape
1470
+ effective_batch_size = input_shape[0]
1471
+
1472
+ # add a dummy token
1473
+ if self.config.pad_token_id is None:
1474
+ raise ValueError("The PAD token should be defined for generation")
1475
+ attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1476
+ dummy_token = torch.full(
1477
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1478
+ )
1479
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1480
+
1481
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1482
+
1483
+
1484
+ @add_start_docstrings(
1485
+ """RetrievaBert Model with a `next sentence prediction (classification)` head on top.""",
1486
+ RETRIEVA_BERT_START_DOCSTRING,
1487
+ )
1488
+ class RetrievaBertForNextSentencePrediction(RetrievaBertPreTrainedModel):
1489
+ def __init__(self, config):
1490
+ super().__init__(config)
1491
+
1492
+ self.bert = RetrievaBertModel(config)
1493
+ self.cls = RetrievaBertOnlyNSPHead(config)
1494
+
1495
+ # Initialize weights and apply final processing
1496
+ self.post_init()
1497
+
1498
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1499
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1500
+ def forward(
1501
+ self,
1502
+ input_ids: Optional[torch.LongTensor] = None,
1503
+ attention_mask: Optional[torch.FloatTensor] = None,
1504
+ token_type_ids: Optional[torch.LongTensor] = None,
1505
+ position_ids: Optional[torch.LongTensor] = None,
1506
+ head_mask: Optional[torch.FloatTensor] = None,
1507
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1508
+ labels: Optional[torch.LongTensor] = None,
1509
+ output_attentions: Optional[bool] = None,
1510
+ output_hidden_states: Optional[bool] = None,
1511
+ return_dict: Optional[bool] = None,
1512
+ **kwargs,
1513
+ ) -> Union[Tuple, NextSentencePredictorOutput]:
1514
+ r"""
1515
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1516
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1517
+ (see `input_ids` docstring). Indices should be in `[0, 1]`:
1518
+
1519
+ - 0 indicates sequence B is a continuation of sequence A,
1520
+ - 1 indicates sequence B is a random sequence.
1521
+
1522
+ Returns:
1523
+
1524
+ Example:
1525
+
1526
+ ```python
1527
+ >>> from transformers import AutoTokenizer
1528
+ >>> from models import RetrievaBertForNextSentencePrediction
1529
+ >>> import torch
1530
+
1531
+ >>> tokenizer = AutoTokenizer.from_pretrained("nvidia/megatron-bert-cased-345m")
1532
+ >>> model = RetrievaBertForNextSentencePrediction.from_pretrained("nvidia/megatron-bert-cased-345m")
1533
+
1534
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1535
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1536
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
1537
+
1538
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1539
+ >>> logits = outputs.logits
1540
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1541
+ ```"""
1542
+
1543
+ if "next_sentence_label" in kwargs:
1544
+ warnings.warn(
1545
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
1546
+ " `labels` instead.",
1547
+ FutureWarning,
1548
+ )
1549
+ labels = kwargs.pop("next_sentence_label")
1550
+
1551
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1552
+
1553
+ outputs = self.bert(
1554
+ input_ids,
1555
+ attention_mask=attention_mask,
1556
+ token_type_ids=token_type_ids,
1557
+ position_ids=position_ids,
1558
+ head_mask=head_mask,
1559
+ inputs_embeds=inputs_embeds,
1560
+ output_attentions=output_attentions,
1561
+ output_hidden_states=output_hidden_states,
1562
+ return_dict=return_dict,
1563
+ )
1564
+
1565
+ pooled_output = outputs[1]
1566
+
1567
+ seq_relationship_scores = self.cls(pooled_output)
1568
+
1569
+ next_sentence_loss = None
1570
+ if labels is not None:
1571
+ loss_fct = CrossEntropyLoss()
1572
+ next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
1573
+
1574
+ if not return_dict:
1575
+ output = (seq_relationship_scores,) + outputs[2:]
1576
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1577
+
1578
+ return NextSentencePredictorOutput(
1579
+ loss=next_sentence_loss,
1580
+ logits=seq_relationship_scores,
1581
+ hidden_states=outputs.hidden_states,
1582
+ attentions=outputs.attentions,
1583
+ )
1584
+
1585
+
1586
+ @add_start_docstrings(
1587
+ """
1588
+ RetrievaBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1589
+ pooled output) e.g. for GLUE tasks.
1590
+ """,
1591
+ RETRIEVA_BERT_START_DOCSTRING,
1592
+ )
1593
+ class RetrievaBertForSequenceClassification(RetrievaBertPreTrainedModel):
1594
+ def __init__(self, config):
1595
+ super().__init__(config)
1596
+ self.num_labels = config.num_labels
1597
+
1598
+ self.bert = RetrievaBertModel(config)
1599
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1600
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1601
+
1602
+ # Initialize weights and apply final processing
1603
+ self.post_init()
1604
+
1605
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1606
+ @add_code_sample_docstrings(
1607
+ checkpoint=_CHECKPOINT_FOR_DOC,
1608
+ output_type=SequenceClassifierOutput,
1609
+ config_class=_CONFIG_FOR_DOC,
1610
+ )
1611
+ def forward(
1612
+ self,
1613
+ input_ids: Optional[torch.LongTensor] = None,
1614
+ attention_mask: Optional[torch.FloatTensor] = None,
1615
+ token_type_ids: Optional[torch.LongTensor] = None,
1616
+ position_ids: Optional[torch.LongTensor] = None,
1617
+ head_mask: Optional[torch.FloatTensor] = None,
1618
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1619
+ labels: Optional[torch.LongTensor] = None,
1620
+ output_attentions: Optional[bool] = None,
1621
+ output_hidden_states: Optional[bool] = None,
1622
+ return_dict: Optional[bool] = None,
1623
+ ) -> Union[Tuple, SequenceClassifierOutput]:
1624
+ r"""
1625
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1626
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
1627
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
1628
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1629
+ """
1630
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1631
+
1632
+ outputs = self.bert(
1633
+ input_ids,
1634
+ attention_mask=attention_mask,
1635
+ token_type_ids=token_type_ids,
1636
+ position_ids=position_ids,
1637
+ head_mask=head_mask,
1638
+ inputs_embeds=inputs_embeds,
1639
+ output_attentions=output_attentions,
1640
+ output_hidden_states=output_hidden_states,
1641
+ return_dict=return_dict,
1642
+ )
1643
+
1644
+ pooled_output = outputs[1]
1645
+
1646
+ pooled_output = self.dropout(pooled_output)
1647
+ logits = self.classifier(pooled_output)
1648
+
1649
+ loss = None
1650
+ if labels is not None:
1651
+ if self.config.problem_type is None:
1652
+ if self.num_labels == 1:
1653
+ self.config.problem_type = "regression"
1654
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1655
+ self.config.problem_type = "single_label_classification"
1656
+ else:
1657
+ self.config.problem_type = "multi_label_classification"
1658
+
1659
+ if self.config.problem_type == "regression":
1660
+ loss_fct = MSELoss()
1661
+ if self.num_labels == 1:
1662
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1663
+ else:
1664
+ loss = loss_fct(logits, labels)
1665
+ elif self.config.problem_type == "single_label_classification":
1666
+ loss_fct = CrossEntropyLoss()
1667
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1668
+ elif self.config.problem_type == "multi_label_classification":
1669
+ loss_fct = BCEWithLogitsLoss()
1670
+ loss = loss_fct(logits, labels)
1671
+ if not return_dict:
1672
+ output = (logits,) + outputs[2:]
1673
+ return ((loss,) + output) if loss is not None else output
1674
+
1675
+ return SequenceClassifierOutput(
1676
+ loss=loss,
1677
+ logits=logits,
1678
+ hidden_states=outputs.hidden_states,
1679
+ attentions=outputs.attentions,
1680
+ )
1681
+
1682
+
1683
+ @add_start_docstrings(
1684
+ """
1685
+ RetrievaBert Model with a multiple choice classification head on top (a linear layer on top of the pooled output
1686
+ and a softmax) e.g. for RocStories/SWAG tasks.
1687
+ """,
1688
+ RETRIEVA_BERT_START_DOCSTRING,
1689
+ )
1690
+ class RetrievaBertForMultipleChoice(RetrievaBertPreTrainedModel):
1691
+ def __init__(self, config):
1692
+ super().__init__(config)
1693
+
1694
+ self.bert = RetrievaBertModel(config)
1695
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1696
+ self.classifier = nn.Linear(config.hidden_size, 1)
1697
+
1698
+ # Initialize weights and apply final processing
1699
+ self.post_init()
1700
+
1701
+ @add_start_docstrings_to_model_forward(
1702
+ RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1703
+ )
1704
+ @add_code_sample_docstrings(
1705
+ checkpoint=_CHECKPOINT_FOR_DOC,
1706
+ output_type=MultipleChoiceModelOutput,
1707
+ config_class=_CONFIG_FOR_DOC,
1708
+ )
1709
+ def forward(
1710
+ self,
1711
+ input_ids: Optional[torch.LongTensor] = None,
1712
+ attention_mask: Optional[torch.FloatTensor] = None,
1713
+ token_type_ids: Optional[torch.LongTensor] = None,
1714
+ position_ids: Optional[torch.LongTensor] = None,
1715
+ head_mask: Optional[torch.FloatTensor] = None,
1716
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1717
+ labels: Optional[torch.LongTensor] = None,
1718
+ output_attentions: Optional[bool] = None,
1719
+ output_hidden_states: Optional[bool] = None,
1720
+ return_dict: Optional[bool] = None,
1721
+ ) -> Union[Tuple, MultipleChoiceModelOutput]:
1722
+ r"""
1723
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1724
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
1725
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
1726
+ `input_ids` above)
1727
+ """
1728
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1729
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1730
+
1731
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
1732
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
1733
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
1734
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
1735
+ inputs_embeds = (
1736
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
1737
+ if inputs_embeds is not None
1738
+ else None
1739
+ )
1740
+
1741
+ outputs = self.bert(
1742
+ input_ids,
1743
+ attention_mask=attention_mask,
1744
+ token_type_ids=token_type_ids,
1745
+ position_ids=position_ids,
1746
+ head_mask=head_mask,
1747
+ inputs_embeds=inputs_embeds,
1748
+ output_attentions=output_attentions,
1749
+ output_hidden_states=output_hidden_states,
1750
+ return_dict=return_dict,
1751
+ )
1752
+
1753
+ pooled_output = outputs[1]
1754
+
1755
+ pooled_output = self.dropout(pooled_output)
1756
+ logits = self.classifier(pooled_output)
1757
+ reshaped_logits = logits.view(-1, num_choices)
1758
+
1759
+ loss = None
1760
+ if labels is not None:
1761
+ loss_fct = CrossEntropyLoss()
1762
+ loss = loss_fct(reshaped_logits, labels)
1763
+
1764
+ if not return_dict:
1765
+ output = (reshaped_logits,) + outputs[2:]
1766
+ return ((loss,) + output) if loss is not None else output
1767
+
1768
+ return MultipleChoiceModelOutput(
1769
+ loss=loss,
1770
+ logits=reshaped_logits,
1771
+ hidden_states=outputs.hidden_states,
1772
+ attentions=outputs.attentions,
1773
+ )
1774
+
1775
+
1776
+ @add_start_docstrings(
1777
+ """
1778
+ RetrievaBert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
1779
+ for Named-Entity-Recognition (NER) tasks.
1780
+ """,
1781
+ RETRIEVA_BERT_START_DOCSTRING,
1782
+ )
1783
+ class RetrievaBertForTokenClassification(RetrievaBertPreTrainedModel):
1784
+ def __init__(self, config):
1785
+ super().__init__(config)
1786
+ self.num_labels = config.num_labels
1787
+
1788
+ self.bert = RetrievaBertModel(config, add_pooling_layer=False)
1789
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1790
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1791
+
1792
+ # Initialize weights and apply final processing
1793
+ self.post_init()
1794
+
1795
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1796
+ @add_code_sample_docstrings(
1797
+ checkpoint=_CHECKPOINT_FOR_DOC,
1798
+ output_type=TokenClassifierOutput,
1799
+ config_class=_CONFIG_FOR_DOC,
1800
+ )
1801
+ def forward(
1802
+ self,
1803
+ input_ids: Optional[torch.LongTensor] = None,
1804
+ attention_mask: Optional[torch.FloatTensor] = None,
1805
+ token_type_ids: Optional[torch.LongTensor] = None,
1806
+ position_ids: Optional[torch.LongTensor] = None,
1807
+ head_mask: Optional[torch.FloatTensor] = None,
1808
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1809
+ labels: Optional[torch.LongTensor] = None,
1810
+ output_attentions: Optional[bool] = None,
1811
+ output_hidden_states: Optional[bool] = None,
1812
+ return_dict: Optional[bool] = None,
1813
+ ) -> Union[Tuple, TokenClassifierOutput]:
1814
+ r"""
1815
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1816
+ Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
1817
+ """
1818
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1819
+
1820
+ outputs = self.bert(
1821
+ input_ids,
1822
+ attention_mask=attention_mask,
1823
+ token_type_ids=token_type_ids,
1824
+ position_ids=position_ids,
1825
+ head_mask=head_mask,
1826
+ inputs_embeds=inputs_embeds,
1827
+ output_attentions=output_attentions,
1828
+ output_hidden_states=output_hidden_states,
1829
+ return_dict=return_dict,
1830
+ )
1831
+
1832
+ sequence_output = outputs[0]
1833
+
1834
+ sequence_output = self.dropout(sequence_output)
1835
+ logits = self.classifier(sequence_output)
1836
+
1837
+ loss = None
1838
+ if labels is not None:
1839
+ loss_fct = CrossEntropyLoss()
1840
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1841
+
1842
+ if not return_dict:
1843
+ output = (logits,) + outputs[2:]
1844
+ return ((loss,) + output) if loss is not None else output
1845
+
1846
+ return TokenClassifierOutput(
1847
+ loss=loss,
1848
+ logits=logits,
1849
+ hidden_states=outputs.hidden_states,
1850
+ attentions=outputs.attentions,
1851
+ )
1852
+
1853
+
1854
+ @add_start_docstrings(
1855
+ """
1856
+ RetrievaBert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
1857
+ linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1858
+ """,
1859
+ RETRIEVA_BERT_START_DOCSTRING,
1860
+ )
1861
+ class RetrievaBertForQuestionAnswering(RetrievaBertPreTrainedModel):
1862
+ def __init__(self, config):
1863
+ super().__init__(config)
1864
+ self.num_labels = config.num_labels
1865
+
1866
+ self.bert = RetrievaBertModel(config, add_pooling_layer=False)
1867
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1868
+
1869
+ # Initialize weights and apply final processing
1870
+ self.post_init()
1871
+
1872
+ @add_start_docstrings_to_model_forward(RETRIEVA_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1873
+ @add_code_sample_docstrings(
1874
+ checkpoint=_CHECKPOINT_FOR_DOC,
1875
+ output_type=QuestionAnsweringModelOutput,
1876
+ config_class=_CONFIG_FOR_DOC,
1877
+ )
1878
+ def forward(
1879
+ self,
1880
+ input_ids: Optional[torch.LongTensor] = None,
1881
+ attention_mask: Optional[torch.FloatTensor] = None,
1882
+ token_type_ids: Optional[torch.LongTensor] = None,
1883
+ position_ids: Optional[torch.LongTensor] = None,
1884
+ head_mask: Optional[torch.FloatTensor] = None,
1885
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1886
+ start_positions: Optional[torch.LongTensor] = None,
1887
+ end_positions: Optional[torch.LongTensor] = None,
1888
+ output_attentions: Optional[bool] = None,
1889
+ output_hidden_states: Optional[bool] = None,
1890
+ return_dict: Optional[bool] = None,
1891
+ ) -> Union[Tuple, QuestionAnsweringModelOutput]:
1892
+ r"""
1893
+ start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1894
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1895
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1896
+ are not taken into account for computing the loss.
1897
+ end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1898
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1899
+ Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
1900
+ are not taken into account for computing the loss.
1901
+ """
1902
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1903
+
1904
+ outputs = self.bert(
1905
+ input_ids,
1906
+ attention_mask=attention_mask,
1907
+ token_type_ids=token_type_ids,
1908
+ position_ids=position_ids,
1909
+ head_mask=head_mask,
1910
+ inputs_embeds=inputs_embeds,
1911
+ output_attentions=output_attentions,
1912
+ output_hidden_states=output_hidden_states,
1913
+ return_dict=return_dict,
1914
+ )
1915
+
1916
+ sequence_output = outputs[0]
1917
+
1918
+ logits = self.qa_outputs(sequence_output)
1919
+ start_logits, end_logits = logits.split(1, dim=-1)
1920
+ start_logits = start_logits.squeeze(-1).contiguous()
1921
+ end_logits = end_logits.squeeze(-1).contiguous()
1922
+
1923
+ total_loss = None
1924
+ if start_positions is not None and end_positions is not None:
1925
+ # If we are on multi-GPU, split add a dimension
1926
+ if len(start_positions.size()) > 1:
1927
+ start_positions = start_positions.squeeze(-1)
1928
+ if len(end_positions.size()) > 1:
1929
+ end_positions = end_positions.squeeze(-1)
1930
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
1931
+ ignored_index = start_logits.size(1)
1932
+ start_positions = start_positions.clamp(0, ignored_index)
1933
+ end_positions = end_positions.clamp(0, ignored_index)
1934
+
1935
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
1936
+ start_loss = loss_fct(start_logits, start_positions)
1937
+ end_loss = loss_fct(end_logits, end_positions)
1938
+ total_loss = (start_loss + end_loss) / 2
1939
+
1940
+ if not return_dict:
1941
+ output = (start_logits, end_logits) + outputs[2:]
1942
+ return ((total_loss,) + output) if total_loss is not None else output
1943
+
1944
+ return QuestionAnsweringModelOutput(
1945
+ loss=total_loss,
1946
+ start_logits=start_logits,
1947
+ end_logits=end_logits,
1948
+ hidden_states=outputs.hidden_states,
1949
+ attentions=outputs.attentions,
1950
+ )