gmastrapas commited on
Commit
5b03f30
1 Parent(s): 2646361

feat: expose configuration of use_reentrant

Browse files
configuration_xlm_roberta.py CHANGED
@@ -25,6 +25,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
25
  position_embedding_type: str = "rotary",
26
  rotary_emb_base: float = 10000.0,
27
  use_cache: bool = True,
 
28
  classifier_dropout: Optional[float] = None,
29
  lora_adaptations: Optional[List[str]] = None,
30
  lora_prompts: Optional[Dict[str, str]] = None,
@@ -62,6 +63,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
62
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
63
  rotary_emb_base (float): Base for rotary embeddings.
64
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
 
65
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
66
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
67
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
@@ -100,6 +102,7 @@ class XLMRobertaFlashConfig(PretrainedConfig):
100
  self.position_embedding_type = position_embedding_type
101
  self.rotary_emb_base = rotary_emb_base
102
  self.use_cache = use_cache
 
103
  self.classifier_dropout = classifier_dropout
104
  self.load_trained_adapters = load_trained_adapters
105
  self.lora_adaptations = lora_adaptations
 
25
  position_embedding_type: str = "rotary",
26
  rotary_emb_base: float = 10000.0,
27
  use_cache: bool = True,
28
+ use_reentrant: bool = False,
29
  classifier_dropout: Optional[float] = None,
30
  lora_adaptations: Optional[List[str]] = None,
31
  lora_prompts: Optional[Dict[str, str]] = None,
 
63
  position_embedding_type (str): Type of position embeddings. Options are 'absolute', 'alibi', or 'rotary'.
64
  rotary_emb_base (float): Base for rotary embeddings.
65
  use_cache (bool): Whether or not the model should return the last key/values attentions (not used by all models).
66
+ use_reentrant (bool): Whether or not the model should enable the 'use_reentrant' flag in gradient checkpointing.
67
  classifier_dropout (Optional[float]): The dropout ratio for the classification head.
68
  lora_adaptations (Optional[List[str]]): LoRA adaptations configuration.
69
  lora_prompts (Optional[Dict[str, str]]): LoRA prompts configuration.
 
102
  self.position_embedding_type = position_embedding_type
103
  self.rotary_emb_base = rotary_emb_base
104
  self.use_cache = use_cache
105
+ self.use_reentrant = use_reentrant
106
  self.classifier_dropout = classifier_dropout
107
  self.load_trained_adapters = load_trained_adapters
108
  self.lora_adaptations = lora_adaptations
modeling_xlm_roberta.py CHANGED
@@ -181,6 +181,7 @@ class XLMRobertaEncoder(nn.Module):
181
  def __init__(self, config: XLMRobertaFlashConfig):
182
  super().__init__()
183
  self.use_flash_attn = get_use_flash_attn(config)
 
184
  self.layers = nn.ModuleList(
185
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
186
  )
@@ -210,7 +211,7 @@ class XLMRobertaEncoder(nn.Module):
210
  hidden_states = torch.utils.checkpoint.checkpoint(
211
  layer,
212
  hidden_states,
213
- use_reentrant=False,
214
  mixer_kwargs=mixer_kwargs,
215
  )
216
  else:
@@ -234,7 +235,7 @@ class XLMRobertaEncoder(nn.Module):
234
  hidden_states = torch.utils.checkpoint.checkpoint(
235
  layer,
236
  hidden_states,
237
- use_reentrant=False,
238
  mixer_kwargs=mixer_kwargs,
239
  )
240
  else:
@@ -246,7 +247,7 @@ class XLMRobertaEncoder(nn.Module):
246
  hidden_states = torch.utils.checkpoint.checkpoint(
247
  layer,
248
  hidden_states,
249
- use_reentrant=False,
250
  mixer_kwargs=mixer_kwargs,
251
  )
252
  else:
@@ -284,7 +285,7 @@ class XLMRobertaEncoder(nn.Module):
284
  torch.utils.checkpoint.checkpoint(
285
  self.layers[-1],
286
  hidden_states_subset,
287
- use_reentrant=False,
288
  mixer_kwargs=mixer_kwargs,
289
  )
290
  else:
 
181
  def __init__(self, config: XLMRobertaFlashConfig):
182
  super().__init__()
183
  self.use_flash_attn = get_use_flash_attn(config)
184
+ self.use_reentrant = config.use_reentrant
185
  self.layers = nn.ModuleList(
186
  [create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
187
  )
 
211
  hidden_states = torch.utils.checkpoint.checkpoint(
212
  layer,
213
  hidden_states,
214
+ use_reentrant=self.use_reentrant,
215
  mixer_kwargs=mixer_kwargs,
216
  )
217
  else:
 
235
  hidden_states = torch.utils.checkpoint.checkpoint(
236
  layer,
237
  hidden_states,
238
+ use_reentrant=self.use_reentrant,
239
  mixer_kwargs=mixer_kwargs,
240
  )
241
  else:
 
247
  hidden_states = torch.utils.checkpoint.checkpoint(
248
  layer,
249
  hidden_states,
250
+ use_reentrant=self.use_reentrant,
251
  mixer_kwargs=mixer_kwargs,
252
  )
253
  else:
 
285
  torch.utils.checkpoint.checkpoint(
286
  self.layers[-1],
287
  hidden_states_subset,
288
+ use_reentrant=self.use_reentrant,
289
  mixer_kwargs=mixer_kwargs,
290
  )
291
  else: