hvaldez commited on
Commit
c18a21e
1 Parent(s): 8d42677

first commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,12 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/charades_ego/video/15AKPEGO.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ data/charades_ego/video/184EHEGO.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ data/charades_ego/video/CC0LBEGO.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ data/charades_ego/video/FLY2FEGO.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ data/charades_ego/video/P9SOAEGO.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ data/charades_ego/video/PRODQEGO.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ data/charades_ego/video/QLXEXEGO.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ data/charades_ego/video/S8YZIEGO.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ data/charades_ego/video/X2JTKEGO.mp4 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ from demo import VideoCLSModel
4
+
5
+ sample_videos = [
6
+ 'data/charades_ego/video/P9SOAEGO.mp4',
7
+ 'data/charades_ego/video/6D5DHEGO.mp4',
8
+ 'data/charades_ego/video/15AKPEGO.mp4',
9
+ 'data/charades_ego/video/X2JTKEGO.mp4',
10
+ 'data/charades_ego/video/184EHEGO.mp4',
11
+ 'data/charades_ego/video/S8YZIEGO.mp4',
12
+ 'data/charades_ego/video/PRODQEGO.mp4',
13
+ 'data/charades_ego/video/QLXEXEGO.mp4',
14
+ 'data/charades_ego/video/CC0LBEGO.mp4',
15
+ 'data/charades_ego/video/FLY2FEGO.mp4'
16
+ ]
17
+
18
+ def main():
19
+ svitt = VideoCLSModel("configs/charades_ego/svitt.yml")
20
+
21
+ def predict(video_str):
22
+ video_file = video_str.split('/')[-1]
23
+ for i, item in enumerate(sample_videos):
24
+ if video_file in item:
25
+ idx = i
26
+ break
27
+
28
+ ft_action, gt_action = svitt.predict(idx)
29
+
30
+ return gt_action, ft_action
31
+
32
+ with gr.Blocks() as demo:
33
+ gr.Markdown(
34
+ """
35
+ # SViTT-Ego for Action Recognition
36
+ Choose a sample video and click predict to view the results.
37
+ """
38
+ )
39
+ with gr.Row():
40
+ idx = gr.Number(label="Idx", visible=False)
41
+ video = gr.Video(label='video', format='mp4', autoplay=True, height=256, width=256)
42
+ with gr.Row():
43
+ label = gr.Text(label="Ground Truth")
44
+ ours = gr.Text(label="SViTT-Ego prediction")
45
+ with gr.Row():
46
+ btn = gr.Button("Predict", variant="primary")
47
+ btn.click(predict, inputs=[video], outputs=[label, ours])
48
+ with gr.Column():
49
+ gr.Examples(examples=[[x] for _, x in enumerate(sample_videos)], inputs=[video])
50
+
51
+ demo.launch()
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
56
+
ckpt/svitt-ego.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73d9612778da3471372bc46a9e10fd6c1ed66dd2c7a715bf34472d795bd0bf58
3
+ size 2500516566
configs/base.yml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrain: ""
3
+ resume: ""
4
+ timesformer_freeze_space: false
5
+ drop_path_rate: 0.1
6
+ dropout_ratio: 0.5
7
+ freeze_vis_backbone: false
8
+ freeze_txt_backbone: false
9
+ use_vn_classifier: false
10
+
11
+ data:
12
+ dataset: ek100_mir
13
+ root: datasets/EK100/video_ht256px
14
+ metadata: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_train.csv
15
+ metadata_val: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/EPIC_100_retrieval_test.csv
16
+ relevancy_path: datasets/EK100/epic-kitchens-100-annotations/retrieval_annotations/relevancy/caption_relevancy_EPIC_100_retrieval_test.pkl
17
+ clip_length: 16
18
+ clip_stride: 4
19
+ sparse_sample: false
20
+ num_crops: 1
21
+ num_clips: 1
configs/charades_ego/action-recognition.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ text_encoder: bert-base-uncased
3
+ bert_config: configs/config_bert.json
4
+ vit_type: beit # items in ${vit_zoo}
5
+ vit_zoo: # from huggingface
6
+ beit: microsoft/beit-base-patch16-224-pt22k-ft22k
7
+ vit_name_or_pretrained_path: ${vit_zoo[${vit_type}]}
8
+
9
+ vision_encoder_args:
10
+ token_keep_rate: 0.7
11
+ token_keep_strategy: cls_attn
12
+ token_drop_loc: [3, 6, 9]
13
+ sparse_local_attn: 1
14
+ sparse_random_attn: 5
15
+ attn_block_size: 56
16
+
17
+ image_res: 224
18
+ embed_dim: 256
19
+ video_input:
20
+ num_frames: 4
21
+ reader: decord # one of [decord, av]
22
+ sample_type: rand
23
+ num_frames_test: 16 # num_frames during inference/test
24
+ sample_type_test: middle
25
+ max_txt_l:
26
+ image: 32
27
+ video: 32
28
+
29
+ batch_size:
30
+ image: 8
31
+ video: 8
32
+ batch_size_test:
33
+ image: 8
34
+ video: 8
35
+ k_test: 128
36
+ temp: 0.18
37
+ mlm_prob: 0.5
configs/charades_ego/svitt.yml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ pretrain: ckpt/svitt-ego.pth
3
+ freeze_vis_backbone: true
4
+ freeze_txt_backbone: true
5
+ num_frames: 16
6
+ config: configs/charades_ego/action-recognition.yaml
7
+
8
+ data:
9
+ dataset: charades_ego
10
+ root: data/charades_ego/video
11
+ metadata_val: data/charades_ego/csv/{}.csv
12
+ label_map: meta/charades_ego/charades_ego.json
13
+ clip_length: 16
14
+ sparse_sample: true
configs/config_bert.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "BertForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "hidden_act": "gelu",
7
+ "hidden_dropout_prob": 0.1,
8
+ "hidden_size": 768,
9
+ "initializer_range": 0.02,
10
+ "intermediate_size": 3072,
11
+ "layer_norm_eps": 1e-12,
12
+ "max_position_embeddings": 512,
13
+ "model_type": "bert",
14
+ "num_attention_heads": 12,
15
+ "num_hidden_layers": 12,
16
+ "pad_token_id": 0,
17
+ "type_vocab_size": 2,
18
+ "vocab_size": 30522,
19
+ "fusion_layer": 9,
20
+ "encoder_width": 768
21
+ }
data/charades_ego/Charades_v1_classes.txt ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c000 Holding some clothes
2
+ c001 Putting clothes somewhere
3
+ c002 Taking some clothes from somewhere
4
+ c003 Throwing clothes somewhere
5
+ c004 Tidying some clothes
6
+ c005 Washing some clothes
7
+ c006 Closing a door
8
+ c007 Fixing a door
9
+ c008 Opening a door
10
+ c009 Putting something on a table
11
+ c010 Sitting on a table
12
+ c011 Sitting at a table
13
+ c012 Tidying up a table
14
+ c013 Washing a table
15
+ c014 Working at a table
16
+ c015 Holding a phone/camera
17
+ c016 Playing with a phone/camera
18
+ c017 Putting a phone/camera somewhere
19
+ c018 Taking a phone/camera from somewhere
20
+ c019 Talking on a phone/camera
21
+ c020 Holding a bag
22
+ c021 Opening a bag
23
+ c022 Putting a bag somewhere
24
+ c023 Taking a bag from somewhere
25
+ c024 Throwing a bag somewhere
26
+ c025 Closing a book
27
+ c026 Holding a book
28
+ c027 Opening a book
29
+ c028 Putting a book somewhere
30
+ c029 Smiling at a book
31
+ c030 Taking a book from somewhere
32
+ c031 Throwing a book somewhere
33
+ c032 Watching/Reading/Looking at a book
34
+ c033 Holding a towel/s
35
+ c034 Putting a towel/s somewhere
36
+ c035 Taking a towel/s from somewhere
37
+ c036 Throwing a towel/s somewhere
38
+ c037 Tidying up a towel/s
39
+ c038 Washing something with a towel
40
+ c039 Closing a box
41
+ c040 Holding a box
42
+ c041 Opening a box
43
+ c042 Putting a box somewhere
44
+ c043 Taking a box from somewhere
45
+ c044 Taking something from a box
46
+ c045 Throwing a box somewhere
47
+ c046 Closing a laptop
48
+ c047 Holding a laptop
49
+ c048 Opening a laptop
50
+ c049 Putting a laptop somewhere
51
+ c050 Taking a laptop from somewhere
52
+ c051 Watching a laptop or something on a laptop
53
+ c052 Working/Playing on a laptop
54
+ c053 Holding a shoe/shoes
55
+ c054 Putting shoes somewhere
56
+ c055 Putting on shoe/shoes
57
+ c056 Taking shoes from somewhere
58
+ c057 Taking off some shoes
59
+ c058 Throwing shoes somewhere
60
+ c059 Sitting in a chair
61
+ c060 Standing on a chair
62
+ c061 Holding some food
63
+ c062 Putting some food somewhere
64
+ c063 Taking food from somewhere
65
+ c064 Throwing food somewhere
66
+ c065 Eating a sandwich
67
+ c066 Making a sandwich
68
+ c067 Holding a sandwich
69
+ c068 Putting a sandwich somewhere
70
+ c069 Taking a sandwich from somewhere
71
+ c070 Holding a blanket
72
+ c071 Putting a blanket somewhere
73
+ c072 Snuggling with a blanket
74
+ c073 Taking a blanket from somewhere
75
+ c074 Throwing a blanket somewhere
76
+ c075 Tidying up a blanket/s
77
+ c076 Holding a pillow
78
+ c077 Putting a pillow somewhere
79
+ c078 Snuggling with a pillow
80
+ c079 Taking a pillow from somewhere
81
+ c080 Throwing a pillow somewhere
82
+ c081 Putting something on a shelf
83
+ c082 Tidying a shelf or something on a shelf
84
+ c083 Reaching for and grabbing a picture
85
+ c084 Holding a picture
86
+ c085 Laughing at a picture
87
+ c086 Putting a picture somewhere
88
+ c087 Taking a picture of something
89
+ c088 Watching/looking at a picture
90
+ c089 Closing a window
91
+ c090 Opening a window
92
+ c091 Washing a window
93
+ c092 Watching/Looking outside of a window
94
+ c093 Holding a mirror
95
+ c094 Smiling in a mirror
96
+ c095 Washing a mirror
97
+ c096 Watching something/someone/themselves in a mirror
98
+ c097 Walking through a doorway
99
+ c098 Holding a broom
100
+ c099 Putting a broom somewhere
101
+ c100 Taking a broom from somewhere
102
+ c101 Throwing a broom somewhere
103
+ c102 Tidying up with a broom
104
+ c103 Fixing a light
105
+ c104 Turning on a light
106
+ c105 Turning off a light
107
+ c106 Drinking from a cup/glass/bottle
108
+ c107 Holding a cup/glass/bottle of something
109
+ c108 Pouring something into a cup/glass/bottle
110
+ c109 Putting a cup/glass/bottle somewhere
111
+ c110 Taking a cup/glass/bottle from somewhere
112
+ c111 Washing a cup/glass/bottle
113
+ c112 Closing a closet/cabinet
114
+ c113 Opening a closet/cabinet
115
+ c114 Tidying up a closet/cabinet
116
+ c115 Someone is holding a paper/notebook
117
+ c116 Putting their paper/notebook somewhere
118
+ c117 Taking paper/notebook from somewhere
119
+ c118 Holding a dish
120
+ c119 Putting a dish/es somewhere
121
+ c120 Taking a dish/es from somewhere
122
+ c121 Wash a dish/dishes
123
+ c122 Lying on a sofa/couch
124
+ c123 Sitting on sofa/couch
125
+ c124 Lying on the floor
126
+ c125 Sitting on the floor
127
+ c126 Throwing something on the floor
128
+ c127 Tidying something on the floor
129
+ c128 Holding some medicine
130
+ c129 Taking/consuming some medicine
131
+ c130 Putting groceries somewhere
132
+ c131 Laughing at television
133
+ c132 Watching television
134
+ c133 Someone is awakening in bed
135
+ c134 Lying on a bed
136
+ c135 Sitting in a bed
137
+ c136 Fixing a vacuum
138
+ c137 Holding a vacuum
139
+ c138 Taking a vacuum from somewhere
140
+ c139 Washing their hands
141
+ c140 Fixing a doorknob
142
+ c141 Grasping onto a doorknob
143
+ c142 Closing a refrigerator
144
+ c143 Opening a refrigerator
145
+ c144 Fixing their hair
146
+ c145 Working on paper/notebook
147
+ c146 Someone is awakening somewhere
148
+ c147 Someone is cooking something
149
+ c148 Someone is dressing
150
+ c149 Someone is laughing
151
+ c150 Someone is running somewhere
152
+ c151 Someone is going from standing to sitting
153
+ c152 Someone is smiling
154
+ c153 Someone is sneezing
155
+ c154 Someone is standing up from somewhere
156
+ c155 Someone is undressing
157
+ c156 Someone is eating something
data/charades_ego/csv/0.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ P9SOAEGO,Z241,Stairs,6.0,6.0,Yes,"A person holding a broom walks up and down the stairs, brushing nonchalantly on each step. They put the broom down and pull out a vial of medicine from their pocket, then toss it down the stairs.",broom;floor;medicine;stairs,,c099 0.00 23.21;c127 1.20 23.21;c102 0.00 23.21;c100 0.00 23.21;c098 0.00 23.21;c101 25.20 23.21,33.33,Yes,LY2GQ
data/charades_ego/csv/1.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ 6D5DHEGO,UTMU,Kitchen,5.0,5.0,Yes,"The person put a pot on the stove and then turned to the counter to grab a cup. The person drank from the cup for a little. The person then poured the rest of the contents of the drink into the sink. The person then went to leave the room, grabbing the doorknob.",food;glass;sink;stove,,c109 20.30 21.38;c107 7.70 17.90;c147 0.00 8.70;c110 7.50 13.60,34.5,Yes,
data/charades_ego/csv/2.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ 15AKPEGO,I2IV,Home Office / Study (A room in a house used for work),7.0,7.0,Yes,A person is putting a box on the shelf and then closing the cabinet.,box;cabinet;desk;door;shelf,,c043 4.00 10.30;c040 5.30 19.80;c112 19.10 26.70;c042 14.10 20.70;c081 14.10 20.70;c114 15.00 20.70;c006 18.80 28.70,30.83,Yes,IA5TC
data/charades_ego/csv/3.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ X2JTKEGO,EV0Z,Recreation room / Man cave,6.0,5.0,Yes,"Person is watching television while another person is walking towards the shelf, grasping the camera.",bed;camera;drawer;television,,c135 0.00 30.30;c132 0.00 30.50;c016 7.70 29.80;c015 0.00 25.50,31.75,Yes,D3PPI
data/charades_ego/csv/4.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ 184EHEGO,KFGP,Home Office / Study (A room in a house used for work),6.0,6.0,Yes,A person is eating at a table while they play on a laptop.,chair;food;laptop;table,,c011 1.20 18.00;c059 0.20 17.60;c052 2.10 23.62;c156 0.40 23.62;c010 15.70 23.62;c014 14.40 23.62,37.17,Yes,CUB69
data/charades_ego/csv/5.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ S8YZIEGO,I2IV,Kitchen,6.0,7.0,Yes,"A person is cooking food on the stove. The person takes a picture of the food with their phone, then puts the phone back into their pocket.",food;phone/camera;pot;stirring utensil;stove,,c147 0.00 24.67;c017 24.50 24.67;c087 16.50 23.50;c015 8.60 24.67;c016 8.10 24.67;c018 5.80 13.60;c147 0.00 10.80;c017 27.10 24.67;c154 0.00 24.67,34.62,Yes,DUEEE
data/charades_ego/csv/6.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ PRODQEGO,DJ17,Kitchen,6.0,7.0,Yes,A person is cooking at a stove then they begin to play with some food that's on the counter.,dish;food;shelf;stove;table,,c009 20.90 30.00;c064 22.50 29.90;c063 16.50 23.10;c147 0.00 18.60;c061 19.10 31.79,32.17,Yes,Y5826
data/charades_ego/csv/7.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ QLXEXEGO,4OHY,Hallway,6.0,5.0,Yes,"A person smiles as they fix the door. The person laughs hysterically, then throws their bag of tools across the room.",door,,c006 7.50 22.20;c008 32.30 29.71;c152 17.40 26.00;c007 0.00 27.40,40.33,Yes,K3T1B
data/charades_ego/csv/8.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ CC0LBEGO,3VLX,Recreation room / Man cave,6.0,7.0,Yes,The person is fixing a shelf in the Recreation room / Man cave. Once the person was finished the person grabbed a glass of water and took a few sips. The person then looked around and left the room.,dirt;glass;shelf;water,,c110 0.00 32.38;c110 22.00 26.90;c106 22.30 26.70;c107 20.90 26.30;c082 1.40 11.40,36.29,Yes,
data/charades_ego/csv/9.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ id,subject,scene,quality,relevance,verified,script,objects,descriptions,actions,length,egocentric,charades_video
2
+ FLY2FEGO,VT5W,Kitchen,7.0,7.0,Yes,A person opens a cabinet and takes out some coffee and then closes the cabinet.,closet/cabinet;coffee,,c113 0.00 4.20;c112 3.90 9.90,31.58,Yes,EJJIO
data/charades_ego/video/15AKPEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1dbe66fb96be257ae24c122a8f534e893883aa058872932ea630d20e9a609a1f
3
+ size 1150119
data/charades_ego/video/184EHEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:115fea186961bddc3bc8e9f10de4472a450720b143361f598257308d64bfd1b9
3
+ size 1459016
data/charades_ego/video/6D5DHEGO.mp4 ADDED
Binary file (871 kB). View file
 
