Fix InternLM2ForCausalLM does not support Flash Attention 2.0 yet (#3)
Browse files- Fix InternLM2ForCausalLM does not support Flash Attention 2.0 yet (6b6271256e90d4f97f1aa954ad3a046313b5f5d9)
Co-authored-by: kosung <[email protected]>
- modeling_internlm2.py +2 -0
modeling_internlm2.py
CHANGED
@@ -709,6 +709,8 @@ class InternLM2PreTrainedModel(PreTrainedModel):
|
|
709 |
supports_gradient_checkpointing = True
|
710 |
_no_split_modules = ['InternLM2DecoderLayer']
|
711 |
_skip_keys_device_placement = 'past_key_values'
|
|
|
|
|
712 |
|
713 |
def _init_weights(self, module):
|
714 |
std = self.config.initializer_range
|
|
|
709 |
supports_gradient_checkpointing = True
|
710 |
_no_split_modules = ['InternLM2DecoderLayer']
|
711 |
_skip_keys_device_placement = 'past_key_values'
|
712 |
+
_supports_flash_attn_2 = True
|
713 |
+
|
714 |
|
715 |
def _init_weights(self, module):
|
716 |
std = self.config.initializer_range
|