first commit
Browse files- .gitattributes +9 -0
- app.py +56 -0
- ckpt/svitt-ego.pth +3 -0
- configs/base.yml +21 -0
- configs/charades_ego/action-recognition.yaml +37 -0
- configs/charades_ego/svitt.yml +14 -0
- configs/config_bert.json +21 -0
- data/charades_ego/Charades_v1_classes.txt +157 -0
- data/charades_ego/csv/0.csv +2 -0
- data/charades_ego/csv/1.csv +2 -0
- data/charades_ego/csv/2.csv +2 -0
- data/charades_ego/csv/3.csv +2 -0
- data/charades_ego/csv/4.csv +2 -0
- data/charades_ego/csv/5.csv +2 -0
- data/charades_ego/csv/6.csv +2 -0
- data/charades_ego/csv/7.csv +2 -0
- data/charades_ego/csv/8.csv +2 -0
- data/charades_ego/csv/9.csv +2 -0
- data/charades_ego/video/15AKPEGO.mp4 +3 -0
- data/charades_ego/video/184EHEGO.mp4 +3 -0
- data/charades_ego/video/6D5DHEGO.mp4 +0 -0
- data/charades_ego/video/CC0LBEGO.mp4 +3 -0
- data/charades_ego/video/FLY2FEGO.mp4 +3 -0
- data/charades_ego/video/P9SOAEGO.mp4 +3 -0
- data/charades_ego/video/PRODQEGO.mp4 +3 -0
- data/charades_ego/video/QLXEXEGO.mp4 +3 -0
- data/charades_ego/video/S8YZIEGO.mp4 +3 -0
- data/charades_ego/video/X2JTKEGO.mp4 +3 -0
- demo.py +226 -0
- meta/charades_ego/label_map.json +1 -0
- requirements.txt +13 -0
- svitt/config.py +37 -0
- svitt/datasets.py +526 -0
- svitt/evaluation.py +36 -0
- svitt/evaluation_charades.py +56 -0
- svitt/model.py +340 -0
- svitt/preprocess.py +86 -0
- svitt/sparse_config.py +351 -0
- svitt/sparse_xbeit.py +1585 -0
- svitt/sparse_xbert.py +2039 -0
- svitt/tokenization_bert.py +546 -0
- svitt/utils.py +235 -0
- svitt/video_transforms.py +186 -0
.gitattributes
CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
data/charades_ego/video/15AKPEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
37 |
+
data/charades_ego/video/184EHEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
38 |
+
data/charades_ego/video/CC0LBEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
39 |
+
data/charades_ego/video/FLY2FEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
40 |
+
data/charades_ego/video/P9SOAEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/charades_ego/video/PRODQEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/charades_ego/video/QLXEXEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
43 |
+
data/charades_ego/video/S8YZIEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
44 |
+
data/charades_ego/video/X2JTKEGO.mp4 filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
from demo import VideoCLSModel
|
4 |
+
|
5 |
+
sample_videos = [
|
6 |
+
'data/charades_ego/video/P9SOAEGO.mp4',
|
7 |
+
'data/charades_ego/video/6D5DHEGO.mp4',
|
8 |
+
'data/charades_ego/video/15AKPEGO.mp4',
|
9 |
+
'data/charades_ego/video/X2JTKEGO.mp4',
|
10 |
+
'data/charades_ego/video/184EHEGO.mp4',
|
11 |
+
'data/charades_ego/video/S8YZIEGO.mp4',
|
12 |
+
'data/charades_ego/video/PRODQEGO.mp4',
|
13 |
+
'data/charades_ego/video/QLXEXEGO.mp4',
|
14 |
+
'data/charades_ego/video/CC0LBEGO.mp4',
|
15 |
+
'data/charades_ego/video/FLY2FEGO.mp4'
|
16 |
+
]
|
17 |
+
|
18 |
+
def main():
|
19 |
+
svitt = VideoCLSModel("configs/charades_ego/svitt.yml")
|
20 |
+
|
21 |
+
def predict(video_str):
|
22 |
+
video_file = video_str.split('/')[-1]
|
23 |
+
for i, item in enumerate(sample_videos):
|
24 |
+
if video_file in item:
|
25 |
+
idx = i
|
26 |
+
break
|
27 |
+
|
28 |
+
ft_action, gt_action = svitt.predict(idx)
|
29 |
+
|
30 |
+
return gt_action, ft_action
|
31 |
+
|
32 |
+
with gr.Blocks() as demo:
|
33 |
+
gr.Markdown(
|
34 |
+
"""
|
35 |
+
# SViTT-Ego for Action Recognition
|
36 |
+
Choose a sample video and click predict to view the results.
|
37 |
+
"""
|
38 |
+
)
|
39 |
+
with gr.Row():
|
40 |
+
idx = gr.Number(label="Idx", visible=False)
|
41 |
+
video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256)
|
42 |
+
with gr.Row():
|
43 |
+
label = gr.Text(label="Ground Truth")
|
44 |
+
ours = gr.Text(label="SViTT-Ego prediction")
|
45 |
+
with gr.Row():
|
46 |
+
btn = gr.Button("Predict", variant="primary")
|
47 |
+
btn.click(predict, inputs=[video], outputs=[label, ours])
|
48 |
+
with gr.Column():
|
49 |
+
gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video])
|
50 |
+
|
51 |
+
demo.launch()
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
main()
|
56 |
+
|
ckpt/svitt-ego.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:73d9612778da3471372bc46a9e10fd6c1ed66dd2c7a715bf34472d795bd0bf58
|
3 |
+
size 2500516566
|
configs/base.yml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
pretrain: ""
|
3 |
+
resume: ""
|
4 |
+
timesformer_freeze_space: false
|
5 |
+
drop_path_rate: 0.1
|
6 |
+
dropout_ratio: 0.5
|
7 |
+
freeze_vis_backbone: false
|
8 |
+
freeze_txt_backbone: false
|
9 |
+
use_vn_classifier: false
|
10 |
+
|
11 |
+
data:
|
12 |
+
dataset: ek100_mir
|
13 |
+
root: datasets/EK100/video_ht256px
|
14 |
+
metadata: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
|
15 |
+
metadata_val: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
|
16 |
+
relevancy_path: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
|
17 |
+
clip_length: 16
|
18 |
+
clip_stride: 4
|
19 |
+
sparse_sample: false
|
20 |
+
num_crops: 1
|
21 |
+
num_clips: 1
|
configs/charades_ego/action-recognition.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
text_encoder: bert-base-uncased
|
3 |
+
bert_config: configs/config_bert.json
|
4 |
+
vit_type: beit # items in ${vit_zoo}
|
5 |
+
vit_zoo: # from huggingface
|
6 |
+
beit: microsoft/beit-base-patch16-224-pt22k-ft22k
|
7 |
+
vit_name_or_pretrained_path: ${vit_zoo[${vit_type}]}
|
8 |
+
|
9 |
+
vision_encoder_args:
|
10 |
+
token_keep_rate: 0.7
|
11 |
+
token_keep_strategy: cls_attn
|
12 |
+
token_drop_loc: [3, 6, 9]
|
13 |
+
sparse_local_attn: 1
|
14 |
+
sparse_random_attn: 5
|
15 |
+
attn_block_size: 56
|
16 |
+
|
17 |
+
image_res: 224
|
18 |
+
embed_dim: 256
|
19 |
+
video_input:
|
20 |
+
num_frames: 4
|
21 |
+
reader: decord # one of [decord, av]
|
22 |
+
sample_type: rand
|
23 |
+
num_frames_test: 16 # num_frames during inference/test
|
24 |
+
sample_type_test: middle
|
25 |
+
max_txt_l:
|
26 |
+
image: 32
|
27 |
+
video: 32
|
28 |
+
|
29 |
+
batch_size:
|
30 |
+
image: 8
|
31 |
+
video: 8
|
32 |
+
batch_size_test:
|
33 |
+
image: 8
|
34 |
+
video: 8
|
35 |
+
k_test: 128
|
36 |
+
temp: 0.18
|
37 |
+
mlm_prob: 0.5
|
configs/charades_ego/svitt.yml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
pretrain: ckpt/svitt-ego.pth
|
3 |
+
freeze_vis_backbone: true
|
4 |
+
freeze_txt_backbone: true
|
5 |
+
num_frames: 16
|
6 |
+
config: configs/charades_ego/action-recognition.yaml
|
7 |
+
|
8 |
+
data:
|
9 |
+
dataset: charades_ego
|
10 |
+
root: data/charades_ego/video
|
11 |
+
metadata_val: data/charades_ego/csv/{}.csv
|
12 |
+
label_map: meta/charades_ego/charades_ego.json
|
13 |
+
clip_length: 16
|
14 |
+
sparse_sample: true
|
configs/config_bert.json
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"BertForMaskedLM"
|
4 |
+
],
|
5 |
+
"attention_probs_dropout_prob": 0.1,
|
6 |
+
"hidden_act": "gelu",
|
7 |
+
"hidden_dropout_prob": 0.1,
|
8 |
+
"hidden_size": 768,
|
9 |
+
"initializer_range": 0.02,
|
10 |
+
"intermediate_size": 3072,
|
11 |
+
"layer_norm_eps": 1e-12,
|
12 |
+
"max_position_embeddings": 512,
|
13 |
+
"model_type": "bert",
|
14 |
+
"num_attention_heads": 12,
|
15 |
+
"num_hidden_layers": 12,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"type_vocab_size": 2,
|
18 |
+
"vocab_size": 30522,
|
19 |
+
"fusion_layer": 9,
|
20 |
+
"encoder_width": 768
|
21 |
+
}
|
data/charades_ego/Charades_v1_classes.txt
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
c000 Holding some clothes
|
2 |
+
c001 Putting clothes somewhere
|
3 |
+
c002 Taking some clothes from somewhere
|
4 |
+
c003 Throwing clothes somewhere
|
5 |
+
c004 Tidying some clothes
|
6 |
+
c005 Washing some clothes
|
7 |
+
c006 Closing a door
|
8 |
+
c007 Fixing a door
|
9 |
+
c008 Opening a door
|
10 |
+
c009 Putting something on a table
|
11 |
+
c010 Sitting on a table
|
12 |
+
c011 Sitting at a table
|
13 |
+
c012 Tidying up a table
|
14 |
+
c013 Washing a table
|
15 |
+
c014 Working at a table
|
16 |
+
c015 Holding a phone/camera
|
17 |
+
c016 Playing with a phone/camera
|
18 |
+
c017 Putting a phone/camera somewhere
|
19 |
+
c018 Taking a phone/camera from somewhere
|
20 |
+
c019 Talking on a phone/camera
|
21 |
+
c020 Holding a bag
|
22 |
+
c021 Opening a bag
|
23 |
+
c022 Putting a bag somewhere
|
24 |
+
c023 Taking a bag from somewhere
|
25 |
+
c024 Throwing a bag somewhere
|
26 |
+
c025 Closing a book
|
27 |
+
c026 Holding a book
|
28 |
+
c027 Opening a book
|
29 |
+
c028 Putting a book somewhere
|
30 |
+
c029 Smiling at a book
|
31 |
+
c030 Taking a book from somewhere
|
32 |
+
c031 Throwing a book somewhere
|
33 |
+
c032 Watching/Reading/Looking at a book
|
34 |
+
c033 Holding a towel/s
|
35 |
+
c034 Putting a towel/s somewhere
|
36 |
+
c035 Taking a towel/s from somewhere
|
37 |
+
c036 Throwing a towel/s somewhere
|
38 |
+
c037 Tidying up a towel/s
|
39 |
+
c038 Washing something with a towel
|
40 |
+
c039 Closing a box
|
41 |
+
c040 Holding a box
|
42 |
+
c041 Opening a box
|
43 |
+
c042 Putting a box somewhere
|
44 |
+
c043 Taking a box from somewhere
|
45 |
+
c044 Taking something from a box
|
46 |
+
c045 Throwing a box somewhere
|
47 |
+
c046 Closing a laptop
|
48 |
+
c047 Holding a laptop
|
49 |
+
c048 Opening a laptop
|
50 |
+
c049 Putting a laptop somewhere
|
51 |
+
c050 Taking a laptop from somewhere
|
52 |
+
c051 Watching a laptop or something on a laptop
|
53 |
+
c052 Working/Playing on a laptop
|
54 |
+
c053 Holding a shoe/shoes
|
55 |
+
c054 Putting shoes somewhere
|
56 |
+
c055 Putting on shoe/shoes
|
57 |
+
c056 Taking shoes from somewhere
|
58 |
+
c057 Taking off some shoes
|
59 |
+
c058 Throwing shoes somewhere
|
60 |
+
c059 Sitting in a chair
|
61 |
+
c060 Standing on a chair
|
62 |
+
c061 Holding some food
|
63 |
+
c062 Putting some food somewhere
|
64 |
+
c063 Taking food from somewhere
|
65 |
+
c064 Throwing food somewhere
|
66 |
+
c065 Eating a sandwich
|
67 |
+
c066 Making a sandwich
|
68 |
+
c067 Holding a sandwich
|
69 |
+
c068 Putting a sandwich somewhere
|
70 |
+
c069 Taking a sandwich from somewhere
|
71 |
+
c070 Holding a blanket
|
72 |
+
c071 Putting a blanket somewhere
|
73 |
+
c072 Snuggling with a blanket
|
74 |
+
c073 Taking a blanket from somewhere
|
75 |
+
c074 Throwing a blanket somewhere
|
76 |
+
c075 Tidying up a blanket/s
|
77 |
+
c076 Holding a pillow
|
78 |
+
c077 Putting a pillow somewhere
|
79 |
+
c078 Snuggling with a pillow
|
80 |
+
c079 Taking a pillow from somewhere
|
81 |
+
c080 Throwing a pillow somewhere
|
82 |
+
c081 Putting something on a shelf
|
83 |
+
c082 Tidying a shelf or something on a shelf
|
84 |
+
c083 Reaching for and grabbing a picture
|
85 |
+
c084 Holding a picture
|
86 |
+
c085 Laughing at a picture
|
87 |
+
c086 Putting a picture somewhere
|
88 |
+
c087 Taking a picture of something
|
89 |
+
c088 Watching/looking at a picture
|
90 |
+
c089 Closing a window
|
91 |
+
c090 Opening a window
|
92 |
+
c091 Washing a window
|
93 |
+
c092 Watching/Looking outside of a window
|
94 |
+
c093 Holding a mirror
|
95 |
+
c094 Smiling in a mirror
|
96 |
+
c095 Washing a mirror
|
97 |
+
c096 Watching something/someone/themselves in a mirror
|
98 |
+
c097 Walking through a doorway
|
99 |
+
c098 Holding a broom
|
100 |
+
c099 Putting a broom somewhere
|
101 |
+
c100 Taking a broom from somewhere
|
102 |
+
c101 Throwing a broom somewhere
|
103 |
+
c102 Tidying up with a broom
|
104 |
+
c103 Fixing a light
|
105 |
+
c104 Turning on a light
|
106 |
+
c105 Turning off a light
|
107 |
+
c106 Drinking from a cup/glass/bottle
|
108 |
+
c107 Holding a cup/glass/bottle of something
|
109 |
+
c108 Pouring something into a cup/glass/bottle
|
110 |
+
c109 Putting a cup/glass/bottle somewhere
|
111 |
+
c110 Taking a cup/glass/bottle from somewhere
|
112 |
+
c111 Washing a cup/glass/bottle
|
113 |
+
c112 Closing a closet/cabinet
|
114 |
+
c113 Opening a closet/cabinet
|
115 |
+
c114 Tidying up a closet/cabinet
|
116 |
+
c115 Someone is holding a paper/notebook
|
117 |
+
c116 Putting their paper/notebook somewhere
|
118 |
+
c117 Taking paper/notebook from somewhere
|
119 |
+
c118 Holding a dish
|
120 |
+
c119 Putting a dish/es somewhere
|
121 |
+
c120 Taking a dish/es from somewhere
|
122 |
+
c121 Wash a dish/dishes
|
123 |
+
c122 Lying on a sofa/couch
|
124 |
+
c123 Sitting on sofa/couch
|
125 |
+
c124 Lying on the floor
|
126 |
+
c125 Sitting on the floor
|
127 |
+
c126 Throwing something on the floor
|
128 |
+
c127 Tidying something on the floor
|
129 |
+
c128 Holding some medicine
|
130 |
+
c129 Taking/consuming some medicine
|
131 |
+
c130 Putting groceries somewhere
|
132 |
+
c131 Laughing at television
|
133 |
+
c132 Watching television
|
134 |
+
c133 Someone is awakening in bed
|
135 |
+
c134 Lying on a bed
|
136 |
+
c135 Sitting in a bed
|
137 |
+
c136 Fixing a vacuum
|
138 |
+
c137 Holding a vacuum
|
139 |
+
c138 Taking a vacuum from somewhere
|
140 |
+
c139 Washing their hands
|
141 |
+
c140 Fixing a doorknob
|
142 |
+
c141 Grasping onto a doorknob
|
143 |
+
c142 Closing a refrigerator
|
144 |
+
c143 Opening a refrigerator
|
145 |
+
c144 Fixing their hair
|
146 |
+
c145 Working on paper/notebook
|
147 |
+
c146 Someone is awakening somewhere
|
148 |
+
c147 Someone is cooking something
|
149 |
+
c148 Someone is dressing
|
150 |
+
c149 Someone is laughing
|
151 |
+
c150 Someone is running somewhere
|
152 |
+
c151 Someone is going from standing to sitting
|
153 |
+
c152 Someone is smiling
|
154 |
+
c153 Someone is sneezing
|
155 |
+
c154 Someone is standing up from somewhere
|
156 |
+
c155 Someone is undressing
|
157 |
+
c156 Someone is eating something
|
data/charades_ego/csv/0.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
P9SOAEGO,Z241,Stairs,6.0,6.0,Yes,"A person holding a broom walks up and down the stairs, brushing nonchalantly on each step. They put the broom down and pull out a vial of medicine from their pocket, then toss it down the stairs.",broom;floor;medicine;stairs,,c099 0.00 23.21;c127 1.20 23.21;c102 0.00 23.21;c100 0.00 23.21;c098 0.00 23.21;c101 25.20 23.21,33.33,Yes,LY2GQ
|
data/charades_ego/csv/1.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
6D5DHEGO,UTMU,Kitchen,5.0,5.0,Yes,"The person put a pot on the stove and then turned to the counter to grab a cup. The person drank from the cup for a little. The person then poured the rest of the contents of the drink into the sink. The person then went to leave the room, grabbing the doorknob.",food;glass;sink;stove,,c109 20.30 21.38;c107 7.70 17.90;c147 0.00 8.70;c110 7.50 13.60,34.5,Yes,
|
data/charades_ego/csv/2.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
15AKPEGO,I2IV,Home Office / Study (A room in a house used for work),7.0,7.0,Yes,A person is putting a box on the shelf and then closing the cabinet.,box;cabinet;desk;door;shelf,,c043 4.00 10.30;c040 5.30 19.80;c112 19.10 26.70;c042 14.10 20.70;c081 14.10 20.70;c114 15.00 20.70;c006 18.80 28.70,30.83,Yes,IA5TC
|
data/charades_ego/csv/3.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
X2JTKEGO,EV0Z,Recreation room / Man cave,6.0,5.0,Yes,"Person is watching television while another person is walking towards the shelf, grasping the camera.",bed;camera;drawer;television,,c135 0.00 30.30;c132 0.00 30.50;c016 7.70 29.80;c015 0.00 25.50,31.75,Yes,D3PPI
|
data/charades_ego/csv/4.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
184EHEGO,KFGP,Home Office / Study (A room in a house used for work),6.0,6.0,Yes,A person is eating at a table while they play on a laptop.,chair;food;laptop;table,,c011 1.20 18.00;c059 0.20 17.60;c052 2.10 23.62;c156 0.40 23.62;c010 15.70 23.62;c014 14.40 23.62,37.17,Yes,CUB69
|
data/charades_ego/csv/5.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
S8YZIEGO,I2IV,Kitchen,6.0,7.0,Yes,"A person is cooking food on the stove. The person takes a picture of the food with their phone, then puts the phone back into their pocket.",food;phone/camera;pot;stirring utensil;stove,,c147 0.00 24.67;c017 24.50 24.67;c087 16.50 23.50;c015 8.60 24.67;c016 8.10 24.67;c018 5.80 13.60;c147 0.00 10.80;c017 27.10 24.67;c154 0.00 24.67,34.62,Yes,DUEEE
|
data/charades_ego/csv/6.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
PRODQEGO,DJ17,Kitchen,6.0,7.0,Yes,A person is cooking at a stove then they begin to play with some food that's on the counter.,dish;food;shelf;stove;table,,c009 20.90 30.00;c064 22.50 29.90;c063 16.50 23.10;c147 0.00 18.60;c061 19.10 31.79,32.17,Yes,Y5826
|
data/charades_ego/csv/7.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
QLXEXEGO,4OHY,Hallway,6.0,5.0,Yes,"A person smiles as they fix the door. The person laughs hysterically, then throws their bag of tools across the room.",door,,c006 7.50 22.20;c008 32.30 29.71;c152 17.40 26.00;c007 0.00 27.40,40.33,Yes,K3T1B
|
data/charades_ego/csv/8.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
CC0LBEGO,3VLX,Recreation room / Man cave,6.0,7.0,Yes,The person is fixing a shelf in the Recreation room / Man cave. Once the person was finished the person grabbed a glass of water and took a few sips. The person then looked around and left the room.,dirt;glass;shelf;water,,c110 0.00 32.38;c110 22.00 26.90;c106 22.30 26.70;c107 20.90 26.30;c082 1.40 11.40,36.29,Yes,
|
data/charades_ego/csv/9.csv
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
|
2 |
+
FLY2FEGO,VT5W,Kitchen,7.0,7.0,Yes,A person opens a cabinet and takes out some coffee and then closes the cabinet.,closet/cabinet;coffee,,c113 0.00 4.20;c112 3.90 9.90,31.58,Yes,EJJIO
|
data/charades_ego/video/15AKPEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1dbe66fb96be257ae24c122a8f534e893883aa058872932ea630d20e9a609a1f
|
3 |
+
size 1150119
|
data/charades_ego/video/184EHEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:115fea186961bddc3bc8e9f10de4472a450720b143361f598257308d64bfd1b9
|
3 |
+
size 1459016
|
data/charades_ego/video/6D5DHEGO.mp4
ADDED
Binary file (871 kB). View file
|
|
data/charades_ego/video/CC0LBEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b80523915026d3e23740b4eb057362809e6375e973e0b2fad42a5f11e4331629
|
3 |
+
size 1414227
|
data/charades_ego/video/FLY2FEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b45b12ffeeeac6ae74955c37fe57bc468d309aaae188f9020cde88c3f7dd68da
|
3 |
+
size 1215686
|
data/charades_ego/video/P9SOAEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:dfc22189ed6d228eca382474ff668f52d0c2b034204241837a4ea00c6b650fd5
|
3 |
+
size 2497723
|
data/charades_ego/video/PRODQEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5cd034ab3146bbb36264f5ddbf7b451d09e3e1002c4550e565805710e8d5a1cd
|
3 |
+
size 1318941
|
data/charades_ego/video/QLXEXEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:29cc6ec737fe810378ba1f91363078194cf73a4e8e9dcafb45013e04e426bb94
|
3 |
+
size 1834803
|
data/charades_ego/video/S8YZIEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d25f66f11549b8dac6a05c0f9405f1dc3a77df68a56430b76d33ab830300fd04
|
3 |
+
size 1439584
|
data/charades_ego/video/X2JTKEGO.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:617b7e321a96271f9603869999a7dc4ac26bf2e8d5f43ab8758428731373de6d
|
3 |
+
size 1976273
|
demo.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### demo.py
|
2 |
+
# Define model classes for inference.
|
3 |
+
###
|
4 |
+
import json
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import pandas as pd
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.backends.cudnn as cudnn
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
import torchvision.transforms._transforms_video as transforms_video
|
14 |
+
from sklearn.metrics import confusion_matrix
|
15 |
+
from einops import rearrange
|
16 |
+
from transformers import BertTokenizer
|
17 |
+
|
18 |
+
from svitt.model import SViTT
|
19 |
+
from svitt.datasets import VideoClassyDataset
|
20 |
+
from svitt.video_transforms import Permute
|
21 |
+
from svitt.config import load_cfg, setup_config
|
22 |
+
from svitt.evaluation_charades import charades_map
|
23 |
+
from svitt.evaluation import get_mean_accuracy
|
24 |
+
|
25 |
+
|
26 |
+
class VideoModel(nn.Module):
|
27 |
+
""" Base model for video understanding based on SViTT architecture. """
|
28 |
+
def __init__(self, config):
|
29 |
+
""" Initializes the model.
|
30 |
+
Parameters:
|
31 |
+
config: config file
|
32 |
+
"""
|
33 |
+
super(VideoModel, self).__init__()
|
34 |
+
self.cfg = load_cfg(config)
|
35 |
+
self.model = self.build_model()
|
36 |
+
self.templates = ['{}']
|
37 |
+
self.dataset = self.cfg['data']['dataset']
|
38 |
+
self.eval()
|
39 |
+
|
40 |
+
def build_model(self):
|
41 |
+
cfg = self.cfg
|
42 |
+
if cfg['model'].get('pretrain', False):
|
43 |
+
ckpt_path = cfg['model']['pretrain']
|
44 |
+
else:
|
45 |
+
raise Exception('no checkpoint found')
|
46 |
+
|
47 |
+
if cfg['model'].get('config', False):
|
48 |
+
config_path = cfg['model']['config']
|
49 |
+
else:
|
50 |
+
raise Exception('no model config found')
|
51 |
+
|
52 |
+
self.model_cfg = setup_config(config_path)
|
53 |
+
self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
|
54 |
+
model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
|
55 |
+
|
56 |
+
print(f"Loading checkpoint from {ckpt_path}")
|
57 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu")
|
58 |
+
state_dict = checkpoint["model"]
|
59 |
+
|
60 |
+
# fix for zero-shot evaluation
|
61 |
+
for key in list(state_dict.keys()):
|
62 |
+
if "bert" in key:
|
63 |
+
encoder_key = key.replace("bert.", "")
|
64 |
+
state_dict[encoder_key] = state_dict[key]
|
65 |
+
|
66 |
+
if torch.cuda.is_available():
|
67 |
+
model.cuda()
|
68 |
+
|
69 |
+
model.load_state_dict(state_dict, strict=False)
|
70 |
+
|
71 |
+
return model
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
def eval(self):
|
76 |
+
cudnn.benchmark = True
|
77 |
+
for p in self.model.parameters():
|
78 |
+
p.requires_grad = False
|
79 |
+
self.model.eval()
|
80 |
+
|
81 |
+
|
82 |
+
class VideoCLSModel(VideoModel):
|
83 |
+
""" Video model for video classification tasks (Charades-Ego, EGTEA). """
|
84 |
+
def __init__(self, config):
|
85 |
+
super(VideoCLSModel, self).__init__(config)
|
86 |
+
self.labels, self.mapping_vn2act = self.gen_label_map()
|
87 |
+
self.text_features = self.get_text_features()
|
88 |
+
|
89 |
+
def gen_label_map(self):
|
90 |
+
labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
|
91 |
+
if os.path.isfile(labelmap):
|
92 |
+
print(f"=> Loading label maps from {labelmap}")
|
93 |
+
meta = json.load(open(labelmap, 'r'))
|
94 |
+
labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
|
95 |
+
else:
|
96 |
+
from svitt.preprocess import generate_label_map
|
97 |
+
labels, mapping_vn2act = generate_label_map(self.dataset)
|
98 |
+
meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
|
99 |
+
meta_dir = f'meta/{self.dataset}'
|
100 |
+
if not os.path.exists(meta_dir):
|
101 |
+
os.makedirs(meta_dir)
|
102 |
+
json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
|
103 |
+
print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
|
104 |
+
|
105 |
+
return labels, mapping_vn2act
|
106 |
+
|
107 |
+
def load_data(self, idx=None):
|
108 |
+
print(f"=> Creating dataset")
|
109 |
+
cfg, dataset = self.cfg, self.dataset
|
110 |
+
data_cfg = cfg['data']
|
111 |
+
crop_size = 224
|
112 |
+
val_transform = transforms.Compose([
|
113 |
+
Permute([3, 0, 1, 2]), # T H W C -> C T H W
|
114 |
+
transforms.Resize(crop_size),
|
115 |
+
transforms.CenterCrop(crop_size),
|
116 |
+
transforms_video.NormalizeVideo(
|
117 |
+
mean=[108.3272985, 116.7460125, 104.09373615000001],
|
118 |
+
std=[68.5005327, 66.6321579, 70.32316305],
|
119 |
+
),
|
120 |
+
])
|
121 |
+
|
122 |
+
if idx is None:
|
123 |
+
metadata_val = data_cfg['metadata_val']
|
124 |
+
else:
|
125 |
+
metadata_val = data_cfg['metadata_val'].format(idx)
|
126 |
+
if dataset in ['charades_ego', 'egtea']:
|
127 |
+
val_dataset = VideoClassyDataset(
|
128 |
+
dataset,
|
129 |
+
data_cfg['root'],
|
130 |
+
metadata_val,
|
131 |
+
transform=val_transform,
|
132 |
+
is_training=False,
|
133 |
+
label_mapping=self.mapping_vn2act,
|
134 |
+
is_trimmed=False,
|
135 |
+
num_clips=1,
|
136 |
+
clip_length=data_cfg['clip_length'],
|
137 |
+
clip_stride=data_cfg['clip_stride'],
|
138 |
+
sparse_sample=data_cfg['sparse_sample'],
|
139 |
+
)
|
140 |
+
else:
|
141 |
+
raise NotImplementedError
|
142 |
+
|
143 |
+
val_loader = torch.utils.data.DataLoader(
|
144 |
+
val_dataset, batch_size=8, shuffle=False,
|
145 |
+
num_workers=4, pin_memory=True, sampler=None, drop_last=False
|
146 |
+
)
|
147 |
+
|
148 |
+
return val_loader
|
149 |
+
|
150 |
+
@torch.no_grad()
|
151 |
+
def get_text_features(self):
|
152 |
+
print('=> Extracting text features')
|
153 |
+
embeddings = self.tokenizer(
|
154 |
+
self.labels,
|
155 |
+
padding="max_length",
|
156 |
+
truncation=True,
|
157 |
+
max_length=self.model_cfg.max_txt_l.video,
|
158 |
+
return_tensors="pt",
|
159 |
+
)
|
160 |
+
_, class_embeddings = self.model.encode_text(embeddings)
|
161 |
+
return class_embeddings
|
162 |
+
|
163 |
+
@torch.no_grad()
|
164 |
+
def forward(self, idx=None):
|
165 |
+
print('=> Start forwarding')
|
166 |
+
val_loader = self.load_data(idx)
|
167 |
+
all_outputs = []
|
168 |
+
all_targets = []
|
169 |
+
for i, values in enumerate(val_loader):
|
170 |
+
images = values[0]
|
171 |
+
target = values[1]
|
172 |
+
|
173 |
+
if torch.cuda.is_available():
|
174 |
+
images = images.cuda(non_blocking=True)
|
175 |
+
target = target.cuda(non_blocking=True)
|
176 |
+
|
177 |
+
# encode images
|
178 |
+
images = rearrange(images, 'b c k h w -> b k c h w')
|
179 |
+
dims = images.shape
|
180 |
+
images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1])
|
181 |
+
|
182 |
+
image_features, _ = self.model.encode_image(images)
|
183 |
+
|
184 |
+
if image_features.ndim == 3:
|
185 |
+
image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1)
|
186 |
+
else:
|
187 |
+
image_features = rearrange(image_features, '(b k) d -> b k d', b=1)
|
188 |
+
|
189 |
+
# cosine similarity as logits
|
190 |
+
similarity = self.model.get_sim(image_features, self.text_features)[0]
|
191 |
+
|
192 |
+
all_outputs.append(similarity.cpu())
|
193 |
+
all_targets.append(target.cpu())
|
194 |
+
|
195 |
+
all_outputs = torch.cat(all_outputs)
|
196 |
+
all_targets = torch.cat(all_targets)
|
197 |
+
|
198 |
+
return all_outputs, all_targets
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def predict(self, idx=0):
|
202 |
+
all_outputs, all_targets = self.forward(idx)
|
203 |
+
preds, targets = all_outputs.numpy(), all_targets.numpy()
|
204 |
+
#sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0]
|
205 |
+
sel = 5
|
206 |
+
df = pd.DataFrame(self.labels)
|
207 |
+
pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
|
208 |
+
gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
|
209 |
+
pred_action = sorted([x[0] for x in pred_action])
|
210 |
+
gt_action = sorted([x[0] for x in gt_action])
|
211 |
+
return pred_action, gt_action
|
212 |
+
|
213 |
+
@torch.no_grad()
|
214 |
+
def evaluate(self):
|
215 |
+
all_outputs, all_targets = self.forward()
|
216 |
+
preds, targets = all_outputs.numpy(), all_targets.numpy()
|
217 |
+
if self.dataset == 'charades_ego':
|
218 |
+
m_ap, _, m_aps = charades_map(preds, targets)
|
219 |
+
print('mAP = {:.3f}'.format(m_ap))
|
220 |
+
elif self.dataset == 'egtea':
|
221 |
+
cm = confusion_matrix(targets, preds.argmax(axis=1))
|
222 |
+
mean_class_acc, acc = get_mean_accuracy(cm)
|
223 |
+
print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
|
224 |
+
else:
|
225 |
+
raise NotImplementedError
|
226 |
+
|
meta/charades_ego/label_map.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"labels": ["Holding some clothes", "Putting clothes somewhere", "Taking some clothes from somewhere", "Throwing clothes somewhere", "Tidying some clothes", "Washing some clothes", "Closing a door", "Fixing a door", "Opening a door", "Putting something on a table", "Sitting on a table", "Sitting at a table", "Tidying up a table", "Washing a table", "Working at a table", "Holding a phone/camera", "Playing with a phone/camera", "Putting a phone/camera somewhere", "Taking a phone/camera from somewhere", "Talking on a phone/camera", "Holding a bag", "Opening a bag", "Putting a bag somewhere", "Taking a bag from somewhere", "Throwing a bag somewhere", "Closing a book", "Holding a book", "Opening a book", "Putting a book somewhere", "Smiling at a book", "Taking a book from somewhere", "Throwing a book somewhere", "Watching/Reading/Looking at a book", "Holding a towel/s", "Putting a towel/s somewhere", "Taking a towel/s from somewhere", "Throwing a towel/s somewhere", "Tidying up a towel/s", "Washing something with a towel", "Closing a box", "Holding a box", "Opening a box", "Putting a box somewhere", "Taking a box from somewhere", "Taking something from a box", "Throwing a box somewhere", "Closing a laptop", "Holding a laptop", "Opening a laptop", "Putting a laptop somewhere", "Taking a laptop from somewhere", "Watching a laptop or something on a laptop", "Working/Playing on a laptop", "Holding a shoe/shoes", "Putting shoes somewhere", "Putting on shoe/shoes", "Taking shoes from somewhere", "Taking off some shoes", "Throwing shoes somewhere", "Sitting in a chair", "Standing on a chair", "Holding some food", "Putting some food somewhere", "Taking food from somewhere", "Throwing food somewhere", "Eating a sandwich", "Making a sandwich", "Holding a sandwich", "Putting a sandwich somewhere", "Taking a sandwich from somewhere", "Holding a blanket", "Putting a blanket somewhere", "Snuggling with a blanket", "Taking a blanket from somewhere", "Throwing a blanket somewhere", "Tidying up a blanket/s", "Holding a pillow", "Putting a pillow somewhere", "Snuggling with a pillow", "Taking a pillow from somewhere", "Throwing a pillow somewhere", "Putting something on a shelf", "Tidying a shelf or something on a shelf", "Reaching for and grabbing a picture", "Holding a picture", "Laughing at a picture", "Putting a picture somewhere", "Taking a picture of something", "Watching/looking at a picture", "Closing a window", "Opening a window", "Washing a window", "Watching/Looking outside of a window", "Holding a mirror", "Smiling in a mirror", "Washing a mirror", "Watching something/someone/themselves in a mirror", "Walking through a doorway", "Holding a broom", "Putting a broom somewhere", "Taking a broom from somewhere", "Throwing a broom somewhere", "Tidying up with a broom", "Fixing a light", "Turning on a light", "Turning off a light", "Drinking from a cup/glass/bottle", "Holding a cup/glass/bottle of something", "Pouring something into a cup/glass/bottle", "Putting a cup/glass/bottle somewhere", "Taking a cup/glass/bottle from somewhere", "Washing a cup/glass/bottle", "Closing a closet/cabinet", "Opening a closet/cabinet", "Tidying up a closet/cabinet", "Someone is holding a paper/notebook", "Putting their paper/notebook somewhere", "Taking paper/notebook from somewhere", "Holding a dish", "Putting a dish/es somewhere", "Taking a dish/es from somewhere", "Wash a dish/dishes", "Lying on a sofa/couch", "Sitting on sofa/couch", "Lying on the floor", "Sitting on the floor", "Throwing something on the floor", "Tidying something on the floor", "Holding some medicine", "Taking/consuming some medicine", "Putting groceries somewhere", "Laughing at television", "Watching television", "Someone is awakening in bed", "Lying on a bed", "Sitting in a bed", "Fixing a vacuum", "Holding a vacuum", "Taking a vacuum from somewhere", "Washing their hands", "Fixing a doorknob", "Grasping onto a doorknob", "Closing a refrigerator", "Opening a refrigerator", "Fixing their hair", "Working on paper/notebook", "Someone is awakening somewhere", "Someone is cooking something", "Someone is dressing", "Someone is laughing", "Someone is running somewhere", "Someone is going from standing to sitting", "Someone is smiling", "Someone is sneezing", "Someone is standing up from somewhere", "Someone is undressing", "Someone is eating something"], "mapping_vn2act": {"c000": 0, "c001": 1, "c002": 2, "c003": 3, "c004": 4, "c005": 5, "c006": 6, "c007": 7, "c008": 8, "c009": 9, "c010": 10, "c011": 11, "c012": 12, "c013": 13, "c014": 14, "c015": 15, "c016": 16, "c017": 17, "c018": 18, "c019": 19, "c020": 20, "c021": 21, "c022": 22, "c023": 23, "c024": 24, "c025": 25, "c026": 26, "c027": 27, "c028": 28, "c029": 29, "c030": 30, "c031": 31, "c032": 32, "c033": 33, "c034": 34, "c035": 35, "c036": 36, "c037": 37, "c038": 38, "c039": 39, "c040": 40, "c041": 41, "c042": 42, "c043": 43, "c044": 44, "c045": 45, "c046": 46, "c047": 47, "c048": 48, "c049": 49, "c050": 50, "c051": 51, "c052": 52, "c053": 53, "c054": 54, "c055": 55, "c056": 56, "c057": 57, "c058": 58, "c059": 59, "c060": 60, "c061": 61, "c062": 62, "c063": 63, "c064": 64, "c065": 65, "c066": 66, "c067": 67, "c068": 68, "c069": 69, "c070": 70, "c071": 71, "c072": 72, "c073": 73, "c074": 74, "c075": 75, "c076": 76, "c077": 77, "c078": 78, "c079": 79, "c080": 80, "c081": 81, "c082": 82, "c083": 83, "c084": 84, "c085": 85, "c086": 86, "c087": 87, "c088": 88, "c089": 89, "c090": 90, "c091": 91, "c092": 92, "c093": 93, "c094": 94, "c095": 95, "c096": 96, "c097": 97, "c098": 98, "c099": 99, "c100": 100, "c101": 101, "c102": 102, "c103": 103, "c104": 104, "c105": 105, "c106": 106, "c107": 107, "c108": 108, "c109": 109, "c110": 110, "c111": 111, "c112": 112, "c113": 113, "c114": 114, "c115": 115, "c116": 116, "c117": 117, "c118": 118, "c119": 119, "c120": 120, "c121": 121, "c122": 122, "c123": 123, "c124": 124, "c125": 125, "c126": 126, "c127": 127, "c128": 128, "c129": 129, "c130": 130, "c131": 131, "c132": 132, "c133": 133, "c134": 134, "c135": 135, "c136": 136, "c137": 137, "c138": 138, "c139": 139, "c140": 140, "c141": 141, "c142": 142, "c143": 143, "c144": 144, "c145": 145, "c146": 146, "c147": 147, "c148": 148, "c149": 149, "c150": 150, "c151": 151, "c152": 152, "c153": 153, "c154": 154, "c155": 155, "c156": 156}}
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
scikit-learn
|
5 |
+
eva-decord
|
6 |
+
timm
|
7 |
+
einops
|
8 |
+
ftfy
|
9 |
+
regex
|
10 |
+
transformers
|
11 |
+
omegaconf
|
12 |
+
zCurve
|
13 |
+
numpy-hilbert-curve
|
svitt/config.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import yaml
|
3 |
+
from omegaconf import OmegaConf, DictConfig
|
4 |
+
|
5 |
+
def load_base_cfg():
|
6 |
+
with open('configs/base.yml', 'r') as fp:
|
7 |
+
cfg = yaml.load(fp, Loader=yaml.SafeLoader)
|
8 |
+
return cfg
|
9 |
+
|
10 |
+
def load_cfg(cfg_file):
|
11 |
+
cfg = load_base_cfg()
|
12 |
+
with open(cfg_file, 'r') as fp:
|
13 |
+
exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
|
14 |
+
|
15 |
+
cfg['model'].update(exp_cfg.get('model', {}))
|
16 |
+
cfg['data'].update(exp_cfg.get('data', {}))
|
17 |
+
dataset = cfg['data'].get('dataset')
|
18 |
+
return cfg
|
19 |
+
|
20 |
+
def convert_types(config):
|
21 |
+
"""Convert `'None'` (str) --> `None` (None). Only supports top-level"""
|
22 |
+
for k, v in config.items():
|
23 |
+
if isinstance(v, DictConfig):
|
24 |
+
setattr(config, k, convert_types(v))
|
25 |
+
|
26 |
+
# TODO convert types in ListConfig, right now they are ignored
|
27 |
+
# if isinstance(v, ListConfig):
|
28 |
+
# new_v = ListConfig()
|
29 |
+
|
30 |
+
if v in ["None", "none"]:
|
31 |
+
setattr(config, k, None)
|
32 |
+
return config
|
33 |
+
|
34 |
+
def setup_config(config_path):
|
35 |
+
yaml_config = OmegaConf.load(config_path)
|
36 |
+
config = convert_types(yaml_config)
|
37 |
+
return config
|
svitt/datasets.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
import glob
|
9 |
+
import json
|
10 |
+
import numpy as np
|
11 |
+
import os.path as osp
|
12 |
+
import pickle
|
13 |
+
import random
|
14 |
+
|
15 |
+
import decord
|
16 |
+
import pandas as pd
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
def datetime2sec(str):
|
21 |
+
hh, mm, ss = str.split(':')
|
22 |
+
return int(hh) * 3600 + int(mm) * 60 + float(ss)
|
23 |
+
|
24 |
+
|
25 |
+
def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
|
26 |
+
if chunk_len == -1:
|
27 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
|
28 |
+
second_offset = second
|
29 |
+
if end_second is not None:
|
30 |
+
end_second = min(end_second, len(vr) / vr.get_avg_fps())
|
31 |
+
else:
|
32 |
+
end_second = len(vr) / vr.get_avg_fps()
|
33 |
+
else:
|
34 |
+
chunk_start = int(second) // chunk_len * chunk_len
|
35 |
+
second_offset = second - chunk_start
|
36 |
+
vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
|
37 |
+
if fps == -1:
|
38 |
+
fps = vr.get_avg_fps()
|
39 |
+
|
40 |
+
# calculate frame_ids
|
41 |
+
frame_offset = int(np.round(second_offset * fps))
|
42 |
+
total_duration = max(int((end_second - second) * fps), clip_length)
|
43 |
+
if chunk_len == -1:
|
44 |
+
if end_second <= second:
|
45 |
+
raise ValueError("end_second should be greater than second")
|
46 |
+
else:
|
47 |
+
frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
|
48 |
+
else:
|
49 |
+
frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
|
50 |
+
|
51 |
+
# load frames
|
52 |
+
if max(frame_ids) < len(vr):
|
53 |
+
try:
|
54 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
55 |
+
except decord.DECORDError as error:
|
56 |
+
print(error)
|
57 |
+
frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
|
58 |
+
else:
|
59 |
+
# find the remaining frames in the next chunk
|
60 |
+
try:
|
61 |
+
frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
|
62 |
+
frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
|
63 |
+
vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
|
64 |
+
frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
|
65 |
+
frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
|
66 |
+
frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
|
67 |
+
frames = np.concatenate([frames_part1, frames_part2], axis=0)
|
68 |
+
# the next chunk does not exist; the current chunk is the last one
|
69 |
+
except (RuntimeError, decord.DECORDError) as error:
|
70 |
+
print(error)
|
71 |
+
frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
|
72 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
73 |
+
|
74 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
75 |
+
return torch.stack(frames, dim=0)
|
76 |
+
|
77 |
+
|
78 |
+
def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
|
79 |
+
seg_size = float(end_frame - start_frame - 1) / num_segments
|
80 |
+
seq = []
|
81 |
+
for i in range(num_segments):
|
82 |
+
start = int(np.round(seg_size * i) + start_frame)
|
83 |
+
end = int(np.round(seg_size * (i + 1)) + start_frame)
|
84 |
+
end = min(end, end_frame)
|
85 |
+
if jitter:
|
86 |
+
frame_id = np.random.randint(low=start, high=(end + 1))
|
87 |
+
else:
|
88 |
+
frame_id = (start + end) // 2
|
89 |
+
seq.append(frame_id)
|
90 |
+
return seq
|
91 |
+
|
92 |
+
|
93 |
+
def video_loader_by_frames(root, vid, frame_ids):
|
94 |
+
vr = decord.VideoReader(osp.join(root, vid))
|
95 |
+
try:
|
96 |
+
frames = vr.get_batch(frame_ids).asnumpy()
|
97 |
+
frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
|
98 |
+
except (IndexError, decord.DECORDError) as error:
|
99 |
+
print(error)
|
100 |
+
print("Erroneous video: ", vid)
|
101 |
+
frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
|
102 |
+
return torch.stack(frames, dim=0)
|
103 |
+
|
104 |
+
|
105 |
+
class VideoCaptionDatasetBase(torch.utils.data.Dataset):
|
106 |
+
def __init__(self, dataset, root, metadata, is_trimmed=True):
|
107 |
+
self.dataset = dataset
|
108 |
+
self.root = root
|
109 |
+
self.is_trimmed = is_trimmed
|
110 |
+
|
111 |
+
if self.dataset == 'ego4d':
|
112 |
+
with open(metadata, 'rb') as f:
|
113 |
+
self.samples = pickle.load(f)
|
114 |
+
elif self.dataset == 'ego4d_mcq':
|
115 |
+
with open(metadata, 'r') as f:
|
116 |
+
self.samples = json.load(f)
|
117 |
+
elif self.dataset in ['ek100_cls', 'ek100_mir']:
|
118 |
+
video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
|
119 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
120 |
+
self.samples = []
|
121 |
+
with open(metadata) as f:
|
122 |
+
csv_reader = csv.reader(f)
|
123 |
+
_ = next(csv_reader) # skip the header
|
124 |
+
for row in csv_reader:
|
125 |
+
pid, vid = row[1:3]
|
126 |
+
# start_frame, end_frame = int(row[6]), int(row[7])
|
127 |
+
# Deprecated: some videos might have fps mismatch issue
|
128 |
+
start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
|
129 |
+
narration = row[8]
|
130 |
+
verb, noun = int(row[10]), int(row[12])
|
131 |
+
vid_path = '{}/{}.MP4'.format(pid, vid)
|
132 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
133 |
+
start_frame = int(np.round(fps * start_timestamp))
|
134 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
135 |
+
self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
|
136 |
+
if self.dataset == 'ek100_mir':
|
137 |
+
self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
|
138 |
+
if 'train' in metadata:
|
139 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
|
140 |
+
elif 'test' in metadata:
|
141 |
+
self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
|
142 |
+
else:
|
143 |
+
raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
|
144 |
+
self.relevancy = .1
|
145 |
+
elif self.dataset == 'egtea':
|
146 |
+
video_list = glob.glob(osp.join(self.root, '*/*'))
|
147 |
+
len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
|
148 |
+
|
149 |
+
vn_list, labels = [], []
|
150 |
+
for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
|
151 |
+
row = row.strip()
|
152 |
+
vn = int(row.split(' ')[-1])
|
153 |
+
vn_list.append(vn)
|
154 |
+
narration = ' '.join(row.split(' ')[:-1])
|
155 |
+
labels.append(narration.replace('_', ' ').lower())
|
156 |
+
# labels.append(narration)
|
157 |
+
mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
|
158 |
+
|
159 |
+
self.samples = []
|
160 |
+
with open(metadata) as f:
|
161 |
+
for row in f:
|
162 |
+
clip_id, action_idx = row.strip().split(' ')[:2]
|
163 |
+
video_id = '-'.join(clip_id.split('-')[:3])
|
164 |
+
vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
|
165 |
+
vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
|
166 |
+
self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
|
167 |
+
elif self.dataset == 'charades_ego':
|
168 |
+
video_list = glob.glob(osp.join(self.root, '*.mp4'))
|
169 |
+
fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
|
170 |
+
self.samples = []
|
171 |
+
with open(metadata) as f:
|
172 |
+
csv_reader = csv.reader(f)
|
173 |
+
_ = next(csv_reader) # skip the header
|
174 |
+
for row in csv_reader:
|
175 |
+
video_id = row[0]
|
176 |
+
if self.is_trimmed:
|
177 |
+
for action_tuple in row[9].split(';'):
|
178 |
+
if not action_tuple:
|
179 |
+
continue
|
180 |
+
action, start_timestamp, end_timestamp = action_tuple.split(' ')
|
181 |
+
start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
|
182 |
+
vid_path = '{}.mp4'.format(video_id)
|
183 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
184 |
+
start_frame = int(np.round(fps * start_timestamp))
|
185 |
+
end_frame = int(np.ceil(fps * end_timestamp))
|
186 |
+
self.samples.append((vid_path, start_frame, end_frame, action))
|
187 |
+
else:
|
188 |
+
if not row[9]:
|
189 |
+
action_list = []
|
190 |
+
else:
|
191 |
+
action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
|
192 |
+
vid_path = '{}.mp4'.format(video_id)
|
193 |
+
fps = fps_dict[osp.join(self.root, vid_path)]
|
194 |
+
duration = fps * float(row[10])
|
195 |
+
self.samples.append((vid_path, 0, duration, action_list))
|
196 |
+
elif self.dataset == 'charades_ego_trimmed':
|
197 |
+
with open(metadata, 'rb') as f:
|
198 |
+
self.samples = pickle.load(f)
|
199 |
+
else:
|
200 |
+
raise NotImplementedError
|
201 |
+
|
202 |
+
def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
|
203 |
+
narration_selection='random'):
|
204 |
+
if self.dataset == 'ego4d':
|
205 |
+
if len(self.samples[i]) == 4:
|
206 |
+
vid, start_second, end_second, narration = self.samples[i]
|
207 |
+
frames = video_loader(self.root, vid, start_second,
|
208 |
+
end_second=end_second,
|
209 |
+
clip_length=clip_length,
|
210 |
+
jitter=is_training)
|
211 |
+
if isinstance(narration, list):
|
212 |
+
if narration_selection == 'random':
|
213 |
+
narration = random.choice(narration)
|
214 |
+
elif narration_selection == 'concat':
|
215 |
+
narration = '. '.join(narration)
|
216 |
+
elif narration_selection == 'list':
|
217 |
+
narration = narration
|
218 |
+
else:
|
219 |
+
raise ValueError
|
220 |
+
return frames, narration
|
221 |
+
elif len(self.samples[i]) == 5:
|
222 |
+
# TODO: need better filtering strategy based on nll
|
223 |
+
vid, start_second, end_second, narration, _ = self.samples[i]
|
224 |
+
frames = video_loader(self.root, vid, start_second,
|
225 |
+
end_second=end_second,
|
226 |
+
clip_length=clip_length,
|
227 |
+
jitter=is_training)
|
228 |
+
if isinstance(narration, list):
|
229 |
+
if narration_selection == 'random':
|
230 |
+
narration = random.choice(narration)
|
231 |
+
elif narration_selection == 'concat':
|
232 |
+
narration = '. '.join(narration)
|
233 |
+
elif narration_selection == 'list':
|
234 |
+
narration = narration
|
235 |
+
else:
|
236 |
+
raise ValueError
|
237 |
+
return frames, narration
|
238 |
+
elif self.dataset == 'ego4d_mcq':
|
239 |
+
itemMCQ = self.samples[str(i)]
|
240 |
+
answerIndex = itemMCQ['answer']
|
241 |
+
textQuery = itemMCQ['query']['clip_text']
|
242 |
+
sampleOptions = itemMCQ['choices']
|
243 |
+
frames_options = []
|
244 |
+
narration_options = []
|
245 |
+
for option_id in range(len(sampleOptions)):
|
246 |
+
option = sampleOptions[str(option_id)]
|
247 |
+
frames = video_loader(self.root, option['video_uid'],
|
248 |
+
float(option['clip_start']), end_second=float(option['clip_end']),
|
249 |
+
clip_length=clip_length,
|
250 |
+
jitter=is_training)
|
251 |
+
frames_options.append(frames)
|
252 |
+
narration_options.append(option['clip_text'])
|
253 |
+
return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
|
254 |
+
elif self.dataset == 'ek100_mir':
|
255 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
256 |
+
# from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
|
257 |
+
# frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
|
258 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
259 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
260 |
+
if is_training:
|
261 |
+
positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
|
262 |
+
if positive_list != []:
|
263 |
+
pos = random.sample(positive_list, min(len(positive_list), 1))[0]
|
264 |
+
if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
|
265 |
+
return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
|
266 |
+
else:
|
267 |
+
return frames, (narration, 1)
|
268 |
+
elif self.dataset == 'ek100_cls':
|
269 |
+
vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
|
270 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
|
271 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
272 |
+
return frames, '{}:{}'.format(verb, noun)
|
273 |
+
elif self.dataset == 'egtea':
|
274 |
+
vid_path, start_frame, end_frame, sentence = self.samples[i]
|
275 |
+
if is_training:
|
276 |
+
assert num_clips == 1
|
277 |
+
if end_frame < clip_length * clip_stride:
|
278 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
279 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
280 |
+
frames = torch.cat((frames, zeros), dim=0)
|
281 |
+
frames = frames[::clip_stride]
|
282 |
+
else:
|
283 |
+
start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
|
284 |
+
frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
|
285 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
286 |
+
else:
|
287 |
+
if end_frame < clip_length * clip_stride:
|
288 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
289 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
290 |
+
frames = torch.cat((frames, zeros), dim=0)
|
291 |
+
frames = frames[::clip_stride]
|
292 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
293 |
+
else:
|
294 |
+
frame_ids = []
|
295 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
296 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
297 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
298 |
+
return frames, sentence
|
299 |
+
elif self.dataset == 'charades_ego':
|
300 |
+
vid_path, start_frame, end_frame, action_list = self.samples[i]
|
301 |
+
if sparse_sample:
|
302 |
+
frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
|
303 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
304 |
+
else:
|
305 |
+
if end_frame < clip_length * clip_stride:
|
306 |
+
frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
|
307 |
+
zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
|
308 |
+
frames = torch.cat((frames, zeros), dim=0)
|
309 |
+
frames = frames[::clip_stride]
|
310 |
+
frames = frames.repeat(num_clips, 1, 1, 1)
|
311 |
+
else:
|
312 |
+
frame_ids = []
|
313 |
+
for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
|
314 |
+
frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
|
315 |
+
#print('frame_ids:', frame_ids)
|
316 |
+
frames = video_loader_by_frames(self.root, vid_path, frame_ids)
|
317 |
+
return frames, action_list, vid_path
|
318 |
+
elif self.dataset == 'charades_ego_trimmed':
|
319 |
+
vid, start_second, end_second, narration = self.samples[i]
|
320 |
+
frames = video_loader(self.root, vid, start_second,
|
321 |
+
end_second=end_second,
|
322 |
+
chunk_len=-1, # no chunk for CharadesEgo
|
323 |
+
fps=-1, # could be variable fps
|
324 |
+
clip_length=clip_length,
|
325 |
+
jitter=is_training)
|
326 |
+
return frames, narration
|
327 |
+
else:
|
328 |
+
raise NotImplementedError
|
329 |
+
|
330 |
+
def __getitem__(self, i):
|
331 |
+
raise NotImplementedError
|
332 |
+
|
333 |
+
def __len__(self):
|
334 |
+
return len(self.samples)
|
335 |
+
|
336 |
+
|
337 |
+
class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
|
338 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
339 |
+
is_training=True, tokenizer=None,
|
340 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
341 |
+
narration_selection='random',
|
342 |
+
num_hard_negatives=0,
|
343 |
+
subsample_stride=None):
|
344 |
+
super().__init__(dataset, root, metadata)
|
345 |
+
|
346 |
+
self.full_samples = self.samples.copy()
|
347 |
+
if isinstance(subsample_stride, int):
|
348 |
+
self.samples = self.samples[::subsample_stride]
|
349 |
+
self.transform = transform
|
350 |
+
self.is_training = is_training
|
351 |
+
self.tokenizer = tokenizer
|
352 |
+
self.clip_length = clip_length
|
353 |
+
self.clip_stride = clip_stride
|
354 |
+
self.sparse_sample = sparse_sample
|
355 |
+
self.narration_selection = narration_selection
|
356 |
+
self.num_hard_negatives = num_hard_negatives
|
357 |
+
if num_hard_negatives > 0:
|
358 |
+
assert self.dataset == 'htm_aa'
|
359 |
+
|
360 |
+
def __getitem__(self, i):
|
361 |
+
frames, caption = self.get_raw_item(
|
362 |
+
i, is_training=self.is_training,
|
363 |
+
clip_length=self.clip_length,
|
364 |
+
clip_stride=self.clip_stride,
|
365 |
+
sparse_sample=self.sparse_sample,
|
366 |
+
narration_selection=self.narration_selection,
|
367 |
+
)
|
368 |
+
|
369 |
+
# ek100_mir will also output relevancy value
|
370 |
+
if isinstance(caption, tuple):
|
371 |
+
caption, relevancy = caption
|
372 |
+
else:
|
373 |
+
relevancy = 0.
|
374 |
+
|
375 |
+
# apply transformation
|
376 |
+
if self.transform is not None:
|
377 |
+
frames = self.transform(frames)
|
378 |
+
|
379 |
+
# tokenize caption
|
380 |
+
if self.tokenizer is not None:
|
381 |
+
caption = self.tokenizer(caption)
|
382 |
+
|
383 |
+
if isinstance(caption, tuple):
|
384 |
+
caption, mask = caption
|
385 |
+
return frames, caption, mask, relevancy
|
386 |
+
else:
|
387 |
+
return frames, caption, relevancy
|
388 |
+
|
389 |
+
|
390 |
+
class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
|
391 |
+
def __init__(self, dataset, root, metadata, transform=None,
|
392 |
+
is_training=True, tokenizer=None,
|
393 |
+
clip_length=32, clip_stride=2, sparse_sample=False,
|
394 |
+
narration_selection='random'):
|
395 |
+
super().__init__(dataset, root, metadata)
|
396 |
+
|
397 |
+
self.full_samples = self.samples.copy()
|
398 |
+
self.transform = transform
|
399 |
+
self.is_training = is_training
|
400 |
+
self.tokenizer = tokenizer
|
401 |
+
self.clip_length = clip_length
|
402 |
+
self.clip_stride = clip_stride
|
403 |
+
self.sparse_sample = sparse_sample
|
404 |
+
self.narration_selection = narration_selection
|
405 |
+
|
406 |
+
def __getitem__(self, i):
|
407 |
+
|
408 |
+
textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
|
409 |
+
i, is_training=self.is_training,
|
410 |
+
clip_length=self.clip_length,
|
411 |
+
clip_stride=self.clip_stride,
|
412 |
+
sparse_sample=self.sparse_sample,
|
413 |
+
narration_selection=self.narration_selection,
|
414 |
+
)
|
415 |
+
|
416 |
+
# apply transformation
|
417 |
+
if self.transform is not None:
|
418 |
+
frames_options = [self.transform(frames) for frames in frames_options]
|
419 |
+
|
420 |
+
# tokenize caption
|
421 |
+
if self.tokenizer is not None:
|
422 |
+
textQuery = self.tokenizer(textQuery)
|
423 |
+
narration_options = self.tokenizer(narration_options)
|
424 |
+
if isinstance(textQuery, tuple):
|
425 |
+
textQuery, mask_query = textQuery
|
426 |
+
narration_options, mask_options = narration_options
|
427 |
+
return (
|
428 |
+
textQuery, torch.stack(frames_options, dim=0),
|
429 |
+
narration_options, answerIndex, q_type,
|
430 |
+
mask_query, mask_options
|
431 |
+
)
|
432 |
+
else:
|
433 |
+
return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
|
434 |
+
|
435 |
+
|
436 |
+
class VideoClassyDataset(VideoCaptionDatasetBase):
|
437 |
+
def __init__(
|
438 |
+
self, dataset, root, metadata, transform=None,
|
439 |
+
is_training=True, label_mapping=None,
|
440 |
+
num_clips=1,
|
441 |
+
clip_length=32, clip_stride=2,
|
442 |
+
sparse_sample=False,
|
443 |
+
is_trimmed=True,
|
444 |
+
):
|
445 |
+
super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
|
446 |
+
|
447 |
+
self.transform = transform
|
448 |
+
self.is_training = is_training
|
449 |
+
self.label_mapping = label_mapping
|
450 |
+
self.num_clips = num_clips
|
451 |
+
self.clip_length = clip_length
|
452 |
+
self.clip_stride = clip_stride
|
453 |
+
self.sparse_sample = sparse_sample
|
454 |
+
|
455 |
+
def __getitem__(self, i):
|
456 |
+
frames, label, vid_path = self.get_raw_item(
|
457 |
+
i, is_training=self.is_training,
|
458 |
+
num_clips=self.num_clips,
|
459 |
+
clip_length=self.clip_length,
|
460 |
+
clip_stride=self.clip_stride,
|
461 |
+
sparse_sample=self.sparse_sample,
|
462 |
+
)
|
463 |
+
|
464 |
+
# apply transformation
|
465 |
+
if self.transform is not None:
|
466 |
+
frames = self.transform(frames)
|
467 |
+
|
468 |
+
if self.label_mapping is not None:
|
469 |
+
if isinstance(label, list):
|
470 |
+
# multi-label case
|
471 |
+
res_array = np.zeros(len(self.label_mapping))
|
472 |
+
for lbl in label:
|
473 |
+
res_array[self.label_mapping[lbl]] = 1.
|
474 |
+
label = res_array
|
475 |
+
else:
|
476 |
+
label = self.label_mapping[label]
|
477 |
+
|
478 |
+
return frames, label, vid_path
|
479 |
+
|
480 |
+
|
481 |
+
def get_dataset(train_transform, tokenizer, cfg, is_training=True):
|
482 |
+
narration_selection = cfg.get('narration_selection', 'random')
|
483 |
+
num_hard_neg = cfg.get('num_hard_neg', 0)
|
484 |
+
data_cfg = cfg['data']
|
485 |
+
if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'):
|
486 |
+
if is_training:
|
487 |
+
metadata = data_cfg['metadata']
|
488 |
+
else:
|
489 |
+
metadata = data_cfg['metadata_val']
|
490 |
+
|
491 |
+
return VideoCaptionDatasetCLIP(
|
492 |
+
data_cfg['dataset'], data_cfg['root'], metadata, train_transform,
|
493 |
+
is_training=is_training,
|
494 |
+
tokenizer=tokenizer,
|
495 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
496 |
+
sparse_sample=data_cfg['sparse_sample'],
|
497 |
+
narration_selection=narration_selection,
|
498 |
+
num_hard_negatives=num_hard_neg
|
499 |
+
)
|
500 |
+
else:
|
501 |
+
raise NotImplementedError
|
502 |
+
|
503 |
+
|
504 |
+
def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None):
|
505 |
+
data_cfg = cfg['data']
|
506 |
+
n_clips = num_clips if num_clips > 0 else data_cfg['num_clips']
|
507 |
+
if is_training:
|
508 |
+
metadata = data_cfg['metadata']
|
509 |
+
return VideoClassyDataset(
|
510 |
+
data_cfg['dataset'], data_cfg['root'], metadata, transform,
|
511 |
+
is_training=True, label_mapping=label_mapping,
|
512 |
+
num_clips=n_clips,
|
513 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
514 |
+
sparse_sample=data_cfg['sparse_sample'],
|
515 |
+
)
|
516 |
+
else:
|
517 |
+
metadata = data_cfg['metadata_val']
|
518 |
+
return VideoClassyDataset(
|
519 |
+
data_cfg['dataset'], data_cfg['root'], metadata, transform,
|
520 |
+
is_training=False, label_mapping=label_mapping,
|
521 |
+
num_clips=n_clips,
|
522 |
+
clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
|
523 |
+
sparse_sample=data_cfg['sparse_sample'],
|
524 |
+
is_trimmed=not data_cfg['dataset'] == 'charades_ego'
|
525 |
+
)
|
526 |
+
|
svitt/evaluation.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
def accuracy(output, target, topk=(1,)):
|
12 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
13 |
+
with torch.no_grad():
|
14 |
+
maxk = max(topk)
|
15 |
+
batch_size = target.size(0)
|
16 |
+
|
17 |
+
_, pred = output.topk(maxk, 1, True, True)
|
18 |
+
pred = pred.t()
|
19 |
+
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
|
20 |
+
|
21 |
+
res = []
|
22 |
+
for k in topk:
|
23 |
+
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
|
24 |
+
res.append(correct_k.mul_(100.0 / batch_size))
|
25 |
+
return res
|
26 |
+
|
27 |
+
|
28 |
+
def get_mean_accuracy(cm):
|
29 |
+
list_acc = []
|
30 |
+
for i in range(len(cm)):
|
31 |
+
acc = 0
|
32 |
+
if cm[i, :].sum() > 0:
|
33 |
+
acc = cm[i, i] / cm[i, :].sum()
|
34 |
+
list_acc.append(acc)
|
35 |
+
|
36 |
+
return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
|
svitt/evaluation_charades.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
|
9 |
+
|
10 |
+
def compute_map(submission_array, gt_array):
|
11 |
+
""" Returns mAP, weighted mAP, and AP array """
|
12 |
+
m_aps = []
|
13 |
+
n_classes = submission_array.shape[1]
|
14 |
+
for oc_i in range(n_classes):
|
15 |
+
sorted_idxs = np.argsort(-submission_array[:, oc_i])
|
16 |
+
tp = gt_array[:, oc_i][sorted_idxs] == 1
|
17 |
+
fp = np.invert(tp)
|
18 |
+
n_pos = tp.sum()
|
19 |
+
if n_pos < 0.1:
|
20 |
+
m_aps.append(float('nan'))
|
21 |
+
continue
|
22 |
+
fp.sum()
|
23 |
+
f_pcs = np.cumsum(fp)
|
24 |
+
t_pcs = np.cumsum(tp)
|
25 |
+
prec = t_pcs / (f_pcs+t_pcs).astype(float)
|
26 |
+
avg_prec = 0
|
27 |
+
for i in range(submission_array.shape[0]):
|
28 |
+
if tp[i]:
|
29 |
+
avg_prec += prec[i]
|
30 |
+
m_aps.append(avg_prec / n_pos.astype(float))
|
31 |
+
m_aps = np.array(m_aps)
|
32 |
+
#m_ap = np.mean(m_aps)
|
33 |
+
m_ap = m_aps[~np.isnan(m_aps)]
|
34 |
+
print(f'num of available classes: {len(m_ap)}')
|
35 |
+
m_ap = m_ap.mean() # compute mean w/o nan
|
36 |
+
w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
|
37 |
+
return m_ap, w_ap, m_aps
|
38 |
+
|
39 |
+
|
40 |
+
def charades_map(submission_array, gt_array):
|
41 |
+
"""
|
42 |
+
Approximate version of the charades evaluation function
|
43 |
+
For precise numbers, use the submission file with the official matlab script
|
44 |
+
"""
|
45 |
+
fix = submission_array.copy()
|
46 |
+
empty = np.sum(gt_array, axis=1) == 0
|
47 |
+
fix[empty, :] = np.NINF
|
48 |
+
return compute_map(fix, gt_array)
|
49 |
+
|
50 |
+
|
51 |
+
def create_submission(video_list, predictions, out_file):
|
52 |
+
assert len(video_list) == predictions.shape[0]
|
53 |
+
with open(out_file, 'w') as f:
|
54 |
+
for i, video_id in enumerate(video_list):
|
55 |
+
pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
|
56 |
+
f.write('{} {}\n\n'.format(video_id, pred_str))
|
svitt/model.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from svitt.utils import (
|
2 |
+
interpolate_pos_embed,
|
3 |
+
interpolate_pos_relative_bias_beit_3d,
|
4 |
+
)
|
5 |
+
from omegaconf import OmegaConf
|
6 |
+
from transformers import ViTModel, ViTConfig
|
7 |
+
from svitt.sparse_config import BertConfig, BeitConfig
|
8 |
+
from svitt.sparse_xbeit import BeitModel
|
9 |
+
from svitt.sparse_xbert import BertModel, BertForMaskedLM
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import nn
|
13 |
+
import torch.nn.functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class SViTT(nn.Module):
|
17 |
+
"""Common utils shared by pretraining and downstream retrieval"""
|
18 |
+
def __init__(self, config=None, tokenizer=None, pretrain=True, **kwargs):
|
19 |
+
super().__init__()
|
20 |
+
self.config = config
|
21 |
+
self.tokenizer = tokenizer
|
22 |
+
self.embed_dim = config.embed_dim
|
23 |
+
self.vision_width = 768
|
24 |
+
self.text_width = 768
|
25 |
+
self.pretrain = pretrain
|
26 |
+
|
27 |
+
self.vision_encoder, self.vision_layernorm = self.build_vision_encoder()
|
28 |
+
self.text_encoder = self.build_text_encoder()
|
29 |
+
|
30 |
+
self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
|
31 |
+
self.text_proj = nn.Linear(self.text_width, self.embed_dim)
|
32 |
+
|
33 |
+
self.temp = nn.Parameter(torch.ones([]) * config.temp)
|
34 |
+
self.itm_head = nn.Linear(self.text_width, 2)
|
35 |
+
|
36 |
+
|
37 |
+
def build_text_encoder(self):
|
38 |
+
|
39 |
+
bert_config = BertConfig.from_json_file(self.config.bert_config)
|
40 |
+
|
41 |
+
# Override params for sparse vision encoder
|
42 |
+
model_args = getattr(self.config, 'text_encoder_args', {})
|
43 |
+
if model_args:
|
44 |
+
model_args = OmegaConf.to_object(model_args)
|
45 |
+
bert_config.update(model_args)
|
46 |
+
|
47 |
+
if self.pretrain:
|
48 |
+
text_encoder, _ = BertForMaskedLM.from_pretrained(
|
49 |
+
self.config.text_encoder, config=bert_config,
|
50 |
+
output_loading_info=True
|
51 |
+
)
|
52 |
+
else:
|
53 |
+
text_encoder, _ = BertModel.from_pretrained(
|
54 |
+
self.config.text_encoder, config=bert_config,
|
55 |
+
add_pooling_layer=False, output_loading_info=True
|
56 |
+
)
|
57 |
+
return text_encoder
|
58 |
+
|
59 |
+
def build_vision_encoder(self):
|
60 |
+
# if self.config.vit_type in ["beit", "deit", "vit", "vit32"]:
|
61 |
+
if self.config.vit_type in ["beit"]:
|
62 |
+
vision_encoder = self.build_huggingface_vit_with_image_size(
|
63 |
+
self.config.vit_name_or_pretrained_path, self.config.image_res,)
|
64 |
+
else:
|
65 |
+
raise ValueError(f"Unknown vit type {self.config.vit_type}")
|
66 |
+
|
67 |
+
# add layernorm for normalizing BEiT outputs hidden states
|
68 |
+
vision_layernorm = None
|
69 |
+
if self.config.vit_type == "beit":
|
70 |
+
vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12)
|
71 |
+
return vision_encoder, vision_layernorm
|
72 |
+
|
73 |
+
# @classmethod
|
74 |
+
# def build_huggingface_vit_with_image_size(cls, model_card: str, image_size: int):
|
75 |
+
def build_huggingface_vit_with_image_size(self, model_card: str, image_size: int):
|
76 |
+
"""Build a vit model from huggingface hub, also interpolate pos_embed when needed.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
model_card: name in huggingface hub, e.g., `facebook/deit-base-patch16-224`
|
80 |
+
image_size: new image size, may be different from pre-training image_size of `model_card`
|
81 |
+
|
82 |
+
ref: https://github.com/huggingface/transformers/issues/12167#issuecomment-861356232
|
83 |
+
"""
|
84 |
+
is_beit = "beit" in model_card
|
85 |
+
if "beit" in model_card:
|
86 |
+
model_cls, config_cls = BeitModel, BeitConfig
|
87 |
+
elif "deit" in model_card or "vit" in model_card:
|
88 |
+
# the deit model we use is loaded in vit arch,
|
89 |
+
# see https://huggingface.co/facebook/deit-base-patch16-224#how-to-use
|
90 |
+
model_cls, config_cls = ViTModel, ViTConfig
|
91 |
+
else:
|
92 |
+
raise ValueError(f"Unexpected model_card: {model_card}")
|
93 |
+
|
94 |
+
# BEiT uses average pooled tokens instead of [CLS] used by other models
|
95 |
+
tmp_model = model_cls.from_pretrained(model_card, add_pooling_layer=is_beit)
|
96 |
+
state_dict = tmp_model.state_dict()
|
97 |
+
del tmp_model
|
98 |
+
|
99 |
+
# Override params for sparse vision encoder
|
100 |
+
model_args = getattr(self.config, 'vision_encoder_args', {})
|
101 |
+
if model_args:
|
102 |
+
model_args = OmegaConf.to_object(model_args)
|
103 |
+
model_config = config_cls.from_pretrained(
|
104 |
+
model_card,
|
105 |
+
image_size=image_size,
|
106 |
+
**model_args,
|
107 |
+
)
|
108 |
+
model = model_cls(config=model_config, add_pooling_layer=is_beit, num_frames=self.config.video_input.num_frames)
|
109 |
+
if is_beit:
|
110 |
+
# interpolate relative pos bias
|
111 |
+
state_dict = interpolate_pos_relative_bias_beit_3d(
|
112 |
+
state_dict_old=state_dict,
|
113 |
+
state_dict_new=model.state_dict(),
|
114 |
+
patch_shape_new=model.window_size
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
# interpolate pos_embed and load weights to new model
|
118 |
+
state_dict["embeddings.position_embeddings"] = interpolate_pos_embed(
|
119 |
+
pos_embed_old=state_dict["embeddings.position_embeddings"],
|
120 |
+
pos_embed_new=model.embeddings.position_embeddings,
|
121 |
+
num_patches_new=model.embeddings.patch_embeddings.num_patches
|
122 |
+
)
|
123 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
124 |
+
return model
|
125 |
+
|
126 |
+
def get_text_encoder(self):
|
127 |
+
"""get text encoder, used for text and cross-modal encoding"""
|
128 |
+
encoder = self.text_encoder
|
129 |
+
return encoder.bert if hasattr(encoder, "bert") else encoder
|
130 |
+
|
131 |
+
def encode_image(self, video, output_token_idx=False, output_attentions=False):
|
132 |
+
video_embeds = self.vision_encoder(video, output_token_idx=output_token_idx, output_attentions=output_attentions) # (bsz, seq_len, d)
|
133 |
+
if self.vision_layernorm is not None: # only for BEiT mean-pooling
|
134 |
+
video_embeds.last_hidden_state = self.vision_layernorm(video_embeds.last_hidden_state)
|
135 |
+
if output_token_idx:
|
136 |
+
token_idx = video_embeds.token_idx
|
137 |
+
|
138 |
+
if output_attentions:
|
139 |
+
attentions = video_embeds.attentions
|
140 |
+
|
141 |
+
if self.config.vit_type == "beit":
|
142 |
+
pooled_video_embeds = video_embeds.pooler_output # (bsz*num_frms, d)
|
143 |
+
video_embeds = video_embeds.last_hidden_state # (bsz*num_frms, L, d)
|
144 |
+
else:
|
145 |
+
video_embeds = video_embeds.last_hidden_state
|
146 |
+
pooled_video_embeds = video_embeds[:, 0]
|
147 |
+
|
148 |
+
outputs = (video_embeds, pooled_video_embeds)
|
149 |
+
|
150 |
+
if output_token_idx:
|
151 |
+
outputs += (token_idx,)
|
152 |
+
|
153 |
+
if output_attentions:
|
154 |
+
outputs += (attentions,)
|
155 |
+
|
156 |
+
return outputs
|
157 |
+
|
158 |
+
def _encode_image(self, image):
|
159 |
+
bsz, num_frms, c, h, w = image.shape # `num_frms` could be changing for image (=1) or video (e.g., =4)
|
160 |
+
image = image.view(bsz*num_frms, c, h, w)
|
161 |
+
image_embeds = self.vision_encoder(image)
|
162 |
+
if self.vision_layernorm is not None: # only for BEiT mean-pooling
|
163 |
+
image_embeds.last_hidden_state = self.vision_layernorm(image_embeds.last_hidden_state)
|
164 |
+
|
165 |
+
if self.config.vit_type == "beit":
|
166 |
+
pooled_image_embeds = image_embeds.pooler_output # (bsz*num_frms, d)
|
167 |
+
image_embeds = image_embeds.last_hidden_state # (bsz*num_frms, L, d)
|
168 |
+
else:
|
169 |
+
image_embeds = image_embeds.last_hidden_state
|
170 |
+
pooled_image_embeds = image_embeds[:, 0]
|
171 |
+
|
172 |
+
image_embeds = image_embeds.view(bsz, num_frms, -1, self.vision_width) # (bsz, num_frms, L, d)
|
173 |
+
pooled_image_embeds = pooled_image_embeds.view(bsz, num_frms, self.vision_width) \
|
174 |
+
if pooled_image_embeds is not None else None # (bsz, num_frms, d)
|
175 |
+
return image_embeds, pooled_image_embeds
|
176 |
+
|
177 |
+
def encode_text(self, text):
|
178 |
+
text_output = self.get_text_encoder()(
|
179 |
+
text.input_ids,
|
180 |
+
attention_mask=text.attention_mask,
|
181 |
+
return_dict=True,
|
182 |
+
mode='text'
|
183 |
+
)
|
184 |
+
text_embeds = text_output.last_hidden_state
|
185 |
+
pooled_text_embeds = text_embeds[:, 0]
|
186 |
+
return text_embeds, pooled_text_embeds
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
|
190 |
+
"""Seems only used during pre-training"""
|
191 |
+
self.temp.clamp_(min_val, max_val)
|
192 |
+
|
193 |
+
@torch.no_grad()
|
194 |
+
def get_mask(self, sim, idx=None, normalize=False):
|
195 |
+
"""
|
196 |
+
sim: (N, N)
|
197 |
+
idx: (N, )
|
198 |
+
normalize: bool, make row sum equal to 1
|
199 |
+
"""
|
200 |
+
if idx is not None:
|
201 |
+
idx = idx.view(-1, 1)
|
202 |
+
mask = torch.eq(idx, idx.T).to(sim.dtype)
|
203 |
+
if normalize:
|
204 |
+
mask = mask / mask.sum(1, keepdim=True)
|
205 |
+
else:
|
206 |
+
mask = torch.zeros_like(sim)
|
207 |
+
mask.fill_diagonal_(1)
|
208 |
+
return mask # `1` mark valid/matched location
|
209 |
+
|
210 |
+
def get_contrastive_loss(self, pooled_image_embeds, pooled_text_embeds, idx=None):
|
211 |
+
sim_i2t, sim_t2i = self.get_sim(
|
212 |
+
pooled_image_embeds, pooled_text_embeds, t=self.temp)
|
213 |
+
|
214 |
+
with torch.no_grad():
|
215 |
+
sim_i2t_targets = self.get_mask(sim_i2t, idx=idx, normalize=True)
|
216 |
+
sim_t2i_targets = sim_i2t_targets
|
217 |
+
|
218 |
+
loss_i2t = -torch.sum(
|
219 |
+
F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
|
220 |
+
loss_t2i = -torch.sum(
|
221 |
+
F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
|
222 |
+
|
223 |
+
loss_ita = (loss_i2t + loss_t2i) / 2
|
224 |
+
return loss_ita, sim_i2t, sim_t2i
|
225 |
+
|
226 |
+
def get_sim(self, pooled_image_embeds, pooled_text_embeds, t=1):
|
227 |
+
"""
|
228 |
+
Args:
|
229 |
+
pooled_image_embeds: (bsz, num_frms, d)
|
230 |
+
pooled_text_embeds: (bsz, d)
|
231 |
+
t: temperature
|
232 |
+
"""
|
233 |
+
image_proj = self.vision_proj
|
234 |
+
text_proj = self.text_proj
|
235 |
+
|
236 |
+
image_feat = F.normalize(image_proj(pooled_image_embeds), dim=-1)
|
237 |
+
text_feat = F.normalize(text_proj(pooled_text_embeds), dim=-1)
|
238 |
+
|
239 |
+
if image_feat.ndim == 3:
|
240 |
+
sim_i2t = torch.einsum("mld,nd->mln", image_feat, text_feat).mean(1) / t # (N, N)
|
241 |
+
else:
|
242 |
+
sim_i2t = torch.einsum("md,nd ->mn", image_feat, text_feat) / t # (N, N)
|
243 |
+
sim_t2i = sim_i2t.T
|
244 |
+
return sim_i2t, sim_t2i
|
245 |
+
|
246 |
+
def get_itm_loss(self,
|
247 |
+
sim_i2t,
|
248 |
+
sim_t2i,
|
249 |
+
text_embeds,
|
250 |
+
text_atts,
|
251 |
+
image_embeds,
|
252 |
+
image_atts,
|
253 |
+
idx=None,
|
254 |
+
):
|
255 |
+
"""
|
256 |
+
sim_i2t, sim_t2i: (N, N)
|
257 |
+
text_embeds, text_atts, image_embeds, image_atts: (N, *)
|
258 |
+
idx: (N, )
|
259 |
+
"""
|
260 |
+
bsz = len(sim_i2t)
|
261 |
+
|
262 |
+
with torch.no_grad():
|
263 |
+
weights_i2t = F.softmax(sim_i2t+1e-4, dim=1) # (N, N)
|
264 |
+
weights_t2i = F.softmax(sim_t2i+1e-4, dim=1)
|
265 |
+
|
266 |
+
mask = self.get_mask(sim_i2t, idx=idx).bool()
|
267 |
+
weights_i2t.masked_fill_(mask, 0)
|
268 |
+
weights_t2i.masked_fill_(mask, 0)
|
269 |
+
|
270 |
+
# select a negative image for each text
|
271 |
+
if self.config.itm_hard_neg:
|
272 |
+
img_neg_indices = torch.multinomial(weights_t2i, 1).squeeze() #RuntimeError: invalid multinomial distribution (sum of probabilities <= 0)
|
273 |
+
else:
|
274 |
+
img_neg_indices = self.get_rand_indices(mask, 1).squeeze()
|
275 |
+
|
276 |
+
image_embeds_neg = image_embeds[img_neg_indices]
|
277 |
+
|
278 |
+
# select a negative text for each image
|
279 |
+
if self.config.itm_hard_neg:
|
280 |
+
txt_neg_indices = torch.multinomial(weights_i2t, 1).squeeze()
|
281 |
+
else:
|
282 |
+
txt_neg_indices = self.get_rand_indices(mask, 1).squeeze()
|
283 |
+
|
284 |
+
text_embeds_neg = text_embeds[txt_neg_indices]
|
285 |
+
text_atts_neg = text_atts[txt_neg_indices] # (N, L, d)
|
286 |
+
|
287 |
+
# embedding on local gpu
|
288 |
+
_text_embeds = text_embeds
|
289 |
+
_text_atts = text_atts
|
290 |
+
_image_embeds = image_embeds
|
291 |
+
_image_atts = image_atts
|
292 |
+
# concat embeddings
|
293 |
+
text_embeds_all = torch.cat([_text_embeds, _text_embeds, text_embeds_neg], dim=0)
|
294 |
+
text_atts_all = torch.cat([_text_atts, _text_atts, text_atts_neg], dim=0)
|
295 |
+
image_embeds_all = torch.cat([_image_embeds, image_embeds_neg, _image_embeds], dim=0)
|
296 |
+
image_atts_all = torch.cat([_image_atts, _image_atts, _image_atts], dim=0)
|
297 |
+
|
298 |
+
text_encoder = self.get_text_encoder()
|
299 |
+
output = text_encoder(
|
300 |
+
encoder_embeds=text_embeds_all,
|
301 |
+
attention_mask=text_atts_all,
|
302 |
+
encoder_hidden_states=image_embeds_all,
|
303 |
+
encoder_attention_mask=image_atts_all,
|
304 |
+
return_dict=True,
|
305 |
+
mode='fusion',
|
306 |
+
)
|
307 |
+
|
308 |
+
itm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
|
309 |
+
|
310 |
+
loss_itm = self._get_itm_loss(itm_embeds, enc=self.itm_head)
|
311 |
+
itm_embeds_pos = itm_embeds[:bsz] # (N, d)
|
312 |
+
|
313 |
+
return loss_itm, itm_embeds_pos
|
314 |
+
|
315 |
+
def _get_itm_loss(self, itm_embeds, enc):
|
316 |
+
"""
|
317 |
+
itm_embeds: (3*N, D)
|
318 |
+
enc: nn.Module that projects cls_embeds
|
319 |
+
"""
|
320 |
+
itm_scores = enc(itm_embeds) # (3*N, 2)
|
321 |
+
bs = itm_scores.size(0) // 3
|
322 |
+
itm_labels = itm_scores.new_ones(3*bs, dtype=torch.long)
|
323 |
+
itm_labels[bs:] = 0
|
324 |
+
loss_itm = F.cross_entropy(itm_scores, itm_labels)
|
325 |
+
return loss_itm
|
326 |
+
|
327 |
+
def get_rand_indices(self, mask, k):
|
328 |
+
"""
|
329 |
+
Args:
|
330 |
+
mask: (N, L) 0 indicates the positions that we can sample, 1 otherwise
|
331 |
+
k: #indices to sample at each row
|
332 |
+
Returns:
|
333 |
+
(N, k) indices
|
334 |
+
"""
|
335 |
+
mask = mask.float()
|
336 |
+
mask = mask - 10000 * mask
|
337 |
+
mask += torch.randn_like(mask)
|
338 |
+
_, indices = torch.sort(mask, dim=1, descending=True)
|
339 |
+
indices = indices[:, :k].contiguous()
|
340 |
+
return indices
|
svitt/preprocess.py
ADDED
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import csv
|
8 |
+
|
9 |
+
from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def generate_label_map(dataset):
|
13 |
+
if dataset == 'ek100_cls':
|
14 |
+
print("Preprocess ek100 action label space")
|
15 |
+
vn_list = []
|
16 |
+
mapping_vn2narration = {}
|
17 |
+
for f in [
|
18 |
+
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
|
19 |
+
'/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
|
20 |
+
]:
|
21 |
+
csv_reader = csv.reader(open(f))
|
22 |
+
_ = next(csv_reader) # skip the header
|
23 |
+
for row in csv_reader:
|
24 |
+
vn = '{}:{}'.format(int(row[10]), int(row[12]))
|
25 |
+
narration = row[8]
|
26 |
+
if vn not in vn_list:
|
27 |
+
vn_list.append(vn)
|
28 |
+
if vn not in mapping_vn2narration:
|
29 |
+
mapping_vn2narration[vn] = [narration]
|
30 |
+
else:
|
31 |
+
mapping_vn2narration[vn].append(narration)
|
32 |
+
# mapping_vn2narration[vn] = [narration]
|
33 |
+
vn_list = sorted(vn_list)
|
34 |
+
print('# of action= {}'.format(len(vn_list)))
|
35 |
+
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
|
36 |
+
labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
|
37 |
+
print(labels[:5])
|
38 |
+
elif dataset == 'charades_ego':
|
39 |
+
print("=> preprocessing charades_ego action label space")
|
40 |
+
vn_list = []
|
41 |
+
labels = []
|
42 |
+
with open('data/charades_ego/Charades_v1_classes.txt') as f:
|
43 |
+
csv_reader = csv.reader(f)
|
44 |
+
for row in csv_reader:
|
45 |
+
vn = row[0][:4]
|
46 |
+
vn_list.append(vn)
|
47 |
+
narration = row[0][5:]
|
48 |
+
labels.append(narration)
|
49 |
+
mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
|
50 |
+
print(labels[:5])
|
51 |
+
elif dataset == 'egtea':
|
52 |
+
print("=> preprocessing egtea action label space")
|
53 |
+
labels = []
|
54 |
+
with open('/data/EGTEA/action_idx.txt') as f:
|
55 |
+
for row in f:
|
56 |
+
row = row.strip()
|
57 |
+
narration = ' '.join(row.split(' ')[:-1])
|
58 |
+
labels.append(narration.replace('_', ' ').lower())
|
59 |
+
# labels.append(narration)
|
60 |
+
mapping_vn2act = {label: i for i, label in enumerate(labels)}
|
61 |
+
print(len(labels), labels[:5])
|
62 |
+
else:
|
63 |
+
raise NotImplementedError
|
64 |
+
return labels, mapping_vn2act
|
65 |
+
|
66 |
+
|
67 |
+
def generate_tokenizer(model):
|
68 |
+
if model.endswith('DISTILBERT_BASE'):
|
69 |
+
tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
|
70 |
+
elif model.endswith('BERT_BASE'):
|
71 |
+
tokenizer = MyBertTokenizer('bert-base-uncased')
|
72 |
+
elif model.endswith('BERT_LARGE'):
|
73 |
+
tokenizer = MyBertTokenizer('bert-large-uncased')
|
74 |
+
elif model.endswith('GPT2'):
|
75 |
+
tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
|
76 |
+
elif model.endswith('GPT2_MEDIUM'):
|
77 |
+
tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True)
|
78 |
+
elif model.endswith('GPT2_LARGE'):
|
79 |
+
tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True)
|
80 |
+
elif model.endswith('GPT2_XL'):
|
81 |
+
tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
|
82 |
+
else:
|
83 |
+
print("Using SimpleTokenizer because of model '{}'. "
|
84 |
+
"Please check if this is what you want".format(model))
|
85 |
+
tokenizer = SimpleTokenizer()
|
86 |
+
return tokenizer
|
svitt/sparse_config.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
from collections import OrderedDict
|
17 |
+
from typing import Mapping
|
18 |
+
|
19 |
+
from transformers.configuration_utils import PretrainedConfig
|
20 |
+
from transformers.onnx import OnnxConfig
|
21 |
+
|
22 |
+
|
23 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
24 |
+
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
|
25 |
+
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
|
26 |
+
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
|
27 |
+
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
|
28 |
+
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
|
29 |
+
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
|
30 |
+
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
|
31 |
+
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
|
32 |
+
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json",
|
33 |
+
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json",
|
34 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json",
|
35 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
|
36 |
+
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
|
37 |
+
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
|
38 |
+
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
|
39 |
+
"cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
|
40 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json",
|
41 |
+
"cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json",
|
42 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json",
|
43 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
|
44 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
|
45 |
+
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
|
46 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
class BertConfig(PretrainedConfig):
|
51 |
+
r"""
|
52 |
+
This is the configuration class to store the configuration of a [`BertModel`] or a
|
53 |
+
[`TFBertModel`]. It is used to instantiate a BERT model according to the specified arguments,
|
54 |
+
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
|
55 |
+
to that of the BERT [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
|
56 |
+
|
57 |
+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
|
58 |
+
outputs. Read the documentation from [`PretrainedConfig`] for more information.
|
59 |
+
|
60 |
+
|
61 |
+
Args:
|
62 |
+
vocab_size (`int`, *optional*, defaults to 30522):
|
63 |
+
Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
|
64 |
+
`inputs_ids` passed when calling [`BertModel`] or
|
65 |
+
[`TFBertModel`].
|
66 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
67 |
+
Dimensionality of the encoder layers and the pooler layer.
|
68 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
69 |
+
Number of hidden layers in the Transformer encoder.
|
70 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
71 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
72 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
73 |
+
Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
|
74 |
+
hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
|
75 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
76 |
+
`"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported.
|
77 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
78 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
79 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
80 |
+
The dropout ratio for the attention probabilities.
|
81 |
+
max_position_embeddings (`int`, *optional*, defaults to 512):
|
82 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
83 |
+
just in case (e.g., 512 or 1024 or 2048).
|
84 |
+
type_vocab_size (`int`, *optional*, defaults to 2):
|
85 |
+
The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or
|
86 |
+
[`TFBertModel`].
|
87 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
88 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
89 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
90 |
+
The epsilon used by the layer normalization layers.
|
91 |
+
position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
|
92 |
+
Type of position embedding. Choose one of `"absolute"`, `"relative_key"`,
|
93 |
+
`"relative_key_query"`. For positional embeddings use `"absolute"`. For more information on
|
94 |
+
`"relative_key"`, please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on `"relative_key_query"`, please refer to
|
95 |
+
*Method 4* in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
|
96 |
+
use_cache (`bool`, *optional*, defaults to `True`):
|
97 |
+
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
98 |
+
relevant if `config.is_decoder=True`.
|
99 |
+
classifier_dropout (`float`, *optional*):
|
100 |
+
The dropout ratio for the classification head.
|
101 |
+
|
102 |
+
Examples:
|
103 |
+
|
104 |
+
```python
|
105 |
+
>>> from transformers import BertModel, BertConfig
|
106 |
+
|
107 |
+
>>> # Initializing a BERT bert-base-uncased style configuration
|
108 |
+
>>> configuration = BertConfig()
|
109 |
+
|
110 |
+
>>> # Initializing a model from the bert-base-uncased style configuration
|
111 |
+
>>> model = BertModel(configuration)
|
112 |
+
|
113 |
+
>>> # Accessing the model configuration
|
114 |
+
>>> configuration = model.config
|
115 |
+
```"""
|
116 |
+
model_type = "bert"
|
117 |
+
|
118 |
+
def __init__(
|
119 |
+
self,
|
120 |
+
vocab_size=30522,
|
121 |
+
hidden_size=768,
|
122 |
+
num_hidden_layers=12,
|
123 |
+
num_attention_heads=12,
|
124 |
+
intermediate_size=3072,
|
125 |
+
hidden_act="gelu",
|
126 |
+
hidden_dropout_prob=0.1,
|
127 |
+
attention_probs_dropout_prob=0.1,
|
128 |
+
max_position_embeddings=512,
|
129 |
+
type_vocab_size=2,
|
130 |
+
initializer_range=0.02,
|
131 |
+
layer_norm_eps=1e-12,
|
132 |
+
pad_token_id=0,
|
133 |
+
position_embedding_type="absolute",
|
134 |
+
use_cache=True,
|
135 |
+
classifier_dropout=None,
|
136 |
+
token_keep_rate=1,
|
137 |
+
token_keep_strategy='cls_attn',
|
138 |
+
token_drop_loc=[9],
|
139 |
+
**kwargs
|
140 |
+
):
|
141 |
+
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
142 |
+
|
143 |
+
self.vocab_size = vocab_size
|
144 |
+
self.hidden_size = hidden_size
|
145 |
+
self.num_hidden_layers = num_hidden_layers
|
146 |
+
self.num_attention_heads = num_attention_heads
|
147 |
+
self.hidden_act = hidden_act
|
148 |
+
self.intermediate_size = intermediate_size
|
149 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
150 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
151 |
+
self.max_position_embeddings = max_position_embeddings
|
152 |
+
self.type_vocab_size = type_vocab_size
|
153 |
+
self.initializer_range = initializer_range
|
154 |
+
self.layer_norm_eps = layer_norm_eps
|
155 |
+
self.position_embedding_type = position_embedding_type
|
156 |
+
self.use_cache = use_cache
|
157 |
+
self.classifier_dropout = classifier_dropout
|
158 |
+
self.token_keep_rate = token_keep_rate
|
159 |
+
self.token_keep_strategy = token_keep_strategy
|
160 |
+
self.token_drop_loc = token_drop_loc
|
161 |
+
|
162 |
+
|
163 |
+
class BertOnnxConfig(OnnxConfig):
|
164 |
+
@property
|
165 |
+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
166 |
+
return OrderedDict(
|
167 |
+
[
|
168 |
+
("input_ids", {0: "batch", 1: "sequence"}),
|
169 |
+
("attention_mask", {0: "batch", 1: "sequence"}),
|
170 |
+
("token_type_ids", {0: "batch", 1: "sequence"}),
|
171 |
+
]
|
172 |
+
)
|
173 |
+
|
174 |
+
|
175 |
+
BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
176 |
+
"microsoft/beit-base-patch16-224-in22k": "https://huggingface.co/microsoft/beit-base-patch16-224-in22k/resolve/main/config.json",
|
177 |
+
# See all BEiT models at https://huggingface.co/models?filter=beit
|
178 |
+
}
|
179 |
+
|
180 |
+
|
181 |
+
class BeitConfig(PretrainedConfig):
|
182 |
+
r"""
|
183 |
+
This is the configuration class to store the configuration of a [`BeitModel`]. It is used to
|
184 |
+
instantiate an BEiT model according to the specified arguments, defining the model architecture. Instantiating a
|
185 |
+
configuration with the defaults will yield a similar configuration to that of the BEiT
|
186 |
+
[microsoft/beit-base-patch16-224-in22k](https://huggingface.co/microsoft/beit-base-patch16-224-in22k)
|
187 |
+
architecture.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
vocab_size (`int`, *optional*, defaults to 8092):
|
191 |
+
Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
|
192 |
+
pre-training.
|
193 |
+
hidden_size (`int`, *optional*, defaults to 768):
|
194 |
+
Dimensionality of the encoder layers and the pooler layer.
|
195 |
+
num_hidden_layers (`int`, *optional*, defaults to 12):
|
196 |
+
Number of hidden layers in the Transformer encoder.
|
197 |
+
num_attention_heads (`int`, *optional*, defaults to 12):
|
198 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
199 |
+
intermediate_size (`int`, *optional*, defaults to 3072):
|
200 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
201 |
+
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
202 |
+
The non-linear activation function (function or string) in the encoder and pooler. If string,
|
203 |
+
`"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
|
204 |
+
hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
|
205 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
206 |
+
attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
|
207 |
+
The dropout ratio for the attention probabilities.
|
208 |
+
initializer_range (`float`, *optional*, defaults to 0.02):
|
209 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
210 |
+
layer_norm_eps (`float`, *optional*, defaults to 1e-12):
|
211 |
+
The epsilon used by the layer normalization layers.
|
212 |
+
image_size (`int`, *optional*, defaults to `224`):
|
213 |
+
The size (resolution) of each image.
|
214 |
+
patch_size (`int`, *optional*, defaults to `16`):
|
215 |
+
The size (resolution) of each patch.
|
216 |
+
num_channels (`int`, *optional*, defaults to `3`):
|
217 |
+
The number of input channels.
|
218 |
+
use_mask_token (`bool`, *optional*, defaults to `False`):
|
219 |
+
Whether to use a mask token for masked image modeling.
|
220 |
+
use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
|
221 |
+
Whether to use BERT-style absolute position embeddings.
|
222 |
+
use_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
223 |
+
Whether to use T5-style relative position embeddings in the self-attention layers.
|
224 |
+
use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
|
225 |
+
Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
|
226 |
+
layer_scale_init_value (`float`, *optional*, defaults to 0.1):
|
227 |
+
Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
|
228 |
+
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
229 |
+
Stochastic depth rate per sample (when applied in the main path of residual layers).
|
230 |
+
use_mean_pooling (`bool`, *optional*, defaults to `True`):
|
231 |
+
Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
|
232 |
+
CLS token, before applying the classification head.
|
233 |
+
out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
|
234 |
+
Indices of the feature maps to use for semantic segmentation.
|
235 |
+
pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
|
236 |
+
Pooling scales used in Pooling Pyramid Module applied on the last feature map.
|
237 |
+
use_auxiliary_head (`bool`, *optional*, defaults to `True`):
|
238 |
+
Whether to use an auxiliary head during training.
|
239 |
+
auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
|
240 |
+
Weight of the cross-entropy loss of the auxiliary head.
|
241 |
+
auxiliary_channels (`int`, *optional*, defaults to 256):
|
242 |
+
Number of channels to use in the auxiliary head.
|
243 |
+
auxiliary_num_convs (`int`, *optional*, defaults to 1):
|
244 |
+
Number of convolutional layers to use in the auxiliary head.
|
245 |
+
auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
|
246 |
+
Whether to concatenate the output of the auxiliary head with the input before the classification layer.
|
247 |
+
semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
|
248 |
+
The index that is ignored by the loss function of the semantic segmentation model.
|
249 |
+
|
250 |
+
Example:
|
251 |
+
|
252 |
+
```python
|
253 |
+
>>> from transformers import BeitModel, BeitConfig
|
254 |
+
|
255 |
+
>>> # Initializing a BEiT beit-base-patch16-224-in22k style configuration
|
256 |
+
>>> configuration = BeitConfig()
|
257 |
+
|
258 |
+
>>> # Initializing a model from the beit-base-patch16-224-in22k style configuration
|
259 |
+
>>> model = BeitModel(configuration)
|
260 |
+
|
261 |
+
>>> # Accessing the model configuration
|
262 |
+
>>> configuration = model.config
|
263 |
+
```"""
|
264 |
+
model_type = "beit"
|
265 |
+
|
266 |
+
def __init__(
|
267 |
+
self,
|
268 |
+
vocab_size=8192,
|
269 |
+
hidden_size=768,
|
270 |
+
num_hidden_layers=12,
|
271 |
+
num_attention_heads=12,
|
272 |
+
intermediate_size=3072,
|
273 |
+
hidden_act="gelu",
|
274 |
+
hidden_dropout_prob=0.0,
|
275 |
+
attention_probs_dropout_prob=0.0,
|
276 |
+
initializer_range=0.02,
|
277 |
+
layer_norm_eps=1e-12,
|
278 |
+
is_encoder_decoder=False,
|
279 |
+
image_size=224,
|
280 |
+
patch_size=16,
|
281 |
+
num_channels=3,
|
282 |
+
use_mask_token=False,
|
283 |
+
use_absolute_position_embeddings=False,
|
284 |
+
use_relative_position_bias=False,
|
285 |
+
use_shared_relative_position_bias=False,
|
286 |
+
layer_scale_init_value=0.1,
|
287 |
+
drop_path_rate=0.1,
|
288 |
+
use_mean_pooling=True,
|
289 |
+
out_indices=[3, 5, 7, 11],
|
290 |
+
pool_scales=[1, 2, 3, 6],
|
291 |
+
use_auxiliary_head=True,
|
292 |
+
auxiliary_loss_weight=0.4,
|
293 |
+
auxiliary_channels=256,
|
294 |
+
auxiliary_num_convs=1,
|
295 |
+
auxiliary_concat_input=False,
|
296 |
+
semantic_loss_ignore_index=255,
|
297 |
+
token_keep_rate=1,
|
298 |
+
token_keep_strategy='cls_attn',
|
299 |
+
token_drop_loc=[3, 6, 9],
|
300 |
+
sparse_random_attn=None,
|
301 |
+
sparse_local_attn=1,
|
302 |
+
attn_block_size=1,
|
303 |
+
num_cls_tokens=1,
|
304 |
+
token_3d_order='none',
|
305 |
+
**kwargs
|
306 |
+
):
|
307 |
+
super().__init__(**kwargs)
|
308 |
+
|
309 |
+
self.vocab_size = vocab_size
|
310 |
+
self.hidden_size = hidden_size
|
311 |
+
self.num_hidden_layers = num_hidden_layers
|
312 |
+
self.num_attention_heads = num_attention_heads
|
313 |
+
self.intermediate_size = intermediate_size
|
314 |
+
self.hidden_act = hidden_act
|
315 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
316 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
317 |
+
self.initializer_range = initializer_range
|
318 |
+
self.layer_norm_eps = layer_norm_eps
|
319 |
+
|
320 |
+
self.image_size = image_size
|
321 |
+
self.patch_size = patch_size
|
322 |
+
self.num_channels = num_channels
|
323 |
+
self.use_mask_token = use_mask_token
|
324 |
+
self.use_absolute_position_embeddings = use_absolute_position_embeddings
|
325 |
+
self.use_relative_position_bias = use_relative_position_bias
|
326 |
+
self.use_shared_relative_position_bias = use_shared_relative_position_bias
|
327 |
+
self.layer_scale_init_value = layer_scale_init_value
|
328 |
+
self.drop_path_rate = drop_path_rate
|
329 |
+
self.use_mean_pooling = use_mean_pooling
|
330 |
+
# decode head attributes (semantic segmentation)
|
331 |
+
self.out_indices = out_indices
|
332 |
+
self.pool_scales = pool_scales
|
333 |
+
# auxiliary head attributes (semantic segmentation)
|
334 |
+
self.use_auxiliary_head = use_auxiliary_head
|
335 |
+
self.auxiliary_loss_weight = auxiliary_loss_weight
|
336 |
+
self.auxiliary_channels = auxiliary_channels
|
337 |
+
self.auxiliary_num_convs = auxiliary_num_convs
|
338 |
+
self.auxiliary_concat_input = auxiliary_concat_input
|
339 |
+
self.semantic_loss_ignore_index = semantic_loss_ignore_index
|
340 |
+
|
341 |
+
# node sparsification
|
342 |
+
self.token_keep_rate = token_keep_rate
|
343 |
+
self.token_keep_strategy = token_keep_strategy
|
344 |
+
self.token_drop_loc = token_drop_loc
|
345 |
+
# edge sparsification
|
346 |
+
self.sparse_random_attn = sparse_random_attn
|
347 |
+
self.sparse_local_attn = sparse_local_attn
|
348 |
+
self.attn_block_size = attn_block_size
|
349 |
+
self.num_cls_tokens = num_cls_tokens
|
350 |
+
# token order
|
351 |
+
self.token_3d_order = token_3d_order
|
svitt/sparse_xbeit.py
ADDED
@@ -0,0 +1,1585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" PyTorch BEiT model. """
|
16 |
+
|
17 |
+
|
18 |
+
import collections.abc
|
19 |
+
import math
|
20 |
+
import numpy as np
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
import zCurve
|
24 |
+
import hilbert
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.utils.checkpoint
|
28 |
+
from torch import nn
|
29 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
30 |
+
from einops import rearrange, repeat
|
31 |
+
|
32 |
+
from transformers.activations import ACT2FN
|
33 |
+
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
34 |
+
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
|
35 |
+
from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
|
36 |
+
from svitt.sparse_config import BeitConfig
|
37 |
+
|
38 |
+
|
39 |
+
_CONFIG_FOR_DOC = "BeitConfig"
|
40 |
+
_CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224"
|
41 |
+
|
42 |
+
BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
43 |
+
"microsoft/beit-base-patch16-224",
|
44 |
+
# See all BEiT models at https://huggingface.co/models?filter=beit
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
@dataclass
|
49 |
+
class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
|
50 |
+
"""
|
51 |
+
Class for outputs of :class:`~transformers.BeitModel`.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
55 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
56 |
+
pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
|
57 |
+
Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
|
58 |
+
`config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
|
59 |
+
will be returned.
|
60 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
61 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
62 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
63 |
+
|
64 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
65 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
66 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
67 |
+
sequence_length, sequence_length)`.
|
68 |
+
|
69 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
70 |
+
heads.
|
71 |
+
"""
|
72 |
+
token_idx: Optional[Tuple[torch.LongTensor]] = None
|
73 |
+
|
74 |
+
|
75 |
+
@dataclass
|
76 |
+
class BeitModelOutput(BaseModelOutput):
|
77 |
+
token_idx: Optional[Tuple[torch.LongTensor]] = None
|
78 |
+
|
79 |
+
|
80 |
+
# Inspired by
|
81 |
+
# https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
|
82 |
+
# From PyTorch internals
|
83 |
+
def to_2tuple(x):
|
84 |
+
if isinstance(x, collections.abc.Iterable):
|
85 |
+
return x
|
86 |
+
return (x, x)
|
87 |
+
|
88 |
+
|
89 |
+
# Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
|
90 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
91 |
+
"""
|
92 |
+
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
93 |
+
|
94 |
+
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
|
95 |
+
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
96 |
+
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
|
97 |
+
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
|
98 |
+
argument.
|
99 |
+
"""
|
100 |
+
if drop_prob == 0.0 or not training:
|
101 |
+
return x
|
102 |
+
keep_prob = 1 - drop_prob
|
103 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
104 |
+
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
105 |
+
random_tensor.floor_() # binarize
|
106 |
+
output = x.div(keep_prob) * random_tensor
|
107 |
+
return output
|
108 |
+
|
109 |
+
|
110 |
+
class DropPath(nn.Module):
|
111 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
112 |
+
|
113 |
+
def __init__(self, drop_prob=None):
|
114 |
+
super().__init__()
|
115 |
+
self.drop_prob = drop_prob
|
116 |
+
|
117 |
+
def forward(self, x):
|
118 |
+
return drop_path(x, self.drop_prob, self.training)
|
119 |
+
|
120 |
+
def extra_repr(self) -> str:
|
121 |
+
return "p={}".format(self.drop_prob)
|
122 |
+
|
123 |
+
|
124 |
+
# Based on timm implementation, which can be found here:
|
125 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
126 |
+
class BeitEmbeddings(nn.Module):
|
127 |
+
"""
|
128 |
+
Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
|
129 |
+
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, config):
|
133 |
+
super().__init__()
|
134 |
+
|
135 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
136 |
+
if config.use_mask_token:
|
137 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
|
138 |
+
else:
|
139 |
+
self.mask_token = None
|
140 |
+
self.patch_embeddings = PatchEmbeddings(
|
141 |
+
image_size=config.image_size,
|
142 |
+
patch_size=config.patch_size,
|
143 |
+
num_channels=config.num_channels,
|
144 |
+
embed_dim=config.hidden_size,
|
145 |
+
)
|
146 |
+
num_patches = self.patch_embeddings.num_patches
|
147 |
+
if config.use_absolute_position_embeddings:
|
148 |
+
self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
|
149 |
+
else:
|
150 |
+
self.position_embeddings = None
|
151 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
152 |
+
|
153 |
+
def forward(self, pixel_values, bool_masked_pos=None):
|
154 |
+
|
155 |
+
if pixel_values.ndim == 5: # video input=
|
156 |
+
embeddings = self.patch_embeddings(pixel_values.flatten(0, 1))
|
157 |
+
embeddings = rearrange(embeddings, '(b m) n d -> b (m n) d', m=pixel_values.shape[1])
|
158 |
+
else: # image input
|
159 |
+
embeddings = self.patch_embeddings(pixel_values)
|
160 |
+
|
161 |
+
batch_size, seq_len, _ = embeddings.size()
|
162 |
+
|
163 |
+
cls_tokens = self.cls_token.expand(batch_size, -1, -1)
|
164 |
+
if bool_masked_pos is not None:
|
165 |
+
mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
|
166 |
+
# replace the masked visual tokens by mask_tokens
|
167 |
+
w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
|
168 |
+
embeddings = embeddings * (1 - w) + mask_tokens * w
|
169 |
+
|
170 |
+
embeddings = torch.cat((cls_tokens, embeddings), dim=1)
|
171 |
+
if self.position_embeddings is not None:
|
172 |
+
embeddings = embeddings + self.position_embeddings
|
173 |
+
embeddings = self.dropout(embeddings)
|
174 |
+
|
175 |
+
return embeddings
|
176 |
+
|
177 |
+
|
178 |
+
# Based on timm implementation, which can be found here:
|
179 |
+
# https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
|
180 |
+
class PatchEmbeddings(nn.Module):
|
181 |
+
"""
|
182 |
+
Image to Patch Embedding.
|
183 |
+
"""
|
184 |
+
|
185 |
+
def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
|
186 |
+
super().__init__()
|
187 |
+
image_size = to_2tuple(image_size)
|
188 |
+
patch_size = to_2tuple(patch_size)
|
189 |
+
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
|
190 |
+
patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
|
191 |
+
self.image_size = image_size
|
192 |
+
self.patch_size = patch_size
|
193 |
+
self.num_patches = num_patches
|
194 |
+
self.patch_shape = patch_shape
|
195 |
+
|
196 |
+
self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
|
197 |
+
|
198 |
+
def forward(self, pixel_values):
|
199 |
+
batch_size, num_channels, height, width = pixel_values.shape
|
200 |
+
# FIXME look at relaxing size constraints
|
201 |
+
if height != self.image_size[0] or width != self.image_size[1]:
|
202 |
+
raise ValueError(
|
203 |
+
f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
|
204 |
+
)
|
205 |
+
x = self.projection(pixel_values).flatten(2).transpose(1, 2)
|
206 |
+
|
207 |
+
return x
|
208 |
+
|
209 |
+
|
210 |
+
class BeitSelfAttention(nn.Module):
|
211 |
+
def __init__(self, config, window_size=None):
|
212 |
+
super().__init__()
|
213 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
214 |
+
raise ValueError(
|
215 |
+
f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
|
216 |
+
f"heads {config.num_attention_heads}."
|
217 |
+
)
|
218 |
+
|
219 |
+
self.num_attention_heads = config.num_attention_heads
|
220 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
221 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
222 |
+
|
223 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
224 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
|
225 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
226 |
+
|
227 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
228 |
+
|
229 |
+
# sparse params
|
230 |
+
self.random_attn = config.sparse_random_attn
|
231 |
+
self.local_attn = config.sparse_local_attn
|
232 |
+
self.block_size = config.attn_block_size
|
233 |
+
self.num_cls_tokens = config.num_cls_tokens
|
234 |
+
if self.local_attn is not None and self.random_attn is not None:
|
235 |
+
self.num_kv_blocks = self.local_attn + self.random_attn
|
236 |
+
|
237 |
+
if window_size:
|
238 |
+
self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
|
239 |
+
else:
|
240 |
+
self.relative_position_bias = None
|
241 |
+
|
242 |
+
def split_heads(self, x):
|
243 |
+
return rearrange(x, 'b n (h d) -> b h n d', h=self.num_attention_heads)
|
244 |
+
|
245 |
+
def join_heads(self, x):
|
246 |
+
return rearrange(x, 'b h n d -> b n (h d)')
|
247 |
+
|
248 |
+
def blockify(self, x):
|
249 |
+
assert x.dim() == 4, f"Unsupported input shape {x.shape}"
|
250 |
+
seq_len = x.shape[2]
|
251 |
+
if seq_len % self.block_size > 0: # seq_len not divisible by block_size, zero pad
|
252 |
+
pad_len = self.block_size - seq_len % self.block_size
|
253 |
+
x = nn.functional.pad(x, (0, 0, 0, pad_len))
|
254 |
+
else:
|
255 |
+
pad_len = 0
|
256 |
+
x = rearrange(x, 'b h (m n) d -> b h m n d', n=self.block_size)
|
257 |
+
return x, pad_len
|
258 |
+
|
259 |
+
def dense_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, k_idx=None):
|
260 |
+
# q, k, v: (bsz, num_heads, seq_len, dims)
|
261 |
+
assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
|
262 |
+
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
263 |
+
sim = sim / math.sqrt(self.attention_head_size)
|
264 |
+
|
265 |
+
# Add relative position bias if present.
|
266 |
+
if self.relative_position_bias is not None:
|
267 |
+
if q_idx is not None and q_idx.ndim == 2:
|
268 |
+
assert k_idx is not None and len(q_idx) == len(k_idx)
|
269 |
+
bias = torch.stack([
|
270 |
+
self.relative_position_bias(from_idx=q_idx_, to_idx=k_idx_)
|
271 |
+
for q_idx_, k_idx_ in zip(q_idx, k_idx)
|
272 |
+
])
|
273 |
+
else:
|
274 |
+
bias = self.relative_position_bias(from_idx=q_idx, to_idx=k_idx).unsqueeze(0)
|
275 |
+
sim = sim + bias
|
276 |
+
|
277 |
+
# Add shared relative position bias if provided.
|
278 |
+
if relative_position_bias is not None:
|
279 |
+
sim = sim + relative_position_bias
|
280 |
+
|
281 |
+
# Normalize the attention scores to probabilities.
|
282 |
+
attn = sim.softmax(dim=-1)
|
283 |
+
attn = self.dropout(attn)
|
284 |
+
if head_mask is not None:
|
285 |
+
attn = attn * head_mask
|
286 |
+
|
287 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
288 |
+
return out, attn
|
289 |
+
|
290 |
+
def _sparse_attn_relative_position_bias(self, q_idx, pad_q, attn_idx, group_len):
|
291 |
+
q_idx_blk = nn.functional.pad(q_idx, (0, pad_q)).view(-1, self.block_size)
|
292 |
+
attn_idx_flt = rearrange(q_idx_blk[attn_idx], 'm n j -> m (n j)') # (seq_len, num_kv_blocks * group_len)
|
293 |
+
cls_idx = torch.arange(self.num_cls_tokens, device=q_idx.device)
|
294 |
+
cls_idx = repeat(cls_idx, 'n -> m n', m=len(attn_idx_flt))
|
295 |
+
attn_idx_flt = torch.cat((cls_idx, attn_idx_flt), dim=1)
|
296 |
+
attn_idx_flt = repeat(attn_idx_flt, 'm n -> (m i) n', i=group_len)
|
297 |
+
if pad_q > 0:
|
298 |
+
attn_idx_flt = attn_idx_flt[:-pad_q]
|
299 |
+
bias_flt = self.relative_position_bias(from_idx=q_idx, to_idx=attn_idx_flt)
|
300 |
+
if pad_q > 0:
|
301 |
+
bias_flt = nn.functional.pad(bias_flt, (0, 0, 0, pad_q))
|
302 |
+
return rearrange(bias_flt, 'h (m i) n -> h m i n', i=group_len) # num_heads, seq_len, group_len, (num_kv_blocks * group_len + num_cls_tokens)
|
303 |
+
|
304 |
+
def sparse_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, mimic_full=False):
|
305 |
+
assert self.local_attn == 0 or self.local_attn % 2 == 1, "Even local window size not supported"
|
306 |
+
assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
|
307 |
+
|
308 |
+
|
309 |
+
if not mimic_full:
|
310 |
+
cls_k, k = k[..., :self.num_cls_tokens, :], k[..., self.num_cls_tokens:, :] # cls_k: (bsz, num_heads, num_cls_tokens, dims)
|
311 |
+
cls_v, v = v[..., :self.num_cls_tokens, :], v[..., self.num_cls_tokens:, :]
|
312 |
+
|
313 |
+
# pad token sequence to multiples of block_size
|
314 |
+
if mimic_full:
|
315 |
+
bsz, num_heads, seq_len, dims = q.shape
|
316 |
+
else:
|
317 |
+
q, pad_q = self.blockify(q) # q: (bsz, num_heads, seq_len, group_len, dims)
|
318 |
+
k, pad_k = self.blockify(k)
|
319 |
+
v, pad_v = self.blockify(v)
|
320 |
+
bsz, num_heads, seq_len, group_len, dims = q.shape
|
321 |
+
|
322 |
+
# global attention
|
323 |
+
cls_sim = torch.einsum('b h n i d, b h j d -> b h n i j', q, cls_k) # (bsz, num_heads, seq_len, group_len, num_cls_tokens)
|
324 |
+
|
325 |
+
if mimic_full:
|
326 |
+
sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
|
327 |
+
sim = sim / math.sqrt(self.attention_head_size)
|
328 |
+
sim = sim + self.relative_position_bias(from_idx=q_idx).unsqueeze(0)
|
329 |
+
|
330 |
+
else:
|
331 |
+
# initialize empty sim matrix
|
332 |
+
sim = torch.empty((bsz, num_heads, seq_len, self.num_kv_blocks, group_len, group_len), device=q.device)
|
333 |
+
attn_idx = torch.zeros((seq_len, self.num_kv_blocks), dtype=torch.int64, device=q.device)
|
334 |
+
|
335 |
+
# local window attention
|
336 |
+
cnt = 0
|
337 |
+
if self.local_attn > 0:
|
338 |
+
num_rolls = self.local_attn // 2
|
339 |
+
for r in range(-num_rolls, num_rolls + 1):
|
340 |
+
sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.roll(-r, dims=2))
|
341 |
+
attn_idx[:, cnt] = torch.arange(seq_len, device=q.device).roll(r)
|
342 |
+
cnt += 1
|
343 |
+
|
344 |
+
# random attention
|
345 |
+
if self.random_attn > 0:
|
346 |
+
# generate random attention pattern
|
347 |
+
rand = torch.rand((seq_len, seq_len), device=q.device)
|
348 |
+
if self.local_attn > 0:
|
349 |
+
# avoid overlap with local attention
|
350 |
+
for r in range(-num_rolls, num_rolls + 1):
|
351 |
+
tgt_idx = list(i % seq_len for i in range(r, seq_len + r))
|
352 |
+
rand[range(seq_len), tgt_idx] = 0
|
353 |
+
_, idx = rand.topk(self.random_attn, dim=-1) # seq_len, random_attn
|
354 |
+
idx, _ = torch.sort(idx, dim=1)
|
355 |
+
attn_idx[:, cnt:] = idx
|
356 |
+
|
357 |
+
idx_ = repeat(idx, 'n m -> b h n m i d', b=bsz, h=num_heads, i=group_len, d=dims)
|
358 |
+
|
359 |
+
for r in range(self.random_attn):
|
360 |
+
sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.gather(2, idx_[..., r, :, :]))
|
361 |
+
cnt += 1
|
362 |
+
|
363 |
+
sim = rearrange(sim, 'b h m n i j -> b h m i (n j)') # (bsz, num_heads, seq_len, group_len, num_kv_blocks * group_len)
|
364 |
+
sim = torch.cat((cls_sim, sim), -1)
|
365 |
+
sim = sim / math.sqrt(self.attention_head_size)
|
366 |
+
|
367 |
+
# Add relative position bias if present.
|
368 |
+
# NOTE: we assume q and k (excluding cls) use same token indexing, for relative position embedding
|
369 |
+
if self.relative_position_bias is not None:
|
370 |
+
assert q_idx is not None, "query index required for relative position bias"
|
371 |
+
if q_idx.ndim == 2:
|
372 |
+
# different indices for each sample
|
373 |
+
bias = torch.stack([
|
374 |
+
self._sparse_attn_relative_position_bias(q_idx_, pad_q, attn_idx, group_len)
|
375 |
+
for q_idx_ in q_idx
|
376 |
+
])
|
377 |
+
else:
|
378 |
+
bias = self._sparse_attn_relative_position_bias(q_idx, pad_q, attn_idx, group_len).unsqueeze(0)
|
379 |
+
sim = sim + bias
|
380 |
+
|
381 |
+
# Add shared relative position bias if provided.
|
382 |
+
if relative_position_bias is not None:
|
383 |
+
raise NotImplementedError
|
384 |
+
sim = sim + relative_position_bias
|
385 |
+
|
386 |
+
attn = sim.softmax(dim=-1)
|
387 |
+
attn = self.dropout(attn)
|
388 |
+
if head_mask is not None:
|
389 |
+
attn = attn * head_mask
|
390 |
+
|
391 |
+
# block attention
|
392 |
+
if mimic_full:
|
393 |
+
out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
|
394 |
+
|
395 |
+
else:
|
396 |
+
out = torch.empty((bsz, num_heads, seq_len, group_len, dims), device=q.device)
|
397 |
+
for m in range(seq_len):
|
398 |
+
v_row = torch.index_select(v, 2, attn_idx[m])
|
399 |
+
v_row = rearrange(v_row, 'b h n j d -> b h (n j) d') # (bsz, num_heads, num_kv_blocks * group_len, dims)
|
400 |
+
v_row = torch.cat((cls_v, v_row), 2)
|
401 |
+
out[..., m, :, :] = torch.einsum('b h i j, b h j d -> b h i d', attn[..., m, :, :], v_row)
|
402 |
+
out = rearrange(out, 'b h n i d -> b h (n i) d')
|
403 |
+
if pad_q > 0:
|
404 |
+
out = out[..., :-pad_q, :]
|
405 |
+
|
406 |
+
return out, attn
|
407 |
+
|
408 |
+
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
|
409 |
+
# compute qkv
|
410 |
+
q = self.split_heads(self.query(hidden_states))
|
411 |
+
k = self.split_heads(self.key(hidden_states))
|
412 |
+
v = self.split_heads(self.value(hidden_states))
|
413 |
+
|
414 |
+
# combine local token_idx with cls tokens
|
415 |
+
# NOTE: assume token_idx starts from 0
|
416 |
+
cls_q_idx = torch.arange(self.num_cls_tokens, device=q.device)
|
417 |
+
if token_idx is not None:
|
418 |
+
if token_idx.ndim == 2:
|
419 |
+
cls_q_idx = repeat(cls_q_idx, 'n -> b n', b=q.shape[0])
|
420 |
+
all_token_idx = torch.cat((cls_q_idx, token_idx + self.num_cls_tokens), dim=-1)
|
421 |
+
else:
|
422 |
+
all_token_idx = None
|
423 |
+
|
424 |
+
if self.random_attn is None:
|
425 |
+
outputs, attention_probs = self.dense_attention(q, k, v, head_mask=head_mask,
|
426 |
+
relative_position_bias=relative_position_bias,
|
427 |
+
q_idx=all_token_idx,
|
428 |
+
k_idx=all_token_idx)
|
429 |
+
cls_attention_probs = attention_probs[..., :self.num_cls_tokens, :]
|
430 |
+
|
431 |
+
else:
|
432 |
+
cls_q, q = q[..., :self.num_cls_tokens, :], q[..., self.num_cls_tokens:, :]
|
433 |
+
|
434 |
+
# dense global attention (num_cls_tokens, seq_len)
|
435 |
+
cls_outputs, cls_attention_probs = self.dense_attention(cls_q, k, v, head_mask=head_mask,
|
436 |
+
relative_position_bias=relative_position_bias,
|
437 |
+
q_idx=cls_q_idx,
|
438 |
+
k_idx=all_token_idx)
|
439 |
+
|
440 |
+
# sparse local attention (local_seq_len, seq_len)
|
441 |
+
if token_idx is None:
|
442 |
+
token_idx = torch.arange(q.shape[-2], device=q.device)
|
443 |
+
outputs, attention_probs = self.sparse_attention(q, k, v, head_mask=head_mask,
|
444 |
+
relative_position_bias=relative_position_bias,
|
445 |
+
q_idx=token_idx + self.num_cls_tokens)
|
446 |
+
|
447 |
+
outputs = torch.cat((cls_outputs, outputs), dim=2)
|
448 |
+
|
449 |
+
outputs = self.join_heads(outputs)
|
450 |
+
|
451 |
+
outputs = (outputs, cls_attention_probs) if output_attentions else (outputs,)
|
452 |
+
|
453 |
+
return outputs
|
454 |
+
|
455 |
+
|
456 |
+
class BeitSelfOutput(nn.Module):
|
457 |
+
"""
|
458 |
+
The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
|
459 |
+
layernorm applied before each block.
|
460 |
+
"""
|
461 |
+
|
462 |
+
def __init__(self, config):
|
463 |
+
super().__init__()
|
464 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
465 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
466 |
+
|
467 |
+
def forward(self, hidden_states, input_tensor, gamma=None):
|
468 |
+
hidden_states = self.dense(hidden_states)
|
469 |
+
hidden_states = self.dropout(hidden_states)
|
470 |
+
|
471 |
+
return hidden_states
|
472 |
+
|
473 |
+
|
474 |
+
class BeitAttention(nn.Module):
|
475 |
+
def __init__(self, config, window_size=None):
|
476 |
+
super().__init__()
|
477 |
+
self.attention = BeitSelfAttention(config, window_size=window_size)
|
478 |
+
self.output = BeitSelfOutput(config)
|
479 |
+
self.pruned_heads = set()
|
480 |
+
|
481 |
+
def prune_heads(self, heads):
|
482 |
+
if len(heads) == 0:
|
483 |
+
return
|
484 |
+
heads, index = find_pruneable_heads_and_indices(
|
485 |
+
heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
|
486 |
+
)
|
487 |
+
|
488 |
+
# Prune linear layers
|
489 |
+
self.attention.query = prune_linear_layer(self.attention.query, index)
|
490 |
+
self.attention.key = prune_linear_layer(self.attention.key, index)
|
491 |
+
self.attention.value = prune_linear_layer(self.attention.value, index)
|
492 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
493 |
+
|
494 |
+
# Update hyper params and store pruned heads
|
495 |
+
self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
|
496 |
+
self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
|
497 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
498 |
+
|
499 |
+
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
|
500 |
+
self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias, token_idx)
|
501 |
+
|
502 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
503 |
+
|
504 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
505 |
+
return outputs
|
506 |
+
|
507 |
+
|
508 |
+
class BeitIntermediate(nn.Module):
|
509 |
+
def __init__(self, config):
|
510 |
+
super().__init__()
|
511 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
512 |
+
if isinstance(config.hidden_act, str):
|
513 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
514 |
+
else:
|
515 |
+
self.intermediate_act_fn = config.hidden_act
|
516 |
+
|
517 |
+
def forward(self, hidden_states):
|
518 |
+
hidden_states = self.dense(hidden_states)
|
519 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
520 |
+
|
521 |
+
return hidden_states
|
522 |
+
|
523 |
+
|
524 |
+
class BeitOutput(nn.Module):
|
525 |
+
def __init__(self, config):
|
526 |
+
super().__init__()
|
527 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
528 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
529 |
+
|
530 |
+
def forward(self, hidden_states):
|
531 |
+
hidden_states = self.dense(hidden_states)
|
532 |
+
hidden_states = self.dropout(hidden_states)
|
533 |
+
|
534 |
+
return hidden_states
|
535 |
+
|
536 |
+
|
537 |
+
class BeitLayer(nn.Module):
|
538 |
+
"""This corresponds to the Block class in the timm implementation."""
|
539 |
+
|
540 |
+
def __init__(self, config, window_size=None, drop_path_rate=0.0,
|
541 |
+
token_keep_rate=1.0):
|
542 |
+
super().__init__()
|
543 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
544 |
+
self.seq_len_dim = 1
|
545 |
+
self.attention = BeitAttention(config, window_size=window_size)
|
546 |
+
self.intermediate = BeitIntermediate(config)
|
547 |
+
self.output = BeitOutput(config)
|
548 |
+
self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
549 |
+
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
550 |
+
self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
551 |
+
|
552 |
+
# sparse params
|
553 |
+
self.token_keep_rate = token_keep_rate
|
554 |
+
self.token_keep_strategy = config.token_keep_strategy
|
555 |
+
self.num_cls_tokens = config.num_cls_tokens
|
556 |
+
|
557 |
+
init_values = config.layer_scale_init_value
|
558 |
+
if init_values > 0:
|
559 |
+
self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
560 |
+
self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
|
561 |
+
else:
|
562 |
+
self.lambda_1, self.lambda_2 = None, None
|
563 |
+
|
564 |
+
def sparsify(self, x, attn):
|
565 |
+
x_cls, x_ = x[:, :self.num_cls_tokens], x[:, self.num_cls_tokens:]
|
566 |
+
assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
|
567 |
+
left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
|
568 |
+
|
569 |
+
if self.token_keep_strategy == 'cls_attn':
|
570 |
+
if len(attn.shape) == 4:
|
571 |
+
attn = attn.mean(1) # pool over attention heads
|
572 |
+
cls_attn = attn[:, 0, self.num_cls_tokens:]
|
573 |
+
_, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
|
574 |
+
|
575 |
+
elif self.token_keep_strategy == 'random':
|
576 |
+
rand = torch.rand(x_.shape[:2], device=x_.device)
|
577 |
+
_, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
|
578 |
+
|
579 |
+
else:
|
580 |
+
raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
|
581 |
+
|
582 |
+
idx, _ = torch.sort(idx, dim=1)
|
583 |
+
index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
|
584 |
+
outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
|
585 |
+
return outputs, idx
|
586 |
+
|
587 |
+
def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
|
588 |
+
self_attention_outputs = self.attention(
|
589 |
+
self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
|
590 |
+
head_mask,
|
591 |
+
output_attentions=(output_attentions or self.token_keep_rate < 1),
|
592 |
+
relative_position_bias=relative_position_bias,
|
593 |
+
token_idx=token_idx
|
594 |
+
)
|
595 |
+
attention_output = self_attention_outputs[0]
|
596 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
597 |
+
|
598 |
+
# apply lambda_1 if present
|
599 |
+
if self.lambda_1 is not None:
|
600 |
+
attention_output = self.lambda_1 * attention_output
|
601 |
+
|
602 |
+
# first residual connection
|
603 |
+
hidden_states = self.drop_path(attention_output) + hidden_states
|
604 |
+
|
605 |
+
# in BEiT, layernorm is also applied after self-attention
|
606 |
+
layer_output = self.layernorm_after(hidden_states)
|
607 |
+
|
608 |
+
layer_output = self.intermediate(layer_output)
|
609 |
+
layer_output = self.output(layer_output)
|
610 |
+
|
611 |
+
if self.lambda_2 is not None:
|
612 |
+
layer_output = self.lambda_2 * layer_output
|
613 |
+
|
614 |
+
# second residual connection
|
615 |
+
layer_output = self.drop_path(layer_output) + hidden_states
|
616 |
+
|
617 |
+
# node sparsification
|
618 |
+
if self.token_keep_rate < 1:
|
619 |
+
layer_output, token_keep_idx = self.sparsify(layer_output, outputs[0])
|
620 |
+
if token_idx is not None:
|
621 |
+
if token_idx.ndim == 1:
|
622 |
+
token_idx = repeat(token_idx, 'n -> b n', b=len(token_keep_idx))
|
623 |
+
token_keep_idx = token_idx.gather(1, token_keep_idx)
|
624 |
+
outputs = outputs + (token_keep_idx,)
|
625 |
+
|
626 |
+
outputs = (layer_output,) + outputs
|
627 |
+
|
628 |
+
return outputs
|
629 |
+
|
630 |
+
|
631 |
+
class BeitRelativePositionBias(nn.Module):
|
632 |
+
def __init__(self, config, window_size):
|
633 |
+
super().__init__()
|
634 |
+
self.window_size = window_size
|
635 |
+
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
|
636 |
+
self.relative_position_bias_table = nn.Parameter(
|
637 |
+
torch.zeros(self.num_relative_distance, config.num_attention_heads)
|
638 |
+
) # 2*Wh-1 * 2*Ww-1, nH
|
639 |
+
# cls to token & token 2 cls & cls to cls
|
640 |
+
|
641 |
+
# get pair-wise relative position index for each token inside the window
|
642 |
+
coords_h = torch.arange(window_size[0])
|
643 |
+
coords_w = torch.arange(window_size[1])
|
644 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
645 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
646 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
647 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
648 |
+
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
649 |
+
relative_coords[:, :, 1] += window_size[1] - 1
|
650 |
+
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
651 |
+
relative_position_index = torch.zeros(
|
652 |
+
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
|
653 |
+
)
|
654 |
+
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
655 |
+
relative_position_index[0, 0:] = self.num_relative_distance - 3
|
656 |
+
relative_position_index[0:, 0] = self.num_relative_distance - 2
|
657 |
+
relative_position_index[0, 0] = self.num_relative_distance - 1
|
658 |
+
|
659 |
+
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
|
660 |
+
|
661 |
+
def forward(self):
|
662 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
663 |
+
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
|
664 |
+
) # Wh*Ww,Wh*Ww,nH
|
665 |
+
|
666 |
+
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
667 |
+
|
668 |
+
|
669 |
+
class BeitRelativePositionBias3D(nn.Module):
|
670 |
+
"""
|
671 |
+
3D relative position bias
|
672 |
+
"""
|
673 |
+
def __init__(self, config, window_size, num_cls_tokens=1):
|
674 |
+
super().__init__()
|
675 |
+
self.window_size = window_size
|
676 |
+
self.num_cls_tokens = num_cls_tokens
|
677 |
+
|
678 |
+
relative_size = [w * 2 - 1 for w in window_size]
|
679 |
+
self.num_relative_distance = np.prod(relative_size) + 2 * num_cls_tokens + num_cls_tokens ** 2
|
680 |
+
|
681 |
+
self.relative_position_bias_table = nn.Parameter(
|
682 |
+
torch.zeros(self.num_relative_distance, config.num_attention_heads)
|
683 |
+
)
|
684 |
+
|
685 |
+
# get pair-wise relative position index for each token inside the window
|
686 |
+
coords_range = [torch.arange(w) for w in window_size]
|
687 |
+
coords_flatten = torch.stack(torch.meshgrid(coords_range)).flatten(1)
|
688 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
689 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
690 |
+
|
691 |
+
for i, w in enumerate(window_size):
|
692 |
+
relative_coords[:, :, i] += w - 1 # shift to start from 0
|
693 |
+
|
694 |
+
for i, r in enumerate(relative_size[1:]):
|
695 |
+
relative_coords[:, :, :i + 1] *= r
|
696 |
+
|
697 |
+
self.seq_len = np.prod(window_size) + num_cls_tokens
|
698 |
+
relative_position_index = torch.zeros((self.seq_len, self.seq_len), dtype=relative_coords.dtype)
|
699 |
+
relative_position_index[num_cls_tokens:, num_cls_tokens:] = relative_coords.sum(-1)
|
700 |
+
|
701 |
+
start = np.prod(relative_size)
|
702 |
+
cls2loc = torch.arange(num_cls_tokens).unsqueeze(1) + start
|
703 |
+
relative_position_index[:num_cls_tokens, num_cls_tokens:] = cls2loc
|
704 |
+
start += num_cls_tokens
|
705 |
+
|
706 |
+
loc2cls = torch.arange(num_cls_tokens).unsqueeze(0) + start
|
707 |
+
relative_position_index[num_cls_tokens:, :num_cls_tokens] = loc2cls
|
708 |
+
start += num_cls_tokens
|
709 |
+
|
710 |
+
cls2cls = torch.arange(num_cls_tokens ** 2).view(num_cls_tokens, num_cls_tokens) + start
|
711 |
+
relative_position_index[:num_cls_tokens, :num_cls_tokens] = cls2cls
|
712 |
+
|
713 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
714 |
+
|
715 |
+
def forward(self, from_idx=None, to_idx=None):
|
716 |
+
"""
|
717 |
+
from_idx: indices of query tokens (1-dim)
|
718 |
+
to_idx: indices of key/value tokens (1-dim, or 2-dim w/ one row per query)
|
719 |
+
"""
|
720 |
+
attn_idx = self.relative_position_index
|
721 |
+
|
722 |
+
# query indices
|
723 |
+
if from_idx is not None:
|
724 |
+
attn_idx = attn_idx[from_idx]
|
725 |
+
|
726 |
+
# key indices
|
727 |
+
if to_idx is not None:
|
728 |
+
assert to_idx.ndim in (1, 2), "to_idx must be 1- or 2-dimensional tensors"
|
729 |
+
if to_idx.ndim == 1:
|
730 |
+
attn_idx = attn_idx[:, to_idx]
|
731 |
+
else:
|
732 |
+
attn_idx = attn_idx.gather(1, to_idx)
|
733 |
+
|
734 |
+
rows, cols = attn_idx.shape
|
735 |
+
relative_position_bias = self.relative_position_bias_table[attn_idx.flatten()]
|
736 |
+
relative_position_bias = rearrange(relative_position_bias, '(i j) h -> h i j', i=rows, j=cols)
|
737 |
+
return relative_position_bias.contiguous()
|
738 |
+
|
739 |
+
|
740 |
+
class BeitEncoder(nn.Module):
|
741 |
+
def __init__(self, config, window_size=None):
|
742 |
+
super().__init__()
|
743 |
+
self.config = config
|
744 |
+
if config.use_shared_relative_position_bias:
|
745 |
+
self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
|
746 |
+
else:
|
747 |
+
self.relative_position_bias = None
|
748 |
+
|
749 |
+
self._register_token_order(window_size)
|
750 |
+
|
751 |
+
# stochastic depth decay rule
|
752 |
+
dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
|
753 |
+
|
754 |
+
# node sparsification
|
755 |
+
token_keep_rate = [1] * config.num_hidden_layers
|
756 |
+
for loc in config.token_drop_loc:
|
757 |
+
token_keep_rate[loc] = config.token_keep_rate
|
758 |
+
|
759 |
+
self.layer = nn.ModuleList(
|
760 |
+
[
|
761 |
+
BeitLayer(
|
762 |
+
config,
|
763 |
+
window_size=window_size if config.use_relative_position_bias else None,
|
764 |
+
drop_path_rate=dpr[i], token_keep_rate=token_keep_rate[i]
|
765 |
+
)
|
766 |
+
for i in range(config.num_hidden_layers)
|
767 |
+
]
|
768 |
+
)
|
769 |
+
|
770 |
+
self.gradient_checkpointing = False
|
771 |
+
|
772 |
+
def _register_token_order(self, shape):
|
773 |
+
if self.config.token_3d_order == 'none':
|
774 |
+
order = None
|
775 |
+
elif self.config.token_3d_order == 'zcurve':
|
776 |
+
nbits = max(shape).bit_length()
|
777 |
+
coords = list(np.ndindex(*shape))
|
778 |
+
order = zCurve.par_interlace(coords, len(shape), nbits)
|
779 |
+
order = torch.tensor(np.argsort(order))
|
780 |
+
elif self.config.token_3d_order == 'hilbert':
|
781 |
+
nbits = max(shape).bit_length()
|
782 |
+
coords = list(np.ndindex(*shape))
|
783 |
+
order = hilbert.encode(np.stack(coords), len(shape), nbits)
|
784 |
+
order = torch.tensor(np.argsort(order))
|
785 |
+
else:
|
786 |
+
raise NotImplementedError(f"Token ordering {self.config.token_3d_order} not supported")
|
787 |
+
|
788 |
+
if order is not None:
|
789 |
+
self.register_buffer('token_order', order, persistent=False)
|
790 |
+
else:
|
791 |
+
self.token_order = None
|
792 |
+
|
793 |
+
def forward(
|
794 |
+
self,
|
795 |
+
hidden_states,
|
796 |
+
head_mask=None,
|
797 |
+
output_attentions=False,
|
798 |
+
output_hidden_states=False,
|
799 |
+
output_token_idx=False,
|
800 |
+
return_dict=True,
|
801 |
+
):
|
802 |
+
all_hidden_states = () if output_hidden_states else None
|
803 |
+
all_self_attentions = () if output_attentions else None
|
804 |
+
all_token_idx = () if output_token_idx else None
|
805 |
+
|
806 |
+
token_idx = self.token_order
|
807 |
+
if token_idx is not None:
|
808 |
+
cls_states, local_states = hidden_states[:, :self.config.num_cls_tokens], hidden_states[:, self.config.num_cls_tokens:]
|
809 |
+
local_states = torch.index_select(local_states, dim=1, index=token_idx)
|
810 |
+
hidden_states = torch.cat((cls_states, local_states), 1)
|
811 |
+
|
812 |
+
for i, layer_module in enumerate(self.layer):
|
813 |
+
if output_hidden_states:
|
814 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
815 |
+
|
816 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
817 |
+
|
818 |
+
if self.gradient_checkpointing and self.training:
|
819 |
+
|
820 |
+
def create_custom_forward(module):
|
821 |
+
def custom_forward(*inputs):
|
822 |
+
return module(*inputs, output_attentions)
|
823 |
+
|
824 |
+
return custom_forward
|
825 |
+
|
826 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
827 |
+
create_custom_forward(layer_module),
|
828 |
+
hidden_states,
|
829 |
+
layer_head_mask,
|
830 |
+
)
|
831 |
+
else:
|
832 |
+
relative_position_bias = (
|
833 |
+
self.relative_position_bias() if self.relative_position_bias is not None else None
|
834 |
+
)
|
835 |
+
layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias, token_idx)
|
836 |
+
|
837 |
+
hidden_states = layer_outputs[0]
|
838 |
+
|
839 |
+
if layer_module.token_keep_rate < 1:
|
840 |
+
token_idx = layer_outputs[-1]
|
841 |
+
|
842 |
+
if output_token_idx:
|
843 |
+
all_token_idx = all_token_idx + (token_idx,)
|
844 |
+
|
845 |
+
if output_attentions:
|
846 |
+
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
847 |
+
|
848 |
+
if output_hidden_states:
|
849 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
850 |
+
|
851 |
+
if not return_dict:
|
852 |
+
return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
|
853 |
+
return BeitModelOutput(
|
854 |
+
last_hidden_state=hidden_states,
|
855 |
+
hidden_states=all_hidden_states,
|
856 |
+
attentions=all_self_attentions,
|
857 |
+
token_idx=all_token_idx
|
858 |
+
)
|
859 |
+
|
860 |
+
|
861 |
+
class BeitPreTrainedModel(PreTrainedModel):
|
862 |
+
"""
|
863 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
864 |
+
models.
|
865 |
+
"""
|
866 |
+
|
867 |
+
config_class = BeitConfig
|
868 |
+
base_model_prefix = "beit"
|
869 |
+
supports_gradient_checkpointing = True
|
870 |
+
|
871 |
+
def _init_weights(self, module):
|
872 |
+
"""Initialize the weights"""
|
873 |
+
if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
|
874 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
875 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
876 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
877 |
+
if module.bias is not None:
|
878 |
+
module.bias.data.zero_()
|
879 |
+
elif isinstance(module, nn.Embedding):
|
880 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
881 |
+
if module.padding_idx is not None:
|
882 |
+
module.weight.data[module.padding_idx].zero_()
|
883 |
+
elif isinstance(module, nn.LayerNorm):
|
884 |
+
module.bias.data.zero_()
|
885 |
+
module.weight.data.fill_(1.0)
|
886 |
+
|
887 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
888 |
+
if isinstance(module, BeitEncoder):
|
889 |
+
module.gradient_checkpointing = value
|
890 |
+
|
891 |
+
|
892 |
+
BEIT_START_DOCSTRING = r"""
|
893 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
|
894 |
+
it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
|
895 |
+
behavior.
|
896 |
+
|
897 |
+
Parameters:
|
898 |
+
config (:class:`~transformers.BeitConfig`): Model configuration class with all the parameters of the model.
|
899 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
900 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
901 |
+
weights.
|
902 |
+
"""
|
903 |
+
|
904 |
+
BEIT_INPUTS_DOCSTRING = r"""
|
905 |
+
Args:
|
906 |
+
pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
|
907 |
+
Pixel values. Pixel values can be obtained using :class:`~transformers.BeitFeatureExtractor`. See
|
908 |
+
:meth:`transformers.BeitFeatureExtractor.__call__` for details.
|
909 |
+
|
910 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
911 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
912 |
+
|
913 |
+
- 1 indicates the head is **not masked**,
|
914 |
+
- 0 indicates the head is **masked**.
|
915 |
+
|
916 |
+
output_attentions (:obj:`bool`, `optional`):
|
917 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
918 |
+
tensors for more detail.
|
919 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
920 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
921 |
+
more detail.
|
922 |
+
return_dict (:obj:`bool`, `optional`):
|
923 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
924 |
+
"""
|
925 |
+
|
926 |
+
|
927 |
+
@add_start_docstrings(
|
928 |
+
"The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
|
929 |
+
BEIT_START_DOCSTRING,
|
930 |
+
)
|
931 |
+
class BeitModel(BeitPreTrainedModel):
|
932 |
+
def __init__(self, config, add_pooling_layer=True, num_frames=None):
|
933 |
+
super().__init__(config)
|
934 |
+
self.config = config
|
935 |
+
|
936 |
+
self.embeddings = BeitEmbeddings(config)
|
937 |
+
self.window_size = self.embeddings.patch_embeddings.patch_shape
|
938 |
+
if num_frames is not None:
|
939 |
+
self.window_size = (num_frames,) + self.window_size
|
940 |
+
self.encoder = BeitEncoder(config, window_size=self.window_size)
|
941 |
+
|
942 |
+
self.layernorm = (
|
943 |
+
nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
944 |
+
)
|
945 |
+
self.pooler = BeitPooler(config) if add_pooling_layer else None
|
946 |
+
|
947 |
+
# Initialize weights and apply final processing
|
948 |
+
self.post_init()
|
949 |
+
|
950 |
+
def get_input_embeddings(self):
|
951 |
+
return self.embeddings.patch_embeddings
|
952 |
+
|
953 |
+
def _prune_heads(self, heads_to_prune):
|
954 |
+
"""
|
955 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
956 |
+
class PreTrainedModel
|
957 |
+
"""
|
958 |
+
for layer, heads in heads_to_prune.items():
|
959 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
960 |
+
|
961 |
+
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
962 |
+
@replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
|
963 |
+
def forward(
|
964 |
+
self,
|
965 |
+
pixel_values=None,
|
966 |
+
bool_masked_pos=None,
|
967 |
+
head_mask=None,
|
968 |
+
output_attentions=None,
|
969 |
+
output_hidden_states=None,
|
970 |
+
output_token_idx=None,
|
971 |
+
return_dict=None,
|
972 |
+
):
|
973 |
+
r"""
|
974 |
+
Returns:
|
975 |
+
|
976 |
+
Examples::
|
977 |
+
|
978 |
+
>>> from transformers import BeitFeatureExtractor, BeitModel
|
979 |
+
>>> from PIL import Image
|
980 |
+
>>> import requests
|
981 |
+
|
982 |
+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
983 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
984 |
+
|
985 |
+
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
|
986 |
+
>>> model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
|
987 |
+
|
988 |
+
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
989 |
+
>>> outputs = model(**inputs)
|
990 |
+
>>> last_hidden_states = outputs.last_hidden_state
|
991 |
+
"""
|
992 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
993 |
+
output_hidden_states = (
|
994 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
995 |
+
)
|
996 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
997 |
+
|
998 |
+
if pixel_values is None:
|
999 |
+
raise ValueError("You have to specify pixel_values")
|
1000 |
+
|
1001 |
+
# Prepare head mask if needed
|
1002 |
+
# 1.0 in head_mask indicate we keep the head
|
1003 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1004 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1005 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1006 |
+
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
1007 |
+
|
1008 |
+
embedding_output = self.embeddings(pixel_values, bool_masked_pos)
|
1009 |
+
|
1010 |
+
encoder_outputs = self.encoder(
|
1011 |
+
embedding_output,
|
1012 |
+
head_mask=head_mask,
|
1013 |
+
output_attentions=output_attentions,
|
1014 |
+
output_hidden_states=output_hidden_states,
|
1015 |
+
output_token_idx=output_token_idx,
|
1016 |
+
return_dict=return_dict,
|
1017 |
+
)
|
1018 |
+
sequence_output = encoder_outputs[0]
|
1019 |
+
sequence_output = self.layernorm(sequence_output)
|
1020 |
+
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
1021 |
+
|
1022 |
+
if not return_dict:
|
1023 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
1024 |
+
|
1025 |
+
return BeitModelOutputWithPooling(
|
1026 |
+
last_hidden_state=sequence_output,
|
1027 |
+
pooler_output=pooled_output,
|
1028 |
+
hidden_states=encoder_outputs.hidden_states,
|
1029 |
+
attentions=encoder_outputs.attentions,
|
1030 |
+
token_idx=encoder_outputs.token_idx,
|
1031 |
+
)
|
1032 |
+
|
1033 |
+
|
1034 |
+
class BeitPooler(nn.Module):
|
1035 |
+
def __init__(self, config):
|
1036 |
+
super().__init__()
|
1037 |
+
self.layernorm = (
|
1038 |
+
nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
|
1039 |
+
)
|
1040 |
+
|
1041 |
+
def forward(self, hidden_states):
|
1042 |
+
if self.layernorm is not None:
|
1043 |
+
# Mean pool the final hidden states of the patch tokens
|
1044 |
+
patch_tokens = hidden_states[:, 1:, :]
|
1045 |
+
pooled_output = self.layernorm(patch_tokens.mean(1))
|
1046 |
+
else:
|
1047 |
+
# Pool by simply taking the final hidden state of the [CLS] token
|
1048 |
+
pooled_output = hidden_states[:, 0]
|
1049 |
+
|
1050 |
+
return pooled_output
|
1051 |
+
|
1052 |
+
|
1053 |
+
@add_start_docstrings(
|
1054 |
+
"Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
|
1055 |
+
BEIT_START_DOCSTRING,
|
1056 |
+
)
|
1057 |
+
class BeitForMaskedImageModeling(BeitPreTrainedModel):
|
1058 |
+
def __init__(self, config):
|
1059 |
+
super().__init__(config)
|
1060 |
+
|
1061 |
+
self.num_labels = config.num_labels
|
1062 |
+
self.beit = BeitModel(config, add_pooling_layer=False)
|
1063 |
+
|
1064 |
+
# Classifier head
|
1065 |
+
self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
1066 |
+
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
1067 |
+
|
1068 |
+
# Initialize weights and apply final processing
|
1069 |
+
self.post_init()
|
1070 |
+
|
1071 |
+
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
1072 |
+
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
1073 |
+
def forward(
|
1074 |
+
self,
|
1075 |
+
pixel_values=None,
|
1076 |
+
bool_masked_pos=None,
|
1077 |
+
head_mask=None,
|
1078 |
+
labels=None,
|
1079 |
+
output_attentions=None,
|
1080 |
+
output_hidden_states=None,
|
1081 |
+
return_dict=None,
|
1082 |
+
):
|
1083 |
+
r"""
|
1084 |
+
bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`):
|
1085 |
+
Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
|
1086 |
+
|
1087 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1088 |
+
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
|
1089 |
+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1090 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1091 |
+
|
1092 |
+
Returns:
|
1093 |
+
|
1094 |
+
Examples::
|
1095 |
+
|
1096 |
+
>>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
|
1097 |
+
>>> from PIL import Image
|
1098 |
+
>>> import requests
|
1099 |
+
|
1100 |
+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
1101 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1102 |
+
|
1103 |
+
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
|
1104 |
+
>>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
|
1105 |
+
|
1106 |
+
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
1107 |
+
>>> outputs = model(**inputs)
|
1108 |
+
>>> logits = outputs.logits
|
1109 |
+
"""
|
1110 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1111 |
+
|
1112 |
+
outputs = self.beit(
|
1113 |
+
pixel_values,
|
1114 |
+
bool_masked_pos=bool_masked_pos,
|
1115 |
+
head_mask=head_mask,
|
1116 |
+
output_attentions=output_attentions,
|
1117 |
+
output_hidden_states=output_hidden_states,
|
1118 |
+
return_dict=return_dict,
|
1119 |
+
)
|
1120 |
+
|
1121 |
+
sequence_output = outputs[0]
|
1122 |
+
sequence_output = self.layernorm(sequence_output)
|
1123 |
+
prediction_scores = self.lm_head(sequence_output[:, 1:])
|
1124 |
+
|
1125 |
+
masked_lm_loss = None
|
1126 |
+
if labels is not None:
|
1127 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1128 |
+
masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
|
1129 |
+
|
1130 |
+
if not return_dict:
|
1131 |
+
output = (prediction_scores,) + outputs[2:]
|
1132 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1133 |
+
|
1134 |
+
return MaskedLMOutput(
|
1135 |
+
loss=masked_lm_loss,
|
1136 |
+
logits=prediction_scores,
|
1137 |
+
hidden_states=outputs.hidden_states,
|
1138 |
+
attentions=outputs.attentions,
|
1139 |
+
)
|
1140 |
+
|
1141 |
+
|
1142 |
+
@add_start_docstrings(
|
1143 |
+
"""
|
1144 |
+
Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
|
1145 |
+
hidden states of the patch tokens) e.g. for ImageNet.
|
1146 |
+
""",
|
1147 |
+
BEIT_START_DOCSTRING,
|
1148 |
+
)
|
1149 |
+
class BeitForImageClassification(BeitPreTrainedModel):
|
1150 |
+
def __init__(self, config):
|
1151 |
+
super().__init__(config)
|
1152 |
+
|
1153 |
+
self.num_labels = config.num_labels
|
1154 |
+
self.beit = BeitModel(config, add_pooling_layer=True)
|
1155 |
+
|
1156 |
+
# Classifier head
|
1157 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
|
1158 |
+
|
1159 |
+
# Initialize weights and apply final processing
|
1160 |
+
self.post_init()
|
1161 |
+
|
1162 |
+
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
1163 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
1164 |
+
def forward(
|
1165 |
+
self,
|
1166 |
+
pixel_values=None,
|
1167 |
+
head_mask=None,
|
1168 |
+
labels=None,
|
1169 |
+
output_attentions=None,
|
1170 |
+
output_hidden_states=None,
|
1171 |
+
return_dict=None,
|
1172 |
+
):
|
1173 |
+
r"""
|
1174 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1175 |
+
Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
|
1176 |
+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1177 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1178 |
+
|
1179 |
+
Returns:
|
1180 |
+
|
1181 |
+
Examples::
|
1182 |
+
|
1183 |
+
>>> from transformers import BeitFeatureExtractor, BeitForImageClassification
|
1184 |
+
>>> from PIL import Image
|
1185 |
+
>>> import requests
|
1186 |
+
|
1187 |
+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
1188 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1189 |
+
|
1190 |
+
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
|
1191 |
+
>>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
|
1192 |
+
|
1193 |
+
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
1194 |
+
>>> outputs = model(**inputs)
|
1195 |
+
>>> logits = outputs.logits
|
1196 |
+
>>> # model predicts one of the 1000 ImageNet classes
|
1197 |
+
>>> predicted_class_idx = logits.argmax(-1).item()
|
1198 |
+
>>> print("Predicted class:", model.config.id2label[predicted_class_idx])
|
1199 |
+
"""
|
1200 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1201 |
+
|
1202 |
+
outputs = self.beit(
|
1203 |
+
pixel_values,
|
1204 |
+
head_mask=head_mask,
|
1205 |
+
output_attentions=output_attentions,
|
1206 |
+
output_hidden_states=output_hidden_states,
|
1207 |
+
return_dict=return_dict,
|
1208 |
+
)
|
1209 |
+
|
1210 |
+
pooled_output = outputs.pooler_output if return_dict else outputs[1]
|
1211 |
+
|
1212 |
+
logits = self.classifier(pooled_output)
|
1213 |
+
|
1214 |
+
loss = None
|
1215 |
+
if labels is not None:
|
1216 |
+
if self.num_labels == 1:
|
1217 |
+
# We are doing regression
|
1218 |
+
loss_fct = MSELoss()
|
1219 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
1220 |
+
else:
|
1221 |
+
loss_fct = CrossEntropyLoss()
|
1222 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1223 |
+
|
1224 |
+
if not return_dict:
|
1225 |
+
output = (logits,) + outputs[2:]
|
1226 |
+
return ((loss,) + output) if loss is not None else output
|
1227 |
+
|
1228 |
+
return SequenceClassifierOutput(
|
1229 |
+
loss=loss,
|
1230 |
+
logits=logits,
|
1231 |
+
hidden_states=outputs.hidden_states,
|
1232 |
+
attentions=outputs.attentions,
|
1233 |
+
)
|
1234 |
+
|
1235 |
+
|
1236 |
+
class BeitConvModule(nn.Module):
|
1237 |
+
"""
|
1238 |
+
A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
|
1239 |
+
layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
1240 |
+
|
1241 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
1242 |
+
"""
|
1243 |
+
|
1244 |
+
def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
|
1245 |
+
super().__init__()
|
1246 |
+
self.conv = nn.Conv2d(
|
1247 |
+
in_channels=in_channels,
|
1248 |
+
out_channels=out_channels,
|
1249 |
+
kernel_size=kernel_size,
|
1250 |
+
padding=padding,
|
1251 |
+
bias=bias,
|
1252 |
+
dilation=dilation,
|
1253 |
+
)
|
1254 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
1255 |
+
self.activation = nn.ReLU()
|
1256 |
+
|
1257 |
+
def forward(self, input):
|
1258 |
+
output = self.conv(input)
|
1259 |
+
output = self.bn(output)
|
1260 |
+
output = self.activation(output)
|
1261 |
+
|
1262 |
+
return output
|
1263 |
+
|
1264 |
+
|
1265 |
+
class BeitPyramidPoolingModule(nn.ModuleList):
|
1266 |
+
"""
|
1267 |
+
Pyramid Pooling Module (PPM) used in PSPNet.
|
1268 |
+
|
1269 |
+
Args:
|
1270 |
+
pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
|
1271 |
+
Module.
|
1272 |
+
in_channels (int): Input channels.
|
1273 |
+
channels (int): Channels after modules, before conv_seg.
|
1274 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
1275 |
+
|
1276 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
1277 |
+
"""
|
1278 |
+
|
1279 |
+
def __init__(self, pool_scales, in_channels, channels, align_corners):
|
1280 |
+
super().__init__()
|
1281 |
+
self.pool_scales = pool_scales
|
1282 |
+
self.align_corners = align_corners
|
1283 |
+
self.in_channels = in_channels
|
1284 |
+
self.channels = channels
|
1285 |
+
for pool_scale in pool_scales:
|
1286 |
+
self.append(
|
1287 |
+
nn.Sequential(
|
1288 |
+
nn.AdaptiveAvgPool2d(pool_scale),
|
1289 |
+
BeitConvModule(self.in_channels, self.channels, kernel_size=1),
|
1290 |
+
)
|
1291 |
+
)
|
1292 |
+
|
1293 |
+
def forward(self, x):
|
1294 |
+
ppm_outs = []
|
1295 |
+
for ppm in self:
|
1296 |
+
ppm_out = ppm(x)
|
1297 |
+
upsampled_ppm_out = nn.functional.interpolate(
|
1298 |
+
ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
|
1299 |
+
)
|
1300 |
+
ppm_outs.append(upsampled_ppm_out)
|
1301 |
+
return ppm_outs
|
1302 |
+
|
1303 |
+
|
1304 |
+
class BeitUperHead(nn.Module):
|
1305 |
+
"""
|
1306 |
+
Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
|
1307 |
+
<https://arxiv.org/abs/1807.10221>`_.
|
1308 |
+
|
1309 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
1310 |
+
"""
|
1311 |
+
|
1312 |
+
def __init__(self, config):
|
1313 |
+
super().__init__()
|
1314 |
+
|
1315 |
+
self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
|
1316 |
+
self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
|
1317 |
+
self.channels = config.hidden_size
|
1318 |
+
self.align_corners = False
|
1319 |
+
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
1320 |
+
|
1321 |
+
# PSP Module
|
1322 |
+
self.psp_modules = BeitPyramidPoolingModule(
|
1323 |
+
self.pool_scales,
|
1324 |
+
self.in_channels[-1],
|
1325 |
+
self.channels,
|
1326 |
+
align_corners=self.align_corners,
|
1327 |
+
)
|
1328 |
+
self.bottleneck = BeitConvModule(
|
1329 |
+
self.in_channels[-1] + len(self.pool_scales) * self.channels,
|
1330 |
+
self.channels,
|
1331 |
+
kernel_size=3,
|
1332 |
+
padding=1,
|
1333 |
+
)
|
1334 |
+
# FPN Module
|
1335 |
+
self.lateral_convs = nn.ModuleList()
|
1336 |
+
self.fpn_convs = nn.ModuleList()
|
1337 |
+
for in_channels in self.in_channels[:-1]: # skip the top layer
|
1338 |
+
l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
|
1339 |
+
fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
|
1340 |
+
self.lateral_convs.append(l_conv)
|
1341 |
+
self.fpn_convs.append(fpn_conv)
|
1342 |
+
|
1343 |
+
self.fpn_bottleneck = BeitConvModule(
|
1344 |
+
len(self.in_channels) * self.channels,
|
1345 |
+
self.channels,
|
1346 |
+
kernel_size=3,
|
1347 |
+
padding=1,
|
1348 |
+
)
|
1349 |
+
|
1350 |
+
def psp_forward(self, inputs):
|
1351 |
+
x = inputs[-1]
|
1352 |
+
psp_outs = [x]
|
1353 |
+
psp_outs.extend(self.psp_modules(x))
|
1354 |
+
psp_outs = torch.cat(psp_outs, dim=1)
|
1355 |
+
output = self.bottleneck(psp_outs)
|
1356 |
+
|
1357 |
+
return output
|
1358 |
+
|
1359 |
+
def forward(self, encoder_hidden_states):
|
1360 |
+
# build laterals
|
1361 |
+
laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
|
1362 |
+
|
1363 |
+
laterals.append(self.psp_forward(encoder_hidden_states))
|
1364 |
+
|
1365 |
+
# build top-down path
|
1366 |
+
used_backbone_levels = len(laterals)
|
1367 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
1368 |
+
prev_shape = laterals[i - 1].shape[2:]
|
1369 |
+
laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
|
1370 |
+
laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
|
1371 |
+
)
|
1372 |
+
|
1373 |
+
# build outputs
|
1374 |
+
fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
|
1375 |
+
# append psp feature
|
1376 |
+
fpn_outs.append(laterals[-1])
|
1377 |
+
|
1378 |
+
for i in range(used_backbone_levels - 1, 0, -1):
|
1379 |
+
fpn_outs[i] = nn.functional.interpolate(
|
1380 |
+
fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
|
1381 |
+
)
|
1382 |
+
fpn_outs = torch.cat(fpn_outs, dim=1)
|
1383 |
+
output = self.fpn_bottleneck(fpn_outs)
|
1384 |
+
output = self.classifier(output)
|
1385 |
+
|
1386 |
+
return output
|
1387 |
+
|
1388 |
+
|
1389 |
+
class BeitFCNHead(nn.Module):
|
1390 |
+
"""
|
1391 |
+
Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
|
1392 |
+
<https://arxiv.org/abs/1411.4038>`_.
|
1393 |
+
|
1394 |
+
Args:
|
1395 |
+
config (BeitConfig): Configuration.
|
1396 |
+
in_channels
|
1397 |
+
kernel_size (int): The kernel size for convs in the head. Default: 3.
|
1398 |
+
dilation (int): The dilation rate for convs in the head. Default: 1.
|
1399 |
+
|
1400 |
+
|
1401 |
+
Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
|
1402 |
+
"""
|
1403 |
+
|
1404 |
+
def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
|
1405 |
+
super().__init__()
|
1406 |
+
self.in_channels = config.hidden_size
|
1407 |
+
self.channels = config.auxiliary_channels
|
1408 |
+
self.num_convs = config.auxiliary_num_convs
|
1409 |
+
self.concat_input = config.auxiliary_concat_input
|
1410 |
+
self.in_index = in_index
|
1411 |
+
|
1412 |
+
conv_padding = (kernel_size // 2) * dilation
|
1413 |
+
convs = []
|
1414 |
+
convs.append(
|
1415 |
+
BeitConvModule(
|
1416 |
+
self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
1417 |
+
)
|
1418 |
+
)
|
1419 |
+
for i in range(self.num_convs - 1):
|
1420 |
+
convs.append(
|
1421 |
+
BeitConvModule(
|
1422 |
+
self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
|
1423 |
+
)
|
1424 |
+
)
|
1425 |
+
if self.num_convs == 0:
|
1426 |
+
self.convs = nn.Identity()
|
1427 |
+
else:
|
1428 |
+
self.convs = nn.Sequential(*convs)
|
1429 |
+
if self.concat_input:
|
1430 |
+
self.conv_cat = BeitConvModule(
|
1431 |
+
self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
|
1435 |
+
|
1436 |
+
def forward(self, encoder_hidden_states):
|
1437 |
+
# just take the relevant feature maps
|
1438 |
+
hidden_states = encoder_hidden_states[self.in_index]
|
1439 |
+
output = self.convs(hidden_states)
|
1440 |
+
if self.concat_input:
|
1441 |
+
output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
|
1442 |
+
output = self.classifier(output)
|
1443 |
+
return output
|
1444 |
+
|
1445 |
+
|
1446 |
+
@add_start_docstrings(
|
1447 |
+
"""
|
1448 |
+
Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
|
1449 |
+
""",
|
1450 |
+
BEIT_START_DOCSTRING,
|
1451 |
+
)
|
1452 |
+
class BeitForSemanticSegmentation(BeitPreTrainedModel):
|
1453 |
+
def __init__(self, config):
|
1454 |
+
super().__init__(config)
|
1455 |
+
|
1456 |
+
self.num_labels = config.num_labels
|
1457 |
+
self.beit = BeitModel(config, add_pooling_layer=False)
|
1458 |
+
|
1459 |
+
# FPNs
|
1460 |
+
self.fpn1 = nn.Sequential(
|
1461 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
1462 |
+
nn.BatchNorm2d(config.hidden_size),
|
1463 |
+
nn.GELU(),
|
1464 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
1465 |
+
)
|
1466 |
+
self.fpn2 = nn.Sequential(
|
1467 |
+
nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
|
1468 |
+
)
|
1469 |
+
self.fpn3 = nn.Identity()
|
1470 |
+
self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
1471 |
+
|
1472 |
+
# Semantic segmentation head(s)
|
1473 |
+
self.decode_head = BeitUperHead(config)
|
1474 |
+
self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
|
1475 |
+
|
1476 |
+
# Initialize weights and apply final processing
|
1477 |
+
self.post_init()
|
1478 |
+
|
1479 |
+
def compute_loss(self, logits, auxiliary_logits, labels):
|
1480 |
+
# upsample logits to the images' original size
|
1481 |
+
upsampled_logits = nn.functional.interpolate(
|
1482 |
+
logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
1483 |
+
)
|
1484 |
+
if auxiliary_logits is not None:
|
1485 |
+
upsampled_auxiliary_logits = nn.functional.interpolate(
|
1486 |
+
auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
|
1487 |
+
)
|
1488 |
+
# compute weighted loss
|
1489 |
+
loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
|
1490 |
+
main_loss = loss_fct(upsampled_logits, labels)
|
1491 |
+
auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
|
1492 |
+
loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
|
1493 |
+
|
1494 |
+
return loss
|
1495 |
+
|
1496 |
+
@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
|
1497 |
+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
|
1498 |
+
def forward(
|
1499 |
+
self,
|
1500 |
+
pixel_values=None,
|
1501 |
+
head_mask=None,
|
1502 |
+
labels=None,
|
1503 |
+
output_attentions=None,
|
1504 |
+
output_hidden_states=None,
|
1505 |
+
return_dict=None,
|
1506 |
+
):
|
1507 |
+
r"""
|
1508 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
|
1509 |
+
Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
|
1510 |
+
config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
|
1511 |
+
(Cross-Entropy).
|
1512 |
+
|
1513 |
+
Returns:
|
1514 |
+
|
1515 |
+
Examples::
|
1516 |
+
|
1517 |
+
>>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
|
1518 |
+
>>> from PIL import Image
|
1519 |
+
>>> import requests
|
1520 |
+
|
1521 |
+
>>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
|
1522 |
+
>>> image = Image.open(requests.get(url, stream=True).raw)
|
1523 |
+
|
1524 |
+
>>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
1525 |
+
>>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
|
1526 |
+
|
1527 |
+
>>> inputs = feature_extractor(images=image, return_tensors="pt")
|
1528 |
+
>>> outputs = model(**inputs)
|
1529 |
+
>>> # logits are of shape (batch_size, num_labels, height/4, width/4)
|
1530 |
+
>>> logits = outputs.logits
|
1531 |
+
"""
|
1532 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1533 |
+
output_hidden_states = (
|
1534 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1535 |
+
)
|
1536 |
+
|
1537 |
+
outputs = self.beit(
|
1538 |
+
pixel_values,
|
1539 |
+
head_mask=head_mask,
|
1540 |
+
output_attentions=output_attentions,
|
1541 |
+
output_hidden_states=True, # we need the intermediate hidden states
|
1542 |
+
return_dict=return_dict,
|
1543 |
+
)
|
1544 |
+
|
1545 |
+
encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
|
1546 |
+
|
1547 |
+
# only keep certain features, and reshape
|
1548 |
+
# note that we do +1 as the encoder_hidden_states also includes the initial embeddings
|
1549 |
+
features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
|
1550 |
+
batch_size = pixel_values.shape[0]
|
1551 |
+
patch_resolution = self.config.image_size // self.config.patch_size
|
1552 |
+
features = [
|
1553 |
+
x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
|
1554 |
+
]
|
1555 |
+
|
1556 |
+
# apply FPNs
|
1557 |
+
ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
|
1558 |
+
for i in range(len(features)):
|
1559 |
+
features[i] = ops[i](features[i])
|
1560 |
+
|
1561 |
+
logits = self.decode_head(features)
|
1562 |
+
auxiliary_logits = None
|
1563 |
+
if self.auxiliary_head is not None:
|
1564 |
+
auxiliary_logits = self.auxiliary_head(features)
|
1565 |
+
|
1566 |
+
loss = None
|
1567 |
+
if labels is not None:
|
1568 |
+
if self.config.num_labels == 1:
|
1569 |
+
raise ValueError("The number of labels should be greater than one")
|
1570 |
+
else:
|
1571 |
+
loss = self.compute_loss(logits, auxiliary_logits, labels)
|
1572 |
+
|
1573 |
+
if not return_dict:
|
1574 |
+
if output_hidden_states:
|
1575 |
+
output = (logits,) + outputs[2:]
|
1576 |
+
else:
|
1577 |
+
output = (logits,) + outputs[3:]
|
1578 |
+
return ((loss,) + output) if loss is not None else output
|
1579 |
+
|
1580 |
+
return SequenceClassifierOutput(
|
1581 |
+
loss=loss,
|
1582 |
+
logits=logits,
|
1583 |
+
hidden_states=outputs.hidden_states if output_hidden_states else None,
|
1584 |
+
attentions=outputs.attentions,
|
1585 |
+
)
|
svitt/sparse_xbert.py
ADDED
@@ -0,0 +1,2039 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model. """
|
17 |
+
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
import warnings
|
21 |
+
from dataclasses import dataclass
|
22 |
+
from typing import Optional, Tuple
|
23 |
+
|
24 |
+
import torch
|
25 |
+
from torch import Tensor, device, nn
|
26 |
+
import torch.utils.checkpoint
|
27 |
+
from torch import nn
|
28 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
29 |
+
import torch.nn.functional as F
|
30 |
+
|
31 |
+
from transformers.activations import ACT2FN
|
32 |
+
from transformers.file_utils import (
|
33 |
+
ModelOutput,
|
34 |
+
add_start_docstrings,
|
35 |
+
add_start_docstrings_to_model_forward,
|
36 |
+
replace_return_docstrings,
|
37 |
+
)
|
38 |
+
from transformers.modeling_outputs import (
|
39 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
40 |
+
BaseModelOutputWithPoolingAndCrossAttentions,
|
41 |
+
CausalLMOutputWithCrossAttentions,
|
42 |
+
MaskedLMOutput,
|
43 |
+
MultipleChoiceModelOutput,
|
44 |
+
NextSentencePredictorOutput,
|
45 |
+
QuestionAnsweringModelOutput,
|
46 |
+
SequenceClassifierOutput,
|
47 |
+
TokenClassifierOutput,
|
48 |
+
)
|
49 |
+
from transformers.modeling_utils import (
|
50 |
+
PreTrainedModel,
|
51 |
+
apply_chunking_to_forward,
|
52 |
+
find_pruneable_heads_and_indices,
|
53 |
+
prune_linear_layer,
|
54 |
+
)
|
55 |
+
from svitt.sparse_config import BertConfig
|
56 |
+
|
57 |
+
import transformers
|
58 |
+
transformers.logging.set_verbosity_error()
|
59 |
+
|
60 |
+
|
61 |
+
_CONFIG_FOR_DOC = "BertConfig"
|
62 |
+
_TOKENIZER_FOR_DOC = "BertTokenizer"
|
63 |
+
|
64 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
65 |
+
"bert-base-uncased",
|
66 |
+
"bert-large-uncased",
|
67 |
+
"bert-base-cased",
|
68 |
+
"bert-large-cased",
|
69 |
+
"bert-base-multilingual-uncased",
|
70 |
+
"bert-base-multilingual-cased",
|
71 |
+
"bert-base-chinese",
|
72 |
+
"bert-base-german-cased",
|
73 |
+
"bert-large-uncased-whole-word-masking",
|
74 |
+
"bert-large-cased-whole-word-masking",
|
75 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad",
|
76 |
+
"bert-large-cased-whole-word-masking-finetuned-squad",
|
77 |
+
"bert-base-cased-finetuned-mrpc",
|
78 |
+
"bert-base-german-dbmdz-cased",
|
79 |
+
"bert-base-german-dbmdz-uncased",
|
80 |
+
"cl-tohoku/bert-base-japanese",
|
81 |
+
"cl-tohoku/bert-base-japanese-whole-word-masking",
|
82 |
+
"cl-tohoku/bert-base-japanese-char",
|
83 |
+
"cl-tohoku/bert-base-japanese-char-whole-word-masking",
|
84 |
+
"TurkuNLP/bert-base-finnish-cased-v1",
|
85 |
+
"TurkuNLP/bert-base-finnish-uncased-v1",
|
86 |
+
"wietsedv/bert-base-dutch-cased",
|
87 |
+
# See all BERT models at https://huggingface.co/models?filter=bert
|
88 |
+
]
|
89 |
+
|
90 |
+
|
91 |
+
@dataclass
|
92 |
+
class BertModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
|
93 |
+
token_idx: Optional[Tuple[torch.LongTensor]] = None
|
94 |
+
|
95 |
+
|
96 |
+
@dataclass
|
97 |
+
class BertModelOutputWithPoolingAndCrossAttentions(BaseModelOutputWithPoolingAndCrossAttentions):
|
98 |
+
token_idx: Optional[Tuple[torch.LongTensor]] = None
|
99 |
+
|
100 |
+
|
101 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
102 |
+
"""Load tf checkpoints in a pytorch model."""
|
103 |
+
try:
|
104 |
+
import re
|
105 |
+
|
106 |
+
import numpy as np
|
107 |
+
import tensorflow as tf
|
108 |
+
except ImportError:
|
109 |
+
print(
|
110 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
111 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
112 |
+
)
|
113 |
+
raise
|
114 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
115 |
+
# Load weights from TF model
|
116 |
+
init_vars = tf.train.list_variables(tf_path)
|
117 |
+
names = []
|
118 |
+
arrays = []
|
119 |
+
for name, shape in init_vars:
|
120 |
+
array = tf.train.load_variable(tf_path, name)
|
121 |
+
names.append(name)
|
122 |
+
arrays.append(array)
|
123 |
+
|
124 |
+
for name, array in zip(names, arrays):
|
125 |
+
name = name.split("/")
|
126 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
127 |
+
# which are not required for using pretrained model
|
128 |
+
if any(
|
129 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer",
|
130 |
+
"AdamWeightDecayOptimizer_1", "global_step"]
|
131 |
+
for n in name
|
132 |
+
):
|
133 |
+
continue
|
134 |
+
pointer = model
|
135 |
+
for m_name in name:
|
136 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
137 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
138 |
+
else:
|
139 |
+
scope_names = [m_name]
|
140 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
141 |
+
pointer = getattr(pointer, "weight")
|
142 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
143 |
+
pointer = getattr(pointer, "bias")
|
144 |
+
elif scope_names[0] == "output_weights":
|
145 |
+
pointer = getattr(pointer, "weight")
|
146 |
+
elif scope_names[0] == "squad":
|
147 |
+
pointer = getattr(pointer, "classifier")
|
148 |
+
else:
|
149 |
+
try:
|
150 |
+
pointer = getattr(pointer, scope_names[0])
|
151 |
+
except AttributeError:
|
152 |
+
continue
|
153 |
+
if len(scope_names) >= 2:
|
154 |
+
num = int(scope_names[1])
|
155 |
+
pointer = pointer[num]
|
156 |
+
if m_name[-11:] == "_embeddings":
|
157 |
+
pointer = getattr(pointer, "weight")
|
158 |
+
elif m_name == "kernel":
|
159 |
+
array = np.transpose(array)
|
160 |
+
try:
|
161 |
+
assert (
|
162 |
+
pointer.shape == array.shape
|
163 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
164 |
+
except AssertionError as e:
|
165 |
+
e.args += (pointer.shape, array.shape)
|
166 |
+
raise
|
167 |
+
pointer.data = torch.from_numpy(array)
|
168 |
+
return model
|
169 |
+
|
170 |
+
|
171 |
+
class BertEmbeddings(nn.Module):
|
172 |
+
"""Construct the embeddings from word, position and token_type embeddings."""
|
173 |
+
|
174 |
+
def __init__(self, config):
|
175 |
+
super().__init__()
|
176 |
+
self.word_embeddings = nn.Embedding(
|
177 |
+
config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
178 |
+
self.position_embeddings = nn.Embedding(
|
179 |
+
config.max_position_embeddings, config.hidden_size)
|
180 |
+
self.token_type_embeddings = nn.Embedding(
|
181 |
+
config.type_vocab_size, config.hidden_size)
|
182 |
+
|
183 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
184 |
+
# any TensorFlow checkpoint file
|
185 |
+
self.LayerNorm = nn.LayerNorm(
|
186 |
+
config.hidden_size, eps=config.layer_norm_eps)
|
187 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
188 |
+
|
189 |
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
190 |
+
self.register_buffer("position_ids", torch.arange(
|
191 |
+
config.max_position_embeddings).expand((1, -1)))
|
192 |
+
self.position_embedding_type = getattr(
|
193 |
+
config, "position_embedding_type", "absolute")
|
194 |
+
|
195 |
+
self.config = config
|
196 |
+
|
197 |
+
def forward(
|
198 |
+
self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
|
199 |
+
):
|
200 |
+
if input_ids is not None:
|
201 |
+
input_shape = input_ids.size()
|
202 |
+
else:
|
203 |
+
input_shape = inputs_embeds.size()[:-1]
|
204 |
+
|
205 |
+
seq_length = input_shape[1]
|
206 |
+
|
207 |
+
if position_ids is None:
|
208 |
+
position_ids = self.position_ids[:,
|
209 |
+
past_key_values_length: seq_length + past_key_values_length]
|
210 |
+
|
211 |
+
if token_type_ids is None:
|
212 |
+
token_type_ids = torch.zeros(
|
213 |
+
input_shape, dtype=torch.long, device=self.position_ids.device)
|
214 |
+
|
215 |
+
if inputs_embeds is None:
|
216 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
217 |
+
|
218 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
219 |
+
|
220 |
+
embeddings = inputs_embeds + token_type_embeddings
|
221 |
+
if self.position_embedding_type == "absolute":
|
222 |
+
position_embeddings = self.position_embeddings(position_ids)
|
223 |
+
embeddings += position_embeddings
|
224 |
+
embeddings = self.LayerNorm(embeddings)
|
225 |
+
embeddings = self.dropout(embeddings)
|
226 |
+
return embeddings
|
227 |
+
|
228 |
+
|
229 |
+
class BertSelfAttention(nn.Module):
|
230 |
+
def __init__(self, config, is_cross_attention):
|
231 |
+
super().__init__()
|
232 |
+
self.config = config
|
233 |
+
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
|
234 |
+
raise ValueError(
|
235 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
236 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
237 |
+
)
|
238 |
+
|
239 |
+
self.num_attention_heads = config.num_attention_heads
|
240 |
+
self.attention_head_size = int(
|
241 |
+
config.hidden_size / config.num_attention_heads)
|
242 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
243 |
+
|
244 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
245 |
+
if is_cross_attention:
|
246 |
+
self.key = nn.Linear(config.encoder_width, self.all_head_size)
|
247 |
+
self.value = nn.Linear(config.encoder_width, self.all_head_size)
|
248 |
+
else:
|
249 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
250 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
251 |
+
|
252 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
253 |
+
self.position_embedding_type = getattr(
|
254 |
+
config, "position_embedding_type", "absolute")
|
255 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
256 |
+
self.max_position_embeddings = config.max_position_embeddings
|
257 |
+
self.distance_embedding = nn.Embedding(
|
258 |
+
2 * config.max_position_embeddings - 1, self.attention_head_size)
|
259 |
+
self.save_attention = False
|
260 |
+
|
261 |
+
def save_attn_gradients(self, attn_gradients):
|
262 |
+
self.attn_gradients = attn_gradients
|
263 |
+
|
264 |
+
def get_attn_gradients(self):
|
265 |
+
return self.attn_gradients
|
266 |
+
|
267 |
+
def save_attention_map(self, attention_map):
|
268 |
+
self.attention_map = attention_map
|
269 |
+
|
270 |
+
def get_attention_map(self):
|
271 |
+
return self.attention_map
|
272 |
+
|
273 |
+
def transpose_for_scores(self, x):
|
274 |
+
new_x_shape = x.size()[
|
275 |
+
:-1] + (self.num_attention_heads, self.attention_head_size)
|
276 |
+
x = x.view(*new_x_shape)
|
277 |
+
return x.permute(0, 2, 1, 3)
|
278 |
+
|
279 |
+
def forward(
|
280 |
+
self,
|
281 |
+
hidden_states,
|
282 |
+
attention_mask=None,
|
283 |
+
head_mask=None,
|
284 |
+
encoder_hidden_states=None,
|
285 |
+
encoder_attention_mask=None,
|
286 |
+
past_key_value=None,
|
287 |
+
output_attentions=False,
|
288 |
+
):
|
289 |
+
mixed_query_layer = self.query(hidden_states)
|
290 |
+
|
291 |
+
# If this is instantiated as a cross-attention module, the keys
|
292 |
+
# and values come from an encoder; the attention mask needs to be
|
293 |
+
# such that the encoder's padding tokens are not attended to.
|
294 |
+
is_cross_attention = encoder_hidden_states is not None
|
295 |
+
|
296 |
+
if is_cross_attention:
|
297 |
+
key_layer = self.transpose_for_scores(
|
298 |
+
self.key(encoder_hidden_states))
|
299 |
+
value_layer = self.transpose_for_scores(
|
300 |
+
self.value(encoder_hidden_states))
|
301 |
+
attention_mask = encoder_attention_mask
|
302 |
+
elif past_key_value is not None:
|
303 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
304 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
305 |
+
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
|
306 |
+
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
|
307 |
+
else:
|
308 |
+
key_layer = self.transpose_for_scores(self.key(hidden_states))
|
309 |
+
value_layer = self.transpose_for_scores(self.value(hidden_states))
|
310 |
+
|
311 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
312 |
+
|
313 |
+
past_key_value = (key_layer, value_layer)
|
314 |
+
|
315 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
316 |
+
attention_scores = torch.matmul(
|
317 |
+
query_layer, key_layer.transpose(-1, -2))
|
318 |
+
|
319 |
+
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
|
320 |
+
seq_length = hidden_states.size()[1]
|
321 |
+
position_ids_l = torch.arange(
|
322 |
+
seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
|
323 |
+
position_ids_r = torch.arange(
|
324 |
+
seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
|
325 |
+
distance = position_ids_l - position_ids_r
|
326 |
+
positional_embedding = self.distance_embedding(
|
327 |
+
distance + self.max_position_embeddings - 1)
|
328 |
+
positional_embedding = positional_embedding.to(
|
329 |
+
dtype=query_layer.dtype) # fp16 compatibility
|
330 |
+
|
331 |
+
if self.position_embedding_type == "relative_key":
|
332 |
+
relative_position_scores = torch.einsum(
|
333 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding)
|
334 |
+
attention_scores = attention_scores + relative_position_scores
|
335 |
+
elif self.position_embedding_type == "relative_key_query":
|
336 |
+
relative_position_scores_query = torch.einsum(
|
337 |
+
"bhld,lrd->bhlr", query_layer, positional_embedding)
|
338 |
+
relative_position_scores_key = torch.einsum(
|
339 |
+
"bhrd,lrd->bhlr", key_layer, positional_embedding)
|
340 |
+
attention_scores = attention_scores + \
|
341 |
+
relative_position_scores_query + relative_position_scores_key
|
342 |
+
|
343 |
+
attention_scores = attention_scores / \
|
344 |
+
math.sqrt(self.attention_head_size)
|
345 |
+
if attention_mask is not None:
|
346 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
347 |
+
attention_scores = attention_scores + attention_mask
|
348 |
+
|
349 |
+
# Normalize the attention scores to probabilities.
|
350 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
351 |
+
|
352 |
+
if is_cross_attention and self.save_attention:
|
353 |
+
self.save_attention_map(attention_probs)
|
354 |
+
attention_probs.register_hook(self.save_attn_gradients)
|
355 |
+
|
356 |
+
# This is actually dropping out entire tokens to attend to, which might
|
357 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
358 |
+
attention_probs_dropped = self.dropout(attention_probs)
|
359 |
+
|
360 |
+
# Mask heads if we want to
|
361 |
+
if head_mask is not None:
|
362 |
+
attention_probs_dropped = attention_probs_dropped * head_mask
|
363 |
+
|
364 |
+
context_layer = torch.matmul(attention_probs_dropped, value_layer)
|
365 |
+
|
366 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
367 |
+
new_context_layer_shape = context_layer.size()[
|
368 |
+
:-2] + (self.all_head_size,)
|
369 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
370 |
+
|
371 |
+
# added `attention_scores` to return tuple
|
372 |
+
outputs = (context_layer, attention_probs, attention_scores) if output_attentions else (
|
373 |
+
context_layer,)
|
374 |
+
|
375 |
+
outputs = outputs + (past_key_value,)
|
376 |
+
return outputs
|
377 |
+
|
378 |
+
|
379 |
+
class BertSelfOutput(nn.Module):
|
380 |
+
def __init__(self, config):
|
381 |
+
super().__init__()
|
382 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
383 |
+
self.LayerNorm = nn.LayerNorm(
|
384 |
+
config.hidden_size, eps=config.layer_norm_eps)
|
385 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
386 |
+
|
387 |
+
def forward(self, hidden_states, input_tensor):
|
388 |
+
hidden_states = self.dense(hidden_states)
|
389 |
+
hidden_states = self.dropout(hidden_states)
|
390 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
391 |
+
return hidden_states
|
392 |
+
|
393 |
+
|
394 |
+
class BertAttention(nn.Module):
|
395 |
+
def __init__(self, config, is_cross_attention=False):
|
396 |
+
super().__init__()
|
397 |
+
self.self = BertSelfAttention(config, is_cross_attention)
|
398 |
+
self.output = BertSelfOutput(config)
|
399 |
+
self.pruned_heads = set()
|
400 |
+
|
401 |
+
def prune_heads(self, heads):
|
402 |
+
if len(heads) == 0:
|
403 |
+
return
|
404 |
+
heads, index = find_pruneable_heads_and_indices(
|
405 |
+
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
|
406 |
+
)
|
407 |
+
|
408 |
+
# Prune linear layers
|
409 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
410 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
411 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
412 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
413 |
+
|
414 |
+
# Update hyper params and store pruned heads
|
415 |
+
self.self.num_attention_heads = self.self.num_attention_heads - \
|
416 |
+
len(heads)
|
417 |
+
self.self.all_head_size = self.self.attention_head_size * \
|
418 |
+
self.self.num_attention_heads
|
419 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
420 |
+
|
421 |
+
def forward(
|
422 |
+
self,
|
423 |
+
hidden_states,
|
424 |
+
attention_mask=None,
|
425 |
+
head_mask=None,
|
426 |
+
encoder_hidden_states=None,
|
427 |
+
encoder_attention_mask=None,
|
428 |
+
past_key_value=None,
|
429 |
+
output_attentions=False,
|
430 |
+
):
|
431 |
+
self_outputs = self.self(
|
432 |
+
hidden_states,
|
433 |
+
attention_mask,
|
434 |
+
head_mask,
|
435 |
+
encoder_hidden_states,
|
436 |
+
encoder_attention_mask,
|
437 |
+
past_key_value,
|
438 |
+
output_attentions,
|
439 |
+
)
|
440 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
441 |
+
# add attentions if we output them
|
442 |
+
outputs = (attention_output,) + self_outputs[1:]
|
443 |
+
return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
|
444 |
+
|
445 |
+
|
446 |
+
class BertIntermediate(nn.Module):
|
447 |
+
def __init__(self, config):
|
448 |
+
super().__init__()
|
449 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
450 |
+
if isinstance(config.hidden_act, str):
|
451 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
452 |
+
else:
|
453 |
+
self.intermediate_act_fn = config.hidden_act
|
454 |
+
|
455 |
+
def forward(self, hidden_states):
|
456 |
+
hidden_states = self.dense(hidden_states)
|
457 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
458 |
+
return hidden_states
|
459 |
+
|
460 |
+
|
461 |
+
class BertOutput(nn.Module):
|
462 |
+
def __init__(self, config):
|
463 |
+
super().__init__()
|
464 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
465 |
+
self.LayerNorm = nn.LayerNorm(
|
466 |
+
config.hidden_size, eps=config.layer_norm_eps)
|
467 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
468 |
+
|
469 |
+
def forward(self, hidden_states, input_tensor):
|
470 |
+
hidden_states = self.dense(hidden_states)
|
471 |
+
hidden_states = self.dropout(hidden_states)
|
472 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
473 |
+
return hidden_states
|
474 |
+
|
475 |
+
|
476 |
+
class BertLayer(nn.Module):
|
477 |
+
def __init__(self, config, layer_num, token_keep_rate=1.0):
|
478 |
+
super().__init__()
|
479 |
+
self.config = config
|
480 |
+
self.chunk_size_feed_forward = config.chunk_size_feed_forward
|
481 |
+
self.seq_len_dim = 1
|
482 |
+
self.attention = BertAttention(config)
|
483 |
+
|
484 |
+
self.has_cross_attention = (layer_num >= config.fusion_layer)
|
485 |
+
if self.has_cross_attention:
|
486 |
+
self.layer_num = layer_num
|
487 |
+
self.crossattention = BertAttention(
|
488 |
+
config, is_cross_attention=True)
|
489 |
+
|
490 |
+
# sparse params
|
491 |
+
self.token_keep_rate = token_keep_rate
|
492 |
+
self.token_keep_strategy = config.token_keep_strategy
|
493 |
+
self.encoder_num_cls_tokens = 1 # multiple cls tokens
|
494 |
+
|
495 |
+
self.intermediate = BertIntermediate(config)
|
496 |
+
self.output = BertOutput(config)
|
497 |
+
|
498 |
+
def sparsify(self, x, attn, mask=None):
|
499 |
+
x_cls, x_ = x[:, :self.encoder_num_cls_tokens], x[:, self.encoder_num_cls_tokens:]
|
500 |
+
assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
|
501 |
+
left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
|
502 |
+
if len(attn.shape) == 4:
|
503 |
+
attn = attn.mean(1) # pool over attention heads
|
504 |
+
|
505 |
+
if self.token_keep_strategy == 'cls_attn':
|
506 |
+
cls_attn = attn[:, 0, self.encoder_num_cls_tokens:]
|
507 |
+
_, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
|
508 |
+
|
509 |
+
elif self.token_keep_strategy == 'avg_attn':
|
510 |
+
avg_attn = attn.mean(1)[:, self.encoder_num_cls_tokens:]
|
511 |
+
_, idx = torch.topk(avg_attn, left_tokens, dim=1) # [B, left_tokens]
|
512 |
+
|
513 |
+
elif self.token_keep_strategy == 'random':
|
514 |
+
rand = torch.rand(x_.shape[:2], device=x_.device)
|
515 |
+
_, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
|
516 |
+
|
517 |
+
else:
|
518 |
+
raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
|
519 |
+
|
520 |
+
idx, _ = torch.sort(idx, dim=1)
|
521 |
+
index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
|
522 |
+
outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
|
523 |
+
if mask is not None:
|
524 |
+
mask_cls, mask_ = mask[..., :self.encoder_num_cls_tokens], mask[..., self.encoder_num_cls_tokens:]
|
525 |
+
index = idx.unsqueeze(1).unsqueeze(1) # [B, 1, 1, left_tokens]
|
526 |
+
mask = torch.cat((mask_cls, mask_.gather(-1, index)), dim=-1).contiguous()
|
527 |
+
return outputs, mask, idx
|
528 |
+
|
529 |
+
def forward(
|
530 |
+
self,
|
531 |
+
hidden_states,
|
532 |
+
attention_mask=None,
|
533 |
+
head_mask=None,
|
534 |
+
encoder_hidden_states=None,
|
535 |
+
encoder_attention_mask=None,
|
536 |
+
past_key_value=None,
|
537 |
+
output_attentions=False,
|
538 |
+
):
|
539 |
+
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
540 |
+
self_attn_past_key_value = past_key_value[:
|
541 |
+
2] if past_key_value is not None else None
|
542 |
+
self_attention_outputs = self.attention(
|
543 |
+
hidden_states,
|
544 |
+
attention_mask,
|
545 |
+
head_mask,
|
546 |
+
output_attentions=output_attentions,
|
547 |
+
past_key_value=self_attn_past_key_value,
|
548 |
+
) # (context_layer, attention_probs, attention_scores, past_key_value,)
|
549 |
+
attention_output = self_attention_outputs[0]
|
550 |
+
|
551 |
+
outputs = self_attention_outputs[1:-1]
|
552 |
+
present_key_value = self_attention_outputs[-1]
|
553 |
+
|
554 |
+
if self.has_cross_attention:
|
555 |
+
assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
|
556 |
+
output_attentions = (output_attentions or self.token_keep_rate < 1)
|
557 |
+
|
558 |
+
if type(encoder_hidden_states) == list:
|
559 |
+
cross_attention_outputs = self.crossattention(
|
560 |
+
attention_output,
|
561 |
+
attention_mask,
|
562 |
+
head_mask,
|
563 |
+
encoder_hidden_states[(
|
564 |
+
self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
|
565 |
+
encoder_attention_mask[(
|
566 |
+
self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
|
567 |
+
output_attentions=output_attentions,
|
568 |
+
)
|
569 |
+
attention_output = cross_attention_outputs[0]
|
570 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
571 |
+
|
572 |
+
else:
|
573 |
+
cross_attention_outputs = self.crossattention(
|
574 |
+
attention_output,
|
575 |
+
attention_mask,
|
576 |
+
head_mask,
|
577 |
+
encoder_hidden_states,
|
578 |
+
encoder_attention_mask,
|
579 |
+
output_attentions=output_attentions,
|
580 |
+
) # (context_layer, attention_probs, attention_scores, past_key_value,)
|
581 |
+
attention_output = cross_attention_outputs[0]
|
582 |
+
|
583 |
+
# add cross attentions if we output attention weights
|
584 |
+
outputs = outputs + cross_attention_outputs[1:-1]
|
585 |
+
|
586 |
+
# node sparsification
|
587 |
+
if self.token_keep_rate < 1:
|
588 |
+
encoder_hidden_states, encoder_attention_mask, token_keep_idx = self.sparsify(
|
589 |
+
encoder_hidden_states, cross_attention_outputs[1], encoder_attention_mask)
|
590 |
+
outputs = outputs + (encoder_hidden_states, encoder_attention_mask, token_keep_idx)
|
591 |
+
|
592 |
+
layer_output = apply_chunking_to_forward(
|
593 |
+
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
|
594 |
+
)
|
595 |
+
outputs = (layer_output,) + outputs
|
596 |
+
|
597 |
+
outputs = outputs + (present_key_value,)
|
598 |
+
|
599 |
+
return outputs
|
600 |
+
|
601 |
+
def feed_forward_chunk(self, attention_output):
|
602 |
+
intermediate_output = self.intermediate(attention_output)
|
603 |
+
layer_output = self.output(intermediate_output, attention_output)
|
604 |
+
return layer_output
|
605 |
+
|
606 |
+
|
607 |
+
class BertEncoder(nn.Module):
|
608 |
+
def __init__(self, config):
|
609 |
+
super().__init__()
|
610 |
+
self.config = config
|
611 |
+
|
612 |
+
# node sparsification
|
613 |
+
token_keep_rate = [1] * config.num_hidden_layers
|
614 |
+
for loc in config.token_drop_loc:
|
615 |
+
token_keep_rate[loc] = config.token_keep_rate
|
616 |
+
|
617 |
+
self.layer = nn.ModuleList([BertLayer(config, i, token_keep_rate[i])
|
618 |
+
for i in range(config.num_hidden_layers)])
|
619 |
+
|
620 |
+
def forward(
|
621 |
+
self,
|
622 |
+
hidden_states,
|
623 |
+
attention_mask=None,
|
624 |
+
head_mask=None,
|
625 |
+
encoder_hidden_states=None,
|
626 |
+
encoder_attention_mask=None,
|
627 |
+
past_key_values=None,
|
628 |
+
use_cache=None,
|
629 |
+
output_attentions=False,
|
630 |
+
output_hidden_states=False,
|
631 |
+
output_token_idx=False,
|
632 |
+
return_dict=True,
|
633 |
+
mode='multi_modal',
|
634 |
+
normalize_attention=True
|
635 |
+
):
|
636 |
+
all_hidden_states = () if output_hidden_states else None
|
637 |
+
all_self_attentions = () if output_attentions else None
|
638 |
+
all_cross_attentions = () if output_attentions else None
|
639 |
+
all_token_idx = () if output_token_idx else None
|
640 |
+
|
641 |
+
next_decoder_cache = () if use_cache else None
|
642 |
+
|
643 |
+
if mode == 'text':
|
644 |
+
start_layer = 0
|
645 |
+
output_layer = self.config.fusion_layer
|
646 |
+
|
647 |
+
elif mode == 'fusion':
|
648 |
+
start_layer = self.config.fusion_layer
|
649 |
+
output_layer = self.config.num_hidden_layers
|
650 |
+
|
651 |
+
elif mode == 'multi_modal':
|
652 |
+
start_layer = 0
|
653 |
+
output_layer = self.config.num_hidden_layers
|
654 |
+
|
655 |
+
for i in range(start_layer, output_layer):
|
656 |
+
layer_module = self.layer[i]
|
657 |
+
if output_hidden_states:
|
658 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
659 |
+
|
660 |
+
layer_head_mask = head_mask[i] if head_mask is not None else None
|
661 |
+
past_key_value = past_key_values[i] if past_key_values is not None else None
|
662 |
+
|
663 |
+
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
664 |
+
|
665 |
+
if use_cache:
|
666 |
+
use_cache = False
|
667 |
+
|
668 |
+
def create_custom_forward(module):
|
669 |
+
def custom_forward(*inputs):
|
670 |
+
return module(*inputs, past_key_value, output_attentions)
|
671 |
+
|
672 |
+
return custom_forward
|
673 |
+
|
674 |
+
layer_outputs = torch.utils.checkpoint.checkpoint(
|
675 |
+
create_custom_forward(layer_module),
|
676 |
+
hidden_states,
|
677 |
+
attention_mask,
|
678 |
+
layer_head_mask,
|
679 |
+
encoder_hidden_states,
|
680 |
+
encoder_attention_mask,
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
layer_outputs = layer_module(
|
684 |
+
hidden_states,
|
685 |
+
attention_mask,
|
686 |
+
layer_head_mask,
|
687 |
+
encoder_hidden_states,
|
688 |
+
encoder_attention_mask,
|
689 |
+
past_key_value,
|
690 |
+
output_attentions,
|
691 |
+
) # (context_layer, attention_probs, attention_scores, past_key_value,)
|
692 |
+
hidden_states = layer_outputs[0]
|
693 |
+
# update visual sequence
|
694 |
+
if mode == 'fusion' and layer_module.token_keep_rate < 1:
|
695 |
+
encoder_hidden_states, encoder_attention_mask, token_idx = layer_outputs[-4:-1]
|
696 |
+
|
697 |
+
if output_token_idx:
|
698 |
+
all_token_idx = all_token_idx + (token_idx,)
|
699 |
+
|
700 |
+
if use_cache:
|
701 |
+
next_decoder_cache += (layer_outputs[-1],)
|
702 |
+
if output_attentions:
|
703 |
+
# whether to output normalized attention,
|
704 |
+
# note for unnormalized attention, there is a mask added
|
705 |
+
offset = int(normalize_attention)
|
706 |
+
all_self_attentions = all_self_attentions + (layer_outputs[2-offset], )
|
707 |
+
if hasattr(layer_module, "crossattention"):
|
708 |
+
all_cross_attentions = all_cross_attentions + (layer_outputs[4-offset], )
|
709 |
+
|
710 |
+
if output_hidden_states:
|
711 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
712 |
+
|
713 |
+
if not return_dict:
|
714 |
+
return tuple(
|
715 |
+
v
|
716 |
+
for v in [
|
717 |
+
hidden_states,
|
718 |
+
next_decoder_cache,
|
719 |
+
all_hidden_states,
|
720 |
+
all_self_attentions,
|
721 |
+
all_cross_attentions,
|
722 |
+
]
|
723 |
+
if v is not None
|
724 |
+
)
|
725 |
+
return BertModelOutputWithPastAndCrossAttentions(
|
726 |
+
last_hidden_state=hidden_states,
|
727 |
+
past_key_values=next_decoder_cache,
|
728 |
+
hidden_states=all_hidden_states,
|
729 |
+
attentions=all_self_attentions,
|
730 |
+
cross_attentions=all_cross_attentions,
|
731 |
+
token_idx=all_token_idx
|
732 |
+
)
|
733 |
+
|
734 |
+
|
735 |
+
class BertPooler(nn.Module):
|
736 |
+
def __init__(self, config):
|
737 |
+
super().__init__()
|
738 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
739 |
+
self.activation = nn.Tanh()
|
740 |
+
|
741 |
+
def forward(self, hidden_states):
|
742 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
743 |
+
# to the first token.
|
744 |
+
first_token_tensor = hidden_states[:, 0]
|
745 |
+
pooled_output = self.dense(first_token_tensor)
|
746 |
+
pooled_output = self.activation(pooled_output)
|
747 |
+
return pooled_output
|
748 |
+
|
749 |
+
|
750 |
+
class BertPredictionHeadTransform(nn.Module):
|
751 |
+
def __init__(self, config):
|
752 |
+
super().__init__()
|
753 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
754 |
+
if isinstance(config.hidden_act, str):
|
755 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
756 |
+
else:
|
757 |
+
self.transform_act_fn = config.hidden_act
|
758 |
+
self.LayerNorm = nn.LayerNorm(
|
759 |
+
config.hidden_size, eps=config.layer_norm_eps)
|
760 |
+
|
761 |
+
def forward(self, hidden_states):
|
762 |
+
hidden_states = self.dense(hidden_states)
|
763 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
764 |
+
hidden_states = self.LayerNorm(hidden_states)
|
765 |
+
return hidden_states
|
766 |
+
|
767 |
+
|
768 |
+
class BertLMPredictionHead(nn.Module):
|
769 |
+
def __init__(self, config):
|
770 |
+
super().__init__()
|
771 |
+
self.transform = BertPredictionHeadTransform(config)
|
772 |
+
|
773 |
+
# The output weights are the same as the input embeddings, but there is
|
774 |
+
# an output-only bias for each token.
|
775 |
+
self.decoder = nn.Linear(
|
776 |
+
config.hidden_size, config.vocab_size, bias=False)
|
777 |
+
|
778 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
779 |
+
|
780 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
781 |
+
self.decoder.bias = self.bias
|
782 |
+
|
783 |
+
def forward(self, hidden_states):
|
784 |
+
hidden_states = self.transform(hidden_states)
|
785 |
+
hidden_states = self.decoder(hidden_states)
|
786 |
+
return hidden_states
|
787 |
+
|
788 |
+
|
789 |
+
class BertOnlyMLMHead(nn.Module):
|
790 |
+
def __init__(self, config):
|
791 |
+
super().__init__()
|
792 |
+
self.predictions = BertLMPredictionHead(config)
|
793 |
+
|
794 |
+
def forward(self, sequence_output):
|
795 |
+
prediction_scores = self.predictions(sequence_output)
|
796 |
+
return prediction_scores
|
797 |
+
|
798 |
+
|
799 |
+
class BertOnlyNSPHead(nn.Module):
|
800 |
+
def __init__(self, config):
|
801 |
+
super().__init__()
|
802 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
803 |
+
|
804 |
+
def forward(self, pooled_output):
|
805 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
806 |
+
return seq_relationship_score
|
807 |
+
|
808 |
+
|
809 |
+
class BertPreTrainingHeads(nn.Module):
|
810 |
+
def __init__(self, config):
|
811 |
+
super().__init__()
|
812 |
+
self.predictions = BertLMPredictionHead(config)
|
813 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
814 |
+
|
815 |
+
def forward(self, sequence_output, pooled_output):
|
816 |
+
prediction_scores = self.predictions(sequence_output)
|
817 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
818 |
+
return prediction_scores, seq_relationship_score
|
819 |
+
|
820 |
+
|
821 |
+
class BertPreTrainedModel(PreTrainedModel):
|
822 |
+
"""
|
823 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
824 |
+
models.
|
825 |
+
"""
|
826 |
+
|
827 |
+
config_class = BertConfig
|
828 |
+
load_tf_weights = load_tf_weights_in_bert
|
829 |
+
base_model_prefix = "bert"
|
830 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
831 |
+
|
832 |
+
def _init_weights(self, module):
|
833 |
+
""" Initialize the weights """
|
834 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
835 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
836 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
837 |
+
module.weight.data.normal_(
|
838 |
+
mean=0.0, std=self.config.initializer_range)
|
839 |
+
elif isinstance(module, nn.LayerNorm):
|
840 |
+
module.bias.data.zero_()
|
841 |
+
module.weight.data.fill_(1.0)
|
842 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
843 |
+
module.bias.data.zero_()
|
844 |
+
|
845 |
+
|
846 |
+
@dataclass
|
847 |
+
class BertForPreTrainingOutput(ModelOutput):
|
848 |
+
"""
|
849 |
+
Output type of :class:`~transformers.BertForPreTraining`.
|
850 |
+
Args:
|
851 |
+
loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
|
852 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
853 |
+
(classification) loss.
|
854 |
+
prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
|
855 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
856 |
+
seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
857 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
|
858 |
+
before SoftMax).
|
859 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
860 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
861 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
862 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
863 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
864 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
865 |
+
sequence_length, sequence_length)`.
|
866 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
867 |
+
heads.
|
868 |
+
"""
|
869 |
+
|
870 |
+
loss: Optional[torch.FloatTensor] = None
|
871 |
+
prediction_logits: torch.FloatTensor = None
|
872 |
+
seq_relationship_logits: torch.FloatTensor = None
|
873 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
874 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
875 |
+
|
876 |
+
|
877 |
+
BERT_START_DOCSTRING = r"""
|
878 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
879 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
880 |
+
pruning heads etc.)
|
881 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
882 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
883 |
+
general usage and behavior.
|
884 |
+
Parameters:
|
885 |
+
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
886 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
887 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
888 |
+
weights.
|
889 |
+
"""
|
890 |
+
|
891 |
+
BERT_INPUTS_DOCSTRING = r"""
|
892 |
+
Args:
|
893 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
|
894 |
+
Indices of input sequence tokens in the vocabulary.
|
895 |
+
Indices can be obtained using :class:`~transformers.BertTokenizer`. See
|
896 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
897 |
+
details.
|
898 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
899 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
|
900 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
901 |
+
- 1 for tokens that are **not masked**,
|
902 |
+
- 0 for tokens that are **masked**.
|
903 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
904 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
905 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
906 |
+
1]``:
|
907 |
+
- 0 corresponds to a `sentence A` token,
|
908 |
+
- 1 corresponds to a `sentence B` token.
|
909 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
910 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
|
911 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
912 |
+
config.max_position_embeddings - 1]``.
|
913 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
914 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
915 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
916 |
+
- 1 indicates the head is **not masked**,
|
917 |
+
- 0 indicates the head is **masked**.
|
918 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
|
919 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
920 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
921 |
+
vectors than the model's internal embedding lookup matrix.
|
922 |
+
output_attentions (:obj:`bool`, `optional`):
|
923 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
924 |
+
tensors for more detail.
|
925 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
926 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
927 |
+
more detail.
|
928 |
+
return_dict (:obj:`bool`, `optional`):
|
929 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
930 |
+
"""
|
931 |
+
|
932 |
+
|
933 |
+
@add_start_docstrings(
|
934 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
935 |
+
BERT_START_DOCSTRING,
|
936 |
+
)
|
937 |
+
class BertModel(BertPreTrainedModel):
|
938 |
+
"""
|
939 |
+
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
|
940 |
+
cross-attention is added between the self-attention layers, following the architecture described in `Attention is
|
941 |
+
all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
|
942 |
+
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
943 |
+
argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
|
944 |
+
input to the forward pass.
|
945 |
+
"""
|
946 |
+
|
947 |
+
def __init__(self, config, add_pooling_layer=True):
|
948 |
+
super().__init__(config)
|
949 |
+
self.config = config
|
950 |
+
|
951 |
+
self.embeddings = BertEmbeddings(config)
|
952 |
+
|
953 |
+
self.encoder = BertEncoder(config)
|
954 |
+
|
955 |
+
self.pooler = BertPooler(config) if add_pooling_layer else None
|
956 |
+
|
957 |
+
self.init_weights()
|
958 |
+
|
959 |
+
def get_input_embeddings(self):
|
960 |
+
return self.embeddings.word_embeddings
|
961 |
+
|
962 |
+
def set_input_embeddings(self, value):
|
963 |
+
self.embeddings.word_embeddings = value
|
964 |
+
|
965 |
+
def _prune_heads(self, heads_to_prune):
|
966 |
+
"""
|
967 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
|
968 |
+
class PreTrainedModel
|
969 |
+
"""
|
970 |
+
for layer, heads in heads_to_prune.items():
|
971 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
972 |
+
|
973 |
+
def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
|
974 |
+
"""
|
975 |
+
Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
|
976 |
+
|
977 |
+
Arguments:
|
978 |
+
attention_mask (:obj:`torch.Tensor`):
|
979 |
+
Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
|
980 |
+
input_shape (:obj:`Tuple[int]`):
|
981 |
+
The shape of the input to the model.
|
982 |
+
device: (:obj:`torch.device`):
|
983 |
+
The device of the input to the model.
|
984 |
+
|
985 |
+
Returns:
|
986 |
+
:obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
|
987 |
+
"""
|
988 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
989 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
990 |
+
if attention_mask.dim() == 3:
|
991 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
992 |
+
elif attention_mask.dim() == 2:
|
993 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
994 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
995 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
996 |
+
if is_decoder:
|
997 |
+
batch_size, seq_length = input_shape
|
998 |
+
seq_ids = torch.arange(seq_length, device=device)
|
999 |
+
causal_mask = seq_ids[None, None, :].repeat(
|
1000 |
+
batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
1001 |
+
# in case past_key_values are used we need to add a prefix ones mask to the causal mask
|
1002 |
+
# causal and attention masks must have same type with pytorch version < 1.3
|
1003 |
+
causal_mask = causal_mask.to(attention_mask.dtype)
|
1004 |
+
|
1005 |
+
if causal_mask.shape[1] < attention_mask.shape[1]:
|
1006 |
+
prefix_seq_len = attention_mask.shape[1] - \
|
1007 |
+
causal_mask.shape[1]
|
1008 |
+
causal_mask = torch.cat(
|
1009 |
+
[
|
1010 |
+
torch.ones(
|
1011 |
+
(batch_size, seq_length,
|
1012 |
+
prefix_seq_len), device=device, dtype=causal_mask.dtype
|
1013 |
+
),
|
1014 |
+
causal_mask,
|
1015 |
+
],
|
1016 |
+
axis=-1,
|
1017 |
+
)
|
1018 |
+
|
1019 |
+
extended_attention_mask = causal_mask[:, None,
|
1020 |
+
:, :] * attention_mask[:, None, None, :]
|
1021 |
+
else:
|
1022 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
1023 |
+
else:
|
1024 |
+
raise ValueError(
|
1025 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
1026 |
+
input_shape, attention_mask.shape
|
1027 |
+
)
|
1028 |
+
)
|
1029 |
+
|
1030 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
1031 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
1032 |
+
# positions we want to attend and -10000.0 for masked positions.
|
1033 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
1034 |
+
# effectively the same as removing these entirely.
|
1035 |
+
extended_attention_mask = extended_attention_mask.to(
|
1036 |
+
dtype=self.dtype) # fp16 compatibility
|
1037 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
1038 |
+
return extended_attention_mask
|
1039 |
+
|
1040 |
+
def forward(
|
1041 |
+
self,
|
1042 |
+
input_ids=None,
|
1043 |
+
attention_mask=None,
|
1044 |
+
token_type_ids=None,
|
1045 |
+
position_ids=None,
|
1046 |
+
head_mask=None,
|
1047 |
+
inputs_embeds=None,
|
1048 |
+
encoder_embeds=None,
|
1049 |
+
encoder_hidden_states=None,
|
1050 |
+
encoder_attention_mask=None,
|
1051 |
+
past_key_values=None,
|
1052 |
+
use_cache=None,
|
1053 |
+
output_attentions=None,
|
1054 |
+
output_hidden_states=None,
|
1055 |
+
output_token_idx=None,
|
1056 |
+
return_dict=None,
|
1057 |
+
is_decoder=False,
|
1058 |
+
mode='multi_modal',
|
1059 |
+
normalize_attention=True,
|
1060 |
+
):
|
1061 |
+
r"""
|
1062 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1063 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1064 |
+
the model is configured as a decoder.
|
1065 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1066 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1067 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1068 |
+
- 1 for tokens that are **not masked**,
|
1069 |
+
- 0 for tokens that are **masked**.
|
1070 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1071 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1072 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1073 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1074 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1075 |
+
use_cache (:obj:`bool`, `optional`):
|
1076 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1077 |
+
decoding (see :obj:`past_key_values`).
|
1078 |
+
"""
|
1079 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
1080 |
+
output_hidden_states = (
|
1081 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
1082 |
+
)
|
1083 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1084 |
+
|
1085 |
+
if is_decoder:
|
1086 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
1087 |
+
else:
|
1088 |
+
use_cache = False
|
1089 |
+
|
1090 |
+
if input_ids is not None and inputs_embeds is not None:
|
1091 |
+
raise ValueError(
|
1092 |
+
"You cannot specify both input_ids and inputs_embeds at the same time")
|
1093 |
+
elif input_ids is not None:
|
1094 |
+
input_shape = input_ids.size()
|
1095 |
+
batch_size, seq_length = input_shape
|
1096 |
+
device = input_ids.device
|
1097 |
+
elif inputs_embeds is not None:
|
1098 |
+
input_shape = inputs_embeds.size()[:-1]
|
1099 |
+
batch_size, seq_length = input_shape
|
1100 |
+
device = inputs_embeds.device
|
1101 |
+
elif encoder_embeds is not None:
|
1102 |
+
input_shape = encoder_embeds.size()[:-1]
|
1103 |
+
batch_size, seq_length = input_shape
|
1104 |
+
device = encoder_embeds.device
|
1105 |
+
else:
|
1106 |
+
raise ValueError(
|
1107 |
+
"You have to specify either input_ids or inputs_embeds or encoder_embeds")
|
1108 |
+
|
1109 |
+
# past_key_values_length
|
1110 |
+
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
1111 |
+
|
1112 |
+
if attention_mask is None:
|
1113 |
+
attention_mask = torch.ones(
|
1114 |
+
((batch_size, seq_length + past_key_values_length)), device=device)
|
1115 |
+
if token_type_ids is None:
|
1116 |
+
token_type_ids = torch.zeros(
|
1117 |
+
input_shape, dtype=torch.long, device=device)
|
1118 |
+
|
1119 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
1120 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
1121 |
+
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
|
1122 |
+
device, is_decoder)
|
1123 |
+
|
1124 |
+
# If a 2D or 3D attention mask is provided for the cross-attention
|
1125 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
1126 |
+
if encoder_hidden_states is not None:
|
1127 |
+
if type(encoder_hidden_states) == list:
|
1128 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size(
|
1129 |
+
)
|
1130 |
+
else:
|
1131 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
1132 |
+
encoder_hidden_shape = (
|
1133 |
+
encoder_batch_size, encoder_sequence_length)
|
1134 |
+
|
1135 |
+
if type(encoder_attention_mask) == list:
|
1136 |
+
encoder_extended_attention_mask = [
|
1137 |
+
self.invert_attention_mask(mask) for mask in encoder_attention_mask]
|
1138 |
+
elif encoder_attention_mask is None:
|
1139 |
+
encoder_attention_mask = torch.ones(
|
1140 |
+
encoder_hidden_shape, device=device)
|
1141 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
1142 |
+
encoder_attention_mask)
|
1143 |
+
else:
|
1144 |
+
encoder_extended_attention_mask = self.invert_attention_mask(
|
1145 |
+
encoder_attention_mask)
|
1146 |
+
else:
|
1147 |
+
encoder_extended_attention_mask = None
|
1148 |
+
|
1149 |
+
# Prepare head mask if needed
|
1150 |
+
# 1.0 in head_mask indicate we keep the head
|
1151 |
+
# attention_probs has shape bsz x n_heads x N x N
|
1152 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
1153 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
1154 |
+
head_mask = self.get_head_mask(
|
1155 |
+
head_mask, self.config.num_hidden_layers)
|
1156 |
+
|
1157 |
+
if encoder_embeds is None:
|
1158 |
+
embedding_output = self.embeddings(
|
1159 |
+
input_ids=input_ids,
|
1160 |
+
position_ids=position_ids,
|
1161 |
+
token_type_ids=token_type_ids,
|
1162 |
+
inputs_embeds=inputs_embeds,
|
1163 |
+
past_key_values_length=past_key_values_length,
|
1164 |
+
)
|
1165 |
+
else:
|
1166 |
+
embedding_output = encoder_embeds
|
1167 |
+
|
1168 |
+
encoder_outputs = self.encoder(
|
1169 |
+
embedding_output,
|
1170 |
+
attention_mask=extended_attention_mask,
|
1171 |
+
head_mask=head_mask,
|
1172 |
+
encoder_hidden_states=encoder_hidden_states,
|
1173 |
+
encoder_attention_mask=encoder_extended_attention_mask,
|
1174 |
+
past_key_values=past_key_values,
|
1175 |
+
use_cache=use_cache,
|
1176 |
+
output_attentions=output_attentions,
|
1177 |
+
output_hidden_states=output_hidden_states,
|
1178 |
+
output_token_idx=output_token_idx,
|
1179 |
+
return_dict=return_dict,
|
1180 |
+
mode=mode,
|
1181 |
+
normalize_attention=normalize_attention,
|
1182 |
+
|
1183 |
+
)
|
1184 |
+
sequence_output = encoder_outputs[0]
|
1185 |
+
pooled_output = self.pooler(
|
1186 |
+
sequence_output) if self.pooler is not None else None
|
1187 |
+
|
1188 |
+
if not return_dict:
|
1189 |
+
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
1190 |
+
|
1191 |
+
return BertModelOutputWithPoolingAndCrossAttentions(
|
1192 |
+
last_hidden_state=sequence_output,
|
1193 |
+
pooler_output=pooled_output,
|
1194 |
+
past_key_values=encoder_outputs.past_key_values,
|
1195 |
+
hidden_states=encoder_outputs.hidden_states,
|
1196 |
+
attentions=encoder_outputs.attentions,
|
1197 |
+
cross_attentions=encoder_outputs.cross_attentions,
|
1198 |
+
token_idx=encoder_outputs.token_idx,
|
1199 |
+
)
|
1200 |
+
|
1201 |
+
|
1202 |
+
@add_start_docstrings(
|
1203 |
+
"""
|
1204 |
+
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
|
1205 |
+
sentence prediction (classification)` head.
|
1206 |
+
""",
|
1207 |
+
BERT_START_DOCSTRING,
|
1208 |
+
)
|
1209 |
+
class BertForPreTraining(BertPreTrainedModel):
|
1210 |
+
def __init__(self, config):
|
1211 |
+
super().__init__(config)
|
1212 |
+
|
1213 |
+
self.bert = BertModel(config)
|
1214 |
+
self.cls = BertPreTrainingHeads(config)
|
1215 |
+
|
1216 |
+
self.init_weights()
|
1217 |
+
|
1218 |
+
def get_output_embeddings(self):
|
1219 |
+
return self.cls.predictions.decoder
|
1220 |
+
|
1221 |
+
def set_output_embeddings(self, new_embeddings):
|
1222 |
+
self.cls.predictions.decoder = new_embeddings
|
1223 |
+
|
1224 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1225 |
+
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
1226 |
+
def forward(
|
1227 |
+
self,
|
1228 |
+
input_ids=None,
|
1229 |
+
attention_mask=None,
|
1230 |
+
token_type_ids=None,
|
1231 |
+
position_ids=None,
|
1232 |
+
head_mask=None,
|
1233 |
+
inputs_embeds=None,
|
1234 |
+
labels=None,
|
1235 |
+
next_sentence_label=None,
|
1236 |
+
output_attentions=None,
|
1237 |
+
output_hidden_states=None,
|
1238 |
+
return_dict=None,
|
1239 |
+
):
|
1240 |
+
r"""
|
1241 |
+
labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
|
1242 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1243 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1244 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1245 |
+
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
|
1246 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
1247 |
+
(see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
|
1248 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
1249 |
+
- 1 indicates sequence B is a random sequence.
|
1250 |
+
kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
|
1251 |
+
Used to hide legacy arguments that have been deprecated.
|
1252 |
+
Returns:
|
1253 |
+
Example::
|
1254 |
+
>>> from transformers import BertTokenizer, BertForPreTraining
|
1255 |
+
>>> import torch
|
1256 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1257 |
+
>>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
1258 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1259 |
+
>>> outputs = model(**inputs)
|
1260 |
+
>>> prediction_logits = outputs.prediction_logits
|
1261 |
+
>>> seq_relationship_logits = outputs.seq_relationship_logits
|
1262 |
+
"""
|
1263 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1264 |
+
|
1265 |
+
outputs = self.bert(
|
1266 |
+
input_ids,
|
1267 |
+
attention_mask=attention_mask,
|
1268 |
+
token_type_ids=token_type_ids,
|
1269 |
+
position_ids=position_ids,
|
1270 |
+
head_mask=head_mask,
|
1271 |
+
inputs_embeds=inputs_embeds,
|
1272 |
+
output_attentions=output_attentions,
|
1273 |
+
output_hidden_states=output_hidden_states,
|
1274 |
+
return_dict=return_dict,
|
1275 |
+
)
|
1276 |
+
|
1277 |
+
sequence_output, pooled_output = outputs[:2]
|
1278 |
+
prediction_scores, seq_relationship_score = self.cls(
|
1279 |
+
sequence_output, pooled_output)
|
1280 |
+
|
1281 |
+
total_loss = None
|
1282 |
+
if labels is not None and next_sentence_label is not None:
|
1283 |
+
loss_fct = CrossEntropyLoss()
|
1284 |
+
masked_lm_loss = loss_fct(
|
1285 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1286 |
+
next_sentence_loss = loss_fct(
|
1287 |
+
seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
1288 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
1289 |
+
|
1290 |
+
if not return_dict:
|
1291 |
+
output = (prediction_scores, seq_relationship_score) + outputs[2:]
|
1292 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
1293 |
+
|
1294 |
+
return BertForPreTrainingOutput(
|
1295 |
+
loss=total_loss,
|
1296 |
+
prediction_logits=prediction_scores,
|
1297 |
+
seq_relationship_logits=seq_relationship_score,
|
1298 |
+
hidden_states=outputs.hidden_states,
|
1299 |
+
attentions=outputs.attentions,
|
1300 |
+
)
|
1301 |
+
|
1302 |
+
|
1303 |
+
@add_start_docstrings(
|
1304 |
+
"""Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
|
1305 |
+
)
|
1306 |
+
class BertLMHeadModel(BertPreTrainedModel):
|
1307 |
+
|
1308 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1309 |
+
_keys_to_ignore_on_load_missing = [
|
1310 |
+
r"position_ids", r"predictions.decoder.bias"]
|
1311 |
+
|
1312 |
+
def __init__(self, config):
|
1313 |
+
super().__init__(config)
|
1314 |
+
|
1315 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1316 |
+
self.cls = BertOnlyMLMHead(config)
|
1317 |
+
|
1318 |
+
self.init_weights()
|
1319 |
+
|
1320 |
+
def get_output_embeddings(self):
|
1321 |
+
return self.cls.predictions.decoder
|
1322 |
+
|
1323 |
+
def set_output_embeddings(self, new_embeddings):
|
1324 |
+
self.cls.predictions.decoder = new_embeddings
|
1325 |
+
|
1326 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1327 |
+
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
1328 |
+
def forward(
|
1329 |
+
self,
|
1330 |
+
input_ids=None,
|
1331 |
+
attention_mask=None,
|
1332 |
+
token_type_ids=None,
|
1333 |
+
position_ids=None,
|
1334 |
+
head_mask=None,
|
1335 |
+
inputs_embeds=None,
|
1336 |
+
encoder_hidden_states=None,
|
1337 |
+
encoder_attention_mask=None,
|
1338 |
+
labels=None,
|
1339 |
+
past_key_values=None,
|
1340 |
+
use_cache=None,
|
1341 |
+
output_attentions=None,
|
1342 |
+
output_hidden_states=None,
|
1343 |
+
return_dict=None,
|
1344 |
+
is_decoder=True,
|
1345 |
+
reduction='mean',
|
1346 |
+
mode='multi_modal',
|
1347 |
+
normalize_attention=True,
|
1348 |
+
soft_labels=None,
|
1349 |
+
alpha=0,
|
1350 |
+
return_logits=False,
|
1351 |
+
):
|
1352 |
+
r"""
|
1353 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
1354 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
|
1355 |
+
the model is configured as a decoder.
|
1356 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1357 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
|
1358 |
+
the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
|
1359 |
+
- 1 for tokens that are **not masked**,
|
1360 |
+
- 0 for tokens that are **masked**.
|
1361 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1362 |
+
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
|
1363 |
+
``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
|
1364 |
+
ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
|
1365 |
+
past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
|
1366 |
+
Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
|
1367 |
+
If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
|
1368 |
+
(those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
|
1369 |
+
instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
|
1370 |
+
use_cache (:obj:`bool`, `optional`):
|
1371 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
1372 |
+
decoding (see :obj:`past_key_values`).
|
1373 |
+
Returns:
|
1374 |
+
Example::
|
1375 |
+
>>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
|
1376 |
+
>>> import torch
|
1377 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
|
1378 |
+
>>> config = BertConfig.from_pretrained("bert-base-cased")
|
1379 |
+
>>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
|
1380 |
+
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
1381 |
+
>>> outputs = model(**inputs)
|
1382 |
+
>>> prediction_logits = outputs.logits
|
1383 |
+
"""
|
1384 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1385 |
+
if labels is not None:
|
1386 |
+
use_cache = False
|
1387 |
+
|
1388 |
+
outputs = self.bert(
|
1389 |
+
input_ids,
|
1390 |
+
attention_mask=attention_mask,
|
1391 |
+
token_type_ids=token_type_ids,
|
1392 |
+
position_ids=position_ids,
|
1393 |
+
head_mask=head_mask,
|
1394 |
+
inputs_embeds=inputs_embeds,
|
1395 |
+
encoder_hidden_states=encoder_hidden_states,
|
1396 |
+
encoder_attention_mask=encoder_attention_mask,
|
1397 |
+
past_key_values=past_key_values,
|
1398 |
+
use_cache=use_cache,
|
1399 |
+
output_attentions=output_attentions,
|
1400 |
+
output_hidden_states=output_hidden_states,
|
1401 |
+
return_dict=return_dict,
|
1402 |
+
is_decoder=is_decoder,
|
1403 |
+
mode=mode,
|
1404 |
+
normalize_attention=normalize_attention,
|
1405 |
+
)
|
1406 |
+
|
1407 |
+
sequence_output = outputs[0]
|
1408 |
+
prediction_scores = self.cls(sequence_output)
|
1409 |
+
|
1410 |
+
if return_logits:
|
1411 |
+
return prediction_scores[:, :-1, :].contiguous()
|
1412 |
+
|
1413 |
+
lm_loss = None
|
1414 |
+
if labels is not None:
|
1415 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1416 |
+
shifted_prediction_scores = prediction_scores[:,
|
1417 |
+
:-1, :].contiguous()
|
1418 |
+
labels = labels[:, 1:].contiguous()
|
1419 |
+
loss_fct = CrossEntropyLoss(reduction=reduction)
|
1420 |
+
lm_loss = loss_fct(
|
1421 |
+
shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1422 |
+
lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
|
1423 |
+
|
1424 |
+
if soft_labels is not None:
|
1425 |
+
loss_distill = - \
|
1426 |
+
torch.sum(F.log_softmax(shifted_prediction_scores,
|
1427 |
+
dim=1)*soft_labels, dim=-1)
|
1428 |
+
loss_distill = (loss_distill * (labels != -100)).sum(1)
|
1429 |
+
lm_loss = (1-alpha)*lm_loss + alpha*loss_distill
|
1430 |
+
|
1431 |
+
if not return_dict:
|
1432 |
+
output = (prediction_scores,) + outputs[2:]
|
1433 |
+
return ((lm_loss,) + output) if lm_loss is not None else output
|
1434 |
+
|
1435 |
+
return CausalLMOutputWithCrossAttentions(
|
1436 |
+
loss=lm_loss,
|
1437 |
+
logits=prediction_scores,
|
1438 |
+
past_key_values=outputs.past_key_values,
|
1439 |
+
hidden_states=outputs.hidden_states,
|
1440 |
+
attentions=outputs.attentions,
|
1441 |
+
cross_attentions=outputs.cross_attentions,
|
1442 |
+
)
|
1443 |
+
|
1444 |
+
def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
|
1445 |
+
input_shape = input_ids.shape
|
1446 |
+
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
1447 |
+
if attention_mask is None:
|
1448 |
+
attention_mask = input_ids.new_ones(input_shape)
|
1449 |
+
|
1450 |
+
# cut decoder_input_ids if past is used
|
1451 |
+
if past is not None:
|
1452 |
+
input_ids = input_ids[:, -1:]
|
1453 |
+
|
1454 |
+
return {
|
1455 |
+
"input_ids": input_ids,
|
1456 |
+
"attention_mask": attention_mask,
|
1457 |
+
"past_key_values": past,
|
1458 |
+
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
1459 |
+
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
1460 |
+
"is_decoder": True,
|
1461 |
+
}
|
1462 |
+
|
1463 |
+
def _reorder_cache(self, past, beam_idx):
|
1464 |
+
reordered_past = ()
|
1465 |
+
for layer_past in past:
|
1466 |
+
reordered_past += (tuple(past_state.index_select(0, beam_idx)
|
1467 |
+
for past_state in layer_past),)
|
1468 |
+
return reordered_past
|
1469 |
+
|
1470 |
+
|
1471 |
+
@dataclass
|
1472 |
+
class MaskedLMOutputWithDistill(MaskedLMOutput):
|
1473 |
+
loss_aux: Optional[torch.FloatTensor] = None
|
1474 |
+
loss_distill: Optional[torch.FloatTensor] = None
|
1475 |
+
|
1476 |
+
|
1477 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
1478 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1479 |
+
|
1480 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1481 |
+
_keys_to_ignore_on_load_missing = [
|
1482 |
+
r"position_ids", r"predictions.decoder.bias"]
|
1483 |
+
|
1484 |
+
def __init__(self, config):
|
1485 |
+
super().__init__(config)
|
1486 |
+
|
1487 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1488 |
+
self.cls = BertOnlyMLMHead(config)
|
1489 |
+
|
1490 |
+
self.init_weights()
|
1491 |
+
|
1492 |
+
def tie_aux_decoder_weights(self, module, aux_modules):
|
1493 |
+
"""Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
|
1494 |
+
for m in aux_modules:
|
1495 |
+
m.predictions.decoder.weight = module.predictions.decoder.weight
|
1496 |
+
|
1497 |
+
def get_output_embeddings(self):
|
1498 |
+
return self.cls.predictions.decoder
|
1499 |
+
|
1500 |
+
def set_output_embeddings(self, new_embeddings):
|
1501 |
+
self.cls.predictions.decoder = new_embeddings
|
1502 |
+
|
1503 |
+
def forward(
|
1504 |
+
self,
|
1505 |
+
input_ids=None,
|
1506 |
+
attention_mask=None,
|
1507 |
+
token_type_ids=None,
|
1508 |
+
position_ids=None,
|
1509 |
+
head_mask=None,
|
1510 |
+
inputs_embeds=None,
|
1511 |
+
encoder_embeds=None,
|
1512 |
+
encoder_hidden_states=None,
|
1513 |
+
encoder_attention_mask=None,
|
1514 |
+
labels=None,
|
1515 |
+
output_attentions=None,
|
1516 |
+
output_hidden_states=None,
|
1517 |
+
return_dict=None,
|
1518 |
+
is_decoder=False,
|
1519 |
+
mode='multi_modal',
|
1520 |
+
normalize_attention=True,
|
1521 |
+
soft_labels=None,
|
1522 |
+
alpha=0,
|
1523 |
+
return_logits=False,
|
1524 |
+
):
|
1525 |
+
r"""
|
1526 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1527 |
+
Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
|
1528 |
+
config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
|
1529 |
+
(masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
|
1530 |
+
"""
|
1531 |
+
|
1532 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1533 |
+
|
1534 |
+
outputs = self.bert(
|
1535 |
+
input_ids,
|
1536 |
+
attention_mask=attention_mask,
|
1537 |
+
token_type_ids=token_type_ids,
|
1538 |
+
position_ids=position_ids,
|
1539 |
+
head_mask=head_mask,
|
1540 |
+
inputs_embeds=inputs_embeds,
|
1541 |
+
encoder_embeds=encoder_embeds,
|
1542 |
+
encoder_hidden_states=encoder_hidden_states,
|
1543 |
+
encoder_attention_mask=encoder_attention_mask,
|
1544 |
+
output_attentions=output_attentions,
|
1545 |
+
output_hidden_states=output_hidden_states,
|
1546 |
+
return_dict=return_dict,
|
1547 |
+
is_decoder=is_decoder,
|
1548 |
+
mode=mode,
|
1549 |
+
normalize_attention=normalize_attention
|
1550 |
+
)
|
1551 |
+
|
1552 |
+
sequence_output = outputs[0]
|
1553 |
+
prediction_scores = self.cls(sequence_output)
|
1554 |
+
|
1555 |
+
if return_logits:
|
1556 |
+
return prediction_scores
|
1557 |
+
|
1558 |
+
masked_lm_loss = None
|
1559 |
+
masked_lm_loss_aux = 0.
|
1560 |
+
if labels is not None:
|
1561 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1562 |
+
masked_lm_loss = loss_fct(
|
1563 |
+
prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
|
1564 |
+
|
1565 |
+
if soft_labels is not None:
|
1566 |
+
loss_distill = - \
|
1567 |
+
torch.sum(F.log_softmax(prediction_scores, dim=1)
|
1568 |
+
* soft_labels, dim=-1)
|
1569 |
+
loss_distill = loss_distill[labels != -100].mean()
|
1570 |
+
masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill
|
1571 |
+
|
1572 |
+
if not return_dict:
|
1573 |
+
output = (prediction_scores,) + outputs[2:]
|
1574 |
+
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
|
1575 |
+
|
1576 |
+
# changed from MaskedLMOutput to MaskedLMOutputWithDistill
|
1577 |
+
return MaskedLMOutputWithDistill(
|
1578 |
+
loss=masked_lm_loss,
|
1579 |
+
loss_aux=masked_lm_loss_aux,
|
1580 |
+
logits=prediction_scores,
|
1581 |
+
hidden_states=outputs.hidden_states,
|
1582 |
+
attentions=outputs.attentions,
|
1583 |
+
)
|
1584 |
+
|
1585 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
|
1586 |
+
input_shape = input_ids.shape
|
1587 |
+
effective_batch_size = input_shape[0]
|
1588 |
+
|
1589 |
+
# add a dummy token
|
1590 |
+
assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
|
1591 |
+
attention_mask = torch.cat(
|
1592 |
+
[attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
|
1593 |
+
dummy_token = torch.full(
|
1594 |
+
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
|
1595 |
+
)
|
1596 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
1597 |
+
|
1598 |
+
return {"input_ids": input_ids, "attention_mask": attention_mask}
|
1599 |
+
|
1600 |
+
|
1601 |
+
@add_start_docstrings(
|
1602 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top. """,
|
1603 |
+
BERT_START_DOCSTRING,
|
1604 |
+
)
|
1605 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
1606 |
+
def __init__(self, config):
|
1607 |
+
super().__init__(config)
|
1608 |
+
|
1609 |
+
self.bert = BertModel(config)
|
1610 |
+
self.cls = BertOnlyNSPHead(config)
|
1611 |
+
|
1612 |
+
self.init_weights()
|
1613 |
+
|
1614 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
1615 |
+
@replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
|
1616 |
+
def forward(
|
1617 |
+
self,
|
1618 |
+
input_ids=None,
|
1619 |
+
attention_mask=None,
|
1620 |
+
token_type_ids=None,
|
1621 |
+
position_ids=None,
|
1622 |
+
head_mask=None,
|
1623 |
+
inputs_embeds=None,
|
1624 |
+
labels=None,
|
1625 |
+
output_attentions=None,
|
1626 |
+
output_hidden_states=None,
|
1627 |
+
return_dict=None,
|
1628 |
+
**kwargs
|
1629 |
+
):
|
1630 |
+
r"""
|
1631 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1632 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
|
1633 |
+
(see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
|
1634 |
+
- 0 indicates sequence B is a continuation of sequence A,
|
1635 |
+
- 1 indicates sequence B is a random sequence.
|
1636 |
+
Returns:
|
1637 |
+
Example::
|
1638 |
+
>>> from transformers import BertTokenizer, BertForNextSentencePrediction
|
1639 |
+
>>> import torch
|
1640 |
+
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1641 |
+
>>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
1642 |
+
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
|
1643 |
+
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
|
1644 |
+
>>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
|
1645 |
+
>>> outputs = model(**encoding, labels=torch.LongTensor([1]))
|
1646 |
+
>>> logits = outputs.logits
|
1647 |
+
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random
|
1648 |
+
"""
|
1649 |
+
|
1650 |
+
if "next_sentence_label" in kwargs:
|
1651 |
+
warnings.warn(
|
1652 |
+
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
|
1653 |
+
FutureWarning,
|
1654 |
+
)
|
1655 |
+
labels = kwargs.pop("next_sentence_label")
|
1656 |
+
|
1657 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1658 |
+
|
1659 |
+
outputs = self.bert(
|
1660 |
+
input_ids,
|
1661 |
+
attention_mask=attention_mask,
|
1662 |
+
token_type_ids=token_type_ids,
|
1663 |
+
position_ids=position_ids,
|
1664 |
+
head_mask=head_mask,
|
1665 |
+
inputs_embeds=inputs_embeds,
|
1666 |
+
output_attentions=output_attentions,
|
1667 |
+
output_hidden_states=output_hidden_states,
|
1668 |
+
return_dict=return_dict,
|
1669 |
+
)
|
1670 |
+
|
1671 |
+
pooled_output = outputs[1]
|
1672 |
+
|
1673 |
+
seq_relationship_scores = self.cls(pooled_output)
|
1674 |
+
|
1675 |
+
next_sentence_loss = None
|
1676 |
+
if labels is not None:
|
1677 |
+
loss_fct = CrossEntropyLoss()
|
1678 |
+
next_sentence_loss = loss_fct(
|
1679 |
+
seq_relationship_scores.view(-1, 2), labels.view(-1))
|
1680 |
+
|
1681 |
+
if not return_dict:
|
1682 |
+
output = (seq_relationship_scores,) + outputs[2:]
|
1683 |
+
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
|
1684 |
+
|
1685 |
+
return NextSentencePredictorOutput(
|
1686 |
+
loss=next_sentence_loss,
|
1687 |
+
logits=seq_relationship_scores,
|
1688 |
+
hidden_states=outputs.hidden_states,
|
1689 |
+
attentions=outputs.attentions,
|
1690 |
+
)
|
1691 |
+
|
1692 |
+
|
1693 |
+
@add_start_docstrings(
|
1694 |
+
"""
|
1695 |
+
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
|
1696 |
+
output) e.g. for GLUE tasks.
|
1697 |
+
""",
|
1698 |
+
BERT_START_DOCSTRING,
|
1699 |
+
)
|
1700 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
1701 |
+
def __init__(self, config):
|
1702 |
+
super().__init__(config)
|
1703 |
+
self.num_labels = config.num_labels
|
1704 |
+
|
1705 |
+
self.bert = BertModel(config)
|
1706 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1707 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1708 |
+
|
1709 |
+
self.init_weights()
|
1710 |
+
|
1711 |
+
def forward(
|
1712 |
+
self,
|
1713 |
+
input_ids=None,
|
1714 |
+
attention_mask=None,
|
1715 |
+
token_type_ids=None,
|
1716 |
+
position_ids=None,
|
1717 |
+
head_mask=None,
|
1718 |
+
inputs_embeds=None,
|
1719 |
+
labels=None,
|
1720 |
+
output_attentions=None,
|
1721 |
+
output_hidden_states=None,
|
1722 |
+
return_dict=None,
|
1723 |
+
):
|
1724 |
+
r"""
|
1725 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1726 |
+
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
|
1727 |
+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1728 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1729 |
+
"""
|
1730 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1731 |
+
|
1732 |
+
outputs = self.bert(
|
1733 |
+
input_ids,
|
1734 |
+
attention_mask=attention_mask,
|
1735 |
+
token_type_ids=token_type_ids,
|
1736 |
+
position_ids=position_ids,
|
1737 |
+
head_mask=head_mask,
|
1738 |
+
inputs_embeds=inputs_embeds,
|
1739 |
+
output_attentions=output_attentions,
|
1740 |
+
output_hidden_states=output_hidden_states,
|
1741 |
+
return_dict=return_dict,
|
1742 |
+
)
|
1743 |
+
|
1744 |
+
pooled_output = outputs[1]
|
1745 |
+
|
1746 |
+
pooled_output = self.dropout(pooled_output)
|
1747 |
+
logits = self.classifier(pooled_output)
|
1748 |
+
|
1749 |
+
loss = None
|
1750 |
+
if labels is not None:
|
1751 |
+
if self.num_labels == 1:
|
1752 |
+
# We are doing regression
|
1753 |
+
loss_fct = MSELoss()
|
1754 |
+
loss = loss_fct(logits.view(-1), labels.view(-1))
|
1755 |
+
else:
|
1756 |
+
loss_fct = CrossEntropyLoss()
|
1757 |
+
loss = loss_fct(
|
1758 |
+
logits.view(-1, self.num_labels), labels.view(-1))
|
1759 |
+
|
1760 |
+
if not return_dict:
|
1761 |
+
output = (logits,) + outputs[2:]
|
1762 |
+
return ((loss,) + output) if loss is not None else output
|
1763 |
+
|
1764 |
+
return SequenceClassifierOutput(
|
1765 |
+
loss=loss,
|
1766 |
+
logits=logits,
|
1767 |
+
hidden_states=outputs.hidden_states,
|
1768 |
+
attentions=outputs.attentions,
|
1769 |
+
)
|
1770 |
+
|
1771 |
+
|
1772 |
+
@add_start_docstrings(
|
1773 |
+
"""
|
1774 |
+
Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
|
1775 |
+
softmax) e.g. for RocStories/SWAG tasks.
|
1776 |
+
""",
|
1777 |
+
BERT_START_DOCSTRING,
|
1778 |
+
)
|
1779 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
1780 |
+
def __init__(self, config):
|
1781 |
+
super().__init__(config)
|
1782 |
+
|
1783 |
+
self.bert = BertModel(config)
|
1784 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1785 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
1786 |
+
|
1787 |
+
self.init_weights()
|
1788 |
+
|
1789 |
+
def forward(
|
1790 |
+
self,
|
1791 |
+
input_ids=None,
|
1792 |
+
attention_mask=None,
|
1793 |
+
token_type_ids=None,
|
1794 |
+
position_ids=None,
|
1795 |
+
head_mask=None,
|
1796 |
+
inputs_embeds=None,
|
1797 |
+
labels=None,
|
1798 |
+
output_attentions=None,
|
1799 |
+
output_hidden_states=None,
|
1800 |
+
return_dict=None,
|
1801 |
+
):
|
1802 |
+
r"""
|
1803 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1804 |
+
Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
|
1805 |
+
num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
|
1806 |
+
:obj:`input_ids` above)
|
1807 |
+
"""
|
1808 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1809 |
+
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
|
1810 |
+
|
1811 |
+
input_ids = input_ids.view(-1, input_ids.size(-1)
|
1812 |
+
) if input_ids is not None else None
|
1813 |
+
attention_mask = attention_mask.view(
|
1814 |
+
-1, attention_mask.size(-1)) if attention_mask is not None else None
|
1815 |
+
token_type_ids = token_type_ids.view(
|
1816 |
+
-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
1817 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)
|
1818 |
+
) if position_ids is not None else None
|
1819 |
+
inputs_embeds = (
|
1820 |
+
inputs_embeds.view(-1, inputs_embeds.size(-2),
|
1821 |
+
inputs_embeds.size(-1))
|
1822 |
+
if inputs_embeds is not None
|
1823 |
+
else None
|
1824 |
+
)
|
1825 |
+
|
1826 |
+
outputs = self.bert(
|
1827 |
+
input_ids,
|
1828 |
+
attention_mask=attention_mask,
|
1829 |
+
token_type_ids=token_type_ids,
|
1830 |
+
position_ids=position_ids,
|
1831 |
+
head_mask=head_mask,
|
1832 |
+
inputs_embeds=inputs_embeds,
|
1833 |
+
output_attentions=output_attentions,
|
1834 |
+
output_hidden_states=output_hidden_states,
|
1835 |
+
return_dict=return_dict,
|
1836 |
+
)
|
1837 |
+
|
1838 |
+
pooled_output = outputs[1]
|
1839 |
+
|
1840 |
+
pooled_output = self.dropout(pooled_output)
|
1841 |
+
logits = self.classifier(pooled_output)
|
1842 |
+
reshaped_logits = logits.view(-1, num_choices)
|
1843 |
+
|
1844 |
+
loss = None
|
1845 |
+
if labels is not None:
|
1846 |
+
loss_fct = CrossEntropyLoss()
|
1847 |
+
loss = loss_fct(reshaped_logits, labels)
|
1848 |
+
|
1849 |
+
if not return_dict:
|
1850 |
+
output = (reshaped_logits,) + outputs[2:]
|
1851 |
+
return ((loss,) + output) if loss is not None else output
|
1852 |
+
|
1853 |
+
return MultipleChoiceModelOutput(
|
1854 |
+
loss=loss,
|
1855 |
+
logits=reshaped_logits,
|
1856 |
+
hidden_states=outputs.hidden_states,
|
1857 |
+
attentions=outputs.attentions,
|
1858 |
+
)
|
1859 |
+
|
1860 |
+
|
1861 |
+
@add_start_docstrings(
|
1862 |
+
"""
|
1863 |
+
Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
|
1864 |
+
Named-Entity-Recognition (NER) tasks.
|
1865 |
+
""",
|
1866 |
+
BERT_START_DOCSTRING,
|
1867 |
+
)
|
1868 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
1869 |
+
|
1870 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1871 |
+
|
1872 |
+
def __init__(self, config):
|
1873 |
+
super().__init__(config)
|
1874 |
+
self.num_labels = config.num_labels
|
1875 |
+
|
1876 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1877 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1878 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1879 |
+
|
1880 |
+
self.init_weights()
|
1881 |
+
|
1882 |
+
def forward(
|
1883 |
+
self,
|
1884 |
+
input_ids=None,
|
1885 |
+
attention_mask=None,
|
1886 |
+
token_type_ids=None,
|
1887 |
+
position_ids=None,
|
1888 |
+
head_mask=None,
|
1889 |
+
inputs_embeds=None,
|
1890 |
+
labels=None,
|
1891 |
+
output_attentions=None,
|
1892 |
+
output_hidden_states=None,
|
1893 |
+
return_dict=None,
|
1894 |
+
):
|
1895 |
+
r"""
|
1896 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
1897 |
+
Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
|
1898 |
+
1]``.
|
1899 |
+
"""
|
1900 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1901 |
+
|
1902 |
+
outputs = self.bert(
|
1903 |
+
input_ids,
|
1904 |
+
attention_mask=attention_mask,
|
1905 |
+
token_type_ids=token_type_ids,
|
1906 |
+
position_ids=position_ids,
|
1907 |
+
head_mask=head_mask,
|
1908 |
+
inputs_embeds=inputs_embeds,
|
1909 |
+
output_attentions=output_attentions,
|
1910 |
+
output_hidden_states=output_hidden_states,
|
1911 |
+
return_dict=return_dict,
|
1912 |
+
)
|
1913 |
+
|
1914 |
+
sequence_output = outputs[0]
|
1915 |
+
|
1916 |
+
sequence_output = self.dropout(sequence_output)
|
1917 |
+
logits = self.classifier(sequence_output)
|
1918 |
+
|
1919 |
+
loss = None
|
1920 |
+
if labels is not None:
|
1921 |
+
loss_fct = CrossEntropyLoss()
|
1922 |
+
# Only keep active parts of the loss
|
1923 |
+
if attention_mask is not None:
|
1924 |
+
active_loss = attention_mask.view(-1) == 1
|
1925 |
+
active_logits = logits.view(-1, self.num_labels)
|
1926 |
+
active_labels = torch.where(
|
1927 |
+
active_loss, labels.view(-1), torch.tensor(
|
1928 |
+
loss_fct.ignore_index).type_as(labels)
|
1929 |
+
)
|
1930 |
+
loss = loss_fct(active_logits, active_labels)
|
1931 |
+
else:
|
1932 |
+
loss = loss_fct(
|
1933 |
+
logits.view(-1, self.num_labels), labels.view(-1))
|
1934 |
+
|
1935 |
+
if not return_dict:
|
1936 |
+
output = (logits,) + outputs[2:]
|
1937 |
+
return ((loss,) + output) if loss is not None else output
|
1938 |
+
|
1939 |
+
return TokenClassifierOutput(
|
1940 |
+
loss=loss,
|
1941 |
+
logits=logits,
|
1942 |
+
hidden_states=outputs.hidden_states,
|
1943 |
+
attentions=outputs.attentions,
|
1944 |
+
)
|
1945 |
+
|
1946 |
+
|
1947 |
+
@add_start_docstrings(
|
1948 |
+
"""
|
1949 |
+
Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1950 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
|
1951 |
+
""",
|
1952 |
+
BERT_START_DOCSTRING,
|
1953 |
+
)
|
1954 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
1955 |
+
|
1956 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
1957 |
+
|
1958 |
+
def __init__(self, config):
|
1959 |
+
super().__init__(config)
|
1960 |
+
self.num_labels = config.num_labels
|
1961 |
+
|
1962 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
1963 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1964 |
+
|
1965 |
+
self.init_weights()
|
1966 |
+
|
1967 |
+
def forward(
|
1968 |
+
self,
|
1969 |
+
input_ids=None,
|
1970 |
+
attention_mask=None,
|
1971 |
+
token_type_ids=None,
|
1972 |
+
position_ids=None,
|
1973 |
+
head_mask=None,
|
1974 |
+
inputs_embeds=None,
|
1975 |
+
start_positions=None,
|
1976 |
+
end_positions=None,
|
1977 |
+
output_attentions=None,
|
1978 |
+
output_hidden_states=None,
|
1979 |
+
return_dict=None,
|
1980 |
+
):
|
1981 |
+
r"""
|
1982 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1983 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1984 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
1985 |
+
sequence are not taken into account for computing the loss.
|
1986 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
|
1987 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1988 |
+
Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
|
1989 |
+
sequence are not taken into account for computing the loss.
|
1990 |
+
"""
|
1991 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
1992 |
+
|
1993 |
+
outputs = self.bert(
|
1994 |
+
input_ids,
|
1995 |
+
attention_mask=attention_mask,
|
1996 |
+
token_type_ids=token_type_ids,
|
1997 |
+
position_ids=position_ids,
|
1998 |
+
head_mask=head_mask,
|
1999 |
+
inputs_embeds=inputs_embeds,
|
2000 |
+
output_attentions=output_attentions,
|
2001 |
+
output_hidden_states=output_hidden_states,
|
2002 |
+
return_dict=return_dict,
|
2003 |
+
)
|
2004 |
+
|
2005 |
+
sequence_output = outputs[0]
|
2006 |
+
|
2007 |
+
logits = self.qa_outputs(sequence_output)
|
2008 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
2009 |
+
start_logits = start_logits.squeeze(-1)
|
2010 |
+
end_logits = end_logits.squeeze(-1)
|
2011 |
+
|
2012 |
+
total_loss = None
|
2013 |
+
if start_positions is not None and end_positions is not None:
|
2014 |
+
# If we are on multi-GPU, split add a dimension
|
2015 |
+
if len(start_positions.size()) > 1:
|
2016 |
+
start_positions = start_positions.squeeze(-1)
|
2017 |
+
if len(end_positions.size()) > 1:
|
2018 |
+
end_positions = end_positions.squeeze(-1)
|
2019 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
2020 |
+
ignored_index = start_logits.size(1)
|
2021 |
+
start_positions.clamp_(0, ignored_index)
|
2022 |
+
end_positions.clamp_(0, ignored_index)
|
2023 |
+
|
2024 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
2025 |
+
start_loss = loss_fct(start_logits, start_positions)
|
2026 |
+
end_loss = loss_fct(end_logits, end_positions)
|
2027 |
+
total_loss = (start_loss + end_loss) / 2
|
2028 |
+
|
2029 |
+
if not return_dict:
|
2030 |
+
output = (start_logits, end_logits) + outputs[2:]
|
2031 |
+
return ((total_loss,) + output) if total_loss is not None else output
|
2032 |
+
|
2033 |
+
return QuestionAnsweringModelOutput(
|
2034 |
+
loss=total_loss,
|
2035 |
+
start_logits=start_logits,
|
2036 |
+
end_logits=end_logits,
|
2037 |
+
hidden_states=outputs.hidden_states,
|
2038 |
+
attentions=outputs.attentions,
|
2039 |
+
)
|
svitt/tokenization_bert.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""Tokenization classes for Bert."""
|
16 |
+
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import os
|
20 |
+
import unicodedata
|
21 |
+
from typing import List, Optional, Tuple
|
22 |
+
|
23 |
+
from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
24 |
+
from transformers.utils import logging
|
25 |
+
|
26 |
+
|
27 |
+
logger = logging.get_logger(__name__)
|
28 |
+
|
29 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
30 |
+
|
31 |
+
PRETRAINED_VOCAB_FILES_MAP = {
|
32 |
+
"vocab_file": {
|
33 |
+
"bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
|
34 |
+
"bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
|
35 |
+
"bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
|
36 |
+
"bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
|
37 |
+
"bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
|
38 |
+
"bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
|
39 |
+
"bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
|
40 |
+
"bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
|
41 |
+
"bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
|
42 |
+
"bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
|
43 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
44 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
|
45 |
+
"bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
|
46 |
+
"bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
|
47 |
+
"bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
|
48 |
+
"TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
|
49 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
|
50 |
+
"wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
|
51 |
+
}
|
52 |
+
}
|
53 |
+
|
54 |
+
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
55 |
+
"bert-base-uncased": 512,
|
56 |
+
"bert-large-uncased": 512,
|
57 |
+
"bert-base-cased": 512,
|
58 |
+
"bert-large-cased": 512,
|
59 |
+
"bert-base-multilingual-uncased": 512,
|
60 |
+
"bert-base-multilingual-cased": 512,
|
61 |
+
"bert-base-chinese": 512,
|
62 |
+
"bert-base-german-cased": 512,
|
63 |
+
"bert-large-uncased-whole-word-masking": 512,
|
64 |
+
"bert-large-cased-whole-word-masking": 512,
|
65 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": 512,
|
66 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": 512,
|
67 |
+
"bert-base-cased-finetuned-mrpc": 512,
|
68 |
+
"bert-base-german-dbmdz-cased": 512,
|
69 |
+
"bert-base-german-dbmdz-uncased": 512,
|
70 |
+
"TurkuNLP/bert-base-finnish-cased-v1": 512,
|
71 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": 512,
|
72 |
+
"wietsedv/bert-base-dutch-cased": 512,
|
73 |
+
}
|
74 |
+
|
75 |
+
PRETRAINED_INIT_CONFIGURATION = {
|
76 |
+
"bert-base-uncased": {"do_lower_case": True},
|
77 |
+
"bert-large-uncased": {"do_lower_case": True},
|
78 |
+
"bert-base-cased": {"do_lower_case": False},
|
79 |
+
"bert-large-cased": {"do_lower_case": False},
|
80 |
+
"bert-base-multilingual-uncased": {"do_lower_case": True},
|
81 |
+
"bert-base-multilingual-cased": {"do_lower_case": False},
|
82 |
+
"bert-base-chinese": {"do_lower_case": False},
|
83 |
+
"bert-base-german-cased": {"do_lower_case": False},
|
84 |
+
"bert-large-uncased-whole-word-masking": {"do_lower_case": True},
|
85 |
+
"bert-large-cased-whole-word-masking": {"do_lower_case": False},
|
86 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
|
87 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
|
88 |
+
"bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
|
89 |
+
"bert-base-german-dbmdz-cased": {"do_lower_case": False},
|
90 |
+
"bert-base-german-dbmdz-uncased": {"do_lower_case": True},
|
91 |
+
"TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
|
92 |
+
"TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
|
93 |
+
"wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
|
94 |
+
}
|
95 |
+
|
96 |
+
|
97 |
+
def load_vocab(vocab_file):
|
98 |
+
"""Loads a vocabulary file into a dictionary."""
|
99 |
+
vocab = collections.OrderedDict()
|
100 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
101 |
+
tokens = reader.readlines()
|
102 |
+
for index, token in enumerate(tokens):
|
103 |
+
token = token.rstrip("\n")
|
104 |
+
vocab[token] = index
|
105 |
+
return vocab
|
106 |
+
|
107 |
+
|
108 |
+
def whitespace_tokenize(text):
|
109 |
+
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
110 |
+
text = text.strip()
|
111 |
+
if not text:
|
112 |
+
return []
|
113 |
+
tokens = text.split()
|
114 |
+
return tokens
|
115 |
+
|
116 |
+
|
117 |
+
class BertTokenizer(PreTrainedTokenizer):
|
118 |
+
r"""
|
119 |
+
Construct a BERT tokenizer. Based on WordPiece.
|
120 |
+
This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
|
121 |
+
Users should refer to this superclass for more information regarding those methods.
|
122 |
+
Args:
|
123 |
+
vocab_file (:obj:`str`):
|
124 |
+
File containing the vocabulary.
|
125 |
+
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
126 |
+
Whether or not to lowercase the input when tokenizing.
|
127 |
+
do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
128 |
+
Whether or not to do basic tokenization before WordPiece.
|
129 |
+
never_split (:obj:`Iterable`, `optional`):
|
130 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
131 |
+
:obj:`do_basic_tokenize=True`
|
132 |
+
unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
|
133 |
+
The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
|
134 |
+
token instead.
|
135 |
+
sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
|
136 |
+
The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
|
137 |
+
sequence classification or for a text and a question for question answering. It is also used as the last
|
138 |
+
token of a sequence built with special tokens.
|
139 |
+
pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
|
140 |
+
The token used for padding, for example when batching sequences of different lengths.
|
141 |
+
cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
|
142 |
+
The classifier token which is used when doing sequence classification (classification of the whole sequence
|
143 |
+
instead of per-token classification). It is the first token of the sequence when built with special tokens.
|
144 |
+
mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
|
145 |
+
The token used for masking values. This is the token used when training this model with masked language
|
146 |
+
modeling. This is the token which the model will try to predict.
|
147 |
+
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
148 |
+
Whether or not to tokenize Chinese characters.
|
149 |
+
This should likely be deactivated for Japanese (see this `issue
|
150 |
+
<https://github.com/huggingface/transformers/issues/328>`__).
|
151 |
+
strip_accents: (:obj:`bool`, `optional`):
|
152 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
153 |
+
value for :obj:`lowercase` (as in the original BERT).
|
154 |
+
"""
|
155 |
+
|
156 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
157 |
+
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
158 |
+
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
159 |
+
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
160 |
+
|
161 |
+
def __init__(
|
162 |
+
self,
|
163 |
+
vocab_file,
|
164 |
+
do_lower_case=True,
|
165 |
+
do_basic_tokenize=True,
|
166 |
+
never_split=None,
|
167 |
+
unk_token="[UNK]",
|
168 |
+
sep_token="[SEP]",
|
169 |
+
pad_token="[PAD]",
|
170 |
+
cls_token="[CLS]",
|
171 |
+
mask_token="[MASK]",
|
172 |
+
tokenize_chinese_chars=True,
|
173 |
+
strip_accents=None,
|
174 |
+
**kwargs
|
175 |
+
):
|
176 |
+
super().__init__(
|
177 |
+
do_lower_case=do_lower_case,
|
178 |
+
do_basic_tokenize=do_basic_tokenize,
|
179 |
+
never_split=never_split,
|
180 |
+
unk_token=unk_token,
|
181 |
+
sep_token=sep_token,
|
182 |
+
pad_token=pad_token,
|
183 |
+
cls_token=cls_token,
|
184 |
+
mask_token=mask_token,
|
185 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
186 |
+
strip_accents=strip_accents,
|
187 |
+
**kwargs,
|
188 |
+
)
|
189 |
+
|
190 |
+
if not os.path.isfile(vocab_file):
|
191 |
+
raise ValueError(
|
192 |
+
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
|
193 |
+
"model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
|
194 |
+
vocab_file)
|
195 |
+
)
|
196 |
+
self.vocab = load_vocab(vocab_file)
|
197 |
+
self.ids_to_tokens = collections.OrderedDict(
|
198 |
+
[(ids, tok) for tok, ids in self.vocab.items()])
|
199 |
+
self.do_basic_tokenize = do_basic_tokenize
|
200 |
+
if do_basic_tokenize:
|
201 |
+
self.basic_tokenizer = BasicTokenizer(
|
202 |
+
do_lower_case=do_lower_case,
|
203 |
+
never_split=never_split,
|
204 |
+
tokenize_chinese_chars=tokenize_chinese_chars,
|
205 |
+
strip_accents=strip_accents,
|
206 |
+
)
|
207 |
+
self.wordpiece_tokenizer = WordpieceTokenizer(
|
208 |
+
vocab=self.vocab, unk_token=self.unk_token)
|
209 |
+
|
210 |
+
@property
|
211 |
+
def do_lower_case(self):
|
212 |
+
return self.basic_tokenizer.do_lower_case
|
213 |
+
|
214 |
+
@property
|
215 |
+
def vocab_size(self):
|
216 |
+
return len(self.vocab)
|
217 |
+
|
218 |
+
def get_vocab(self):
|
219 |
+
return dict(self.vocab, **self.added_tokens_encoder)
|
220 |
+
|
221 |
+
def _tokenize(self, text):
|
222 |
+
split_tokens = []
|
223 |
+
if self.do_basic_tokenize:
|
224 |
+
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
225 |
+
|
226 |
+
# If the token is part of the never_split set
|
227 |
+
if token in self.basic_tokenizer.never_split:
|
228 |
+
split_tokens.append(token)
|
229 |
+
else:
|
230 |
+
split_tokens += self.wordpiece_tokenizer.tokenize(token)
|
231 |
+
else:
|
232 |
+
split_tokens = self.wordpiece_tokenizer.tokenize(text)
|
233 |
+
return split_tokens
|
234 |
+
|
235 |
+
def _convert_token_to_id(self, token):
|
236 |
+
""" Converts a token (str) in an id using the vocab. """
|
237 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
238 |
+
|
239 |
+
def _convert_id_to_token(self, index):
|
240 |
+
"""Converts an index (integer) in a token (str) using the vocab."""
|
241 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
242 |
+
|
243 |
+
def convert_tokens_to_string(self, tokens):
|
244 |
+
""" Converts a sequence of tokens (string) in a single string. """
|
245 |
+
out_string = " ".join(tokens).replace(" ##", "").strip()
|
246 |
+
return out_string
|
247 |
+
|
248 |
+
def build_inputs_with_special_tokens(
|
249 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
250 |
+
) -> List[int]:
|
251 |
+
"""
|
252 |
+
Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
|
253 |
+
adding special tokens. A BERT sequence has the following format:
|
254 |
+
- single sequence: ``[CLS] X ``
|
255 |
+
- pair of sequences: ``[CLS] A [SEP] B [SEP]``
|
256 |
+
Args:
|
257 |
+
token_ids_0 (:obj:`List[int]`):
|
258 |
+
List of IDs to which the special tokens will be added.
|
259 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
260 |
+
Optional second list of IDs for sequence pairs.
|
261 |
+
Returns:
|
262 |
+
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
263 |
+
"""
|
264 |
+
if token_ids_1 is None:
|
265 |
+
return [self.cls_token_id] + token_ids_0
|
266 |
+
cls = [self.cls_token_id]
|
267 |
+
sep = [self.sep_token_id]
|
268 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
269 |
+
|
270 |
+
def get_special_tokens_mask(
|
271 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
272 |
+
) -> List[int]:
|
273 |
+
"""
|
274 |
+
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
275 |
+
special tokens using the tokenizer ``prepare_for_model`` method.
|
276 |
+
Args:
|
277 |
+
token_ids_0 (:obj:`List[int]`):
|
278 |
+
List of IDs.
|
279 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
280 |
+
Optional second list of IDs for sequence pairs.
|
281 |
+
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
282 |
+
Whether or not the token list is already formatted with special tokens for the model.
|
283 |
+
Returns:
|
284 |
+
:obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
285 |
+
"""
|
286 |
+
|
287 |
+
if already_has_special_tokens:
|
288 |
+
if token_ids_1 is not None:
|
289 |
+
raise ValueError(
|
290 |
+
"You should not supply a second sequence if the provided sequence of "
|
291 |
+
"ids is already formatted with special tokens for the model."
|
292 |
+
)
|
293 |
+
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
294 |
+
|
295 |
+
if token_ids_1 is not None:
|
296 |
+
return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
297 |
+
return [1] + ([0] * len(token_ids_0)) + [1]
|
298 |
+
|
299 |
+
def create_token_type_ids_from_sequences(
|
300 |
+
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
301 |
+
) -> List[int]:
|
302 |
+
"""
|
303 |
+
Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
|
304 |
+
pair mask has the following format:
|
305 |
+
::
|
306 |
+
0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
|
307 |
+
| first sequence | second sequence |
|
308 |
+
If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
|
309 |
+
Args:
|
310 |
+
token_ids_0 (:obj:`List[int]`):
|
311 |
+
List of IDs.
|
312 |
+
token_ids_1 (:obj:`List[int]`, `optional`):
|
313 |
+
Optional second list of IDs for sequence pairs.
|
314 |
+
Returns:
|
315 |
+
:obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
|
316 |
+
sequence(s).
|
317 |
+
"""
|
318 |
+
sep = [self.sep_token_id]
|
319 |
+
cls = [self.cls_token_id]
|
320 |
+
if token_ids_1 is None:
|
321 |
+
return len(cls + token_ids_0 + sep) * [0]
|
322 |
+
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
323 |
+
|
324 |
+
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
325 |
+
index = 0
|
326 |
+
if os.path.isdir(save_directory):
|
327 |
+
vocab_file = os.path.join(
|
328 |
+
save_directory, (filename_prefix + "-" if filename_prefix else "") +
|
329 |
+
VOCAB_FILES_NAMES["vocab_file"]
|
330 |
+
)
|
331 |
+
else:
|
332 |
+
vocab_file = (filename_prefix +
|
333 |
+
"-" if filename_prefix else "") + save_directory
|
334 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
335 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
336 |
+
if index != token_index:
|
337 |
+
logger.warning(
|
338 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
339 |
+
" Please check that the vocabulary is not corrupted!".format(
|
340 |
+
vocab_file)
|
341 |
+
)
|
342 |
+
index = token_index
|
343 |
+
writer.write(token + "\n")
|
344 |
+
index += 1
|
345 |
+
return (vocab_file,)
|
346 |
+
|
347 |
+
|
348 |
+
class BasicTokenizer(object):
|
349 |
+
"""
|
350 |
+
Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
|
351 |
+
Args:
|
352 |
+
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
353 |
+
Whether or not to lowercase the input when tokenizing.
|
354 |
+
never_split (:obj:`Iterable`, `optional`):
|
355 |
+
Collection of tokens which will never be split during tokenization. Only has an effect when
|
356 |
+
:obj:`do_basic_tokenize=True`
|
357 |
+
tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
358 |
+
Whether or not to tokenize Chinese characters.
|
359 |
+
This should likely be deactivated for Japanese (see this `issue
|
360 |
+
<https://github.com/huggingface/transformers/issues/328>`__).
|
361 |
+
strip_accents: (:obj:`bool`, `optional`):
|
362 |
+
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
|
363 |
+
value for :obj:`lowercase` (as in the original BERT).
|
364 |
+
"""
|
365 |
+
|
366 |
+
def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
|
367 |
+
if never_split is None:
|
368 |
+
never_split = []
|
369 |
+
self.do_lower_case = do_lower_case
|
370 |
+
self.never_split = set(never_split)
|
371 |
+
self.tokenize_chinese_chars = tokenize_chinese_chars
|
372 |
+
self.strip_accents = strip_accents
|
373 |
+
|
374 |
+
def tokenize(self, text, never_split=None):
|
375 |
+
"""
|
376 |
+
Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
|
377 |
+
WordPieceTokenizer.
|
378 |
+
Args:
|
379 |
+
**never_split**: (`optional`) list of str
|
380 |
+
Kept for backward compatibility purposes. Now implemented directly at the base class level (see
|
381 |
+
:func:`PreTrainedTokenizer.tokenize`) List of token not to split.
|
382 |
+
"""
|
383 |
+
# union() returns a new set by concatenating the two sets.
|
384 |
+
never_split = self.never_split.union(
|
385 |
+
set(never_split)) if never_split else self.never_split
|
386 |
+
text = self._clean_text(text)
|
387 |
+
|
388 |
+
# This was added on November 1st, 2018 for the multilingual and Chinese
|
389 |
+
# models. This is also applied to the English models now, but it doesn't
|
390 |
+
# matter since the English models were not trained on any Chinese data
|
391 |
+
# and generally don't have any Chinese data in them (there are Chinese
|
392 |
+
# characters in the vocabulary because Wikipedia does have some Chinese
|
393 |
+
# words in the English Wikipedia.).
|
394 |
+
if self.tokenize_chinese_chars:
|
395 |
+
text = self._tokenize_chinese_chars(text)
|
396 |
+
orig_tokens = whitespace_tokenize(text)
|
397 |
+
split_tokens = []
|
398 |
+
for token in orig_tokens:
|
399 |
+
if token not in never_split:
|
400 |
+
if self.do_lower_case:
|
401 |
+
token = token.lower()
|
402 |
+
if self.strip_accents is not False:
|
403 |
+
token = self._run_strip_accents(token)
|
404 |
+
elif self.strip_accents:
|
405 |
+
token = self._run_strip_accents(token)
|
406 |
+
split_tokens.extend(self._run_split_on_punc(token, never_split))
|
407 |
+
|
408 |
+
output_tokens = whitespace_tokenize(" ".join(split_tokens))
|
409 |
+
return output_tokens
|
410 |
+
|
411 |
+
def _run_strip_accents(self, text):
|
412 |
+
"""Strips accents from a piece of text."""
|
413 |
+
text = unicodedata.normalize("NFD", text)
|
414 |
+
output = []
|
415 |
+
for char in text:
|
416 |
+
cat = unicodedata.category(char)
|
417 |
+
if cat == "Mn":
|
418 |
+
continue
|
419 |
+
output.append(char)
|
420 |
+
return "".join(output)
|
421 |
+
|
422 |
+
def _run_split_on_punc(self, text, never_split=None):
|
423 |
+
"""Splits punctuation on a piece of text."""
|
424 |
+
if never_split is not None and text in never_split:
|
425 |
+
return [text]
|
426 |
+
chars = list(text)
|
427 |
+
i = 0
|
428 |
+
start_new_word = True
|
429 |
+
output = []
|
430 |
+
while i < len(chars):
|
431 |
+
char = chars[i]
|
432 |
+
if _is_punctuation(char):
|
433 |
+
output.append([char])
|
434 |
+
start_new_word = True
|
435 |
+
else:
|
436 |
+
if start_new_word:
|
437 |
+
output.append([])
|
438 |
+
start_new_word = False
|
439 |
+
output[-1].append(char)
|
440 |
+
i += 1
|
441 |
+
|
442 |
+
return ["".join(x) for x in output]
|
443 |
+
|
444 |
+
def _tokenize_chinese_chars(self, text):
|
445 |
+
"""Adds whitespace around any CJK character."""
|
446 |
+
output = []
|
447 |
+
for char in text:
|
448 |
+
cp = ord(char)
|
449 |
+
if self._is_chinese_char(cp):
|
450 |
+
output.append(" ")
|
451 |
+
output.append(char)
|
452 |
+
output.append(" ")
|
453 |
+
else:
|
454 |
+
output.append(char)
|
455 |
+
return "".join(output)
|
456 |
+
|
457 |
+
def _is_chinese_char(self, cp):
|
458 |
+
"""Checks whether CP is the codepoint of a CJK character."""
|
459 |
+
# This defines a "chinese character" as anything in the CJK Unicode block:
|
460 |
+
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
|
461 |
+
#
|
462 |
+
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
|
463 |
+
# despite its name. The modern Korean Hangul alphabet is a different block,
|
464 |
+
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
|
465 |
+
# space-separated words, so they are not treated specially and handled
|
466 |
+
# like the all of the other languages.
|
467 |
+
if (
|
468 |
+
(cp >= 0x4E00 and cp <= 0x9FFF)
|
469 |
+
or (cp >= 0x3400 and cp <= 0x4DBF) #
|
470 |
+
or (cp >= 0x20000 and cp <= 0x2A6DF) #
|
471 |
+
or (cp >= 0x2A700 and cp <= 0x2B73F) #
|
472 |
+
or (cp >= 0x2B740 and cp <= 0x2B81F) #
|
473 |
+
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
|
474 |
+
or (cp >= 0xF900 and cp <= 0xFAFF)
|
475 |
+
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
|
476 |
+
): #
|
477 |
+
return True
|
478 |
+
|
479 |
+
return False
|
480 |
+
|
481 |
+
def _clean_text(self, text):
|
482 |
+
"""Performs invalid character removal and whitespace cleanup on text."""
|
483 |
+
output = []
|
484 |
+
for char in text:
|
485 |
+
cp = ord(char)
|
486 |
+
if cp == 0 or cp == 0xFFFD or _is_control(char):
|
487 |
+
continue
|
488 |
+
if _is_whitespace(char):
|
489 |
+
output.append(" ")
|
490 |
+
else:
|
491 |
+
output.append(char)
|
492 |
+
return "".join(output)
|
493 |
+
|
494 |
+
|
495 |
+
class WordpieceTokenizer(object):
|
496 |
+
"""Runs WordPiece tokenization."""
|
497 |
+
|
498 |
+
def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
|
499 |
+
self.vocab = vocab
|
500 |
+
self.unk_token = unk_token
|
501 |
+
self.max_input_chars_per_word = max_input_chars_per_word
|
502 |
+
|
503 |
+
def tokenize(self, text):
|
504 |
+
"""
|
505 |
+
Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
|
506 |
+
tokenization using the given vocabulary.
|
507 |
+
For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
|
508 |
+
Args:
|
509 |
+
text: A single token or whitespace separated tokens. This should have
|
510 |
+
already been passed through `BasicTokenizer`.
|
511 |
+
Returns:
|
512 |
+
A list of wordpiece tokens.
|
513 |
+
"""
|
514 |
+
|
515 |
+
output_tokens = []
|
516 |
+
for token in whitespace_tokenize(text):
|
517 |
+
chars = list(token)
|
518 |
+
if len(chars) > self.max_input_chars_per_word:
|
519 |
+
output_tokens.append(self.unk_token)
|
520 |
+
continue
|
521 |
+
|
522 |
+
is_bad = False
|
523 |
+
start = 0
|
524 |
+
sub_tokens = []
|
525 |
+
while start < len(chars):
|
526 |
+
end = len(chars)
|
527 |
+
cur_substr = None
|
528 |
+
while start < end:
|
529 |
+
substr = "".join(chars[start:end])
|
530 |
+
if start > 0:
|
531 |
+
substr = "##" + substr
|
532 |
+
if substr in self.vocab:
|
533 |
+
cur_substr = substr
|
534 |
+
break
|
535 |
+
end -= 1
|
536 |
+
if cur_substr is None:
|
537 |
+
is_bad = True
|
538 |
+
break
|
539 |
+
sub_tokens.append(cur_substr)
|
540 |
+
start = end
|
541 |
+
|
542 |
+
if is_bad:
|
543 |
+
output_tokens.append(self.unk_token)
|
544 |
+
else:
|
545 |
+
output_tokens.extend(sub_tokens)
|
546 |
+
return output_tokens
|
svitt/utils.py
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
from scipy import interpolate
|
5 |
+
import numpy as np
|
6 |
+
from einops import rearrange, repeat
|
7 |
+
|
8 |
+
|
9 |
+
def _init_transformer_weights(module, initializer_range=0.02):
|
10 |
+
"""Initialize the weights. Copied from transformers ViT/Bert model init"""
|
11 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
12 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
13 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
14 |
+
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
15 |
+
if module.bias is not None:
|
16 |
+
module.bias.data.zero_()
|
17 |
+
elif isinstance(module, nn.Embedding):
|
18 |
+
module.weight.data.normal_(mean=0.0, std=initializer_range)
|
19 |
+
if module.padding_idx is not None:
|
20 |
+
module.weight.data[module.padding_idx].zero_()
|
21 |
+
elif isinstance(module, nn.LayerNorm):
|
22 |
+
module.bias.data.zero_()
|
23 |
+
module.weight.data.fill_(1.0)
|
24 |
+
|
25 |
+
|
26 |
+
def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
|
27 |
+
"""
|
28 |
+
Args:
|
29 |
+
pos_embed_old: (1, L_old, d), pre-trained
|
30 |
+
pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
|
31 |
+
num_patches_new:
|
32 |
+
"""
|
33 |
+
# interpolate position embedding
|
34 |
+
embedding_size = pos_embed_old.shape[-1]
|
35 |
+
num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
|
36 |
+
# height (== width) for the checkpoint position embedding
|
37 |
+
orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
|
38 |
+
# height (== width) for the new position embedding
|
39 |
+
new_size = int(num_patches_new ** 0.5)
|
40 |
+
|
41 |
+
if orig_size != new_size:
|
42 |
+
# class_token and dist_token are kept unchanged
|
43 |
+
# the extra tokens seems always at the beginning of the position embedding
|
44 |
+
extra_tokens = pos_embed_old[:, :num_extra_tokens]
|
45 |
+
# only the position tokens are interpolated
|
46 |
+
pos_tokens = pos_embed_old[:, num_extra_tokens:]
|
47 |
+
pos_tokens = pos_tokens.reshape(
|
48 |
+
-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
|
49 |
+
pos_tokens = torch.nn.functional.interpolate(
|
50 |
+
pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
|
51 |
+
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
|
52 |
+
interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
|
53 |
+
return interpolated_pos_embed
|
54 |
+
else:
|
55 |
+
return pos_embed_old
|
56 |
+
|
57 |
+
|
58 |
+
def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
|
59 |
+
"""
|
60 |
+
Args:
|
61 |
+
state_dict_old: loaded state dict
|
62 |
+
state_dict_new: state dict for model with new image size
|
63 |
+
patch_shape_new: new model patch_shape
|
64 |
+
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
|
65 |
+
"""
|
66 |
+
all_keys = list(state_dict_old.keys())
|
67 |
+
for key in all_keys:
|
68 |
+
if "relative_position_index" in key:
|
69 |
+
state_dict_old.pop(key)
|
70 |
+
|
71 |
+
if "relative_position_bias_table" in key:
|
72 |
+
rel_pos_bias = state_dict_old[key]
|
73 |
+
src_num_pos, num_attn_heads = rel_pos_bias.size()
|
74 |
+
dst_num_pos, _ = state_dict_new[key].size()
|
75 |
+
dst_patch_shape = patch_shape_new
|
76 |
+
if dst_patch_shape[0] != dst_patch_shape[1]:
|
77 |
+
raise NotImplementedError()
|
78 |
+
num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
|
79 |
+
src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
|
80 |
+
dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
|
81 |
+
if src_size != dst_size:
|
82 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
83 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
84 |
+
|
85 |
+
def geometric_progression(a, r, n):
|
86 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
87 |
+
|
88 |
+
left, right = 1.01, 1.5
|
89 |
+
while right - left > 1e-6:
|
90 |
+
q = (left + right) / 2.0
|
91 |
+
gp = geometric_progression(1, q, src_size // 2)
|
92 |
+
if gp > dst_size // 2:
|
93 |
+
right = q
|
94 |
+
else:
|
95 |
+
left = q
|
96 |
+
|
97 |
+
# if q > 1.090307:
|
98 |
+
# q = 1.090307
|
99 |
+
|
100 |
+
dis = []
|
101 |
+
cur = 1
|
102 |
+
for i in range(src_size // 2):
|
103 |
+
dis.append(cur)
|
104 |
+
cur += q ** (i + 1)
|
105 |
+
|
106 |
+
r_ids = [-_ for _ in reversed(dis)]
|
107 |
+
|
108 |
+
x = r_ids + [0] + dis
|
109 |
+
y = r_ids + [0] + dis
|
110 |
+
|
111 |
+
t = dst_size // 2.0
|
112 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
113 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
114 |
+
|
115 |
+
all_rel_pos_bias = []
|
116 |
+
|
117 |
+
for i in range(num_attn_heads):
|
118 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
119 |
+
f = interpolate.interp2d(x, y, z, kind='cubic')
|
120 |
+
all_rel_pos_bias.append(
|
121 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
122 |
+
|
123 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
124 |
+
|
125 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
126 |
+
state_dict_old[key] = new_rel_pos_bias
|
127 |
+
return state_dict_old
|
128 |
+
|
129 |
+
|
130 |
+
def interpolate_pos_relative_bias_beit_3d(state_dict_old, state_dict_new, patch_shape_new, src_t_size=1):
|
131 |
+
"""
|
132 |
+
Args:
|
133 |
+
state_dict_old: loaded state dict
|
134 |
+
state_dict_new: state dict for model with new image size
|
135 |
+
patch_shape_new: new model patch_shape
|
136 |
+
ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
|
137 |
+
"""
|
138 |
+
all_keys = list(state_dict_old.keys())
|
139 |
+
for key in all_keys:
|
140 |
+
if "relative_position_index" in key:
|
141 |
+
state_dict_old.pop(key)
|
142 |
+
|
143 |
+
if "relative_position_bias_table" in key:
|
144 |
+
src_num_pos, num_attn_heads = state_dict_old[key].size()
|
145 |
+
dst_num_pos, _ = state_dict_new[key].size()
|
146 |
+
if src_num_pos == dst_num_pos:
|
147 |
+
continue
|
148 |
+
|
149 |
+
num_extra_tokens = dst_num_pos - np.prod([w * 2 - 1 for w in patch_shape_new])
|
150 |
+
|
151 |
+
src_s_size = int((src_num_pos - num_extra_tokens) / src_t_size)
|
152 |
+
src_size = int(src_s_size ** 0.5)
|
153 |
+
dst_size = patch_shape_new[-1] * 2 - 1
|
154 |
+
|
155 |
+
if src_size != dst_size:
|
156 |
+
# Spatial interpolation
|
157 |
+
rel_pos_bias = state_dict_old[key]
|
158 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
159 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
160 |
+
|
161 |
+
def geometric_progression(a, r, n):
|
162 |
+
return a * (1.0 - r ** n) / (1.0 - r)
|
163 |
+
|
164 |
+
left, right = 1.01, 1.5
|
165 |
+
while right - left > 1e-6:
|
166 |
+
q = (left + right) / 2.0
|
167 |
+
gp = geometric_progression(1, q, src_size // 2)
|
168 |
+
if gp > dst_size // 2:
|
169 |
+
right = q
|
170 |
+
else:
|
171 |
+
left = q
|
172 |
+
|
173 |
+
# if q > 1.090307:
|
174 |
+
# q = 1.090307
|
175 |
+
|
176 |
+
dis = []
|
177 |
+
cur = 1
|
178 |
+
for i in range(src_size // 2):
|
179 |
+
dis.append(cur)
|
180 |
+
cur += q ** (i + 1)
|
181 |
+
|
182 |
+
r_ids = [-_ for _ in reversed(dis)]
|
183 |
+
|
184 |
+
x = r_ids + [0] + dis
|
185 |
+
y = r_ids + [0] + dis
|
186 |
+
|
187 |
+
t = dst_size // 2.0
|
188 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
189 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
190 |
+
|
191 |
+
all_rel_pos_bias = []
|
192 |
+
|
193 |
+
for i in range(num_attn_heads):
|
194 |
+
z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
|
195 |
+
f = interpolate.interp2d(x, y, z, kind='cubic')
|
196 |
+
all_rel_pos_bias.append(
|
197 |
+
torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
|
198 |
+
|
199 |
+
rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
200 |
+
|
201 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
202 |
+
state_dict_old[key] = new_rel_pos_bias
|
203 |
+
|
204 |
+
dst_t_size = patch_shape_new[0] * 2 - 1
|
205 |
+
if src_t_size != dst_t_size:
|
206 |
+
# Temporal interpolation
|
207 |
+
rel_pos_bias = state_dict_old[key]
|
208 |
+
extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
|
209 |
+
rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
|
210 |
+
|
211 |
+
if src_t_size == 1:
|
212 |
+
rel_pos_bias = repeat(rel_pos_bias, 's d -> (t s) d', t=dst_t_size)
|
213 |
+
else:
|
214 |
+
rel_pos_bias = rearrange(rel_pos_bias, '(t s) d -> s d t', t=src_t_size)
|
215 |
+
rel_pos_bias = F.interpolate(rel_pos_bias, dst_t_size, mode='nearest')
|
216 |
+
rel_pos_bias = rearrange(rel_pos_bias, 's d t -> (t s) d')
|
217 |
+
new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
|
218 |
+
state_dict_old[key] = new_rel_pos_bias
|
219 |
+
|
220 |
+
return state_dict_old
|
221 |
+
|
222 |
+
|
223 |
+
def tile(x, dim, n_tile):
|
224 |
+
init_dim = x.size(dim)
|
225 |
+
repeat_idx = [1] * x.dim()
|
226 |
+
repeat_idx[dim] = n_tile
|
227 |
+
x = x.repeat(*repeat_idx)
|
228 |
+
order_index = torch.LongTensor(np.concatenate(
|
229 |
+
[init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
|
230 |
+
return torch.index_select(x, dim, order_index.to(x.device))
|
231 |
+
|
232 |
+
|
233 |
+
def mask_logits(target, mask):
|
234 |
+
return target * mask + (1 - mask) * (-1e10)
|
235 |
+
|
svitt/video_transforms.py
ADDED
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import math
|
8 |
+
from typing import Sequence
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
from torchvision import transforms
|
12 |
+
|
13 |
+
|
14 |
+
class Permute(nn.Module):
|
15 |
+
"""
|
16 |
+
Permutation as an op
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(self, ordering):
|
20 |
+
super().__init__()
|
21 |
+
self.ordering = ordering
|
22 |
+
|
23 |
+
def forward(self, frames):
|
24 |
+
"""
|
25 |
+
Args:
|
26 |
+
frames in some ordering, by default (C, T, H, W)
|
27 |
+
Returns:
|
28 |
+
frames in the ordering that was specified
|
29 |
+
"""
|
30 |
+
return frames.permute(self.ordering)
|
31 |
+
|
32 |
+
|
33 |
+
class TemporalCrop(nn.Module):
|
34 |
+
"""
|
35 |
+
Convert the video into smaller clips temporally.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
|
40 |
+
):
|
41 |
+
super().__init__()
|
42 |
+
self.frames = frames_per_clip
|
43 |
+
self.stride = stride
|
44 |
+
self.frame_stride = frame_stride
|
45 |
+
|
46 |
+
def forward(self, video):
|
47 |
+
assert video.ndim == 4, "Must be (C, T, H, W)"
|
48 |
+
res = []
|
49 |
+
for start in range(
|
50 |
+
0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
|
51 |
+
):
|
52 |
+
end = start + (self.frames) * self.frame_stride
|
53 |
+
res.append(video[:, start: end: self.frame_stride, ...])
|
54 |
+
return res
|
55 |
+
|
56 |
+
|
57 |
+
def crop_boxes(boxes, x_offset, y_offset):
|
58 |
+
"""
|
59 |
+
Peform crop on the bounding boxes given the offsets.
|
60 |
+
Args:
|
61 |
+
boxes (ndarray or None): bounding boxes to peform crop. The dimension
|
62 |
+
is `num boxes` x 4.
|
63 |
+
x_offset (int): cropping offset in the x axis.
|
64 |
+
y_offset (int): cropping offset in the y axis.
|
65 |
+
Returns:
|
66 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
67 |
+
`num boxes` x 4.
|
68 |
+
"""
|
69 |
+
cropped_boxes = boxes.copy()
|
70 |
+
cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
|
71 |
+
cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
|
72 |
+
|
73 |
+
return cropped_boxes
|
74 |
+
|
75 |
+
|
76 |
+
def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
|
77 |
+
"""
|
78 |
+
Perform uniform spatial sampling on the images and corresponding boxes.
|
79 |
+
Args:
|
80 |
+
images (tensor): images to perform uniform crop. The dimension is
|
81 |
+
`num frames` x `channel` x `height` x `width`.
|
82 |
+
size (int): size of height and weight to crop the images.
|
83 |
+
spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
|
84 |
+
is larger than height. Or 0, 1, or 2 for top, center, and bottom
|
85 |
+
crop if height is larger than width.
|
86 |
+
boxes (ndarray or None): optional. Corresponding boxes to images.
|
87 |
+
Dimension is `num boxes` x 4.
|
88 |
+
scale_size (int): optinal. If not None, resize the images to scale_size before
|
89 |
+
performing any crop.
|
90 |
+
Returns:
|
91 |
+
cropped (tensor): images with dimension of
|
92 |
+
`num frames` x `channel` x `size` x `size`.
|
93 |
+
cropped_boxes (ndarray or None): the cropped boxes with dimension of
|
94 |
+
`num boxes` x 4.
|
95 |
+
"""
|
96 |
+
assert spatial_idx in [0, 1, 2]
|
97 |
+
ndim = len(images.shape)
|
98 |
+
if ndim == 3:
|
99 |
+
images = images.unsqueeze(0)
|
100 |
+
height = images.shape[2]
|
101 |
+
width = images.shape[3]
|
102 |
+
|
103 |
+
if scale_size is not None:
|
104 |
+
if width <= height:
|
105 |
+
width, height = scale_size, int(height / width * scale_size)
|
106 |
+
else:
|
107 |
+
width, height = int(width / height * scale_size), scale_size
|
108 |
+
images = torch.nn.functional.interpolate(
|
109 |
+
images,
|
110 |
+
size=(height, width),
|
111 |
+
mode="bilinear",
|
112 |
+
align_corners=False,
|
113 |
+
)
|
114 |
+
|
115 |
+
y_offset = int(math.ceil((height - size) / 2))
|
116 |
+
x_offset = int(math.ceil((width - size) / 2))
|
117 |
+
|
118 |
+
if height > width:
|
119 |
+
if spatial_idx == 0:
|
120 |
+
y_offset = 0
|
121 |
+
elif spatial_idx == 2:
|
122 |
+
y_offset = height - size
|
123 |
+
else:
|
124 |
+
if spatial_idx == 0:
|
125 |
+
x_offset = 0
|
126 |
+
elif spatial_idx == 2:
|
127 |
+
x_offset = width - size
|
128 |
+
cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
|
129 |
+
cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
|
130 |
+
if ndim == 3:
|
131 |
+
cropped = cropped.squeeze(0)
|
132 |
+
return cropped, cropped_boxes
|
133 |
+
|
134 |
+
|
135 |
+
class SpatialCrop(nn.Module):
|
136 |
+
"""
|
137 |
+
Convert the video into 3 smaller clips spatially. Must be used after the
|
138 |
+
temporal crops to get spatial crops, and should be used with
|
139 |
+
-2 in the spatial crop at the slowfast augmentation stage (so full
|
140 |
+
frames are passed in here). Will return a larger list with the
|
141 |
+
3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
|
142 |
+
or 3x10 testing in SlowFast etc.
|
143 |
+
"""
|
144 |
+
|
145 |
+
def __init__(self, crop_size: int = 224, num_crops: int = 3):
|
146 |
+
super().__init__()
|
147 |
+
self.crop_size = crop_size
|
148 |
+
if num_crops == 6:
|
149 |
+
self.crops_to_ext = [0, 1, 2]
|
150 |
+
# I guess Swin uses 5 crops without flipping, but that doesn't
|
151 |
+
# make sense given they first resize to 224 and take 224 crops.
|
152 |
+
# (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
|
153 |
+
# So I'm assuming we can use flipped crops and that will add sth..
|
154 |
+
self.flipped_crops_to_ext = [0, 1, 2]
|
155 |
+
elif num_crops == 3:
|
156 |
+
self.crops_to_ext = [0, 1, 2]
|
157 |
+
self.flipped_crops_to_ext = []
|
158 |
+
elif num_crops == 1:
|
159 |
+
self.crops_to_ext = [1]
|
160 |
+
self.flipped_crops_to_ext = []
|
161 |
+
else:
|
162 |
+
raise NotImplementedError(
|
163 |
+
"Nothing else supported yet, "
|
164 |
+
"slowfast only takes 0, 1, 2 as arguments"
|
165 |
+
)
|
166 |
+
|
167 |
+
def forward(self, videos: Sequence[torch.Tensor]):
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
videos: A list of C, T, H, W videos.
|
171 |
+
Returns:
|
172 |
+
videos: A list with 3x the number of elements. Each video converted
|
173 |
+
to C, T, H', W' by spatial cropping.
|
174 |
+
"""
|
175 |
+
assert isinstance(videos, list), "Must be a list of videos after temporal crops"
|
176 |
+
assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
|
177 |
+
res = []
|
178 |
+
for video in videos:
|
179 |
+
for spatial_idx in self.crops_to_ext:
|
180 |
+
res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
|
181 |
+
if not self.flipped_crops_to_ext:
|
182 |
+
continue
|
183 |
+
flipped_video = transforms.functional.hflip(video)
|
184 |
+
for spatial_idx in self.flipped_crops_to_ext:
|
185 |
+
res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
|
186 |
+
return res
|