BK-Lee commited on
Commit
cbb062a
1 Parent(s): ab82892
Files changed (1) hide show
  1. meteor/arch/modeling_internlm2.py +198 -14
meteor/arch/modeling_internlm2.py CHANGED
@@ -43,18 +43,18 @@ from .configuration_internlm2 import InternLM2Config
43
  logger = logging.get_logger(__name__)
44
 
45
  _CONFIG_FOR_DOC = 'InternLM2Config'
46
- # flash_attn_func, flash_attn_varlen_func = None, None
47
- # pad_input, index_first_axis, unpad_input = None, None, None
48
- # def _import_flash_attn():
49
- # global flash_attn_func, flash_attn_varlen_func
50
- # global pad_input, index_first_axis, unpad_input
51
- # try:
52
- # from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
- # from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
- # flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
- # pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
- # except ImportError:
57
- # raise ImportError("flash_attn is not installed.")
58
 
59
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
  def _get_unpad_data(attention_mask):
@@ -492,12 +492,196 @@ class InternLM2Attention(nn.Module):
492
 
493
  return attn_output, attn_weights, past_key_value
494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
  class InternLM2DecoderLayer(nn.Module):
496
 
497
  def __init__(self, config: InternLM2Config):
498
  super().__init__()
499
  self.hidden_size = config.hidden_size
500
- self.attention = InternLM2Attention(config=config)
 
 
 
501
  self.feed_forward = InternLM2MLP(config)
502
  self.attention_norm = InternLM2RMSNorm(
503
  config.hidden_size, eps=config.rms_norm_eps)
@@ -762,7 +946,7 @@ class InternLM2Model(InternLM2PreTrainedModel):
762
 
763
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
764
 
765
- # if self.config.attn_implementation: _import_flash_attn()
766
 
767
  # retrieve input_ids and inputs_embeds
768
  if input_ids is not None and inputs_embeds is not None:
 
43
  logger = logging.get_logger(__name__)
44
 
45
  _CONFIG_FOR_DOC = 'InternLM2Config'
46
+ flash_attn_func, flash_attn_varlen_func = None, None
47
+ pad_input, index_first_axis, unpad_input = None, None, None
48
+ def _import_flash_attn():
49
+ global flash_attn_func, flash_attn_varlen_func
50
+ global pad_input, index_first_axis, unpad_input
51
+ try:
52
+ from flash_attn import flash_attn_func as _flash_attn_func, flash_attn_varlen_func as _flash_attn_varlen_func
53
+ from flash_attn.bert_padding import pad_input as _pad_input, index_first_axis as _index_first_axis, unpad_input as _unpad_input
54
+ flash_attn_func, flash_attn_varlen_func = _flash_attn_func, _flash_attn_varlen_func
55
+ pad_input, index_first_axis, unpad_input = _pad_input, _index_first_axis, _unpad_input
56
+ except ImportError:
57
+ raise ImportError("flash_attn is not installed.")
58
 
59
  # Copied from transformers.models.llama.modeling_llama._get_unpad_data
60
  def _get_unpad_data(attention_mask):
 
492
 
493
  return attn_output, attn_weights, past_key_value
494
 
