makitanikaze commited on
Commit
6394525
1 Parent(s): dcb67a0

Delete modeling_p5.py

Browse files
Files changed (1) hide show
  1. modeling_p5.py +0 -456
modeling_p5.py DELETED
@@ -1,456 +0,0 @@
1
- from dataclasses import dataclass
2
-
3
- from transformers.models.t5.modeling_t5 import (
4
- T5Stack, T5Block, T5LayerNorm, T5LayerSelfAttention, T5LayerFF, T5LayerCrossAttention,
5
- T5PreTrainedModel, T5ForConditionalGeneration
6
- )
7
-
8
- import torch
9
- import torch.nn as nn
10
- from torch.nn import CrossEntropyLoss
11
-
12
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple
13
- import copy
14
-
15
- from transformers.modeling_outputs import ModelOutput, BaseModelOutput, BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput
16
- from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
17
- from transformers.utils import logging
18
- from transformers import BeamScorer, BeamSearchScorer
19
-
20
- logger = logging.get_logger(__name__)
21
-
22
- # The encoder for input token sequence
23
- class JointEncoder(T5Stack):
24
- def __init__(self, config, embed_tokens=None):
25
- super(T5Stack, self).__init__(config)
26
- self.config = config
27
-
28
- self.embed_tokens = embed_tokens
29
- self.is_decoder = self.config.is_decoder
30
- assert self.config.is_decoder is False
31
-
32
- self.block = nn.ModuleList(
33
- [T5Block(config, has_relative_attention_bias=(i == 0))
34
- for i in range(config.num_layers)]
35
- )
36
- self.final_layer_norm = T5LayerNorm(
37
- config.d_model, eps=config.layer_norm_epsilon)
38
- self.dropout = nn.Dropout(config.dropout_rate)
39
-
40
- ## Set maximum 512 whole words in a source text
41
- self.whole_word_embeddings = nn.Embedding(
42
- 512, config.d_model ## config.d_model is 768 for base
43
- )
44
- self.init_weights()
45
- self.model_parallel = False
46
- self.device_map = None
47
-
48
- def set_input_embeddings(self, new_embeddings):
49
- self.embed_tokens = new_embeddings
50
-
51
- def forward(
52
- self,
53
- input_ids=None,
54
- whole_word_ids=None,
55
- attention_mask=None,
56
- inputs_embeds=None,
57
- head_mask=None,
58
- past_key_values=None,
59
- use_cache=None,
60
- output_attentions=None,
61
- output_hidden_states=None,
62
- return_dict=None,
63
- ):
64
-
65
- if inputs_embeds is None:
66
- assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
67
- inputs_embeds = self.embed_tokens(input_ids) ### embedding step - add HERE ###
68
- if whole_word_ids is not None:
69
- whole_word_embeds = self.whole_word_embeddings(whole_word_ids)
70
- assert whole_word_embeds.shape[-1] == inputs_embeds.shape[-1]
71
- inputs_embeds = inputs_embeds + whole_word_embeds
72
-
73
- B, L = inputs_embeds.size()[:-1]
74
-
75
- if attention_mask is None:
76
- attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=inputs_embeds.dtype, device=inputs_embeds.device)
77
-
78
- # ourselves in which case we just need to make it broadcastable to all heads.
79
- extended_attention_mask = self.get_extended_attention_mask(
80
- attention_mask,
81
- (B, L),
82
- inputs_embeds.device)
83
-
84
- # initialize past_key_values with `None` if past does not exist
85
- if past_key_values is None:
86
- past_key_values = [None] * len(self.block)
87
-
88
- # Prepare head mask if needed
89
- head_mask = self.get_head_mask(head_mask, self.config.num_layers)
90
- present_key_value_states = () if use_cache else None
91
- all_hidden_states = () if output_hidden_states else None
92
- all_attentions = () if output_attentions else None
93
- all_cross_attentions = () if (output_attentions and self.is_decoder) else None
94
-
95
- hidden_states = self.dropout(inputs_embeds)
96
-
97
- if self.config.num_layers > 0:
98
-
99
- assert self.block[0].layer[0].SelfAttention.has_relative_attention_bias
100
-
101
- seq_length = L
102
- q_len = seq_length
103
- k_len = seq_length
104
-
105
- # [1, n_heads, Q_len, K_len]
106
- text_position_bias = self.block[0].layer[0].SelfAttention.compute_bias(
107
- L, L)
108
- num_heads = text_position_bias.size(1)
109
- position_bias = text_position_bias.new_zeros(
110
- 1, num_heads, seq_length, seq_length)
111
- position_bias[:, :, :L, :L] = text_position_bias
112
-
113
- position_bias = position_bias + extended_attention_mask
114
-
115
- for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)):
116
- layer_head_mask = head_mask[i]
117
- layer_outputs = layer_module(
118
- hidden_states,
119
- attention_mask=extended_attention_mask,
120
- position_bias=position_bias,
121
- encoder_hidden_states=None,
122
- encoder_attention_mask=None,
123
- encoder_decoder_position_bias=None,
124
- # head_mask=head_mask[i],
125
- layer_head_mask=layer_head_mask,
126
- past_key_value=past_key_value,
127
- use_cache=use_cache,
128
- output_attentions=output_attentions,
129
- )
130
-
131
- # layer_outputs is a tuple with:
132
- # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
133
- hidden_states, present_key_value_state = layer_outputs[:2]
134
-
135
- # We share the position biases between the layers - the first layer store them
136
- # layer_outputs = hidden-states, key-value-states (self-attention weights),
137
- # (self-attention position bias), (cross-attention weights), (cross-attention position bias)
138
-
139
- # position_bias = layer_outputs[2]
140
-
141
- # append next layer key value states
142
- if use_cache:
143
- present_key_value_states = present_key_value_states + \
144
- (present_key_value_state,)
145
-
146
- hidden_states = self.final_layer_norm(hidden_states)
147
- hidden_states = self.dropout(hidden_states)
148
-
149
- # Add last layer
150
- if output_hidden_states:
151
- all_hidden_states = all_hidden_states + (hidden_states,)
152
-
153
- if not return_dict:
154
- return tuple(
155
- v
156
- for v in [
157
- hidden_states,
158
- present_key_value_states,
159
- all_hidden_states,
160
- all_attentions,
161
- all_cross_attentions,
162
- ]
163
- if v is not None
164
- )
165
- return BaseModelOutputWithPastAndCrossAttentions(
166
- last_hidden_state=hidden_states,
167
- past_key_values=present_key_value_states,
168
- hidden_states=all_hidden_states,
169
- attentions=all_attentions,
170
- cross_attentions=all_cross_attentions,
171
- )
172
-
173
-
174
- class P5(T5ForConditionalGeneration):
175
- _keys_to_ignore_on_load_missing = [
176
- r"encoder\.embed_tokens\.weight",
177
- r"decoder\.embed_tokens\.weight",
178
- r"lm_head\.weight",
179
- ]
180
- _keys_to_ignore_on_load_unexpected = [
181
- r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
182
- ]
183
-
184
- def __init__(self, config):
185
- super(T5ForConditionalGeneration, self).__init__(config)
186
-
187
- self.config = config
188
-
189
- self.model_dim = config.d_model
190
-
191
- self.shared = nn.Embedding(config.vocab_size, config.d_model)
192
-
193
- encoder_config = copy.deepcopy(config)
194
- encoder_config.is_decoder = False
195
- encoder_config.use_cache = False
196
- encoder_config.is_encoder_decoder = False
197
-
198
- self.encoder = JointEncoder(encoder_config, self.shared)
199
-
200
- decoder_config = copy.deepcopy(config)
201
- decoder_config.is_decoder = True
202
- decoder_config.is_encoder_decoder = False
203
-
204
- self.decoder = T5Stack(decoder_config, self.shared)
205
-
206
- self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
207
-
208
- self.init_weights()
209
-
210
- self.model_parallel = False
211
- self.device_map = None
212
-
213
- def set_input_embeddings(self, new_embeddings):
214
- self.shared = new_embeddings
215
- self.encoder.set_input_embeddings(new_embeddings)
216
- self.decoder.set_input_embeddings(new_embeddings)
217
-
218
- def extend_vocab(self, vocab_size):
219
-
220
- new_shared = nn.Embedding(vocab_size, self.config.d_model)
221
- old_weight = self.shared.weight.data.detach().clone()
222
- old_vocab_size = old_weight.size(0)
223
- new_shared.weight.data[:old_vocab_size, :] = old_weight
224
- self.shared = new_shared
225
-
226
- new_lm_head = nn.Linear(self.config.d_model, vocab_size, bias=False)
227
- old_weight = self.lm_head.weight.data.detach().clone()
228
- old_vocab_size = old_weight.size(0)
229
- new_lm_head.weight.data[:old_vocab_size, :] = old_weight
230
- self.lm_head = new_lm_head
231
-
232
- self.encoder.embed_tokens = self.shared
233
- self.decoder.embed_tokens = self.shared
234
-
235
- self.lm_head.weight = self.shared.weight
236
-
237
- self.config.vocab_size = vocab_size
238
- self.encoder.config.vocab_size = vocab_size
239
- self.decoder.config.vocab_size = vocab_size
240
-
241
- def forward(
242
- self,
243
- input_ids=None,
244
- whole_word_ids=None,
245
- attention_mask=None,
246
- encoder_outputs=None,
247
- decoder_input_ids=None,
248
- decoder_attention_mask=None,
249
- past_key_values=None,
250
- use_cache=None,
251
- labels=None,
252
- inputs_embeds=None,
253
- decoder_inputs_embeds=None,
254
- head_mask=None,
255
- output_attentions=None,
256
- output_hidden_states=None,
257
- return_dict=None,
258
- reduce_loss=False,
259
-
260
- return_hidden_state=False,
261
-
262
- **kwargs,
263
- ):
264
-
265
- use_cache = use_cache if use_cache is not None else self.config.use_cache
266
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
267
-
268
- if encoder_outputs is None:
269
- encoder_outputs = self.encoder(
270
- input_ids=input_ids,
271
- whole_word_ids=whole_word_ids,
272
- attention_mask=attention_mask,
273
- inputs_embeds=inputs_embeds,
274
- head_mask=head_mask,
275
- output_attentions=output_attentions,
276
- output_hidden_states=output_hidden_states,
277
- return_dict=return_dict,
278
- )
279
- elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
280
- encoder_outputs = BaseModelOutput(
281
- last_hidden_state=encoder_outputs[0],
282
- hidden_states=encoder_outputs[1] if len(
283
- encoder_outputs) > 1 else None,
284
- attentions=encoder_outputs[2] if len(
285
- encoder_outputs) > 2 else None,
286
- )
287
-
288
- hidden_states = encoder_outputs[0]
289
-
290
- if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
291
- # get decoder inputs from shifting lm labels to the right
292
- decoder_input_ids = self._shift_right(labels)
293
-
294
- # If decoding with past key value states, only the last tokens
295
- # should be given as an input
296
- if past_key_values is not None:
297
- assert labels is None, "Decoder should not use cached key value states when training."
298
- if decoder_input_ids is not None:
299
- decoder_input_ids = decoder_input_ids[:, -1:]
300
- if decoder_inputs_embeds is not None:
301
- decoder_inputs_embeds = decoder_inputs_embeds[:, -1:]
302
-
303
- if attention_mask is None:
304
- attention_mask = input_ids.ne(self.config.pad_token_id).to(dtype=hidden_states.dtype, device=hidden_states.device)
305
- encoder_attention_mask = attention_mask
306
-
307
- # Decode
308
- decoder_outputs = self.decoder(
309
- input_ids=decoder_input_ids,
310
- attention_mask=decoder_attention_mask,
311
- inputs_embeds=decoder_inputs_embeds,
312
- past_key_values=past_key_values,
313
-
314
- encoder_hidden_states=hidden_states,
315
- encoder_attention_mask=encoder_attention_mask,
316
-
317
- head_mask=head_mask,
318
- use_cache=use_cache,
319
- output_attentions=output_attentions,
320
- output_hidden_states=output_hidden_states,
321
- return_dict=return_dict,
322
- )
323
-
324
- sequence_output = decoder_outputs[0]
325
-
326
- assert self.config.tie_word_embeddings is True
327
-
328
- if self.config.tie_word_embeddings:
329
- sequence_output = sequence_output * (self.model_dim ** -0.5)
330
-
331
- if return_hidden_state:
332
- return sequence_output
333
-
334
- lm_logits = self.lm_head(sequence_output)
335
-
336
- loss = None
337
- if labels is not None:
338
- if reduce_loss:
339
- loss_fct = CrossEntropyLoss(ignore_index=-100)
340
- else:
341
- loss_fct = CrossEntropyLoss(ignore_index=-100, reduction='none')
342
- loss = loss_fct(
343
- lm_logits.view(-1, lm_logits.size(-1)),
344
- labels.view(-1))
345
-
346
- return P5Seq2SeqLMOutput(
347
- loss=loss,
348
- logits=lm_logits,
349
- past_key_values=decoder_outputs.past_key_values,
350
- decoder_last_hidden_state=decoder_outputs.last_hidden_state,
351
- decoder_hidden_states=decoder_outputs.hidden_states,
352
- )
353
-
354
- def prepare_inputs_for_generation(
355
- self, input_ids, past=None, attention_mask=None, use_cache=None,
356
- encoder_outputs=None,
357
- **kwargs):
358
-
359
- if past is not None:
360
- input_ids = input_ids[:, -1:]
361
-
362
- output = {
363
- "decoder_input_ids": input_ids,
364
- "past_key_values": past,
365
- "encoder_outputs": encoder_outputs,
366
- "attention_mask": attention_mask,
367
- "use_cache": use_cache,
368
- }
369
-
370
- return output
371
-
372
- @staticmethod
373
- def _expand_inputs_for_generation(
374
- input_ids: torch.LongTensor,
375
- expand_size: int = 1,
376
- is_encoder_decoder: bool = False,
377
- attention_mask: torch.LongTensor = None,
378
- encoder_outputs: ModelOutput = None,
379
- **model_kwargs
380
- ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
381
- expanded_return_idx = (
382
- torch.arange(input_ids.shape[0]).view(-1, 1).repeat(1,
383
- expand_size).view(-1).to(input_ids.device)
384
- )
385
- input_ids = input_ids.index_select(0, expanded_return_idx)
386
-
387
- if "token_type_ids" in model_kwargs:
388
- token_type_ids = model_kwargs["token_type_ids"]
389
- model_kwargs["token_type_ids"] = token_type_ids.index_select(
390
- 0, expanded_return_idx)
391
-
392
- if attention_mask is not None:
393
- model_kwargs["attention_mask"] = attention_mask.index_select(
394
- 0, expanded_return_idx)
395
-
396
- if is_encoder_decoder:
397
- assert encoder_outputs is not None
398
- encoder_outputs["last_hidden_state"] = encoder_outputs.last_hidden_state.index_select(
399
- 0, expanded_return_idx
400
- )
401
- model_kwargs["encoder_outputs"] = encoder_outputs
402
-
403
- return input_ids, model_kwargs
404
-
405
-
406
- @dataclass
407
- class P5Seq2SeqLMOutput(ModelOutput):
408
- """
409
- Base class for sequence-to-sequence language models outputs.
410
-
411
- Args:
412
- loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
413
- Languaged modeling loss.
414
- logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
415
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
416
- past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
417
- List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape
418
- :obj:`(2, batch_size, num_heads, sequence_length, embed_size_per_head)`).
419
-
420
- Contains pre-computed hidden-states (key and values in the attention blocks) of the decoder that can be
421
- used (see ``past_key_values`` input) to speed up sequential decoding.
422
- decoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
423
- Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
424
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
425
-
426
- Hidden-states of the decoder at the output of each layer plus the initial embedding outputs.
427
- decoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
428
- Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
429
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
430
-
431
- Attentions weights of the decoder, after the attention softmax, used to compute the weighted average in the
432
- self-attention heads.
433
- encoder_last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
434
- Sequence of hidden-states at the output of the last layer of the encoder of the model.
435
- encoder_hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
436
- Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
437
- of shape :obj:`(batch_size, sequence_length, hidden_size)`.
438
-
439
- Hidden-states of the encoder at the output of each layer plus the initial embedding outputs.
440
- encoder_attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
441
- Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
442
- :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
443
-
444
- Attentions weights of the encoder, after the attention softmax, used to compute the weighted average in the
445
- self-attention heads.
446
- """
447
-
448
- loss: Optional[torch.FloatTensor] = None
449
- logits: torch.FloatTensor = None
450
- past_key_values: Optional[List[torch.FloatTensor]] = None
451
- decoder_last_hidden_state: Optional[Tuple[torch.FloatTensor]] = None
452
- decoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
453
- decoder_attentions: Optional[Tuple[torch.FloatTensor]] = None
454
- encoder_last_hidden_state: Optional[torch.FloatTensor] = None
455
- encoder_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
456
- encoder_attentions: Optional[Tuple[torch.FloatTensor]] = None