radna commited on
Commit
cd3db47
1 Parent(s): 75397f9

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - laion/laion2B-en
5
+ - laion/laion-coco
6
+ - laion/laion2B-multi
7
+ - kakaobrain/coyo-700m
8
+ - conceptual_captions
9
+ - wanng/wukong100m
10
+ pipeline_tag: image-feature-extraction
11
+ ---
12
+
13
+ # Model Card for InternViT-6B-448px-V1-5
14
+ <p align="center">
15
+ <img src="https://cdn-uploads.huggingface.co/production/uploads/64119264f0f81eb569e0d569/AUE-3OBtfr9vDA7Elgkhd.webp" alt="Image Description" width="300" height="300">
16
+ </p>
17
+
18
+ [\[🆕 Blog\]](https://internvl.github.io/blog/) [\[📜 InternVL 1.0 Paper\]](https://arxiv.org/abs/2312.14238) [\[📜 InternVL 1.5 Report\]](https://arxiv.org/abs/2404.16821) [\[🗨️ Chat Demo\]](https://internvl.opengvlab.com/)
19
+
20
+ [\[🤗 HF Demo\]](https://huggingface.co/spaces/OpenGVLab/InternVL) [\[🚀 Quick Start\]](#model-usage) [\[🌐 Community-hosted API\]](https://rapidapi.com/adushar1320/api/internvl-chat) [\[📖 中文解读\]](https://zhuanlan.zhihu.com/p/675877376)
21
+
22
+ We develop InternViT-6B-448px-V1-5 based on the pre-training of the strong foundation of [InternViT-6B-448px-V1-2](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-2). In this update, the resolution of training images is expanded from 448&times;448 to dynamic 448&times;448, where the basic tile size is 448&times;448 and the number of tiles ranges from 1 to 12.
23
+ Additionally, we enhance the data scale, quality, and diversity of the pre-training dataset, resulting in the powerful robustness, OCR capability, and high-resolution processing capability of our
24
+ 1.5 version model.
25
+
26
+ ## Model Details
27
+ - **Model Type:** vision foundation model, feature backbone
28
+ - **Model Stats:**
29
+ - Params (M): 5540 (the last 3 blocks are discarded)
30
+ - Image size: 448 x 448, training with 1 - 12 tiles
31
+ - **Pretrain Dataset:** LAION-en, LAION-zh, COYO, GRIT, COCO, TextCaps, Objects365, OpenImages, All-Seeing, Wukong-OCR, LaionCOCO-OCR, and other OCR-related datasets.
32
+ To enhance the OCR capability of the model, we have incorporated additional OCR data alongside the general caption datasets. Specifically, we utilized PaddleOCR to perform Chinese OCR on images from Wukong and English OCR on images from LAION-COCO.
33
+ - **Note:** InternViT-6B originally had 48 blocks, and we found that using the output after the fourth-to-last block worked best for MLLM. For ease of use and to save GPU memory, we simply discarded the last 3 blocks. Now, the model has only 45 blocks and the number of parameters has been reduced from 5.9B to 5.5B. Therefore, if you want to build a MLLM based on this model, **please make use of the features from the last layer.**
34
+
35
+ ## Released Models
36
+ ### Vision Foundation model
37
+ | Model | Date | Download | Note |
38
+ | ----------------------- | ---------- | ---------------------------------------------------------------------- | -------------------------------- |
39
+ | InternViT-6B-448px-V1-5 | 2024.04.20 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-5) | support dynamic resolution, super strong OCR (🔥new) |
40
+ | InternViT-6B-448px-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-2) | 448 resolution |
41
+ | InternViT-6B-448px-V1-0 | 2024.01.30 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-448px-V1-0) | 448 resolution |
42
+ | InternViT-6B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternViT-6B-224px) | vision foundation model |
43
+ | InternVL-14B-224px | 2023.12.22 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-14B-224px) | vision-language foundation model |
44
+
45
+ ### Multimodal Large Language Model (MLLM)
46
+ | Model | Date | Download | Note |
47
+ | ----------------------- | ---------- | --------------------------------------------------------------------------- | ---------------------------------- |
48
+ | InternVL-Chat-V1-5 | 2024.04.18 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-5) | support 4K image; super strong OCR; Approaching the performance of GPT-4V and Gemini Pro on various benchmarks like MMMU, DocVQA, ChartQA, MathVista, etc. (🔥new)|
49
+ | InternVL-Chat-V1-2-Plus | 2024.02.21 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2-Plus) | more SFT data and stronger |
50
+ | InternVL-Chat-V1-2 | 2024.02.11 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-2) | scaling up LLM to 34B |
51
+ | InternVL-Chat-V1-1 | 2024.01.24 | 🤗 [HF link](https://huggingface.co/OpenGVLab/InternVL-Chat-V1-1) | support Chinese and stronger OCR |
52
+
53
+ ## Model Usage (Image Embeddings)
54
+
55
+ ```python
56
+ import torch
57
+ from PIL import Image
58
+ from transformers import AutoModel, CLIPImageProcessor
59
+
60
+ model = AutoModel.from_pretrained(
61
+ 'OpenGVLab/InternViT-6B-448px-V1-5',
62
+ torch_dtype=torch.bfloat16,
63
+ low_cpu_mem_usage=True,
64
+ trust_remote_code=True).cuda().eval()
65
+
66
+ image = Image.open('./examples/image1.jpg').convert('RGB')
67
+
68
+ image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternViT-6B-448px-V1-5')
69
+
70
+ pixel_values = image_processor(images=image, return_tensors='pt').pixel_values
71
+ pixel_values = pixel_values.to(torch.bfloat16).cuda()
72
+
73
+ outputs = model(pixel_values)
74
+ ```
75
+
76
+ ## Citation
77
+
78
+ If you find this project useful in your research, please consider citing:
79
+
80
+ ```BibTeX
81
+ @article{chen2023internvl,
82
+ title={InternVL: Scaling up Vision Foundation Models and Aligning for Generic Visual-Linguistic Tasks},
83
+ author={Chen, Zhe and Wu, Jiannan and Wang, Wenhai and Su, Weijie and Chen, Guo and Xing, Sen and Zhong, Muyan and Zhang, Qinglong and Zhu, Xizhou and Lu, Lewei and Li, Bin and Luo, Ping and Lu, Tong and Qiao, Yu and Dai, Jifeng},
84
+ journal={arXiv preprint arXiv:2312.14238},
85
+ year={2023}
86
+ }
87
+ @article{chen2024far,
88
+ title={How Far Are We to GPT-4V? Closing the Gap to Commercial Multimodal Models with Open-Source Suites},
89
+ author={Chen, Zhe and Wang, Weiyun and Tian, Hao and Ye, Shenglong and Gao, Zhangwei and Cui, Erfei and Tong, Wenwen and Hu, Kongzhi and Luo, Jiapeng and Ma, Zheng and others},
90
+ journal={arXiv preprint arXiv:2404.16821},
91
+ year={2024}
92
+ }
93
+ ```
94
+
95
+
96
+ ## Acknowledgement
97
+
98
+ InternVL is built with reference to the code of the following projects: [OpenAI CLIP](https://github.com/openai/CLIP), [Open CLIP](https://github.com/mlfoundations/open_clip), [CLIP Benchmark](https://github.com/LAION-AI/CLIP_benchmark), [EVA](https://github.com/baaivision/EVA/tree/master), [InternImage](https://github.com/OpenGVLab/InternImage), [ViT-Adapter](https://github.com/czczup/ViT-Adapter), [MMSegmentation](https://github.com/open-mmlab/mmsegmentation), [Transformers](https://github.com/huggingface/transformers), [DINOv2](https://github.com/facebookresearch/dinov2), [BLIP-2](https://github.com/salesforce/LAVIS/tree/main/projects/blip2), [Qwen-VL](https://github.com/QwenLM/Qwen-VL/tree/master/eval_mm), and [LLaVA-1.5](https://github.com/haotian-liu/LLaVA). Thanks for their awesome work!
config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "OpenGVLab/InternViT-6B-448px-V1-5",
3
+ "architectures": [
4
+ "InternVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_intern_vit.InternVisionConfig",
9
+ "AutoModel": "modeling_intern_vit.InternVisionModel"
10
+ },
11
+ "drop_path_rate": 0.0,
12
+ "dropout": 0.0,
13
+ "hidden_act": "gelu",
14
+ "hidden_size": 3200,
15
+ "image_size": 448,
16
+ "initializer_factor": 0.1,
17
+ "initializer_range": 1e-10,
18
+ "intermediate_size": 12800,
19
+ "layer_norm_eps": 1e-06,
20
+ "model_type": "intern_vit_6b",
21
+ "num_attention_heads": 25,
22
+ "num_channels": 3,
23
+ "num_hidden_layers": 45,
24
+ "patch_size": 14,
25
+ "qk_normalization": true,
26
+ "qkv_bias": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "4.36.2",
29
+ "use_bfloat16": true,
30
+ "use_flash_attn": true
31
+ }
configuration_intern_vit.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ import os
7
+ from typing import Union
8
+
9
+ from transformers.configuration_utils import PretrainedConfig
10
+ from transformers.utils import logging
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class InternVisionConfig(PretrainedConfig):
16
+ r"""
17
+ This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
18
+ instantiate a vision encoder according to the specified arguments, defining the model architecture.
19
+
20
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
21
+ documentation from [`PretrainedConfig`] for more information.
22
+
23
+ Args:
24
+ num_channels (`int`, *optional*, defaults to 3):
25
+ Number of color channels in the input images (e.g., 3 for RGB).
26
+ patch_size (`int`, *optional*, defaults to 14):
27
+ The size (resolution) of each patch.
28
+ image_size (`int`, *optional*, defaults to 224):
29
+ The size (resolution) of each image.
30
+ qkv_bias (`bool`, *optional*, defaults to `False`):
31
+ Whether to add a bias to the queries and values in the self-attention layers.
32
+ hidden_size (`int`, *optional*, defaults to 3200):
33
+ Dimensionality of the encoder layers and the pooler layer.
34
+ num_attention_heads (`int`, *optional*, defaults to 25):
35
+ Number of attention heads for each attention layer in the Transformer encoder.
36
+ intermediate_size (`int`, *optional*, defaults to 12800):
37
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
38
+ qk_normalization (`bool`, *optional*, defaults to `True`):
39
+ Whether to normalize the queries and keys in the self-attention layers.
40
+ num_hidden_layers (`int`, *optional*, defaults to 48):
41
+ Number of hidden layers in the Transformer encoder.
42
+ use_flash_attn (`bool`, *optional*, defaults to `True`):
43
+ Whether to use flash attention mechanism.
44
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
45
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
46
+ `"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
47
+ layer_norm_eps (`float`, *optional*, defaults to 1e-6):
48
+ The epsilon used by the layer normalization layers.
49
+ dropout (`float`, *optional*, defaults to 0.0):
50
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
51
+ drop_path_rate (`float`, *optional*, defaults to 0.0):
52
+ Dropout rate for stochastic depth.
53
+ attention_dropout (`float`, *optional*, defaults to 0.0):
54
+ The dropout ratio for the attention probabilities.
55
+ initializer_range (`float`, *optional*, defaults to 0.02):
56
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
57
+ initializer_factor (`float`, *optional*, defaults to 0.1):
58
+ A factor for layer scale.
59
+ """
60
+
61
+ model_type = 'intern_vit_6b'
62
+
63
+ def __init__(
64
+ self,
65
+ num_channels=3,
66
+ patch_size=14,
67
+ image_size=224,
68
+ qkv_bias=False,
69
+ hidden_size=3200,
70
+ num_attention_heads=25,
71
+ intermediate_size=12800,
72
+ qk_normalization=True,
73
+ num_hidden_layers=48,
74
+ use_flash_attn=True,
75
+ hidden_act='gelu',
76
+ layer_norm_eps=1e-6,
77
+ dropout=0.0,
78
+ drop_path_rate=0.0,
79
+ attention_dropout=0.0,
80
+ initializer_range=0.02,
81
+ initializer_factor=0.1,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(**kwargs)
85
+
86
+ self.hidden_size = hidden_size
87
+ self.intermediate_size = intermediate_size
88
+ self.dropout = dropout
89
+ self.drop_path_rate = drop_path_rate
90
+ self.num_hidden_layers = num_hidden_layers
91
+ self.num_attention_heads = num_attention_heads
92
+ self.num_channels = num_channels
93
+ self.patch_size = patch_size
94
+ self.image_size = image_size
95
+ self.initializer_range = initializer_range
96
+ self.initializer_factor = initializer_factor
97
+ self.attention_dropout = attention_dropout
98
+ self.layer_norm_eps = layer_norm_eps
99
+ self.hidden_act = hidden_act
100
+ self.qkv_bias = qkv_bias
101
+ self.qk_normalization = qk_normalization
102
+ self.use_flash_attn = use_flash_attn
103
+
104
+ @classmethod
105
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> 'PretrainedConfig':
106
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
107
+
108
+ if 'vision_config' in config_dict:
109
+ config_dict = config_dict['vision_config']
110
+
111
+ if 'model_type' in config_dict and hasattr(cls, 'model_type') and config_dict['model_type'] != cls.model_type:
112
+ logger.warning(
113
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
114
+ f'{cls.model_type}. This is not supported for all configurations of models and can yield errors.'
115
+ )
116
+
117
+ return cls.from_dict(config_dict, **kwargs)
flash_attention.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from einops import rearrange
4
+
5
+ try: # v1
6
+ from flash_attn.flash_attn_interface import \
7
+ flash_attn_unpadded_qkvpacked_func
8
+ except: # v2
9
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func
10
+
11
+ from flash_attn.bert_padding import pad_input, unpad_input
12
+
13
+
14
+ class FlashAttention(nn.Module):
15
+ """Implement the scaled dot product attention with softmax.
16
+ Arguments
17
+ ---------
18
+ softmax_scale: The temperature to use for the softmax attention.
19
+ (default: 1/sqrt(d_keys) where d_keys is computed at
20
+ runtime)
21
+ attention_dropout: The dropout rate to apply to the attention
22
+ (default: 0.0)
23
+ """
24
+
25
+ def __init__(self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None):
26
+ super().__init__()
27
+ self.softmax_scale = softmax_scale
28
+ self.dropout_p = attention_dropout
29
+
30
+ def forward(self, qkv, key_padding_mask=None, causal=False, cu_seqlens=None,
31
+ max_s=None, need_weights=False):
32
+ """Implements the multihead softmax attention.
33
+ Arguments
34
+ ---------
35
+ qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
36
+ if unpadded: (nnz, 3, h, d)
37
+ key_padding_mask: a bool tensor of shape (B, S)
38
+ """
39
+ assert not need_weights
40
+ assert qkv.dtype in [torch.float16, torch.bfloat16]
41
+ assert qkv.is_cuda
42
+
43
+ if cu_seqlens is None:
44
+ batch_size = qkv.shape[0]
45
+ seqlen = qkv.shape[1]
46
+ if key_padding_mask is None:
47
+ qkv = rearrange(qkv, 'b s ... -> (b s) ...')
48
+ max_s = seqlen
49
+ cu_seqlens = torch.arange(0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32,
50
+ device=qkv.device)
51
+ output = flash_attn_unpadded_qkvpacked_func(
52
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
53
+ softmax_scale=self.softmax_scale, causal=causal
54
+ )
55
+ output = rearrange(output, '(b s) ... -> b s ...', b=batch_size)
56
+ else:
57
+ nheads = qkv.shape[-2]
58
+ x = rearrange(qkv, 'b s three h d -> b s (three h d)')
59
+ x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask)
60
+ x_unpad = rearrange(x_unpad, 'nnz (three h d) -> nnz three h d', three=3, h=nheads)
61
+ output_unpad = flash_attn_unpadded_qkvpacked_func(
62
+ x_unpad, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
63
+ softmax_scale=self.softmax_scale, causal=causal
64
+ )
65
+ output = rearrange(pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'),
66
+ indices, batch_size, seqlen),
67
+ 'b s (h d) -> b s h d', h=nheads)
68
+ else:
69
+ assert max_s is not None
70
+ output = flash_attn_unpadded_qkvpacked_func(
71
+ qkv, cu_seqlens, max_s, self.dropout_p if self.training else 0.0,
72
+ softmax_scale=self.softmax_scale, causal=causal
73
+ )
74
+
75
+ return output, None
model-00001-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:331fc0e79147081bb4260491b4db121aaf4252e0f29ed3509ae5df11bd8ae41e
3
+ size 4988565944
model-00002-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4785be9bec8771f0b25a2f33c52fe9e53623068eb0e7d72aa01e410c43a91cbc
3
+ size 4937250176
model-00003-of-00003.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:95fe64ed513580d1fbd4257823adfd5b7b1a283c70cdd2771453443dd1f0b6b6
3
+ size 1147238088
model.safetensors.index.json ADDED
@@ -0,0 +1,596 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "metadata": {
3
+ "total_size": 11072992000
4
+ },
5
+ "weight_map": {
6
+ "embeddings.class_embedding": "model-00001-of-00003.safetensors",
7
+ "embeddings.patch_embedding.bias": "model-00001-of-00003.safetensors",
8
+ "embeddings.patch_embedding.weight": "model-00001-of-00003.safetensors",
9
+ "embeddings.position_embedding": "model-00001-of-00003.safetensors",
10
+ "encoder.layers.0.attn.k_norm.weight": "model-00001-of-00003.safetensors",
11
+ "encoder.layers.0.attn.proj.bias": "model-00001-of-00003.safetensors",
12
+ "encoder.layers.0.attn.proj.weight": "model-00001-of-00003.safetensors",
13
+ "encoder.layers.0.attn.q_norm.weight": "model-00001-of-00003.safetensors",
14
+ "encoder.layers.0.attn.qkv.weight": "model-00001-of-00003.safetensors",
15
+ "encoder.layers.0.ls1": "model-00001-of-00003.safetensors",
16
+ "encoder.layers.0.ls2": "model-00001-of-00003.safetensors",
17
+ "encoder.layers.0.mlp.fc1.bias": "model-00001-of-00003.safetensors",
18
+ "encoder.layers.0.mlp.fc1.weight": "model-00001-of-00003.safetensors",
19
+ "encoder.layers.0.mlp.fc2.bias": "model-00001-of-00003.safetensors",
20
+ "encoder.layers.0.mlp.fc2.weight": "model-00001-of-00003.safetensors",
21
+ "encoder.layers.0.norm1.weight": "model-00001-of-00003.safetensors",
22
+ "encoder.layers.0.norm2.weight": "model-00001-of-00003.safetensors",
23
+ "encoder.layers.1.attn.k_norm.weight": "model-00001-of-00003.safetensors",
24
+ "encoder.layers.1.attn.proj.bias": "model-00001-of-00003.safetensors",
25
+ "encoder.layers.1.attn.proj.weight": "model-00001-of-00003.safetensors",
26
+ "encoder.layers.1.attn.q_norm.weight": "model-00001-of-00003.safetensors",
27
+ "encoder.layers.1.attn.qkv.weight": "model-00001-of-00003.safetensors",
28
+ "encoder.layers.1.ls1": "model-00001-of-00003.safetensors",
29
+ "encoder.layers.1.ls2": "model-00001-of-00003.safetensors",
30
+ "encoder.layers.1.mlp.fc1.bias": "model-00001-of-00003.safetensors",
31
+ "encoder.layers.1.mlp.fc1.weight": "model-00001-of-00003.safetensors",
32
+ "encoder.layers.1.mlp.fc2.bias": "model-00001-of-00003.safetensors",
33
+ "encoder.layers.1.mlp.fc2.weight": "model-00001-of-00003.safetensors",
34
+ "encoder.layers.1.norm1.weight": "model-00001-of-00003.safetensors",
35
+ "encoder.layers.1.norm2.weight": "model-00001-of-00003.safetensors",
36
+ "encoder.layers.10.attn.k_norm.weight": "model-00001-of-00003.safetensors",
37
+ "encoder.layers.10.attn.proj.bias": "model-00001-of-00003.safetensors",
38
+ "encoder.layers.10.attn.proj.weight": "model-00001-of-00003.safetensors",
39
+ "encoder.layers.10.attn.q_norm.weight": "model-00001-of-00003.safetensors",
40
+ "encoder.layers.10.attn.qkv.weight": "model-00001-of-00003.safetensors",
41
+ "encoder.layers.10.ls1": "model-00001-of-00003.safetensors",
42
+ "encoder.layers.10.ls2": "model-00001-of-00003.safetensors",
43
+ "encoder.layers.10.mlp.fc1.bias": "model-00001-of-00003.safetensors",
44
+ "encoder.layers.10.mlp.fc1.weight": "model-00001-of-00003.safetensors",
45
+ "encoder.layers.10.mlp.fc2.bias": "model-00001-of-00003.safetensors",
46
+ "encoder.layers.10.mlp.fc2.weight": "model-00001-of-00003.safetensors",
47
+ "encoder.layers.10.norm1.weight": "model-00001-of-00003.safetensors",
48
+ "encoder.layers.10.norm2.weight": "model-00001-of-00003.safetensors",
49
+ "encoder.layers.11.attn.k_norm.weight": "model-00001-of-00003.safetensors",
50
+ "encoder.layers.11.attn.proj.bias": "model-00001-of-00003.safetensors",
51
+ "encoder.layers.11.attn.proj.weight": "model-00001-of-00003.safetensors",
52
+ "encoder.layers.11.attn.q_norm.weight": "model-00001-of-00003.safetensors",
53
+ "encoder.layers.11.attn.qkv.weight": "model-00001-of-00003.safetensors",
54
+ "encoder.layers.11.ls1": "model-00001-of-00003.safetensors",
55
+ "encoder.layers.11.ls2": "model-00001-of-00003.safetensors",
56
+ "encoder.layers.11.mlp.fc1.bias": "model-00001-of-00003.safetensors",
57
+ "encoder.layers.11.mlp.fc1.weight": "model-00001-of-00003.safetensors",
58
+ "encoder.layers.11.mlp.fc2.bias": "model-00001-of-00003.safetensors",
59
+ "encoder.layers.11.mlp.fc2.weight": "model-00001-of-00003.safetensors",
60
+ "encoder.layers.11.norm1.weight": "model-00001-of-00003.safetensors",
61
+ "encoder.layers.11.norm2.weight": "model-00001-of-00003.safetensors",
62
+ "encoder.layers.12.attn.k_norm.weight": "model-00001-of-00003.safetensors",
63
+ "encoder.layers.12.attn.proj.bias": "model-00001-of-00003.safetensors",
64
+ "encoder.layers.12.attn.proj.weight": "model-00001-of-00003.safetensors",
65
+ "encoder.layers.12.attn.q_norm.weight": "model-00001-of-00003.safetensors",
66
+ "encoder.layers.12.attn.qkv.weight": "model-00001-of-00003.safetensors",
67
+ "encoder.layers.12.ls1": "model-00001-of-00003.safetensors",
68
+ "encoder.layers.12.ls2": "model-00001-of-00003.safetensors",
69
+ "encoder.layers.12.mlp.fc1.bias": "model-00001-of-00003.safetensors",
70
+ "encoder.layers.12.mlp.fc1.weight": "model-00001-of-00003.safetensors",
71
+ "encoder.layers.12.mlp.fc2.bias": "model-00001-of-00003.safetensors",
72
+ "encoder.layers.12.mlp.fc2.weight": "model-00001-of-00003.safetensors",
73
+ "encoder.layers.12.norm1.weight": "model-00001-of-00003.safetensors",
74
+ "encoder.layers.12.norm2.weight": "model-00001-of-00003.safetensors",
75
+ "encoder.layers.13.attn.k_norm.weight": "model-00001-of-00003.safetensors",
76
+ "encoder.layers.13.attn.proj.bias": "model-00001-of-00003.safetensors",
77
+ "encoder.layers.13.attn.proj.weight": "model-00001-of-00003.safetensors",
78
+ "encoder.layers.13.attn.q_norm.weight": "model-00001-of-00003.safetensors",
79
+ "encoder.layers.13.attn.qkv.weight": "model-00001-of-00003.safetensors",
80
+ "encoder.layers.13.ls1": "model-00001-of-00003.safetensors",
81
+ "encoder.layers.13.ls2": "model-00001-of-00003.safetensors",
82
+ "encoder.layers.13.mlp.fc1.bias": "model-00001-of-00003.safetensors",
83
+ "encoder.layers.13.mlp.fc1.weight": "model-00001-of-00003.safetensors",
84
+ "encoder.layers.13.mlp.fc2.bias": "model-00001-of-00003.safetensors",
85
+ "encoder.layers.13.mlp.fc2.weight": "model-00001-of-00003.safetensors",
86
+ "encoder.layers.13.norm1.weight": "model-00001-of-00003.safetensors",
87
+ "encoder.layers.13.norm2.weight": "model-00001-of-00003.safetensors",
88
+ "encoder.layers.14.attn.k_norm.weight": "model-00001-of-00003.safetensors",
89
+ "encoder.layers.14.attn.proj.bias": "model-00001-of-00003.safetensors",
90
+ "encoder.layers.14.attn.proj.weight": "model-00001-of-00003.safetensors",
91
+ "encoder.layers.14.attn.q_norm.weight": "model-00001-of-00003.safetensors",
92
+ "encoder.layers.14.attn.qkv.weight": "model-00001-of-00003.safetensors",
93
+ "encoder.layers.14.ls1": "model-00001-of-00003.safetensors",
94
+ "encoder.layers.14.ls2": "model-00001-of-00003.safetensors",
95
+ "encoder.layers.14.mlp.fc1.bias": "model-00001-of-00003.safetensors",
96
+ "encoder.layers.14.mlp.fc1.weight": "model-00001-of-00003.safetensors",
97
+ "encoder.layers.14.mlp.fc2.bias": "model-00001-of-00003.safetensors",
98
+ "encoder.layers.14.mlp.fc2.weight": "model-00001-of-00003.safetensors",
99
+ "encoder.layers.14.norm1.weight": "model-00001-of-00003.safetensors",
100
+ "encoder.layers.14.norm2.weight": "model-00001-of-00003.safetensors",
101
+ "encoder.layers.15.attn.k_norm.weight": "model-00001-of-00003.safetensors",
102
+ "encoder.layers.15.attn.proj.bias": "model-00001-of-00003.safetensors",
103
+ "encoder.layers.15.attn.proj.weight": "model-00001-of-00003.safetensors",
104
+ "encoder.layers.15.attn.q_norm.weight": "model-00001-of-00003.safetensors",
105
+ "encoder.layers.15.attn.qkv.weight": "model-00001-of-00003.safetensors",
106
+ "encoder.layers.15.ls1": "model-00001-of-00003.safetensors",
107
+ "encoder.layers.15.ls2": "model-00001-of-00003.safetensors",
108
+ "encoder.layers.15.mlp.fc1.bias": "model-00001-of-00003.safetensors",
109
+ "encoder.layers.15.mlp.fc1.weight": "model-00001-of-00003.safetensors",
110
+ "encoder.layers.15.mlp.fc2.bias": "model-00001-of-00003.safetensors",
111
+ "encoder.layers.15.mlp.fc2.weight": "model-00001-of-00003.safetensors",
112
+ "encoder.layers.15.norm1.weight": "model-00001-of-00003.safetensors",
113
+ "encoder.layers.15.norm2.weight": "model-00001-of-00003.safetensors",
114
+ "encoder.layers.16.attn.k_norm.weight": "model-00001-of-00003.safetensors",
115
+ "encoder.layers.16.attn.proj.bias": "model-00001-of-00003.safetensors",
116
+ "encoder.layers.16.attn.proj.weight": "model-00001-of-00003.safetensors",
117
+ "encoder.layers.16.attn.q_norm.weight": "model-00001-of-00003.safetensors",
118
+ "encoder.layers.16.attn.qkv.weight": "model-00001-of-00003.safetensors",
119
+ "encoder.layers.16.ls1": "model-00001-of-00003.safetensors",
120
+ "encoder.layers.16.ls2": "model-00001-of-00003.safetensors",
121
+ "encoder.layers.16.mlp.fc1.bias": "model-00001-of-00003.safetensors",
122
+ "encoder.layers.16.mlp.fc1.weight": "model-00001-of-00003.safetensors",
123
+ "encoder.layers.16.mlp.fc2.bias": "model-00001-of-00003.safetensors",
124
+ "encoder.layers.16.mlp.fc2.weight": "model-00001-of-00003.safetensors",
125
+ "encoder.layers.16.norm1.weight": "model-00001-of-00003.safetensors",
126
+ "encoder.layers.16.norm2.weight": "model-00001-of-00003.safetensors",
127
+ "encoder.layers.17.attn.k_norm.weight": "model-00001-of-00003.safetensors",
128
+ "encoder.layers.17.attn.proj.bias": "model-00001-of-00003.safetensors",
129
+ "encoder.layers.17.attn.proj.weight": "model-00001-of-00003.safetensors",
130
+ "encoder.layers.17.attn.q_norm.weight": "model-00001-of-00003.safetensors",
131
+ "encoder.layers.17.attn.qkv.weight": "model-00001-of-00003.safetensors",
132
+ "encoder.layers.17.ls1": "model-00001-of-00003.safetensors",
133
+ "encoder.layers.17.ls2": "model-00001-of-00003.safetensors",
134
+ "encoder.layers.17.mlp.fc1.bias": "model-00001-of-00003.safetensors",
135
+ "encoder.layers.17.mlp.fc1.weight": "model-00001-of-00003.safetensors",
136
+ "encoder.layers.17.mlp.fc2.bias": "model-00001-of-00003.safetensors",
137
+ "encoder.layers.17.mlp.fc2.weight": "model-00001-of-00003.safetensors",
138
+ "encoder.layers.17.norm1.weight": "model-00001-of-00003.safetensors",
139
+ "encoder.layers.17.norm2.weight": "model-00001-of-00003.safetensors",
140
+ "encoder.layers.18.attn.k_norm.weight": "model-00001-of-00003.safetensors",
141
+ "encoder.layers.18.attn.proj.bias": "model-00001-of-00003.safetensors",
142
+ "encoder.layers.18.attn.proj.weight": "model-00001-of-00003.safetensors",
143
+ "encoder.layers.18.attn.q_norm.weight": "model-00001-of-00003.safetensors",
144
+ "encoder.layers.18.attn.qkv.weight": "model-00001-of-00003.safetensors",
145
+ "encoder.layers.18.ls1": "model-00001-of-00003.safetensors",
146
+ "encoder.layers.18.ls2": "model-00001-of-00003.safetensors",
147
+ "encoder.layers.18.mlp.fc1.bias": "model-00001-of-00003.safetensors",
148
+ "encoder.layers.18.mlp.fc1.weight": "model-00001-of-00003.safetensors",
149
+ "encoder.layers.18.mlp.fc2.bias": "model-00001-of-00003.safetensors",
150
+ "encoder.layers.18.mlp.fc2.weight": "model-00001-of-00003.safetensors",
151
+ "encoder.layers.18.norm1.weight": "model-00001-of-00003.safetensors",
152
+ "encoder.layers.18.norm2.weight": "model-00001-of-00003.safetensors",
153
+ "encoder.layers.19.attn.k_norm.weight": "model-00001-of-00003.safetensors",
154
+ "encoder.layers.19.attn.proj.bias": "model-00001-of-00003.safetensors",
155
+ "encoder.layers.19.attn.proj.weight": "model-00001-of-00003.safetensors",
156
+ "encoder.layers.19.attn.q_norm.weight": "model-00001-of-00003.safetensors",
157
+ "encoder.layers.19.attn.qkv.weight": "model-00001-of-00003.safetensors",
158
+ "encoder.layers.19.ls1": "model-00001-of-00003.safetensors",
159
+ "encoder.layers.19.ls2": "model-00001-of-00003.safetensors",
160
+ "encoder.layers.19.mlp.fc1.bias": "model-00001-of-00003.safetensors",
161
+ "encoder.layers.19.mlp.fc1.weight": "model-00001-of-00003.safetensors",
162
+ "encoder.layers.19.mlp.fc2.bias": "model-00001-of-00003.safetensors",
163
+ "encoder.layers.19.mlp.fc2.weight": "model-00001-of-00003.safetensors",
164
+ "encoder.layers.19.norm1.weight": "model-00001-of-00003.safetensors",
165
+ "encoder.layers.19.norm2.weight": "model-00001-of-00003.safetensors",
166
+ "encoder.layers.2.attn.k_norm.weight": "model-00001-of-00003.safetensors",
167
+ "encoder.layers.2.attn.proj.bias": "model-00001-of-00003.safetensors",
168
+ "encoder.layers.2.attn.proj.weight": "model-00001-of-00003.safetensors",
169
+ "encoder.layers.2.attn.q_norm.weight": "model-00001-of-00003.safetensors",
170
+ "encoder.layers.2.attn.qkv.weight": "model-00001-of-00003.safetensors",
171
+ "encoder.layers.2.ls1": "model-00001-of-00003.safetensors",
172
+ "encoder.layers.2.ls2": "model-00001-of-00003.safetensors",
173
+ "encoder.layers.2.mlp.fc1.bias": "model-00001-of-00003.safetensors",
174
+ "encoder.layers.2.mlp.fc1.weight": "model-00001-of-00003.safetensors",
175
+ "encoder.layers.2.mlp.fc2.bias": "model-00001-of-00003.safetensors",
176
+ "encoder.layers.2.mlp.fc2.weight": "model-00001-of-00003.safetensors",
177
+ "encoder.layers.2.norm1.weight": "model-00001-of-00003.safetensors",
178
+ "encoder.layers.2.norm2.weight": "model-00001-of-00003.safetensors",
179
+ "encoder.layers.20.attn.k_norm.weight": "model-00001-of-00003.safetensors",
180
+ "encoder.layers.20.attn.proj.bias": "model-00002-of-00003.safetensors",
181
+ "encoder.layers.20.attn.proj.weight": "model-00002-of-00003.safetensors",
182
+ "encoder.layers.20.attn.q_norm.weight": "model-00001-of-00003.safetensors",
183
+ "encoder.layers.20.attn.qkv.weight": "model-00001-of-00003.safetensors",
184
+ "encoder.layers.20.ls1": "model-00001-of-00003.safetensors",
185
+ "encoder.layers.20.ls2": "model-00001-of-00003.safetensors",
186
+ "encoder.layers.20.mlp.fc1.bias": "model-00002-of-00003.safetensors",
187
+ "encoder.layers.20.mlp.fc1.weight": "model-00002-of-00003.safetensors",
188
+ "encoder.layers.20.mlp.fc2.bias": "model-00002-of-00003.safetensors",
189
+ "encoder.layers.20.mlp.fc2.weight": "model-00002-of-00003.safetensors",
190
+ "encoder.layers.20.norm1.weight": "model-00002-of-00003.safetensors",
191
+ "encoder.layers.20.norm2.weight": "model-00002-of-00003.safetensors",
192
+ "encoder.layers.21.attn.k_norm.weight": "model-00002-of-00003.safetensors",
193
+ "encoder.layers.21.attn.proj.bias": "model-00002-of-00003.safetensors",
194
+ "encoder.layers.21.attn.proj.weight": "model-00002-of-00003.safetensors",
195
+ "encoder.layers.21.attn.q_norm.weight": "model-00002-of-00003.safetensors",
196
+ "encoder.layers.21.attn.qkv.weight": "model-00002-of-00003.safetensors",
197
+ "encoder.layers.21.ls1": "model-00002-of-00003.safetensors",
198
+ "encoder.layers.21.ls2": "model-00002-of-00003.safetensors",
199
+ "encoder.layers.21.mlp.fc1.bias": "model-00002-of-00003.safetensors",
200
+ "encoder.layers.21.mlp.fc1.weight": "model-00002-of-00003.safetensors",
201
+ "encoder.layers.21.mlp.fc2.bias": "model-00002-of-00003.safetensors",
202
+ "encoder.layers.21.mlp.fc2.weight": "model-00002-of-00003.safetensors",
203
+ "encoder.layers.21.norm1.weight": "model-00002-of-00003.safetensors",
204
+ "encoder.layers.21.norm2.weight": "model-00002-of-00003.safetensors",
205
+ "encoder.layers.22.attn.k_norm.weight": "model-00002-of-00003.safetensors",
206
+ "encoder.layers.22.attn.proj.bias": "model-00002-of-00003.safetensors",
207
+ "encoder.layers.22.attn.proj.weight": "model-00002-of-00003.safetensors",
208
+ "encoder.layers.22.attn.q_norm.weight": "model-00002-of-00003.safetensors",
209
+ "encoder.layers.22.attn.qkv.weight": "model-00002-of-00003.safetensors",
210
+ "encoder.layers.22.ls1": "model-00002-of-00003.safetensors",
211
+ "encoder.layers.22.ls2": "model-00002-of-00003.safetensors",
212
+ "encoder.layers.22.mlp.fc1.bias": "model-00002-of-00003.safetensors",
213
+ "encoder.layers.22.mlp.fc1.weight": "model-00002-of-00003.safetensors",
214
+ "encoder.layers.22.mlp.fc2.bias": "model-00002-of-00003.safetensors",
215
+ "encoder.layers.22.mlp.fc2.weight": "model-00002-of-00003.safetensors",
216
+ "encoder.layers.22.norm1.weight": "model-00002-of-00003.safetensors",
217
+ "encoder.layers.22.norm2.weight": "model-00002-of-00003.safetensors",
218
+ "encoder.layers.23.attn.k_norm.weight": "model-00002-of-00003.safetensors",
219
+ "encoder.layers.23.attn.proj.bias": "model-00002-of-00003.safetensors",
220
+ "encoder.layers.23.attn.proj.weight": "model-00002-of-00003.safetensors",
221
+ "encoder.layers.23.attn.q_norm.weight": "model-00002-of-00003.safetensors",
222
+ "encoder.layers.23.attn.qkv.weight": "model-00002-of-00003.safetensors",
223
+ "encoder.layers.23.ls1": "model-00002-of-00003.safetensors",
224
+ "encoder.layers.23.ls2": "model-00002-of-00003.safetensors",
225
+ "encoder.layers.23.mlp.fc1.bias": "model-00002-of-00003.safetensors",
226
+ "encoder.layers.23.mlp.fc1.weight": "model-00002-of-00003.safetensors",
227
+ "encoder.layers.23.mlp.fc2.bias": "model-00002-of-00003.safetensors",
228
+ "encoder.layers.23.mlp.fc2.weight": "model-00002-of-00003.safetensors",
229
+ "encoder.layers.23.norm1.weight": "model-00002-of-00003.safetensors",
230
+ "encoder.layers.23.norm2.weight": "model-00002-of-00003.safetensors",
231
+ "encoder.layers.24.attn.k_norm.weight": "model-00002-of-00003.safetensors",
232
+ "encoder.layers.24.attn.proj.bias": "model-00002-of-00003.safetensors",
233
+ "encoder.layers.24.attn.proj.weight": "model-00002-of-00003.safetensors",
234
+ "encoder.layers.24.attn.q_norm.weight": "model-00002-of-00003.safetensors",
235
+ "encoder.layers.24.attn.qkv.weight": "model-00002-of-00003.safetensors",
236
+ "encoder.layers.24.ls1": "model-00002-of-00003.safetensors",
237
+ "encoder.layers.24.ls2": "model-00002-of-00003.safetensors",
238
+ "encoder.layers.24.mlp.fc1.bias": "model-00002-of-00003.safetensors",
239
+ "encoder.layers.24.mlp.fc1.weight": "model-00002-of-00003.safetensors",
240
+ "encoder.layers.24.mlp.fc2.bias": "model-00002-of-00003.safetensors",
241
+ "encoder.layers.24.mlp.fc2.weight": "model-00002-of-00003.safetensors",
242
+ "encoder.layers.24.norm1.weight": "model-00002-of-00003.safetensors",
243
+ "encoder.layers.24.norm2.weight": "model-00002-of-00003.safetensors",
244
+ "encoder.layers.25.attn.k_norm.weight": "model-00002-of-00003.safetensors",
245
+ "encoder.layers.25.attn.proj.bias": "model-00002-of-00003.safetensors",
246
+ "encoder.layers.25.attn.proj.weight": "model-00002-of-00003.safetensors",
247
+ "encoder.layers.25.attn.q_norm.weight": "model-00002-of-00003.safetensors",
248
+ "encoder.layers.25.attn.qkv.weight": "model-00002-of-00003.safetensors",
249
+ "encoder.layers.25.ls1": "model-00002-of-00003.safetensors",
250
+ "encoder.layers.25.ls2": "model-00002-of-00003.safetensors",
251
+ "encoder.layers.25.mlp.fc1.bias": "model-00002-of-00003.safetensors",
252
+ "encoder.layers.25.mlp.fc1.weight": "model-00002-of-00003.safetensors",
253
+ "encoder.layers.25.mlp.fc2.bias": "model-00002-of-00003.safetensors",
254
+ "encoder.layers.25.mlp.fc2.weight": "model-00002-of-00003.safetensors",
255
+ "encoder.layers.25.norm1.weight": "model-00002-of-00003.safetensors",
256
+ "encoder.layers.25.norm2.weight": "model-00002-of-00003.safetensors",
257
+ "encoder.layers.26.attn.k_norm.weight": "model-00002-of-00003.safetensors",
258
+ "encoder.layers.26.attn.proj.bias": "model-00002-of-00003.safetensors",
259
+ "encoder.layers.26.attn.proj.weight": "model-00002-of-00003.safetensors",
260
+ "encoder.layers.26.attn.q_norm.weight": "model-00002-of-00003.safetensors",
261
+ "encoder.layers.26.attn.qkv.weight": "model-00002-of-00003.safetensors",
262
+ "encoder.layers.26.ls1": "model-00002-of-00003.safetensors",
263
+ "encoder.layers.26.ls2": "model-00002-of-00003.safetensors",
264
+ "encoder.layers.26.mlp.fc1.bias": "model-00002-of-00003.safetensors",
265
+ "encoder.layers.26.mlp.fc1.weight": "model-00002-of-00003.safetensors",
266
+ "encoder.layers.26.mlp.fc2.bias": "model-00002-of-00003.safetensors",
267
+ "encoder.layers.26.mlp.fc2.weight": "model-00002-of-00003.safetensors",
268
+ "encoder.layers.26.norm1.weight": "model-00002-of-00003.safetensors",
269
+ "encoder.layers.26.norm2.weight": "model-00002-of-00003.safetensors",
270
+ "encoder.layers.27.attn.k_norm.weight": "model-00002-of-00003.safetensors",
271
+ "encoder.layers.27.attn.proj.bias": "model-00002-of-00003.safetensors",
272
+ "encoder.layers.27.attn.proj.weight": "model-00002-of-00003.safetensors",
273
+ "encoder.layers.27.attn.q_norm.weight": "model-00002-of-00003.safetensors",
274
+ "encoder.layers.27.attn.qkv.weight": "model-00002-of-00003.safetensors",
275
+ "encoder.layers.27.ls1": "model-00002-of-00003.safetensors",
276
+ "encoder.layers.27.ls2": "model-00002-of-00003.safetensors",
277
+ "encoder.layers.27.mlp.fc1.bias": "model-00002-of-00003.safetensors",
278
+ "encoder.layers.27.mlp.fc1.weight": "model-00002-of-00003.safetensors",
279
+ "encoder.layers.27.mlp.fc2.bias": "model-00002-of-00003.safetensors",
280
+ "encoder.layers.27.mlp.fc2.weight": "model-00002-of-00003.safetensors",
281
+ "encoder.layers.27.norm1.weight": "model-00002-of-00003.safetensors",
282
+ "encoder.layers.27.norm2.weight": "model-00002-of-00003.safetensors",
283
+ "encoder.layers.28.attn.k_norm.weight": "model-00002-of-00003.safetensors",
284
+ "encoder.layers.28.attn.proj.bias": "model-00002-of-00003.safetensors",
285
+ "encoder.layers.28.attn.proj.weight": "model-00002-of-00003.safetensors",
286
+ "encoder.layers.28.attn.q_norm.weight": "model-00002-of-00003.safetensors",
287
+ "encoder.layers.28.attn.qkv.weight": "model-00002-of-00003.safetensors",
288
+ "encoder.layers.28.ls1": "model-00002-of-00003.safetensors",
289
+ "encoder.layers.28.ls2": "model-00002-of-00003.safetensors",
290
+ "encoder.layers.28.mlp.fc1.bias": "model-00002-of-00003.safetensors",
291
+ "encoder.layers.28.mlp.fc1.weight": "model-00002-of-00003.safetensors",
292
+ "encoder.layers.28.mlp.fc2.bias": "model-00002-of-00003.safetensors",
293
+ "encoder.layers.28.mlp.fc2.weight": "model-00002-of-00003.safetensors",
294
+ "encoder.layers.28.norm1.weight": "model-00002-of-00003.safetensors",
295
+ "encoder.layers.28.norm2.weight": "model-00002-of-00003.safetensors",
296
+ "encoder.layers.29.attn.k_norm.weight": "model-00002-of-00003.safetensors",
297
+ "encoder.layers.29.attn.proj.bias": "model-00002-of-00003.safetensors",
298
+ "encoder.layers.29.attn.proj.weight": "model-00002-of-00003.safetensors",
299
+ "encoder.layers.29.attn.q_norm.weight": "model-00002-of-00003.safetensors",
300
+ "encoder.layers.29.attn.qkv.weight": "model-00002-of-00003.safetensors",
301
+ "encoder.layers.29.ls1": "model-00002-of-00003.safetensors",
302
+ "encoder.layers.29.ls2": "model-00002-of-00003.safetensors",
303
+ "encoder.layers.29.mlp.fc1.bias": "model-00002-of-00003.safetensors",
304
+ "encoder.layers.29.mlp.fc1.weight": "model-00002-of-00003.safetensors",
305
+ "encoder.layers.29.mlp.fc2.bias": "model-00002-of-00003.safetensors",
306
+ "encoder.layers.29.mlp.fc2.weight": "model-00002-of-00003.safetensors",
307
+ "encoder.layers.29.norm1.weight": "model-00002-of-00003.safetensors",
308
+ "encoder.layers.29.norm2.weight": "model-00002-of-00003.safetensors",
309
+ "encoder.layers.3.attn.k_norm.weight": "model-00001-of-00003.safetensors",
310
+ "encoder.layers.3.attn.proj.bias": "model-00001-of-00003.safetensors",
311
+ "encoder.layers.3.attn.proj.weight": "model-00001-of-00003.safetensors",
312
+ "encoder.layers.3.attn.q_norm.weight": "model-00001-of-00003.safetensors",
313
+ "encoder.layers.3.attn.qkv.weight": "model-00001-of-00003.safetensors",
314
+ "encoder.layers.3.ls1": "model-00001-of-00003.safetensors",
315
+ "encoder.layers.3.ls2": "model-00001-of-00003.safetensors",
316
+ "encoder.layers.3.mlp.fc1.bias": "model-00001-of-00003.safetensors",
317
+ "encoder.layers.3.mlp.fc1.weight": "model-00001-of-00003.safetensors",
318
+ "encoder.layers.3.mlp.fc2.bias": "model-00001-of-00003.safetensors",
319
+ "encoder.layers.3.mlp.fc2.weight": "model-00001-of-00003.safetensors",
320
+ "encoder.layers.3.norm1.weight": "model-00001-of-00003.safetensors",
321
+ "encoder.layers.3.norm2.weight": "model-00001-of-00003.safetensors",
322
+ "encoder.layers.30.attn.k_norm.weight": "model-00002-of-00003.safetensors",
323
+ "encoder.layers.30.attn.proj.bias": "model-00002-of-00003.safetensors",
324
+ "encoder.layers.30.attn.proj.weight": "model-00002-of-00003.safetensors",
325
+ "encoder.layers.30.attn.q_norm.weight": "model-00002-of-00003.safetensors",
326
+ "encoder.layers.30.attn.qkv.weight": "model-00002-of-00003.safetensors",
327
+ "encoder.layers.30.ls1": "model-00002-of-00003.safetensors",
328
+ "encoder.layers.30.ls2": "model-00002-of-00003.safetensors",
329
+ "encoder.layers.30.mlp.fc1.bias": "model-00002-of-00003.safetensors",
330
+ "encoder.layers.30.mlp.fc1.weight": "model-00002-of-00003.safetensors",
331
+ "encoder.layers.30.mlp.fc2.bias": "model-00002-of-00003.safetensors",
332
+ "encoder.layers.30.mlp.fc2.weight": "model-00002-of-00003.safetensors",
333
+ "encoder.layers.30.norm1.weight": "model-00002-of-00003.safetensors",
334
+ "encoder.layers.30.norm2.weight": "model-00002-of-00003.safetensors",
335
+ "encoder.layers.31.attn.k_norm.weight": "model-00002-of-00003.safetensors",
336
+ "encoder.layers.31.attn.proj.bias": "model-00002-of-00003.safetensors",
337
+ "encoder.layers.31.attn.proj.weight": "model-00002-of-00003.safetensors",
338
+ "encoder.layers.31.attn.q_norm.weight": "model-00002-of-00003.safetensors",
339
+ "encoder.layers.31.attn.qkv.weight": "model-00002-of-00003.safetensors",
340
+ "encoder.layers.31.ls1": "model-00002-of-00003.safetensors",
341
+ "encoder.layers.31.ls2": "model-00002-of-00003.safetensors",
342
+ "encoder.layers.31.mlp.fc1.bias": "model-00002-of-00003.safetensors",
343
+ "encoder.layers.31.mlp.fc1.weight": "model-00002-of-00003.safetensors",
344
+ "encoder.layers.31.mlp.fc2.bias": "model-00002-of-00003.safetensors",
345
+ "encoder.layers.31.mlp.fc2.weight": "model-00002-of-00003.safetensors",
346
+ "encoder.layers.31.norm1.weight": "model-00002-of-00003.safetensors",
347
+ "encoder.layers.31.norm2.weight": "model-00002-of-00003.safetensors",
348
+ "encoder.layers.32.attn.k_norm.weight": "model-00002-of-00003.safetensors",
349
+ "encoder.layers.32.attn.proj.bias": "model-00002-of-00003.safetensors",
350
+ "encoder.layers.32.attn.proj.weight": "model-00002-of-00003.safetensors",
351
+ "encoder.layers.32.attn.q_norm.weight": "model-00002-of-00003.safetensors",
352
+ "encoder.layers.32.attn.qkv.weight": "model-00002-of-00003.safetensors",
353
+ "encoder.layers.32.ls1": "model-00002-of-00003.safetensors",
354
+ "encoder.layers.32.ls2": "model-00002-of-00003.safetensors",
355
+ "encoder.layers.32.mlp.fc1.bias": "model-00002-of-00003.safetensors",
356
+ "encoder.layers.32.mlp.fc1.weight": "model-00002-of-00003.safetensors",
357
+ "encoder.layers.32.mlp.fc2.bias": "model-00002-of-00003.safetensors",
358
+ "encoder.layers.32.mlp.fc2.weight": "model-00002-of-00003.safetensors",
359
+ "encoder.layers.32.norm1.weight": "model-00002-of-00003.safetensors",
360
+ "encoder.layers.32.norm2.weight": "model-00002-of-00003.safetensors",
361
+ "encoder.layers.33.attn.k_norm.weight": "model-00002-of-00003.safetensors",
362
+ "encoder.layers.33.attn.proj.bias": "model-00002-of-00003.safetensors",
363
+ "encoder.layers.33.attn.proj.weight": "model-00002-of-00003.safetensors",
364
+ "encoder.layers.33.attn.q_norm.weight": "model-00002-of-00003.safetensors",
365
+ "encoder.layers.33.attn.qkv.weight": "model-00002-of-00003.safetensors",
366
+ "encoder.layers.33.ls1": "model-00002-of-00003.safetensors",
367
+ "encoder.layers.33.ls2": "model-00002-of-00003.safetensors",
368
+ "encoder.layers.33.mlp.fc1.bias": "model-00002-of-00003.safetensors",
369
+ "encoder.layers.33.mlp.fc1.weight": "model-00002-of-00003.safetensors",
370
+ "encoder.layers.33.mlp.fc2.bias": "model-00002-of-00003.safetensors",
371
+ "encoder.layers.33.mlp.fc2.weight": "model-00002-of-00003.safetensors",
372
+ "encoder.layers.33.norm1.weight": "model-00002-of-00003.safetensors",
373
+ "encoder.layers.33.norm2.weight": "model-00002-of-00003.safetensors",
374
+ "encoder.layers.34.attn.k_norm.weight": "model-00002-of-00003.safetensors",
375
+ "encoder.layers.34.attn.proj.bias": "model-00002-of-00003.safetensors",
376
+ "encoder.layers.34.attn.proj.weight": "model-00002-of-00003.safetensors",
377
+ "encoder.layers.34.attn.q_norm.weight": "model-00002-of-00003.safetensors",
378
+ "encoder.layers.34.attn.qkv.weight": "model-00002-of-00003.safetensors",
379
+ "encoder.layers.34.ls1": "model-00002-of-00003.safetensors",
380
+ "encoder.layers.34.ls2": "model-00002-of-00003.safetensors",
381
+ "encoder.layers.34.mlp.fc1.bias": "model-00002-of-00003.safetensors",
382
+ "encoder.layers.34.mlp.fc1.weight": "model-00002-of-00003.safetensors",
383
+ "encoder.layers.34.mlp.fc2.bias": "model-00002-of-00003.safetensors",
384
+ "encoder.layers.34.mlp.fc2.weight": "model-00002-of-00003.safetensors",
385
+ "encoder.layers.34.norm1.weight": "model-00002-of-00003.safetensors",
386
+ "encoder.layers.34.norm2.weight": "model-00002-of-00003.safetensors",
387
+ "encoder.layers.35.attn.k_norm.weight": "model-00002-of-00003.safetensors",
388
+ "encoder.layers.35.attn.proj.bias": "model-00002-of-00003.safetensors",
389
+ "encoder.layers.35.attn.proj.weight": "model-00002-of-00003.safetensors",
390
+ "encoder.layers.35.attn.q_norm.weight": "model-00002-of-00003.safetensors",
391
+ "encoder.layers.35.attn.qkv.weight": "model-00002-of-00003.safetensors",
392
+ "encoder.layers.35.ls1": "model-00002-of-00003.safetensors",
393
+ "encoder.layers.35.ls2": "model-00002-of-00003.safetensors",
394
+ "encoder.layers.35.mlp.fc1.bias": "model-00002-of-00003.safetensors",
395
+ "encoder.layers.35.mlp.fc1.weight": "model-00002-of-00003.safetensors",
396
+ "encoder.layers.35.mlp.fc2.bias": "model-00002-of-00003.safetensors",
397
+ "encoder.layers.35.mlp.fc2.weight": "model-00002-of-00003.safetensors",
398
+ "encoder.layers.35.norm1.weight": "model-00002-of-00003.safetensors",
399
+ "encoder.layers.35.norm2.weight": "model-00002-of-00003.safetensors",
400
+ "encoder.layers.36.attn.k_norm.weight": "model-00002-of-00003.safetensors",
401
+ "encoder.layers.36.attn.proj.bias": "model-00002-of-00003.safetensors",
402
+ "encoder.layers.36.attn.proj.weight": "model-00002-of-00003.safetensors",
403
+ "encoder.layers.36.attn.q_norm.weight": "model-00002-of-00003.safetensors",
404
+ "encoder.layers.36.attn.qkv.weight": "model-00002-of-00003.safetensors",
405
+ "encoder.layers.36.ls1": "model-00002-of-00003.safetensors",
406
+ "encoder.layers.36.ls2": "model-00002-of-00003.safetensors",
407
+ "encoder.layers.36.mlp.fc1.bias": "model-00002-of-00003.safetensors",
408
+ "encoder.layers.36.mlp.fc1.weight": "model-00002-of-00003.safetensors",
409
+ "encoder.layers.36.mlp.fc2.bias": "model-00002-of-00003.safetensors",
410
+ "encoder.layers.36.mlp.fc2.weight": "model-00002-of-00003.safetensors",
411
+ "encoder.layers.36.norm1.weight": "model-00002-of-00003.safetensors",
412
+ "encoder.layers.36.norm2.weight": "model-00002-of-00003.safetensors",
413
+ "encoder.layers.37.attn.k_norm.weight": "model-00002-of-00003.safetensors",
414
+ "encoder.layers.37.attn.proj.bias": "model-00002-of-00003.safetensors",
415
+ "encoder.layers.37.attn.proj.weight": "model-00002-of-00003.safetensors",
416
+ "encoder.layers.37.attn.q_norm.weight": "model-00002-of-00003.safetensors",
417
+ "encoder.layers.37.attn.qkv.weight": "model-00002-of-00003.safetensors",
418
+ "encoder.layers.37.ls1": "model-00002-of-00003.safetensors",
419
+ "encoder.layers.37.ls2": "model-00002-of-00003.safetensors",
420
+ "encoder.layers.37.mlp.fc1.bias": "model-00002-of-00003.safetensors",
421
+ "encoder.layers.37.mlp.fc1.weight": "model-00002-of-00003.safetensors",
422
+ "encoder.layers.37.mlp.fc2.bias": "model-00002-of-00003.safetensors",
423
+ "encoder.layers.37.mlp.fc2.weight": "model-00002-of-00003.safetensors",
424
+ "encoder.layers.37.norm1.weight": "model-00002-of-00003.safetensors",
425
+ "encoder.layers.37.norm2.weight": "model-00002-of-00003.safetensors",
426
+ "encoder.layers.38.attn.k_norm.weight": "model-00002-of-00003.safetensors",
427
+ "encoder.layers.38.attn.proj.bias": "model-00002-of-00003.safetensors",
428
+ "encoder.layers.38.attn.proj.weight": "model-00002-of-00003.safetensors",
429
+ "encoder.layers.38.attn.q_norm.weight": "model-00002-of-00003.safetensors",
430
+ "encoder.layers.38.attn.qkv.weight": "model-00002-of-00003.safetensors",
431
+ "encoder.layers.38.ls1": "model-00002-of-00003.safetensors",
432
+ "encoder.layers.38.ls2": "model-00002-of-00003.safetensors",
433
+ "encoder.layers.38.mlp.fc1.bias": "model-00002-of-00003.safetensors",
434
+ "encoder.layers.38.mlp.fc1.weight": "model-00002-of-00003.safetensors",
435
+ "encoder.layers.38.mlp.fc2.bias": "model-00002-of-00003.safetensors",
436
+ "encoder.layers.38.mlp.fc2.weight": "model-00002-of-00003.safetensors",
437
+ "encoder.layers.38.norm1.weight": "model-00002-of-00003.safetensors",
438
+ "encoder.layers.38.norm2.weight": "model-00002-of-00003.safetensors",
439
+ "encoder.layers.39.attn.k_norm.weight": "model-00002-of-00003.safetensors",
440
+ "encoder.layers.39.attn.proj.bias": "model-00002-of-00003.safetensors",
441
+ "encoder.layers.39.attn.proj.weight": "model-00002-of-00003.safetensors",
442
+ "encoder.layers.39.attn.q_norm.weight": "model-00002-of-00003.safetensors",
443
+ "encoder.layers.39.attn.qkv.weight": "model-00002-of-00003.safetensors",
444
+ "encoder.layers.39.ls1": "model-00002-of-00003.safetensors",
445
+ "encoder.layers.39.ls2": "model-00002-of-00003.safetensors",
446
+ "encoder.layers.39.mlp.fc1.bias": "model-00002-of-00003.safetensors",
447
+ "encoder.layers.39.mlp.fc1.weight": "model-00002-of-00003.safetensors",
448
+ "encoder.layers.39.mlp.fc2.bias": "model-00002-of-00003.safetensors",
449
+ "encoder.layers.39.mlp.fc2.weight": "model-00002-of-00003.safetensors",
450
+ "encoder.layers.39.norm1.weight": "model-00002-of-00003.safetensors",
451
+ "encoder.layers.39.norm2.weight": "model-00002-of-00003.safetensors",
452
+ "encoder.layers.4.attn.k_norm.weight": "model-00001-of-00003.safetensors",
453
+ "encoder.layers.4.attn.proj.bias": "model-00001-of-00003.safetensors",
454
+ "encoder.layers.4.attn.proj.weight": "model-00001-of-00003.safetensors",
455
+ "encoder.layers.4.attn.q_norm.weight": "model-00001-of-00003.safetensors",
456
+ "encoder.layers.4.attn.qkv.weight": "model-00001-of-00003.safetensors",
457
+ "encoder.layers.4.ls1": "model-00001-of-00003.safetensors",
458
+ "encoder.layers.4.ls2": "model-00001-of-00003.safetensors",
459
+ "encoder.layers.4.mlp.fc1.bias": "model-00001-of-00003.safetensors",
460
+ "encoder.layers.4.mlp.fc1.weight": "model-00001-of-00003.safetensors",
461
+ "encoder.layers.4.mlp.fc2.bias": "model-00001-of-00003.safetensors",
462
+ "encoder.layers.4.mlp.fc2.weight": "model-00001-of-00003.safetensors",
463
+ "encoder.layers.4.norm1.weight": "model-00001-of-00003.safetensors",
464
+ "encoder.layers.4.norm2.weight": "model-00001-of-00003.safetensors",
465
+ "encoder.layers.40.attn.k_norm.weight": "model-00002-of-00003.safetensors",
466
+ "encoder.layers.40.attn.proj.bias": "model-00002-of-00003.safetensors",
467
+ "encoder.layers.40.attn.proj.weight": "model-00002-of-00003.safetensors",
468
+ "encoder.layers.40.attn.q_norm.weight": "model-00002-of-00003.safetensors",
469
+ "encoder.layers.40.attn.qkv.weight": "model-00002-of-00003.safetensors",
470
+ "encoder.layers.40.ls1": "model-00002-of-00003.safetensors",
471
+ "encoder.layers.40.ls2": "model-00002-of-00003.safetensors",
472
+ "encoder.layers.40.mlp.fc1.bias": "model-00003-of-00003.safetensors",
473
+ "encoder.layers.40.mlp.fc1.weight": "model-00003-of-00003.safetensors",
474
+ "encoder.layers.40.mlp.fc2.bias": "model-00003-of-00003.safetensors",
475
+ "encoder.layers.40.mlp.fc2.weight": "model-00003-of-00003.safetensors",
476
+ "encoder.layers.40.norm1.weight": "model-00003-of-00003.safetensors",
477
+ "encoder.layers.40.norm2.weight": "model-00003-of-00003.safetensors",
478
+ "encoder.layers.41.attn.k_norm.weight": "model-00003-of-00003.safetensors",
479
+ "encoder.layers.41.attn.proj.bias": "model-00003-of-00003.safetensors",
480
+ "encoder.layers.41.attn.proj.weight": "model-00003-of-00003.safetensors",
481
+ "encoder.layers.41.attn.q_norm.weight": "model-00003-of-00003.safetensors",
482
+ "encoder.layers.41.attn.qkv.weight": "model-00003-of-00003.safetensors",
483
+ "encoder.layers.41.ls1": "model-00003-of-00003.safetensors",
484
+ "encoder.layers.41.ls2": "model-00003-of-00003.safetensors",
485
+ "encoder.layers.41.mlp.fc1.bias": "model-00003-of-00003.safetensors",
486
+ "encoder.layers.41.mlp.fc1.weight": "model-00003-of-00003.safetensors",
487
+ "encoder.layers.41.mlp.fc2.bias": "model-00003-of-00003.safetensors",
488
+ "encoder.layers.41.mlp.fc2.weight": "model-00003-of-00003.safetensors",
489
+ "encoder.layers.41.norm1.weight": "model-00003-of-00003.safetensors",
490
+ "encoder.layers.41.norm2.weight": "model-00003-of-00003.safetensors",
491
+ "encoder.layers.42.attn.k_norm.weight": "model-00003-of-00003.safetensors",
492
+ "encoder.layers.42.attn.proj.bias": "model-00003-of-00003.safetensors",
493
+ "encoder.layers.42.attn.proj.weight": "model-00003-of-00003.safetensors",
494
+ "encoder.layers.42.attn.q_norm.weight": "model-00003-of-00003.safetensors",
495
+ "encoder.layers.42.attn.qkv.weight": "model-00003-of-00003.safetensors",
496
+ "encoder.layers.42.ls1": "model-00003-of-00003.safetensors",
497
+ "encoder.layers.42.ls2": "model-00003-of-00003.safetensors",
498
+ "encoder.layers.42.mlp.fc1.bias": "model-00003-of-00003.safetensors",
499
+ "encoder.layers.42.mlp.fc1.weight": "model-00003-of-00003.safetensors",
500
+ "encoder.layers.42.mlp.fc2.bias": "model-00003-of-00003.safetensors",
501
+ "encoder.layers.42.mlp.fc2.weight": "model-00003-of-00003.safetensors",
502
+ "encoder.layers.42.norm1.weight": "model-00003-of-00003.safetensors",
503
+ "encoder.layers.42.norm2.weight": "model-00003-of-00003.safetensors",
504
+ "encoder.layers.43.attn.k_norm.weight": "model-00003-of-00003.safetensors",
505
+ "encoder.layers.43.attn.proj.bias": "model-00003-of-00003.safetensors",
506
+ "encoder.layers.43.attn.proj.weight": "model-00003-of-00003.safetensors",
507
+ "encoder.layers.43.attn.q_norm.weight": "model-00003-of-00003.safetensors",
508
+ "encoder.layers.43.attn.qkv.weight": "model-00003-of-00003.safetensors",
509
+ "encoder.layers.43.ls1": "model-00003-of-00003.safetensors",
510
+ "encoder.layers.43.ls2": "model-00003-of-00003.safetensors",
511
+ "encoder.layers.43.mlp.fc1.bias": "model-00003-of-00003.safetensors",
512
+ "encoder.layers.43.mlp.fc1.weight": "model-00003-of-00003.safetensors",
513
+ "encoder.layers.43.mlp.fc2.bias": "model-00003-of-00003.safetensors",
514
+ "encoder.layers.43.mlp.fc2.weight": "model-00003-of-00003.safetensors",
515
+ "encoder.layers.43.norm1.weight": "model-00003-of-00003.safetensors",
516
+ "encoder.layers.43.norm2.weight": "model-00003-of-00003.safetensors",
517
+ "encoder.layers.44.attn.k_norm.weight": "model-00003-of-00003.safetensors",
518
+ "encoder.layers.44.attn.proj.bias": "model-00003-of-00003.safetensors",
519
+ "encoder.layers.44.attn.proj.weight": "model-00003-of-00003.safetensors",
520
+ "encoder.layers.44.attn.q_norm.weight": "model-00003-of-00003.safetensors",
521
+ "encoder.layers.44.attn.qkv.weight": "model-00003-of-00003.safetensors",
522
+ "encoder.layers.44.ls1": "model-00003-of-00003.safetensors",
523
+ "encoder.layers.44.ls2": "model-00003-of-00003.safetensors",
524
+ "encoder.layers.44.mlp.fc1.bias": "model-00003-of-00003.safetensors",
525
+ "encoder.layers.44.mlp.fc1.weight": "model-00003-of-00003.safetensors",
526
+ "encoder.layers.44.mlp.fc2.bias": "model-00003-of-00003.safetensors",
527
+ "encoder.layers.44.mlp.fc2.weight": "model-00003-of-00003.safetensors",
528
+ "encoder.layers.44.norm1.weight": "model-00003-of-00003.safetensors",
529
+ "encoder.layers.44.norm2.weight": "model-00003-of-00003.safetensors",
530
+ "encoder.layers.5.attn.k_norm.weight": "model-00001-of-00003.safetensors",
531
+ "encoder.layers.5.attn.proj.bias": "model-00001-of-00003.safetensors",
532
+ "encoder.layers.5.attn.proj.weight": "model-00001-of-00003.safetensors",
533
+ "encoder.layers.5.attn.q_norm.weight": "model-00001-of-00003.safetensors",
534
+ "encoder.layers.5.attn.qkv.weight": "model-00001-of-00003.safetensors",
535
+ "encoder.layers.5.ls1": "model-00001-of-00003.safetensors",
536
+ "encoder.layers.5.ls2": "model-00001-of-00003.safetensors",
537
+ "encoder.layers.5.mlp.fc1.bias": "model-00001-of-00003.safetensors",
538
+ "encoder.layers.5.mlp.fc1.weight": "model-00001-of-00003.safetensors",
539
+ "encoder.layers.5.mlp.fc2.bias": "model-00001-of-00003.safetensors",
540
+ "encoder.layers.5.mlp.fc2.weight": "model-00001-of-00003.safetensors",
541
+ "encoder.layers.5.norm1.weight": "model-00001-of-00003.safetensors",
542
+ "encoder.layers.5.norm2.weight": "model-00001-of-00003.safetensors",
543
+ "encoder.layers.6.attn.k_norm.weight": "model-00001-of-00003.safetensors",
544
+ "encoder.layers.6.attn.proj.bias": "model-00001-of-00003.safetensors",
545
+ "encoder.layers.6.attn.proj.weight": "model-00001-of-00003.safetensors",
546
+ "encoder.layers.6.attn.q_norm.weight": "model-00001-of-00003.safetensors",
547
+ "encoder.layers.6.attn.qkv.weight": "model-00001-of-00003.safetensors",
548
+ "encoder.layers.6.ls1": "model-00001-of-00003.safetensors",
549
+ "encoder.layers.6.ls2": "model-00001-of-00003.safetensors",
550
+ "encoder.layers.6.mlp.fc1.bias": "model-00001-of-00003.safetensors",
551
+ "encoder.layers.6.mlp.fc1.weight": "model-00001-of-00003.safetensors",
552
+ "encoder.layers.6.mlp.fc2.bias": "model-00001-of-00003.safetensors",
553
+ "encoder.layers.6.mlp.fc2.weight": "model-00001-of-00003.safetensors",
554
+ "encoder.layers.6.norm1.weight": "model-00001-of-00003.safetensors",
555
+ "encoder.layers.6.norm2.weight": "model-00001-of-00003.safetensors",
556
+ "encoder.layers.7.attn.k_norm.weight": "model-00001-of-00003.safetensors",
557
+ "encoder.layers.7.attn.proj.bias": "model-00001-of-00003.safetensors",
558
+ "encoder.layers.7.attn.proj.weight": "model-00001-of-00003.safetensors",
559
+ "encoder.layers.7.attn.q_norm.weight": "model-00001-of-00003.safetensors",
560
+ "encoder.layers.7.attn.qkv.weight": "model-00001-of-00003.safetensors",
561
+ "encoder.layers.7.ls1": "model-00001-of-00003.safetensors",
562
+ "encoder.layers.7.ls2": "model-00001-of-00003.safetensors",
563
+ "encoder.layers.7.mlp.fc1.bias": "model-00001-of-00003.safetensors",
564
+ "encoder.layers.7.mlp.fc1.weight": "model-00001-of-00003.safetensors",
565
+ "encoder.layers.7.mlp.fc2.bias": "model-00001-of-00003.safetensors",
566
+ "encoder.layers.7.mlp.fc2.weight": "model-00001-of-00003.safetensors",
567
+ "encoder.layers.7.norm1.weight": "model-00001-of-00003.safetensors",
568
+ "encoder.layers.7.norm2.weight": "model-00001-of-00003.safetensors",
569
+ "encoder.layers.8.attn.k_norm.weight": "model-00001-of-00003.safetensors",
570
+ "encoder.layers.8.attn.proj.bias": "model-00001-of-00003.safetensors",
571
+ "encoder.layers.8.attn.proj.weight": "model-00001-of-00003.safetensors",
572
+ "encoder.layers.8.attn.q_norm.weight": "model-00001-of-00003.safetensors",
573
+ "encoder.layers.8.attn.qkv.weight": "model-00001-of-00003.safetensors",
574
+ "encoder.layers.8.ls1": "model-00001-of-00003.safetensors",
575
+ "encoder.layers.8.ls2": "model-00001-of-00003.safetensors",
576
+ "encoder.layers.8.mlp.fc1.bias": "model-00001-of-00003.safetensors",
577
+ "encoder.layers.8.mlp.fc1.weight": "model-00001-of-00003.safetensors",
578
+ "encoder.layers.8.mlp.fc2.bias": "model-00001-of-00003.safetensors",
579
+ "encoder.layers.8.mlp.fc2.weight": "model-00001-of-00003.safetensors",
580
+ "encoder.layers.8.norm1.weight": "model-00001-of-00003.safetensors",
581
+ "encoder.layers.8.norm2.weight": "model-00001-of-00003.safetensors",
582
+ "encoder.layers.9.attn.k_norm.weight": "model-00001-of-00003.safetensors",
583
+ "encoder.layers.9.attn.proj.bias": "model-00001-of-00003.safetensors",
584
+ "encoder.layers.9.attn.proj.weight": "model-00001-of-00003.safetensors",
585
+ "encoder.layers.9.attn.q_norm.weight": "model-00001-of-00003.safetensors",
586
+ "encoder.layers.9.attn.qkv.weight": "model-00001-of-00003.safetensors",
587
+ "encoder.layers.9.ls1": "model-00001-of-00003.safetensors",
588
+ "encoder.layers.9.ls2": "model-00001-of-00003.safetensors",
589
+ "encoder.layers.9.mlp.fc1.bias": "model-00001-of-00003.safetensors",
590
+ "encoder.layers.9.mlp.fc1.weight": "model-00001-of-00003.safetensors",
591
+ "encoder.layers.9.mlp.fc2.bias": "model-00001-of-00003.safetensors",
592
+ "encoder.layers.9.mlp.fc2.weight": "model-00001-of-00003.safetensors",
593
+ "encoder.layers.9.norm1.weight": "model-00001-of-00003.safetensors",
594
+ "encoder.layers.9.norm2.weight": "model-00001-of-00003.safetensors"
595
+ }
596
+ }
modeling_intern_vit.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # InternVL
3
+ # Copyright (c) 2023 OpenGVLab
4
+ # Licensed under The MIT License [see LICENSE for details]
5
+ # --------------------------------------------------------
6
+ from typing import Optional, Tuple, Union
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ from einops import rearrange
12
+ from timm.models.layers import DropPath
13
+ from torch import nn
14
+ from transformers.activations import ACT2FN
15
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
16
+ from transformers.modeling_utils import PreTrainedModel
17
+ from transformers.utils import logging
18
+
19
+ from .configuration_intern_vit import InternVisionConfig
20
+
21
+ try:
22
+ from .triton_flash_attn import attention
23
+
24
+ has_flash_attn = True
25
+ except:
26
+ print("attention is not installed.")
27
+ has_flash_attn = False
28
+
29
+
30
+ logger = logging.get_logger(__name__)
31
+
32
+
33
+ class InternRMSNorm(nn.Module):
34
+ def __init__(self, hidden_size, eps=1e-6):
35
+ super().__init__()
36
+ self.weight = nn.Parameter(torch.ones(hidden_size))
37
+ self.variance_epsilon = eps
38
+
39
+ def forward(self, hidden_states):
40
+ input_dtype = hidden_states.dtype
41
+ hidden_states = hidden_states.to(torch.float32)
42
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
43
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
44
+ return self.weight * hidden_states.to(input_dtype)
45
+
46
+
47
+ try:
48
+ from apex.normalization import FusedRMSNorm
49
+
50
+ InternRMSNorm = FusedRMSNorm # noqa
51
+
52
+ logger.info(
53
+ "Discovered apex.normalization.FusedRMSNorm - will use it instead of InternRMSNorm"
54
+ )
55
+ except ImportError:
56
+ # using the normal InternRMSNorm
57
+ pass
58
+ except Exception:
59
+ logger.warning(
60
+ "discovered apex but it failed to load, falling back to InternRMSNorm"
61
+ )
62
+ pass
63
+
64
+
65
+ class InternVisionEmbeddings(nn.Module):
66
+ def __init__(self, config: InternVisionConfig):
67
+ super().__init__()
68
+ self.config = config
69
+ self.embed_dim = config.hidden_size
70
+ self.image_size = config.image_size
71
+ self.patch_size = config.patch_size
72
+
73
+ self.class_embedding = nn.Parameter(
74
+ torch.randn(1, 1, self.embed_dim),
75
+ )
76
+
77
+ self.patch_embedding = nn.Conv2d(
78
+ in_channels=3,
79
+ out_channels=self.embed_dim,
80
+ kernel_size=self.patch_size,
81
+ stride=self.patch_size,
82
+ )
83
+
84
+ self.num_patches = (self.image_size // self.patch_size) ** 2
85
+ self.num_positions = self.num_patches + 1
86
+
87
+ self.position_embedding = nn.Parameter(
88
+ torch.randn(1, self.num_positions, self.embed_dim)
89
+ )
90
+
91
+ def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
92
+ batch_size = pixel_values.shape[0]
93
+ target_dtype = self.patch_embedding.weight.dtype
94
+ patch_embeds = self.patch_embedding(
95
+ pixel_values
96
+ ) # shape = [*, width, grid, grid]
97
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
98
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
99
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
100
+ embeddings = embeddings + self.position_embedding.to(target_dtype)
101
+ return embeddings
102
+
103
+
104
+ class InternAttention(nn.Module):
105
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
106
+
107
+ def __init__(self, config: InternVisionConfig):
108
+ super().__init__()
109
+ self.config = config
110
+ self.embed_dim = config.hidden_size
111
+ self.num_heads = config.num_attention_heads
112
+ self.use_flash_attn = config.use_flash_attn and has_flash_attn
113
+ if config.use_flash_attn and not has_flash_attn:
114
+ print(
115
+ "Warning: Flash Attention is not available, use_flash_attn is set to False."
116
+ )
117
+ self.head_dim = self.embed_dim // self.num_heads
118
+ if self.head_dim * self.num_heads != self.embed_dim:
119
+ raise ValueError(
120
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:"
121
+ f" {self.num_heads})."
122
+ )
123
+
124
+ self.scale = self.head_dim**-0.5
125
+ self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
126
+ self.attn_drop = nn.Dropout(config.attention_dropout)
127
+ self.proj_drop = nn.Dropout(config.dropout)
128
+
129
+ self.qk_normalization = config.qk_normalization
130
+
131
+ if self.qk_normalization:
132
+ self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
133
+ self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
134
+
135
+ if self.use_flash_attn:
136
+ self.inner_attn = attention(attention_dropout=config.attention_dropout)
137
+ self.proj = nn.Linear(self.embed_dim, self.embed_dim)
138
+
139
+ def _naive_attn(self, x):
140
+ B, N, C = x.shape
141
+ qkv = (
142
+ self.qkv(x)
143
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
144
+ .permute(2, 0, 3, 1, 4)
145
+ )
146
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
147
+
148
+ if self.qk_normalization:
149
+ B_, H_, N_, D_ = q.shape
150
+ q = (
151
+ self.q_norm(q.transpose(1, 2).flatten(-2, -1))
152
+ .view(B_, N_, H_, D_)
153
+ .transpose(1, 2)
154
+ )
155
+ k = (
156
+ self.k_norm(k.transpose(1, 2).flatten(-2, -1))
157
+ .view(B_, N_, H_, D_)
158
+ .transpose(1, 2)
159
+ )
160
+
161
+ attn = (q * self.scale) @ k.transpose(-2, -1)
162
+ attn = attn.softmax(dim=-1)
163
+ attn = self.attn_drop(attn)
164
+
165
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
166
+ x = self.proj(x)
167
+ x = self.proj_drop(x)
168
+ return x
169
+
170
+ def _flash_attn(self, x, key_padding_mask=None, need_weights=False):
171
+ qkv = self.qkv(x)
172
+ qkv = rearrange(
173
+ qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
174
+ )
175
+
176
+ if self.qk_normalization:
177
+ q, k, v = qkv.unbind(2)
178
+ q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
179
+ k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
180
+ qkv = torch.stack([q, k, v], dim=2)
181
+
182
+ context, _ = self.inner_attn(
183
+ qkv,
184
+ key_padding_mask=key_padding_mask,
185
+ need_weights=need_weights,
186
+ causal=False,
187
+ )
188
+ outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
189
+ outs = self.proj_drop(outs)
190
+ return outs
191
+
192
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
193
+ x = (
194
+ self._naive_attn(hidden_states)
195
+ if not self.use_flash_attn
196
+ else self._flash_attn(hidden_states)
197
+ )
198
+ return x
199
+
200
+
201
+ class InternMLP(nn.Module):
202
+ def __init__(self, config: InternVisionConfig):
203
+ super().__init__()
204
+ self.config = config
205
+ self.act = ACT2FN[config.hidden_act]
206
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
207
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
208
+
209
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
210
+ hidden_states = self.fc1(hidden_states)
211
+ hidden_states = self.act(hidden_states)
212
+ hidden_states = self.fc2(hidden_states)
213
+ return hidden_states
214
+
215
+
216
+ class InternVisionEncoderLayer(nn.Module):
217
+ def __init__(self, config: InternVisionConfig, drop_path_rate: float):
218
+ super().__init__()
219
+ self.embed_dim = config.hidden_size
220
+ self.intermediate_size = config.intermediate_size
221
+
222
+ self.attn = InternAttention(config)
223
+ self.mlp = InternMLP(config)
224
+ self.norm1 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
225
+ self.norm2 = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
226
+
227
+ self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
228
+ self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
229
+ self.drop_path1 = (
230
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
231
+ )
232
+ self.drop_path2 = (
233
+ DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
234
+ )
235
+
236
+ def forward(
237
+ self,
238
+ hidden_states: torch.Tensor,
239
+ ) -> Tuple[
240
+ torch.FloatTensor,
241
+ Optional[torch.FloatTensor],
242
+ Optional[Tuple[torch.FloatTensor]],
243
+ ]:
244
+ """
245
+ Args:
246
+ hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
247
+ """
248
+ hidden_states = hidden_states + self.drop_path1(
249
+ self.attn(self.norm1(hidden_states)) * self.ls1
250
+ )
251
+
252
+ hidden_states = hidden_states + self.drop_path2(
253
+ self.mlp(self.norm2(hidden_states)) * self.ls2
254
+ )
255
+
256
+ return hidden_states
257
+
258
+
259
+ class InternVisionEncoder(nn.Module):
260
+ """
261
+ Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
262
+ [`InternEncoderLayer`].
263
+
264
+ Args:
265
+ config (`InternConfig`):
266
+ The corresponding vision configuration for the `InternEncoder`.
267
+ """
268
+
269
+ def __init__(self, config: InternVisionConfig):
270
+ super().__init__()
271
+ self.config = config
272
+ # stochastic depth decay rule
273
+ dpr = [
274
+ x.item()
275
+ for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
276
+ ]
277
+ self.layers = nn.ModuleList(
278
+ [
279
+ InternVisionEncoderLayer(config, dpr[idx])
280
+ for idx in range(config.num_hidden_layers)
281
+ ]
282
+ )
283
+ self.gradient_checkpointing = True
284
+
285
+ def forward(
286
+ self,
287
+ inputs_embeds,
288
+ output_hidden_states: Optional[bool] = None,
289
+ return_dict: Optional[bool] = None,
290
+ ) -> Union[Tuple, BaseModelOutput]:
291
+ r"""
292
+ Args:
293
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
294
+ Embedded representation of the inputs. Should be float, not int tokens.
295
+ output_hidden_states (`bool`, *optional*):
296
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
297
+ for more detail.
298
+ return_dict (`bool`, *optional*):
299
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
300
+ """
301
+ output_hidden_states = (
302
+ output_hidden_states
303
+ if output_hidden_states is not None
304
+ else self.config.output_hidden_states
305
+ )
306
+ return_dict = (
307
+ return_dict if return_dict is not None else self.config.use_return_dict
308
+ )
309
+
310
+ encoder_states = () if output_hidden_states else None
311
+ hidden_states = inputs_embeds
312
+
313
+ for idx, encoder_layer in enumerate(self.layers):
314
+ if output_hidden_states:
315
+ encoder_states = encoder_states + (hidden_states,)
316
+ if self.gradient_checkpointing and self.training:
317
+ layer_outputs = torch.utils.checkpoint.checkpoint(
318
+ encoder_layer, hidden_states
319
+ )
320
+ else:
321
+ layer_outputs = encoder_layer(
322
+ hidden_states,
323
+ )
324
+ hidden_states = layer_outputs
325
+
326
+ if output_hidden_states:
327
+ encoder_states = encoder_states + (hidden_states,)
328
+
329
+ if not return_dict:
330
+ return tuple(v for v in [hidden_states, encoder_states] if v is not None)
331
+ return BaseModelOutput(
332
+ last_hidden_state=hidden_states, hidden_states=encoder_states
333
+ )
334
+
335
+
336
+ class InternVisionModel(PreTrainedModel):
337
+ main_input_name = "pixel_values"
338
+ config_class = InternVisionConfig
339
+ _no_split_modules = ["InternVisionEncoderLayer"]
340
+
341
+ def __init__(self, config: InternVisionConfig):
342
+ super().__init__(config)
343
+ self.config = config
344
+
345
+ self.embeddings = InternVisionEmbeddings(config)
346
+ self.encoder = InternVisionEncoder(config)
347
+
348
+ def resize_pos_embeddings(self, old_size, new_size, patch_size):
349
+ pos_emb = self.embeddings.position_embedding
350
+ _, num_positions, embed_dim = pos_emb.shape
351
+ cls_emb = pos_emb[:, :1, :]
352
+ pos_emb = (
353
+ pos_emb[:, 1:, :]
354
+ .reshape(1, old_size // patch_size, old_size // patch_size, -1)
355
+ .permute(0, 3, 1, 2)
356
+ )
357
+ pos_emb = F.interpolate(
358
+ pos_emb.float(),
359
+ size=new_size // patch_size,
360
+ mode="bicubic",
361
+ align_corners=False,
362
+ )
363
+ pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
364
+ pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
365
+ self.embeddings.position_embedding = nn.Parameter(pos_emb)
366
+ logger.info(
367
+ "Resized position embeddings from {} to {}".format(old_size, new_size)
368
+ )
369
+
370
+ def get_input_embeddings(self):
371
+ return self.embeddings
372
+
373
+ def forward(
374
+ self,
375
+ pixel_values: Optional[torch.FloatTensor] = None,
376
+ output_hidden_states: Optional[bool] = None,
377
+ return_dict: Optional[bool] = None,
378
+ pixel_embeds: Optional[torch.FloatTensor] = None,
379
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
380
+ output_hidden_states = (
381
+ output_hidden_states
382
+ if output_hidden_states is not None
383
+ else self.config.output_hidden_states
384
+ )
385
+ return_dict = (
386
+ return_dict if return_dict is not None else self.config.use_return_dict
387
+ )
388
+
389
+ if pixel_values is None and pixel_embeds is None:
390
+ raise ValueError("You have to specify pixel_values or pixel_embeds")
391
+
392
+ if pixel_embeds is not None:
393
+ hidden_states = pixel_embeds
394
+ else:
395
+ if len(pixel_values.shape) == 4:
396
+ hidden_states = self.embeddings(pixel_values)
397
+ else:
398
+ raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
399
+ encoder_outputs = self.encoder(
400
+ inputs_embeds=hidden_states,
401
+ output_hidden_states=output_hidden_states,
402
+ return_dict=return_dict,
403
+ )
404
+ last_hidden_state = encoder_outputs.last_hidden_state
405
+ pooled_output = last_hidden_state[:, 0, :]
406
+
407
+ if not return_dict:
408
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
409
+
410
+ return BaseModelOutputWithPooling(
411
+ last_hidden_state=last_hidden_state,
412
+ pooler_output=pooled_output,
413
+ hidden_states=encoder_outputs.hidden_states,
414
+ attentions=encoder_outputs.attentions,
415
+ )
preprocessor_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": 448,
3
+ "do_center_crop": true,
4
+ "do_normalize": true,
5
+ "do_resize": true,
6
+ "feature_extractor_type": "CLIPFeatureExtractor",
7
+ "image_mean": [
8
+ 0.485,
9
+ 0.456,
10
+ 0.406
11
+ ],
12
+ "image_std": [
13
+ 0.229,
14
+ 0.224,
15
+ 0.225
16
+ ],
17
+ "resample": 3,
18
+ "size": 448
19
+ }
triton_bert_padding.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+
7
+
8
+ class IndexFirstAxis(torch.autograd.Function):
9
+ @staticmethod
10
+ def forward(ctx, input, indices):
11
+ ctx.save_for_backward(indices)
12
+ assert input.ndim >= 2
13
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
14
+ second_dim = other_shape.numel()
15
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
16
+ # return input[indices]
17
+ return torch.gather(
18
+ rearrange(input, "b ... -> b (...)"),
19
+ 0,
20
+ repeat(indices, "z -> z d", d=second_dim),
21
+ ).reshape(-1, *other_shape)
22
+
23
+ @staticmethod
24
+ def backward(ctx, grad_output):
25
+ (indices,) = ctx.saved_tensors
26
+ assert grad_output.ndim >= 2
27
+ other_shape = grad_output.shape[1:]
28
+ grad_output = rearrange(grad_output, "b ... -> b (...)")
29
+ grad_input = torch.zeros(
30
+ [ctx.first_axis_dim, grad_output.shape[1]],
31
+ device=grad_output.device,
32
+ dtype=grad_output.dtype,
33
+ )
34
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
35
+ # grad_input[indices] = grad_output
36
+ grad_input.scatter_(
37
+ 0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output
38
+ )
39
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
40
+
41
+
42
+ index_first_axis = IndexFirstAxis.apply
43
+
44
+
45
+ class IndexPutFirstAxis(torch.autograd.Function):
46
+ @staticmethod
47
+ def forward(ctx, values, indices, first_axis_dim):
48
+ ctx.save_for_backward(indices)
49
+ assert indices.ndim == 1
50
+ assert values.ndim >= 2
51
+ output = torch.zeros(
52
+ first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype
53
+ )
54
+ # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing.
55
+ output[indices] = values
56
+ # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values)
57
+ return output
58
+
59
+ @staticmethod
60
+ def backward(ctx, grad_output):
61
+ (indices,) = ctx.saved_tensors
62
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
63
+ grad_values = grad_output[indices]
64
+ # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1]))
65
+ return grad_values, None, None
66
+
67
+
68
+ index_put_first_axis = IndexPutFirstAxis.apply
69
+
70
+
71
+ class IndexFirstAxisResidual(torch.autograd.Function):
72
+ @staticmethod
73
+ def forward(ctx, input, indices):
74
+ ctx.save_for_backward(indices)
75
+ assert input.ndim >= 2
76
+ ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:]
77
+ second_dim = other_shape.numel()
78
+ # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing.
79
+ output = input[indices]
80
+ # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last
81
+ # memory format to channel_first. In other words, input might not be contiguous.
82
+ # If we don't detach, Pytorch complains about output being a view and is being modified inplace
83
+ return output, input.detach()
84
+
85
+ @staticmethod
86
+ def backward(ctx, grad_output, grad_residual):
87
+ (indices,) = ctx.saved_tensors
88
+ assert grad_output.ndim >= 2
89
+ other_shape = grad_output.shape[1:]
90
+ assert grad_residual.shape[1:] == other_shape
91
+ grad_input = grad_residual
92
+ # grad_input[indices] += grad_output
93
+ indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1)))
94
+ indices = indices.expand_as(grad_output)
95
+ grad_input.scatter_add_(0, indices, grad_output)
96
+ return grad_input.reshape(ctx.first_axis_dim, *other_shape), None
97
+
98
+
99
+ index_first_axis_residual = IndexFirstAxisResidual.apply
100
+
101
+
102
+ def unpad_input(hidden_states, attention_mask):
103
+ """
104
+ Arguments:
105
+ hidden_states: (batch, seqlen, ...)
106
+ attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
107
+ Return:
108
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
109
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
110
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
111
+ max_seqlen_in_batch: int
112
+ """
113
+ seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
114
+ indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
115
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
116
+ cu_seqlens = F.pad(
117
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
118
+ )
119
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
120
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
121
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
122
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
123
+ # so we write custom forward and backward to make it a bit faster.
124
+ return (
125
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
126
+ indices,
127
+ cu_seqlens,
128
+ max_seqlen_in_batch,
129
+ )
130
+
131
+
132
+ def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length):
133
+ """
134
+ Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model).
135
+ The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286).
136
+
137
+ For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is:
138
+ ```
139
+ [
140
+ [2, 3, 0, 0, 0, 0],
141
+ [3, 2, 0, 0, 0, 0],
142
+ [6, 0, 0, 0, 0, 0]
143
+ ]
144
+ ```
145
+ , which refers to the 3D-attention mask:
146
+ ```
147
+ [
148
+ [
149
+ [1, 0, 0, 0, 0, 0],
150
+ [1, 1, 0, 0, 0, 0],
151
+ [0, 0, 1, 0, 0, 0],
152
+ [0, 0, 1, 1, 0, 0],
153
+ [0, 0, 1, 1, 1, 0],
154
+ [0, 0, 0, 0, 0, 1]
155
+ ],
156
+ [
157
+ [1, 0, 0, 0, 0, 0],
158
+ [1, 1, 0, 0, 0, 0],
159
+ [1, 1, 1, 0, 0, 0],
160
+ [0, 0, 0, 1, 0, 0],
161
+ [0, 0, 0, 1, 1, 0],
162
+ [0, 0, 0, 0, 0, 1]
163
+ ],
164
+ [
165
+ [1, 0, 0, 0, 0, 0],
166
+ [1, 1, 0, 0, 0, 0],
167
+ [1, 1, 1, 0, 0, 0],
168
+ [1, 1, 1, 1, 0, 0],
169
+ [1, 1, 1, 1, 1, 0],
170
+ [1, 1, 1, 1, 1, 1]
171
+ ]
172
+ ]
173
+ ```.
174
+
175
+ Arguments:
176
+ hidden_states: (batch, seqlen, ...)
177
+ attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none.
178
+ Return:
179
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
180
+ indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence.
181
+ cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states.
182
+ max_seqlen_in_batch: int
183
+ """
184
+ length = attention_mask_in_length.sum(dim=-1)
185
+ seqlen = attention_mask_in_length.size(-1)
186
+ attention_mask_2d = torch.arange(
187
+ seqlen, device=length.device, dtype=length.dtype
188
+ ).expand(len(length), seqlen) < length.unsqueeze(1)
189
+ real_indices_idx = torch.nonzero(
190
+ attention_mask_in_length.flatten(), as_tuple=False
191
+ ).flatten()
192
+ seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx]
193
+ indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten()
194
+ max_seqlen_in_batch = seqlens_in_batch.max().item()
195
+ cu_seqlens = F.pad(
196
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)
197
+ )
198
+ # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the
199
+ # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim
200
+ # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to
201
+ # index with integer indices. Moreover, torch's index is a bit slower than it needs to be,
202
+ # so we write custom forward and backward to make it a bit faster.
203
+ return (
204
+ index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices),
205
+ indices,
206
+ cu_seqlens,
207
+ max_seqlen_in_batch,
208
+ )
209
+
210
+
211
+ def pad_input(hidden_states, indices, batch, seqlen):
212
+ """
213
+ Arguments:
214
+ hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask.
215
+ indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence.
216
+ batch: int, batch size for the padded sequence.
217
+ seqlen: int, maximum sequence length for the padded sequence.
218
+ Return:
219
+ hidden_states: (batch, seqlen, ...)
220
+ """
221
+ dim = hidden_states.shape[-1]
222
+ # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype)
223
+ # output[indices] = hidden_states
224
+ output = index_put_first_axis(hidden_states, indices, batch * seqlen)
225
+ return rearrange(output, "(b s) ... -> b s ...", b=batch)
triton_flash_attn.py ADDED
@@ -0,0 +1,964 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Fused Attention
3
+ ===============
4
+
5
+ This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf)
6
+ Credits: OpenAI kernel team
7
+
8
+ Extra Credits:
9
+ - Original flash attention paper (https://arxiv.org/abs/2205.14135)
10
+ - Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf)
11
+
12
+ """
13
+
14
+ import pytest
15
+ import torch
16
+
17
+ import triton
18
+ import triton.language as tl
19
+
20
+ # Pick the fp8 data type
21
+
22
+ # AMD E4M3B8
23
+ # Note: When picking this f8 data type, scaling is required when using f8
24
+ # for the second gemm
25
+ # TORCH_HAS_FP8E4B8 = hasattr(torch, 'float8_e4m3fnuz')
26
+
27
+ # AMD E5M2B16
28
+ TORCH_HAS_FP8E5B16 = hasattr(torch, "float8_e5m2fnuz")
29
+
30
+
31
+ @triton.jit
32
+ def _attn_fwd_inner(
33
+ acc,
34
+ l_i,
35
+ m_i,
36
+ q,
37
+ K_block_ptr,
38
+ V_block_ptr,
39
+ start_m,
40
+ BLOCK_M: tl.constexpr,
41
+ BLOCK_DMODEL: tl.constexpr,
42
+ BLOCK_N: tl.constexpr,
43
+ STAGE: tl.constexpr,
44
+ offs_m: tl.constexpr,
45
+ offs_n: tl.constexpr,
46
+ N_CTX,
47
+ pre_load_v: tl.constexpr,
48
+ ):
49
+ # range of values handled by this stage
50
+ if STAGE == 1:
51
+ lo, hi = 0, start_m * BLOCK_M
52
+ elif STAGE == 2:
53
+ lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
54
+ lo = tl.multiple_of(lo, BLOCK_M)
55
+ K_block_ptr = tl.advance(K_block_ptr, (0, lo))
56
+ V_block_ptr = tl.advance(V_block_ptr, (lo, 0))
57
+ # causal = False
58
+ else:
59
+ lo, hi = 0, N_CTX
60
+ # loop over k, v and update accumulator
61
+ for start_n in range(lo, hi, BLOCK_N):
62
+ start_n = tl.multiple_of(start_n, BLOCK_N)
63
+ # -- compute qk ----
64
+ k = tl.load(K_block_ptr)
65
+ if pre_load_v:
66
+ v = tl.load(V_block_ptr)
67
+ qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
68
+ if STAGE == 2:
69
+ mask = offs_m[:, None] >= (start_n + offs_n[None, :])
70
+ qk = tl.where(mask, qk, float("-inf"))
71
+ qk += tl.dot(q, k)
72
+ m_ij = tl.maximum(m_i, tl.max(qk, 1))
73
+ qk = qk - m_ij[:, None]
74
+ p = tl.math.exp2(qk)
75
+ # -- update output accumulator --
76
+ alpha = tl.math.exp2(m_i - m_ij)
77
+ acc = acc * alpha[:, None]
78
+ if not pre_load_v:
79
+ v = tl.load(V_block_ptr)
80
+ acc += tl.dot(p.to(v.dtype), v)
81
+ # -- update m_i and l_i
82
+ l_ij = tl.sum(p, 1)
83
+ l_i = l_i * alpha + l_ij
84
+ # update m_i and l_i
85
+ m_i = m_ij
86
+ V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))
87
+ K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
88
+ return acc, l_i, m_i
89
+
90
+
91
+ # We don't run auto-tuning everytime to keep the tutorial fast. Uncommenting
92
+ # the code below and commenting out the equivalent parameters is convenient for
93
+ # re-tuning.
94
+ @triton.autotune(
95
+ configs=[
96
+ triton.Config(
97
+ {
98
+ "BLOCK_M": 64,
99
+ "BLOCK_N": 16,
100
+ "waves_per_eu": 2,
101
+ "slice_k_tile": 0,
102
+ "pre_load_v": False,
103
+ },
104
+ num_stages=1,
105
+ num_warps=2,
106
+ ),
107
+ triton.Config(
108
+ {
109
+ "BLOCK_M": 64,
110
+ "BLOCK_N": 16,
111
+ "waves_per_eu": 2,
112
+ "slice_k_tile": 32,
113
+ "pre_load_v": False,
114
+ },
115
+ num_stages=1,
116
+ num_warps=2,
117
+ ),
118
+ triton.Config(
119
+ {
120
+ "BLOCK_M": 32,
121
+ "BLOCK_N": 32,
122
+ "waves_per_eu": 2,
123
+ "slice_k_tile": 0,
124
+ "pre_load_v": False,
125
+ },
126
+ num_stages=1,
127
+ num_warps=1,
128
+ ),
129
+ triton.Config(
130
+ {
131
+ "BLOCK_M": 32,
132
+ "BLOCK_N": 32,
133
+ "waves_per_eu": 2,
134
+ "slice_k_tile": 32,
135
+ "pre_load_v": False,
136
+ },
137
+ num_stages=1,
138
+ num_warps=1,
139
+ ),
140
+ triton.Config(
141
+ {
142
+ "BLOCK_M": 64,
143
+ "BLOCK_N": 32,
144
+ "waves_per_eu": 2,
145
+ "slice_k_tile": 0,
146
+ "pre_load_v": False,
147
+ },
148
+ num_stages=1,
149
+ num_warps=2,
150
+ ),
151
+ triton.Config(
152
+ {
153
+ "BLOCK_M": 32,
154
+ "BLOCK_N": 16,
155
+ "waves_per_eu": 3,
156
+ "slice_k_tile": 0,
157
+ "pre_load_v": True,
158
+ },
159
+ num_stages=1,
160
+ num_warps=1,
161
+ ),
162
+ triton.Config(
163
+ {
164
+ "BLOCK_M": 32,
165
+ "BLOCK_N": 16,
166
+ "waves_per_eu": 3,
167
+ "slice_k_tile": 0,
168
+ "pre_load_v": False,
169
+ },
170
+ num_stages=1,
171
+ num_warps=1,
172
+ ),
173
+ ],
174
+ key=["Z", "H", "N_CTX", "STAGE", "BLOCK_DMODEL"],
175
+ )
176
+ @triton.jit
177
+ def _attn_fwd(
178
+ Q,
179
+ K,
180
+ V,
181
+ sm_scale,
182
+ M,
183
+ Out,
184
+ stride_qz,
185
+ stride_qh,
186
+ stride_qm,
187
+ stride_qk,
188
+ stride_kz,
189
+ stride_kh,
190
+ stride_kn,
191
+ stride_kk,
192
+ stride_vz,
193
+ stride_vh,
194
+ stride_vk,
195
+ stride_vn,
196
+ stride_oz,
197
+ stride_oh,
198
+ stride_om,
199
+ stride_on,
200
+ Z,
201
+ H,
202
+ N_CTX,
203
+ BLOCK_DMODEL: tl.constexpr,
204
+ STAGE: tl.constexpr,
205
+ BLOCK_M: tl.constexpr,
206
+ BLOCK_N: tl.constexpr,
207
+ pre_load_v: tl.constexpr,
208
+ ):
209
+ start_m = tl.program_id(0)
210
+ off_hz = tl.program_id(1)
211
+ qvk_offset = off_hz * stride_qh
212
+
213
+ # block pointers
214
+ Q_block_ptr = tl.make_block_ptr(
215
+ base=Q + qvk_offset,
216
+ shape=(N_CTX, BLOCK_DMODEL),
217
+ strides=(stride_qm, stride_qk),
218
+ offsets=(start_m * BLOCK_M, 0),
219
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
220
+ order=(1, 0),
221
+ )
222
+ V_block_ptr = tl.make_block_ptr(
223
+ base=V + qvk_offset,
224
+ shape=(N_CTX, BLOCK_DMODEL),
225
+ strides=(stride_vk, stride_vn),
226
+ offsets=(0, 0),
227
+ block_shape=(BLOCK_N, BLOCK_DMODEL),
228
+ order=(1, 0),
229
+ )
230
+ K_block_ptr = tl.make_block_ptr(
231
+ base=K + qvk_offset,
232
+ shape=(BLOCK_DMODEL, N_CTX),
233
+ strides=(stride_kk, stride_kn),
234
+ offsets=(0, 0),
235
+ block_shape=(BLOCK_DMODEL, BLOCK_N),
236
+ order=(0, 1),
237
+ )
238
+ O_block_ptr = tl.make_block_ptr(
239
+ base=Out + qvk_offset,
240
+ shape=(N_CTX, BLOCK_DMODEL),
241
+ strides=(stride_om, stride_on),
242
+ offsets=(start_m * BLOCK_M, 0),
243
+ block_shape=(BLOCK_M, BLOCK_DMODEL),
244
+ order=(1, 0),
245
+ )
246
+ # initialize offsets
247
+ offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
248
+ offs_n = tl.arange(0, BLOCK_N)
249
+ # initialize pointer to m and l
250
+ m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
251
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
252
+ acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
253
+ # scale sm_scale by log_2(e) and use
254
+ # 2^x instead of exp in the loop because CSE and LICM
255
+ # don't work as expected with `exp` in the loop
256
+ qk_scale = sm_scale * 1.44269504
257
+ # load q: it will stay in SRAM throughout on NV GPUs but in VGPRs on AMD GPUs
258
+ q = tl.load(Q_block_ptr)
259
+ q = (q * qk_scale).to(q.dtype)
260
+ # stage 1: off-band
261
+ # For causal = True, STAGE = 3 and _attn_fwd_inner gets 1 as its STAGE
262
+ # For causal = False, STAGE = 1, and _attn_fwd_inner gets 3 as its STAGE
263
+ if STAGE & 1:
264
+ acc, l_i, m_i = _attn_fwd_inner(
265
+ acc,
266
+ l_i,
267
+ m_i,
268
+ q,
269
+ K_block_ptr,
270
+ V_block_ptr,
271
+ start_m,
272
+ BLOCK_M,
273
+ BLOCK_DMODEL,
274
+ BLOCK_N,
275
+ 4 - STAGE,
276
+ offs_m,
277
+ offs_n,
278
+ N_CTX,
279
+ pre_load_v,
280
+ )
281
+ # stage 2: on-band
282
+ if STAGE & 2:
283
+ # barrier makes it easier for compielr to schedule the
284
+ # two loops independently
285
+ tl.debug_barrier()
286
+ acc, l_i, m_i = _attn_fwd_inner(
287
+ acc,
288
+ l_i,
289
+ m_i,
290
+ q,
291
+ K_block_ptr,
292
+ V_block_ptr,
293
+ start_m,
294
+ BLOCK_M,
295
+ BLOCK_DMODEL,
296
+ BLOCK_N,
297
+ 2,
298
+ offs_m,
299
+ offs_n,
300
+ N_CTX,
301
+ pre_load_v,
302
+ )
303
+ # epilogue
304
+ # write back m
305
+ acc = acc / l_i[:, None]
306
+ m_ptrs = M + off_hz * N_CTX + offs_m
307
+ tl.store(m_ptrs, m_i + tl.math.log2(l_i))
308
+ tl.store(O_block_ptr, acc.to(Out.type.element_ty))
309
+
310
+
311
+ @triton.jit
312
+ def _attn_bwd_preprocess(
313
+ O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr
314
+ ):
315
+ off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
316
+ off_hz = tl.program_id(1)
317
+ off_n = tl.arange(0, D_HEAD)
318
+ o = tl.load(O + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :])
319
+ do = tl.load(
320
+ DO + off_hz * D_HEAD * N_CTX + off_m[:, None] * D_HEAD + off_n[None, :]
321
+ ).to(tl.float32)
322
+ delta = tl.sum(o * do, axis=1)
323
+ tl.store(Delta + off_hz * N_CTX + off_m, delta)
324
+
325
+
326
+ # The main inner-loop logic for computing dK and dV.
327
+ @triton.jit
328
+ def _attn_bwd_dkdv(
329
+ dk,
330
+ dv,
331
+ Q,
332
+ k,
333
+ v,
334
+ sm_scale,
335
+ DO,
336
+ M,
337
+ D,
338
+ # shared by Q/K/V/DO.
339
+ stride_tok,
340
+ stride_d,
341
+ H,
342
+ N_CTX,
343
+ BLOCK_M1: tl.constexpr,
344
+ BLOCK_N1: tl.constexpr,
345
+ BLOCK_DMODEL: tl.constexpr,
346
+ # Filled in by the wrapper.
347
+ start_n,
348
+ start_m,
349
+ num_steps,
350
+ MASK: tl.constexpr,
351
+ ):
352
+ offs_m = start_m + tl.arange(0, BLOCK_M1)
353
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
354
+ offs_k = tl.arange(0, BLOCK_DMODEL)
355
+ QT_block_ptr = tl.make_block_ptr(
356
+ base=Q,
357
+ shape=(BLOCK_DMODEL, N_CTX),
358
+ strides=(stride_d, stride_tok),
359
+ offsets=(0, start_m),
360
+ block_shape=(BLOCK_DMODEL, BLOCK_M1),
361
+ order=(0, 1),
362
+ )
363
+ DO_block_ptr = tl.make_block_ptr(
364
+ base=DO,
365
+ shape=(N_CTX, BLOCK_DMODEL),
366
+ strides=(stride_tok, stride_d),
367
+ offsets=(start_m, 0),
368
+ block_shape=(BLOCK_M1, BLOCK_DMODEL),
369
+ order=(1, 0),
370
+ )
371
+ # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
372
+ tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
373
+ curr_m = start_m
374
+ step_m = BLOCK_M1
375
+ for blk_idx in range(num_steps):
376
+ qT = tl.load(QT_block_ptr)
377
+ # Load m before computing qk to reduce pipeline stall.
378
+ offs_m = curr_m + tl.arange(0, BLOCK_M1)
379
+ m = tl.load(M + offs_m)
380
+ qkT = tl.dot(k, qT)
381
+ pT = tl.math.exp2(qkT - m[None, :])
382
+ # Autoregressive masking.
383
+ if MASK:
384
+ mask = offs_m[None, :] >= offs_n[:, None]
385
+ pT = tl.where(mask, pT, 0.0)
386
+ do = tl.load(DO_block_ptr)
387
+ # Compute dV.
388
+ ppT = pT
389
+ ppT = ppT.to(tl.float16)
390
+ dv += tl.dot(ppT, do)
391
+ # D (= delta) is pre-divided by ds_scale.
392
+ Di = tl.load(D + offs_m)
393
+ # Compute dP and dS.
394
+ dpT = tl.dot(v, tl.trans(do))
395
+ dsT = pT * (dpT - Di[None, :])
396
+ dsT = dsT.to(tl.float16)
397
+ dk += tl.dot(dsT, tl.trans(qT))
398
+ # Increment pointers.
399
+ curr_m += step_m
400
+ QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m))
401
+ DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0))
402
+ return dk, dv
403
+
404
+
405
+ # the main inner-loop logic for computing dQ
406
+ @triton.jit
407
+ def _attn_bwd_dq(
408
+ dq,
409
+ q,
410
+ K,
411
+ V,
412
+ do,
413
+ m,
414
+ D,
415
+ # shared by Q/K/V/DO.
416
+ stride_tok,
417
+ stride_d,
418
+ H,
419
+ N_CTX,
420
+ BLOCK_M2: tl.constexpr,
421
+ BLOCK_N2: tl.constexpr,
422
+ BLOCK_DMODEL: tl.constexpr,
423
+ # Filled in by the wrapper.
424
+ start_m,
425
+ start_n,
426
+ num_steps,
427
+ MASK: tl.constexpr,
428
+ ):
429
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
430
+ offs_n = start_n + tl.arange(0, BLOCK_N2)
431
+ offs_k = tl.arange(0, BLOCK_DMODEL)
432
+ KT_block_ptr = tl.make_block_ptr(
433
+ base=K,
434
+ shape=(BLOCK_DMODEL, N_CTX),
435
+ strides=(stride_d, stride_tok),
436
+ offsets=(0, start_n),
437
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
438
+ order=(0, 1),
439
+ )
440
+ VT_block_ptr = tl.make_block_ptr(
441
+ base=V,
442
+ shape=(BLOCK_DMODEL, N_CTX),
443
+ strides=(stride_d, stride_tok),
444
+ offsets=(0, start_n),
445
+ block_shape=(BLOCK_DMODEL, BLOCK_N2),
446
+ order=(0, 1),
447
+ )
448
+ # D (= delta) is pre-divided by ds_scale.
449
+ Di = tl.load(D + offs_m)
450
+ # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
451
+ tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
452
+ curr_n = start_n
453
+ step_n = BLOCK_N2
454
+ for blk_idx in range(num_steps):
455
+ kT = tl.load(KT_block_ptr)
456
+ qk = tl.dot(q, kT)
457
+ p = tl.math.exp2(qk - m)
458
+ # Autoregressive masking.
459
+ if MASK:
460
+ offs_n = curr_n + tl.arange(0, BLOCK_N2)
461
+ mask = offs_m[:, None] >= offs_n[None, :]
462
+ p = tl.where(mask, p, 0.0)
463
+ # Compute dP and dS.
464
+ vT = tl.load(VT_block_ptr)
465
+ dp = tl.dot(do, vT).to(tl.float32)
466
+ ds = p * (dp - Di[:, None])
467
+ ds = ds.to(tl.float16)
468
+ # Compute dQ.
469
+ # NOTE: We need to de-scale dq in the end, because kT was pre-scaled.
470
+ dq += tl.dot(ds, tl.trans(kT))
471
+ # Increment pointers.
472
+ curr_n += step_n
473
+ KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n))
474
+ VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
475
+ return dq
476
+
477
+
478
+ @triton.autotune(
479
+ configs=[
480
+ triton.Config(
481
+ {
482
+ "BLOCK_M1": 32,
483
+ "BLOCK_N1": 64,
484
+ "BLOCK_M2": 64,
485
+ "BLOCK_N2": 32,
486
+ "BLK_SLICE_FACTOR": 1,
487
+ },
488
+ num_stages=1,
489
+ num_warps=4,
490
+ ),
491
+ triton.Config(
492
+ {
493
+ "BLOCK_M1": 32,
494
+ "BLOCK_N1": 64,
495
+ "BLOCK_M2": 64,
496
+ "BLOCK_N2": 32,
497
+ "BLK_SLICE_FACTOR": 2,
498
+ },
499
+ num_stages=1,
500
+ num_warps=4,
501
+ ),
502
+ triton.Config(
503
+ {
504
+ "BLOCK_M1": 64,
505
+ "BLOCK_N1": 128,
506
+ "BLOCK_M2": 128,
507
+ "BLOCK_N2": 64,
508
+ "BLK_SLICE_FACTOR": 1,
509
+ },
510
+ num_stages=1,
511
+ num_warps=4,
512
+ ),
513
+ triton.Config(
514
+ {
515
+ "BLOCK_M1": 64,
516
+ "BLOCK_N1": 128,
517
+ "BLOCK_M2": 128,
518
+ "BLOCK_N2": 64,
519
+ "BLK_SLICE_FACTOR": 2,
520
+ },
521
+ num_stages=1,
522
+ num_warps=4,
523
+ ),
524
+ triton.Config(
525
+ {
526
+ "BLOCK_M1": 64,
527
+ "BLOCK_N1": 64,
528
+ "BLOCK_M2": 64,
529
+ "BLOCK_N2": 64,
530
+ "BLK_SLICE_FACTOR": 1,
531
+ },
532
+ num_stages=1,
533
+ num_warps=4,
534
+ ),
535
+ triton.Config(
536
+ {
537
+ "BLOCK_M1": 64,
538
+ "BLOCK_N1": 64,
539
+ "BLOCK_M2": 64,
540
+ "BLOCK_N2": 64,
541
+ "BLK_SLICE_FACTOR": 2,
542
+ },
543
+ num_stages=1,
544
+ num_warps=4,
545
+ ),
546
+ triton.Config(
547
+ {
548
+ "BLOCK_M1": 32,
549
+ "BLOCK_N1": 128,
550
+ "BLOCK_M2": 128,
551
+ "BLOCK_N2": 32,
552
+ "BLK_SLICE_FACTOR": 1,
553
+ },
554
+ num_stages=1,
555
+ num_warps=4,
556
+ ),
557
+ triton.Config(
558
+ {
559
+ "BLOCK_M1": 32,
560
+ "BLOCK_N1": 128,
561
+ "BLOCK_M2": 128,
562
+ "BLOCK_N2": 32,
563
+ "BLK_SLICE_FACTOR": 2,
564
+ },
565
+ num_stages=1,
566
+ num_warps=4,
567
+ ),
568
+ triton.Config(
569
+ {
570
+ "BLOCK_M1": 32,
571
+ "BLOCK_N1": 128,
572
+ "BLOCK_M2": 128,
573
+ "BLOCK_N2": 32,
574
+ "BLK_SLICE_FACTOR": 2,
575
+ },
576
+ num_stages=1,
577
+ num_warps=8,
578
+ ),
579
+ ],
580
+ key=["H", "N_CTX", "BLOCK_DMODEL"],
581
+ )
582
+ @triton.jit
583
+ def _attn_bwd(
584
+ Q,
585
+ K,
586
+ V,
587
+ sm_scale,
588
+ DO,
589
+ DQ,
590
+ DK,
591
+ DV,
592
+ M,
593
+ D,
594
+ # shared by Q/K/V/DO.
595
+ stride_z,
596
+ stride_h,
597
+ stride_tok,
598
+ stride_d,
599
+ # H = 16, N_CTX = 1024
600
+ H,
601
+ N_CTX,
602
+ BLOCK_DMODEL: tl.constexpr,
603
+ BLOCK_M1: tl.constexpr,
604
+ BLOCK_N1: tl.constexpr,
605
+ BLOCK_M2: tl.constexpr,
606
+ BLOCK_N2: tl.constexpr,
607
+ BLK_SLICE_FACTOR: tl.constexpr,
608
+ ):
609
+ LN2: tl.constexpr = 0.6931471824645996 # = ln(2)
610
+
611
+ bhid = tl.program_id(2)
612
+ off_chz = (bhid * N_CTX).to(tl.int64)
613
+ adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64)
614
+ pid = tl.program_id(0)
615
+
616
+ # offset pointers for batch/head
617
+ Q += adj
618
+ K += adj
619
+ V += adj
620
+ DO += adj
621
+ DQ += adj
622
+ DK += adj
623
+ DV += adj
624
+ M += off_chz
625
+ D += off_chz
626
+
627
+ offs_k = tl.arange(0, BLOCK_DMODEL)
628
+
629
+ start_n = pid * BLOCK_N1
630
+ # This assignment is important. It is what allows us to pick the diagonal
631
+ # blocks. Later, when we want to do the lower triangular, we update start_m
632
+ # after the first dkdv call.
633
+ start_m = start_n
634
+
635
+ MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR
636
+ offs_n = start_n + tl.arange(0, BLOCK_N1)
637
+
638
+ dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
639
+ dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32)
640
+
641
+ K_block_ptr = tl.make_block_ptr(
642
+ base=K,
643
+ shape=(N_CTX, BLOCK_DMODEL),
644
+ strides=(stride_tok, stride_d),
645
+ offsets=(start_n, 0),
646
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
647
+ order=(1, 0),
648
+ )
649
+ V_block_ptr = tl.make_block_ptr(
650
+ base=V,
651
+ shape=(N_CTX, BLOCK_DMODEL),
652
+ strides=(stride_tok, stride_d),
653
+ offsets=(start_n, 0),
654
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
655
+ order=(1, 0),
656
+ )
657
+
658
+ # load K and V: they stay in SRAM throughout the inner loop for dkdv.
659
+ k = tl.load(K_block_ptr)
660
+ v = tl.load(V_block_ptr)
661
+
662
+ num_steps = BLOCK_N1 // MASK_BLOCK_M1
663
+
664
+ dk, dv = _attn_bwd_dkdv(
665
+ dk,
666
+ dv,
667
+ Q,
668
+ k,
669
+ v,
670
+ sm_scale,
671
+ DO,
672
+ M,
673
+ D,
674
+ stride_tok,
675
+ stride_d,
676
+ H,
677
+ N_CTX,
678
+ MASK_BLOCK_M1,
679
+ BLOCK_N1,
680
+ BLOCK_DMODEL,
681
+ start_n,
682
+ start_m,
683
+ num_steps,
684
+ MASK=True,
685
+ )
686
+
687
+ start_m += num_steps * MASK_BLOCK_M1
688
+ num_steps = (N_CTX - start_m) // BLOCK_M1
689
+
690
+ # Compute dK and dV for non-masked blocks.
691
+ dk, dv = _attn_bwd_dkdv(
692
+ dk,
693
+ dv,
694
+ Q,
695
+ k,
696
+ v,
697
+ sm_scale,
698
+ DO,
699
+ M,
700
+ D,
701
+ stride_tok,
702
+ stride_d,
703
+ H,
704
+ N_CTX,
705
+ BLOCK_M1,
706
+ BLOCK_N1,
707
+ BLOCK_DMODEL,
708
+ start_n,
709
+ start_m,
710
+ num_steps,
711
+ MASK=False,
712
+ )
713
+
714
+ DV_block_ptrs = tl.make_block_ptr(
715
+ base=DV,
716
+ shape=(N_CTX, BLOCK_DMODEL),
717
+ strides=(stride_tok, stride_d),
718
+ offsets=(start_n, 0),
719
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
720
+ order=(1, 0),
721
+ )
722
+ tl.store(DV_block_ptrs, dv.to(tl.float16))
723
+
724
+ # Write back dK.
725
+ dk *= sm_scale
726
+ DK_block_ptrs = tl.make_block_ptr(
727
+ base=DK,
728
+ shape=(N_CTX, BLOCK_DMODEL),
729
+ strides=(stride_tok, stride_d),
730
+ offsets=(start_n, 0),
731
+ block_shape=(BLOCK_N1, BLOCK_DMODEL),
732
+ order=(1, 0),
733
+ )
734
+ tl.store(DK_block_ptrs, dk.to(tl.float16))
735
+
736
+ # THIS BLOCK DOES DQ:
737
+ start_m = pid * BLOCK_M2
738
+ end_n = start_m + BLOCK_M2
739
+
740
+ MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR
741
+ offs_m = start_m + tl.arange(0, BLOCK_M2)
742
+
743
+ Q_block_ptr = tl.make_block_ptr(
744
+ base=Q,
745
+ shape=(N_CTX, BLOCK_DMODEL),
746
+ strides=(stride_tok, stride_d),
747
+ offsets=(start_m, 0),
748
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
749
+ order=(1, 0),
750
+ )
751
+
752
+ DO_block_ptr = tl.make_block_ptr(
753
+ base=DO,
754
+ shape=(N_CTX, BLOCK_DMODEL),
755
+ strides=(stride_tok, stride_d),
756
+ offsets=(start_m, 0),
757
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
758
+ order=(1, 0),
759
+ )
760
+ q = tl.load(Q_block_ptr)
761
+ do = tl.load(DO_block_ptr)
762
+ dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32)
763
+
764
+ m = tl.load(M + offs_m)
765
+ m = m[:, None]
766
+
767
+ # Compute dQ for masked (diagonal) blocks.
768
+ # NOTE: This code scans each row of QK^T backward (from right to left,
769
+ # but inside each call to _attn_bwd_dq, from left to right), but that's
770
+ # not due to anything important. I just wanted to reuse the loop
771
+ # structure for dK & dV above as much as possible.
772
+ num_steps = BLOCK_M2 // MASK_BLOCK_N2
773
+ dq = _attn_bwd_dq(
774
+ dq,
775
+ q,
776
+ K,
777
+ V,
778
+ do,
779
+ m,
780
+ D,
781
+ stride_tok,
782
+ stride_d,
783
+ H,
784
+ N_CTX,
785
+ BLOCK_M2,
786
+ MASK_BLOCK_N2,
787
+ BLOCK_DMODEL,
788
+ start_m,
789
+ end_n - num_steps * MASK_BLOCK_N2,
790
+ num_steps,
791
+ MASK=True,
792
+ )
793
+ end_n -= num_steps * MASK_BLOCK_N2
794
+ # stage 2
795
+ num_steps = end_n // BLOCK_N2
796
+ dq = _attn_bwd_dq(
797
+ dq,
798
+ q,
799
+ K,
800
+ V,
801
+ do,
802
+ m,
803
+ D,
804
+ stride_tok,
805
+ stride_d,
806
+ H,
807
+ N_CTX,
808
+ BLOCK_M2,
809
+ BLOCK_N2,
810
+ BLOCK_DMODEL,
811
+ start_m,
812
+ end_n - num_steps * BLOCK_N2,
813
+ num_steps,
814
+ MASK=False,
815
+ )
816
+ # Write back dQ.
817
+ DQ_block_ptr = tl.make_block_ptr(
818
+ base=DQ,
819
+ shape=(N_CTX, BLOCK_DMODEL),
820
+ strides=(stride_tok, stride_d),
821
+ offsets=(start_m, 0),
822
+ block_shape=(BLOCK_M2, BLOCK_DMODEL),
823
+ order=(1, 0),
824
+ )
825
+ dq *= LN2
826
+ tl.store(DQ_block_ptr, dq.to(tl.float16))
827
+
828
+
829
+ empty = torch.empty(128, device="cuda")
830
+
831
+
832
+ class _attention(torch.autograd.Function):
833
+
834
+ @staticmethod
835
+ def forward(ctx, q, k, v, causal, sm_scale):
836
+ # shape constraints
837
+ Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
838
+ assert Lq == Lk and Lk == Lv
839
+ assert Lk in {16, 32, 64, 128}
840
+ o = torch.empty_like(q, dtype=v.dtype)
841
+ if torch.version.hip is None:
842
+ BLOCK_M = 128
843
+ BLOCK_N = 64 if Lk <= 64 else 32
844
+ num_stages = 4 if Lk <= 64 else 3
845
+ num_warps = 4 if Lk <= 64 else 8
846
+ # Tuning for H100
847
+ if torch.cuda.get_device_capability()[0] == 9:
848
+ num_warps = 8
849
+ num_stages = 7 if Lk >= 64 else 3
850
+ stage = 3 if causal else 1
851
+
852
+ def grid(META):
853
+ return (
854
+ triton.cdiv(q.shape[2], META["BLOCK_M"]),
855
+ q.shape[0] * q.shape[1],
856
+ 1,
857
+ )
858
+
859
+ M = torch.empty(
860
+ (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
861
+ )
862
+ _attn_fwd[grid](
863
+ q,
864
+ k,
865
+ v,
866
+ sm_scale,
867
+ M,
868
+ o,
869
+ q.stride(0),
870
+ q.stride(1),
871
+ q.stride(2),
872
+ q.stride(3),
873
+ k.stride(0),
874
+ k.stride(1),
875
+ k.stride(2),
876
+ k.stride(3),
877
+ v.stride(0),
878
+ v.stride(1),
879
+ v.stride(2),
880
+ v.stride(3),
881
+ o.stride(0),
882
+ o.stride(1),
883
+ o.stride(2),
884
+ o.stride(3),
885
+ q.shape[0],
886
+ q.shape[1],
887
+ N_CTX=q.shape[2],
888
+ BLOCK_DMODEL=Lk,
889
+ STAGE=stage,
890
+ )
891
+
892
+ # restore the grid for bwd kernel
893
+ best_config = _attn_fwd.get_best_config()
894
+ block_m = int(best_config.__str__().split(",")[0].split("BLOCK_M:")[1])
895
+ grid = (triton.cdiv(q.shape[2], block_m), q.shape[0] * q.shape[1], 1)
896
+
897
+ ctx.save_for_backward(q, k, v, o, M)
898
+ ctx.grid = grid
899
+ ctx.sm_scale = sm_scale
900
+ ctx.BLOCK_DMODEL = Lk
901
+ ctx.causal = causal
902
+ return o
903
+
904
+ @staticmethod
905
+ def backward(ctx, do):
906
+ if torch.version.hip is not None:
907
+ BLOCK = 64
908
+ else:
909
+ BLOCK = 128
910
+ q, k, v, o, M = ctx.saved_tensors
911
+ assert do.is_contiguous()
912
+ assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
913
+ dq = torch.empty_like(q)
914
+ dk = torch.empty_like(k)
915
+ dv = torch.empty_like(v)
916
+ BATCH, N_HEAD, N_CTX = q.shape[:3]
917
+ PRE_BLOCK = 128
918
+ NUM_WARPS, NUM_STAGES = 4, 1
919
+ BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32
920
+ BLK_SLICE_FACTOR = 2
921
+ RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2)
922
+ arg_k = k
923
+ arg_k = arg_k * (ctx.sm_scale * RCP_LN2)
924
+ assert N_CTX % PRE_BLOCK == 0
925
+ pre_grid = (N_CTX // PRE_BLOCK, BATCH * N_HEAD)
926
+ delta = torch.empty_like(M)
927
+ _attn_bwd_preprocess[pre_grid](
928
+ o,
929
+ do,
930
+ delta,
931
+ BATCH,
932
+ N_HEAD,
933
+ N_CTX,
934
+ BLOCK_M=PRE_BLOCK,
935
+ D_HEAD=ctx.BLOCK_DMODEL,
936
+ )
937
+
938
+ def grid(META):
939
+ return (triton.cdiv(N_CTX, META["BLOCK_N1"]), 1, BATCH * N_HEAD)
940
+
941
+ _attn_bwd[grid](
942
+ q,
943
+ arg_k,
944
+ v,
945
+ ctx.sm_scale,
946
+ do,
947
+ dq,
948
+ dk,
949
+ dv,
950
+ M,
951
+ delta,
952
+ q.stride(0),
953
+ q.stride(1),
954
+ q.stride(2),
955
+ q.stride(3),
956
+ N_HEAD,
957
+ N_CTX,
958
+ BLOCK_DMODEL=ctx.BLOCK_DMODEL,
959
+ )
960
+
961
+ return dq, dk, dv, None, None
962
+
963
+
964
+ attention = _attention.apply