495
+
496
+ class InternLM2FlashAttention2(InternLM2Attention):
497
+ """InternLM2 flash attention module.
498
+
499
+ This module inherits from `InternLM2Attention` as the weights of the module
500
+ stays untouched. The only required change would be on the forward pass
501
+ where it needs to correctly call the public API of flash attention and deal
502
+ with padding tokens in case the input contains any of them.
503
+ """
504
+
505
+ def forward(
506
+ self,
507
+ hidden_states: torch.Tensor,
508
+ attention_mask: Optional[torch.LongTensor] = None,
509
+ position_ids: Optional[torch.LongTensor] = None,
510
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
511
+ output_attentions: bool = False,
512
+ use_cache: bool = False,
513
+ im_mask: Optional[Tuple[torch.Tensor]] = None,
514
+ **kwargs,
515
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor],
516
+ Optional[Tuple[torch.Tensor]]]:
517
+ # InternLM2FlashAttention2 attention does not support output_attentions
518
+ if 'padding_mask' in kwargs:
519
+ warnings.warn(
520
+ 'Passing `padding_mask` is deprecated and will be removed in v4.37. '
521
+ 'Please make sure use `attention_mask` instead.`')
522
+
523
+ # overwrite attention_mask with padding_mask
524
+ attention_mask = kwargs.pop('padding_mask')
525
+
526
+ output_attentions = False
527
+
528
+ bsz, q_len, _ = hidden_states.size()
529
+
530
+ qkv_states = self.wqkv(hidden_states, im_mask)
531
+
532
+ qkv_states = rearrange(
533
+ qkv_states,
534
+ 'b q (h gs d) -> b q h gs d',
535
+ gs=2 + self.num_key_value_groups,
536
+ d=self.head_dim,
537
+ q=q_len,
538
+ )
539
+
540
+ query_states = qkv_states[..., :self.num_key_value_groups, :]
541
+ query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d')
542
+ key_states = qkv_states[..., -2, :]
543
+ value_states = qkv_states[..., -1, :]
544
+ query_states = query_states.transpose(1, 2)
545
+ key_states = key_states.transpose(1, 2)
546
+ value_states = value_states.transpose(1, 2)
547
+
548
+ kv_seq_len = key_states.shape[-2]
549
+ if past_key_value is not None:
550
+ kv_seq_len += past_key_value[0].shape[-2]
551
+
552
+ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
553
+
554
+ query_states, key_states = apply_rotary_pos_emb(
555
+ query_states, key_states, cos, sin, position_ids)
556
+
557
+ if past_key_value is not None:
558
+ # reuse k, v, self_attention
559
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
560
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
561
+
562
+ past_key_value = (key_states, value_states) if use_cache else None
563
+
564
+ query_states = query_states.transpose(1, 2)
565
+ key_states = key_states.transpose(1, 2)
566
+ value_states = value_states.transpose(1, 2)
567
+
568
+ attn_output = self._flash_attention_forward(
569
+ query_states,
570
+ key_states,
571
+ value_states,
572
+ attention_mask,
573
+ q_len)
574
+
575
+ attn_output = attn_output.reshape(bsz, q_len,
576
+ self.hidden_size).contiguous()
577
+ attn_output = self.wo(attn_output, im_mask)
578
+
579
+ if not output_attentions:
580
+ attn_weights = None
581
+
582
+ return attn_output, attn_weights, past_key_value
583
+
584
+ def _flash_attention_forward(
585
+ self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
586
+ ):
587
+ """
588
+ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token
589
+ first unpad the input, then computes the attention scores and pad the final attention scores.
590
+ Args:
591
+ query_states (`torch.Tensor`):
592
+ Input query states to be passed to Flash Attention API
593
+ key_states (`torch.Tensor`):
594
+ Input key states to be passed to Flash Attention API
595
+ value_states (`torch.Tensor`):
596
+ Input value states to be passed to Flash Attention API
597
+ attention_mask (`torch.Tensor`):
598
+ The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the
599
+ position of padding tokens and 1 for the position of non-padding tokens.
600
+ dropout (`int`, *optional*):
601
+ Attention dropout
602
+ softmax_scale (`float`, *optional*):
603
+ The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim)
604
+ """
605
+ # Contains at least one padding token in the sequence
606
+ causal = self.is_causal and query_length != 1
607
+ if attention_mask is not None:
608
+ batch_size = query_states.shape[0]
609
+ query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._unpad_input(
610
+ query_states, key_states, value_states, attention_mask, query_length
611
+ )
612
+
613
+ cu_seqlens_q, cu_seqlens_k = cu_seq_lens
614
+ max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
615
+
616
+ attn_output_unpad = flash_attn_varlen_func(
617
+ query_states,
618
+ key_states,
619
+ value_states,
620
+ cu_seqlens_q=cu_seqlens_q,
621
+ cu_seqlens_k=cu_seqlens_k,
622
+ max_seqlen_q=max_seqlen_in_batch_q,
623
+ max_seqlen_k=max_seqlen_in_batch_k,
624
+ dropout_p=dropout,
625
+ softmax_scale=softmax_scale,
626
+ causal=causal,
627
+ )
628
+
629
+ attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length)
630
+ else:
631
+ attn_output = flash_attn_func(
632
+ query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal
633
+ )
634
+
635
+ return attn_output
636
+
637
+ def _unpad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
638
+ indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask)
639
+ batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
640
+
641
+ key_layer = index_first_axis(
642
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
643
+ )
644
+ value_layer = index_first_axis(
645
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
646
+ )
647
+
648
+ if query_length == kv_seq_len:
649
+ query_layer = index_first_axis(
650
+ query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k
651
+ )
652
+ cu_seqlens_q = cu_seqlens_k
653
+ max_seqlen_in_batch_q = max_seqlen_in_batch_k
654
+ indices_q = indices_k
655
+ elif query_length == 1:
656
+ max_seqlen_in_batch_q = 1
657
+ cu_seqlens_q = torch.arange(
658
+ batch_size + 1, dtype=torch.int32, device=query_layer.device
659
+ ) # There is a memcpy here, that is very bad.
660
+ indices_q = cu_seqlens_q[:-1]
661
+ query_layer = query_layer.squeeze(1)
662
+ else:
663
+ # The -q_len: slice assumes left padding.
664
+ attention_mask = attention_mask[:, -query_length:]
665
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
666
+
667
+ return (
668
+ query_layer,
669
+ key_layer,
670
+ value_layer,
671
+ indices_q.to(torch.int64),
672
+ (cu_seqlens_q, cu_seqlens_k),
673
+ (max_seqlen_in_batch_q, max_seqlen_in_batch_k),
674
+ )
675
+
676
  class InternLM2DecoderLayer(nn.Module):
677
 
678
  def __init__(self, config: InternLM2Config):
679
  super().__init__()
680
  self.hidden_size = config.hidden_size
681
+ self.attention = (
682
+ InternLM2Attention(config=config)
683
+ if not getattr(config, 'attn_implementation')=="flash_attention_2" else
684
+ InternLM2FlashAttention2(config=config))
685
  self.feed_forward = InternLM2MLP(config)
686
  self.attention_norm = InternLM2RMSNorm(
687
  config.hidden_size, eps=config.rms_norm_eps)
 
946
 
947
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
948
 
949
+ if self.config.attn_implementation: _import_flash_attn()
950
 
951
  # retrieve input_ids and inputs_embeds
952
  if input_ids is not None and inputs_embeds is not None: