BK-Lee commited on
Commit
79c8c4a
1 Parent(s): 3ce2b2b
Files changed (1) hide show
  1. meteor/arch/modeling_internlm2.py +14 -198
meteor/arch/modeling_internlm2.py CHANGED
@@ -43,18 +43,18 @@ from .configuration_internlm2 import InternLM2Config
43
  logger = logging.get_logger(__name__)
44
 
45
  _CONFIG_FOR_DOC = 'InternLM2Config'
46
- flash_attn_func, flash_attn_varlen_func = None, None
47
- pad_input, index_first_axis, unpad_input = None, None, None
48
- def _import_flash_attn():
49
- global flash_attn_func, flash_attn_varlen_func
50
- global pad_input, index_first_axis, unpad_input
51
- try:
52
- from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
- from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
- flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
- pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
- except ImportError:
57
- raise ImportError("flash_attn is not installed.")
58
 
59
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
  def _get_unpad_data(attention_mask):
@@ -492,196 +492,12 @@ class InternLM2Attention(nn.Module):
492
 
493
  return attn_output, attn_weights, past_key_value
494
 
495
-
496
- class InternLM2FlashAttention2(InternLM2Attention):
497
- """InternLM2 flash attention module.
498
-
499
- This module inherits from `InternLM2Attention` as the weights of the module
500
- stays untouched. The only required change would be on the forward pass
501
- where it needs to correctly call the public API of flash attention and deal
502
- with padding tokens in case the input contains any of them.
503
- """
504
-
505
- def forward(
506
- self,
507
- hidden_states: torch.Tensor,
508
- attention_mask: Optional[torch.LongTensor] = None,
509
- position_ids: Optional[torch.LongTensor] = None,
510
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
511
- output_attentions: bool = False,
512
- use_cache: bool = False,
513
- im_mask: Optional[Tuple[torch.Tensor]] = None,
514
- **kwargs,
515
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
516
- Optional[Tuple[torch.Tensor]]]:
517
- # InternLM2FlashAttention2 attention does not support output_attentions
518
- if 'padding_mask' in kwargs:
519
- warnings.warn(
520
- 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
521
- 'Please make sure use `attention_mask` instead.`')
522
-
523
- # overwrite attention_mask with padding_mask
524
- attention_mask = kwargs.pop('padding_mask')
525
-
526
- output_attentions = False
527
-
528
- bsz, q_len, _ = hidden_states.size()
529
-
530
- qkv_states = self.wqkv(hidden_states, im_mask)
531
-
532
- qkv_states = rearrange(
533
- qkv_states,
534
- 'b q (h gs d) -> b q h gs d',
535
- gs=2 + self.num_key_value_groups,
536
- d=self.head_dim,
537
- q=q_len,
538
- )
539
-
540
- query_states = qkv_states[..., :self.num_key_value_groups, :]
541
- query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
542
- key_states = qkv_states[..., -2, :]
543
- value_states = qkv_states[..., -1, :]
544
- query_states = query_states.transpose(1, 2)
545
- key_states = key_states.transpose(1, 2)
546
- value_states = value_states.transpose(1, 2)
547
-
548
- kv_seq_len = key_states.shape[-2]
549
- if past_key_value is not None:
550
- kv_seq_len += past_key_value[0].shape[-2]
551
-
552
- cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
553
-
554
- query_states, key_states = apply_rotary_pos_emb(
555
- query_states, key_states, cos, sin, position_ids)
556
-
557
- if past_key_value is not None:
558
- # reuse k, v, self_attention
559
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
560
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
561
-
562
- past_key_value = (key_states, value_states) if use_cache else None
563
-
564
- query_states = query_states.transpose(1, 2)
565
- key_states = key_states.transpose(1, 2)
566
- value_states = value_states.transpose(1, 2)
567
-
568
- attn_output = self._flash_attention_forward(
569
- query_states,
570
- key_states,
571
- value_states,
572
- attention_mask,
573
- q_len)
574
-
575
- attn_output = attn_output.reshape(bsz, q_len,
576
- self.hidden_size).contiguous()
577
- attn_output = self.wo(attn_output, im_mask)
578
-
579
- if not output_attentions:
580
- attn_weights = None
581
-
582
- return attn_output, attn_weights, past_key_value
583
-
584
- def _flash_attention_forward(
585
- self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
586
- ):
587
- """
588
- Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
589
- first unpad the input, then computes the attention scores and pad the final attention scores.
590
- Args:
591
- query_states (`torch.Tensor`):
592
- Input query states to be passed to Flash Attention API
593
- key_states (`torch.Tensor`):
594
- Input key states to be passed to Flash Attention API
595
- value_states (`torch.Tensor`):
596
- Input value states to be passed to Flash Attention API
597
- attention_mask (`torch.Tensor`):
598
- The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
599
- position of padding tokens and 1 for the position of non-padding tokens.
600
- dropout (`int`, *optional*):
601
- Attention dropout
602
- softmax_scale (`float`, *optional*):
603
- The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
604
- """
605
- # Contains at least one padding token in the sequence
606
- causal = self.is_causal and query_length != 1
607
- if attention_mask is not None:
608
- batch_size = query_states.shape[0]
609
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
610
- query_states, key_states, value_states, attention_mask, query_length
611
- )
612
-
613
- cu_seqlens_q, cu_seqlens_k = cu_seq_lens
614
- max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
615
-
616
- attn_output_unpad = flash_attn_varlen_func(
617
- query_states,
618
- key_states,
619
- value_states,
620
- cu_seqlens_q=cu_seqlens_q,
621
- cu_seqlens_k=cu_seqlens_k,
622
- max_seqlen_q=max_seqlen_in_batch_q,
623
- max_seqlen_k=max_seqlen_in_batch_k,
624
- dropout_p=dropout,
625
- softmax_scale=softmax_scale,
626
- causal=causal,
627
- )
628
-
629
- attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
630
- else:
631
- attn_output = flash_attn_func(
632
- query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
633
- )
634
-
635
- return attn_output
636
-
637
- def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
638
- indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
639
- batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
640
-
641
- key_layer = index_first_axis(
642
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
643
- )
644
- value_layer = index_first_axis(
645
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
646
- )
647
-
648
- if query_length == kv_seq_len:
649
- query_layer = index_first_axis(
650
- query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
651
- )
652
- cu_seqlens_q = cu_seqlens_k
653
- max_seqlen_in_batch_q = max_seqlen_in_batch_k
654
- indices_q = indices_k
655
- elif query_length == 1:
656
- max_seqlen_in_batch_q = 1
657
- cu_seqlens_q = torch.arange(
658
- batch_size + 1, dtype=torch.int32, device=query_layer.device
659
- ) # There is a memcpy here, that is very bad.
660
- indices_q = cu_seqlens_q[:-1]
661
- query_layer = query_layer.squeeze(1)
662
- else:
663
- # The -q_len: slice assumes left padding.
664
- attention_mask = attention_mask[:, -query_length:]
665
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
666
-
667
- return (
668
- query_layer,
669
- key_layer,
670
- value_layer,
671
- indices_q.to(torch.int64),
672
- (cu_seqlens_q, cu_seqlens_k),
673
- (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
674
- )
675
-
676
  class InternLM2DecoderLayer(nn.Module):
677
 
678
  def __init__(self, config: InternLM2Config):
679
  super().__init__()
680
  self.hidden_size = config.hidden_size
681
- self.attention = (
682
- InternLM2Attention(config=config)
683
- if not getattr(config, 'attn_implementation')=="flash_attention_2" else
684
- InternLM2FlashAttention2(config=config))
685
  self.feed_forward = InternLM2MLP(config)
686
  self.attention_norm = InternLM2RMSNorm(
687
  config.hidden_size, eps=config.rms_norm_eps)
@@ -946,7 +762,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
946
 
947
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
 
949
- if self.config.attn_implementation: _import_flash_attn()
950
 
951
  # retrieve input_ids and inputs_embeds
952
  if input_ids is not None and inputs_embeds is not None:
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CONFIG_FOR_DOC = 'InternLM2Config'
46
+ # flash_attn_func, flash_attn_varlen_func = None, None
47
+ # pad_input, index_first_axis, unpad_input = None, None, None
48
+ # def _import_flash_attn():
49
+ # global flash_attn_func, flash_attn_varlen_func
50
+ # global pad_input, index_first_axis, unpad_input
51
+ # try:
52
+ # from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
+ # from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
+ # flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
+ # pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
+ # except ImportError:
57
+ # raise ImportError("flash_attn is not installed.")
58
 
59
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
  def _get_unpad_data(attention_mask):
 
492
 
493
  return attn_output, attn_weights, past_key_value
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  class InternLM2DecoderLayer(nn.Module):
496
 
497
  def __init__(self, config: InternLM2Config):
498
  super().__init__()
499
  self.hidden_size = config.hidden_size
500
+ self.attention = InternLM2Attention(config=config)
 
 
 
501
  self.feed_forward = InternLM2MLP(config)
502
  self.attention_norm = InternLM2RMSNorm(
503
  config.hidden_size, eps=config.rms_norm_eps)
 
762
 
763
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
764
 
765
+ # if self.config.attn_implementation: _import_flash_attn()
766
 
767
  # retrieve input_ids and inputs_embeds
768
  if input_ids is not None and inputs_embeds is not None: