efederici commited on
Commit
0792bb4
1 Parent(s): d25a04f

Update modeling_mpt.py

Browse files
Files changed (1) hide show
  1. modeling_mpt.py +55 -51
modeling_mpt.py CHANGED
@@ -4,26 +4,31 @@ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
  """
5
  import math
6
  import warnings
7
- from typing import List, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
- from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
  from .custom_embedding import SharedEmbedding
 
 
 
 
16
  from .norm import NORM_CLASS_REGISTRY
17
  from .configuration_mpt import MPTConfig
18
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
19
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
20
  from .meta_init_context import init_empty_weights
21
- from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
22
  try:
23
- from .flash_attn_triton import flash_attn_func
24
  except:
25
  pass
26
- Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
 
27
 
28
  class MPTPreTrainedModel(PreTrainedModel):
29
  config_class = MPTConfig
@@ -40,6 +45,7 @@ class MPTModel(MPTPreTrainedModel):
40
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
41
  self.alibi = config.attn_config['alibi']
42
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
 
43
  if config.init_device == 'mixed':
44
  if dist.get_local_rank() == 0:
45
  config.init_device = 'cpu'
@@ -51,13 +57,13 @@ class MPTModel(MPTPreTrainedModel):
51
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
52
  self.embedding_fraction = config.embedding_fraction
53
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
54
- if not self.alibi:
55
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
56
  self.emb_drop = nn.Dropout(config.emb_pdrop)
57
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
58
  self.norm_f = norm_class(config.d_model, device=config.init_device)
59
  if config.init_device != 'meta':
60
- print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
61
  self.apply(self.param_init_fn)
62
  self.is_causal = not self.prefix_lm
63
  self._attn_bias_initialized = False
@@ -66,25 +72,22 @@ class MPTModel(MPTPreTrainedModel):
66
  if config.no_bias:
67
  for module in self.modules():
68
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
69
- if config.verbose:
70
- warnings.warn(f'Removing bias ({module.bias}) from {module}.')
71
  module.register_parameter('bias', None)
72
- if config.verbose and config.verbose > 2:
73
- print(self)
74
- if 'verbose' not in self.config.init_config:
75
- self.config.init_config['verbose'] = self.config.verbose
76
- if self.config.init_config['verbose'] > 1:
77
- init_fn_name = self.config.init_config['name']
78
- warnings.warn(f'Using {init_fn_name} initialization.')
79
 
80
- def get_input_embeddings(self):
81
  return self.wte
82
 
83
- def set_input_embeddings(self, value):
84
  self.wte = value
85
 
86
  @torch.no_grad()
87
- def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
88
  if not self._attn_bias_initialized:
89
  if self.attn_bias_shape:
90
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
@@ -115,7 +118,7 @@ class MPTModel(MPTPreTrainedModel):
115
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
116
  return (attn_bias, None)
117
 
118
- def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
119
  (s_k, s_q) = attn_bias.shape[-2:]
120
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
121
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
@@ -130,7 +133,7 @@ class MPTModel(MPTPreTrainedModel):
130
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
131
  return attn_bias
132
 
133
- def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
134
  seq_len = sequence_id.shape[-1]
135
  if seq_len > self.config.max_seq_len:
136
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
@@ -140,7 +143,7 @@ class MPTModel(MPTPreTrainedModel):
140
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
141
  return attn_bias
142
 
143
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None):
144
  return_dict = return_dict if return_dict is not None else self.config.return_dict
145
  use_cache = use_cache if use_cache is not None else self.config.use_cache
146
  if attention_mask is not None:
@@ -152,7 +155,7 @@ class MPTModel(MPTPreTrainedModel):
152
  if output_attentions:
153
  if self.attn_impl != 'torch':
154
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
155
- if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
156
  raise NotImplementedError('MPT does not support training with left padding.')
157
  if self.prefix_lm and prefix_mask is None:
158
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
@@ -166,9 +169,7 @@ class MPTModel(MPTPreTrainedModel):
166
  S = input_ids.size(1)
167
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
168
  tok_emb = self.wte(input_ids)
169
- if self.alibi:
170
- x = tok_emb
171
- else:
172
  past_position = 0
173
  if past_key_values is not None:
174
  if len(past_key_values) != self.config.n_layers:
@@ -177,12 +178,14 @@ class MPTModel(MPTPreTrainedModel):
177
  if self.attn_impl == 'torch':
178
  past_position = past_key_values[0][0].size(3)
179
  if S + past_position > self.config.max_seq_len:
180
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
181
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
182
  if attention_mask is not None:
183
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
184
  pos_emb = self.wpe(pos)
185
  x = tok_emb + pos_emb
 
 
186
  if self.embedding_fraction == 1:
187
  x = self.emb_drop(x)
188
  else:
@@ -190,6 +193,7 @@ class MPTModel(MPTPreTrainedModel):
190
  assert isinstance(self.emb_drop, nn.Module)
191
  x = self.emb_drop(x_shrunk)
192
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
 
193
  if use_cache and past_key_values is None:
194
  past_key_values = [() for _ in range(self.config.n_layers)]
195
  all_hidden_states = () if output_hidden_states else None
@@ -199,9 +203,9 @@ class MPTModel(MPTPreTrainedModel):
199
  assert all_hidden_states is not None
200
  all_hidden_states = all_hidden_states + (x,)
201
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
202
- (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
203
- if past_key_values is not None:
204
- past_key_values[b_idx] = past_key_value
205
  if output_attentions:
206
  assert all_self_attns is not None
207
  all_self_attns = all_self_attns + (attn_weights,)
@@ -209,16 +213,16 @@ class MPTModel(MPTPreTrainedModel):
209
  if output_hidden_states:
210
  assert all_hidden_states is not None
211
  all_hidden_states = all_hidden_states + (x,)
212
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
213
 
214
- def param_init_fn(self, module):
215
  init_fn_name = self.config.init_config['name']
216
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
217
 
218
- def fsdp_wrap_fn(self, module):
219
  return isinstance(module, MPTBlock)
220
 
221
- def activation_checkpointing_fn(self, module):
222
  return isinstance(module, MPTBlock)
223
 
224
  class MPTForCausalLM(MPTPreTrainedModel):
@@ -227,8 +231,8 @@ class MPTForCausalLM(MPTPreTrainedModel):
227
  super().__init__(config)
228
  if not config.tie_word_embeddings:
229
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
230
- print(f'Instantiating an MPTForCausalLM model from {__file__}')
231
- self.transformer = MPTModel(config)
232
  for child in self.transformer.children():
233
  if isinstance(child, torch.nn.ModuleList):
234
  continue
@@ -244,25 +248,25 @@ class MPTForCausalLM(MPTPreTrainedModel):
244
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
245
  self.logit_scale = logit_scale
246
 
247
- def get_input_embeddings(self):
248
  return self.transformer.wte
249
 
250
- def set_input_embeddings(self, value):
251
  self.transformer.wte = value
252
 
253
- def get_output_embeddings(self):
254
  return self.transformer.wte
255
 
256
- def set_output_embeddings(self, new_embeddings):
257
  self.transformer.wte = new_embeddings
258
 
259
- def set_decoder(self, decoder):
260
  self.transformer = decoder
261
 
262
- def get_decoder(self):
263
  return self.transformer
264
 
265
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None):
266
  return_dict = return_dict if return_dict is not None else self.config.return_dict
267
  use_cache = use_cache if use_cache is not None else self.config.use_cache
268
  if inputs_embeds is not None:
@@ -275,22 +279,22 @@ class MPTForCausalLM(MPTPreTrainedModel):
275
  logits *= self.logit_scale
276
  loss = None
277
  if labels is not None:
278
- labels = torch.roll(labels, shifts=-1)
279
- labels[:, -1] = -100
280
- loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
281
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
282
 
283
- def param_init_fn(self, module):
284
  init_fn_name = self.config.init_config['name']
285
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
286
 
287
- def fsdp_wrap_fn(self, module):
288
  return isinstance(module, MPTBlock)
289
 
290
- def activation_checkpointing_fn(self, module):
291
  return isinstance(module, MPTBlock)
292
 
293
- def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
294
  if inputs_embeds is not None:
295
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
296
  attention_mask = kwargs['attention_mask'].bool()
@@ -311,7 +315,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
311
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
312
 
313
  @staticmethod
314
- def _reorder_cache(past_key_values, beam_idx):
315
  """Used by HuggingFace generate when using beam search with kv-caching.