data/charades_ego/video/CC0LBEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b80523915026d3e23740b4eb057362809e6375e973e0b2fad42a5f11e4331629
3
+ size 1414227
data/charades_ego/video/FLY2FEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b45b12ffeeeac6ae74955c37fe57bc468d309aaae188f9020cde88c3f7dd68da
3
+ size 1215686
data/charades_ego/video/P9SOAEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfc22189ed6d228eca382474ff668f52d0c2b034204241837a4ea00c6b650fd5
3
+ size 2497723
data/charades_ego/video/PRODQEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd034ab3146bbb36264f5ddbf7b451d09e3e1002c4550e565805710e8d5a1cd
3
+ size 1318941
data/charades_ego/video/QLXEXEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29cc6ec737fe810378ba1f91363078194cf73a4e8e9dcafb45013e04e426bb94
3
+ size 1834803
data/charades_ego/video/S8YZIEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d25f66f11549b8dac6a05c0f9405f1dc3a77df68a56430b76d33ab830300fd04
3
+ size 1439584
data/charades_ego/video/X2JTKEGO.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:617b7e321a96271f9603869999a7dc4ac26bf2e8d5f43ab8758428731373de6d
3
+ size 1976273
demo.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### demo.py
2
+ # Define model classes for inference.
3
+ ###
4
+ import json
5
+ import numpy as np
6
+ import os
7
+ import pandas as pd
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.backends.cudnn as cudnn
12
+ import torchvision.transforms as transforms
13
+ import torchvision.transforms._transforms_video as transforms_video
14
+ from sklearn.metrics import confusion_matrix
15
+ from einops import rearrange
16
+ from transformers import BertTokenizer
17
+
18
+ from svitt.model import SViTT
19
+ from svitt.datasets import VideoClassyDataset
20
+ from svitt.video_transforms import Permute
21
+ from svitt.config import load_cfg, setup_config
22
+ from svitt.evaluation_charades import charades_map
23
+ from svitt.evaluation import get_mean_accuracy
24
+
25
+
26
+ class VideoModel(nn.Module):
27
+ """ Base model for video understanding based on SViTT architecture. """
28
+ def __init__(self, config):
29
+ """ Initializes the model.
30
+ Parameters:
31
+ config: config file
32
+ """
33
+ super(VideoModel, self).__init__()
34
+ self.cfg = load_cfg(config)
35
+ self.model = self.build_model()
36
+ self.templates = ['{}']
37
+ self.dataset = self.cfg['data']['dataset']
38
+ self.eval()
39
+
40
+ def build_model(self):
41
+ cfg = self.cfg
42
+ if cfg['model'].get('pretrain', False):
43
+ ckpt_path = cfg['model']['pretrain']
44
+ else:
45
+ raise Exception('no checkpoint found')
46
+
47
+ if cfg['model'].get('config', False):
48
+ config_path = cfg['model']['config']
49
+ else:
50
+ raise Exception('no model config found')
51
+
52
+ self.model_cfg = setup_config(config_path)
53
+ self.tokenizer = BertTokenizer.from_pretrained(self.model_cfg.text_encoder)
54
+ model = SViTT(config=self.model_cfg, tokenizer=self.tokenizer)
55
+
56
+ print(f"Loading checkpoint from {ckpt_path}")
57
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
58
+ state_dict = checkpoint["model"]
59
+
60
+ # fix for zero-shot evaluation
61
+ for key in list(state_dict.keys()):
62
+ if "bert" in key:
63
+ encoder_key = key.replace("bert.", "")
64
+ state_dict[encoder_key] = state_dict[key]
65
+
66
+ if torch.cuda.is_available():
67
+ model.cuda()
68
+
69
+ model.load_state_dict(state_dict, strict=False)
70
+
71
+ return model
72
+
73
+
74
+
75
+ def eval(self):
76
+ cudnn.benchmark = True
77
+ for p in self.model.parameters():
78
+ p.requires_grad = False
79
+ self.model.eval()
80
+
81
+
82
+ class VideoCLSModel(VideoModel):
83
+ """ Video model for video classification tasks (Charades-Ego, EGTEA). """
84
+ def __init__(self, config):
85
+ super(VideoCLSModel, self).__init__(config)
86
+ self.labels, self.mapping_vn2act = self.gen_label_map()
87
+ self.text_features = self.get_text_features()
88
+
89
+ def gen_label_map(self):
90
+ labelmap = self.cfg.get('label_map', 'meta/charades_ego/label_map.json')
91
+ if os.path.isfile(labelmap):
92
+ print(f"=> Loading label maps from {labelmap}")
93
+ meta = json.load(open(labelmap, 'r'))
94
+ labels, mapping_vn2act = meta['labels'], meta['mapping_vn2act']
95
+ else:
96
+ from svitt.preprocess import generate_label_map
97
+ labels, mapping_vn2act = generate_label_map(self.dataset)
98
+ meta = {'labels': labels, 'mapping_vn2act': mapping_vn2act}
99
+ meta_dir = f'meta/{self.dataset}'
100
+ if not os.path.exists(meta_dir):
101
+ os.makedirs(meta_dir)
102
+ json.dump(meta, open(f'{meta_dir}/label_map.json', 'w'))
103
+ print(f"=> Label map is generated and saved to {meta_dir}/label_map.json")
104
+
105
+ return labels, mapping_vn2act
106
+
107
+ def load_data(self, idx=None):
108
+ print(f"=> Creating dataset")
109
+ cfg, dataset = self.cfg, self.dataset
110
+ data_cfg = cfg['data']
111
+ crop_size = 224
112
+ val_transform = transforms.Compose([
113
+ Permute([3, 0, 1, 2]), # T H W C -> C T H W
114
+ transforms.Resize(crop_size),
115
+ transforms.CenterCrop(crop_size),
116
+ transforms_video.NormalizeVideo(
117
+ mean=[108.3272985, 116.7460125, 104.09373615000001],
118
+ std=[68.5005327, 66.6321579, 70.32316305],
119
+ ),
120
+ ])
121
+
122
+ if idx is None:
123
+ metadata_val = data_cfg['metadata_val']
124
+ else:
125
+ metadata_val = data_cfg['metadata_val'].format(idx)
126
+ if dataset in ['charades_ego', 'egtea']:
127
+ val_dataset = VideoClassyDataset(
128
+ dataset,
129
+ data_cfg['root'],
130
+ metadata_val,
131
+ transform=val_transform,
132
+ is_training=False,
133
+ label_mapping=self.mapping_vn2act,
134
+ is_trimmed=False,
135
+ num_clips=1,
136
+ clip_length=data_cfg['clip_length'],
137
+ clip_stride=data_cfg['clip_stride'],
138
+ sparse_sample=data_cfg['sparse_sample'],
139
+ )
140
+ else:
141
+ raise NotImplementedError
142
+
143
+ val_loader = torch.utils.data.DataLoader(
144
+ val_dataset, batch_size=8, shuffle=False,
145
+ num_workers=4, pin_memory=True, sampler=None, drop_last=False
146
+ )
147
+
148
+ return val_loader
149
+
150
+ @torch.no_grad()
151
+ def get_text_features(self):
152
+ print('=> Extracting text features')
153
+ embeddings = self.tokenizer(
154
+ self.labels,
155
+ padding="max_length",
156
+ truncation=True,
157
+ max_length=self.model_cfg.max_txt_l.video,
158
+ return_tensors="pt",
159
+ )
160
+ _, class_embeddings = self.model.encode_text(embeddings)
161
+ return class_embeddings
162
+
163
+ @torch.no_grad()
164
+ def forward(self, idx=None):
165
+ print('=> Start forwarding')
166
+ val_loader = self.load_data(idx)
167
+ all_outputs = []
168
+ all_targets = []
169
+ for i, values in enumerate(val_loader):
170
+ images = values[0]
171
+ target = values[1]
172
+
173
+ if torch.cuda.is_available():
174
+ images = images.cuda(non_blocking=True)
175
+ target = target.cuda(non_blocking=True)
176
+
177
+ # encode images
178
+ images = rearrange(images, 'b c k h w -> b k c h w')
179
+ dims = images.shape
180
+ images = images.reshape(-1, 4, dims[-3], dims[-2], dims[-1])
181
+
182
+ image_features, _ = self.model.encode_image(images)
183
+
184
+ if image_features.ndim == 3:
185
+ image_features = rearrange(image_features, '(b k) n d -> b (k n) d', b=1)
186
+ else:
187
+ image_features = rearrange(image_features, '(b k) d -> b k d', b=1)
188
+
189
+ # cosine similarity as logits
190
+ similarity = self.model.get_sim(image_features, self.text_features)[0]
191
+
192
+ all_outputs.append(similarity.cpu())
193
+ all_targets.append(target.cpu())
194
+
195
+ all_outputs = torch.cat(all_outputs)
196
+ all_targets = torch.cat(all_targets)
197
+
198
+ return all_outputs, all_targets
199
+
200
+ @torch.no_grad()
201
+ def predict(self, idx=0):
202
+ all_outputs, all_targets = self.forward(idx)
203
+ preds, targets = all_outputs.numpy(), all_targets.numpy()
204
+ #sel = np.where(np.cumsum(sorted(preds[0].tolist(), reverse=True)) > 0.06)[0][0]
205
+ sel = 5
206
+ df = pd.DataFrame(self.labels)
207
+ pred_action = df.iloc[preds[0].argsort()[-sel:]].values.tolist()
208
+ gt_action = df.iloc[np.where(targets[0])[0]].values.tolist()
209
+ pred_action = sorted([x[0] for x in pred_action])
210
+ gt_action = sorted([x[0] for x in gt_action])
211
+ return pred_action, gt_action
212
+
213
+ @torch.no_grad()
214
+ def evaluate(self):
215
+ all_outputs, all_targets = self.forward()
216
+ preds, targets = all_outputs.numpy(), all_targets.numpy()
217
+ if self.dataset == 'charades_ego':
218
+ m_ap, _, m_aps = charades_map(preds, targets)
219
+ print('mAP = {:.3f}'.format(m_ap))
220
+ elif self.dataset == 'egtea':
221
+ cm = confusion_matrix(targets, preds.argmax(axis=1))
222
+ mean_class_acc, acc = get_mean_accuracy(cm)
223
+ print('Mean Acc. = {:.3f}, Top-1 Acc. = {:.3f}'.format(mean_class_acc, acc))
224
+ else:
225
+ raise NotImplementedError
226
+
meta/charades_ego/label_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"labels": ["Holding some clothes", "Putting clothes somewhere", "Taking some clothes from somewhere", "Throwing clothes somewhere", "Tidying some clothes", "Washing some clothes", "Closing a door", "Fixing a door", "Opening a door", "Putting something on a table", "Sitting on a table", "Sitting at a table", "Tidying up a table", "Washing a table", "Working at a table", "Holding a phone/camera", "Playing with a phone/camera", "Putting a phone/camera somewhere", "Taking a phone/camera from somewhere", "Talking on a phone/camera", "Holding a bag", "Opening a bag", "Putting a bag somewhere", "Taking a bag from somewhere", "Throwing a bag somewhere", "Closing a book", "Holding a book", "Opening a book", "Putting a book somewhere", "Smiling at a book", "Taking a book from somewhere", "Throwing a book somewhere", "Watching/Reading/Looking at a book", "Holding a towel/s", "Putting a towel/s somewhere", "Taking a towel/s from somewhere", "Throwing a towel/s somewhere", "Tidying up a towel/s", "Washing something with a towel", "Closing a box", "Holding a box", "Opening a box", "Putting a box somewhere", "Taking a box from somewhere", "Taking something from a box", "Throwing a box somewhere", "Closing a laptop", "Holding a laptop", "Opening a laptop", "Putting a laptop somewhere", "Taking a laptop from somewhere", "Watching a laptop or something on a laptop", "Working/Playing on a laptop", "Holding a shoe/shoes", "Putting shoes somewhere", "Putting on shoe/shoes", "Taking shoes from somewhere", "Taking off some shoes", "Throwing shoes somewhere", "Sitting in a chair", "Standing on a chair", "Holding some food", "Putting some food somewhere", "Taking food from somewhere", "Throwing food somewhere", "Eating a sandwich", "Making a sandwich", "Holding a sandwich", "Putting a sandwich somewhere", "Taking a sandwich from somewhere", "Holding a blanket", "Putting a blanket somewhere", "Snuggling with a blanket", "Taking a blanket from somewhere", "Throwing a blanket somewhere", "Tidying up a blanket/s", "Holding a pillow", "Putting a pillow somewhere", "Snuggling with a pillow", "Taking a pillow from somewhere", "Throwing a pillow somewhere", "Putting something on a shelf", "Tidying a shelf or something on a shelf", "Reaching for and grabbing a picture", "Holding a picture", "Laughing at a picture", "Putting a picture somewhere", "Taking a picture of something", "Watching/looking at a picture", "Closing a window", "Opening a window", "Washing a window", "Watching/Looking outside of a window", "Holding a mirror", "Smiling in a mirror", "Washing a mirror", "Watching something/someone/themselves in a mirror", "Walking through a doorway", "Holding a broom", "Putting a broom somewhere", "Taking a broom from somewhere", "Throwing a broom somewhere", "Tidying up with a broom", "Fixing a light", "Turning on a light", "Turning off a light", "Drinking from a cup/glass/bottle", "Holding a cup/glass/bottle of something", "Pouring something into a cup/glass/bottle", "Putting a cup/glass/bottle somewhere", "Taking a cup/glass/bottle from somewhere", "Washing a cup/glass/bottle", "Closing a closet/cabinet", "Opening a closet/cabinet", "Tidying up a closet/cabinet", "Someone is holding a paper/notebook", "Putting their paper/notebook somewhere", "Taking paper/notebook from somewhere", "Holding a dish", "Putting a dish/es somewhere", "Taking a dish/es from somewhere", "Wash a dish/dishes", "Lying on a sofa/couch", "Sitting on sofa/couch", "Lying on the floor", "Sitting on the floor", "Throwing something on the floor", "Tidying something on the floor", "Holding some medicine", "Taking/consuming some medicine", "Putting groceries somewhere", "Laughing at television", "Watching television", "Someone is awakening in bed", "Lying on a bed", "Sitting in a bed", "Fixing a vacuum", "Holding a vacuum", "Taking a vacuum from somewhere", "Washing their hands", "Fixing a doorknob", "Grasping onto a doorknob", "Closing a refrigerator", "Opening a refrigerator", "Fixing their hair", "Working on paper/notebook", "Someone is awakening somewhere", "Someone is cooking something", "Someone is dressing", "Someone is laughing", "Someone is running somewhere", "Someone is going from standing to sitting", "Someone is smiling", "Someone is sneezing", "Someone is standing up from somewhere", "Someone is undressing", "Someone is eating something"], "mapping_vn2act": {"c000": 0, "c001": 1, "c002": 2, "c003": 3, "c004": 4, "c005": 5, "c006": 6, "c007": 7, "c008": 8, "c009": 9, "c010": 10, "c011": 11, "c012": 12, "c013": 13, "c014": 14, "c015": 15, "c016": 16, "c017": 17, "c018": 18, "c019": 19, "c020": 20, "c021": 21, "c022": 22, "c023": 23, "c024": 24, "c025": 25, "c026": 26, "c027": 27, "c028": 28, "c029": 29, "c030": 30, "c031": 31, "c032": 32, "c033": 33, "c034": 34, "c035": 35, "c036": 36, "c037": 37, "c038": 38, "c039": 39, "c040": 40, "c041": 41, "c042": 42, "c043": 43, "c044": 44, "c045": 45, "c046": 46, "c047": 47, "c048": 48, "c049": 49, "c050": 50, "c051": 51, "c052": 52, "c053": 53, "c054": 54, "c055": 55, "c056": 56, "c057": 57, "c058": 58, "c059": 59, "c060": 60, "c061": 61, "c062": 62, "c063": 63, "c064": 64, "c065": 65, "c066": 66, "c067": 67, "c068": 68, "c069": 69, "c070": 70, "c071": 71, "c072": 72, "c073": 73, "c074": 74, "c075": 75, "c076": 76, "c077": 77, "c078": 78, "c079": 79, "c080": 80, "c081": 81, "c082": 82, "c083": 83, "c084": 84, "c085": 85, "c086": 86, "c087": 87, "c088": 88, "c089": 89, "c090": 90, "c091": 91, "c092": 92, "c093": 93, "c094": 94, "c095": 95, "c096": 96, "c097": 97, "c098": 98, "c099": 99, "c100": 100, "c101": 101, "c102": 102, "c103": 103, "c104": 104, "c105": 105, "c106": 106, "c107": 107, "c108": 108, "c109": 109, "c110": 110, "c111": 111, "c112": 112, "c113": 113, "c114": 114, "c115": 115, "c116": 116, "c117": 117, "c118": 118, "c119": 119, "c120": 120, "c121": 121, "c122": 122, "c123": 123, "c124": 124, "c125": 125, "c126": 126, "c127": 127, "c128": 128, "c129": 129, "c130": 130, "c131": 131, "c132": 132, "c133": 133, "c134": 134, "c135": 135, "c136": 136, "c137": 137, "c138": 138, "c139": 139, "c140": 140, "c141": 141, "c142": 142, "c143": 143, "c144": 144, "c145": 145, "c146": 146, "c147": 147, "c148": 148, "c149": 149, "c150": 150, "c151": 151, "c152": 152, "c153": 153, "c154": 154, "c155": 155, "c156": 156}}
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ scikit-learn
5
+ eva-decord
6
+ timm
7
+ einops
8
+ ftfy
9
+ regex
10
+ transformers
11
+ omegaconf
12
+ zCurve
13
+ numpy-hilbert-curve
svitt/config.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import yaml
3
+ from omegaconf import OmegaConf, DictConfig
4
+
5
+ def load_base_cfg():
6
+ with open('configs/base.yml', 'r') as fp:
7
+ cfg = yaml.load(fp, Loader=yaml.SafeLoader)
8
+ return cfg
9
+
10
+ def load_cfg(cfg_file):
11
+ cfg = load_base_cfg()
12
+ with open(cfg_file, 'r') as fp:
13
+ exp_cfg = yaml.load(fp, Loader=yaml.SafeLoader)
14
+
15
+ cfg['model'].update(exp_cfg.get('model', {}))
16
+ cfg['data'].update(exp_cfg.get('data', {}))
17
+ dataset = cfg['data'].get('dataset')
18
+ return cfg
19
+
20
+ def convert_types(config):
21
+ """Convert `'None'` (str) --> `None` (None). Only supports top-level"""
22
+ for k, v in config.items():
23
+ if isinstance(v, DictConfig):
24
+ setattr(config, k, convert_types(v))
25
+
26
+ # TODO convert types in ListConfig, right now they are ignored
27
+ # if isinstance(v, ListConfig):
28
+ # new_v = ListConfig()
29
+
30
+ if v in ["None", "none"]:
31
+ setattr(config, k, None)
32
+ return config
33
+
34
+ def setup_config(config_path):
35
+ yaml_config = OmegaConf.load(config_path)
36
+ config = convert_types(yaml_config)
37
+ return config
svitt/datasets.py ADDED
@@ -0,0 +1,526 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+ import glob
9
+ import json
10
+ import numpy as np
11
+ import os.path as osp
12
+ import pickle
13
+ import random
14
+
15
+ import decord
16
+ import pandas as pd
17
+ import torch
18
+
19
+
20
+ def datetime2sec(str):
21
+ hh, mm, ss = str.split(':')
22
+ return int(hh) * 3600 + int(mm) * 60 + float(ss)
23
+
24
+
25
+ def video_loader(root, vid, second, end_second=None, chunk_len=300, fps=30, clip_length=32, jitter=False):
26
+ if chunk_len == -1:
27
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid)))
28
+ second_offset = second
29
+ if end_second is not None:
30
+ end_second = min(end_second, len(vr) / vr.get_avg_fps())
31
+ else:
32
+ end_second = len(vr) / vr.get_avg_fps()
33
+ else:
34
+ chunk_start = int(second) // chunk_len * chunk_len
35
+ second_offset = second - chunk_start
36
+ vr = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start)))
37
+ if fps == -1:
38
+ fps = vr.get_avg_fps()
39
+
40
+ # calculate frame_ids
41
+ frame_offset = int(np.round(second_offset * fps))
42
+ total_duration = max(int((end_second - second) * fps), clip_length)
43
+ if chunk_len == -1:
44
+ if end_second <= second:
45
+ raise ValueError("end_second should be greater than second")
46
+ else:
47
+ frame_ids = get_frame_ids(frame_offset, min(frame_offset + total_duration, len(vr)), num_segments=clip_length, jitter=jitter)
48
+ else:
49
+ frame_ids = get_frame_ids(frame_offset, frame_offset + total_duration, num_segments=clip_length, jitter=jitter)
50
+
51
+ # load frames
52
+ if max(frame_ids) < len(vr):
53
+ try:
54
+ frames = vr.get_batch(frame_ids).asnumpy()
55
+ except decord.DECORDError as error:
56
+ print(error)
57
+ frames = vr.get_batch([0] * len(frame_ids)).asnumpy()
58
+ else:
59
+ # find the remaining frames in the next chunk
60
+ try:
61
+ frame_ids_part1 = list(filter(lambda frame_id: frame_id < len(vr), frame_ids))
62
+ frames_part1 = vr.get_batch(frame_ids_part1).asnumpy()
63
+ vr2 = decord.VideoReader(osp.join(root, '{}.mp4'.format(vid), '{}.mp4'.format(chunk_start + chunk_len)))
64
+ frame_ids_part2 = list(filter(lambda frame_id: frame_id >= len(vr), frame_ids))
65
+ frame_ids_part2 = [min(frame_id % len(vr), len(vr2) - 1) for frame_id in frame_ids_part2]
66
+ frames_part2 = vr2.get_batch(frame_ids_part2).asnumpy()
67
+ frames = np.concatenate([frames_part1, frames_part2], axis=0)
68
+ # the next chunk does not exist; the current chunk is the last one
69
+ except (RuntimeError, decord.DECORDError) as error:
70
+ print(error)
71
+ frame_ids = get_frame_ids(min(frame_offset, len(vr) - 1), len(vr), num_segments=clip_length, jitter=jitter)
72
+ frames = vr.get_batch(frame_ids).asnumpy()
73
+
74
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
75
+ return torch.stack(frames, dim=0)
76
+
77
+
78
+ def get_frame_ids(start_frame, end_frame, num_segments=32, jitter=True):
79
+ seg_size = float(end_frame - start_frame - 1) / num_segments
80
+ seq = []
81
+ for i in range(num_segments):
82
+ start = int(np.round(seg_size * i) + start_frame)
83
+ end = int(np.round(seg_size * (i + 1)) + start_frame)
84
+ end = min(end, end_frame)
85
+ if jitter:
86
+ frame_id = np.random.randint(low=start, high=(end + 1))
87
+ else:
88
+ frame_id = (start + end) // 2
89
+ seq.append(frame_id)
90
+ return seq
91
+
92
+
93
+ def video_loader_by_frames(root, vid, frame_ids):
94
+ vr = decord.VideoReader(osp.join(root, vid))
95
+ try:
96
+ frames = vr.get_batch(frame_ids).asnumpy()
97
+ frames = [torch.tensor(frame, dtype=torch.float32) for frame in frames]
98
+ except (IndexError, decord.DECORDError) as error:
99
+ print(error)
100
+ print("Erroneous video: ", vid)
101
+ frames = [torch.zeros((240, 320, 3)) for _ in range(len(frame_ids))]
102
+ return torch.stack(frames, dim=0)
103
+
104
+
105
+ class VideoCaptionDatasetBase(torch.utils.data.Dataset):
106
+ def __init__(self, dataset, root, metadata, is_trimmed=True):
107
+ self.dataset = dataset
108
+ self.root = root
109
+ self.is_trimmed = is_trimmed
110
+
111
+ if self.dataset == 'ego4d':
112
+ with open(metadata, 'rb') as f:
113
+ self.samples = pickle.load(f)
114
+ elif self.dataset == 'ego4d_mcq':
115
+ with open(metadata, 'r') as f:
116
+ self.samples = json.load(f)
117
+ elif self.dataset in ['ek100_cls', 'ek100_mir']:
118
+ video_list = glob.glob(osp.join(self.root, '*/*.MP4'))
119
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
120
+ self.samples = []
121
+ with open(metadata) as f:
122
+ csv_reader = csv.reader(f)
123
+ _ = next(csv_reader) # skip the header
124
+ for row in csv_reader:
125
+ pid, vid = row[1:3]
126
+ # start_frame, end_frame = int(row[6]), int(row[7])
127
+ # Deprecated: some videos might have fps mismatch issue
128
+ start_timestamp, end_timestamp = datetime2sec(row[4]), datetime2sec(row[5])
129
+ narration = row[8]
130
+ verb, noun = int(row[10]), int(row[12])
131
+ vid_path = '{}/{}.MP4'.format(pid, vid)
132
+ fps = fps_dict[osp.join(self.root, vid_path)]
133
+ start_frame = int(np.round(fps * start_timestamp))
134
+ end_frame = int(np.ceil(fps * end_timestamp))
135
+ self.samples.append((vid_path, start_frame, end_frame, narration, verb, noun))
136
+ if self.dataset == 'ek100_mir':
137
+ self.metadata_sentence = pd.read_csv(metadata[:metadata.index('.csv')] + '_sentence.csv')
138
+ if 'train' in metadata:
139
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_train.pkl'), 'rb'))
140
+ elif 'test' in metadata:
141
+ self.relevancy_mat = pickle.load(open(osp.join(osp.dirname(metadata), 'relevancy', 'caption_relevancy_EPIC_100_retrieval_test.pkl'), 'rb'))
142
+ else:
143
+ raise ValueError('{} should contain either "train" or "test"!'.format(metadata))
144
+ self.relevancy = .1
145
+ elif self.dataset == 'egtea':
146
+ video_list = glob.glob(osp.join(self.root, '*/*'))
147
+ len_dict = {video: len(decord.VideoReader(video)) for video in video_list}
148
+
149
+ vn_list, labels = [], []
150
+ for row in open(osp.join(osp.dirname(metadata), 'action_idx.txt')):
151
+ row = row.strip()
152
+ vn = int(row.split(' ')[-1])
153
+ vn_list.append(vn)
154
+ narration = ' '.join(row.split(' ')[:-1])
155
+ labels.append(narration.replace('_', ' ').lower())
156
+ # labels.append(narration)
157
+ mapping_act2narration = {vn: narration for vn, narration in zip(vn_list, labels)}
158
+
159
+ self.samples = []
160
+ with open(metadata) as f:
161
+ for row in f:
162
+ clip_id, action_idx = row.strip().split(' ')[:2]
163
+ video_id = '-'.join(clip_id.split('-')[:3])
164
+ vid_relpath = osp.join(video_id, '{}.mp4'.format(clip_id))
165
+ vid_fullpath = osp.join(self.root, video_id, '{}.mp4'.format(clip_id))
166
+ self.samples.append((vid_relpath, 0, len_dict[vid_fullpath], mapping_act2narration[int(action_idx)]))
167
+ elif self.dataset == 'charades_ego':
168
+ video_list = glob.glob(osp.join(self.root, '*.mp4'))
169
+ fps_dict = {video: decord.VideoReader(video).get_avg_fps() for video in video_list}
170
+ self.samples = []
171
+ with open(metadata) as f:
172
+ csv_reader = csv.reader(f)
173
+ _ = next(csv_reader) # skip the header
174
+ for row in csv_reader:
175
+ video_id = row[0]
176
+ if self.is_trimmed:
177
+ for action_tuple in row[9].split(';'):
178
+ if not action_tuple:
179
+ continue
180
+ action, start_timestamp, end_timestamp = action_tuple.split(' ')
181
+ start_timestamp, end_timestamp = float(start_timestamp), float(end_timestamp)
182
+ vid_path = '{}.mp4'.format(video_id)
183
+ fps = fps_dict[osp.join(self.root, vid_path)]
184
+ start_frame = int(np.round(fps * start_timestamp))
185
+ end_frame = int(np.ceil(fps * end_timestamp))
186
+ self.samples.append((vid_path, start_frame, end_frame, action))
187
+ else:
188
+ if not row[9]:
189
+ action_list = []
190
+ else:
191
+ action_list = [action_tuple.split(' ')[0] for action_tuple in row[9].split(';')]
192
+ vid_path = '{}.mp4'.format(video_id)
193
+ fps = fps_dict[osp.join(self.root, vid_path)]
194
+ duration = fps * float(row[10])
195
+ self.samples.append((vid_path, 0, duration, action_list))
196
+ elif self.dataset == 'charades_ego_trimmed':
197
+ with open(metadata, 'rb') as f:
198
+ self.samples = pickle.load(f)
199
+ else:
200
+ raise NotImplementedError
201
+
202
+ def get_raw_item(self, i, is_training=True, num_clips=1, clip_length=32, clip_stride=2, sparse_sample=False,
203
+ narration_selection='random'):
204
+ if self.dataset == 'ego4d':
205
+ if len(self.samples[i]) == 4:
206
+ vid, start_second, end_second, narration = self.samples[i]
207
+ frames = video_loader(self.root, vid, start_second,
208
+ end_second=end_second,
209
+ clip_length=clip_length,
210
+ jitter=is_training)
211
+ if isinstance(narration, list):
212
+ if narration_selection == 'random':
213
+ narration = random.choice(narration)
214
+ elif narration_selection == 'concat':
215
+ narration = '. '.join(narration)
216
+ elif narration_selection == 'list':
217
+ narration = narration
218
+ else:
219
+ raise ValueError
220
+ return frames, narration
221
+ elif len(self.samples[i]) == 5:
222
+ # TODO: need better filtering strategy based on nll
223
+ vid, start_second, end_second, narration, _ = self.samples[i]
224
+ frames = video_loader(self.root, vid, start_second,
225
+ end_second=end_second,
226
+ clip_length=clip_length,
227
+ jitter=is_training)
228
+ if isinstance(narration, list):
229
+ if narration_selection == 'random':
230
+ narration = random.choice(narration)
231
+ elif narration_selection == 'concat':
232
+ narration = '. '.join(narration)
233
+ elif narration_selection == 'list':
234
+ narration = narration
235
+ else:
236
+ raise ValueError
237
+ return frames, narration
238
+ elif self.dataset == 'ego4d_mcq':
239
+ itemMCQ = self.samples[str(i)]
240
+ answerIndex = itemMCQ['answer']
241
+ textQuery = itemMCQ['query']['clip_text']
242
+ sampleOptions = itemMCQ['choices']
243
+ frames_options = []
244
+ narration_options = []
245
+ for option_id in range(len(sampleOptions)):
246
+ option = sampleOptions[str(option_id)]
247
+ frames = video_loader(self.root, option['video_uid'],
248
+ float(option['clip_start']), end_second=float(option['clip_end']),
249
+ clip_length=clip_length,
250
+ jitter=is_training)
251
+ frames_options.append(frames)
252
+ narration_options.append(option['clip_text'])
253
+ return textQuery, frames_options, narration_options, answerIndex, itemMCQ['types']
254
+ elif self.dataset == 'ek100_mir':
255
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
256
+ # from third_party.EgoVLP.base.base_dataset import sample_frames_start_end
257
+ # frame_ids = sample_frames_start_end(clip_length, start_frame, end_frame, sample='uniform', fix_start=None)
258
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
259
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
260
+ if is_training:
261
+ positive_list = np.where(self.relevancy_mat[i] > self.relevancy)[0].tolist()
262
+ if positive_list != []:
263
+ pos = random.sample(positive_list, min(len(positive_list), 1))[0]
264
+ if pos < len(self.metadata_sentence) and pos < self.relevancy_mat.shape[1]:
265
+ return frames, (self.metadata_sentence.iloc[pos][1], self.relevancy_mat[i][pos])
266
+ else:
267
+ return frames, (narration, 1)
268
+ elif self.dataset == 'ek100_cls':
269
+ vid_path, start_frame, end_frame, narration, verb, noun = self.samples[i]
270
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=clip_length, jitter=is_training)
271
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
272
+ return frames, '{}:{}'.format(verb, noun)
273
+ elif self.dataset == 'egtea':
274
+ vid_path, start_frame, end_frame, sentence = self.samples[i]
275
+ if is_training:
276
+ assert num_clips == 1
277
+ if end_frame < clip_length * clip_stride:
278
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
279
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
280
+ frames = torch.cat((frames, zeros), dim=0)
281
+ frames = frames[::clip_stride]
282
+ else:
283
+ start_id = np.random.randint(0, end_frame - clip_length * clip_stride + 1)
284
+ frame_ids = np.arange(start_id, start_id + clip_length * clip_stride, clip_stride)
285
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
286
+ else:
287
+ if end_frame < clip_length * clip_stride:
288
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
289
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
290
+ frames = torch.cat((frames, zeros), dim=0)
291
+ frames = frames[::clip_stride]
292
+ frames = frames.repeat(num_clips, 1, 1, 1)
293
+ else:
294
+ frame_ids = []
295
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
296
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
297
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
298
+ return frames, sentence
299
+ elif self.dataset == 'charades_ego':
300
+ vid_path, start_frame, end_frame, action_list = self.samples[i]
301
+ if sparse_sample:
302
+ frame_ids = get_frame_ids(start_frame, end_frame, num_segments=num_clips * clip_length, jitter=is_training)
303
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
304
+ else:
305
+ if end_frame < clip_length * clip_stride:
306
+ frames = video_loader_by_frames(self.root, vid_path, list(np.arange(0, end_frame)))
307
+ zeros = torch.zeros((clip_length * clip_stride - end_frame, *frames.shape[1:]))
308
+ frames = torch.cat((frames, zeros), dim=0)
309
+ frames = frames[::clip_stride]
310
+ frames = frames.repeat(num_clips, 1, 1, 1)
311
+ else:
312
+ frame_ids = []
313
+ for start_id in np.linspace(0, end_frame - clip_length * clip_stride, num_clips, dtype=int):
314
+ frame_ids.extend(np.arange(start_id, start_id + clip_length * clip_stride, clip_stride))
315
+ #print('frame_ids:', frame_ids)
316
+ frames = video_loader_by_frames(self.root, vid_path, frame_ids)
317
+ return frames, action_list, vid_path
318
+ elif self.dataset == 'charades_ego_trimmed':
319
+ vid, start_second, end_second, narration = self.samples[i]
320
+ frames = video_loader(self.root, vid, start_second,
321
+ end_second=end_second,
322
+ chunk_len=-1, # no chunk for CharadesEgo
323
+ fps=-1, # could be variable fps
324
+ clip_length=clip_length,
325
+ jitter=is_training)
326
+ return frames, narration
327
+ else:
328
+ raise NotImplementedError
329
+
330
+ def __getitem__(self, i):
331
+ raise NotImplementedError
332
+
333
+ def __len__(self):
334
+ return len(self.samples)
335
+
336
+
337
+ class VideoCaptionDatasetCLIP(VideoCaptionDatasetBase):
338
+ def __init__(self, dataset, root, metadata, transform=None,
339
+ is_training=True, tokenizer=None,
340
+ clip_length=32, clip_stride=2, sparse_sample=False,
341
+ narration_selection='random',
342
+ num_hard_negatives=0,
343
+ subsample_stride=None):
344
+ super().__init__(dataset, root, metadata)
345
+
346
+ self.full_samples = self.samples.copy()
347
+ if isinstance(subsample_stride, int):
348
+ self.samples = self.samples[::subsample_stride]
349
+ self.transform = transform
350
+ self.is_training = is_training
351
+ self.tokenizer = tokenizer
352
+ self.clip_length = clip_length
353
+ self.clip_stride = clip_stride
354
+ self.sparse_sample = sparse_sample
355
+ self.narration_selection = narration_selection
356
+ self.num_hard_negatives = num_hard_negatives
357
+ if num_hard_negatives > 0:
358
+ assert self.dataset == 'htm_aa'
359
+
360
+ def __getitem__(self, i):
361
+ frames, caption = self.get_raw_item(
362
+ i, is_training=self.is_training,
363
+ clip_length=self.clip_length,
364
+ clip_stride=self.clip_stride,
365
+ sparse_sample=self.sparse_sample,
366
+ narration_selection=self.narration_selection,
367
+ )
368
+
369
+ # ek100_mir will also output relevancy value
370
+ if isinstance(caption, tuple):
371
+ caption, relevancy = caption
372
+ else:
373
+ relevancy = 0.
374
+
375
+ # apply transformation
376
+ if self.transform is not None:
377
+ frames = self.transform(frames)
378
+
379
+ # tokenize caption
380
+ if self.tokenizer is not None:
381
+ caption = self.tokenizer(caption)
382
+
383
+ if isinstance(caption, tuple):
384
+ caption, mask = caption
385
+ return frames, caption, mask, relevancy
386
+ else:
387
+ return frames, caption, relevancy
388
+
389
+
390
+ class VideoCaptionDatasetMCQ(VideoCaptionDatasetBase):
391
+ def __init__(self, dataset, root, metadata, transform=None,
392
+ is_training=True, tokenizer=None,
393
+ clip_length=32, clip_stride=2, sparse_sample=False,
394
+ narration_selection='random'):
395
+ super().__init__(dataset, root, metadata)
396
+
397
+ self.full_samples = self.samples.copy()
398
+ self.transform = transform
399
+ self.is_training = is_training
400
+ self.tokenizer = tokenizer
401
+ self.clip_length = clip_length
402
+ self.clip_stride = clip_stride
403
+ self.sparse_sample = sparse_sample
404
+ self.narration_selection = narration_selection
405
+
406
+ def __getitem__(self, i):
407
+
408
+ textQuery, frames_options, narration_options, answerIndex, q_type = self.get_raw_item(
409
+ i, is_training=self.is_training,
410
+ clip_length=self.clip_length,
411
+ clip_stride=self.clip_stride,
412
+ sparse_sample=self.sparse_sample,
413
+ narration_selection=self.narration_selection,
414
+ )
415
+
416
+ # apply transformation
417
+ if self.transform is not None:
418
+ frames_options = [self.transform(frames) for frames in frames_options]
419
+
420
+ # tokenize caption
421
+ if self.tokenizer is not None:
422
+ textQuery = self.tokenizer(textQuery)
423
+ narration_options = self.tokenizer(narration_options)
424
+ if isinstance(textQuery, tuple):
425
+ textQuery, mask_query = textQuery
426
+ narration_options, mask_options = narration_options
427
+ return (
428
+ textQuery, torch.stack(frames_options, dim=0),
429
+ narration_options, answerIndex, q_type,
430
+ mask_query, mask_options
431
+ )
432
+ else:
433
+ return textQuery, torch.stack(frames_options, dim=0), narration_options, answerIndex, q_type
434
+
435
+
436
+ class VideoClassyDataset(VideoCaptionDatasetBase):
437
+ def __init__(
438
+ self, dataset, root, metadata, transform=None,
439
+ is_training=True, label_mapping=None,
440
+ num_clips=1,
441
+ clip_length=32, clip_stride=2,
442
+ sparse_sample=False,
443
+ is_trimmed=True,
444
+ ):
445
+ super().__init__(dataset, root, metadata, is_trimmed=is_trimmed)
446
+
447
+ self.transform = transform
448
+ self.is_training = is_training
449
+ self.label_mapping = label_mapping
450
+ self.num_clips = num_clips
451
+ self.clip_length = clip_length
452
+ self.clip_stride = clip_stride
453
+ self.sparse_sample = sparse_sample
454
+
455
+ def __getitem__(self, i):
456
+ frames, label, vid_path = self.get_raw_item(
457
+ i, is_training=self.is_training,
458
+ num_clips=self.num_clips,
459
+ clip_length=self.clip_length,
460
+ clip_stride=self.clip_stride,
461
+ sparse_sample=self.sparse_sample,
462
+ )
463
+
464
+ # apply transformation
465
+ if self.transform is not None:
466
+ frames = self.transform(frames)
467
+
468
+ if self.label_mapping is not None:
469
+ if isinstance(label, list):
470
+ # multi-label case
471
+ res_array = np.zeros(len(self.label_mapping))
472
+ for lbl in label:
473
+ res_array[self.label_mapping[lbl]] = 1.
474
+ label = res_array
475
+ else:
476
+ label = self.label_mapping[label]
477
+
478
+ return frames, label, vid_path
479
+
480
+
481
+ def get_dataset(train_transform, tokenizer, cfg, is_training=True):
482
+ narration_selection = cfg.get('narration_selection', 'random')
483
+ num_hard_neg = cfg.get('num_hard_neg', 0)
484
+ data_cfg = cfg['data']
485
+ if cfg['model']['arch'].startswith('CLIP') or cfg['model']['arch'].startswith('VCLM'):
486
+ if is_training:
487
+ metadata = data_cfg['metadata']
488
+ else:
489
+ metadata = data_cfg['metadata_val']
490
+
491
+ return VideoCaptionDatasetCLIP(
492
+ data_cfg['dataset'], data_cfg['root'], metadata, train_transform,
493
+ is_training=is_training,
494
+ tokenizer=tokenizer,
495
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
496
+ sparse_sample=data_cfg['sparse_sample'],
497
+ narration_selection=narration_selection,
498
+ num_hard_negatives=num_hard_neg
499
+ )
500
+ else:
501
+ raise NotImplementedError
502
+
503
+
504
+ def get_downstream_dataset(transform, tokenizer, cfg, is_training=True, num_clips=0, label_mapping=None):
505
+ data_cfg = cfg['data']
506
+ n_clips = num_clips if num_clips > 0 else data_cfg['num_clips']
507
+ if is_training:
508
+ metadata = data_cfg['metadata']
509
+ return VideoClassyDataset(
510
+ data_cfg['dataset'], data_cfg['root'], metadata, transform,
511
+ is_training=True, label_mapping=label_mapping,
512
+ num_clips=n_clips,
513
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
514
+ sparse_sample=data_cfg['sparse_sample'],
515
+ )
516
+ else:
517
+ metadata = data_cfg['metadata_val']
518
+ return VideoClassyDataset(
519
+ data_cfg['dataset'], data_cfg['root'], metadata, transform,
520
+ is_training=False, label_mapping=label_mapping,
521
+ num_clips=n_clips,
522
+ clip_length=data_cfg['clip_length'], clip_stride=data_cfg['clip_stride'],
523
+ sparse_sample=data_cfg['sparse_sample'],
524
+ is_trimmed=not data_cfg['dataset'] == 'charades_ego'
525
+ )
526
+
svitt/evaluation.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+ import torch
9
+
10
+
11
+ def accuracy(output, target, topk=(1,)):
12
+ """Computes the accuracy over the k top predictions for the specified values of k"""
13
+ with torch.no_grad():
14
+ maxk = max(topk)
15
+ batch_size = target.size(0)
16
+
17
+ _, pred = output.topk(maxk, 1, True, True)
18
+ pred = pred.t()
19
+ correct = pred.eq(target.reshape(1, -1).expand_as(pred))
20
+
21
+ res = []
22
+ for k in topk:
23
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
24
+ res.append(correct_k.mul_(100.0 / batch_size))
25
+ return res
26
+
27
+
28
+ def get_mean_accuracy(cm):
29
+ list_acc = []
30
+ for i in range(len(cm)):
31
+ acc = 0
32
+ if cm[i, :].sum() > 0:
33
+ acc = cm[i, i] / cm[i, :].sum()
34
+ list_acc.append(acc)
35
+
36
+ return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)
svitt/evaluation_charades.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import numpy as np
8
+
9
+
10
+ def compute_map(submission_array, gt_array):
11
+ """ Returns mAP, weighted mAP, and AP array """
12
+ m_aps = []
13
+ n_classes = submission_array.shape[1]
14
+ for oc_i in range(n_classes):
15
+ sorted_idxs = np.argsort(-submission_array[:, oc_i])
16
+ tp = gt_array[:, oc_i][sorted_idxs] == 1
17
+ fp = np.invert(tp)
18
+ n_pos = tp.sum()
19
+ if n_pos < 0.1:
20
+ m_aps.append(float('nan'))
21
+ continue
22
+ fp.sum()
23
+ f_pcs = np.cumsum(fp)
24
+ t_pcs = np.cumsum(tp)
25
+ prec = t_pcs / (f_pcs+t_pcs).astype(float)
26
+ avg_prec = 0
27
+ for i in range(submission_array.shape[0]):
28
+ if tp[i]:
29
+ avg_prec += prec[i]
30
+ m_aps.append(avg_prec / n_pos.astype(float))
31
+ m_aps = np.array(m_aps)
32
+ #m_ap = np.mean(m_aps)
33
+ m_ap = m_aps[~np.isnan(m_aps)]
34
+ print(f'num of available classes: {len(m_ap)}')
35
+ m_ap = m_ap.mean() # compute mean w/o nan
36
+ w_ap = (m_aps * gt_array.sum(axis=0) / gt_array.sum().sum().astype(float))
37
+ return m_ap, w_ap, m_aps
38
+
39
+
40
+ def charades_map(submission_array, gt_array):
41
+ """
42
+ Approximate version of the charades evaluation function
43
+ For precise numbers, use the submission file with the official matlab script
44
+ """
45
+ fix = submission_array.copy()
46
+ empty = np.sum(gt_array, axis=1) == 0
47
+ fix[empty, :] = np.NINF
48
+ return compute_map(fix, gt_array)
49
+
50
+
51
+ def create_submission(video_list, predictions, out_file):
52
+ assert len(video_list) == predictions.shape[0]
53
+ with open(out_file, 'w') as f:
54
+ for i, video_id in enumerate(video_list):
55
+ pred_str = ' '.join(map(lambda x: str(x), predictions[i].tolist()))
56
+ f.write('{} {}\n\n'.format(video_id, pred_str))
svitt/model.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from svitt.utils import (
2
+ interpolate_pos_embed,
3
+ interpolate_pos_relative_bias_beit_3d,
4
+ )
5
+ from omegaconf import OmegaConf
6
+ from transformers import ViTModel, ViTConfig
7
+ from svitt.sparse_config import BertConfig, BeitConfig
8
+ from svitt.sparse_xbeit import BeitModel
9
+ from svitt.sparse_xbert import BertModel, BertForMaskedLM
10
+
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+
15
+
16
+ class SViTT(nn.Module):
17
+ """Common utils shared by pretraining and downstream retrieval"""
18
+ def __init__(self, config=None, tokenizer=None, pretrain=True, **kwargs):
19
+ super().__init__()
20
+ self.config = config
21
+ self.tokenizer = tokenizer
22
+ self.embed_dim = config.embed_dim
23
+ self.vision_width = 768
24
+ self.text_width = 768
25
+ self.pretrain = pretrain
26
+
27
+ self.vision_encoder, self.vision_layernorm = self.build_vision_encoder()
28
+ self.text_encoder = self.build_text_encoder()
29
+
30
+ self.vision_proj = nn.Linear(self.vision_width, self.embed_dim)
31
+ self.text_proj = nn.Linear(self.text_width, self.embed_dim)
32
+
33
+ self.temp = nn.Parameter(torch.ones([]) * config.temp)
34
+ self.itm_head = nn.Linear(self.text_width, 2)
35
+
36
+
37
+ def build_text_encoder(self):
38
+
39
+ bert_config = BertConfig.from_json_file(self.config.bert_config)
40
+
41
+ # Override params for sparse vision encoder
42
+ model_args = getattr(self.config, 'text_encoder_args', {})
43
+ if model_args:
44
+ model_args = OmegaConf.to_object(model_args)
45
+ bert_config.update(model_args)
46
+
47
+ if self.pretrain:
48
+ text_encoder, _ = BertForMaskedLM.from_pretrained(
49
+ self.config.text_encoder, config=bert_config,
50
+ output_loading_info=True
51
+ )
52
+ else:
53
+ text_encoder, _ = BertModel.from_pretrained(
54
+ self.config.text_encoder, config=bert_config,
55
+ add_pooling_layer=False, output_loading_info=True
56
+ )
57
+ return text_encoder
58
+
59
+ def build_vision_encoder(self):
60
+ # if self.config.vit_type in ["beit", "deit", "vit", "vit32"]:
61
+ if self.config.vit_type in ["beit"]:
62
+ vision_encoder = self.build_huggingface_vit_with_image_size(
63
+ self.config.vit_name_or_pretrained_path, self.config.image_res,)
64
+ else:
65
+ raise ValueError(f"Unknown vit type {self.config.vit_type}")
66
+
67
+ # add layernorm for normalizing BEiT outputs hidden states
68
+ vision_layernorm = None
69
+ if self.config.vit_type == "beit":
70
+ vision_layernorm = nn.LayerNorm(self.vision_width, eps=1e-12)
71
+ return vision_encoder, vision_layernorm
72
+
73
+ # @classmethod
74
+ # def build_huggingface_vit_with_image_size(cls, model_card: str, image_size: int):
75
+ def build_huggingface_vit_with_image_size(self, model_card: str, image_size: int):
76
+ """Build a vit model from huggingface hub, also interpolate pos_embed when needed.
77
+
78
+ Args:
79
+ model_card: name in huggingface hub, e.g., `facebook/deit-base-patch16-224`
80
+ image_size: new image size, may be different from pre-training image_size of `model_card`
81
+
82
+ ref: https://github.com/huggingface/transformers/issues/12167#issuecomment-861356232
83
+ """
84
+ is_beit = "beit" in model_card
85
+ if "beit" in model_card:
86
+ model_cls, config_cls = BeitModel, BeitConfig
87
+ elif "deit" in model_card or "vit" in model_card:
88
+ # the deit model we use is loaded in vit arch,
89
+ # see https://huggingface.co/facebook/deit-base-patch16-224#how-to-use
90
+ model_cls, config_cls = ViTModel, ViTConfig
91
+ else:
92
+ raise ValueError(f"Unexpected model_card: {model_card}")
93
+
94
+ # BEiT uses average pooled tokens instead of [CLS] used by other models
95
+ tmp_model = model_cls.from_pretrained(model_card, add_pooling_layer=is_beit)
96
+ state_dict = tmp_model.state_dict()
97
+ del tmp_model
98
+
99
+ # Override params for sparse vision encoder
100
+ model_args = getattr(self.config, 'vision_encoder_args', {})
101
+ if model_args:
102
+ model_args = OmegaConf.to_object(model_args)
103
+ model_config = config_cls.from_pretrained(
104
+ model_card,
105
+ image_size=image_size,
106
+ **model_args,
107
+ )
108
+ model = model_cls(config=model_config, add_pooling_layer=is_beit, num_frames=self.config.video_input.num_frames)
109
+ if is_beit:
110
+ # interpolate relative pos bias
111
+ state_dict = interpolate_pos_relative_bias_beit_3d(
112
+ state_dict_old=state_dict,
113
+ state_dict_new=model.state_dict(),
114
+ patch_shape_new=model.window_size
115
+ )
116
+ else:
117
+ # interpolate pos_embed and load weights to new model
118
+ state_dict["embeddings.position_embeddings"] = interpolate_pos_embed(
119
+ pos_embed_old=state_dict["embeddings.position_embeddings"],
120
+ pos_embed_new=model.embeddings.position_embeddings,
121
+ num_patches_new=model.embeddings.patch_embeddings.num_patches
122
+ )
123
+ msg = model.load_state_dict(state_dict, strict=False)
124
+ return model
125
+
126
+ def get_text_encoder(self):
127
+ """get text encoder, used for text and cross-modal encoding"""
128
+ encoder = self.text_encoder
129
+ return encoder.bert if hasattr(encoder, "bert") else encoder
130
+
131
+ def encode_image(self, video, output_token_idx=False, output_attentions=False):
132
+ video_embeds = self.vision_encoder(video, output_token_idx=output_token_idx, output_attentions=output_attentions) # (bsz, seq_len, d)
133
+ if self.vision_layernorm is not None: # only for BEiT mean-pooling
134
+ video_embeds.last_hidden_state = self.vision_layernorm(video_embeds.last_hidden_state)
135
+ if output_token_idx:
136
+ token_idx = video_embeds.token_idx
137
+
138
+ if output_attentions:
139
+ attentions = video_embeds.attentions
140
+
141
+ if self.config.vit_type == "beit":
142
+ pooled_video_embeds = video_embeds.pooler_output # (bsz*num_frms, d)
143
+ video_embeds = video_embeds.last_hidden_state # (bsz*num_frms, L, d)
144
+ else:
145
+ video_embeds = video_embeds.last_hidden_state
146
+ pooled_video_embeds = video_embeds[:, 0]
147
+
148
+ outputs = (video_embeds, pooled_video_embeds)
149
+
150
+ if output_token_idx:
151
+ outputs += (token_idx,)
152
+
153
+ if output_attentions:
154
+ outputs += (attentions,)
155
+
156
+ return outputs
157
+
158
+ def _encode_image(self, image):
159
+ bsz, num_frms, c, h, w = image.shape # `num_frms` could be changing for image (=1) or video (e.g., =4)
160
+ image = image.view(bsz*num_frms, c, h, w)
161
+ image_embeds = self.vision_encoder(image)
162
+ if self.vision_layernorm is not None: # only for BEiT mean-pooling
163
+ image_embeds.last_hidden_state = self.vision_layernorm(image_embeds.last_hidden_state)
164
+
165
+ if self.config.vit_type == "beit":
166
+ pooled_image_embeds = image_embeds.pooler_output # (bsz*num_frms, d)
167
+ image_embeds = image_embeds.last_hidden_state # (bsz*num_frms, L, d)
168
+ else:
169
+ image_embeds = image_embeds.last_hidden_state
170
+ pooled_image_embeds = image_embeds[:, 0]
171
+
172
+ image_embeds = image_embeds.view(bsz, num_frms, -1, self.vision_width) # (bsz, num_frms, L, d)
173
+ pooled_image_embeds = pooled_image_embeds.view(bsz, num_frms, self.vision_width) \
174
+ if pooled_image_embeds is not None else None # (bsz, num_frms, d)
175
+ return image_embeds, pooled_image_embeds
176
+
177
+ def encode_text(self, text):
178
+ text_output = self.get_text_encoder()(
179
+ text.input_ids,
180
+ attention_mask=text.attention_mask,
181
+ return_dict=True,
182
+ mode='text'
183
+ )
184
+ text_embeds = text_output.last_hidden_state
185
+ pooled_text_embeds = text_embeds[:, 0]
186
+ return text_embeds, pooled_text_embeds
187
+
188
+ @torch.no_grad()
189
+ def clip_contrastive_temperature(self, min_val=0.001, max_val=0.5):
190
+ """Seems only used during pre-training"""
191
+ self.temp.clamp_(min_val, max_val)
192
+
193
+ @torch.no_grad()
194
+ def get_mask(self, sim, idx=None, normalize=False):
195
+ """
196
+ sim: (N, N)
197
+ idx: (N, )
198
+ normalize: bool, make row sum equal to 1
199
+ """
200
+ if idx is not None:
201
+ idx = idx.view(-1, 1)
202
+ mask = torch.eq(idx, idx.T).to(sim.dtype)
203
+ if normalize:
204
+ mask = mask / mask.sum(1, keepdim=True)
205
+ else:
206
+ mask = torch.zeros_like(sim)
207
+ mask.fill_diagonal_(1)
208
+ return mask # `1` mark valid/matched location
209
+
210
+ def get_contrastive_loss(self, pooled_image_embeds, pooled_text_embeds, idx=None):
211
+ sim_i2t, sim_t2i = self.get_sim(
212
+ pooled_image_embeds, pooled_text_embeds, t=self.temp)
213
+
214
+ with torch.no_grad():
215
+ sim_i2t_targets = self.get_mask(sim_i2t, idx=idx, normalize=True)
216
+ sim_t2i_targets = sim_i2t_targets
217
+
218
+ loss_i2t = -torch.sum(
219
+ F.log_softmax(sim_i2t, dim=1) * sim_i2t_targets, dim=1).mean()
220
+ loss_t2i = -torch.sum(
221
+ F.log_softmax(sim_t2i, dim=1) * sim_t2i_targets, dim=1).mean()
222
+
223
+ loss_ita = (loss_i2t + loss_t2i) / 2
224
+ return loss_ita, sim_i2t, sim_t2i
225
+
226
+ def get_sim(self, pooled_image_embeds, pooled_text_embeds, t=1):
227
+ """
228
+ Args:
229
+ pooled_image_embeds: (bsz, num_frms, d)
230
+ pooled_text_embeds: (bsz, d)
231
+ t: temperature
232
+ """
233
+ image_proj = self.vision_proj
234
+ text_proj = self.text_proj
235
+
236
+ image_feat = F.normalize(image_proj(pooled_image_embeds), dim=-1)
237
+ text_feat = F.normalize(text_proj(pooled_text_embeds), dim=-1)
238
+
239
+ if image_feat.ndim == 3:
240
+ sim_i2t = torch.einsum("mld,nd->mln", image_feat, text_feat).mean(1) / t # (N, N)
241
+ else:
242
+ sim_i2t = torch.einsum("md,nd ->mn", image_feat, text_feat) / t # (N, N)
243
+ sim_t2i = sim_i2t.T
244
+ return sim_i2t, sim_t2i
245
+
246
+ def get_itm_loss(self,
247
+ sim_i2t,
248
+ sim_t2i,
249
+ text_embeds,
250
+ text_atts,
251
+ image_embeds,
252
+ image_atts,
253
+ idx=None,
254
+ ):
255
+ """
256
+ sim_i2t, sim_t2i: (N, N)
257
+ text_embeds, text_atts, image_embeds, image_atts: (N, *)
258
+ idx: (N, )
259
+ """
260
+ bsz = len(sim_i2t)
261
+
262
+ with torch.no_grad():
263
+ weights_i2t = F.softmax(sim_i2t+1e-4, dim=1) # (N, N)
264
+ weights_t2i = F.softmax(sim_t2i+1e-4, dim=1)
265
+
266
+ mask = self.get_mask(sim_i2t, idx=idx).bool()
267
+ weights_i2t.masked_fill_(mask, 0)
268
+ weights_t2i.masked_fill_(mask, 0)
269
+
270
+ # select a negative image for each text
271
+ if self.config.itm_hard_neg:
272
+ img_neg_indices = torch.multinomial(weights_t2i, 1).squeeze() #RuntimeError: invalid multinomial distribution (sum of probabilities <= 0)
273
+ else:
274
+ img_neg_indices = self.get_rand_indices(mask, 1).squeeze()
275
+
276
+ image_embeds_neg = image_embeds[img_neg_indices]
277
+
278
+ # select a negative text for each image
279
+ if self.config.itm_hard_neg:
280
+ txt_neg_indices = torch.multinomial(weights_i2t, 1).squeeze()
281
+ else:
282
+ txt_neg_indices = self.get_rand_indices(mask, 1).squeeze()
283
+
284
+ text_embeds_neg = text_embeds[txt_neg_indices]
285
+ text_atts_neg = text_atts[txt_neg_indices] # (N, L, d)
286
+
287
+ # embedding on local gpu
288
+ _text_embeds = text_embeds
289
+ _text_atts = text_atts
290
+ _image_embeds = image_embeds
291
+ _image_atts = image_atts
292
+ # concat embeddings
293
+ text_embeds_all = torch.cat([_text_embeds, _text_embeds, text_embeds_neg], dim=0)
294
+ text_atts_all = torch.cat([_text_atts, _text_atts, text_atts_neg], dim=0)
295
+ image_embeds_all = torch.cat([_image_embeds, image_embeds_neg, _image_embeds], dim=0)
296
+ image_atts_all = torch.cat([_image_atts, _image_atts, _image_atts], dim=0)
297
+
298
+ text_encoder = self.get_text_encoder()
299
+ output = text_encoder(
300
+ encoder_embeds=text_embeds_all,
301
+ attention_mask=text_atts_all,
302
+ encoder_hidden_states=image_embeds_all,
303
+ encoder_attention_mask=image_atts_all,
304
+ return_dict=True,
305
+ mode='fusion',
306
+ )
307
+
308
+ itm_embeds = output.last_hidden_state[:, 0] # pos (N, d) + neg (2N, d)
309
+
310
+ loss_itm = self._get_itm_loss(itm_embeds, enc=self.itm_head)
311
+ itm_embeds_pos = itm_embeds[:bsz] # (N, d)
312
+
313
+ return loss_itm, itm_embeds_pos
314
+
315
+ def _get_itm_loss(self, itm_embeds, enc):
316
+ """
317
+ itm_embeds: (3*N, D)
318
+ enc: nn.Module that projects cls_embeds
319
+ """
320
+ itm_scores = enc(itm_embeds) # (3*N, 2)
321
+ bs = itm_scores.size(0) // 3
322
+ itm_labels = itm_scores.new_ones(3*bs, dtype=torch.long)
323
+ itm_labels[bs:] = 0
324
+ loss_itm = F.cross_entropy(itm_scores, itm_labels)
325
+ return loss_itm
326
+
327
+ def get_rand_indices(self, mask, k):
328
+ """
329
+ Args:
330
+ mask: (N, L) 0 indicates the positions that we can sample, 1 otherwise
331
+ k: #indices to sample at each row
332
+ Returns:
333
+ (N, k) indices
334
+ """
335
+ mask = mask.float()
336
+ mask = mask - 10000 * mask
337
+ mask += torch.randn_like(mask)
338
+ _, indices = torch.sort(mask, dim=1, descending=True)
339
+ indices = indices[:, :k].contiguous()
340
+ return indices
svitt/preprocess.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import csv
8
+
9
+ from lavila.models.tokenizer import MyBertTokenizer, MyDistilBertTokenizer, MyGPT2Tokenizer, SimpleTokenizer
10
+
11
+
12
+ def generate_label_map(dataset):
13
+ if dataset == 'ek100_cls':
14
+ print("Preprocess ek100 action label space")
15
+ vn_list = []
16
+ mapping_vn2narration = {}
17
+ for f in [
18
+ '/data/EK100/epic-kitchens-100-annotations/EPIC_100_train.csv',
19
+ '/data/EK100/epic-kitchens-100-annotations/EPIC_100_validation.csv',
20
+ ]:
21
+ csv_reader = csv.reader(open(f))
22
+ _ = next(csv_reader) # skip the header
23
+ for row in csv_reader:
24
+ vn = '{}:{}'.format(int(row[10]), int(row[12]))
25
+ narration = row[8]
26
+ if vn not in vn_list:
27
+ vn_list.append(vn)
28
+ if vn not in mapping_vn2narration:
29
+ mapping_vn2narration[vn] = [narration]
30
+ else:
31
+ mapping_vn2narration[vn].append(narration)
32
+ # mapping_vn2narration[vn] = [narration]
33
+ vn_list = sorted(vn_list)
34
+ print('# of action= {}'.format(len(vn_list)))
35
+ mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
36
+ labels = [list(set(mapping_vn2narration[vn_list[i]])) for i in range(len(mapping_vn2act))]
37
+ print(labels[:5])
38
+ elif dataset == 'charades_ego':
39
+ print("=> preprocessing charades_ego action label space")
40
+ vn_list = []
41
+ labels = []
42
+ with open('data/charades_ego/Charades_v1_classes.txt') as f:
43
+ csv_reader = csv.reader(f)
44
+ for row in csv_reader:
45
+ vn = row[0][:4]
46
+ vn_list.append(vn)
47
+ narration = row[0][5:]
48
+ labels.append(narration)
49
+ mapping_vn2act = {vn: i for i, vn in enumerate(vn_list)}
50
+ print(labels[:5])
51
+ elif dataset == 'egtea':
52
+ print("=> preprocessing egtea action label space")
53
+ labels = []
54
+ with open('/data/EGTEA/action_idx.txt') as f:
55
+ for row in f:
56
+ row = row.strip()
57
+ narration = ' '.join(row.split(' ')[:-1])
58
+ labels.append(narration.replace('_', ' ').lower())
59
+ # labels.append(narration)
60
+ mapping_vn2act = {label: i for i, label in enumerate(labels)}
61
+ print(len(labels), labels[:5])
62
+ else:
63
+ raise NotImplementedError
64
+ return labels, mapping_vn2act
65
+
66
+
67
+ def generate_tokenizer(model):
68
+ if model.endswith('DISTILBERT_BASE'):
69
+ tokenizer = MyDistilBertTokenizer('distilbert-base-uncased')
70
+ elif model.endswith('BERT_BASE'):
71
+ tokenizer = MyBertTokenizer('bert-base-uncased')
72
+ elif model.endswith('BERT_LARGE'):
73
+ tokenizer = MyBertTokenizer('bert-large-uncased')
74
+ elif model.endswith('GPT2'):
75
+ tokenizer = MyGPT2Tokenizer('gpt2', add_bos=True)
76
+ elif model.endswith('GPT2_MEDIUM'):
77
+ tokenizer = MyGPT2Tokenizer('gpt2-medium', add_bos=True)
78
+ elif model.endswith('GPT2_LARGE'):
79
+ tokenizer = MyGPT2Tokenizer('gpt2-large', add_bos=True)
80
+ elif model.endswith('GPT2_XL'):
81
+ tokenizer = MyGPT2Tokenizer('gpt2-xl', add_bos=True)
82
+ else:
83
+ print("Using SimpleTokenizer because of model '{}'. "
84
+ "Please check if this is what you want".format(model))
85
+ tokenizer = SimpleTokenizer()
86
+ return tokenizer
svitt/sparse_config.py ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright Microsoft Research and The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from collections import OrderedDict
17
+ from typing import Mapping
18
+
19
+ from transformers.configuration_utils import PretrainedConfig
20
+ from transformers.onnx import OnnxConfig
21
+
22
+
23
+ BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
24
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/config.json",
25
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/config.json",
26
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/config.json",
27
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/config.json",
28
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/config.json",
29
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/config.json",
30
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/config.json",
31
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/config.json",
32
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/config.json",
33
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/config.json",
34
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/config.json",
35
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/config.json",
36
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/config.json",
37
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/config.json",
38
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/config.json",
39
+ "cl-tohoku/bert-base-japanese": "https://huggingface.co/cl-tohoku/bert-base-japanese/resolve/main/config.json",
40
+ "cl-tohoku/bert-base-japanese-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-whole-word-masking/resolve/main/config.json",
41
+ "cl-tohoku/bert-base-japanese-char": "https://huggingface.co/cl-tohoku/bert-base-japanese-char/resolve/main/config.json",
42
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking": "https://huggingface.co/cl-tohoku/bert-base-japanese-char-whole-word-masking/resolve/main/config.json",
43
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/config.json",
44
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/config.json",
45
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/config.json",
46
+ # See all BERT models at https://huggingface.co/models?filter=bert
47
+ }
48
+
49
+
50
+ class BertConfig(PretrainedConfig):
51
+ r"""
52
+ This is the configuration class to store the configuration of a [`BertModel`] or a
53
+ [`TFBertModel`]. It is used to instantiate a BERT model according to the specified arguments,
54
+ defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
55
+ to that of the BERT [bert-base-uncased](https://huggingface.co/bert-base-uncased) architecture.
56
+
57
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model
58
+ outputs. Read the documentation from [`PretrainedConfig`] for more information.
59
+
60
+
61
+ Args:
62
+ vocab_size (`int`, *optional*, defaults to 30522):
63
+ Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
64
+ `inputs_ids` passed when calling [`BertModel`] or
65
+ [`TFBertModel`].
66
+ hidden_size (`int`, *optional*, defaults to 768):
67
+ Dimensionality of the encoder layers and the pooler layer.
68
+ num_hidden_layers (`int`, *optional*, defaults to 12):
69
+ Number of hidden layers in the Transformer encoder.
70
+ num_attention_heads (`int`, *optional*, defaults to 12):
71
+ Number of attention heads for each attention layer in the Transformer encoder.
72
+ intermediate_size (`int`, *optional*, defaults to 3072):
73
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
74
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
75
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
76
+ `"gelu"`, `"relu"`, `"silu"` and `"gelu_new"` are supported.
77
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
78
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
79
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
80
+ The dropout ratio for the attention probabilities.
81
+ max_position_embeddings (`int`, *optional*, defaults to 512):
82
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
83
+ just in case (e.g., 512 or 1024 or 2048).
84
+ type_vocab_size (`int`, *optional*, defaults to 2):
85
+ The vocabulary size of the `token_type_ids` passed when calling [`BertModel`] or
86
+ [`TFBertModel`].
87
+ initializer_range (`float`, *optional*, defaults to 0.02):
88
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
89
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
90
+ The epsilon used by the layer normalization layers.
91
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
92
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`,
93
+ `"relative_key_query"`. For positional embeddings use `"absolute"`. For more information on
94
+ `"relative_key"`, please refer to [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155). For more information on `"relative_key_query"`, please refer to
95
+ *Method 4* in [Improve Transformer Models with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
96
+ use_cache (`bool`, *optional*, defaults to `True`):
97
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
98
+ relevant if `config.is_decoder=True`.
99
+ classifier_dropout (`float`, *optional*):
100
+ The dropout ratio for the classification head.
101
+
102
+ Examples:
103
+
104
+ ```python
105
+ >>> from transformers import BertModel, BertConfig
106
+
107
+ >>> # Initializing a BERT bert-base-uncased style configuration
108
+ >>> configuration = BertConfig()
109
+
110
+ >>> # Initializing a model from the bert-base-uncased style configuration
111
+ >>> model = BertModel(configuration)
112
+
113
+ >>> # Accessing the model configuration
114
+ >>> configuration = model.config
115
+ ```"""
116
+ model_type = "bert"
117
+
118
+ def __init__(
119
+ self,
120
+ vocab_size=30522,
121
+ hidden_size=768,
122
+ num_hidden_layers=12,
123
+ num_attention_heads=12,
124
+ intermediate_size=3072,
125
+ hidden_act="gelu",
126
+ hidden_dropout_prob=0.1,
127
+ attention_probs_dropout_prob=0.1,
128
+ max_position_embeddings=512,
129
+ type_vocab_size=2,
130
+ initializer_range=0.02,
131
+ layer_norm_eps=1e-12,
132
+ pad_token_id=0,
133
+ position_embedding_type="absolute",
134
+ use_cache=True,
135
+ classifier_dropout=None,
136
+ token_keep_rate=1,
137
+ token_keep_strategy='cls_attn',
138
+ token_drop_loc=[9],
139
+ **kwargs
140
+ ):
141
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
142
+
143
+ self.vocab_size = vocab_size
144
+ self.hidden_size = hidden_size
145
+ self.num_hidden_layers = num_hidden_layers
146
+ self.num_attention_heads = num_attention_heads
147
+ self.hidden_act = hidden_act
148
+ self.intermediate_size = intermediate_size
149
+ self.hidden_dropout_prob = hidden_dropout_prob
150
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
151
+ self.max_position_embeddings = max_position_embeddings
152
+ self.type_vocab_size = type_vocab_size
153
+ self.initializer_range = initializer_range
154
+ self.layer_norm_eps = layer_norm_eps
155
+ self.position_embedding_type = position_embedding_type
156
+ self.use_cache = use_cache
157
+ self.classifier_dropout = classifier_dropout
158
+ self.token_keep_rate = token_keep_rate
159
+ self.token_keep_strategy = token_keep_strategy
160
+ self.token_drop_loc = token_drop_loc
161
+
162
+
163
+ class BertOnnxConfig(OnnxConfig):
164
+ @property
165
+ def inputs(self) -> Mapping[str, Mapping[int, str]]:
166
+ return OrderedDict(
167
+ [
168
+ ("input_ids", {0: "batch", 1: "sequence"}),
169
+ ("attention_mask", {0: "batch", 1: "sequence"}),
170
+ ("token_type_ids", {0: "batch", 1: "sequence"}),
171
+ ]
172
+ )
173
+
174
+
175
+ BEIT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
176
+ "microsoft/beit-base-patch16-224-in22k": "https://huggingface.co/microsoft/beit-base-patch16-224-in22k/resolve/main/config.json",
177
+ # See all BEiT models at https://huggingface.co/models?filter=beit
178
+ }
179
+
180
+
181
+ class BeitConfig(PretrainedConfig):
182
+ r"""
183
+ This is the configuration class to store the configuration of a [`BeitModel`]. It is used to
184
+ instantiate an BEiT model according to the specified arguments, defining the model architecture. Instantiating a
185
+ configuration with the defaults will yield a similar configuration to that of the BEiT
186
+ [microsoft/beit-base-patch16-224-in22k](https://huggingface.co/microsoft/beit-base-patch16-224-in22k)
187
+ architecture.
188
+
189
+ Args:
190
+ vocab_size (`int`, *optional*, defaults to 8092):
191
+ Vocabulary size of the BEiT model. Defines the number of different image tokens that can be used during
192
+ pre-training.
193
+ hidden_size (`int`, *optional*, defaults to 768):
194
+ Dimensionality of the encoder layers and the pooler layer.
195
+ num_hidden_layers (`int`, *optional*, defaults to 12):
196
+ Number of hidden layers in the Transformer encoder.
197
+ num_attention_heads (`int`, *optional*, defaults to 12):
198
+ Number of attention heads for each attention layer in the Transformer encoder.
199
+ intermediate_size (`int`, *optional*, defaults to 3072):
200
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
201
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
202
+ The non-linear activation function (function or string) in the encoder and pooler. If string,
203
+ `"gelu"`, `"relu"`, `"selu"` and `"gelu_new"` are supported.
204
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
205
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
206
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
207
+ The dropout ratio for the attention probabilities.
208
+ initializer_range (`float`, *optional*, defaults to 0.02):
209
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
210
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
211
+ The epsilon used by the layer normalization layers.
212
+ image_size (`int`, *optional*, defaults to `224`):
213
+ The size (resolution) of each image.
214
+ patch_size (`int`, *optional*, defaults to `16`):
215
+ The size (resolution) of each patch.
216
+ num_channels (`int`, *optional*, defaults to `3`):
217
+ The number of input channels.
218
+ use_mask_token (`bool`, *optional*, defaults to `False`):
219
+ Whether to use a mask token for masked image modeling.
220
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `False`):
221
+ Whether to use BERT-style absolute position embeddings.
222
+ use_relative_position_bias (`bool`, *optional*, defaults to `False`):
223
+ Whether to use T5-style relative position embeddings in the self-attention layers.
224
+ use_shared_relative_position_bias (`bool`, *optional*, defaults to `False`):
225
+ Whether to use the same relative position embeddings across all self-attention layers of the Transformer.
226
+ layer_scale_init_value (`float`, *optional*, defaults to 0.1):
227
+ Scale to use in the self-attention layers. 0.1 for base, 1e-5 for large. Set 0 to disable layer scale.
228
+ drop_path_rate (`float`, *optional*, defaults to 0.1):
229
+ Stochastic depth rate per sample (when applied in the main path of residual layers).
230
+ use_mean_pooling (`bool`, *optional*, defaults to `True`):
231
+ Whether to mean pool the final hidden states of the patches instead of using the final hidden state of the
232
+ CLS token, before applying the classification head.
233
+ out_indices (`List[int]`, *optional*, defaults to `[3, 5, 7, 11]`):
234
+ Indices of the feature maps to use for semantic segmentation.
235
+ pool_scales (`Tuple[int]`, *optional*, defaults to `[1, 2, 3, 6]`):
236
+ Pooling scales used in Pooling Pyramid Module applied on the last feature map.
237
+ use_auxiliary_head (`bool`, *optional*, defaults to `True`):
238
+ Whether to use an auxiliary head during training.
239
+ auxiliary_loss_weight (`float`, *optional*, defaults to 0.4):
240
+ Weight of the cross-entropy loss of the auxiliary head.
241
+ auxiliary_channels (`int`, *optional*, defaults to 256):
242
+ Number of channels to use in the auxiliary head.
243
+ auxiliary_num_convs (`int`, *optional*, defaults to 1):
244
+ Number of convolutional layers to use in the auxiliary head.
245
+ auxiliary_concat_input (`bool`, *optional*, defaults to `False`):
246
+ Whether to concatenate the output of the auxiliary head with the input before the classification layer.
247
+ semantic_loss_ignore_index (`int`, *optional*, defaults to 255):
248
+ The index that is ignored by the loss function of the semantic segmentation model.
249
+
250
+ Example:
251
+
252
+ ```python
253
+ >>> from transformers import BeitModel, BeitConfig
254
+
255
+ >>> # Initializing a BEiT beit-base-patch16-224-in22k style configuration
256
+ >>> configuration = BeitConfig()
257
+
258
+ >>> # Initializing a model from the beit-base-patch16-224-in22k style configuration
259
+ >>> model = BeitModel(configuration)
260
+
261
+ >>> # Accessing the model configuration
262
+ >>> configuration = model.config
263
+ ```"""
264
+ model_type = "beit"
265
+
266
+ def __init__(
267
+ self,
268
+ vocab_size=8192,
269
+ hidden_size=768,
270
+ num_hidden_layers=12,
271
+ num_attention_heads=12,
272
+ intermediate_size=3072,
273
+ hidden_act="gelu",
274
+ hidden_dropout_prob=0.0,
275
+ attention_probs_dropout_prob=0.0,
276
+ initializer_range=0.02,
277
+ layer_norm_eps=1e-12,
278
+ is_encoder_decoder=False,
279
+ image_size=224,
280
+ patch_size=16,
281
+ num_channels=3,
282
+ use_mask_token=False,
283
+ use_absolute_position_embeddings=False,
284
+ use_relative_position_bias=False,
285
+ use_shared_relative_position_bias=False,
286
+ layer_scale_init_value=0.1,
287
+ drop_path_rate=0.1,
288
+ use_mean_pooling=True,
289
+ out_indices=[3, 5, 7, 11],
290
+ pool_scales=[1, 2, 3, 6],
291
+ use_auxiliary_head=True,
292
+ auxiliary_loss_weight=0.4,
293
+ auxiliary_channels=256,
294
+ auxiliary_num_convs=1,
295
+ auxiliary_concat_input=False,
296
+ semantic_loss_ignore_index=255,
297
+ token_keep_rate=1,
298
+ token_keep_strategy='cls_attn',
299
+ token_drop_loc=[3, 6, 9],
300
+ sparse_random_attn=None,
301
+ sparse_local_attn=1,
302
+ attn_block_size=1,
303
+ num_cls_tokens=1,
304
+ token_3d_order='none',
305
+ **kwargs
306
+ ):
307
+ super().__init__(**kwargs)
308
+
309
+ self.vocab_size = vocab_size
310
+ self.hidden_size = hidden_size
311
+ self.num_hidden_layers = num_hidden_layers
312
+ self.num_attention_heads = num_attention_heads
313
+ self.intermediate_size = intermediate_size
314
+ self.hidden_act = hidden_act
315
+ self.hidden_dropout_prob = hidden_dropout_prob
316
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
317
+ self.initializer_range = initializer_range
318
+ self.layer_norm_eps = layer_norm_eps
319
+
320
+ self.image_size = image_size
321
+ self.patch_size = patch_size
322
+ self.num_channels = num_channels
323
+ self.use_mask_token = use_mask_token
324
+ self.use_absolute_position_embeddings = use_absolute_position_embeddings
325
+ self.use_relative_position_bias = use_relative_position_bias
326
+ self.use_shared_relative_position_bias = use_shared_relative_position_bias
327
+ self.layer_scale_init_value = layer_scale_init_value
328
+ self.drop_path_rate = drop_path_rate
329
+ self.use_mean_pooling = use_mean_pooling
330
+ # decode head attributes (semantic segmentation)
331
+ self.out_indices = out_indices
332
+ self.pool_scales = pool_scales
333
+ # auxiliary head attributes (semantic segmentation)
334
+ self.use_auxiliary_head = use_auxiliary_head
335
+ self.auxiliary_loss_weight = auxiliary_loss_weight
336
+ self.auxiliary_channels = auxiliary_channels
337
+ self.auxiliary_num_convs = auxiliary_num_convs
338
+ self.auxiliary_concat_input = auxiliary_concat_input
339
+ self.semantic_loss_ignore_index = semantic_loss_ignore_index
340
+
341
+ # node sparsification
342
+ self.token_keep_rate = token_keep_rate
343
+ self.token_keep_strategy = token_keep_strategy
344
+ self.token_drop_loc = token_drop_loc
345
+ # edge sparsification
346
+ self.sparse_random_attn = sparse_random_attn
347
+ self.sparse_local_attn = sparse_local_attn
348
+ self.attn_block_size = attn_block_size
349
+ self.num_cls_tokens = num_cls_tokens
350
+ # token order
351
+ self.token_3d_order = token_3d_order
svitt/sparse_xbeit.py ADDED
@@ -0,0 +1,1585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch BEiT model. """
16
+
17
+
18
+ import collections.abc
19
+ import math
20
+ import numpy as np
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+ import zCurve
24
+ import hilbert
25
+
26
+ import torch
27
+ import torch.utils.checkpoint
28
+ from torch import nn
29
+ from torch.nn import CrossEntropyLoss, MSELoss
30
+ from einops import rearrange, repeat
31
+
32
+ from transformers.activations import ACT2FN
33
+ from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
34
+ from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, MaskedLMOutput, SequenceClassifierOutput
35
+ from transformers.modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer
36
+ from svitt.sparse_config import BeitConfig
37
+
38
+
39
+ _CONFIG_FOR_DOC = "BeitConfig"
40
+ _CHECKPOINT_FOR_DOC = "microsoft/beit-base-patch16-224"
41
+
42
+ BEIT_PRETRAINED_MODEL_ARCHIVE_LIST = [
43
+ "microsoft/beit-base-patch16-224",
44
+ # See all BEiT models at https://huggingface.co/models?filter=beit
45
+ ]
46
+
47
+
48
+ @dataclass
49
+ class BeitModelOutputWithPooling(BaseModelOutputWithPooling):
50
+ """
51
+ Class for outputs of :class:`~transformers.BeitModel`.
52
+
53
+ Args:
54
+ last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
55
+ Sequence of hidden-states at the output of the last layer of the model.
56
+ pooler_output (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, hidden_size)`):
57
+ Average of the last layer hidden states of the patch tokens (excluding the `[CLS]` token) if
58
+ `config.use_mean_pooling` is set to True. If set to False, then the final hidden state of the `[CLS]` token
59
+ will be returned.
60
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
61
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
62
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
63
+
64
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
65
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
66
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
67
+ sequence_length, sequence_length)`.
68
+
69
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
70
+ heads.
71
+ """
72
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
73
+
74
+
75
+ @dataclass
76
+ class BeitModelOutput(BaseModelOutput):
77
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
78
+
79
+
80
+ # Inspired by
81
+ # https://github.com/rwightman/pytorch-image-models/blob/b9bd960a032c75ca6b808ddeed76bee5f3ed4972/timm/models/layers/helpers.py
82
+ # From PyTorch internals
83
+ def to_2tuple(x):
84
+ if isinstance(x, collections.abc.Iterable):
85
+ return x
86
+ return (x, x)
87
+
88
+
89
+ # Based on https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py
90
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
91
+ """
92
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
93
+
94
+ Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks,
95
+ however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
96
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the
97
+ layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the
98
+ argument.
99
+ """
100
+ if drop_prob == 0.0 or not training:
101
+ return x
102
+ keep_prob = 1 - drop_prob
103
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
104
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
105
+ random_tensor.floor_() # binarize
106
+ output = x.div(keep_prob) * random_tensor
107
+ return output
108
+
109
+
110
+ class DropPath(nn.Module):
111
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
112
+
113
+ def __init__(self, drop_prob=None):
114
+ super().__init__()
115
+ self.drop_prob = drop_prob
116
+
117
+ def forward(self, x):
118
+ return drop_path(x, self.drop_prob, self.training)
119
+
120
+ def extra_repr(self) -> str:
121
+ return "p={}".format(self.drop_prob)
122
+
123
+
124
+ # Based on timm implementation, which can be found here:
125
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
126
+ class BeitEmbeddings(nn.Module):
127
+ """
128
+ Construct the CLS token, position and patch embeddings. Optionally, also the mask token.
129
+
130
+ """
131
+
132
+ def __init__(self, config):
133
+ super().__init__()
134
+
135
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
136
+ if config.use_mask_token:
137
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
138
+ else:
139
+ self.mask_token = None
140
+ self.patch_embeddings = PatchEmbeddings(
141
+ image_size=config.image_size,
142
+ patch_size=config.patch_size,
143
+ num_channels=config.num_channels,
144
+ embed_dim=config.hidden_size,
145
+ )
146
+ num_patches = self.patch_embeddings.num_patches
147
+ if config.use_absolute_position_embeddings:
148
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size))
149
+ else:
150
+ self.position_embeddings = None
151
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
152
+
153
+ def forward(self, pixel_values, bool_masked_pos=None):
154
+
155
+ if pixel_values.ndim == 5: # video input=
156
+ embeddings = self.patch_embeddings(pixel_values.flatten(0, 1))
157
+ embeddings = rearrange(embeddings, '(b m) n d -> b (m n) d', m=pixel_values.shape[1])
158
+ else: # image input
159
+ embeddings = self.patch_embeddings(pixel_values)
160
+
161
+ batch_size, seq_len, _ = embeddings.size()
162
+
163
+ cls_tokens = self.cls_token.expand(batch_size, -1, -1)
164
+ if bool_masked_pos is not None:
165
+ mask_tokens = self.mask_token.expand(batch_size, seq_len, -1)
166
+ # replace the masked visual tokens by mask_tokens
167
+ w = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens)
168
+ embeddings = embeddings * (1 - w) + mask_tokens * w
169
+
170
+ embeddings = torch.cat((cls_tokens, embeddings), dim=1)
171
+ if self.position_embeddings is not None:
172
+ embeddings = embeddings + self.position_embeddings
173
+ embeddings = self.dropout(embeddings)
174
+
175
+ return embeddings
176
+
177
+
178
+ # Based on timm implementation, which can be found here:
179
+ # https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
180
+ class PatchEmbeddings(nn.Module):
181
+ """
182
+ Image to Patch Embedding.
183
+ """
184
+
185
+ def __init__(self, image_size=224, patch_size=16, num_channels=3, embed_dim=768):
186
+ super().__init__()
187
+ image_size = to_2tuple(image_size)
188
+ patch_size = to_2tuple(patch_size)
189
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
190
+ patch_shape = (image_size[0] // patch_size[0], image_size[1] // patch_size[1])
191
+ self.image_size = image_size
192
+ self.patch_size = patch_size
193
+ self.num_patches = num_patches
194
+ self.patch_shape = patch_shape
195
+
196
+ self.projection = nn.Conv2d(num_channels, embed_dim, kernel_size=patch_size, stride=patch_size)
197
+
198
+ def forward(self, pixel_values):
199
+ batch_size, num_channels, height, width = pixel_values.shape
200
+ # FIXME look at relaxing size constraints
201
+ if height != self.image_size[0] or width != self.image_size[1]:
202
+ raise ValueError(
203
+ f"Input image size ({height}*{width}) doesn't match model ({self.image_size[0]}*{self.image_size[1]})."
204
+ )
205
+ x = self.projection(pixel_values).flatten(2).transpose(1, 2)
206
+
207
+ return x
208
+
209
+
210
+ class BeitSelfAttention(nn.Module):
211
+ def __init__(self, config, window_size=None):
212
+ super().__init__()
213
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
214
+ raise ValueError(
215
+ f"The hidden size {config.hidden_size,} is not a multiple of the number of attention "
216
+ f"heads {config.num_attention_heads}."
217
+ )
218
+
219
+ self.num_attention_heads = config.num_attention_heads
220
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
221
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
222
+
223
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
224
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
225
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
226
+
227
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
228
+
229
+ # sparse params
230
+ self.random_attn = config.sparse_random_attn
231
+ self.local_attn = config.sparse_local_attn
232
+ self.block_size = config.attn_block_size
233
+ self.num_cls_tokens = config.num_cls_tokens
234
+ if self.local_attn is not None and self.random_attn is not None:
235
+ self.num_kv_blocks = self.local_attn + self.random_attn
236
+
237
+ if window_size:
238
+ self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
239
+ else:
240
+ self.relative_position_bias = None
241
+
242
+ def split_heads(self, x):
243
+ return rearrange(x, 'b n (h d) -> b h n d', h=self.num_attention_heads)
244
+
245
+ def join_heads(self, x):
246
+ return rearrange(x, 'b h n d -> b n (h d)')
247
+
248
+ def blockify(self, x):
249
+ assert x.dim() == 4, f"Unsupported input shape {x.shape}"
250
+ seq_len = x.shape[2]
251
+ if seq_len % self.block_size > 0: # seq_len not divisible by block_size, zero pad
252
+ pad_len = self.block_size - seq_len % self.block_size
253
+ x = nn.functional.pad(x, (0, 0, 0, pad_len))
254
+ else:
255
+ pad_len = 0
256
+ x = rearrange(x, 'b h (m n) d -> b h m n d', n=self.block_size)
257
+ return x, pad_len
258
+
259
+ def dense_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, k_idx=None):
260
+ # q, k, v: (bsz, num_heads, seq_len, dims)
261
+ assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
262
+ sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
263
+ sim = sim / math.sqrt(self.attention_head_size)
264
+
265
+ # Add relative position bias if present.
266
+ if self.relative_position_bias is not None:
267
+ if q_idx is not None and q_idx.ndim == 2:
268
+ assert k_idx is not None and len(q_idx) == len(k_idx)
269
+ bias = torch.stack([
270
+ self.relative_position_bias(from_idx=q_idx_, to_idx=k_idx_)
271
+ for q_idx_, k_idx_ in zip(q_idx, k_idx)
272
+ ])
273
+ else:
274
+ bias = self.relative_position_bias(from_idx=q_idx, to_idx=k_idx).unsqueeze(0)
275
+ sim = sim + bias
276
+
277
+ # Add shared relative position bias if provided.
278
+ if relative_position_bias is not None:
279
+ sim = sim + relative_position_bias
280
+
281
+ # Normalize the attention scores to probabilities.
282
+ attn = sim.softmax(dim=-1)
283
+ attn = self.dropout(attn)
284
+ if head_mask is not None:
285
+ attn = attn * head_mask
286
+
287
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
288
+ return out, attn
289
+
290
+ def _sparse_attn_relative_position_bias(self, q_idx, pad_q, attn_idx, group_len):
291
+ q_idx_blk = nn.functional.pad(q_idx, (0, pad_q)).view(-1, self.block_size)
292
+ attn_idx_flt = rearrange(q_idx_blk[attn_idx], 'm n j -> m (n j)') # (seq_len, num_kv_blocks * group_len)
293
+ cls_idx = torch.arange(self.num_cls_tokens, device=q_idx.device)
294
+ cls_idx = repeat(cls_idx, 'n -> m n', m=len(attn_idx_flt))
295
+ attn_idx_flt = torch.cat((cls_idx, attn_idx_flt), dim=1)
296
+ attn_idx_flt = repeat(attn_idx_flt, 'm n -> (m i) n', i=group_len)
297
+ if pad_q > 0:
298
+ attn_idx_flt = attn_idx_flt[:-pad_q]
299
+ bias_flt = self.relative_position_bias(from_idx=q_idx, to_idx=attn_idx_flt)
300
+ if pad_q > 0:
301
+ bias_flt = nn.functional.pad(bias_flt, (0, 0, 0, pad_q))
302
+ return rearrange(bias_flt, 'h (m i) n -> h m i n', i=group_len) # num_heads, seq_len, group_len, (num_kv_blocks * group_len + num_cls_tokens)
303
+
304
+ def sparse_attention(self, q, k, v, head_mask=None, relative_position_bias=None, q_idx=None, mimic_full=False):
305
+ assert self.local_attn == 0 or self.local_attn % 2 == 1, "Even local window size not supported"
306
+ assert k.shape[2] == v.shape[2], "Key and value shapes mismatch"
307
+
308
+
309
+ if not mimic_full:
310
+ cls_k, k = k[..., :self.num_cls_tokens, :], k[..., self.num_cls_tokens:, :] # cls_k: (bsz, num_heads, num_cls_tokens, dims)
311
+ cls_v, v = v[..., :self.num_cls_tokens, :], v[..., self.num_cls_tokens:, :]
312
+
313
+ # pad token sequence to multiples of block_size
314
+ if mimic_full:
315
+ bsz, num_heads, seq_len, dims = q.shape
316
+ else:
317
+ q, pad_q = self.blockify(q) # q: (bsz, num_heads, seq_len, group_len, dims)
318
+ k, pad_k = self.blockify(k)
319
+ v, pad_v = self.blockify(v)
320
+ bsz, num_heads, seq_len, group_len, dims = q.shape
321
+
322
+ # global attention
323
+ cls_sim = torch.einsum('b h n i d, b h j d -> b h n i j', q, cls_k) # (bsz, num_heads, seq_len, group_len, num_cls_tokens)
324
+
325
+ if mimic_full:
326
+ sim = torch.einsum('b h i d, b h j d -> b h i j', q, k)
327
+ sim = sim / math.sqrt(self.attention_head_size)
328
+ sim = sim + self.relative_position_bias(from_idx=q_idx).unsqueeze(0)
329
+
330
+ else:
331
+ # initialize empty sim matrix
332
+ sim = torch.empty((bsz, num_heads, seq_len, self.num_kv_blocks, group_len, group_len), device=q.device)
333
+ attn_idx = torch.zeros((seq_len, self.num_kv_blocks), dtype=torch.int64, device=q.device)
334
+
335
+ # local window attention
336
+ cnt = 0
337
+ if self.local_attn > 0:
338
+ num_rolls = self.local_attn // 2
339
+ for r in range(-num_rolls, num_rolls + 1):
340
+ sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.roll(-r, dims=2))
341
+ attn_idx[:, cnt] = torch.arange(seq_len, device=q.device).roll(r)
342
+ cnt += 1
343
+
344
+ # random attention
345
+ if self.random_attn > 0:
346
+ # generate random attention pattern
347
+ rand = torch.rand((seq_len, seq_len), device=q.device)
348
+ if self.local_attn > 0:
349
+ # avoid overlap with local attention
350
+ for r in range(-num_rolls, num_rolls + 1):
351
+ tgt_idx = list(i % seq_len for i in range(r, seq_len + r))
352
+ rand[range(seq_len), tgt_idx] = 0
353
+ _, idx = rand.topk(self.random_attn, dim=-1) # seq_len, random_attn
354
+ idx, _ = torch.sort(idx, dim=1)
355
+ attn_idx[:, cnt:] = idx
356
+
357
+ idx_ = repeat(idx, 'n m -> b h n m i d', b=bsz, h=num_heads, i=group_len, d=dims)
358
+
359
+ for r in range(self.random_attn):
360
+ sim[..., cnt, :, :] = torch.einsum('b h n i d, b h n j d -> b h n i j', q, k.gather(2, idx_[..., r, :, :]))
361
+ cnt += 1
362
+
363
+ sim = rearrange(sim, 'b h m n i j -> b h m i (n j)') # (bsz, num_heads, seq_len, group_len, num_kv_blocks * group_len)
364
+ sim = torch.cat((cls_sim, sim), -1)
365
+ sim = sim / math.sqrt(self.attention_head_size)
366
+
367
+ # Add relative position bias if present.
368
+ # NOTE: we assume q and k (excluding cls) use same token indexing, for relative position embedding
369
+ if self.relative_position_bias is not None:
370
+ assert q_idx is not None, "query index required for relative position bias"
371
+ if q_idx.ndim == 2:
372
+ # different indices for each sample
373
+ bias = torch.stack([
374
+ self._sparse_attn_relative_position_bias(q_idx_, pad_q, attn_idx, group_len)
375
+ for q_idx_ in q_idx
376
+ ])
377
+ else:
378
+ bias = self._sparse_attn_relative_position_bias(q_idx, pad_q, attn_idx, group_len).unsqueeze(0)
379
+ sim = sim + bias
380
+
381
+ # Add shared relative position bias if provided.
382
+ if relative_position_bias is not None:
383
+ raise NotImplementedError
384
+ sim = sim + relative_position_bias
385
+
386
+ attn = sim.softmax(dim=-1)
387
+ attn = self.dropout(attn)
388
+ if head_mask is not None:
389
+ attn = attn * head_mask
390
+
391
+ # block attention
392
+ if mimic_full:
393
+ out = torch.einsum('b h i j, b h j d -> b h i d', attn, v)
394
+
395
+ else:
396
+ out = torch.empty((bsz, num_heads, seq_len, group_len, dims), device=q.device)
397
+ for m in range(seq_len):
398
+ v_row = torch.index_select(v, 2, attn_idx[m])
399
+ v_row = rearrange(v_row, 'b h n j d -> b h (n j) d') # (bsz, num_heads, num_kv_blocks * group_len, dims)
400
+ v_row = torch.cat((cls_v, v_row), 2)
401
+ out[..., m, :, :] = torch.einsum('b h i j, b h j d -> b h i d', attn[..., m, :, :], v_row)
402
+ out = rearrange(out, 'b h n i d -> b h (n i) d')
403
+ if pad_q > 0:
404
+ out = out[..., :-pad_q, :]
405
+
406
+ return out, attn
407
+
408
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
409
+ # compute qkv
410
+ q = self.split_heads(self.query(hidden_states))
411
+ k = self.split_heads(self.key(hidden_states))
412
+ v = self.split_heads(self.value(hidden_states))
413
+
414
+ # combine local token_idx with cls tokens
415
+ # NOTE: assume token_idx starts from 0
416
+ cls_q_idx = torch.arange(self.num_cls_tokens, device=q.device)
417
+ if token_idx is not None:
418
+ if token_idx.ndim == 2:
419
+ cls_q_idx = repeat(cls_q_idx, 'n -> b n', b=q.shape[0])
420
+ all_token_idx = torch.cat((cls_q_idx, token_idx + self.num_cls_tokens), dim=-1)
421
+ else:
422
+ all_token_idx = None
423
+
424
+ if self.random_attn is None:
425
+ outputs, attention_probs = self.dense_attention(q, k, v, head_mask=head_mask,
426
+ relative_position_bias=relative_position_bias,
427
+ q_idx=all_token_idx,
428
+ k_idx=all_token_idx)
429
+ cls_attention_probs = attention_probs[..., :self.num_cls_tokens, :]
430
+
431
+ else:
432
+ cls_q, q = q[..., :self.num_cls_tokens, :], q[..., self.num_cls_tokens:, :]
433
+
434
+ # dense global attention (num_cls_tokens, seq_len)
435
+ cls_outputs, cls_attention_probs = self.dense_attention(cls_q, k, v, head_mask=head_mask,
436
+ relative_position_bias=relative_position_bias,
437
+ q_idx=cls_q_idx,
438
+ k_idx=all_token_idx)
439
+
440
+ # sparse local attention (local_seq_len, seq_len)
441
+ if token_idx is None:
442
+ token_idx = torch.arange(q.shape[-2], device=q.device)
443
+ outputs, attention_probs = self.sparse_attention(q, k, v, head_mask=head_mask,
444
+ relative_position_bias=relative_position_bias,
445
+ q_idx=token_idx + self.num_cls_tokens)
446
+
447
+ outputs = torch.cat((cls_outputs, outputs), dim=2)
448
+
449
+ outputs = self.join_heads(outputs)
450
+
451
+ outputs = (outputs, cls_attention_probs) if output_attentions else (outputs,)
452
+
453
+ return outputs
454
+
455
+
456
+ class BeitSelfOutput(nn.Module):
457
+ """
458
+ The residual connection is defined in BeitLayer instead of here (as is the case with other models), due to the
459
+ layernorm applied before each block.
460
+ """
461
+
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
465
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
466
+
467
+ def forward(self, hidden_states, input_tensor, gamma=None):
468
+ hidden_states = self.dense(hidden_states)
469
+ hidden_states = self.dropout(hidden_states)
470
+
471
+ return hidden_states
472
+
473
+
474
+ class BeitAttention(nn.Module):
475
+ def __init__(self, config, window_size=None):
476
+ super().__init__()
477
+ self.attention = BeitSelfAttention(config, window_size=window_size)
478
+ self.output = BeitSelfOutput(config)
479
+ self.pruned_heads = set()
480
+
481
+ def prune_heads(self, heads):
482
+ if len(heads) == 0:
483
+ return
484
+ heads, index = find_pruneable_heads_and_indices(
485
+ heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads
486
+ )
487
+
488
+ # Prune linear layers
489
+ self.attention.query = prune_linear_layer(self.attention.query, index)
490
+ self.attention.key = prune_linear_layer(self.attention.key, index)
491
+ self.attention.value = prune_linear_layer(self.attention.value, index)
492
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
493
+
494
+ # Update hyper params and store pruned heads
495
+ self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads)
496
+ self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads
497
+ self.pruned_heads = self.pruned_heads.union(heads)
498
+
499
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
500
+ self_outputs = self.attention(hidden_states, head_mask, output_attentions, relative_position_bias, token_idx)
501
+
502
+ attention_output = self.output(self_outputs[0], hidden_states)
503
+
504
+ outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
505
+ return outputs
506
+
507
+
508
+ class BeitIntermediate(nn.Module):
509
+ def __init__(self, config):
510
+ super().__init__()
511
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
512
+ if isinstance(config.hidden_act, str):
513
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
514
+ else:
515
+ self.intermediate_act_fn = config.hidden_act
516
+
517
+ def forward(self, hidden_states):
518
+ hidden_states = self.dense(hidden_states)
519
+ hidden_states = self.intermediate_act_fn(hidden_states)
520
+
521
+ return hidden_states
522
+
523
+
524
+ class BeitOutput(nn.Module):
525
+ def __init__(self, config):
526
+ super().__init__()
527
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
528
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
529
+
530
+ def forward(self, hidden_states):
531
+ hidden_states = self.dense(hidden_states)
532
+ hidden_states = self.dropout(hidden_states)
533
+
534
+ return hidden_states
535
+
536
+
537
+ class BeitLayer(nn.Module):
538
+ """This corresponds to the Block class in the timm implementation."""
539
+
540
+ def __init__(self, config, window_size=None, drop_path_rate=0.0,
541
+ token_keep_rate=1.0):
542
+ super().__init__()
543
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
544
+ self.seq_len_dim = 1
545
+ self.attention = BeitAttention(config, window_size=window_size)
546
+ self.intermediate = BeitIntermediate(config)
547
+ self.output = BeitOutput(config)
548
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
549
+ self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
550
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
551
+
552
+ # sparse params
553
+ self.token_keep_rate = token_keep_rate
554
+ self.token_keep_strategy = config.token_keep_strategy
555
+ self.num_cls_tokens = config.num_cls_tokens
556
+
557
+ init_values = config.layer_scale_init_value
558
+ if init_values > 0:
559
+ self.lambda_1 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
560
+ self.lambda_2 = nn.Parameter(init_values * torch.ones((config.hidden_size)), requires_grad=True)
561
+ else:
562
+ self.lambda_1, self.lambda_2 = None, None
563
+
564
+ def sparsify(self, x, attn):
565
+ x_cls, x_ = x[:, :self.num_cls_tokens], x[:, self.num_cls_tokens:]
566
+ assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
567
+ left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
568
+
569
+ if self.token_keep_strategy == 'cls_attn':
570
+ if len(attn.shape) == 4:
571
+ attn = attn.mean(1) # pool over attention heads
572
+ cls_attn = attn[:, 0, self.num_cls_tokens:]
573
+ _, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
574
+
575
+ elif self.token_keep_strategy == 'random':
576
+ rand = torch.rand(x_.shape[:2], device=x_.device)
577
+ _, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
578
+
579
+ else:
580
+ raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
581
+
582
+ idx, _ = torch.sort(idx, dim=1)
583
+ index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
584
+ outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
585
+ return outputs, idx
586
+
587
+ def forward(self, hidden_states, head_mask=None, output_attentions=False, relative_position_bias=None, token_idx=None):
588
+ self_attention_outputs = self.attention(
589
+ self.layernorm_before(hidden_states), # in BEiT, layernorm is applied before self-attention
590
+ head_mask,
591
+ output_attentions=(output_attentions or self.token_keep_rate < 1),
592
+ relative_position_bias=relative_position_bias,
593
+ token_idx=token_idx
594
+ )
595
+ attention_output = self_attention_outputs[0]
596
+ outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
597
+
598
+ # apply lambda_1 if present
599
+ if self.lambda_1 is not None:
600
+ attention_output = self.lambda_1 * attention_output
601
+
602
+ # first residual connection
603
+ hidden_states = self.drop_path(attention_output) + hidden_states
604
+
605
+ # in BEiT, layernorm is also applied after self-attention
606
+ layer_output = self.layernorm_after(hidden_states)
607
+
608
+ layer_output = self.intermediate(layer_output)
609
+ layer_output = self.output(layer_output)
610
+
611
+ if self.lambda_2 is not None:
612
+ layer_output = self.lambda_2 * layer_output
613
+
614
+ # second residual connection
615
+ layer_output = self.drop_path(layer_output) + hidden_states
616
+
617
+ # node sparsification
618
+ if self.token_keep_rate < 1:
619
+ layer_output, token_keep_idx = self.sparsify(layer_output, outputs[0])
620
+ if token_idx is not None:
621
+ if token_idx.ndim == 1:
622
+ token_idx = repeat(token_idx, 'n -> b n', b=len(token_keep_idx))
623
+ token_keep_idx = token_idx.gather(1, token_keep_idx)
624
+ outputs = outputs + (token_keep_idx,)
625
+
626
+ outputs = (layer_output,) + outputs
627
+
628
+ return outputs
629
+
630
+
631
+ class BeitRelativePositionBias(nn.Module):
632
+ def __init__(self, config, window_size):
633
+ super().__init__()
634
+ self.window_size = window_size
635
+ self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
636
+ self.relative_position_bias_table = nn.Parameter(
637
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
638
+ ) # 2*Wh-1 * 2*Ww-1, nH
639
+ # cls to token & token 2 cls & cls to cls
640
+
641
+ # get pair-wise relative position index for each token inside the window
642
+ coords_h = torch.arange(window_size[0])
643
+ coords_w = torch.arange(window_size[1])
644
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
645
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
646
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
647
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
648
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
649
+ relative_coords[:, :, 1] += window_size[1] - 1
650
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
651
+ relative_position_index = torch.zeros(
652
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
653
+ )
654
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
655
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
656
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
657
+ relative_position_index[0, 0] = self.num_relative_distance - 1
658
+
659
+ self.register_buffer("relative_position_index", relative_position_index, persistent=False)
660
+
661
+ def forward(self):
662
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
663
+ self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
664
+ ) # Wh*Ww,Wh*Ww,nH
665
+
666
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
667
+
668
+
669
+ class BeitRelativePositionBias3D(nn.Module):
670
+ """
671
+ 3D relative position bias
672
+ """
673
+ def __init__(self, config, window_size, num_cls_tokens=1):
674
+ super().__init__()
675
+ self.window_size = window_size
676
+ self.num_cls_tokens = num_cls_tokens
677
+
678
+ relative_size = [w * 2 - 1 for w in window_size]
679
+ self.num_relative_distance = np.prod(relative_size) + 2 * num_cls_tokens + num_cls_tokens ** 2
680
+
681
+ self.relative_position_bias_table = nn.Parameter(
682
+ torch.zeros(self.num_relative_distance, config.num_attention_heads)
683
+ )
684
+
685
+ # get pair-wise relative position index for each token inside the window
686
+ coords_range = [torch.arange(w) for w in window_size]
687
+ coords_flatten = torch.stack(torch.meshgrid(coords_range)).flatten(1)
688
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
689
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
690
+
691
+ for i, w in enumerate(window_size):
692
+ relative_coords[:, :, i] += w - 1 # shift to start from 0
693
+
694
+ for i, r in enumerate(relative_size[1:]):
695
+ relative_coords[:, :, :i + 1] *= r
696
+
697
+ self.seq_len = np.prod(window_size) + num_cls_tokens
698
+ relative_position_index = torch.zeros((self.seq_len, self.seq_len), dtype=relative_coords.dtype)
699
+ relative_position_index[num_cls_tokens:, num_cls_tokens:] = relative_coords.sum(-1)
700
+
701
+ start = np.prod(relative_size)
702
+ cls2loc = torch.arange(num_cls_tokens).unsqueeze(1) + start
703
+ relative_position_index[:num_cls_tokens, num_cls_tokens:] = cls2loc
704
+ start += num_cls_tokens
705
+
706
+ loc2cls = torch.arange(num_cls_tokens).unsqueeze(0) + start
707
+ relative_position_index[num_cls_tokens:, :num_cls_tokens] = loc2cls
708
+ start += num_cls_tokens
709
+
710
+ cls2cls = torch.arange(num_cls_tokens ** 2).view(num_cls_tokens, num_cls_tokens) + start
711
+ relative_position_index[:num_cls_tokens, :num_cls_tokens] = cls2cls
712
+
713
+ self.register_buffer("relative_position_index", relative_position_index)
714
+
715
+ def forward(self, from_idx=None, to_idx=None):
716
+ """
717
+ from_idx: indices of query tokens (1-dim)
718
+ to_idx: indices of key/value tokens (1-dim, or 2-dim w/ one row per query)
719
+ """
720
+ attn_idx = self.relative_position_index
721
+
722
+ # query indices
723
+ if from_idx is not None:
724
+ attn_idx = attn_idx[from_idx]
725
+
726
+ # key indices
727
+ if to_idx is not None:
728
+ assert to_idx.ndim in (1, 2), "to_idx must be 1- or 2-dimensional tensors"
729
+ if to_idx.ndim == 1:
730
+ attn_idx = attn_idx[:, to_idx]
731
+ else:
732
+ attn_idx = attn_idx.gather(1, to_idx)
733
+
734
+ rows, cols = attn_idx.shape
735
+ relative_position_bias = self.relative_position_bias_table[attn_idx.flatten()]
736
+ relative_position_bias = rearrange(relative_position_bias, '(i j) h -> h i j', i=rows, j=cols)
737
+ return relative_position_bias.contiguous()
738
+
739
+
740
+ class BeitEncoder(nn.Module):
741
+ def __init__(self, config, window_size=None):
742
+ super().__init__()
743
+ self.config = config
744
+ if config.use_shared_relative_position_bias:
745
+ self.relative_position_bias = BeitRelativePositionBias3D(config, window_size=window_size)
746
+ else:
747
+ self.relative_position_bias = None
748
+
749
+ self._register_token_order(window_size)
750
+
751
+ # stochastic depth decay rule
752
+ dpr = [x.item() for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)]
753
+
754
+ # node sparsification
755
+ token_keep_rate = [1] * config.num_hidden_layers
756
+ for loc in config.token_drop_loc:
757
+ token_keep_rate[loc] = config.token_keep_rate
758
+
759
+ self.layer = nn.ModuleList(
760
+ [
761
+ BeitLayer(
762
+ config,
763
+ window_size=window_size if config.use_relative_position_bias else None,
764
+ drop_path_rate=dpr[i], token_keep_rate=token_keep_rate[i]
765
+ )
766
+ for i in range(config.num_hidden_layers)
767
+ ]
768
+ )
769
+
770
+ self.gradient_checkpointing = False
771
+
772
+ def _register_token_order(self, shape):
773
+ if self.config.token_3d_order == 'none':
774
+ order = None
775
+ elif self.config.token_3d_order == 'zcurve':
776
+ nbits = max(shape).bit_length()
777
+ coords = list(np.ndindex(*shape))
778
+ order = zCurve.par_interlace(coords, len(shape), nbits)
779
+ order = torch.tensor(np.argsort(order))
780
+ elif self.config.token_3d_order == 'hilbert':
781
+ nbits = max(shape).bit_length()
782
+ coords = list(np.ndindex(*shape))
783
+ order = hilbert.encode(np.stack(coords), len(shape), nbits)
784
+ order = torch.tensor(np.argsort(order))
785
+ else:
786
+ raise NotImplementedError(f"Token ordering {self.config.token_3d_order} not supported")
787
+
788
+ if order is not None:
789
+ self.register_buffer('token_order', order, persistent=False)
790
+ else:
791
+ self.token_order = None
792
+
793
+ def forward(
794
+ self,
795
+ hidden_states,
796
+ head_mask=None,
797
+ output_attentions=False,
798
+ output_hidden_states=False,
799
+ output_token_idx=False,
800
+ return_dict=True,
801
+ ):
802
+ all_hidden_states = () if output_hidden_states else None
803
+ all_self_attentions = () if output_attentions else None
804
+ all_token_idx = () if output_token_idx else None
805
+
806
+ token_idx = self.token_order
807
+ if token_idx is not None:
808
+ cls_states, local_states = hidden_states[:, :self.config.num_cls_tokens], hidden_states[:, self.config.num_cls_tokens:]
809
+ local_states = torch.index_select(local_states, dim=1, index=token_idx)
810
+ hidden_states = torch.cat((cls_states, local_states), 1)
811
+
812
+ for i, layer_module in enumerate(self.layer):
813
+ if output_hidden_states:
814
+ all_hidden_states = all_hidden_states + (hidden_states,)
815
+
816
+ layer_head_mask = head_mask[i] if head_mask is not None else None
817
+
818
+ if self.gradient_checkpointing and self.training:
819
+
820
+ def create_custom_forward(module):
821
+ def custom_forward(*inputs):
822
+ return module(*inputs, output_attentions)
823
+
824
+ return custom_forward
825
+
826
+ layer_outputs = torch.utils.checkpoint.checkpoint(
827
+ create_custom_forward(layer_module),
828
+ hidden_states,
829
+ layer_head_mask,
830
+ )
831
+ else:
832
+ relative_position_bias = (
833
+ self.relative_position_bias() if self.relative_position_bias is not None else None
834
+ )
835
+ layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions, relative_position_bias, token_idx)
836
+
837
+ hidden_states = layer_outputs[0]
838
+
839
+ if layer_module.token_keep_rate < 1:
840
+ token_idx = layer_outputs[-1]
841
+
842
+ if output_token_idx:
843
+ all_token_idx = all_token_idx + (token_idx,)
844
+
845
+ if output_attentions:
846
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
847
+
848
+ if output_hidden_states:
849
+ all_hidden_states = all_hidden_states + (hidden_states,)
850
+
851
+ if not return_dict:
852
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
853
+ return BeitModelOutput(
854
+ last_hidden_state=hidden_states,
855
+ hidden_states=all_hidden_states,
856
+ attentions=all_self_attentions,
857
+ token_idx=all_token_idx
858
+ )
859
+
860
+
861
+ class BeitPreTrainedModel(PreTrainedModel):
862
+ """
863
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
864
+ models.
865
+ """
866
+
867
+ config_class = BeitConfig
868
+ base_model_prefix = "beit"
869
+ supports_gradient_checkpointing = True
870
+
871
+ def _init_weights(self, module):
872
+ """Initialize the weights"""
873
+ if isinstance(module, (nn.Linear, nn.Conv2d, nn.ConvTranspose2d)):
874
+ # Slightly different from the TF version which uses truncated_normal for initialization
875
+ # cf https://github.com/pytorch/pytorch/pull/5617
876
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
877
+ if module.bias is not None:
878
+ module.bias.data.zero_()
879
+ elif isinstance(module, nn.Embedding):
880
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
881
+ if module.padding_idx is not None:
882
+ module.weight.data[module.padding_idx].zero_()
883
+ elif isinstance(module, nn.LayerNorm):
884
+ module.bias.data.zero_()
885
+ module.weight.data.fill_(1.0)
886
+
887
+ def _set_gradient_checkpointing(self, module, value=False):
888
+ if isinstance(module, BeitEncoder):
889
+ module.gradient_checkpointing = value
890
+
891
+
892
+ BEIT_START_DOCSTRING = r"""
893
+ This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ subclass. Use
894
+ it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
895
+ behavior.
896
+
897
+ Parameters:
898
+ config (:class:`~transformers.BeitConfig`): Model configuration class with all the parameters of the model.
899
+ Initializing with a config file does not load the weights associated with the model, only the
900
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
901
+ weights.
902
+ """
903
+
904
+ BEIT_INPUTS_DOCSTRING = r"""
905
+ Args:
906
+ pixel_values (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_channels, height, width)`):
907
+ Pixel values. Pixel values can be obtained using :class:`~transformers.BeitFeatureExtractor`. See
908
+ :meth:`transformers.BeitFeatureExtractor.__call__` for details.
909
+
910
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
911
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
912
+
913
+ - 1 indicates the head is **not masked**,
914
+ - 0 indicates the head is **masked**.
915
+
916
+ output_attentions (:obj:`bool`, `optional`):
917
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
918
+ tensors for more detail.
919
+ output_hidden_states (:obj:`bool`, `optional`):
920
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
921
+ more detail.
922
+ return_dict (:obj:`bool`, `optional`):
923
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
924
+ """
925
+
926
+
927
+ @add_start_docstrings(
928
+ "The bare Beit Model transformer outputting raw hidden-states without any specific head on top.",
929
+ BEIT_START_DOCSTRING,
930
+ )
931
+ class BeitModel(BeitPreTrainedModel):
932
+ def __init__(self, config, add_pooling_layer=True, num_frames=None):
933
+ super().__init__(config)
934
+ self.config = config
935
+
936
+ self.embeddings = BeitEmbeddings(config)
937
+ self.window_size = self.embeddings.patch_embeddings.patch_shape
938
+ if num_frames is not None:
939
+ self.window_size = (num_frames,) + self.window_size
940
+ self.encoder = BeitEncoder(config, window_size=self.window_size)
941
+
942
+ self.layernorm = (
943
+ nn.Identity() if config.use_mean_pooling else nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
944
+ )
945
+ self.pooler = BeitPooler(config) if add_pooling_layer else None
946
+
947
+ # Initialize weights and apply final processing
948
+ self.post_init()
949
+
950
+ def get_input_embeddings(self):
951
+ return self.embeddings.patch_embeddings
952
+
953
+ def _prune_heads(self, heads_to_prune):
954
+ """
955
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
956
+ class PreTrainedModel
957
+ """
958
+ for layer, heads in heads_to_prune.items():
959
+ self.encoder.layer[layer].attention.prune_heads(heads)
960
+
961
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
962
+ @replace_return_docstrings(output_type=BeitModelOutputWithPooling, config_class=_CONFIG_FOR_DOC)
963
+ def forward(
964
+ self,
965
+ pixel_values=None,
966
+ bool_masked_pos=None,
967
+ head_mask=None,
968
+ output_attentions=None,
969
+ output_hidden_states=None,
970
+ output_token_idx=None,
971
+ return_dict=None,
972
+ ):
973
+ r"""
974
+ Returns:
975
+
976
+ Examples::
977
+
978
+ >>> from transformers import BeitFeatureExtractor, BeitModel
979
+ >>> from PIL import Image
980
+ >>> import requests
981
+
982
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
983
+ >>> image = Image.open(requests.get(url, stream=True).raw)
984
+
985
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
986
+ >>> model = BeitModel.from_pretrained('microsoft/beit-base-patch16-224-pt22k-ft22k')
987
+
988
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
989
+ >>> outputs = model(**inputs)
990
+ >>> last_hidden_states = outputs.last_hidden_state
991
+ """
992
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
993
+ output_hidden_states = (
994
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
995
+ )
996
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
997
+
998
+ if pixel_values is None:
999
+ raise ValueError("You have to specify pixel_values")
1000
+
1001
+ # Prepare head mask if needed
1002
+ # 1.0 in head_mask indicate we keep the head
1003
+ # attention_probs has shape bsz x n_heads x N x N
1004
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1005
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1006
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
1007
+
1008
+ embedding_output = self.embeddings(pixel_values, bool_masked_pos)
1009
+
1010
+ encoder_outputs = self.encoder(
1011
+ embedding_output,
1012
+ head_mask=head_mask,
1013
+ output_attentions=output_attentions,
1014
+ output_hidden_states=output_hidden_states,
1015
+ output_token_idx=output_token_idx,
1016
+ return_dict=return_dict,
1017
+ )
1018
+ sequence_output = encoder_outputs[0]
1019
+ sequence_output = self.layernorm(sequence_output)
1020
+ pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
1021
+
1022
+ if not return_dict:
1023
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1024
+
1025
+ return BeitModelOutputWithPooling(
1026
+ last_hidden_state=sequence_output,
1027
+ pooler_output=pooled_output,
1028
+ hidden_states=encoder_outputs.hidden_states,
1029
+ attentions=encoder_outputs.attentions,
1030
+ token_idx=encoder_outputs.token_idx,
1031
+ )
1032
+
1033
+
1034
+ class BeitPooler(nn.Module):
1035
+ def __init__(self, config):
1036
+ super().__init__()
1037
+ self.layernorm = (
1038
+ nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) if config.use_mean_pooling else None
1039
+ )
1040
+
1041
+ def forward(self, hidden_states):
1042
+ if self.layernorm is not None:
1043
+ # Mean pool the final hidden states of the patch tokens
1044
+ patch_tokens = hidden_states[:, 1:, :]
1045
+ pooled_output = self.layernorm(patch_tokens.mean(1))
1046
+ else:
1047
+ # Pool by simply taking the final hidden state of the [CLS] token
1048
+ pooled_output = hidden_states[:, 0]
1049
+
1050
+ return pooled_output
1051
+
1052
+
1053
+ @add_start_docstrings(
1054
+ "Beit Model transformer with a 'language' modeling head on top (to predict visual tokens).",
1055
+ BEIT_START_DOCSTRING,
1056
+ )
1057
+ class BeitForMaskedImageModeling(BeitPreTrainedModel):
1058
+ def __init__(self, config):
1059
+ super().__init__(config)
1060
+
1061
+ self.num_labels = config.num_labels
1062
+ self.beit = BeitModel(config, add_pooling_layer=False)
1063
+
1064
+ # Classifier head
1065
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1066
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
1067
+
1068
+ # Initialize weights and apply final processing
1069
+ self.post_init()
1070
+
1071
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1072
+ @replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
1073
+ def forward(
1074
+ self,
1075
+ pixel_values=None,
1076
+ bool_masked_pos=None,
1077
+ head_mask=None,
1078
+ labels=None,
1079
+ output_attentions=None,
1080
+ output_hidden_states=None,
1081
+ return_dict=None,
1082
+ ):
1083
+ r"""
1084
+ bool_masked_pos (:obj:`torch.BoolTensor` of shape :obj:`(batch_size, num_patches)`):
1085
+ Boolean masked positions. Indicates which patches are masked (1) and which aren't (0).
1086
+
1087
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1088
+ Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
1089
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1090
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1091
+
1092
+ Returns:
1093
+
1094
+ Examples::
1095
+
1096
+ >>> from transformers import BeitFeatureExtractor, BeitForMaskedImageModeling
1097
+ >>> from PIL import Image
1098
+ >>> import requests
1099
+
1100
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1101
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1102
+
1103
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
1104
+ >>> model = BeitForMaskedImageModeling.from_pretrained('microsoft/beit-base-patch16-224-pt22k')
1105
+
1106
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1107
+ >>> outputs = model(**inputs)
1108
+ >>> logits = outputs.logits
1109
+ """
1110
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1111
+
1112
+ outputs = self.beit(
1113
+ pixel_values,
1114
+ bool_masked_pos=bool_masked_pos,
1115
+ head_mask=head_mask,
1116
+ output_attentions=output_attentions,
1117
+ output_hidden_states=output_hidden_states,
1118
+ return_dict=return_dict,
1119
+ )
1120
+
1121
+ sequence_output = outputs[0]
1122
+ sequence_output = self.layernorm(sequence_output)
1123
+ prediction_scores = self.lm_head(sequence_output[:, 1:])
1124
+
1125
+ masked_lm_loss = None
1126
+ if labels is not None:
1127
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1128
+ masked_lm_loss = loss_fct(prediction_scores[bool_masked_pos], labels)
1129
+
1130
+ if not return_dict:
1131
+ output = (prediction_scores,) + outputs[2:]
1132
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1133
+
1134
+ return MaskedLMOutput(
1135
+ loss=masked_lm_loss,
1136
+ logits=prediction_scores,
1137
+ hidden_states=outputs.hidden_states,
1138
+ attentions=outputs.attentions,
1139
+ )
1140
+
1141
+
1142
+ @add_start_docstrings(
1143
+ """
1144
+ Beit Model transformer with an image classification head on top (a linear layer on top of the average of the final
1145
+ hidden states of the patch tokens) e.g. for ImageNet.
1146
+ """,
1147
+ BEIT_START_DOCSTRING,
1148
+ )
1149
+ class BeitForImageClassification(BeitPreTrainedModel):
1150
+ def __init__(self, config):
1151
+ super().__init__(config)
1152
+
1153
+ self.num_labels = config.num_labels
1154
+ self.beit = BeitModel(config, add_pooling_layer=True)
1155
+
1156
+ # Classifier head
1157
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity()
1158
+
1159
+ # Initialize weights and apply final processing
1160
+ self.post_init()
1161
+
1162
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1163
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1164
+ def forward(
1165
+ self,
1166
+ pixel_values=None,
1167
+ head_mask=None,
1168
+ labels=None,
1169
+ output_attentions=None,
1170
+ output_hidden_states=None,
1171
+ return_dict=None,
1172
+ ):
1173
+ r"""
1174
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1175
+ Labels for computing the image classification/regression loss. Indices should be in :obj:`[0, ...,
1176
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1177
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1178
+
1179
+ Returns:
1180
+
1181
+ Examples::
1182
+
1183
+ >>> from transformers import BeitFeatureExtractor, BeitForImageClassification
1184
+ >>> from PIL import Image
1185
+ >>> import requests
1186
+
1187
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1188
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1189
+
1190
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-patch16-224')
1191
+ >>> model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
1192
+
1193
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1194
+ >>> outputs = model(**inputs)
1195
+ >>> logits = outputs.logits
1196
+ >>> # model predicts one of the 1000 ImageNet classes
1197
+ >>> predicted_class_idx = logits.argmax(-1).item()
1198
+ >>> print("Predicted class:", model.config.id2label[predicted_class_idx])
1199
+ """
1200
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1201
+
1202
+ outputs = self.beit(
1203
+ pixel_values,
1204
+ head_mask=head_mask,
1205
+ output_attentions=output_attentions,
1206
+ output_hidden_states=output_hidden_states,
1207
+ return_dict=return_dict,
1208
+ )
1209
+
1210
+ pooled_output = outputs.pooler_output if return_dict else outputs[1]
1211
+
1212
+ logits = self.classifier(pooled_output)
1213
+
1214
+ loss = None
1215
+ if labels is not None:
1216
+ if self.num_labels == 1:
1217
+ # We are doing regression
1218
+ loss_fct = MSELoss()
1219
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1220
+ else:
1221
+ loss_fct = CrossEntropyLoss()
1222
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1223
+
1224
+ if not return_dict:
1225
+ output = (logits,) + outputs[2:]
1226
+ return ((loss,) + output) if loss is not None else output
1227
+
1228
+ return SequenceClassifierOutput(
1229
+ loss=loss,
1230
+ logits=logits,
1231
+ hidden_states=outputs.hidden_states,
1232
+ attentions=outputs.attentions,
1233
+ )
1234
+
1235
+
1236
+ class BeitConvModule(nn.Module):
1237
+ """
1238
+ A convolutional block that bundles conv/norm/activation layers. This block simplifies the usage of convolution
1239
+ layers, which are commonly used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
1240
+
1241
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1242
+ """
1243
+
1244
+ def __init__(self, in_channels, out_channels, kernel_size, padding=0, bias=False, dilation=1):
1245
+ super().__init__()
1246
+ self.conv = nn.Conv2d(
1247
+ in_channels=in_channels,
1248
+ out_channels=out_channels,
1249
+ kernel_size=kernel_size,
1250
+ padding=padding,
1251
+ bias=bias,
1252
+ dilation=dilation,
1253
+ )
1254
+ self.bn = nn.BatchNorm2d(out_channels)
1255
+ self.activation = nn.ReLU()
1256
+
1257
+ def forward(self, input):
1258
+ output = self.conv(input)
1259
+ output = self.bn(output)
1260
+ output = self.activation(output)
1261
+
1262
+ return output
1263
+
1264
+
1265
+ class BeitPyramidPoolingModule(nn.ModuleList):
1266
+ """
1267
+ Pyramid Pooling Module (PPM) used in PSPNet.
1268
+
1269
+ Args:
1270
+ pool_scales (tuple[int]): Pooling scales used in Pooling Pyramid
1271
+ Module.
1272
+ in_channels (int): Input channels.
1273
+ channels (int): Channels after modules, before conv_seg.
1274
+ align_corners (bool): align_corners argument of F.interpolate.
1275
+
1276
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1277
+ """
1278
+
1279
+ def __init__(self, pool_scales, in_channels, channels, align_corners):
1280
+ super().__init__()
1281
+ self.pool_scales = pool_scales
1282
+ self.align_corners = align_corners
1283
+ self.in_channels = in_channels
1284
+ self.channels = channels
1285
+ for pool_scale in pool_scales:
1286
+ self.append(
1287
+ nn.Sequential(
1288
+ nn.AdaptiveAvgPool2d(pool_scale),
1289
+ BeitConvModule(self.in_channels, self.channels, kernel_size=1),
1290
+ )
1291
+ )
1292
+
1293
+ def forward(self, x):
1294
+ ppm_outs = []
1295
+ for ppm in self:
1296
+ ppm_out = ppm(x)
1297
+ upsampled_ppm_out = nn.functional.interpolate(
1298
+ ppm_out, size=x.size()[2:], mode="bilinear", align_corners=self.align_corners
1299
+ )
1300
+ ppm_outs.append(upsampled_ppm_out)
1301
+ return ppm_outs
1302
+
1303
+
1304
+ class BeitUperHead(nn.Module):
1305
+ """
1306
+ Unified Perceptual Parsing for Scene Understanding. This head is the implementation of `UPerNet
1307
+ <https://arxiv.org/abs/1807.10221>`_.
1308
+
1309
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1310
+ """
1311
+
1312
+ def __init__(self, config):
1313
+ super().__init__()
1314
+
1315
+ self.pool_scales = config.pool_scales # e.g. (1, 2, 3, 6)
1316
+ self.in_channels = [config.hidden_size] * 4 # e.g. [768, 768, 768, 768]
1317
+ self.channels = config.hidden_size
1318
+ self.align_corners = False
1319
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1320
+
1321
+ # PSP Module
1322
+ self.psp_modules = BeitPyramidPoolingModule(
1323
+ self.pool_scales,
1324
+ self.in_channels[-1],
1325
+ self.channels,
1326
+ align_corners=self.align_corners,
1327
+ )
1328
+ self.bottleneck = BeitConvModule(
1329
+ self.in_channels[-1] + len(self.pool_scales) * self.channels,
1330
+ self.channels,
1331
+ kernel_size=3,
1332
+ padding=1,
1333
+ )
1334
+ # FPN Module
1335
+ self.lateral_convs = nn.ModuleList()
1336
+ self.fpn_convs = nn.ModuleList()
1337
+ for in_channels in self.in_channels[:-1]: # skip the top layer
1338
+ l_conv = BeitConvModule(in_channels, self.channels, kernel_size=1)
1339
+ fpn_conv = BeitConvModule(self.channels, self.channels, kernel_size=3, padding=1)
1340
+ self.lateral_convs.append(l_conv)
1341
+ self.fpn_convs.append(fpn_conv)
1342
+
1343
+ self.fpn_bottleneck = BeitConvModule(
1344
+ len(self.in_channels) * self.channels,
1345
+ self.channels,
1346
+ kernel_size=3,
1347
+ padding=1,
1348
+ )
1349
+
1350
+ def psp_forward(self, inputs):
1351
+ x = inputs[-1]
1352
+ psp_outs = [x]
1353
+ psp_outs.extend(self.psp_modules(x))
1354
+ psp_outs = torch.cat(psp_outs, dim=1)
1355
+ output = self.bottleneck(psp_outs)
1356
+
1357
+ return output
1358
+
1359
+ def forward(self, encoder_hidden_states):
1360
+ # build laterals
1361
+ laterals = [lateral_conv(encoder_hidden_states[i]) for i, lateral_conv in enumerate(self.lateral_convs)]
1362
+
1363
+ laterals.append(self.psp_forward(encoder_hidden_states))
1364
+
1365
+ # build top-down path
1366
+ used_backbone_levels = len(laterals)
1367
+ for i in range(used_backbone_levels - 1, 0, -1):
1368
+ prev_shape = laterals[i - 1].shape[2:]
1369
+ laterals[i - 1] = laterals[i - 1] + nn.functional.interpolate(
1370
+ laterals[i], size=prev_shape, mode="bilinear", align_corners=self.align_corners
1371
+ )
1372
+
1373
+ # build outputs
1374
+ fpn_outs = [self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels - 1)]
1375
+ # append psp feature
1376
+ fpn_outs.append(laterals[-1])
1377
+
1378
+ for i in range(used_backbone_levels - 1, 0, -1):
1379
+ fpn_outs[i] = nn.functional.interpolate(
1380
+ fpn_outs[i], size=fpn_outs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
1381
+ )
1382
+ fpn_outs = torch.cat(fpn_outs, dim=1)
1383
+ output = self.fpn_bottleneck(fpn_outs)
1384
+ output = self.classifier(output)
1385
+
1386
+ return output
1387
+
1388
+
1389
+ class BeitFCNHead(nn.Module):
1390
+ """
1391
+ Fully Convolution Networks for Semantic Segmentation. This head is implemented of `FCNNet
1392
+ <https://arxiv.org/abs/1411.4038>`_.
1393
+
1394
+ Args:
1395
+ config (BeitConfig): Configuration.
1396
+ in_channels
1397
+ kernel_size (int): The kernel size for convs in the head. Default: 3.
1398
+ dilation (int): The dilation rate for convs in the head. Default: 1.
1399
+
1400
+
1401
+ Based on OpenMMLab's implementation, found in https://github.com/open-mmlab/mmsegmentation.
1402
+ """
1403
+
1404
+ def __init__(self, config, in_index=2, kernel_size=3, dilation=1):
1405
+ super().__init__()
1406
+ self.in_channels = config.hidden_size
1407
+ self.channels = config.auxiliary_channels
1408
+ self.num_convs = config.auxiliary_num_convs
1409
+ self.concat_input = config.auxiliary_concat_input
1410
+ self.in_index = in_index
1411
+
1412
+ conv_padding = (kernel_size // 2) * dilation
1413
+ convs = []
1414
+ convs.append(
1415
+ BeitConvModule(
1416
+ self.in_channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1417
+ )
1418
+ )
1419
+ for i in range(self.num_convs - 1):
1420
+ convs.append(
1421
+ BeitConvModule(
1422
+ self.channels, self.channels, kernel_size=kernel_size, padding=conv_padding, dilation=dilation
1423
+ )
1424
+ )
1425
+ if self.num_convs == 0:
1426
+ self.convs = nn.Identity()
1427
+ else:
1428
+ self.convs = nn.Sequential(*convs)
1429
+ if self.concat_input:
1430
+ self.conv_cat = BeitConvModule(
1431
+ self.in_channels + self.channels, self.channels, kernel_size=kernel_size, padding=kernel_size // 2
1432
+ )
1433
+
1434
+ self.classifier = nn.Conv2d(self.channels, config.num_labels, kernel_size=1)
1435
+
1436
+ def forward(self, encoder_hidden_states):
1437
+ # just take the relevant feature maps
1438
+ hidden_states = encoder_hidden_states[self.in_index]
1439
+ output = self.convs(hidden_states)
1440
+ if self.concat_input:
1441
+ output = self.conv_cat(torch.cat([hidden_states, output], dim=1))
1442
+ output = self.classifier(output)
1443
+ return output
1444
+
1445
+
1446
+ @add_start_docstrings(
1447
+ """
1448
+ Beit Model transformer with a semantic segmentation head on top e.g. for ADE20k, CityScapes.
1449
+ """,
1450
+ BEIT_START_DOCSTRING,
1451
+ )
1452
+ class BeitForSemanticSegmentation(BeitPreTrainedModel):
1453
+ def __init__(self, config):
1454
+ super().__init__(config)
1455
+
1456
+ self.num_labels = config.num_labels
1457
+ self.beit = BeitModel(config, add_pooling_layer=False)
1458
+
1459
+ # FPNs
1460
+ self.fpn1 = nn.Sequential(
1461
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1462
+ nn.BatchNorm2d(config.hidden_size),
1463
+ nn.GELU(),
1464
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1465
+ )
1466
+ self.fpn2 = nn.Sequential(
1467
+ nn.ConvTranspose2d(config.hidden_size, config.hidden_size, kernel_size=2, stride=2),
1468
+ )
1469
+ self.fpn3 = nn.Identity()
1470
+ self.fpn4 = nn.MaxPool2d(kernel_size=2, stride=2)
1471
+
1472
+ # Semantic segmentation head(s)
1473
+ self.decode_head = BeitUperHead(config)
1474
+ self.auxiliary_head = BeitFCNHead(config) if config.use_auxiliary_head else None
1475
+
1476
+ # Initialize weights and apply final processing
1477
+ self.post_init()
1478
+
1479
+ def compute_loss(self, logits, auxiliary_logits, labels):
1480
+ # upsample logits to the images' original size
1481
+ upsampled_logits = nn.functional.interpolate(
1482
+ logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1483
+ )
1484
+ if auxiliary_logits is not None:
1485
+ upsampled_auxiliary_logits = nn.functional.interpolate(
1486
+ auxiliary_logits, size=labels.shape[-2:], mode="bilinear", align_corners=False
1487
+ )
1488
+ # compute weighted loss
1489
+ loss_fct = CrossEntropyLoss(ignore_index=self.config.semantic_loss_ignore_index)
1490
+ main_loss = loss_fct(upsampled_logits, labels)
1491
+ auxiliary_loss = loss_fct(upsampled_auxiliary_logits, labels)
1492
+ loss = main_loss + self.config.auxiliary_loss_weight * auxiliary_loss
1493
+
1494
+ return loss
1495
+
1496
+ @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
1497
+ @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1498
+ def forward(
1499
+ self,
1500
+ pixel_values=None,
1501
+ head_mask=None,
1502
+ labels=None,
1503
+ output_attentions=None,
1504
+ output_hidden_states=None,
1505
+ return_dict=None,
1506
+ ):
1507
+ r"""
1508
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, height, width)`, `optional`):
1509
+ Ground truth semantic segmentation maps for computing the loss. Indices should be in :obj:`[0, ...,
1510
+ config.num_labels - 1]`. If :obj:`config.num_labels > 1`, a classification loss is computed
1511
+ (Cross-Entropy).
1512
+
1513
+ Returns:
1514
+
1515
+ Examples::
1516
+
1517
+ >>> from transformers import BeitFeatureExtractor, BeitForSemanticSegmentation
1518
+ >>> from PIL import Image
1519
+ >>> import requests
1520
+
1521
+ >>> url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
1522
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1523
+
1524
+ >>> feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
1525
+ >>> model = BeitForSemanticSegmentation.from_pretrained('microsoft/beit-base-finetuned-ade-640-640')
1526
+
1527
+ >>> inputs = feature_extractor(images=image, return_tensors="pt")
1528
+ >>> outputs = model(**inputs)
1529
+ >>> # logits are of shape (batch_size, num_labels, height/4, width/4)
1530
+ >>> logits = outputs.logits
1531
+ """
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+ output_hidden_states = (
1534
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1535
+ )
1536
+
1537
+ outputs = self.beit(
1538
+ pixel_values,
1539
+ head_mask=head_mask,
1540
+ output_attentions=output_attentions,
1541
+ output_hidden_states=True, # we need the intermediate hidden states
1542
+ return_dict=return_dict,
1543
+ )
1544
+
1545
+ encoder_hidden_states = outputs.hidden_states if return_dict else outputs[2]
1546
+
1547
+ # only keep certain features, and reshape
1548
+ # note that we do +1 as the encoder_hidden_states also includes the initial embeddings
1549
+ features = [feature for idx, feature in enumerate(encoder_hidden_states) if idx + 1 in self.config.out_indices]
1550
+ batch_size = pixel_values.shape[0]
1551
+ patch_resolution = self.config.image_size // self.config.patch_size
1552
+ features = [
1553
+ x[:, 1:, :].permute(0, 2, 1).reshape(batch_size, -1, patch_resolution, patch_resolution) for x in features
1554
+ ]
1555
+
1556
+ # apply FPNs
1557
+ ops = [self.fpn1, self.fpn2, self.fpn3, self.fpn4]
1558
+ for i in range(len(features)):
1559
+ features[i] = ops[i](features[i])
1560
+
1561
+ logits = self.decode_head(features)
1562
+ auxiliary_logits = None
1563
+ if self.auxiliary_head is not None:
1564
+ auxiliary_logits = self.auxiliary_head(features)
1565
+
1566
+ loss = None
1567
+ if labels is not None:
1568
+ if self.config.num_labels == 1:
1569
+ raise ValueError("The number of labels should be greater than one")
1570
+ else:
1571
+ loss = self.compute_loss(logits, auxiliary_logits, labels)
1572
+
1573
+ if not return_dict:
1574
+ if output_hidden_states:
1575
+ output = (logits,) + outputs[2:]
1576
+ else:
1577
+ output = (logits,) + outputs[3:]
1578
+ return ((loss,) + output) if loss is not None else output
1579
+
1580
+ return SequenceClassifierOutput(
1581
+ loss=loss,
1582
+ logits=logits,
1583
+ hidden_states=outputs.hidden_states if output_hidden_states else None,
1584
+ attentions=outputs.attentions,
1585
+ )
svitt/sparse_xbert.py ADDED
@@ -0,0 +1,2039 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch BERT model. """
17
+
18
+ import math
19
+ import os
20
+ import warnings
21
+ from dataclasses import dataclass
22
+ from typing import Optional, Tuple
23
+
24
+ import torch
25
+ from torch import Tensor, device, nn
26
+ import torch.utils.checkpoint
27
+ from torch import nn
28
+ from torch.nn import CrossEntropyLoss, MSELoss
29
+ import torch.nn.functional as F
30
+
31
+ from transformers.activations import ACT2FN
32
+ from transformers.file_utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ replace_return_docstrings,
37
+ )
38
+ from transformers.modeling_outputs import (
39
+ BaseModelOutputWithPastAndCrossAttentions,
40
+ BaseModelOutputWithPoolingAndCrossAttentions,
41
+ CausalLMOutputWithCrossAttentions,
42
+ MaskedLMOutput,
43
+ MultipleChoiceModelOutput,
44
+ NextSentencePredictorOutput,
45
+ QuestionAnsweringModelOutput,
46
+ SequenceClassifierOutput,
47
+ TokenClassifierOutput,
48
+ )
49
+ from transformers.modeling_utils import (
50
+ PreTrainedModel,
51
+ apply_chunking_to_forward,
52
+ find_pruneable_heads_and_indices,
53
+ prune_linear_layer,
54
+ )
55
+ from svitt.sparse_config import BertConfig
56
+
57
+ import transformers
58
+ transformers.logging.set_verbosity_error()
59
+
60
+
61
+ _CONFIG_FOR_DOC = "BertConfig"
62
+ _TOKENIZER_FOR_DOC = "BertTokenizer"
63
+
64
+ BERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
65
+ "bert-base-uncased",
66
+ "bert-large-uncased",
67
+ "bert-base-cased",
68
+ "bert-large-cased",
69
+ "bert-base-multilingual-uncased",
70
+ "bert-base-multilingual-cased",
71
+ "bert-base-chinese",
72
+ "bert-base-german-cased",
73
+ "bert-large-uncased-whole-word-masking",
74
+ "bert-large-cased-whole-word-masking",
75
+ "bert-large-uncased-whole-word-masking-finetuned-squad",
76
+ "bert-large-cased-whole-word-masking-finetuned-squad",
77
+ "bert-base-cased-finetuned-mrpc",
78
+ "bert-base-german-dbmdz-cased",
79
+ "bert-base-german-dbmdz-uncased",
80
+ "cl-tohoku/bert-base-japanese",
81
+ "cl-tohoku/bert-base-japanese-whole-word-masking",
82
+ "cl-tohoku/bert-base-japanese-char",
83
+ "cl-tohoku/bert-base-japanese-char-whole-word-masking",
84
+ "TurkuNLP/bert-base-finnish-cased-v1",
85
+ "TurkuNLP/bert-base-finnish-uncased-v1",
86
+ "wietsedv/bert-base-dutch-cased",
87
+ # See all BERT models at https://huggingface.co/models?filter=bert
88
+ ]
89
+
90
+
91
+ @dataclass
92
+ class BertModelOutputWithPastAndCrossAttentions(BaseModelOutputWithPastAndCrossAttentions):
93
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
94
+
95
+
96
+ @dataclass
97
+ class BertModelOutputWithPoolingAndCrossAttentions(BaseModelOutputWithPoolingAndCrossAttentions):
98
+ token_idx: Optional[Tuple[torch.LongTensor]] = None
99
+
100
+
101
+ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
102
+ """Load tf checkpoints in a pytorch model."""
103
+ try:
104
+ import re
105
+
106
+ import numpy as np
107
+ import tensorflow as tf
108
+ except ImportError:
109
+ print(
110
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
111
+ "https://www.tensorflow.org/install/ for installation instructions."
112
+ )
113
+ raise
114
+ tf_path = os.path.abspath(tf_checkpoint_path)
115
+ # Load weights from TF model
116
+ init_vars = tf.train.list_variables(tf_path)
117
+ names = []
118
+ arrays = []
119
+ for name, shape in init_vars:
120
+ array = tf.train.load_variable(tf_path, name)
121
+ names.append(name)
122
+ arrays.append(array)
123
+
124
+ for name, array in zip(names, arrays):
125
+ name = name.split("/")
126
+ # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
127
+ # which are not required for using pretrained model
128
+ if any(
129
+ n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer",
130
+ "AdamWeightDecayOptimizer_1", "global_step"]
131
+ for n in name
132
+ ):
133
+ continue
134
+ pointer = model
135
+ for m_name in name:
136
+ if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
137
+ scope_names = re.split(r"_(\d+)", m_name)
138
+ else:
139
+ scope_names = [m_name]
140
+ if scope_names[0] == "kernel" or scope_names[0] == "gamma":
141
+ pointer = getattr(pointer, "weight")
142
+ elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
143
+ pointer = getattr(pointer, "bias")
144
+ elif scope_names[0] == "output_weights":
145
+ pointer = getattr(pointer, "weight")
146
+ elif scope_names[0] == "squad":
147
+ pointer = getattr(pointer, "classifier")
148
+ else:
149
+ try:
150
+ pointer = getattr(pointer, scope_names[0])
151
+ except AttributeError:
152
+ continue
153
+ if len(scope_names) >= 2:
154
+ num = int(scope_names[1])
155
+ pointer = pointer[num]
156
+ if m_name[-11:] == "_embeddings":
157
+ pointer = getattr(pointer, "weight")
158
+ elif m_name == "kernel":
159
+ array = np.transpose(array)
160
+ try:
161
+ assert (
162
+ pointer.shape == array.shape
163
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
164
+ except AssertionError as e:
165
+ e.args += (pointer.shape, array.shape)
166
+ raise
167
+ pointer.data = torch.from_numpy(array)
168
+ return model
169
+
170
+
171
+ class BertEmbeddings(nn.Module):
172
+ """Construct the embeddings from word, position and token_type embeddings."""
173
+
174
+ def __init__(self, config):
175
+ super().__init__()
176
+ self.word_embeddings = nn.Embedding(
177
+ config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
178
+ self.position_embeddings = nn.Embedding(
179
+ config.max_position_embeddings, config.hidden_size)
180
+ self.token_type_embeddings = nn.Embedding(
181
+ config.type_vocab_size, config.hidden_size)
182
+
183
+ # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
184
+ # any TensorFlow checkpoint file
185
+ self.LayerNorm = nn.LayerNorm(
186
+ config.hidden_size, eps=config.layer_norm_eps)
187
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
188
+
189
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
190
+ self.register_buffer("position_ids", torch.arange(
191
+ config.max_position_embeddings).expand((1, -1)))
192
+ self.position_embedding_type = getattr(
193
+ config, "position_embedding_type", "absolute")
194
+
195
+ self.config = config
196
+
197
+ def forward(
198
+ self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
199
+ ):
200
+ if input_ids is not None:
201
+ input_shape = input_ids.size()
202
+ else:
203
+ input_shape = inputs_embeds.size()[:-1]
204
+
205
+ seq_length = input_shape[1]
206
+
207
+ if position_ids is None:
208
+ position_ids = self.position_ids[:,
209
+ past_key_values_length: seq_length + past_key_values_length]
210
+
211
+ if token_type_ids is None:
212
+ token_type_ids = torch.zeros(
213
+ input_shape, dtype=torch.long, device=self.position_ids.device)
214
+
215
+ if inputs_embeds is None:
216
+ inputs_embeds = self.word_embeddings(input_ids)
217
+
218
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
219
+
220
+ embeddings = inputs_embeds + token_type_embeddings
221
+ if self.position_embedding_type == "absolute":
222
+ position_embeddings = self.position_embeddings(position_ids)
223
+ embeddings += position_embeddings
224
+ embeddings = self.LayerNorm(embeddings)
225
+ embeddings = self.dropout(embeddings)
226
+ return embeddings
227
+
228
+
229
+ class BertSelfAttention(nn.Module):
230
+ def __init__(self, config, is_cross_attention):
231
+ super().__init__()
232
+ self.config = config
233
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
234
+ raise ValueError(
235
+ "The hidden size (%d) is not a multiple of the number of attention "
236
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads)
237
+ )
238
+
239
+ self.num_attention_heads = config.num_attention_heads
240
+ self.attention_head_size = int(
241
+ config.hidden_size / config.num_attention_heads)
242
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
243
+
244
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
245
+ if is_cross_attention:
246
+ self.key = nn.Linear(config.encoder_width, self.all_head_size)
247
+ self.value = nn.Linear(config.encoder_width, self.all_head_size)
248
+ else:
249
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
250
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
251
+
252
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
253
+ self.position_embedding_type = getattr(
254
+ config, "position_embedding_type", "absolute")
255
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
256
+ self.max_position_embeddings = config.max_position_embeddings
257
+ self.distance_embedding = nn.Embedding(
258
+ 2 * config.max_position_embeddings - 1, self.attention_head_size)
259
+ self.save_attention = False
260
+
261
+ def save_attn_gradients(self, attn_gradients):
262
+ self.attn_gradients = attn_gradients
263
+
264
+ def get_attn_gradients(self):
265
+ return self.attn_gradients
266
+
267
+ def save_attention_map(self, attention_map):
268
+ self.attention_map = attention_map
269
+
270
+ def get_attention_map(self):
271
+ return self.attention_map
272
+
273
+ def transpose_for_scores(self, x):
274
+ new_x_shape = x.size()[
275
+ :-1] + (self.num_attention_heads, self.attention_head_size)
276
+ x = x.view(*new_x_shape)
277
+ return x.permute(0, 2, 1, 3)
278
+
279
+ def forward(
280
+ self,
281
+ hidden_states,
282
+ attention_mask=None,
283
+ head_mask=None,
284
+ encoder_hidden_states=None,
285
+ encoder_attention_mask=None,
286
+ past_key_value=None,
287
+ output_attentions=False,
288
+ ):
289
+ mixed_query_layer = self.query(hidden_states)
290
+
291
+ # If this is instantiated as a cross-attention module, the keys
292
+ # and values come from an encoder; the attention mask needs to be
293
+ # such that the encoder's padding tokens are not attended to.
294
+ is_cross_attention = encoder_hidden_states is not None
295
+
296
+ if is_cross_attention:
297
+ key_layer = self.transpose_for_scores(
298
+ self.key(encoder_hidden_states))
299
+ value_layer = self.transpose_for_scores(
300
+ self.value(encoder_hidden_states))
301
+ attention_mask = encoder_attention_mask
302
+ elif past_key_value is not None:
303
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
304
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
305
+ key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
306
+ value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
307
+ else:
308
+ key_layer = self.transpose_for_scores(self.key(hidden_states))
309
+ value_layer = self.transpose_for_scores(self.value(hidden_states))
310
+
311
+ query_layer = self.transpose_for_scores(mixed_query_layer)
312
+
313
+ past_key_value = (key_layer, value_layer)
314
+
315
+ # Take the dot product between "query" and "key" to get the raw attention scores.
316
+ attention_scores = torch.matmul(
317
+ query_layer, key_layer.transpose(-1, -2))
318
+
319
+ if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
320
+ seq_length = hidden_states.size()[1]
321
+ position_ids_l = torch.arange(
322
+ seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
323
+ position_ids_r = torch.arange(
324
+ seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
325
+ distance = position_ids_l - position_ids_r
326
+ positional_embedding = self.distance_embedding(
327
+ distance + self.max_position_embeddings - 1)
328
+ positional_embedding = positional_embedding.to(
329
+ dtype=query_layer.dtype) # fp16 compatibility
330
+
331
+ if self.position_embedding_type == "relative_key":
332
+ relative_position_scores = torch.einsum(
333
+ "bhld,lrd->bhlr", query_layer, positional_embedding)
334
+ attention_scores = attention_scores + relative_position_scores
335
+ elif self.position_embedding_type == "relative_key_query":
336
+ relative_position_scores_query = torch.einsum(
337
+ "bhld,lrd->bhlr", query_layer, positional_embedding)
338
+ relative_position_scores_key = torch.einsum(
339
+ "bhrd,lrd->bhlr", key_layer, positional_embedding)
340
+ attention_scores = attention_scores + \
341
+ relative_position_scores_query + relative_position_scores_key
342
+
343
+ attention_scores = attention_scores / \
344
+ math.sqrt(self.attention_head_size)
345
+ if attention_mask is not None:
346
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
347
+ attention_scores = attention_scores + attention_mask
348
+
349
+ # Normalize the attention scores to probabilities.
350
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
351
+
352
+ if is_cross_attention and self.save_attention:
353
+ self.save_attention_map(attention_probs)
354
+ attention_probs.register_hook(self.save_attn_gradients)
355
+
356
+ # This is actually dropping out entire tokens to attend to, which might
357
+ # seem a bit unusual, but is taken from the original Transformer paper.
358
+ attention_probs_dropped = self.dropout(attention_probs)
359
+
360
+ # Mask heads if we want to
361
+ if head_mask is not None:
362
+ attention_probs_dropped = attention_probs_dropped * head_mask
363
+
364
+ context_layer = torch.matmul(attention_probs_dropped, value_layer)
365
+
366
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
367
+ new_context_layer_shape = context_layer.size()[
368
+ :-2] + (self.all_head_size,)
369
+ context_layer = context_layer.view(*new_context_layer_shape)
370
+
371
+ # added `attention_scores` to return tuple
372
+ outputs = (context_layer, attention_probs, attention_scores) if output_attentions else (
373
+ context_layer,)
374
+
375
+ outputs = outputs + (past_key_value,)
376
+ return outputs
377
+
378
+
379
+ class BertSelfOutput(nn.Module):
380
+ def __init__(self, config):
381
+ super().__init__()
382
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
383
+ self.LayerNorm = nn.LayerNorm(
384
+ config.hidden_size, eps=config.layer_norm_eps)
385
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
386
+
387
+ def forward(self, hidden_states, input_tensor):
388
+ hidden_states = self.dense(hidden_states)
389
+ hidden_states = self.dropout(hidden_states)
390
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
391
+ return hidden_states
392
+
393
+
394
+ class BertAttention(nn.Module):
395
+ def __init__(self, config, is_cross_attention=False):
396
+ super().__init__()
397
+ self.self = BertSelfAttention(config, is_cross_attention)
398
+ self.output = BertSelfOutput(config)
399
+ self.pruned_heads = set()
400
+
401
+ def prune_heads(self, heads):
402
+ if len(heads) == 0:
403
+ return
404
+ heads, index = find_pruneable_heads_and_indices(
405
+ heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
406
+ )
407
+
408
+ # Prune linear layers
409
+ self.self.query = prune_linear_layer(self.self.query, index)
410
+ self.self.key = prune_linear_layer(self.self.key, index)
411
+ self.self.value = prune_linear_layer(self.self.value, index)
412
+ self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
413
+
414
+ # Update hyper params and store pruned heads
415
+ self.self.num_attention_heads = self.self.num_attention_heads - \
416
+ len(heads)
417
+ self.self.all_head_size = self.self.attention_head_size * \
418
+ self.self.num_attention_heads
419
+ self.pruned_heads = self.pruned_heads.union(heads)
420
+
421
+ def forward(
422
+ self,
423
+ hidden_states,
424
+ attention_mask=None,
425
+ head_mask=None,
426
+ encoder_hidden_states=None,
427
+ encoder_attention_mask=None,
428
+ past_key_value=None,
429
+ output_attentions=False,
430
+ ):
431
+ self_outputs = self.self(
432
+ hidden_states,
433
+ attention_mask,
434
+ head_mask,
435
+ encoder_hidden_states,
436
+ encoder_attention_mask,
437
+ past_key_value,
438
+ output_attentions,
439
+ )
440
+ attention_output = self.output(self_outputs[0], hidden_states)
441
+ # add attentions if we output them
442
+ outputs = (attention_output,) + self_outputs[1:]
443
+ return outputs # (context_layer, attention_probs, attention_scores, past_key_value,)
444
+
445
+
446
+ class BertIntermediate(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+ self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
450
+ if isinstance(config.hidden_act, str):
451
+ self.intermediate_act_fn = ACT2FN[config.hidden_act]
452
+ else:
453
+ self.intermediate_act_fn = config.hidden_act
454
+
455
+ def forward(self, hidden_states):
456
+ hidden_states = self.dense(hidden_states)
457
+ hidden_states = self.intermediate_act_fn(hidden_states)
458
+ return hidden_states
459
+
460
+
461
+ class BertOutput(nn.Module):
462
+ def __init__(self, config):
463
+ super().__init__()
464
+ self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
465
+ self.LayerNorm = nn.LayerNorm(
466
+ config.hidden_size, eps=config.layer_norm_eps)
467
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
468
+
469
+ def forward(self, hidden_states, input_tensor):
470
+ hidden_states = self.dense(hidden_states)
471
+ hidden_states = self.dropout(hidden_states)
472
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
473
+ return hidden_states
474
+
475
+
476
+ class BertLayer(nn.Module):
477
+ def __init__(self, config, layer_num, token_keep_rate=1.0):
478
+ super().__init__()
479
+ self.config = config
480
+ self.chunk_size_feed_forward = config.chunk_size_feed_forward
481
+ self.seq_len_dim = 1
482
+ self.attention = BertAttention(config)
483
+
484
+ self.has_cross_attention = (layer_num >= config.fusion_layer)
485
+ if self.has_cross_attention:
486
+ self.layer_num = layer_num
487
+ self.crossattention = BertAttention(
488
+ config, is_cross_attention=True)
489
+
490
+ # sparse params
491
+ self.token_keep_rate = token_keep_rate
492
+ self.token_keep_strategy = config.token_keep_strategy
493
+ self.encoder_num_cls_tokens = 1 # multiple cls tokens
494
+
495
+ self.intermediate = BertIntermediate(config)
496
+ self.output = BertOutput(config)
497
+
498
+ def sparsify(self, x, attn, mask=None):
499
+ x_cls, x_ = x[:, :self.encoder_num_cls_tokens], x[:, self.encoder_num_cls_tokens:]
500
+ assert 0 < self.token_keep_rate <= 1, "Expected keep rate in range (0, 1]"
501
+ left_tokens = math.ceil(self.token_keep_rate * x_.size(1))
502
+ if len(attn.shape) == 4:
503
+ attn = attn.mean(1) # pool over attention heads
504
+
505
+ if self.token_keep_strategy == 'cls_attn':
506
+ cls_attn = attn[:, 0, self.encoder_num_cls_tokens:]
507
+ _, idx = torch.topk(cls_attn, left_tokens, dim=1) # [B, left_tokens]
508
+
509
+ elif self.token_keep_strategy == 'avg_attn':
510
+ avg_attn = attn.mean(1)[:, self.encoder_num_cls_tokens:]
511
+ _, idx = torch.topk(avg_attn, left_tokens, dim=1) # [B, left_tokens]
512
+
513
+ elif self.token_keep_strategy == 'random':
514
+ rand = torch.rand(x_.shape[:2], device=x_.device)
515
+ _, idx = torch.topk(rand, left_tokens, dim=1) # [B, left_tokens]
516
+
517
+ else:
518
+ raise NotImplementedError(f"Sparse strategy {self.token_keep_strategy} is not implemented")
519
+
520
+ idx, _ = torch.sort(idx, dim=1)
521
+ index = idx.unsqueeze(-1).expand(-1, -1, x_.size(-1)) # [B, left_tokens, C]
522
+ outputs = torch.cat((x_cls, x_.gather(1, index)), dim=1).contiguous()
523
+ if mask is not None:
524
+ mask_cls, mask_ = mask[..., :self.encoder_num_cls_tokens], mask[..., self.encoder_num_cls_tokens:]
525
+ index = idx.unsqueeze(1).unsqueeze(1) # [B, 1, 1, left_tokens]
526
+ mask = torch.cat((mask_cls, mask_.gather(-1, index)), dim=-1).contiguous()
527
+ return outputs, mask, idx
528
+
529
+ def forward(
530
+ self,
531
+ hidden_states,
532
+ attention_mask=None,
533
+ head_mask=None,
534
+ encoder_hidden_states=None,
535
+ encoder_attention_mask=None,
536
+ past_key_value=None,
537
+ output_attentions=False,
538
+ ):
539
+ # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
540
+ self_attn_past_key_value = past_key_value[:
541
+ 2] if past_key_value is not None else None
542
+ self_attention_outputs = self.attention(
543
+ hidden_states,
544
+ attention_mask,
545
+ head_mask,
546
+ output_attentions=output_attentions,
547
+ past_key_value=self_attn_past_key_value,
548
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
549
+ attention_output = self_attention_outputs[0]
550
+
551
+ outputs = self_attention_outputs[1:-1]
552
+ present_key_value = self_attention_outputs[-1]
553
+
554
+ if self.has_cross_attention:
555
+ assert encoder_hidden_states is not None, "encoder_hidden_states must be given for cross-attention layers"
556
+ output_attentions = (output_attentions or self.token_keep_rate < 1)
557
+
558
+ if type(encoder_hidden_states) == list:
559
+ cross_attention_outputs = self.crossattention(
560
+ attention_output,
561
+ attention_mask,
562
+ head_mask,
563
+ encoder_hidden_states[(
564
+ self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
565
+ encoder_attention_mask[(
566
+ self.layer_num-self.config.fusion_layer) % len(encoder_hidden_states)],
567
+ output_attentions=output_attentions,
568
+ )
569
+ attention_output = cross_attention_outputs[0]
570
+ outputs = outputs + cross_attention_outputs[1:-1]
571
+
572
+ else:
573
+ cross_attention_outputs = self.crossattention(
574
+ attention_output,
575
+ attention_mask,
576
+ head_mask,
577
+ encoder_hidden_states,
578
+ encoder_attention_mask,
579
+ output_attentions=output_attentions,
580
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
581
+ attention_output = cross_attention_outputs[0]
582
+
583
+ # add cross attentions if we output attention weights
584
+ outputs = outputs + cross_attention_outputs[1:-1]
585
+
586
+ # node sparsification
587
+ if self.token_keep_rate < 1:
588
+ encoder_hidden_states, encoder_attention_mask, token_keep_idx = self.sparsify(
589
+ encoder_hidden_states, cross_attention_outputs[1], encoder_attention_mask)
590
+ outputs = outputs + (encoder_hidden_states, encoder_attention_mask, token_keep_idx)
591
+
592
+ layer_output = apply_chunking_to_forward(
593
+ self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
594
+ )
595
+ outputs = (layer_output,) + outputs
596
+
597
+ outputs = outputs + (present_key_value,)
598
+
599
+ return outputs
600
+
601
+ def feed_forward_chunk(self, attention_output):
602
+ intermediate_output = self.intermediate(attention_output)
603
+ layer_output = self.output(intermediate_output, attention_output)
604
+ return layer_output
605
+
606
+
607
+ class BertEncoder(nn.Module):
608
+ def __init__(self, config):
609
+ super().__init__()
610
+ self.config = config
611
+
612
+ # node sparsification
613
+ token_keep_rate = [1] * config.num_hidden_layers
614
+ for loc in config.token_drop_loc:
615
+ token_keep_rate[loc] = config.token_keep_rate
616
+
617
+ self.layer = nn.ModuleList([BertLayer(config, i, token_keep_rate[i])
618
+ for i in range(config.num_hidden_layers)])
619
+
620
+ def forward(
621
+ self,
622
+ hidden_states,
623
+ attention_mask=None,
624
+ head_mask=None,
625
+ encoder_hidden_states=None,
626
+ encoder_attention_mask=None,
627
+ past_key_values=None,
628
+ use_cache=None,
629
+ output_attentions=False,
630
+ output_hidden_states=False,
631
+ output_token_idx=False,
632
+ return_dict=True,
633
+ mode='multi_modal',
634
+ normalize_attention=True
635
+ ):
636
+ all_hidden_states = () if output_hidden_states else None
637
+ all_self_attentions = () if output_attentions else None
638
+ all_cross_attentions = () if output_attentions else None
639
+ all_token_idx = () if output_token_idx else None
640
+
641
+ next_decoder_cache = () if use_cache else None
642
+
643
+ if mode == 'text':
644
+ start_layer = 0
645
+ output_layer = self.config.fusion_layer
646
+
647
+ elif mode == 'fusion':
648
+ start_layer = self.config.fusion_layer
649
+ output_layer = self.config.num_hidden_layers
650
+
651
+ elif mode == 'multi_modal':
652
+ start_layer = 0
653
+ output_layer = self.config.num_hidden_layers
654
+
655
+ for i in range(start_layer, output_layer):
656
+ layer_module = self.layer[i]
657
+ if output_hidden_states:
658
+ all_hidden_states = all_hidden_states + (hidden_states,)
659
+
660
+ layer_head_mask = head_mask[i] if head_mask is not None else None
661
+ past_key_value = past_key_values[i] if past_key_values is not None else None
662
+
663
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
664
+
665
+ if use_cache:
666
+ use_cache = False
667
+
668
+ def create_custom_forward(module):
669
+ def custom_forward(*inputs):
670
+ return module(*inputs, past_key_value, output_attentions)
671
+
672
+ return custom_forward
673
+
674
+ layer_outputs = torch.utils.checkpoint.checkpoint(
675
+ create_custom_forward(layer_module),
676
+ hidden_states,
677
+ attention_mask,
678
+ layer_head_mask,
679
+ encoder_hidden_states,
680
+ encoder_attention_mask,
681
+ )
682
+ else:
683
+ layer_outputs = layer_module(
684
+ hidden_states,
685
+ attention_mask,
686
+ layer_head_mask,
687
+ encoder_hidden_states,
688
+ encoder_attention_mask,
689
+ past_key_value,
690
+ output_attentions,
691
+ ) # (context_layer, attention_probs, attention_scores, past_key_value,)
692
+ hidden_states = layer_outputs[0]
693
+ # update visual sequence
694
+ if mode == 'fusion' and layer_module.token_keep_rate < 1:
695
+ encoder_hidden_states, encoder_attention_mask, token_idx = layer_outputs[-4:-1]
696
+
697
+ if output_token_idx:
698
+ all_token_idx = all_token_idx + (token_idx,)
699
+
700
+ if use_cache:
701
+ next_decoder_cache += (layer_outputs[-1],)
702
+ if output_attentions:
703
+ # whether to output normalized attention,
704
+ # note for unnormalized attention, there is a mask added
705
+ offset = int(normalize_attention)
706
+ all_self_attentions = all_self_attentions + (layer_outputs[2-offset], )
707
+ if hasattr(layer_module, "crossattention"):
708
+ all_cross_attentions = all_cross_attentions + (layer_outputs[4-offset], )
709
+
710
+ if output_hidden_states:
711
+ all_hidden_states = all_hidden_states + (hidden_states,)
712
+
713
+ if not return_dict:
714
+ return tuple(
715
+ v
716
+ for v in [
717
+ hidden_states,
718
+ next_decoder_cache,
719
+ all_hidden_states,
720
+ all_self_attentions,
721
+ all_cross_attentions,
722
+ ]
723
+ if v is not None
724
+ )
725
+ return BertModelOutputWithPastAndCrossAttentions(
726
+ last_hidden_state=hidden_states,
727
+ past_key_values=next_decoder_cache,
728
+ hidden_states=all_hidden_states,
729
+ attentions=all_self_attentions,
730
+ cross_attentions=all_cross_attentions,
731
+ token_idx=all_token_idx
732
+ )
733
+
734
+
735
+ class BertPooler(nn.Module):
736
+ def __init__(self, config):
737
+ super().__init__()
738
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
739
+ self.activation = nn.Tanh()
740
+
741
+ def forward(self, hidden_states):
742
+ # We "pool" the model by simply taking the hidden state corresponding
743
+ # to the first token.
744
+ first_token_tensor = hidden_states[:, 0]
745
+ pooled_output = self.dense(first_token_tensor)
746
+ pooled_output = self.activation(pooled_output)
747
+ return pooled_output
748
+
749
+
750
+ class BertPredictionHeadTransform(nn.Module):
751
+ def __init__(self, config):
752
+ super().__init__()
753
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
754
+ if isinstance(config.hidden_act, str):
755
+ self.transform_act_fn = ACT2FN[config.hidden_act]
756
+ else:
757
+ self.transform_act_fn = config.hidden_act
758
+ self.LayerNorm = nn.LayerNorm(
759
+ config.hidden_size, eps=config.layer_norm_eps)
760
+
761
+ def forward(self, hidden_states):
762
+ hidden_states = self.dense(hidden_states)
763
+ hidden_states = self.transform_act_fn(hidden_states)
764
+ hidden_states = self.LayerNorm(hidden_states)
765
+ return hidden_states
766
+
767
+
768
+ class BertLMPredictionHead(nn.Module):
769
+ def __init__(self, config):
770
+ super().__init__()
771
+ self.transform = BertPredictionHeadTransform(config)
772
+
773
+ # The output weights are the same as the input embeddings, but there is
774
+ # an output-only bias for each token.
775
+ self.decoder = nn.Linear(
776
+ config.hidden_size, config.vocab_size, bias=False)
777
+
778
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
779
+
780
+ # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
781
+ self.decoder.bias = self.bias
782
+
783
+ def forward(self, hidden_states):
784
+ hidden_states = self.transform(hidden_states)
785
+ hidden_states = self.decoder(hidden_states)
786
+ return hidden_states
787
+
788
+
789
+ class BertOnlyMLMHead(nn.Module):
790
+ def __init__(self, config):
791
+ super().__init__()
792
+ self.predictions = BertLMPredictionHead(config)
793
+
794
+ def forward(self, sequence_output):
795
+ prediction_scores = self.predictions(sequence_output)
796
+ return prediction_scores
797
+
798
+
799
+ class BertOnlyNSPHead(nn.Module):
800
+ def __init__(self, config):
801
+ super().__init__()
802
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
803
+
804
+ def forward(self, pooled_output):
805
+ seq_relationship_score = self.seq_relationship(pooled_output)
806
+ return seq_relationship_score
807
+
808
+
809
+ class BertPreTrainingHeads(nn.Module):
810
+ def __init__(self, config):
811
+ super().__init__()
812
+ self.predictions = BertLMPredictionHead(config)
813
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
814
+
815
+ def forward(self, sequence_output, pooled_output):
816
+ prediction_scores = self.predictions(sequence_output)
817
+ seq_relationship_score = self.seq_relationship(pooled_output)
818
+ return prediction_scores, seq_relationship_score
819
+
820
+
821
+ class BertPreTrainedModel(PreTrainedModel):
822
+ """
823
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
824
+ models.
825
+ """
826
+
827
+ config_class = BertConfig
828
+ load_tf_weights = load_tf_weights_in_bert
829
+ base_model_prefix = "bert"
830
+ _keys_to_ignore_on_load_missing = [r"position_ids"]
831
+
832
+ def _init_weights(self, module):
833
+ """ Initialize the weights """
834
+ if isinstance(module, (nn.Linear, nn.Embedding)):
835
+ # Slightly different from the TF version which uses truncated_normal for initialization
836
+ # cf https://github.com/pytorch/pytorch/pull/5617
837
+ module.weight.data.normal_(
838
+ mean=0.0, std=self.config.initializer_range)
839
+ elif isinstance(module, nn.LayerNorm):
840
+ module.bias.data.zero_()
841
+ module.weight.data.fill_(1.0)
842
+ if isinstance(module, nn.Linear) and module.bias is not None:
843
+ module.bias.data.zero_()
844
+
845
+
846
+ @dataclass
847
+ class BertForPreTrainingOutput(ModelOutput):
848
+ """
849
+ Output type of :class:`~transformers.BertForPreTraining`.
850
+ Args:
851
+ loss (`optional`, returned when ``labels`` is provided, ``torch.FloatTensor`` of shape :obj:`(1,)`):
852
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
853
+ (classification) loss.
854
+ prediction_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`):
855
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
856
+ seq_relationship_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
857
+ Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
858
+ before SoftMax).
859
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
860
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
861
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
862
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
863
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
864
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
865
+ sequence_length, sequence_length)`.
866
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
867
+ heads.
868
+ """
869
+
870
+ loss: Optional[torch.FloatTensor] = None
871
+ prediction_logits: torch.FloatTensor = None
872
+ seq_relationship_logits: torch.FloatTensor = None
873
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
874
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
875
+
876
+
877
+ BERT_START_DOCSTRING = r"""
878
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
879
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
880
+ pruning heads etc.)
881
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
882
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
883
+ general usage and behavior.
884
+ Parameters:
885
+ config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
886
+ Initializing with a config file does not load the weights associated with the model, only the
887
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
888
+ weights.
889
+ """
890
+
891
+ BERT_INPUTS_DOCSTRING = r"""
892
+ Args:
893
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`):
894
+ Indices of input sequence tokens in the vocabulary.
895
+ Indices can be obtained using :class:`~transformers.BertTokenizer`. See
896
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
897
+ details.
898
+ `What are input IDs? <../glossary.html#input-ids>`__
899
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`({0})`, `optional`):
900
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
901
+ - 1 for tokens that are **not masked**,
902
+ - 0 for tokens that are **masked**.
903
+ `What are attention masks? <../glossary.html#attention-mask>`__
904
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
905
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
906
+ 1]``:
907
+ - 0 corresponds to a `sentence A` token,
908
+ - 1 corresponds to a `sentence B` token.
909
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
910
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`({0})`, `optional`):
911
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
912
+ config.max_position_embeddings - 1]``.
913
+ `What are position IDs? <../glossary.html#position-ids>`_
914
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
915
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
916
+ - 1 indicates the head is **not masked**,
917
+ - 0 indicates the head is **masked**.
918
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`({0}, hidden_size)`, `optional`):
919
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
920
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
921
+ vectors than the model's internal embedding lookup matrix.
922
+ output_attentions (:obj:`bool`, `optional`):
923
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
924
+ tensors for more detail.
925
+ output_hidden_states (:obj:`bool`, `optional`):
926
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
927
+ more detail.
928
+ return_dict (:obj:`bool`, `optional`):
929
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
930
+ """
931
+
932
+
933
+ @add_start_docstrings(
934
+ "The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
935
+ BERT_START_DOCSTRING,
936
+ )
937
+ class BertModel(BertPreTrainedModel):
938
+ """
939
+ The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
940
+ cross-attention is added between the self-attention layers, following the architecture described in `Attention is
941
+ all you need <https://arxiv.org/abs/1706.03762>`__ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
942
+ Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
943
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an :obj:`encoder_hidden_states` is then expected as an
944
+ input to the forward pass.
945
+ """
946
+
947
+ def __init__(self, config, add_pooling_layer=True):
948
+ super().__init__(config)
949
+ self.config = config
950
+
951
+ self.embeddings = BertEmbeddings(config)
952
+
953
+ self.encoder = BertEncoder(config)
954
+
955
+ self.pooler = BertPooler(config) if add_pooling_layer else None
956
+
957
+ self.init_weights()
958
+
959
+ def get_input_embeddings(self):
960
+ return self.embeddings.word_embeddings
961
+
962
+ def set_input_embeddings(self, value):
963
+ self.embeddings.word_embeddings = value
964
+
965
+ def _prune_heads(self, heads_to_prune):
966
+ """
967
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
968
+ class PreTrainedModel
969
+ """
970
+ for layer, heads in heads_to_prune.items():
971
+ self.encoder.layer[layer].attention.prune_heads(heads)
972
+
973
+ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape: Tuple[int], device: device, is_decoder: bool) -> Tensor:
974
+ """
975
+ Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
976
+
977
+ Arguments:
978
+ attention_mask (:obj:`torch.Tensor`):
979
+ Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
980
+ input_shape (:obj:`Tuple[int]`):
981
+ The shape of the input to the model.
982
+ device: (:obj:`torch.device`):
983
+ The device of the input to the model.
984
+
985
+ Returns:
986
+ :obj:`torch.Tensor` The extended attention mask, with a the same dtype as :obj:`attention_mask.dtype`.
987
+ """
988
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
989
+ # ourselves in which case we just need to make it broadcastable to all heads.
990
+ if attention_mask.dim() == 3:
991
+ extended_attention_mask = attention_mask[:, None, :, :]
992
+ elif attention_mask.dim() == 2:
993
+ # Provided a padding mask of dimensions [batch_size, seq_length]
994
+ # - if the model is a decoder, apply a causal mask in addition to the padding mask
995
+ # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
996
+ if is_decoder:
997
+ batch_size, seq_length = input_shape
998
+ seq_ids = torch.arange(seq_length, device=device)
999
+ causal_mask = seq_ids[None, None, :].repeat(
1000
+ batch_size, seq_length, 1) <= seq_ids[None, :, None]
1001
+ # in case past_key_values are used we need to add a prefix ones mask to the causal mask
1002
+ # causal and attention masks must have same type with pytorch version < 1.3
1003
+ causal_mask = causal_mask.to(attention_mask.dtype)
1004
+
1005
+ if causal_mask.shape[1] < attention_mask.shape[1]:
1006
+ prefix_seq_len = attention_mask.shape[1] - \
1007
+ causal_mask.shape[1]
1008
+ causal_mask = torch.cat(
1009
+ [
1010
+ torch.ones(
1011
+ (batch_size, seq_length,
1012
+ prefix_seq_len), device=device, dtype=causal_mask.dtype
1013
+ ),
1014
+ causal_mask,
1015
+ ],
1016
+ axis=-1,
1017
+ )
1018
+
1019
+ extended_attention_mask = causal_mask[:, None,
1020
+ :, :] * attention_mask[:, None, None, :]
1021
+ else:
1022
+ extended_attention_mask = attention_mask[:, None, None, :]
1023
+ else:
1024
+ raise ValueError(
1025
+ "Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
1026
+ input_shape, attention_mask.shape
1027
+ )
1028
+ )
1029
+
1030
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
1031
+ # masked positions, this operation will create a tensor which is 0.0 for
1032
+ # positions we want to attend and -10000.0 for masked positions.
1033
+ # Since we are adding it to the raw scores before the softmax, this is
1034
+ # effectively the same as removing these entirely.
1035
+ extended_attention_mask = extended_attention_mask.to(
1036
+ dtype=self.dtype) # fp16 compatibility
1037
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
1038
+ return extended_attention_mask
1039
+
1040
+ def forward(
1041
+ self,
1042
+ input_ids=None,
1043
+ attention_mask=None,
1044
+ token_type_ids=None,
1045
+ position_ids=None,
1046
+ head_mask=None,
1047
+ inputs_embeds=None,
1048
+ encoder_embeds=None,
1049
+ encoder_hidden_states=None,
1050
+ encoder_attention_mask=None,
1051
+ past_key_values=None,
1052
+ use_cache=None,
1053
+ output_attentions=None,
1054
+ output_hidden_states=None,
1055
+ output_token_idx=None,
1056
+ return_dict=None,
1057
+ is_decoder=False,
1058
+ mode='multi_modal',
1059
+ normalize_attention=True,
1060
+ ):
1061
+ r"""
1062
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1063
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1064
+ the model is configured as a decoder.
1065
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1066
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1067
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1068
+ - 1 for tokens that are **not masked**,
1069
+ - 0 for tokens that are **masked**.
1070
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1071
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1072
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1073
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1074
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1075
+ use_cache (:obj:`bool`, `optional`):
1076
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1077
+ decoding (see :obj:`past_key_values`).
1078
+ """
1079
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1080
+ output_hidden_states = (
1081
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1082
+ )
1083
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1084
+
1085
+ if is_decoder:
1086
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
1087
+ else:
1088
+ use_cache = False
1089
+
1090
+ if input_ids is not None and inputs_embeds is not None:
1091
+ raise ValueError(
1092
+ "You cannot specify both input_ids and inputs_embeds at the same time")
1093
+ elif input_ids is not None:
1094
+ input_shape = input_ids.size()
1095
+ batch_size, seq_length = input_shape
1096
+ device = input_ids.device
1097
+ elif inputs_embeds is not None:
1098
+ input_shape = inputs_embeds.size()[:-1]
1099
+ batch_size, seq_length = input_shape
1100
+ device = inputs_embeds.device
1101
+ elif encoder_embeds is not None:
1102
+ input_shape = encoder_embeds.size()[:-1]
1103
+ batch_size, seq_length = input_shape
1104
+ device = encoder_embeds.device
1105
+ else:
1106
+ raise ValueError(
1107
+ "You have to specify either input_ids or inputs_embeds or encoder_embeds")
1108
+
1109
+ # past_key_values_length
1110
+ past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
1111
+
1112
+ if attention_mask is None:
1113
+ attention_mask = torch.ones(
1114
+ ((batch_size, seq_length + past_key_values_length)), device=device)
1115
+ if token_type_ids is None:
1116
+ token_type_ids = torch.zeros(
1117
+ input_shape, dtype=torch.long, device=device)
1118
+
1119
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
1120
+ # ourselves in which case we just need to make it broadcastable to all heads.
1121
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape,
1122
+ device, is_decoder)
1123
+
1124
+ # If a 2D or 3D attention mask is provided for the cross-attention
1125
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
1126
+ if encoder_hidden_states is not None:
1127
+ if type(encoder_hidden_states) == list:
1128
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states[0].size(
1129
+ )
1130
+ else:
1131
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
1132
+ encoder_hidden_shape = (
1133
+ encoder_batch_size, encoder_sequence_length)
1134
+
1135
+ if type(encoder_attention_mask) == list:
1136
+ encoder_extended_attention_mask = [
1137
+ self.invert_attention_mask(mask) for mask in encoder_attention_mask]
1138
+ elif encoder_attention_mask is None:
1139
+ encoder_attention_mask = torch.ones(
1140
+ encoder_hidden_shape, device=device)
1141
+ encoder_extended_attention_mask = self.invert_attention_mask(
1142
+ encoder_attention_mask)
1143
+ else:
1144
+ encoder_extended_attention_mask = self.invert_attention_mask(
1145
+ encoder_attention_mask)
1146
+ else:
1147
+ encoder_extended_attention_mask = None
1148
+
1149
+ # Prepare head mask if needed
1150
+ # 1.0 in head_mask indicate we keep the head
1151
+ # attention_probs has shape bsz x n_heads x N x N
1152
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
1153
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
1154
+ head_mask = self.get_head_mask(
1155
+ head_mask, self.config.num_hidden_layers)
1156
+
1157
+ if encoder_embeds is None:
1158
+ embedding_output = self.embeddings(
1159
+ input_ids=input_ids,
1160
+ position_ids=position_ids,
1161
+ token_type_ids=token_type_ids,
1162
+ inputs_embeds=inputs_embeds,
1163
+ past_key_values_length=past_key_values_length,
1164
+ )
1165
+ else:
1166
+ embedding_output = encoder_embeds
1167
+
1168
+ encoder_outputs = self.encoder(
1169
+ embedding_output,
1170
+ attention_mask=extended_attention_mask,
1171
+ head_mask=head_mask,
1172
+ encoder_hidden_states=encoder_hidden_states,
1173
+ encoder_attention_mask=encoder_extended_attention_mask,
1174
+ past_key_values=past_key_values,
1175
+ use_cache=use_cache,
1176
+ output_attentions=output_attentions,
1177
+ output_hidden_states=output_hidden_states,
1178
+ output_token_idx=output_token_idx,
1179
+ return_dict=return_dict,
1180
+ mode=mode,
1181
+ normalize_attention=normalize_attention,
1182
+
1183
+ )
1184
+ sequence_output = encoder_outputs[0]
1185
+ pooled_output = self.pooler(
1186
+ sequence_output) if self.pooler is not None else None
1187
+
1188
+ if not return_dict:
1189
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
1190
+
1191
+ return BertModelOutputWithPoolingAndCrossAttentions(
1192
+ last_hidden_state=sequence_output,
1193
+ pooler_output=pooled_output,
1194
+ past_key_values=encoder_outputs.past_key_values,
1195
+ hidden_states=encoder_outputs.hidden_states,
1196
+ attentions=encoder_outputs.attentions,
1197
+ cross_attentions=encoder_outputs.cross_attentions,
1198
+ token_idx=encoder_outputs.token_idx,
1199
+ )
1200
+
1201
+
1202
+ @add_start_docstrings(
1203
+ """
1204
+ Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
1205
+ sentence prediction (classification)` head.
1206
+ """,
1207
+ BERT_START_DOCSTRING,
1208
+ )
1209
+ class BertForPreTraining(BertPreTrainedModel):
1210
+ def __init__(self, config):
1211
+ super().__init__(config)
1212
+
1213
+ self.bert = BertModel(config)
1214
+ self.cls = BertPreTrainingHeads(config)
1215
+
1216
+ self.init_weights()
1217
+
1218
+ def get_output_embeddings(self):
1219
+ return self.cls.predictions.decoder
1220
+
1221
+ def set_output_embeddings(self, new_embeddings):
1222
+ self.cls.predictions.decoder = new_embeddings
1223
+
1224
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1225
+ @replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
1226
+ def forward(
1227
+ self,
1228
+ input_ids=None,
1229
+ attention_mask=None,
1230
+ token_type_ids=None,
1231
+ position_ids=None,
1232
+ head_mask=None,
1233
+ inputs_embeds=None,
1234
+ labels=None,
1235
+ next_sentence_label=None,
1236
+ output_attentions=None,
1237
+ output_hidden_states=None,
1238
+ return_dict=None,
1239
+ ):
1240
+ r"""
1241
+ labels (:obj:`torch.LongTensor` of shape ``(batch_size, sequence_length)``, `optional`):
1242
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1243
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1244
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1245
+ next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`):
1246
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1247
+ (see :obj:`input_ids` docstring) Indices should be in ``[0, 1]``:
1248
+ - 0 indicates sequence B is a continuation of sequence A,
1249
+ - 1 indicates sequence B is a random sequence.
1250
+ kwargs (:obj:`Dict[str, any]`, optional, defaults to `{}`):
1251
+ Used to hide legacy arguments that have been deprecated.
1252
+ Returns:
1253
+ Example::
1254
+ >>> from transformers import BertTokenizer, BertForPreTraining
1255
+ >>> import torch
1256
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1257
+ >>> model = BertForPreTraining.from_pretrained('bert-base-uncased')
1258
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1259
+ >>> outputs = model(**inputs)
1260
+ >>> prediction_logits = outputs.prediction_logits
1261
+ >>> seq_relationship_logits = outputs.seq_relationship_logits
1262
+ """
1263
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1264
+
1265
+ outputs = self.bert(
1266
+ input_ids,
1267
+ attention_mask=attention_mask,
1268
+ token_type_ids=token_type_ids,
1269
+ position_ids=position_ids,
1270
+ head_mask=head_mask,
1271
+ inputs_embeds=inputs_embeds,
1272
+ output_attentions=output_attentions,
1273
+ output_hidden_states=output_hidden_states,
1274
+ return_dict=return_dict,
1275
+ )
1276
+
1277
+ sequence_output, pooled_output = outputs[:2]
1278
+ prediction_scores, seq_relationship_score = self.cls(
1279
+ sequence_output, pooled_output)
1280
+
1281
+ total_loss = None
1282
+ if labels is not None and next_sentence_label is not None:
1283
+ loss_fct = CrossEntropyLoss()
1284
+ masked_lm_loss = loss_fct(
1285
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1286
+ next_sentence_loss = loss_fct(
1287
+ seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
1288
+ total_loss = masked_lm_loss + next_sentence_loss
1289
+
1290
+ if not return_dict:
1291
+ output = (prediction_scores, seq_relationship_score) + outputs[2:]
1292
+ return ((total_loss,) + output) if total_loss is not None else output
1293
+
1294
+ return BertForPreTrainingOutput(
1295
+ loss=total_loss,
1296
+ prediction_logits=prediction_scores,
1297
+ seq_relationship_logits=seq_relationship_score,
1298
+ hidden_states=outputs.hidden_states,
1299
+ attentions=outputs.attentions,
1300
+ )
1301
+
1302
+
1303
+ @add_start_docstrings(
1304
+ """Bert Model with a `language modeling` head on top for CLM fine-tuning. """, BERT_START_DOCSTRING
1305
+ )
1306
+ class BertLMHeadModel(BertPreTrainedModel):
1307
+
1308
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1309
+ _keys_to_ignore_on_load_missing = [
1310
+ r"position_ids", r"predictions.decoder.bias"]
1311
+
1312
+ def __init__(self, config):
1313
+ super().__init__(config)
1314
+
1315
+ self.bert = BertModel(config, add_pooling_layer=False)
1316
+ self.cls = BertOnlyMLMHead(config)
1317
+
1318
+ self.init_weights()
1319
+
1320
+ def get_output_embeddings(self):
1321
+ return self.cls.predictions.decoder
1322
+
1323
+ def set_output_embeddings(self, new_embeddings):
1324
+ self.cls.predictions.decoder = new_embeddings
1325
+
1326
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1327
+ @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
1328
+ def forward(
1329
+ self,
1330
+ input_ids=None,
1331
+ attention_mask=None,
1332
+ token_type_ids=None,
1333
+ position_ids=None,
1334
+ head_mask=None,
1335
+ inputs_embeds=None,
1336
+ encoder_hidden_states=None,
1337
+ encoder_attention_mask=None,
1338
+ labels=None,
1339
+ past_key_values=None,
1340
+ use_cache=None,
1341
+ output_attentions=None,
1342
+ output_hidden_states=None,
1343
+ return_dict=None,
1344
+ is_decoder=True,
1345
+ reduction='mean',
1346
+ mode='multi_modal',
1347
+ normalize_attention=True,
1348
+ soft_labels=None,
1349
+ alpha=0,
1350
+ return_logits=False,
1351
+ ):
1352
+ r"""
1353
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
1354
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention if
1355
+ the model is configured as a decoder.
1356
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1357
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask is used in
1358
+ the cross-attention if the model is configured as a decoder. Mask values selected in ``[0, 1]``:
1359
+ - 1 for tokens that are **not masked**,
1360
+ - 0 for tokens that are **masked**.
1361
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1362
+ Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
1363
+ ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are
1364
+ ignored (masked), the loss is only computed for the tokens with labels n ``[0, ..., config.vocab_size]``
1365
+ past_key_values (:obj:`tuple(tuple(torch.FloatTensor))` of length :obj:`config.n_layers` with each tuple having 4 tensors of shape :obj:`(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
1366
+ Contains precomputed key and value hidden states of the attention blocks. Can be used to speed up decoding.
1367
+ If :obj:`past_key_values` are used, the user can optionally input only the last :obj:`decoder_input_ids`
1368
+ (those that don't have their past key value states given to this model) of shape :obj:`(batch_size, 1)`
1369
+ instead of all :obj:`decoder_input_ids` of shape :obj:`(batch_size, sequence_length)`.
1370
+ use_cache (:obj:`bool`, `optional`):
1371
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
1372
+ decoding (see :obj:`past_key_values`).
1373
+ Returns:
1374
+ Example::
1375
+ >>> from transformers import BertTokenizer, BertLMHeadModel, BertConfig
1376
+ >>> import torch
1377
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
1378
+ >>> config = BertConfig.from_pretrained("bert-base-cased")
1379
+ >>> model = BertLMHeadModel.from_pretrained('bert-base-cased', config=config)
1380
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
1381
+ >>> outputs = model(**inputs)
1382
+ >>> prediction_logits = outputs.logits
1383
+ """
1384
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1385
+ if labels is not None:
1386
+ use_cache = False
1387
+
1388
+ outputs = self.bert(
1389
+ input_ids,
1390
+ attention_mask=attention_mask,
1391
+ token_type_ids=token_type_ids,
1392
+ position_ids=position_ids,
1393
+ head_mask=head_mask,
1394
+ inputs_embeds=inputs_embeds,
1395
+ encoder_hidden_states=encoder_hidden_states,
1396
+ encoder_attention_mask=encoder_attention_mask,
1397
+ past_key_values=past_key_values,
1398
+ use_cache=use_cache,
1399
+ output_attentions=output_attentions,
1400
+ output_hidden_states=output_hidden_states,
1401
+ return_dict=return_dict,
1402
+ is_decoder=is_decoder,
1403
+ mode=mode,
1404
+ normalize_attention=normalize_attention,
1405
+ )
1406
+
1407
+ sequence_output = outputs[0]
1408
+ prediction_scores = self.cls(sequence_output)
1409
+
1410
+ if return_logits:
1411
+ return prediction_scores[:, :-1, :].contiguous()
1412
+
1413
+ lm_loss = None
1414
+ if labels is not None:
1415
+ # we are doing next-token prediction; shift prediction scores and input ids by one
1416
+ shifted_prediction_scores = prediction_scores[:,
1417
+ :-1, :].contiguous()
1418
+ labels = labels[:, 1:].contiguous()
1419
+ loss_fct = CrossEntropyLoss(reduction=reduction)
1420
+ lm_loss = loss_fct(
1421
+ shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1422
+ lm_loss = lm_loss.view(prediction_scores.size(0), -1).sum(1)
1423
+
1424
+ if soft_labels is not None:
1425
+ loss_distill = - \
1426
+ torch.sum(F.log_softmax(shifted_prediction_scores,
1427
+ dim=1)*soft_labels, dim=-1)
1428
+ loss_distill = (loss_distill * (labels != -100)).sum(1)
1429
+ lm_loss = (1-alpha)*lm_loss + alpha*loss_distill
1430
+
1431
+ if not return_dict:
1432
+ output = (prediction_scores,) + outputs[2:]
1433
+ return ((lm_loss,) + output) if lm_loss is not None else output
1434
+
1435
+ return CausalLMOutputWithCrossAttentions(
1436
+ loss=lm_loss,
1437
+ logits=prediction_scores,
1438
+ past_key_values=outputs.past_key_values,
1439
+ hidden_states=outputs.hidden_states,
1440
+ attentions=outputs.attentions,
1441
+ cross_attentions=outputs.cross_attentions,
1442
+ )
1443
+
1444
+ def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, **model_kwargs):
1445
+ input_shape = input_ids.shape
1446
+ # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
1447
+ if attention_mask is None:
1448
+ attention_mask = input_ids.new_ones(input_shape)
1449
+
1450
+ # cut decoder_input_ids if past is used
1451
+ if past is not None:
1452
+ input_ids = input_ids[:, -1:]
1453
+
1454
+ return {
1455
+ "input_ids": input_ids,
1456
+ "attention_mask": attention_mask,
1457
+ "past_key_values": past,
1458
+ "encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
1459
+ "encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
1460
+ "is_decoder": True,
1461
+ }
1462
+
1463
+ def _reorder_cache(self, past, beam_idx):
1464
+ reordered_past = ()
1465
+ for layer_past in past:
1466
+ reordered_past += (tuple(past_state.index_select(0, beam_idx)
1467
+ for past_state in layer_past),)
1468
+ return reordered_past
1469
+
1470
+
1471
+ @dataclass
1472
+ class MaskedLMOutputWithDistill(MaskedLMOutput):
1473
+ loss_aux: Optional[torch.FloatTensor] = None
1474
+ loss_distill: Optional[torch.FloatTensor] = None
1475
+
1476
+
1477
+ @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
1478
+ class BertForMaskedLM(BertPreTrainedModel):
1479
+
1480
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1481
+ _keys_to_ignore_on_load_missing = [
1482
+ r"position_ids", r"predictions.decoder.bias"]
1483
+
1484
+ def __init__(self, config):
1485
+ super().__init__(config)
1486
+
1487
+ self.bert = BertModel(config, add_pooling_layer=False)
1488
+ self.cls = BertOnlyMLMHead(config)
1489
+
1490
+ self.init_weights()
1491
+
1492
+ def tie_aux_decoder_weights(self, module, aux_modules):
1493
+ """Tie decoder weights of all `aux_modules` to `module`, (not bias)"""
1494
+ for m in aux_modules:
1495
+ m.predictions.decoder.weight = module.predictions.decoder.weight
1496
+
1497
+ def get_output_embeddings(self):
1498
+ return self.cls.predictions.decoder
1499
+
1500
+ def set_output_embeddings(self, new_embeddings):
1501
+ self.cls.predictions.decoder = new_embeddings
1502
+
1503
+ def forward(
1504
+ self,
1505
+ input_ids=None,
1506
+ attention_mask=None,
1507
+ token_type_ids=None,
1508
+ position_ids=None,
1509
+ head_mask=None,
1510
+ inputs_embeds=None,
1511
+ encoder_embeds=None,
1512
+ encoder_hidden_states=None,
1513
+ encoder_attention_mask=None,
1514
+ labels=None,
1515
+ output_attentions=None,
1516
+ output_hidden_states=None,
1517
+ return_dict=None,
1518
+ is_decoder=False,
1519
+ mode='multi_modal',
1520
+ normalize_attention=True,
1521
+ soft_labels=None,
1522
+ alpha=0,
1523
+ return_logits=False,
1524
+ ):
1525
+ r"""
1526
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1527
+ Labels for computing the masked language modeling loss. Indices should be in ``[-100, 0, ...,
1528
+ config.vocab_size]`` (see ``input_ids`` docstring) Tokens with indices set to ``-100`` are ignored
1529
+ (masked), the loss is only computed for the tokens with labels in ``[0, ..., config.vocab_size]``
1530
+ """
1531
+
1532
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1533
+
1534
+ outputs = self.bert(
1535
+ input_ids,
1536
+ attention_mask=attention_mask,
1537
+ token_type_ids=token_type_ids,
1538
+ position_ids=position_ids,
1539
+ head_mask=head_mask,
1540
+ inputs_embeds=inputs_embeds,
1541
+ encoder_embeds=encoder_embeds,
1542
+ encoder_hidden_states=encoder_hidden_states,
1543
+ encoder_attention_mask=encoder_attention_mask,
1544
+ output_attentions=output_attentions,
1545
+ output_hidden_states=output_hidden_states,
1546
+ return_dict=return_dict,
1547
+ is_decoder=is_decoder,
1548
+ mode=mode,
1549
+ normalize_attention=normalize_attention
1550
+ )
1551
+
1552
+ sequence_output = outputs[0]
1553
+ prediction_scores = self.cls(sequence_output)
1554
+
1555
+ if return_logits:
1556
+ return prediction_scores
1557
+
1558
+ masked_lm_loss = None
1559
+ masked_lm_loss_aux = 0.
1560
+ if labels is not None:
1561
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
1562
+ masked_lm_loss = loss_fct(
1563
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
1564
+
1565
+ if soft_labels is not None:
1566
+ loss_distill = - \
1567
+ torch.sum(F.log_softmax(prediction_scores, dim=1)
1568
+ * soft_labels, dim=-1)
1569
+ loss_distill = loss_distill[labels != -100].mean()
1570
+ masked_lm_loss = (1-alpha)*masked_lm_loss + alpha*loss_distill
1571
+
1572
+ if not return_dict:
1573
+ output = (prediction_scores,) + outputs[2:]
1574
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
1575
+
1576
+ # changed from MaskedLMOutput to MaskedLMOutputWithDistill
1577
+ return MaskedLMOutputWithDistill(
1578
+ loss=masked_lm_loss,
1579
+ loss_aux=masked_lm_loss_aux,
1580
+ logits=prediction_scores,
1581
+ hidden_states=outputs.hidden_states,
1582
+ attentions=outputs.attentions,
1583
+ )
1584
+
1585
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
1586
+ input_shape = input_ids.shape
1587
+ effective_batch_size = input_shape[0]
1588
+
1589
+ # add a dummy token
1590
+ assert self.config.pad_token_id is not None, "The PAD token should be defined for generation"
1591
+ attention_mask = torch.cat(
1592
+ [attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
1593
+ dummy_token = torch.full(
1594
+ (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
1595
+ )
1596
+ input_ids = torch.cat([input_ids, dummy_token], dim=1)
1597
+
1598
+ return {"input_ids": input_ids, "attention_mask": attention_mask}
1599
+
1600
+
1601
+ @add_start_docstrings(
1602
+ """Bert Model with a `next sentence prediction (classification)` head on top. """,
1603
+ BERT_START_DOCSTRING,
1604
+ )
1605
+ class BertForNextSentencePrediction(BertPreTrainedModel):
1606
+ def __init__(self, config):
1607
+ super().__init__(config)
1608
+
1609
+ self.bert = BertModel(config)
1610
+ self.cls = BertOnlyNSPHead(config)
1611
+
1612
+ self.init_weights()
1613
+
1614
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1615
+ @replace_return_docstrings(output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC)
1616
+ def forward(
1617
+ self,
1618
+ input_ids=None,
1619
+ attention_mask=None,
1620
+ token_type_ids=None,
1621
+ position_ids=None,
1622
+ head_mask=None,
1623
+ inputs_embeds=None,
1624
+ labels=None,
1625
+ output_attentions=None,
1626
+ output_hidden_states=None,
1627
+ return_dict=None,
1628
+ **kwargs
1629
+ ):
1630
+ r"""
1631
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1632
+ Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
1633
+ (see ``input_ids`` docstring). Indices should be in ``[0, 1]``:
1634
+ - 0 indicates sequence B is a continuation of sequence A,
1635
+ - 1 indicates sequence B is a random sequence.
1636
+ Returns:
1637
+ Example::
1638
+ >>> from transformers import BertTokenizer, BertForNextSentencePrediction
1639
+ >>> import torch
1640
+ >>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
1641
+ >>> model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
1642
+ >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
1643
+ >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
1644
+ >>> encoding = tokenizer(prompt, next_sentence, return_tensors='pt')
1645
+ >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
1646
+ >>> logits = outputs.logits
1647
+ >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
1648
+ """
1649
+
1650
+ if "next_sentence_label" in kwargs:
1651
+ warnings.warn(
1652
+ "The `next_sentence_label` argument is deprecated and will be removed in a future version, use `labels` instead.",
1653
+ FutureWarning,
1654
+ )
1655
+ labels = kwargs.pop("next_sentence_label")
1656
+
1657
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1658
+
1659
+ outputs = self.bert(
1660
+ input_ids,
1661
+ attention_mask=attention_mask,
1662
+ token_type_ids=token_type_ids,
1663
+ position_ids=position_ids,
1664
+ head_mask=head_mask,
1665
+ inputs_embeds=inputs_embeds,
1666
+ output_attentions=output_attentions,
1667
+ output_hidden_states=output_hidden_states,
1668
+ return_dict=return_dict,
1669
+ )
1670
+
1671
+ pooled_output = outputs[1]
1672
+
1673
+ seq_relationship_scores = self.cls(pooled_output)
1674
+
1675
+ next_sentence_loss = None
1676
+ if labels is not None:
1677
+ loss_fct = CrossEntropyLoss()
1678
+ next_sentence_loss = loss_fct(
1679
+ seq_relationship_scores.view(-1, 2), labels.view(-1))
1680
+
1681
+ if not return_dict:
1682
+ output = (seq_relationship_scores,) + outputs[2:]
1683
+ return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
1684
+
1685
+ return NextSentencePredictorOutput(
1686
+ loss=next_sentence_loss,
1687
+ logits=seq_relationship_scores,
1688
+ hidden_states=outputs.hidden_states,
1689
+ attentions=outputs.attentions,
1690
+ )
1691
+
1692
+
1693
+ @add_start_docstrings(
1694
+ """
1695
+ Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
1696
+ output) e.g. for GLUE tasks.
1697
+ """,
1698
+ BERT_START_DOCSTRING,
1699
+ )
1700
+ class BertForSequenceClassification(BertPreTrainedModel):
1701
+ def __init__(self, config):
1702
+ super().__init__(config)
1703
+ self.num_labels = config.num_labels
1704
+
1705
+ self.bert = BertModel(config)
1706
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1707
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1708
+
1709
+ self.init_weights()
1710
+
1711
+ def forward(
1712
+ self,
1713
+ input_ids=None,
1714
+ attention_mask=None,
1715
+ token_type_ids=None,
1716
+ position_ids=None,
1717
+ head_mask=None,
1718
+ inputs_embeds=None,
1719
+ labels=None,
1720
+ output_attentions=None,
1721
+ output_hidden_states=None,
1722
+ return_dict=None,
1723
+ ):
1724
+ r"""
1725
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1726
+ Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1727
+ config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1728
+ If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1729
+ """
1730
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1731
+
1732
+ outputs = self.bert(
1733
+ input_ids,
1734
+ attention_mask=attention_mask,
1735
+ token_type_ids=token_type_ids,
1736
+ position_ids=position_ids,
1737
+ head_mask=head_mask,
1738
+ inputs_embeds=inputs_embeds,
1739
+ output_attentions=output_attentions,
1740
+ output_hidden_states=output_hidden_states,
1741
+ return_dict=return_dict,
1742
+ )
1743
+
1744
+ pooled_output = outputs[1]
1745
+
1746
+ pooled_output = self.dropout(pooled_output)
1747
+ logits = self.classifier(pooled_output)
1748
+
1749
+ loss = None
1750
+ if labels is not None:
1751
+ if self.num_labels == 1:
1752
+ # We are doing regression
1753
+ loss_fct = MSELoss()
1754
+ loss = loss_fct(logits.view(-1), labels.view(-1))
1755
+ else:
1756
+ loss_fct = CrossEntropyLoss()
1757
+ loss = loss_fct(
1758
+ logits.view(-1, self.num_labels), labels.view(-1))
1759
+
1760
+ if not return_dict:
1761
+ output = (logits,) + outputs[2:]
1762
+ return ((loss,) + output) if loss is not None else output
1763
+
1764
+ return SequenceClassifierOutput(
1765
+ loss=loss,
1766
+ logits=logits,
1767
+ hidden_states=outputs.hidden_states,
1768
+ attentions=outputs.attentions,
1769
+ )
1770
+
1771
+
1772
+ @add_start_docstrings(
1773
+ """
1774
+ Bert Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1775
+ softmax) e.g. for RocStories/SWAG tasks.
1776
+ """,
1777
+ BERT_START_DOCSTRING,
1778
+ )
1779
+ class BertForMultipleChoice(BertPreTrainedModel):
1780
+ def __init__(self, config):
1781
+ super().__init__(config)
1782
+
1783
+ self.bert = BertModel(config)
1784
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1785
+ self.classifier = nn.Linear(config.hidden_size, 1)
1786
+
1787
+ self.init_weights()
1788
+
1789
+ def forward(
1790
+ self,
1791
+ input_ids=None,
1792
+ attention_mask=None,
1793
+ token_type_ids=None,
1794
+ position_ids=None,
1795
+ head_mask=None,
1796
+ inputs_embeds=None,
1797
+ labels=None,
1798
+ output_attentions=None,
1799
+ output_hidden_states=None,
1800
+ return_dict=None,
1801
+ ):
1802
+ r"""
1803
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1804
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
1805
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
1806
+ :obj:`input_ids` above)
1807
+ """
1808
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1809
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
1810
+
1811
+ input_ids = input_ids.view(-1, input_ids.size(-1)
1812
+ ) if input_ids is not None else None
1813
+ attention_mask = attention_mask.view(
1814
+ -1, attention_mask.size(-1)) if attention_mask is not None else None
1815
+ token_type_ids = token_type_ids.view(
1816
+ -1, token_type_ids.size(-1)) if token_type_ids is not None else None
1817
+ position_ids = position_ids.view(-1, position_ids.size(-1)
1818
+ ) if position_ids is not None else None
1819
+ inputs_embeds = (
1820
+ inputs_embeds.view(-1, inputs_embeds.size(-2),
1821
+ inputs_embeds.size(-1))
1822
+ if inputs_embeds is not None
1823
+ else None
1824
+ )
1825
+
1826
+ outputs = self.bert(
1827
+ input_ids,
1828
+ attention_mask=attention_mask,
1829
+ token_type_ids=token_type_ids,
1830
+ position_ids=position_ids,
1831
+ head_mask=head_mask,
1832
+ inputs_embeds=inputs_embeds,
1833
+ output_attentions=output_attentions,
1834
+ output_hidden_states=output_hidden_states,
1835
+ return_dict=return_dict,
1836
+ )
1837
+
1838
+ pooled_output = outputs[1]
1839
+
1840
+ pooled_output = self.dropout(pooled_output)
1841
+ logits = self.classifier(pooled_output)
1842
+ reshaped_logits = logits.view(-1, num_choices)
1843
+
1844
+ loss = None
1845
+ if labels is not None:
1846
+ loss_fct = CrossEntropyLoss()
1847
+ loss = loss_fct(reshaped_logits, labels)
1848
+
1849
+ if not return_dict:
1850
+ output = (reshaped_logits,) + outputs[2:]
1851
+ return ((loss,) + output) if loss is not None else output
1852
+
1853
+ return MultipleChoiceModelOutput(
1854
+ loss=loss,
1855
+ logits=reshaped_logits,
1856
+ hidden_states=outputs.hidden_states,
1857
+ attentions=outputs.attentions,
1858
+ )
1859
+
1860
+
1861
+ @add_start_docstrings(
1862
+ """
1863
+ Bert Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1864
+ Named-Entity-Recognition (NER) tasks.
1865
+ """,
1866
+ BERT_START_DOCSTRING,
1867
+ )
1868
+ class BertForTokenClassification(BertPreTrainedModel):
1869
+
1870
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1871
+
1872
+ def __init__(self, config):
1873
+ super().__init__(config)
1874
+ self.num_labels = config.num_labels
1875
+
1876
+ self.bert = BertModel(config, add_pooling_layer=False)
1877
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1878
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
1879
+
1880
+ self.init_weights()
1881
+
1882
+ def forward(
1883
+ self,
1884
+ input_ids=None,
1885
+ attention_mask=None,
1886
+ token_type_ids=None,
1887
+ position_ids=None,
1888
+ head_mask=None,
1889
+ inputs_embeds=None,
1890
+ labels=None,
1891
+ output_attentions=None,
1892
+ output_hidden_states=None,
1893
+ return_dict=None,
1894
+ ):
1895
+ r"""
1896
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
1897
+ Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
1898
+ 1]``.
1899
+ """
1900
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1901
+
1902
+ outputs = self.bert(
1903
+ input_ids,
1904
+ attention_mask=attention_mask,
1905
+ token_type_ids=token_type_ids,
1906
+ position_ids=position_ids,
1907
+ head_mask=head_mask,
1908
+ inputs_embeds=inputs_embeds,
1909
+ output_attentions=output_attentions,
1910
+ output_hidden_states=output_hidden_states,
1911
+ return_dict=return_dict,
1912
+ )
1913
+
1914
+ sequence_output = outputs[0]
1915
+
1916
+ sequence_output = self.dropout(sequence_output)
1917
+ logits = self.classifier(sequence_output)
1918
+
1919
+ loss = None
1920
+ if labels is not None:
1921
+ loss_fct = CrossEntropyLoss()
1922
+ # Only keep active parts of the loss
1923
+ if attention_mask is not None:
1924
+ active_loss = attention_mask.view(-1) == 1
1925
+ active_logits = logits.view(-1, self.num_labels)
1926
+ active_labels = torch.where(
1927
+ active_loss, labels.view(-1), torch.tensor(
1928
+ loss_fct.ignore_index).type_as(labels)
1929
+ )
1930
+ loss = loss_fct(active_logits, active_labels)
1931
+ else:
1932
+ loss = loss_fct(
1933
+ logits.view(-1, self.num_labels), labels.view(-1))
1934
+
1935
+ if not return_dict:
1936
+ output = (logits,) + outputs[2:]
1937
+ return ((loss,) + output) if loss is not None else output
1938
+
1939
+ return TokenClassifierOutput(
1940
+ loss=loss,
1941
+ logits=logits,
1942
+ hidden_states=outputs.hidden_states,
1943
+ attentions=outputs.attentions,
1944
+ )
1945
+
1946
+
1947
+ @add_start_docstrings(
1948
+ """
1949
+ Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1950
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1951
+ """,
1952
+ BERT_START_DOCSTRING,
1953
+ )
1954
+ class BertForQuestionAnswering(BertPreTrainedModel):
1955
+
1956
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
1957
+
1958
+ def __init__(self, config):
1959
+ super().__init__(config)
1960
+ self.num_labels = config.num_labels
1961
+
1962
+ self.bert = BertModel(config, add_pooling_layer=False)
1963
+ self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
1964
+
1965
+ self.init_weights()
1966
+
1967
+ def forward(
1968
+ self,
1969
+ input_ids=None,
1970
+ attention_mask=None,
1971
+ token_type_ids=None,
1972
+ position_ids=None,
1973
+ head_mask=None,
1974
+ inputs_embeds=None,
1975
+ start_positions=None,
1976
+ end_positions=None,
1977
+ output_attentions=None,
1978
+ output_hidden_states=None,
1979
+ return_dict=None,
1980
+ ):
1981
+ r"""
1982
+ start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1983
+ Labels for position (index) of the start of the labelled span for computing the token classification loss.
1984
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1985
+ sequence are not taken into account for computing the loss.
1986
+ end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1987
+ Labels for position (index) of the end of the labelled span for computing the token classification loss.
1988
+ Positions are clamped to the length of the sequence (:obj:`sequence_length`). Position outside of the
1989
+ sequence are not taken into account for computing the loss.
1990
+ """
1991
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1992
+
1993
+ outputs = self.bert(
1994
+ input_ids,
1995
+ attention_mask=attention_mask,
1996
+ token_type_ids=token_type_ids,
1997
+ position_ids=position_ids,
1998
+ head_mask=head_mask,
1999
+ inputs_embeds=inputs_embeds,
2000
+ output_attentions=output_attentions,
2001
+ output_hidden_states=output_hidden_states,
2002
+ return_dict=return_dict,
2003
+ )
2004
+
2005
+ sequence_output = outputs[0]
2006
+
2007
+ logits = self.qa_outputs(sequence_output)
2008
+ start_logits, end_logits = logits.split(1, dim=-1)
2009
+ start_logits = start_logits.squeeze(-1)
2010
+ end_logits = end_logits.squeeze(-1)
2011
+
2012
+ total_loss = None
2013
+ if start_positions is not None and end_positions is not None:
2014
+ # If we are on multi-GPU, split add a dimension
2015
+ if len(start_positions.size()) > 1:
2016
+ start_positions = start_positions.squeeze(-1)
2017
+ if len(end_positions.size()) > 1:
2018
+ end_positions = end_positions.squeeze(-1)
2019
+ # sometimes the start/end positions are outside our model inputs, we ignore these terms
2020
+ ignored_index = start_logits.size(1)
2021
+ start_positions.clamp_(0, ignored_index)
2022
+ end_positions.clamp_(0, ignored_index)
2023
+
2024
+ loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
2025
+ start_loss = loss_fct(start_logits, start_positions)
2026
+ end_loss = loss_fct(end_logits, end_positions)
2027
+ total_loss = (start_loss + end_loss) / 2
2028
+
2029
+ if not return_dict:
2030
+ output = (start_logits, end_logits) + outputs[2:]
2031
+ return ((total_loss,) + output) if total_loss is not None else output
2032
+
2033
+ return QuestionAnsweringModelOutput(
2034
+ loss=total_loss,
2035
+ start_logits=start_logits,
2036
+ end_logits=end_logits,
2037
+ hidden_states=outputs.hidden_states,
2038
+ attentions=outputs.attentions,
2039
+ )
svitt/tokenization_bert.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for Bert."""
16
+
17
+
18
+ import collections
19
+ import os
20
+ import unicodedata
21
+ from typing import List, Optional, Tuple
22
+
23
+ from transformers.tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
24
+ from transformers.utils import logging
25
+
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
30
+
31
+ PRETRAINED_VOCAB_FILES_MAP = {
32
+ "vocab_file": {
33
+ "bert-base-uncased": "https://huggingface.co/bert-base-uncased/resolve/main/vocab.txt",
34
+ "bert-large-uncased": "https://huggingface.co/bert-large-uncased/resolve/main/vocab.txt",
35
+ "bert-base-cased": "https://huggingface.co/bert-base-cased/resolve/main/vocab.txt",
36
+ "bert-large-cased": "https://huggingface.co/bert-large-cased/resolve/main/vocab.txt",
37
+ "bert-base-multilingual-uncased": "https://huggingface.co/bert-base-multilingual-uncased/resolve/main/vocab.txt",
38
+ "bert-base-multilingual-cased": "https://huggingface.co/bert-base-multilingual-cased/resolve/main/vocab.txt",
39
+ "bert-base-chinese": "https://huggingface.co/bert-base-chinese/resolve/main/vocab.txt",
40
+ "bert-base-german-cased": "https://huggingface.co/bert-base-german-cased/resolve/main/vocab.txt",
41
+ "bert-large-uncased-whole-word-masking": "https://huggingface.co/bert-large-uncased-whole-word-masking/resolve/main/vocab.txt",
42
+ "bert-large-cased-whole-word-masking": "https://huggingface.co/bert-large-cased-whole-word-masking/resolve/main/vocab.txt",
43
+ "bert-large-uncased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-uncased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
44
+ "bert-large-cased-whole-word-masking-finetuned-squad": "https://huggingface.co/bert-large-cased-whole-word-masking-finetuned-squad/resolve/main/vocab.txt",
45
+ "bert-base-cased-finetuned-mrpc": "https://huggingface.co/bert-base-cased-finetuned-mrpc/resolve/main/vocab.txt",
46
+ "bert-base-german-dbmdz-cased": "https://huggingface.co/bert-base-german-dbmdz-cased/resolve/main/vocab.txt",
47
+ "bert-base-german-dbmdz-uncased": "https://huggingface.co/bert-base-german-dbmdz-uncased/resolve/main/vocab.txt",
48
+ "TurkuNLP/bert-base-finnish-cased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-cased-v1/resolve/main/vocab.txt",
49
+ "TurkuNLP/bert-base-finnish-uncased-v1": "https://huggingface.co/TurkuNLP/bert-base-finnish-uncased-v1/resolve/main/vocab.txt",
50
+ "wietsedv/bert-base-dutch-cased": "https://huggingface.co/wietsedv/bert-base-dutch-cased/resolve/main/vocab.txt",
51
+ }
52
+ }
53
+
54
+ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
55
+ "bert-base-uncased": 512,
56
+ "bert-large-uncased": 512,
57
+ "bert-base-cased": 512,
58
+ "bert-large-cased": 512,
59
+ "bert-base-multilingual-uncased": 512,
60
+ "bert-base-multilingual-cased": 512,
61
+ "bert-base-chinese": 512,
62
+ "bert-base-german-cased": 512,
63
+ "bert-large-uncased-whole-word-masking": 512,
64
+ "bert-large-cased-whole-word-masking": 512,
65
+ "bert-large-uncased-whole-word-masking-finetuned-squad": 512,
66
+ "bert-large-cased-whole-word-masking-finetuned-squad": 512,
67
+ "bert-base-cased-finetuned-mrpc": 512,
68
+ "bert-base-german-dbmdz-cased": 512,
69
+ "bert-base-german-dbmdz-uncased": 512,
70
+ "TurkuNLP/bert-base-finnish-cased-v1": 512,
71
+ "TurkuNLP/bert-base-finnish-uncased-v1": 512,
72
+ "wietsedv/bert-base-dutch-cased": 512,
73
+ }
74
+
75
+ PRETRAINED_INIT_CONFIGURATION = {
76
+ "bert-base-uncased": {"do_lower_case": True},
77
+ "bert-large-uncased": {"do_lower_case": True},
78
+ "bert-base-cased": {"do_lower_case": False},
79
+ "bert-large-cased": {"do_lower_case": False},
80
+ "bert-base-multilingual-uncased": {"do_lower_case": True},
81
+ "bert-base-multilingual-cased": {"do_lower_case": False},
82
+ "bert-base-chinese": {"do_lower_case": False},
83
+ "bert-base-german-cased": {"do_lower_case": False},
84
+ "bert-large-uncased-whole-word-masking": {"do_lower_case": True},
85
+ "bert-large-cased-whole-word-masking": {"do_lower_case": False},
86
+ "bert-large-uncased-whole-word-masking-finetuned-squad": {"do_lower_case": True},
87
+ "bert-large-cased-whole-word-masking-finetuned-squad": {"do_lower_case": False},
88
+ "bert-base-cased-finetuned-mrpc": {"do_lower_case": False},
89
+ "bert-base-german-dbmdz-cased": {"do_lower_case": False},
90
+ "bert-base-german-dbmdz-uncased": {"do_lower_case": True},
91
+ "TurkuNLP/bert-base-finnish-cased-v1": {"do_lower_case": False},
92
+ "TurkuNLP/bert-base-finnish-uncased-v1": {"do_lower_case": True},
93
+ "wietsedv/bert-base-dutch-cased": {"do_lower_case": False},
94
+ }
95
+
96
+
97
+ def load_vocab(vocab_file):
98
+ """Loads a vocabulary file into a dictionary."""
99
+ vocab = collections.OrderedDict()
100
+ with open(vocab_file, "r", encoding="utf-8") as reader:
101
+ tokens = reader.readlines()
102
+ for index, token in enumerate(tokens):
103
+ token = token.rstrip("\n")
104
+ vocab[token] = index
105
+ return vocab
106
+
107
+
108
+ def whitespace_tokenize(text):
109
+ """Runs basic whitespace cleaning and splitting on a piece of text."""
110
+ text = text.strip()
111
+ if not text:
112
+ return []
113
+ tokens = text.split()
114
+ return tokens
115
+
116
+
117
+ class BertTokenizer(PreTrainedTokenizer):
118
+ r"""
119
+ Construct a BERT tokenizer. Based on WordPiece.
120
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods.
121
+ Users should refer to this superclass for more information regarding those methods.
122
+ Args:
123
+ vocab_file (:obj:`str`):
124
+ File containing the vocabulary.
125
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
126
+ Whether or not to lowercase the input when tokenizing.
127
+ do_basic_tokenize (:obj:`bool`, `optional`, defaults to :obj:`True`):
128
+ Whether or not to do basic tokenization before WordPiece.
129
+ never_split (:obj:`Iterable`, `optional`):
130
+ Collection of tokens which will never be split during tokenization. Only has an effect when
131
+ :obj:`do_basic_tokenize=True`
132
+ unk_token (:obj:`str`, `optional`, defaults to :obj:`"[UNK]"`):
133
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
134
+ token instead.
135
+ sep_token (:obj:`str`, `optional`, defaults to :obj:`"[SEP]"`):
136
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for
137
+ sequence classification or for a text and a question for question answering. It is also used as the last
138
+ token of a sequence built with special tokens.
139
+ pad_token (:obj:`str`, `optional`, defaults to :obj:`"[PAD]"`):
140
+ The token used for padding, for example when batching sequences of different lengths.
141
+ cls_token (:obj:`str`, `optional`, defaults to :obj:`"[CLS]"`):
142
+ The classifier token which is used when doing sequence classification (classification of the whole sequence
143
+ instead of per-token classification). It is the first token of the sequence when built with special tokens.
144
+ mask_token (:obj:`str`, `optional`, defaults to :obj:`"[MASK]"`):
145
+ The token used for masking values. This is the token used when training this model with masked language
146
+ modeling. This is the token which the model will try to predict.
147
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
148
+ Whether or not to tokenize Chinese characters.
149
+ This should likely be deactivated for Japanese (see this `issue
150
+ <https://github.com/huggingface/transformers/issues/328>`__).
151
+ strip_accents: (:obj:`bool`, `optional`):
152
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
153
+ value for :obj:`lowercase` (as in the original BERT).
154
+ """
155
+
156
+ vocab_files_names = VOCAB_FILES_NAMES
157
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
158
+ pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
159
+ max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
160
+
161
+ def __init__(
162
+ self,
163
+ vocab_file,
164
+ do_lower_case=True,
165
+ do_basic_tokenize=True,
166
+ never_split=None,
167
+ unk_token="[UNK]",
168
+ sep_token="[SEP]",
169
+ pad_token="[PAD]",
170
+ cls_token="[CLS]",
171
+ mask_token="[MASK]",
172
+ tokenize_chinese_chars=True,
173
+ strip_accents=None,
174
+ **kwargs
175
+ ):
176
+ super().__init__(
177
+ do_lower_case=do_lower_case,
178
+ do_basic_tokenize=do_basic_tokenize,
179
+ never_split=never_split,
180
+ unk_token=unk_token,
181
+ sep_token=sep_token,
182
+ pad_token=pad_token,
183
+ cls_token=cls_token,
184
+ mask_token=mask_token,
185
+ tokenize_chinese_chars=tokenize_chinese_chars,
186
+ strip_accents=strip_accents,
187
+ **kwargs,
188
+ )
189
+
190
+ if not os.path.isfile(vocab_file):
191
+ raise ValueError(
192
+ "Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
193
+ "model use `tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`".format(
194
+ vocab_file)
195
+ )
196
+ self.vocab = load_vocab(vocab_file)
197
+ self.ids_to_tokens = collections.OrderedDict(
198
+ [(ids, tok) for tok, ids in self.vocab.items()])
199
+ self.do_basic_tokenize = do_basic_tokenize
200
+ if do_basic_tokenize:
201
+ self.basic_tokenizer = BasicTokenizer(
202
+ do_lower_case=do_lower_case,
203
+ never_split=never_split,
204
+ tokenize_chinese_chars=tokenize_chinese_chars,
205
+ strip_accents=strip_accents,
206
+ )
207
+ self.wordpiece_tokenizer = WordpieceTokenizer(
208
+ vocab=self.vocab, unk_token=self.unk_token)
209
+
210
+ @property
211
+ def do_lower_case(self):
212
+ return self.basic_tokenizer.do_lower_case
213
+
214
+ @property
215
+ def vocab_size(self):
216
+ return len(self.vocab)
217
+
218
+ def get_vocab(self):
219
+ return dict(self.vocab, **self.added_tokens_encoder)
220
+
221
+ def _tokenize(self, text):
222
+ split_tokens = []
223
+ if self.do_basic_tokenize:
224
+ for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
225
+
226
+ # If the token is part of the never_split set
227
+ if token in self.basic_tokenizer.never_split:
228
+ split_tokens.append(token)
229
+ else:
230
+ split_tokens += self.wordpiece_tokenizer.tokenize(token)
231
+ else:
232
+ split_tokens = self.wordpiece_tokenizer.tokenize(text)
233
+ return split_tokens
234
+
235
+ def _convert_token_to_id(self, token):
236
+ """ Converts a token (str) in an id using the vocab. """
237
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
238
+
239
+ def _convert_id_to_token(self, index):
240
+ """Converts an index (integer) in a token (str) using the vocab."""
241
+ return self.ids_to_tokens.get(index, self.unk_token)
242
+
243
+ def convert_tokens_to_string(self, tokens):
244
+ """ Converts a sequence of tokens (string) in a single string. """
245
+ out_string = " ".join(tokens).replace(" ##", "").strip()
246
+ return out_string
247
+
248
+ def build_inputs_with_special_tokens(
249
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
250
+ ) -> List[int]:
251
+ """
252
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
253
+ adding special tokens. A BERT sequence has the following format:
254
+ - single sequence: ``[CLS] X ``
255
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
256
+ Args:
257
+ token_ids_0 (:obj:`List[int]`):
258
+ List of IDs to which the special tokens will be added.
259
+ token_ids_1 (:obj:`List[int]`, `optional`):
260
+ Optional second list of IDs for sequence pairs.
261
+ Returns:
262
+ :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
263
+ """
264
+ if token_ids_1 is None:
265
+ return [self.cls_token_id] + token_ids_0
266
+ cls = [self.cls_token_id]
267
+ sep = [self.sep_token_id]
268
+ return cls + token_ids_0 + sep + token_ids_1 + sep
269
+
270
+ def get_special_tokens_mask(
271
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
272
+ ) -> List[int]:
273
+ """
274
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
275
+ special tokens using the tokenizer ``prepare_for_model`` method.
276
+ Args:
277
+ token_ids_0 (:obj:`List[int]`):
278
+ List of IDs.
279
+ token_ids_1 (:obj:`List[int]`, `optional`):
280
+ Optional second list of IDs for sequence pairs.
281
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
282
+ Whether or not the token list is already formatted with special tokens for the model.
283
+ Returns:
284
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
285
+ """
286
+
287
+ if already_has_special_tokens:
288
+ if token_ids_1 is not None:
289
+ raise ValueError(
290
+ "You should not supply a second sequence if the provided sequence of "
291
+ "ids is already formatted with special tokens for the model."
292
+ )
293
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
294
+
295
+ if token_ids_1 is not None:
296
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
297
+ return [1] + ([0] * len(token_ids_0)) + [1]
298
+
299
+ def create_token_type_ids_from_sequences(
300
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
301
+ ) -> List[int]:
302
+ """
303
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. A BERT sequence
304
+ pair mask has the following format:
305
+ ::
306
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
307
+ | first sequence | second sequence |
308
+ If :obj:`token_ids_1` is :obj:`None`, this method only returns the first portion of the mask (0s).
309
+ Args:
310
+ token_ids_0 (:obj:`List[int]`):
311
+ List of IDs.
312
+ token_ids_1 (:obj:`List[int]`, `optional`):
313
+ Optional second list of IDs for sequence pairs.
314
+ Returns:
315
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
316
+ sequence(s).
317
+ """
318
+ sep = [self.sep_token_id]
319
+ cls = [self.cls_token_id]
320
+ if token_ids_1 is None:
321
+ return len(cls + token_ids_0 + sep) * [0]
322
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
323
+
324
+ def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
325
+ index = 0
326
+ if os.path.isdir(save_directory):
327
+ vocab_file = os.path.join(
328
+ save_directory, (filename_prefix + "-" if filename_prefix else "") +
329
+ VOCAB_FILES_NAMES["vocab_file"]
330
+ )
331
+ else:
332
+ vocab_file = (filename_prefix +
333
+ "-" if filename_prefix else "") + save_directory
334
+ with open(vocab_file, "w", encoding="utf-8") as writer:
335
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
336
+ if index != token_index:
337
+ logger.warning(
338
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
339
+ " Please check that the vocabulary is not corrupted!".format(
340
+ vocab_file)
341
+ )
342
+ index = token_index
343
+ writer.write(token + "\n")
344
+ index += 1
345
+ return (vocab_file,)
346
+
347
+
348
+ class BasicTokenizer(object):
349
+ """
350
+ Constructs a BasicTokenizer that will run basic tokenization (punctuation splitting, lower casing, etc.).
351
+ Args:
352
+ do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`True`):
353
+ Whether or not to lowercase the input when tokenizing.
354
+ never_split (:obj:`Iterable`, `optional`):
355
+ Collection of tokens which will never be split during tokenization. Only has an effect when
356
+ :obj:`do_basic_tokenize=True`
357
+ tokenize_chinese_chars (:obj:`bool`, `optional`, defaults to :obj:`True`):
358
+ Whether or not to tokenize Chinese characters.
359
+ This should likely be deactivated for Japanese (see this `issue
360
+ <https://github.com/huggingface/transformers/issues/328>`__).
361
+ strip_accents: (:obj:`bool`, `optional`):
362
+ Whether or not to strip all accents. If this option is not specified, then it will be determined by the
363
+ value for :obj:`lowercase` (as in the original BERT).
364
+ """
365
+
366
+ def __init__(self, do_lower_case=True, never_split=None, tokenize_chinese_chars=True, strip_accents=None):
367
+ if never_split is None:
368
+ never_split = []
369
+ self.do_lower_case = do_lower_case
370
+ self.never_split = set(never_split)
371
+ self.tokenize_chinese_chars = tokenize_chinese_chars
372
+ self.strip_accents = strip_accents
373
+
374
+ def tokenize(self, text, never_split=None):
375
+ """
376
+ Basic Tokenization of a piece of text. Split on "white spaces" only, for sub-word tokenization, see
377
+ WordPieceTokenizer.
378
+ Args:
379
+ **never_split**: (`optional`) list of str
380
+ Kept for backward compatibility purposes. Now implemented directly at the base class level (see
381
+ :func:`PreTrainedTokenizer.tokenize`) List of token not to split.
382
+ """
383
+ # union() returns a new set by concatenating the two sets.
384
+ never_split = self.never_split.union(
385
+ set(never_split)) if never_split else self.never_split
386
+ text = self._clean_text(text)
387
+
388
+ # This was added on November 1st, 2018 for the multilingual and Chinese
389
+ # models. This is also applied to the English models now, but it doesn't
390
+ # matter since the English models were not trained on any Chinese data
391
+ # and generally don't have any Chinese data in them (there are Chinese
392
+ # characters in the vocabulary because Wikipedia does have some Chinese
393
+ # words in the English Wikipedia.).
394
+ if self.tokenize_chinese_chars:
395
+ text = self._tokenize_chinese_chars(text)
396
+ orig_tokens = whitespace_tokenize(text)
397
+ split_tokens = []
398
+ for token in orig_tokens:
399
+ if token not in never_split:
400
+ if self.do_lower_case:
401
+ token = token.lower()
402
+ if self.strip_accents is not False:
403
+ token = self._run_strip_accents(token)
404
+ elif self.strip_accents:
405
+ token = self._run_strip_accents(token)
406
+ split_tokens.extend(self._run_split_on_punc(token, never_split))
407
+
408
+ output_tokens = whitespace_tokenize(" ".join(split_tokens))
409
+ return output_tokens
410
+
411
+ def _run_strip_accents(self, text):
412
+ """Strips accents from a piece of text."""
413
+ text = unicodedata.normalize("NFD", text)
414
+ output = []
415
+ for char in text:
416
+ cat = unicodedata.category(char)
417
+ if cat == "Mn":
418
+ continue
419
+ output.append(char)
420
+ return "".join(output)
421
+
422
+ def _run_split_on_punc(self, text, never_split=None):
423
+ """Splits punctuation on a piece of text."""
424
+ if never_split is not None and text in never_split:
425
+ return [text]
426
+ chars = list(text)
427
+ i = 0
428
+ start_new_word = True
429
+ output = []
430
+ while i < len(chars):
431
+ char = chars[i]
432
+ if _is_punctuation(char):
433
+ output.append([char])
434
+ start_new_word = True
435
+ else:
436
+ if start_new_word:
437
+ output.append([])
438
+ start_new_word = False
439
+ output[-1].append(char)
440
+ i += 1
441
+
442
+ return ["".join(x) for x in output]
443
+
444
+ def _tokenize_chinese_chars(self, text):
445
+ """Adds whitespace around any CJK character."""
446
+ output = []
447
+ for char in text:
448
+ cp = ord(char)
449
+ if self._is_chinese_char(cp):
450
+ output.append(" ")
451
+ output.append(char)
452
+ output.append(" ")
453
+ else:
454
+ output.append(char)
455
+ return "".join(output)
456
+
457
+ def _is_chinese_char(self, cp):
458
+ """Checks whether CP is the codepoint of a CJK character."""
459
+ # This defines a "chinese character" as anything in the CJK Unicode block:
460
+ # https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
461
+ #
462
+ # Note that the CJK Unicode block is NOT all Japanese and Korean characters,
463
+ # despite its name. The modern Korean Hangul alphabet is a different block,
464
+ # as is Japanese Hiragana and Katakana. Those alphabets are used to write
465
+ # space-separated words, so they are not treated specially and handled
466
+ # like the all of the other languages.
467
+ if (
468
+ (cp >= 0x4E00 and cp <= 0x9FFF)
469
+ or (cp >= 0x3400 and cp <= 0x4DBF) #
470
+ or (cp >= 0x20000 and cp <= 0x2A6DF) #
471
+ or (cp >= 0x2A700 and cp <= 0x2B73F) #
472
+ or (cp >= 0x2B740 and cp <= 0x2B81F) #
473
+ or (cp >= 0x2B820 and cp <= 0x2CEAF) #
474
+ or (cp >= 0xF900 and cp <= 0xFAFF)
475
+ or (cp >= 0x2F800 and cp <= 0x2FA1F) #
476
+ ): #
477
+ return True
478
+
479
+ return False
480
+
481
+ def _clean_text(self, text):
482
+ """Performs invalid character removal and whitespace cleanup on text."""
483
+ output = []
484
+ for char in text:
485
+ cp = ord(char)
486
+ if cp == 0 or cp == 0xFFFD or _is_control(char):
487
+ continue
488
+ if _is_whitespace(char):
489
+ output.append(" ")
490
+ else:
491
+ output.append(char)
492
+ return "".join(output)
493
+
494
+
495
+ class WordpieceTokenizer(object):
496
+ """Runs WordPiece tokenization."""
497
+
498
+ def __init__(self, vocab, unk_token, max_input_chars_per_word=100):
499
+ self.vocab = vocab
500
+ self.unk_token = unk_token
501
+ self.max_input_chars_per_word = max_input_chars_per_word
502
+
503
+ def tokenize(self, text):
504
+ """
505
+ Tokenizes a piece of text into its word pieces. This uses a greedy longest-match-first algorithm to perform
506
+ tokenization using the given vocabulary.
507
+ For example, :obj:`input = "unaffable"` wil return as output :obj:`["un", "##aff", "##able"]`.
508
+ Args:
509
+ text: A single token or whitespace separated tokens. This should have
510
+ already been passed through `BasicTokenizer`.
511
+ Returns:
512
+ A list of wordpiece tokens.
513
+ """
514
+
515
+ output_tokens = []
516
+ for token in whitespace_tokenize(text):
517
+ chars = list(token)
518
+ if len(chars) > self.max_input_chars_per_word:
519
+ output_tokens.append(self.unk_token)
520
+ continue
521
+
522
+ is_bad = False
523
+ start = 0
524
+ sub_tokens = []
525
+ while start < len(chars):
526
+ end = len(chars)
527
+ cur_substr = None
528
+ while start < end:
529
+ substr = "".join(chars[start:end])
530
+ if start > 0:
531
+ substr = "##" + substr
532
+ if substr in self.vocab:
533
+ cur_substr = substr
534
+ break
535
+ end -= 1
536
+ if cur_substr is None:
537
+ is_bad = True
538
+ break
539
+ sub_tokens.append(cur_substr)
540
+ start = end
541
+
542
+ if is_bad:
543
+ output_tokens.append(self.unk_token)
544
+ else:
545
+ output_tokens.extend(sub_tokens)
546
+ return output_tokens
svitt/utils.py ADDED
@@ -0,0 +1,235 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from scipy import interpolate
5
+ import numpy as np
6
+ from einops import rearrange, repeat
7
+
8
+
9
+ def _init_transformer_weights(module, initializer_range=0.02):
10
+ """Initialize the weights. Copied from transformers ViT/Bert model init"""
11
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
12
+ # Slightly different from the TF version which uses truncated_normal for initialization
13
+ # cf https://github.com/pytorch/pytorch/pull/5617
14
+ module.weight.data.normal_(mean=0.0, std=initializer_range)
15
+ if module.bias is not None:
16
+ module.bias.data.zero_()
17
+ elif isinstance(module, nn.Embedding):
18
+ module.weight.data.normal_(mean=0.0, std=initializer_range)
19
+ if module.padding_idx is not None:
20
+ module.weight.data[module.padding_idx].zero_()
21
+ elif isinstance(module, nn.LayerNorm):
22
+ module.bias.data.zero_()
23
+ module.weight.data.fill_(1.0)
24
+
25
+
26
+ def interpolate_pos_embed(pos_embed_old, pos_embed_new, num_patches_new):
27
+ """
28
+ Args:
29
+ pos_embed_old: (1, L_old, d), pre-trained
30
+ pos_embed_new: (1, L_new, d), newly initialized, to be replaced by interpolated weights
31
+ num_patches_new:
32
+ """
33
+ # interpolate position embedding
34
+ embedding_size = pos_embed_old.shape[-1]
35
+ num_extra_tokens = pos_embed_new.shape[-2] - num_patches_new
36
+ # height (== width) for the checkpoint position embedding
37
+ orig_size = int((pos_embed_old.shape[-2] - num_extra_tokens) ** 0.5)
38
+ # height (== width) for the new position embedding
39
+ new_size = int(num_patches_new ** 0.5)
40
+
41
+ if orig_size != new_size:
42
+ # class_token and dist_token are kept unchanged
43
+ # the extra tokens seems always at the beginning of the position embedding
44
+ extra_tokens = pos_embed_old[:, :num_extra_tokens]
45
+ # only the position tokens are interpolated
46
+ pos_tokens = pos_embed_old[:, num_extra_tokens:]
47
+ pos_tokens = pos_tokens.reshape(
48
+ -1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
49
+ pos_tokens = torch.nn.functional.interpolate(
50
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
51
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
52
+ interpolated_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
53
+ return interpolated_pos_embed
54
+ else:
55
+ return pos_embed_old
56
+
57
+
58
+ def interpolate_pos_relative_bias_beit(state_dict_old, state_dict_new, patch_shape_new):
59
+ """
60
+ Args:
61
+ state_dict_old: loaded state dict
62
+ state_dict_new: state dict for model with new image size
63
+ patch_shape_new: new model patch_shape
64
+ ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
65
+ """
66
+ all_keys = list(state_dict_old.keys())
67
+ for key in all_keys:
68
+ if "relative_position_index" in key:
69
+ state_dict_old.pop(key)
70
+
71
+ if "relative_position_bias_table" in key:
72
+ rel_pos_bias = state_dict_old[key]
73
+ src_num_pos, num_attn_heads = rel_pos_bias.size()
74
+ dst_num_pos, _ = state_dict_new[key].size()
75
+ dst_patch_shape = patch_shape_new
76
+ if dst_patch_shape[0] != dst_patch_shape[1]:
77
+ raise NotImplementedError()
78
+ num_extra_tokens = dst_num_pos - (dst_patch_shape[0] * 2 - 1) * (dst_patch_shape[1] * 2 - 1)
79
+ src_size = int((src_num_pos - num_extra_tokens) ** 0.5)
80
+ dst_size = int((dst_num_pos - num_extra_tokens) ** 0.5)
81
+ if src_size != dst_size:
82
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
83
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
84
+
85
+ def geometric_progression(a, r, n):
86
+ return a * (1.0 - r ** n) / (1.0 - r)
87
+
88
+ left, right = 1.01, 1.5
89
+ while right - left > 1e-6:
90
+ q = (left + right) / 2.0
91
+ gp = geometric_progression(1, q, src_size // 2)
92
+ if gp > dst_size // 2:
93
+ right = q
94
+ else:
95
+ left = q
96
+
97
+ # if q > 1.090307:
98
+ # q = 1.090307
99
+
100
+ dis = []
101
+ cur = 1
102
+ for i in range(src_size // 2):
103
+ dis.append(cur)
104
+ cur += q ** (i + 1)
105
+
106
+ r_ids = [-_ for _ in reversed(dis)]
107
+
108
+ x = r_ids + [0] + dis
109
+ y = r_ids + [0] + dis
110
+
111
+ t = dst_size // 2.0
112
+ dx = np.arange(-t, t + 0.1, 1.0)
113
+ dy = np.arange(-t, t + 0.1, 1.0)
114
+
115
+ all_rel_pos_bias = []
116
+
117
+ for i in range(num_attn_heads):
118
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
119
+ f = interpolate.interp2d(x, y, z, kind='cubic')
120
+ all_rel_pos_bias.append(
121
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
122
+
123
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
124
+
125
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
126
+ state_dict_old[key] = new_rel_pos_bias
127
+ return state_dict_old
128
+
129
+
130
+ def interpolate_pos_relative_bias_beit_3d(state_dict_old, state_dict_new, patch_shape_new, src_t_size=1):
131
+ """
132
+ Args:
133
+ state_dict_old: loaded state dict
134
+ state_dict_new: state dict for model with new image size
135
+ patch_shape_new: new model patch_shape
136
+ ref: https://github.com/microsoft/unilm/blob/master/beit/run_class_finetuning.py
137
+ """
138
+ all_keys = list(state_dict_old.keys())
139
+ for key in all_keys:
140
+ if "relative_position_index" in key:
141
+ state_dict_old.pop(key)
142
+
143
+ if "relative_position_bias_table" in key:
144
+ src_num_pos, num_attn_heads = state_dict_old[key].size()
145
+ dst_num_pos, _ = state_dict_new[key].size()
146
+ if src_num_pos == dst_num_pos:
147
+ continue
148
+
149
+ num_extra_tokens = dst_num_pos - np.prod([w * 2 - 1 for w in patch_shape_new])
150
+
151
+ src_s_size = int((src_num_pos - num_extra_tokens) / src_t_size)
152
+ src_size = int(src_s_size ** 0.5)
153
+ dst_size = patch_shape_new[-1] * 2 - 1
154
+
155
+ if src_size != dst_size:
156
+ # Spatial interpolation
157
+ rel_pos_bias = state_dict_old[key]
158
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
159
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
160
+
161
+ def geometric_progression(a, r, n):
162
+ return a * (1.0 - r ** n) / (1.0 - r)
163
+
164
+ left, right = 1.01, 1.5
165
+ while right - left > 1e-6:
166
+ q = (left + right) / 2.0
167
+ gp = geometric_progression(1, q, src_size // 2)
168
+ if gp > dst_size // 2:
169
+ right = q
170
+ else:
171
+ left = q
172
+
173
+ # if q > 1.090307:
174
+ # q = 1.090307
175
+
176
+ dis = []
177
+ cur = 1
178
+ for i in range(src_size // 2):
179
+ dis.append(cur)
180
+ cur += q ** (i + 1)
181
+
182
+ r_ids = [-_ for _ in reversed(dis)]
183
+
184
+ x = r_ids + [0] + dis
185
+ y = r_ids + [0] + dis
186
+
187
+ t = dst_size // 2.0
188
+ dx = np.arange(-t, t + 0.1, 1.0)
189
+ dy = np.arange(-t, t + 0.1, 1.0)
190
+
191
+ all_rel_pos_bias = []
192
+
193
+ for i in range(num_attn_heads):
194
+ z = rel_pos_bias[:, i].view(src_size, src_size).float().numpy()
195
+ f = interpolate.interp2d(x, y, z, kind='cubic')
196
+ all_rel_pos_bias.append(
197
+ torch.Tensor(f(dx, dy)).contiguous().view(-1, 1).to(rel_pos_bias.device))
198
+
199
+ rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
200
+
201
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
202
+ state_dict_old[key] = new_rel_pos_bias
203
+
204
+ dst_t_size = patch_shape_new[0] * 2 - 1
205
+ if src_t_size != dst_t_size:
206
+ # Temporal interpolation
207
+ rel_pos_bias = state_dict_old[key]
208
+ extra_tokens = rel_pos_bias[-num_extra_tokens:, :]
209
+ rel_pos_bias = rel_pos_bias[:-num_extra_tokens, :]
210
+
211
+ if src_t_size == 1:
212
+ rel_pos_bias = repeat(rel_pos_bias, 's d -> (t s) d', t=dst_t_size)
213
+ else:
214
+ rel_pos_bias = rearrange(rel_pos_bias, '(t s) d -> s d t', t=src_t_size)
215
+ rel_pos_bias = F.interpolate(rel_pos_bias, dst_t_size, mode='nearest')
216
+ rel_pos_bias = rearrange(rel_pos_bias, 's d t -> (t s) d')
217
+ new_rel_pos_bias = torch.cat((rel_pos_bias, extra_tokens), dim=0)
218
+ state_dict_old[key] = new_rel_pos_bias
219
+
220
+ return state_dict_old
221
+
222
+
223
+ def tile(x, dim, n_tile):
224
+ init_dim = x.size(dim)
225
+ repeat_idx = [1] * x.dim()
226
+ repeat_idx[dim] = n_tile
227
+ x = x.repeat(*repeat_idx)
228
+ order_index = torch.LongTensor(np.concatenate(
229
+ [init_dim * np.arange(n_tile) + i for i in range(init_dim)]))
230
+ return torch.index_select(x, dim, order_index.to(x.device))
231
+
232
+
233
+ def mask_logits(target, mask):
234
+ return target * mask + (1 - mask) * (-1e10)
235
+
svitt/video_transforms.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import Sequence
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import transforms
12
+
13
+
14
+ class Permute(nn.Module):
15
+ """
16
+ Permutation as an op
17
+ """
18
+
19
+ def __init__(self, ordering):
20
+ super().__init__()
21
+ self.ordering = ordering
22
+
23
+ def forward(self, frames):
24
+ """
25
+ Args:
26
+ frames in some ordering, by default (C, T, H, W)
27
+ Returns:
28
+ frames in the ordering that was specified
29
+ """
30
+ return frames.permute(self.ordering)
31
+
32
+
33
+ class TemporalCrop(nn.Module):
34
+ """
35
+ Convert the video into smaller clips temporally.
36
+ """
37
+
38
+ def __init__(
39
+ self, frames_per_clip: int = 8, stride: int = 8, frame_stride: int = 1
40
+ ):
41
+ super().__init__()
42
+ self.frames = frames_per_clip
43
+ self.stride = stride
44
+ self.frame_stride = frame_stride
45
+
46
+ def forward(self, video):
47
+ assert video.ndim == 4, "Must be (C, T, H, W)"
48
+ res = []
49
+ for start in range(
50
+ 0, video.size(1) - (self.frames * self.frame_stride) + 1, self.stride
51
+ ):
52
+ end = start + (self.frames) * self.frame_stride
53
+ res.append(video[:, start: end: self.frame_stride, ...])
54
+ return res
55
+
56
+
57
+ def crop_boxes(boxes, x_offset, y_offset):
58
+ """
59
+ Peform crop on the bounding boxes given the offsets.
60
+ Args:
61
+ boxes (ndarray or None): bounding boxes to peform crop. The dimension
62
+ is `num boxes` x 4.
63
+ x_offset (int): cropping offset in the x axis.
64
+ y_offset (int): cropping offset in the y axis.
65
+ Returns:
66
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
67
+ `num boxes` x 4.
68
+ """
69
+ cropped_boxes = boxes.copy()
70
+ cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset
71
+ cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset
72
+
73
+ return cropped_boxes
74
+
75
+
76
+ def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None):
77
+ """
78
+ Perform uniform spatial sampling on the images and corresponding boxes.
79
+ Args:
80
+ images (tensor): images to perform uniform crop. The dimension is
81
+ `num frames` x `channel` x `height` x `width`.
82
+ size (int): size of height and weight to crop the images.
83
+ spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width
84
+ is larger than height. Or 0, 1, or 2 for top, center, and bottom
85
+ crop if height is larger than width.
86
+ boxes (ndarray or None): optional. Corresponding boxes to images.
87
+ Dimension is `num boxes` x 4.
88
+ scale_size (int): optinal. If not None, resize the images to scale_size before
89
+ performing any crop.
90
+ Returns:
91
+ cropped (tensor): images with dimension of
92
+ `num frames` x `channel` x `size` x `size`.
93
+ cropped_boxes (ndarray or None): the cropped boxes with dimension of
94
+ `num boxes` x 4.
95
+ """
96
+ assert spatial_idx in [0, 1, 2]
97
+ ndim = len(images.shape)
98
+ if ndim == 3:
99
+ images = images.unsqueeze(0)
100
+ height = images.shape[2]
101
+ width = images.shape[3]
102
+
103
+ if scale_size is not None:
104
+ if width <= height:
105
+ width, height = scale_size, int(height / width * scale_size)
106
+ else:
107
+ width, height = int(width / height * scale_size), scale_size
108
+ images = torch.nn.functional.interpolate(
109
+ images,
110
+ size=(height, width),
111
+ mode="bilinear",
112
+ align_corners=False,
113
+ )
114
+
115
+ y_offset = int(math.ceil((height - size) / 2))
116
+ x_offset = int(math.ceil((width - size) / 2))
117
+
118
+ if height > width:
119
+ if spatial_idx == 0:
120
+ y_offset = 0
121
+ elif spatial_idx == 2:
122
+ y_offset = height - size
123
+ else:
124
+ if spatial_idx == 0:
125
+ x_offset = 0
126
+ elif spatial_idx == 2:
127
+ x_offset = width - size
128
+ cropped = images[:, :, y_offset: y_offset + size, x_offset: x_offset + size]
129
+ cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None
130
+ if ndim == 3:
131
+ cropped = cropped.squeeze(0)
132
+ return cropped, cropped_boxes
133
+
134
+
135
+ class SpatialCrop(nn.Module):
136
+ """
137
+ Convert the video into 3 smaller clips spatially. Must be used after the
138
+ temporal crops to get spatial crops, and should be used with
139
+ -2 in the spatial crop at the slowfast augmentation stage (so full
140
+ frames are passed in here). Will return a larger list with the
141
+ 3x spatial crops as well. It's useful for 3x4 testing (eg in SwinT)
142
+ or 3x10 testing in SlowFast etc.
143
+ """
144
+
145
+ def __init__(self, crop_size: int = 224, num_crops: int = 3):
146
+ super().__init__()
147
+ self.crop_size = crop_size
148
+ if num_crops == 6:
149
+ self.crops_to_ext = [0, 1, 2]
150
+ # I guess Swin uses 5 crops without flipping, but that doesn't
151
+ # make sense given they first resize to 224 and take 224 crops.
152
+ # (pg 6 of https://arxiv.org/pdf/2106.13230.pdf)
153
+ # So I'm assuming we can use flipped crops and that will add sth..
154
+ self.flipped_crops_to_ext = [0, 1, 2]
155
+ elif num_crops == 3:
156
+ self.crops_to_ext = [0, 1, 2]
157
+ self.flipped_crops_to_ext = []
158
+ elif num_crops == 1:
159
+ self.crops_to_ext = [1]
160
+ self.flipped_crops_to_ext = []
161
+ else:
162
+ raise NotImplementedError(
163
+ "Nothing else supported yet, "
164
+ "slowfast only takes 0, 1, 2 as arguments"
165
+ )
166
+
167
+ def forward(self, videos: Sequence[torch.Tensor]):
168
+ """
169
+ Args:
170
+ videos: A list of C, T, H, W videos.
171
+ Returns:
172
+ videos: A list with 3x the number of elements. Each video converted
173
+ to C, T, H', W' by spatial cropping.
174
+ """
175
+ assert isinstance(videos, list), "Must be a list of videos after temporal crops"
176
+ assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)"
177
+ res = []
178
+ for video in videos:
179
+ for spatial_idx in self.crops_to_ext:
180
+ res.append(uniform_crop(video, self.crop_size, spatial_idx)[0])
181
+ if not self.flipped_crops_to_ext:
182
+ continue
183
+ flipped_video = transforms.functional.hflip(video)
184
+ for spatial_idx in self.flipped_crops_to_ext:
185
+ res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0])
186
+ return res