316
 
317
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
 
4
  """
5
  import math
6
  import warnings
7
+ from typing import Any, Dict, List, Mapping, MutableMapping, Optional, Tuple, Union
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from transformers import PreTrainedModel, PreTrainedTokenizerBase
12
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
13
  from .attention import attn_bias_shape, build_attn_bias
14
  from .blocks import MPTBlock
15
  from .custom_embedding import SharedEmbedding
16
+ from .fc import FC_CLASS_REGISTRY as FC_CLASS_REGISTRY
17
+ from .ffn import FFN_CLASS_REGISTRY as FFN_CLASS_REGISTRY
18
+ from .ffn import MPTMLP as MPTMLP
19
+ from .ffn import build_ffn as build_ffn
20
  from .norm import NORM_CLASS_REGISTRY
21
  from .configuration_mpt import MPTConfig
22
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
23
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
24
  from .meta_init_context import init_empty_weights
25
+ from .param_init_fns import generic_param_init_fn_, MODEL_INIT_REGISTRY
26
  try:
27
+ from .flash_attn_triton import flash_attn_func as flash_attn_func
28
  except:
29
  pass
30
+ import logging
31
+ log = logging.getLogger(__name__)
32
 
33
  class MPTPreTrainedModel(PreTrainedModel):
34
  config_class = MPTConfig
 
45
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
46
  self.alibi = config.attn_config['alibi']
47
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
48
+ self.learned_pos_emb = config.learned_pos_emb
49
  if config.init_device == 'mixed':
50
  if dist.get_local_rank() == 0:
51
  config.init_device = 'cpu'
 
57
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
58
  self.embedding_fraction = config.embedding_fraction
59
  self.wte = SharedEmbedding(config.vocab_size, config.d_model, device=config.init_device)
60
+ if self.learned_pos_emb:
61
  self.wpe = torch.nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
62
  self.emb_drop = nn.Dropout(config.emb_pdrop)
63
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
64
  self.norm_f = norm_class(config.d_model, device=config.init_device)
65
  if config.init_device != 'meta':
66
+ log.info(f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.')
67
  self.apply(self.param_init_fn)
68
  self.is_causal = not self.prefix_lm
69
  self._attn_bias_initialized = False
 
72
  if config.no_bias:
73
  for module in self.modules():
74
  if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
75
+ log.info(f'Removing bias ({module.bias}) from {module}.')
 
76
  module.register_parameter('bias', None)
77
+ if hasattr(module, 'use_bias'):
78
+ log.info(f'Setting use_bias=False for {module}.')
79
+ module.use_bias = False
80
+ log.debug(self)
81
+ log.debug(f"Using {self.config.init_config['name']} initialization.")
 
 
82
 
83
+ def get_input_embeddings(self) -> nn.Embedding:
84
  return self.wte
85
 
86
+ def set_input_embeddings(self, value: nn.Embedding) -> None:
87
  self.wte = value
88
 
89
  @torch.no_grad()
90
+ def _attn_bias(self, device: torch.device, dtype: torch.dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None) -> Tuple[Optional[torch.Tensor], Optional[torch.ByteTensor]]:
91
  if not self._attn_bias_initialized:
92
  if self.attn_bias_shape:
93
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
 
118
  attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
119
  return (attn_bias, None)
120
 
121
+ def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor) -> torch.Tensor:
122
  (s_k, s_q) = attn_bias.shape[-2:]
123
  if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
124
  raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
 
133
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
134
  return attn_bias
135
 
136
+ def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor) -> torch.Tensor:
137
  seq_len = sequence_id.shape[-1]
138
  if seq_len > self.config.max_seq_len:
139
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
 
143
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
144
  return attn_bias
145
 
146
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.Tensor]=None) -> BaseModelOutputWithPast:
147
  return_dict = return_dict if return_dict is not None else self.config.return_dict
148
  use_cache = use_cache if use_cache is not None else self.config.use_cache
149
  if attention_mask is not None:
 
155
  if output_attentions:
156
  if self.attn_impl != 'torch':
157
  raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
158
+ if self.training and attention_mask is not None and (attention_mask[:, 0].sum() != attention_mask.shape[0]):
159
  raise NotImplementedError('MPT does not support training with left padding.')
160
  if self.prefix_lm and prefix_mask is None:
161
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
 
169
  S = input_ids.size(1)
170
  assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
171
  tok_emb = self.wte(input_ids)
172
+ if self.learned_pos_emb:
 
 
173
  past_position = 0
174
  if past_key_values is not None:
175
  if len(past_key_values) != self.config.n_layers:
 
178
  if self.attn_impl == 'torch':
179
  past_position = past_key_values[0][0].size(3)
180
  if S + past_position > self.config.max_seq_len:
181
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length ' + f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
182
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
183
  if attention_mask is not None:
184
  pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
185
  pos_emb = self.wpe(pos)
186
  x = tok_emb + pos_emb
187
+ else:
188
+ x = tok_emb
189
  if self.embedding_fraction == 1:
190
  x = self.emb_drop(x)
191
  else:
 
193
  assert isinstance(self.emb_drop, nn.Module)
194
  x = self.emb_drop(x_shrunk)
195
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=torch.float32, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
196
+ presents = () if use_cache else None
197
  if use_cache and past_key_values is None:
198
  past_key_values = [() for _ in range(self.config.n_layers)]
199
  all_hidden_states = () if output_hidden_states else None
 
203
  assert all_hidden_states is not None
204
  all_hidden_states = all_hidden_states + (x,)
205
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
206
+ (x, attn_weights, present) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal, output_attentions=bool(output_attentions))
207
+ if presents is not None:
208
+ presents += (present,)
209
  if output_attentions:
210
  assert all_self_attns is not None
211
  all_self_attns = all_self_attns + (attn_weights,)
 
213
  if output_hidden_states:
214
  assert all_hidden_states is not None
215
  all_hidden_states = all_hidden_states + (x,)
216
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=presents, hidden_states=all_hidden_states, attentions=all_self_attns)
217
 
218
+ def param_init_fn(self, module: nn.Module) -> None:
219
  init_fn_name = self.config.init_config['name']
220
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
221
 
222
+ def fsdp_wrap_fn(self, module: nn.Module) -> bool:
223
  return isinstance(module, MPTBlock)
224
 
225
+ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
226
  return isinstance(module, MPTBlock)
227
 
228
  class MPTForCausalLM(MPTPreTrainedModel):
 
231
  super().__init__(config)
232
  if not config.tie_word_embeddings:
233
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
234
+ log.info(f'Instantiating an MPTForCausalLM model from {__file__}')
235
+ self.transformer: MPTModel = MPTModel(config)
236
  for child in self.transformer.children():
237
  if isinstance(child, torch.nn.ModuleList):
238
  continue
 
248
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
249
  self.logit_scale = logit_scale
250
 
251
+ def get_input_embeddings(self) -> nn.Embedding:
252
  return self.transformer.wte
253
 
254
+ def set_input_embeddings(self, value: Union[SharedEmbedding, nn.Embedding]) -> None:
255
  self.transformer.wte = value
256
 
257
+ def get_output_embeddings(self) -> nn.Embedding:
258
  return self.transformer.wte
259
 
260
+ def set_output_embeddings(self, new_embeddings: Union[SharedEmbedding, nn.Embedding]) -> None:
261
  self.transformer.wte = new_embeddings
262
 
263
+ def set_decoder(self, decoder: MPTModel) -> None:
264
  self.transformer = decoder
265
 
266
+ def get_decoder(self) -> MPTModel:
267
  return self.transformer
268
 
269
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, inputs_embeds: Optional[torch.FloatTensor]=None) -> CausalLMOutputWithPast:
270
  return_dict = return_dict if return_dict is not None else self.config.return_dict
271
  use_cache = use_cache if use_cache is not None else self.config.use_cache
272
  if inputs_embeds is not None:
 
279
  logits *= self.logit_scale
280
  loss = None
281
  if labels is not None:
282
+ _labels = torch.roll(labels, shifts=-1)
283
+ _labels[:, -1] = -100
284
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), _labels.to(logits.device).view(-1))
285
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
286
 
287
+ def param_init_fn(self, module: nn.Module) -> None:
288
  init_fn_name = self.config.init_config['name']
289
  MODEL_INIT_REGISTRY[init_fn_name](module=module, n_layers=self.config.n_layers, d_model=self.config.d_model, **self.config.init_config)
290
 
291
+ def fsdp_wrap_fn(self, module: nn.Module) -> bool:
292
  return isinstance(module, MPTBlock)
293
 
294
+ def activation_checkpointing_fn(self, module: nn.Module) -> bool:
295
  return isinstance(module, MPTBlock)
296
 
297
+ def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]=None, inputs_embeds: Optional[torch.Tensor]=None, **kwargs: Any) -> Dict[str, Any]:
298
  if inputs_embeds is not None:
299
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
300
  attention_mask = kwargs['attention_mask'].bool()
 
315
  return {'input_ids': input_ids, 'attention_mask': attention_mask, 'prefix_mask': prefix_mask, 'sequence_id': sequence_id, 'past_key_values': past_key_values, 'use_cache': kwargs.get('use_cache', True)}
316
 
317
  @staticmethod
318
+ def _reorder_cache(past_key_values: List[Tuple[torch.Tensor, torch.Tensor]], beam_idx: torch.LongTensor) -> List[Tuple[torch.Tensor, ...]]:
319
  """Used by HuggingFace generate when using beam search with kv-caching.
320
 
321
  See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133