Spaces:
Running
Running
adamtayzzz
commited on
Commit
•
1076673
1
Parent(s):
e28c379
Upload 41 files
Browse files- .gitattributes +2 -0
- app.py +84 -0
- glue_data/STS-B/LICENSE.txt +136 -0
- glue_data/STS-B/cached_dev_albert-base-v2_128_sts-b +3 -0
- glue_data/STS-B/cached_train_albert-base-v2_128_sts-b +3 -0
- glue_data/STS-B/dev.tsv +0 -0
- glue_data/STS-B/original/sts-dev.tsv +0 -0
- glue_data/STS-B/original/sts-test.tsv +0 -0
- glue_data/STS-B/original/sts-train.tsv +0 -0
- glue_data/STS-B/readme.txt +174 -0
- glue_data/STS-B/test.tsv +0 -0
- glue_data/STS-B/train.tsv +0 -0
- outputs/train/albert/STS-B/checkpoint-600/config.json +50 -0
- outputs/train/albert/STS-B/checkpoint-600/eval_results.txt +1 -0
- outputs/train/albert/STS-B/checkpoint-600/optimizer.pt +3 -0
- outputs/train/albert/STS-B/checkpoint-600/pytorch_model.bin +3 -0
- outputs/train/albert/STS-B/checkpoint-600/scheduler.pt +3 -0
- outputs/train/albert/STS-B/checkpoint-600/special_tokens_map.json +15 -0
- outputs/train/albert/STS-B/checkpoint-600/spiece.model +3 -0
- outputs/train/albert/STS-B/checkpoint-600/tokenizer_config.json +24 -0
- outputs/train/albert/STS-B/checkpoint-600/training_args.bin +3 -0
- outputs/train/albert/STS-B/config.json +50 -0
- outputs/train/albert/STS-B/eval_results.txt +1 -0
- outputs/train/albert/STS-B/pytorch_model.bin +3 -0
- outputs/train/albert/STS-B/special_tokens_map.json +15 -0
- outputs/train/albert/STS-B/spiece.model +3 -0
- outputs/train/albert/STS-B/tokenizer_config.json +24 -0
- outputs/train/albert/STS-B/training_args.bin +3 -0
- pabee/__pycache__/modeling_albert.cpython-37.pyc +0 -0
- pabee/__pycache__/modeling_albert.cpython-39.pyc +0 -0
- pabee/__pycache__/modeling_bert.cpython-37.pyc +0 -0
- pabee/__pycache__/modeling_bert.cpython-39.pyc +0 -0
- pabee/configuration_albert.py +146 -0
- pabee/configuration_bert.py +142 -0
- pabee/modeling_albert.py +1085 -0
- pabee/modeling_bert.py +1663 -0
- run_glue.py +772 -0
- whitebox_utils/__pycache__/attack.cpython-37.pyc +0 -0
- whitebox_utils/__pycache__/classifier.cpython-37.pyc +0 -0
- whitebox_utils/__pycache__/metric.cpython-37.pyc +0 -0
- whitebox_utils/classifier.py +117 -0
.gitattributes
CHANGED
@@ -32,3 +32,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
glue_data/STS-B/cached_dev_albert-base-v2_128_sts-b filter=lfs diff=lfs merge=lfs -text
|
36 |
+
glue_data/STS-B/cached_train_albert-base-v2_128_sts-b filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import requests
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
# from badnet_m import BadNet
|
8 |
+
|
9 |
+
import timm
|
10 |
+
|
11 |
+
# model = timm.create_model("hf_hub:nateraw/resnet18-random", pretrained=True)
|
12 |
+
# model.train()
|
13 |
+
|
14 |
+
# model = BadNet(3, 10)
|
15 |
+
# pipeline = pipeline.to('cuda:0')
|
16 |
+
|
17 |
+
|
18 |
+
import os
|
19 |
+
|
20 |
+
import logging
|
21 |
+
|
22 |
+
from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer
|
23 |
+
from pabee.modeling_albert import AlbertForSequenceClassification
|
24 |
+
from pabee.modeling_bert import BertForSequenceClassification
|
25 |
+
from transformers import glue_output_modes as output_modes
|
26 |
+
from transformers import glue_processors as processors
|
27 |
+
|
28 |
+
import datasets
|
29 |
+
from whitebox_utils.classifier import MyClassifier
|
30 |
+
|
31 |
+
import ssl
|
32 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
33 |
+
|
34 |
+
import random
|
35 |
+
import numpy as np
|
36 |
+
import torch
|
37 |
+
import argparse
|
38 |
+
|
39 |
+
def random_seed(seed):
|
40 |
+
random.seed(seed)
|
41 |
+
np.random.seed(seed)
|
42 |
+
torch.manual_seed(seed)
|
43 |
+
torch.cuda.manual_seed_all(seed)
|
44 |
+
|
45 |
+
logger = logging.getLogger(__name__)
|
46 |
+
|
47 |
+
# TODO: dataset model tokenizer etc.
|
48 |
+
best_model_path = {
|
49 |
+
'albert_STS-B':'./outputs/train/albert/SST-2/checkpoint-7500',
|
50 |
+
}
|
51 |
+
|
52 |
+
MODEL_CLASSES = {
|
53 |
+
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
54 |
+
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
55 |
+
}
|
56 |
+
|
57 |
+
model = 'albert'
|
58 |
+
dataset = 'STS-B'
|
59 |
+
task_name = f'{dataset}'.lower()
|
60 |
+
if task_name not in processors:
|
61 |
+
raise ValueError("Task not found: %s" % (task_name))
|
62 |
+
processor = processors[task_name]() # transformers package-preprocessor
|
63 |
+
output_mode = output_modes[task_name] # output type
|
64 |
+
label_list = processor.get_labels()
|
65 |
+
num_labels = len(label_list)
|
66 |
+
|
67 |
+
output_dir = f'./PABEE/outputs/train/{model}/{dataset}'
|
68 |
+
data_dir = f'./PABEE/glue_data/{dataset}'
|
69 |
+
|
70 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[model]
|
71 |
+
tokenizer = tokenizer_class.from_pretrained(output_dir, do_lower_case=True)
|
72 |
+
model = model_class.from_pretrained(best_model_path[f'{model}_{dataset}'])
|
73 |
+
|
74 |
+
exit_type='patience'
|
75 |
+
exit_value=3
|
76 |
+
|
77 |
+
classifier = MyClassifier(model,tokenizer,label_list,output_mode,exit_type,exit_value,model)
|
78 |
+
|
79 |
+
def greet(text,text2,exit_pos):
|
80 |
+
text_input = [(text,text2)]
|
81 |
+
classifier.get_prob_time(text_input,exit_position=exit_pos)
|
82 |
+
|
83 |
+
iface = gr.Interface(fn=greet, inputs='text', outputs="image")
|
84 |
+
iface.launch()
|
glue_data/STS-B/LICENSE.txt
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
Notes on datasets and licenses
|
3 |
+
------------------------------
|
4 |
+
|
5 |
+
If using this data in your research please cite the following paper
|
6 |
+
and the url of the STS website: http://ixa2.si.ehu.eus/stswiki:
|
7 |
+
|
8 |
+
Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia
|
9 |
+
Specia. Semeval-2017 Task 1: Semantic Textual Similarity
|
10 |
+
Multilingual and Crosslingual Focused Evaluation. Proceedings of
|
11 |
+
SemEval 2017.
|
12 |
+
|
13 |
+
The scores are released under a "Commons Attribution - Share Alike 4.0
|
14 |
+
International License" http://creativecommons.org/licenses/by-sa/4.0/
|
15 |
+
|
16 |
+
The text of each dataset has a license of its own, as follows:
|
17 |
+
|
18 |
+
- MSR-Paraphrase, Microsoft Research Paraphrase Corpus. In order to use
|
19 |
+
MSRpar, researchers need to agree with the license terms from
|
20 |
+
Microsoft Research:
|
21 |
+
http://research.microsoft.com/en-us/downloads/607d14d9-20cd-47e3-85bc-a2f65cd28042/
|
22 |
+
|
23 |
+
- headlines: Mined from several news sources by European Media Monitor
|
24 |
+
(Best et al. 2005). using the RSS feed. European Media Monitor (EMM)
|
25 |
+
Real Time News Clusters are the top news stories for the last 4
|
26 |
+
hours, updated every ten minutes. The article clustering is fully
|
27 |
+
automatic. The selection and placement of stories are determined
|
28 |
+
automatically by a computer program. This site is a joint project of
|
29 |
+
DG-JRC and DG-COMM. The information on this site is subject to a
|
30 |
+
disclaimer (see
|
31 |
+
http://europa.eu/geninfo/legal_notices_en.htm). Please acknowledge
|
32 |
+
EMM when (re)using this material.
|
33 |
+
http://emm.newsbrief.eu/rss?type=rtn&language=en&duplicates=false
|
34 |
+
|
35 |
+
- deft-news: A subset of news article data in the DEFT
|
36 |
+
project.
|
37 |
+
|
38 |
+
- MSR-Video, Microsoft Research Video Description Corpus. In order to
|
39 |
+
use MSRvideo, researchers need to agree with the license terms from
|
40 |
+
Microsoft Research:
|
41 |
+
http://research.microsoft.com/en-us/downloads/38cf15fd-b8df-477e-a4e4-a4680caa75af/
|
42 |
+
|
43 |
+
- image: The Image Descriptions data set is a subset of
|
44 |
+
the PASCAL VOC-2008 data set (Rashtchian et al., 2010) . PASCAL
|
45 |
+
VOC-2008 data set consists of 1,000 images and has been used by a
|
46 |
+
number of image description systems. The image captions of the data
|
47 |
+
set are released under a CreativeCommons Attribution-ShareAlike
|
48 |
+
license, the descriptions itself are free.
|
49 |
+
|
50 |
+
- track5.en-en: This text is a subset of the Stanford Natural
|
51 |
+
Language Inference (SNLI) corpus, by The Stanford NLP Group is
|
52 |
+
licensed under a Creative Commons Attribution-ShareAlike 4.0
|
53 |
+
International License. Based on a work at
|
54 |
+
http://shannon.cs.illinois.edu/DenotationGraph/.
|
55 |
+
https://creativecommons.org/licenses/by-sa/4.0/
|
56 |
+
|
57 |
+
- answers-answers: user content from stack-exchange. Check the license
|
58 |
+
below in ======ANSWERS-ANSWERS======
|
59 |
+
|
60 |
+
- answers-forums: user content from stack-exchange. Check the license
|
61 |
+
below in ======ANSWERS-FORUMS======
|
62 |
+
|
63 |
+
|
64 |
+
|
65 |
+
======ANSWER-ANSWER======
|
66 |
+
|
67 |
+
Creative Commons Attribution-ShareAlike 3.0 Unported (CC BY-SA 3.0)
|
68 |
+
http://creativecommons.org/licenses/by-sa/3.0/
|
69 |
+
|
70 |
+
Attribution Requirements:
|
71 |
+
|
72 |
+
"* Visually display or otherwise indicate the source of the content
|
73 |
+
as coming from the Stack Exchange Network. This requirement is
|
74 |
+
satisfied with a discreet text blurb, or some other unobtrusive but
|
75 |
+
clear visual indication.
|
76 |
+
|
77 |
+
* Ensure that any Internet use of the content includes a hyperlink
|
78 |
+
directly to the original question on the source site on the Network
|
79 |
+
(e.g., http://stackoverflow.com/questions/12345)
|
80 |
+
|
81 |
+
* Visually display or otherwise clearly indicate the author names for
|
82 |
+
every question and answer used
|
83 |
+
|
84 |
+
* Ensure that any Internet use of the content includes a hyperlink for
|
85 |
+
each author name directly back to his or her user profile page on the
|
86 |
+
source site on the Network (e.g.,
|
87 |
+
http://stackoverflow.com/users/12345/username), directly to the Stack
|
88 |
+
Exchange domain, in standard HTML (i.e. not through a Tinyurl or other
|
89 |
+
such indirect hyperlink, form of obfuscation or redirection), without
|
90 |
+
any “nofollow” command or any other such means of avoiding detection by
|
91 |
+
search engines, and visible even with JavaScript disabled."
|
92 |
+
|
93 |
+
(https://archive.org/details/stackexchange)
|
94 |
+
|
95 |
+
|
96 |
+
|
97 |
+
======ANSWERS-FORUMS======
|
98 |
+
|
99 |
+
|
100 |
+
Stack Exchange Inc. generously made the data used to construct the STS 2015 answer-answer statement pairs available under a Creative Commons Attribution-ShareAlike (cc-by-sa) 3.0 license.
|
101 |
+
|
102 |
+
The license is reproduced below from: https://archive.org/details/stackexchange
|
103 |
+
|
104 |
+
The STS.input.answers-forums.txt file should be redistributed with this LICENSE text and the accompanying files in LICENSE.answers-forums.zip. The tsv files in the zip file contain the additional information that's needed to comply with the license.
|
105 |
+
|
106 |
+
--
|
107 |
+
|
108 |
+
All user content contributed to the Stack Exchange network is cc-by-sa 3.0 licensed, intended to be shared and remixed. We even provide all our data as a convenient data dump.
|
109 |
+
|
110 |
+
http://creativecommons.org/licenses/by-sa/3.0/
|
111 |
+
|
112 |
+
But our cc-by-sa 3.0 licensing, while intentionally permissive, does *require attribution*:
|
113 |
+
|
114 |
+
"Attribution — You must attribute the work in the manner specified by the author or licensor (but not in any way that suggests that they endorse you or your use of the work)."
|
115 |
+
|
116 |
+
Specifically the attribution requirements are as follows:
|
117 |
+
|
118 |
+
1. Visually display or otherwise indicate the source of the content as coming from the Stack Exchange Network. This requirement is satisfied with a discreet text blurb, or some other unobtrusive but clear visual indication.
|
119 |
+
|
120 |
+
2. Ensure that any Internet use of the content includes a hyperlink directly to the original question on the source site on the Network (e.g., http://stackoverflow.com/questions/12345)
|
121 |
+
|
122 |
+
3. Visually display or otherwise clearly indicate the author names for every question and answer so used.
|
123 |
+
|
124 |
+
4. Ensure that any Internet use of the content includes a hyperlink for each author name directly back to his or her user profile page on the source site on the Network (e.g., http://stackoverflow.com/users/12345/username), directly to the Stack Exchange domain, in standard HTML (i.e. not through a Tinyurl or other such indirect hyperlink, form of obfuscation or redirection), without any “nofollow” command or any other such means of avoiding detection by search engines, and visible even with JavaScript disabled.
|
125 |
+
|
126 |
+
Our goal is to maintain the spirit of fair attribution. That means attribution to the website, and more importantly, to the individuals who so generously contributed their time to create that content in the first place!
|
127 |
+
|
128 |
+
For more information, see the Stack Exchange Terms of Service: http://stackexchange.com/legal/terms-of-service
|
129 |
+
|
130 |
+
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
|
glue_data/STS-B/cached_dev_albert-base-v2_128_sts-b
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7372a96bfde5bf6e8bd8d82e254bfceab72c1ebfabaf6ddf7386de0f54a5afb5
|
3 |
+
size 1249575
|
glue_data/STS-B/cached_train_albert-base-v2_128_sts-b
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:af95cda957d025baa9d0ee0aef3fc22e3c58947b21079fadbcbbce15f1dec3a0
|
3 |
+
size 4787691
|
glue_data/STS-B/dev.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
glue_data/STS-B/original/sts-dev.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
glue_data/STS-B/original/sts-test.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
glue_data/STS-B/original/sts-train.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
glue_data/STS-B/readme.txt
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
STS Benchmark: Main English dataset
|
3 |
+
|
4 |
+
Semantic Textual Similarity 2012-2017 Dataset
|
5 |
+
|
6 |
+
http://ixa2.si.ehu.eus/stswiki
|
7 |
+
|
8 |
+
|
9 |
+
STS Benchmark comprises a selection of the English datasets used in
|
10 |
+
the STS tasks organized by us in the context of SemEval between 2012
|
11 |
+
and 2017.
|
12 |
+
|
13 |
+
In order to provide a standard benchmark to compare among systems, we
|
14 |
+
organized it into train, development and test. The development part
|
15 |
+
can be used to develop and tune hyperparameters of the systems, and
|
16 |
+
the test part should be only used once for the final system.
|
17 |
+
|
18 |
+
The benchmark comprises 8628 sentence pairs. This is the breakdown
|
19 |
+
according to genres and train-dev-test splits:
|
20 |
+
|
21 |
+
train dev test total
|
22 |
+
-----------------------------
|
23 |
+
news 3299 500 500 4299
|
24 |
+
caption 2000 625 525 3250
|
25 |
+
forum 450 375 254 1079
|
26 |
+
-----------------------------
|
27 |
+
total 5749 1500 1379 8628
|
28 |
+
|
29 |
+
For reference, this is the breakdown according to the original names
|
30 |
+
and task years of the datasets:
|
31 |
+
|
32 |
+
genre file years train dev test
|
33 |
+
------------------------------------------------
|
34 |
+
news MSRpar 2012 1000 250 250
|
35 |
+
news headlines 2013-16 1999 250 250
|
36 |
+
news deft-news 2014 300 0 0
|
37 |
+
captions MSRvid 2012 1000 250 250
|
38 |
+
captions images 2014-15 1000 250 250
|
39 |
+
captions track5.en-en 2017 0 125 125
|
40 |
+
forum deft-forum 2014 450 0 0
|
41 |
+
forum answers-forums 2015 0 375 0
|
42 |
+
forum answer-answer 2016 0 0 254
|
43 |
+
|
44 |
+
In addition to the standard benchmark, we also include other datasets
|
45 |
+
(see readme.txt in "companion" directory).
|
46 |
+
|
47 |
+
|
48 |
+
Introduction
|
49 |
+
------------
|
50 |
+
|
51 |
+
Given two sentences of text, s1 and s2, the systems need to compute
|
52 |
+
how similar s1 and s2 are, returning a similarity score between 0 and
|
53 |
+
5. The dataset comprises naturally occurring pairs of sentences drawn
|
54 |
+
from several domains and genres, annotated by crowdsourcing. See
|
55 |
+
papers by Agirre et al. (2012; 2013; 2014; 2015; 2016; 2017).
|
56 |
+
|
57 |
+
Format
|
58 |
+
------
|
59 |
+
|
60 |
+
Each file is encoded in utf-8 (a superset of ASCII), and has the
|
61 |
+
following tab separated fields:
|
62 |
+
|
63 |
+
genre filename year score sentence1 sentence2
|
64 |
+
|
65 |
+
optionally there might be some license-related fields after sentence2.
|
66 |
+
|
67 |
+
NOTE: Given that some sentence pairs have been reused here and
|
68 |
+
elsewhere, systems should NOT use the following datasets to develop or
|
69 |
+
train their systems (see below for more details on datasets):
|
70 |
+
|
71 |
+
- Any of the datasets in Semeval STS competitions, including Semeval
|
72 |
+
2014 task 1 (also known as SICK).
|
73 |
+
- The test part of MSR-Paraphrase (development and train are fine).
|
74 |
+
- The text of the videos in MSR-Video.
|
75 |
+
|
76 |
+
|
77 |
+
Evaluation script
|
78 |
+
-----------------
|
79 |
+
|
80 |
+
The official evaluation is the Pearson correlation coefficient. Given
|
81 |
+
an output file comprising the system scores (one per line) in a file
|
82 |
+
called sys.txt, you can use the evaluation script as follows:
|
83 |
+
|
84 |
+
$ perl correlation.pl sts-dev.txt sys.txt
|
85 |
+
|
86 |
+
|
87 |
+
Other
|
88 |
+
-----
|
89 |
+
|
90 |
+
Please check http://ixa2.si.ehu.eus/stswiki
|
91 |
+
|
92 |
+
We recommend that interested researchers join the (low traffic)
|
93 |
+
mailing list:
|
94 |
+
|
95 |
+
http://groups.google.com/group/STS-semeval
|
96 |
+
|
97 |
+
Notse on datasets and licenses
|
98 |
+
------------------------------
|
99 |
+
|
100 |
+
If using this data in your research please cite (Agirre et al. 2017)
|
101 |
+
and the STS website: http://ixa2.si.ehu.eus/stswiki.
|
102 |
+
|
103 |
+
Please see LICENSE.txt
|
104 |
+
|
105 |
+
|
106 |
+
Organizers of tasks by year
|
107 |
+
---------------------------
|
108 |
+
|
109 |
+
2012 Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre
|
110 |
+
|
111 |
+
2013 Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre,
|
112 |
+
WeiWei Guo
|
113 |
+
|
114 |
+
2014 Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab,
|
115 |
+
Aitor Gonzalez-Agirre, Weiwei Guo, Rada Mihalcea, German Rigau,
|
116 |
+
Janyce Wiebe
|
117 |
+
|
118 |
+
2015 Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab,
|
119 |
+
Aitor Gonzalez-Agirre, Weiwei Guo, Inigo Lopez-Gazpio, Montse
|
120 |
+
Maritxalar, Rada Mihalcea, German Rigau, Larraitz Uria, Janyce
|
121 |
+
Wiebe
|
122 |
+
|
123 |
+
2016 Eneko Agirre, Carmen Banea, Daniel Cer, Mona Diab, Aitor
|
124 |
+
Gonzalez-Agirre, Rada Mihalcea, German Rigau, Janyce
|
125 |
+
Wiebe
|
126 |
+
|
127 |
+
2017 Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia
|
128 |
+
Specia
|
129 |
+
|
130 |
+
|
131 |
+
References
|
132 |
+
----------
|
133 |
+
|
134 |
+
Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre. Task 6: A
|
135 |
+
Pilot on Semantic Textual Similarity. Procceedings of Semeval 2012
|
136 |
+
|
137 |
+
Eneko Agirre, Daniel Cer, Mona Diab, Aitor Gonzalez-Agirre, WeiWei
|
138 |
+
Guo. *SEM 2013 shared task: Semantic Textual
|
139 |
+
Similarity. Procceedings of *SEM 2013
|
140 |
+
|
141 |
+
Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab,
|
142 |
+
Aitor Gonzalez-Agirre, Weiwei Guo, Rada Mihalcea, German Rigau,
|
143 |
+
Janyce Wiebe. Task 10: Multilingual Semantic Textual
|
144 |
+
Similarity. Proceedings of SemEval 2014.
|
145 |
+
|
146 |
+
Eneko Agirre, Carmen Banea, Claire Cardie, Daniel Cer, Mona Diab,
|
147 |
+
Aitor Gonzalez-Agirre, Weiwei Guo, Inigo Lopez-Gazpio, Montse
|
148 |
+
Maritxalar, Rada Mihalcea, German Rigau, Larraitz Uria, Janyce
|
149 |
+
Wiebe. Task 2: Semantic Textual Similarity, English, Spanish and
|
150 |
+
Pilot on Interpretability. Proceedings of SemEval 2015.
|
151 |
+
|
152 |
+
Eneko Agirre, Carmen Banea, Daniel Cer, Mona Diab, Aitor
|
153 |
+
Gonzalez-Agirre, Rada Mihalcea, German Rigau, Janyce
|
154 |
+
Wiebe. Semeval-2016 Task 1: Semantic Textual Similarity,
|
155 |
+
Monolingual and Cross-Lingual Evaluation. Proceedings of SemEval
|
156 |
+
2016.
|
157 |
+
|
158 |
+
Eneko Agirre, Daniel Cer, Mona Diab, Iñigo Lopez-Gazpio, Lucia
|
159 |
+
Specia. Semeval-2017 Task 1: Semantic Textual Similarity
|
160 |
+
Multilingual and Crosslingual Focused Evaluation. Proceedings of
|
161 |
+
SemEval 2017.
|
162 |
+
|
163 |
+
Clive Best, Erik van der Goot, Ken Blackler, Tefilo Garcia, and David
|
164 |
+
Horby. 2005. Europe media monitor - system description. In EUR
|
165 |
+
Report 22173-En, Ispra, Italy.
|
166 |
+
|
167 |
+
Cyrus Rashtchian, Peter Young, Micah Hodosh, and Julia Hockenmaier.
|
168 |
+
Collecting Image Annotations Using Amazon's Mechanical Turk. In
|
169 |
+
Proceedings of the NAACL HLT 2010 Workshop on Creating Speech and
|
170 |
+
Language Data with Amazon's Mechanical Turk.
|
171 |
+
|
172 |
+
|
173 |
+
|
174 |
+
|
glue_data/STS-B/test.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
glue_data/STS-B/train.tsv
ADDED
The diff for this file is too large to render.
See raw diff
|
|
outputs/train/albert/STS-B/checkpoint-600/config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "albert-base-v2",
|
3 |
+
"architectures": [
|
4 |
+
"AlbertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0,
|
7 |
+
"bos_token_id": 2,
|
8 |
+
"classifier_dropout_prob": 0.1,
|
9 |
+
"down_scale_factor": 1,
|
10 |
+
"embedding_size": 128,
|
11 |
+
"eos_token_id": 3,
|
12 |
+
"finetuning_task": "sts-b",
|
13 |
+
"gap_size": 0,
|
14 |
+
"hidden_act": "gelu_new",
|
15 |
+
"hidden_dropout_prob": 0,
|
16 |
+
"hidden_size": 768,
|
17 |
+
"id2label": {
|
18 |
+
"0": "LABEL_0",
|
19 |
+
"1": "LABEL_1",
|
20 |
+
"2": "LABEL_2",
|
21 |
+
"3": "LABEL_3",
|
22 |
+
"4": "LABEL_4",
|
23 |
+
"5": "LABEL_5"
|
24 |
+
},
|
25 |
+
"initializer_range": 0.02,
|
26 |
+
"inner_group_num": 1,
|
27 |
+
"intermediate_size": 3072,
|
28 |
+
"label2id": {
|
29 |
+
"LABEL_0": 0,
|
30 |
+
"LABEL_1": 1,
|
31 |
+
"LABEL_2": 2,
|
32 |
+
"LABEL_3": 3,
|
33 |
+
"LABEL_4": 4,
|
34 |
+
"LABEL_5": 5
|
35 |
+
},
|
36 |
+
"layer_norm_eps": 1e-12,
|
37 |
+
"max_position_embeddings": 512,
|
38 |
+
"model_type": "albert",
|
39 |
+
"net_structure_type": 0,
|
40 |
+
"num_attention_heads": 12,
|
41 |
+
"num_hidden_groups": 1,
|
42 |
+
"num_hidden_layers": 12,
|
43 |
+
"num_memory_blocks": 0,
|
44 |
+
"pad_token_id": 0,
|
45 |
+
"position_embedding_type": "absolute",
|
46 |
+
"torch_dtype": "float32",
|
47 |
+
"transformers_version": "4.26.1",
|
48 |
+
"type_vocab_size": 2,
|
49 |
+
"vocab_size": 30000
|
50 |
+
}
|
outputs/train/albert/STS-B/checkpoint-600/eval_results.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
acc = 0.592
|
outputs/train/albert/STS-B/checkpoint-600/optimizer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4dfd6bf08df6da49349982cc0ab76838a6ffad68ec2a923fbc55722a5a5d9f49
|
3 |
+
size 93939655
|
outputs/train/albert/STS-B/checkpoint-600/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:28fb7cdd36c39e7158f17e1aa8254e0ac22cf9c3d8cb199b71738e7fa371e002
|
3 |
+
size 46976071
|
outputs/train/albert/STS-B/checkpoint-600/scheduler.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a7f50584e3f70293ee9c0a9d963b05655869ad22d2d8ef97958b8d690dc508c
|
3 |
+
size 627
|
outputs/train/albert/STS-B/checkpoint-600/special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[CLS]",
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"eos_token": "[SEP]",
|
5 |
+
"mask_token": {
|
6 |
+
"content": "[MASK]",
|
7 |
+
"lstrip": true,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"pad_token": "<pad>",
|
13 |
+
"sep_token": "[SEP]",
|
14 |
+
"unk_token": "<unk>"
|
15 |
+
}
|
outputs/train/albert/STS-B/checkpoint-600/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336
|
3 |
+
size 760289
|
outputs/train/albert/STS-B/checkpoint-600/tokenizer_config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[CLS]",
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"do_lower_case": true,
|
5 |
+
"eos_token": "[SEP]",
|
6 |
+
"keep_accents": false,
|
7 |
+
"mask_token": {
|
8 |
+
"__type": "AddedToken",
|
9 |
+
"content": "[MASK]",
|
10 |
+
"lstrip": true,
|
11 |
+
"normalized": false,
|
12 |
+
"rstrip": false,
|
13 |
+
"single_word": false
|
14 |
+
},
|
15 |
+
"model_max_length": 512,
|
16 |
+
"name_or_path": "albert-base-v2",
|
17 |
+
"pad_token": "<pad>",
|
18 |
+
"remove_space": true,
|
19 |
+
"sep_token": "[SEP]",
|
20 |
+
"sp_model_kwargs": {},
|
21 |
+
"special_tokens_map_file": null,
|
22 |
+
"tokenizer_class": "AlbertTokenizer",
|
23 |
+
"unk_token": "<unk>"
|
24 |
+
}
|
outputs/train/albert/STS-B/checkpoint-600/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91be98fab55128e207d96f3343dd0adfe1ccc97af4dfea9febedf9badfb20498
|
3 |
+
size 1595
|
outputs/train/albert/STS-B/config.json
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "albert-base-v2",
|
3 |
+
"architectures": [
|
4 |
+
"AlbertForSequenceClassification"
|
5 |
+
],
|
6 |
+
"attention_probs_dropout_prob": 0,
|
7 |
+
"bos_token_id": 2,
|
8 |
+
"classifier_dropout_prob": 0.1,
|
9 |
+
"down_scale_factor": 1,
|
10 |
+
"embedding_size": 128,
|
11 |
+
"eos_token_id": 3,
|
12 |
+
"finetuning_task": "sts-b",
|
13 |
+
"gap_size": 0,
|
14 |
+
"hidden_act": "gelu_new",
|
15 |
+
"hidden_dropout_prob": 0,
|
16 |
+
"hidden_size": 768,
|
17 |
+
"id2label": {
|
18 |
+
"0": "LABEL_0",
|
19 |
+
"1": "LABEL_1",
|
20 |
+
"2": "LABEL_2",
|
21 |
+
"3": "LABEL_3",
|
22 |
+
"4": "LABEL_4",
|
23 |
+
"5": "LABEL_5"
|
24 |
+
},
|
25 |
+
"initializer_range": 0.02,
|
26 |
+
"inner_group_num": 1,
|
27 |
+
"intermediate_size": 3072,
|
28 |
+
"label2id": {
|
29 |
+
"LABEL_0": 0,
|
30 |
+
"LABEL_1": 1,
|
31 |
+
"LABEL_2": 2,
|
32 |
+
"LABEL_3": 3,
|
33 |
+
"LABEL_4": 4,
|
34 |
+
"LABEL_5": 5
|
35 |
+
},
|
36 |
+
"layer_norm_eps": 1e-12,
|
37 |
+
"max_position_embeddings": 512,
|
38 |
+
"model_type": "albert",
|
39 |
+
"net_structure_type": 0,
|
40 |
+
"num_attention_heads": 12,
|
41 |
+
"num_hidden_groups": 1,
|
42 |
+
"num_hidden_layers": 12,
|
43 |
+
"num_memory_blocks": 0,
|
44 |
+
"pad_token_id": 0,
|
45 |
+
"position_embedding_type": "absolute",
|
46 |
+
"torch_dtype": "float32",
|
47 |
+
"transformers_version": "4.26.1",
|
48 |
+
"type_vocab_size": 2,
|
49 |
+
"vocab_size": 30000
|
50 |
+
}
|
outputs/train/albert/STS-B/eval_results.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
acc = 0.5846666666666667
|
outputs/train/albert/STS-B/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ea08c98bcdccf7f9f4f8d7c12c8603d420e5524d508722bdb357e75b4b81a3ef
|
3 |
+
size 46976071
|
outputs/train/albert/STS-B/special_tokens_map.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[CLS]",
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"eos_token": "[SEP]",
|
5 |
+
"mask_token": {
|
6 |
+
"content": "[MASK]",
|
7 |
+
"lstrip": true,
|
8 |
+
"normalized": false,
|
9 |
+
"rstrip": false,
|
10 |
+
"single_word": false
|
11 |
+
},
|
12 |
+
"pad_token": "<pad>",
|
13 |
+
"sep_token": "[SEP]",
|
14 |
+
"unk_token": "<unk>"
|
15 |
+
}
|
outputs/train/albert/STS-B/spiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fefb02b667a6c5c2fe27602d28e5fb3428f66ab89c7d6f388e7c8d44a02d0336
|
3 |
+
size 760289
|
outputs/train/albert/STS-B/tokenizer_config.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": "[CLS]",
|
3 |
+
"cls_token": "[CLS]",
|
4 |
+
"do_lower_case": true,
|
5 |
+
"eos_token": "[SEP]",
|
6 |
+
"keep_accents": false,
|
7 |
+
"mask_token": {
|
8 |
+
"__type": "AddedToken",
|
9 |
+
"content": "[MASK]",
|
10 |
+
"lstrip": true,
|
11 |
+
"normalized": false,
|
12 |
+
"rstrip": false,
|
13 |
+
"single_word": false
|
14 |
+
},
|
15 |
+
"model_max_length": 512,
|
16 |
+
"name_or_path": "albert-base-v2",
|
17 |
+
"pad_token": "<pad>",
|
18 |
+
"remove_space": true,
|
19 |
+
"sep_token": "[SEP]",
|
20 |
+
"sp_model_kwargs": {},
|
21 |
+
"special_tokens_map_file": null,
|
22 |
+
"tokenizer_class": "AlbertTokenizer",
|
23 |
+
"unk_token": "<unk>"
|
24 |
+
}
|
outputs/train/albert/STS-B/training_args.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91be98fab55128e207d96f3343dd0adfe1ccc97af4dfea9febedf9badfb20498
|
3 |
+
size 1595
|
pabee/__pycache__/modeling_albert.cpython-37.pyc
ADDED
Binary file (36.7 kB). View file
|
|
pabee/__pycache__/modeling_albert.cpython-39.pyc
ADDED
Binary file (36.6 kB). View file
|
|
pabee/__pycache__/modeling_bert.cpython-37.pyc
ADDED
Binary file (59.4 kB). View file
|
|
pabee/__pycache__/modeling_bert.cpython-39.pyc
ADDED
Binary file (59 kB). View file
|
|
pabee/configuration_albert.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" ALBERT model configuration """
|
17 |
+
|
18 |
+
from transformers.configuration_utils import PretrainedConfig
|
19 |
+
|
20 |
+
|
21 |
+
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
22 |
+
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-config.json",
|
23 |
+
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-config.json",
|
24 |
+
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-config.json",
|
25 |
+
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-config.json",
|
26 |
+
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-config.json",
|
27 |
+
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-config.json",
|
28 |
+
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-config.json",
|
29 |
+
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-config.json",
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
class AlbertConfig(PretrainedConfig):
|
34 |
+
r"""
|
35 |
+
This is the configuration class to store the configuration of an :class:`~transformers.AlbertModel`.
|
36 |
+
It is used to instantiate an ALBERT model according to the specified arguments, defining the model
|
37 |
+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
38 |
+
the ALBERT `xxlarge <https://huggingface.co/albert-xxlarge-v2>`__ architecture.
|
39 |
+
|
40 |
+
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
41 |
+
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
42 |
+
for more information.
|
43 |
+
|
44 |
+
|
45 |
+
Args:
|
46 |
+
vocab_size (:obj:`int`, optional, defaults to 30000):
|
47 |
+
Vocabulary size of the ALBERT model. Defines the different tokens that
|
48 |
+
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.AlbertModel`.
|
49 |
+
embedding_size (:obj:`int`, optional, defaults to 128):
|
50 |
+
Dimensionality of vocabulary embeddings.
|
51 |
+
hidden_size (:obj:`int`, optional, defaults to 4096):
|
52 |
+
Dimensionality of the encoder layers and the pooler layer.
|
53 |
+
num_hidden_layers (:obj:`int`, optional, defaults to 12):
|
54 |
+
Number of hidden layers in the Transformer encoder.
|
55 |
+
num_hidden_groups (:obj:`int`, optional, defaults to 1):
|
56 |
+
Number of groups for the hidden layers, parameters in the same group are shared.
|
57 |
+
num_attention_heads (:obj:`int`, optional, defaults to 64):
|
58 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
59 |
+
intermediate_size (:obj:`int`, optional, defaults to 16384):
|
60 |
+
The dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
61 |
+
inner_group_num (:obj:`int`, optional, defaults to 1):
|
62 |
+
The number of inner repetition of attention and ffn.
|
63 |
+
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu_new"):
|
64 |
+
The non-linear activation function (function or string) in the encoder and pooler.
|
65 |
+
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
66 |
+
hidden_dropout_prob (:obj:`float`, optional, defaults to 0):
|
67 |
+
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
68 |
+
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0):
|
69 |
+
The dropout ratio for the attention probabilities.
|
70 |
+
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
71 |
+
The maximum sequence length that this model might ever be used with. Typically set this to something
|
72 |
+
large (e.g., 512 or 1024 or 2048).
|
73 |
+
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
74 |
+
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.AlbertModel`.
|
75 |
+
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
76 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
77 |
+
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
78 |
+
The epsilon used by the layer normalization layers.
|
79 |
+
classifier_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
80 |
+
The dropout ratio for attached classifiers.
|
81 |
+
|
82 |
+
Example::
|
83 |
+
|
84 |
+
from transformers import AlbertConfig, AlbertModel
|
85 |
+
# Initializing an ALBERT-xxlarge style configuration
|
86 |
+
albert_xxlarge_configuration = AlbertConfig()
|
87 |
+
|
88 |
+
# Initializing an ALBERT-base style configuration
|
89 |
+
albert_base_configuration = AlbertConfig(
|
90 |
+
hidden_size=768,
|
91 |
+
num_attention_heads=12,
|
92 |
+
intermediate_size=3072,
|
93 |
+
)
|
94 |
+
|
95 |
+
# Initializing a model from the ALBERT-base style configuration
|
96 |
+
model = AlbertModel(albert_xxlarge_configuration)
|
97 |
+
|
98 |
+
# Accessing the model configuration
|
99 |
+
configuration = model.config
|
100 |
+
|
101 |
+
Attributes:
|
102 |
+
pretrained_config_archive_map (Dict[str, str]):
|
103 |
+
A dictionary containing all the available pre-trained checkpoints.
|
104 |
+
"""
|
105 |
+
|
106 |
+
pretrained_config_archive_map = ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
107 |
+
model_type = "albert"
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self,
|
111 |
+
vocab_size=30000,
|
112 |
+
embedding_size=128,
|
113 |
+
hidden_size=4096,
|
114 |
+
num_hidden_layers=12,
|
115 |
+
num_hidden_groups=1,
|
116 |
+
num_attention_heads=64,
|
117 |
+
intermediate_size=16384,
|
118 |
+
inner_group_num=1,
|
119 |
+
hidden_act="gelu_new",
|
120 |
+
hidden_dropout_prob=0,
|
121 |
+
attention_probs_dropout_prob=0,
|
122 |
+
max_position_embeddings=512,
|
123 |
+
type_vocab_size=2,
|
124 |
+
initializer_range=0.02,
|
125 |
+
layer_norm_eps=1e-12,
|
126 |
+
classifier_dropout_prob=0.1,
|
127 |
+
**kwargs
|
128 |
+
):
|
129 |
+
super().__init__(**kwargs)
|
130 |
+
|
131 |
+
self.vocab_size = vocab_size
|
132 |
+
self.embedding_size = embedding_size
|
133 |
+
self.hidden_size = hidden_size
|
134 |
+
self.num_hidden_layers = num_hidden_layers
|
135 |
+
self.num_hidden_groups = num_hidden_groups
|
136 |
+
self.num_attention_heads = num_attention_heads
|
137 |
+
self.inner_group_num = inner_group_num
|
138 |
+
self.hidden_act = hidden_act
|
139 |
+
self.intermediate_size = intermediate_size
|
140 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
141 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
142 |
+
self.max_position_embeddings = max_position_embeddings
|
143 |
+
self.type_vocab_size = type_vocab_size
|
144 |
+
self.initializer_range = initializer_range
|
145 |
+
self.layer_norm_eps = layer_norm_eps
|
146 |
+
self.classifier_dropout_prob = classifier_dropout_prob
|
pabee/configuration_bert.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" BERT model configuration """
|
17 |
+
|
18 |
+
|
19 |
+
import logging
|
20 |
+
|
21 |
+
from transformers.configuration_utils import PretrainedConfig
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger(__name__)
|
25 |
+
|
26 |
+
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
27 |
+
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
|
28 |
+
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",
|
29 |
+
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json",
|
30 |
+
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-config.json",
|
31 |
+
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-config.json",
|
32 |
+
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-config.json",
|
33 |
+
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json",
|
34 |
+
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-config.json",
|
35 |
+
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-config.json",
|
36 |
+
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-config.json",
|
37 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-config.json",
|
38 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-config.json",
|
39 |
+
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-config.json",
|
40 |
+
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-config.json",
|
41 |
+
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-config.json",
|
42 |
+
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-config.json",
|
43 |
+
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-config.json",
|
44 |
+
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-config.json",
|
45 |
+
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-config.json",
|
46 |
+
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/config.json",
|
47 |
+
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/config.json",
|
48 |
+
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/config.json",
|
49 |
+
}
|
50 |
+
|
51 |
+
|
52 |
+
class BertConfig(PretrainedConfig):
|
53 |
+
r"""
|
54 |
+
This is the configuration class to store the configuration of a :class:`~transformers.BertModel`.
|
55 |
+
It is used to instantiate an BERT model according to the specified arguments, defining the model
|
56 |
+
architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of
|
57 |
+
the BERT `bert-base-uncased <https://huggingface.co/bert-base-uncased>`__ architecture.
|
58 |
+
|
59 |
+
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used
|
60 |
+
to control the model outputs. Read the documentation from :class:`~transformers.PretrainedConfig`
|
61 |
+
for more information.
|
62 |
+
|
63 |
+
|
64 |
+
Args:
|
65 |
+
vocab_size (:obj:`int`, optional, defaults to 30522):
|
66 |
+
Vocabulary size of the BERT model. Defines the different tokens that
|
67 |
+
can be represented by the `inputs_ids` passed to the forward method of :class:`~transformers.BertModel`.
|
68 |
+
hidden_size (:obj:`int`, optional, defaults to 768):
|
69 |
+
Dimensionality of the encoder layers and the pooler layer.
|
70 |
+
num_hidden_layers (:obj:`int`, optional, defaults to 12):
|
71 |
+
Number of hidden layers in the Transformer encoder.
|
72 |
+
num_attention_heads (:obj:`int`, optional, defaults to 12):
|
73 |
+
Number of attention heads for each attention layer in the Transformer encoder.
|
74 |
+
intermediate_size (:obj:`int`, optional, defaults to 3072):
|
75 |
+
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
76 |
+
hidden_act (:obj:`str` or :obj:`function`, optional, defaults to "gelu"):
|
77 |
+
The non-linear activation function (function or string) in the encoder and pooler.
|
78 |
+
If string, "gelu", "relu", "swish" and "gelu_new" are supported.
|
79 |
+
hidden_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
80 |
+
The dropout probabilitiy for all fully connected layers in the embeddings, encoder, and pooler.
|
81 |
+
attention_probs_dropout_prob (:obj:`float`, optional, defaults to 0.1):
|
82 |
+
The dropout ratio for the attention probabilities.
|
83 |
+
max_position_embeddings (:obj:`int`, optional, defaults to 512):
|
84 |
+
The maximum sequence length that this model might ever be used with.
|
85 |
+
Typically set this to something large just in case (e.g., 512 or 1024 or 2048).
|
86 |
+
type_vocab_size (:obj:`int`, optional, defaults to 2):
|
87 |
+
The vocabulary size of the `token_type_ids` passed into :class:`~transformers.BertModel`.
|
88 |
+
initializer_range (:obj:`float`, optional, defaults to 0.02):
|
89 |
+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
90 |
+
layer_norm_eps (:obj:`float`, optional, defaults to 1e-12):
|
91 |
+
The epsilon used by the layer normalization layers.
|
92 |
+
|
93 |
+
Example::
|
94 |
+
|
95 |
+
from transformers import BertModel, BertConfig
|
96 |
+
|
97 |
+
# Initializing a BERT bert-base-uncased style configuration
|
98 |
+
configuration = BertConfig()
|
99 |
+
|
100 |
+
# Initializing a model from the bert-base-uncased style configuration
|
101 |
+
model = BertModel(configuration)
|
102 |
+
|
103 |
+
# Accessing the model configuration
|
104 |
+
configuration = model.config
|
105 |
+
|
106 |
+
Attributes:
|
107 |
+
pretrained_config_archive_map (Dict[str, str]):
|
108 |
+
A dictionary containing all the available pre-trained checkpoints.
|
109 |
+
"""
|
110 |
+
pretrained_config_archive_map = BERT_PRETRAINED_CONFIG_ARCHIVE_MAP
|
111 |
+
model_type = "bert"
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self,
|
115 |
+
vocab_size=30522,
|
116 |
+
hidden_size=768,
|
117 |
+
num_hidden_layers=12,
|
118 |
+
num_attention_heads=12,
|
119 |
+
intermediate_size=3072,
|
120 |
+
hidden_act="gelu",
|
121 |
+
hidden_dropout_prob=0.1,
|
122 |
+
attention_probs_dropout_prob=0.1,
|
123 |
+
max_position_embeddings=512,
|
124 |
+
type_vocab_size=2,
|
125 |
+
initializer_range=0.02,
|
126 |
+
layer_norm_eps=1e-12,
|
127 |
+
**kwargs
|
128 |
+
):
|
129 |
+
super().__init__(**kwargs)
|
130 |
+
|
131 |
+
self.vocab_size = vocab_size
|
132 |
+
self.hidden_size = hidden_size
|
133 |
+
self.num_hidden_layers = num_hidden_layers
|
134 |
+
self.num_attention_heads = num_attention_heads
|
135 |
+
self.hidden_act = hidden_act
|
136 |
+
self.intermediate_size = intermediate_size
|
137 |
+
self.hidden_dropout_prob = hidden_dropout_prob
|
138 |
+
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
139 |
+
self.max_position_embeddings = max_position_embeddings
|
140 |
+
self.type_vocab_size = type_vocab_size
|
141 |
+
self.initializer_range = initializer_range
|
142 |
+
self.layer_norm_eps = layer_norm_eps
|
pabee/modeling_albert.py
ADDED
@@ -0,0 +1,1085 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
"""PyTorch ALBERT model. """
|
16 |
+
|
17 |
+
import logging
|
18 |
+
import math
|
19 |
+
import os
|
20 |
+
|
21 |
+
import torch
|
22 |
+
import torch.nn as nn
|
23 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
24 |
+
from datetime import datetime
|
25 |
+
|
26 |
+
from transformers.models.albert.configuration_albert import AlbertConfig
|
27 |
+
from transformers.models.bert.modeling_bert import ACT2FN,BertEmbeddings, BertSelfAttention, prune_linear_layer
|
28 |
+
# from transformers.configuration_albert import AlbertConfig
|
29 |
+
# from transformers.modeling_bert import ACT2FN, BertEmbeddings, BertSelfAttention, prune_linear_layer
|
30 |
+
from transformers.modeling_utils import PreTrainedModel
|
31 |
+
|
32 |
+
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
33 |
+
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
39 |
+
"albert-base-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-pytorch_model.bin",
|
40 |
+
"albert-large-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-pytorch_model.bin",
|
41 |
+
"albert-xlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-pytorch_model.bin",
|
42 |
+
"albert-xxlarge-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-pytorch_model.bin",
|
43 |
+
"albert-base-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-base-v2-pytorch_model.bin",
|
44 |
+
"albert-large-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-large-v2-pytorch_model.bin",
|
45 |
+
"albert-xlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xlarge-v2-pytorch_model.bin",
|
46 |
+
"albert-xxlarge-v2": "https://s3.amazonaws.com/models.huggingface.co/bert/albert-xxlarge-v2-pytorch_model.bin",
|
47 |
+
}
|
48 |
+
|
49 |
+
# load pretrained weights from tensorflow
|
50 |
+
def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
|
51 |
+
""" Load tf checkpoints in a pytorch model."""
|
52 |
+
try:
|
53 |
+
import re
|
54 |
+
import numpy as np
|
55 |
+
import tensorflow as tf
|
56 |
+
except ImportError:
|
57 |
+
logger.error(
|
58 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
59 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
60 |
+
)
|
61 |
+
raise
|
62 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
63 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
64 |
+
# Load weights from TF mode·l
|
65 |
+
init_vars = tf.train.list_variables(tf_path)
|
66 |
+
names = []
|
67 |
+
arrays = []
|
68 |
+
for name, shape in init_vars:
|
69 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
70 |
+
array = tf.train.load_variable(tf_path, name)
|
71 |
+
names.append(name)
|
72 |
+
arrays.append(array)
|
73 |
+
|
74 |
+
for name, array in zip(names, arrays):
|
75 |
+
print(name)
|
76 |
+
|
77 |
+
for name, array in zip(names, arrays):
|
78 |
+
original_name = name
|
79 |
+
|
80 |
+
# If saved from the TF HUB module
|
81 |
+
name = name.replace("module/", "")
|
82 |
+
|
83 |
+
# Renaming and simplifying
|
84 |
+
name = name.replace("ffn_1", "ffn")
|
85 |
+
name = name.replace("bert/", "albert/")
|
86 |
+
name = name.replace("attention_1", "attention")
|
87 |
+
name = name.replace("transform/", "")
|
88 |
+
name = name.replace("LayerNorm_1", "full_layer_layer_norm")
|
89 |
+
name = name.replace("LayerNorm", "attention/LayerNorm")
|
90 |
+
name = name.replace("transformer/", "")
|
91 |
+
|
92 |
+
# The feed forward layer had an 'intermediate' step which has been abstracted away
|
93 |
+
name = name.replace("intermediate/dense/", "")
|
94 |
+
name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
|
95 |
+
|
96 |
+
# ALBERT attention was split between self and output which have been abstracted away
|
97 |
+
name = name.replace("/output/", "/")
|
98 |
+
name = name.replace("/self/", "/")
|
99 |
+
|
100 |
+
# The pooler is a linear layer
|
101 |
+
name = name.replace("pooler/dense", "pooler")
|
102 |
+
|
103 |
+
# The classifier was simplified to predictions from cls/predictions
|
104 |
+
name = name.replace("cls/predictions", "predictions")
|
105 |
+
name = name.replace("predictions/attention", "predictions")
|
106 |
+
|
107 |
+
# Naming was changed to be more explicit
|
108 |
+
name = name.replace("embeddings/attention", "embeddings")
|
109 |
+
name = name.replace("inner_group_", "albert_layers/")
|
110 |
+
name = name.replace("group_", "albert_layer_groups/")
|
111 |
+
|
112 |
+
# Classifier
|
113 |
+
if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
|
114 |
+
name = "classifier/" + name
|
115 |
+
|
116 |
+
# No ALBERT model currently handles the next sentence prediction task
|
117 |
+
if "seq_relationship" in name:
|
118 |
+
continue
|
119 |
+
|
120 |
+
name = name.split("/")
|
121 |
+
|
122 |
+
# Ignore the gradients applied by the LAMB/ADAM optimizers.
|
123 |
+
if (
|
124 |
+
"adam_m" in name
|
125 |
+
or "adam_v" in name
|
126 |
+
or "AdamWeightDecayOptimizer" in name
|
127 |
+
or "AdamWeightDecayOptimizer_1" in name
|
128 |
+
or "global_step" in name
|
129 |
+
):
|
130 |
+
logger.info("Skipping {}".format("/".join(name)))
|
131 |
+
continue
|
132 |
+
|
133 |
+
pointer = model
|
134 |
+
for m_name in name:
|
135 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
136 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
137 |
+
else:
|
138 |
+
scope_names = [m_name]
|
139 |
+
|
140 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
141 |
+
pointer = getattr(pointer, "weight")
|
142 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
143 |
+
pointer = getattr(pointer, "bias")
|
144 |
+
elif scope_names[0] == "output_weights":
|
145 |
+
pointer = getattr(pointer, "weight")
|
146 |
+
elif scope_names[0] == "squad":
|
147 |
+
pointer = getattr(pointer, "classifier")
|
148 |
+
else:
|
149 |
+
try:
|
150 |
+
pointer = getattr(pointer, scope_names[0])
|
151 |
+
except AttributeError:
|
152 |
+
logger.info("Skipping {}".format("/".join(name)))
|
153 |
+
continue
|
154 |
+
if len(scope_names) >= 2:
|
155 |
+
num = int(scope_names[1])
|
156 |
+
pointer = pointer[num]
|
157 |
+
|
158 |
+
if m_name[-11:] == "_embeddings":
|
159 |
+
pointer = getattr(pointer, "weight")
|
160 |
+
elif m_name == "kernel":
|
161 |
+
array = np.transpose(array)
|
162 |
+
try:
|
163 |
+
assert pointer.shape == array.shape
|
164 |
+
except AssertionError as e:
|
165 |
+
e.args += (pointer.shape, array.shape)
|
166 |
+
raise
|
167 |
+
print("Initialize PyTorch weight {} from {}".format(name, original_name))
|
168 |
+
pointer.data = torch.from_numpy(array)
|
169 |
+
|
170 |
+
return model
|
171 |
+
|
172 |
+
|
173 |
+
class AlbertEmbeddings(BertEmbeddings):
|
174 |
+
"""
|
175 |
+
Construct the embeddings from word, position and token_type embeddings.
|
176 |
+
"""
|
177 |
+
|
178 |
+
def __init__(self, config):
|
179 |
+
super().__init__(config)
|
180 |
+
|
181 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=0)
|
182 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
|
183 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
|
184 |
+
self.LayerNorm = torch.nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
|
185 |
+
|
186 |
+
|
187 |
+
class AlbertAttention(BertSelfAttention):
|
188 |
+
def __init__(self, config):
|
189 |
+
super().__init__(config)
|
190 |
+
|
191 |
+
self.output_attentions = config.output_attentions
|
192 |
+
self.num_attention_heads = config.num_attention_heads
|
193 |
+
self.hidden_size = config.hidden_size
|
194 |
+
self.attention_head_size = config.hidden_size // config.num_attention_heads
|
195 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
196 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
197 |
+
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
198 |
+
self.pruned_heads = set()
|
199 |
+
|
200 |
+
def prune_heads(self, heads):
|
201 |
+
if len(heads) == 0:
|
202 |
+
return
|
203 |
+
mask = torch.ones(self.num_attention_heads, self.attention_head_size)
|
204 |
+
heads = set(heads) - self.pruned_heads # Convert to set and emove already pruned heads
|
205 |
+
for head in heads:
|
206 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
207 |
+
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
208 |
+
mask[head] = 0
|
209 |
+
mask = mask.view(-1).contiguous().eq(1)
|
210 |
+
index = torch.arange(len(mask))[mask].long()
|
211 |
+
|
212 |
+
# Prune linear layers
|
213 |
+
self.query = prune_linear_layer(self.query, index)
|
214 |
+
self.key = prune_linear_layer(self.key, index)
|
215 |
+
self.value = prune_linear_layer(self.value, index)
|
216 |
+
self.dense = prune_linear_layer(self.dense, index, dim=1)
|
217 |
+
|
218 |
+
# Update hyper params and store pruned heads
|
219 |
+
self.num_attention_heads = self.num_attention_heads - len(heads)
|
220 |
+
self.all_head_size = self.attention_head_size * self.num_attention_heads
|
221 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
222 |
+
|
223 |
+
def forward(self, input_ids, attention_mask=None, head_mask=None):
|
224 |
+
mixed_query_layer = self.query(input_ids)
|
225 |
+
mixed_key_layer = self.key(input_ids)
|
226 |
+
mixed_value_layer = self.value(input_ids)
|
227 |
+
|
228 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
229 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
230 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
231 |
+
|
232 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
233 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
234 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
235 |
+
if attention_mask is not None:
|
236 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
237 |
+
attention_scores = attention_scores + attention_mask
|
238 |
+
|
239 |
+
# Normalize the attention scores to probabilities.
|
240 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
241 |
+
|
242 |
+
# This is actually dropping out entire tokens to attend to, which might
|
243 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
244 |
+
attention_probs = self.dropout(attention_probs)
|
245 |
+
|
246 |
+
# Mask heads if we want to
|
247 |
+
if head_mask is not None:
|
248 |
+
attention_probs = attention_probs * head_mask
|
249 |
+
|
250 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
251 |
+
|
252 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
253 |
+
|
254 |
+
# Should find a better way to do this
|
255 |
+
w = (
|
256 |
+
self.dense.weight.t()
|
257 |
+
.view(self.num_attention_heads, self.attention_head_size, self.hidden_size)
|
258 |
+
.to(context_layer.dtype)
|
259 |
+
)
|
260 |
+
b = self.dense.bias.to(context_layer.dtype)
|
261 |
+
|
262 |
+
projected_context_layer = torch.einsum("bfnd,ndh->bfh", context_layer, w) + b
|
263 |
+
projected_context_layer_dropout = self.dropout(projected_context_layer)
|
264 |
+
layernormed_context_layer = self.LayerNorm(input_ids + projected_context_layer_dropout)
|
265 |
+
return (layernormed_context_layer, attention_probs) if self.output_attentions else (layernormed_context_layer,)
|
266 |
+
|
267 |
+
|
268 |
+
class AlbertLayer(nn.Module):
|
269 |
+
def __init__(self, config):
|
270 |
+
super().__init__()
|
271 |
+
|
272 |
+
self.config = config
|
273 |
+
self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
274 |
+
self.attention = AlbertAttention(config)
|
275 |
+
self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
|
276 |
+
self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
|
277 |
+
self.activation = ACT2FN[config.hidden_act]
|
278 |
+
|
279 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
280 |
+
attention_output = self.attention(hidden_states, attention_mask, head_mask)
|
281 |
+
ffn_output = self.ffn(attention_output[0])
|
282 |
+
ffn_output = self.activation(ffn_output)
|
283 |
+
ffn_output = self.ffn_output(ffn_output)
|
284 |
+
hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
|
285 |
+
|
286 |
+
return (hidden_states,) + attention_output[1:] # add attentions if we output them
|
287 |
+
|
288 |
+
|
289 |
+
class AlbertLayerGroup(nn.Module):
|
290 |
+
def __init__(self, config):
|
291 |
+
super().__init__()
|
292 |
+
|
293 |
+
self.output_attentions = config.output_attentions
|
294 |
+
self.output_hidden_states = config.output_hidden_states
|
295 |
+
self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
|
296 |
+
|
297 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
298 |
+
layer_hidden_states = ()
|
299 |
+
layer_attentions = ()
|
300 |
+
|
301 |
+
for layer_index, albert_layer in enumerate(self.albert_layers):
|
302 |
+
layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index])
|
303 |
+
hidden_states = layer_output[0]
|
304 |
+
|
305 |
+
if self.output_attentions:
|
306 |
+
layer_attentions = layer_attentions + (layer_output[1],)
|
307 |
+
|
308 |
+
if self.output_hidden_states:
|
309 |
+
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
310 |
+
|
311 |
+
outputs = (hidden_states,)
|
312 |
+
if self.output_hidden_states:
|
313 |
+
outputs = outputs + (layer_hidden_states,)
|
314 |
+
if self.output_attentions:
|
315 |
+
outputs = outputs + (layer_attentions,)
|
316 |
+
return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
|
317 |
+
|
318 |
+
|
319 |
+
class AlbertTransformer(nn.Module):
|
320 |
+
def __init__(self, config):
|
321 |
+
super().__init__()
|
322 |
+
|
323 |
+
self.config = config
|
324 |
+
self.output_attentions = config.output_attentions
|
325 |
+
self.output_hidden_states = config.output_hidden_states
|
326 |
+
self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
|
327 |
+
self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
328 |
+
|
329 |
+
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
330 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
331 |
+
|
332 |
+
all_attentions = ()
|
333 |
+
|
334 |
+
if self.output_hidden_states:
|
335 |
+
all_hidden_states = (hidden_states,)
|
336 |
+
|
337 |
+
for i in range(self.config.num_hidden_layers):
|
338 |
+
# Number of layers in a hidden group
|
339 |
+
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
340 |
+
|
341 |
+
# Index of the hidden group
|
342 |
+
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
343 |
+
|
344 |
+
layer_group_output = self.albert_layer_groups[group_idx](
|
345 |
+
hidden_states,
|
346 |
+
attention_mask,
|
347 |
+
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
348 |
+
)
|
349 |
+
hidden_states = layer_group_output[0]
|
350 |
+
|
351 |
+
if self.output_attentions:
|
352 |
+
all_attentions = all_attentions + layer_group_output[-1]
|
353 |
+
|
354 |
+
if self.output_hidden_states:
|
355 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
356 |
+
|
357 |
+
outputs = (hidden_states,)
|
358 |
+
if self.output_hidden_states:
|
359 |
+
outputs = outputs + (all_hidden_states,)
|
360 |
+
if self.output_attentions:
|
361 |
+
outputs = outputs + (all_attentions,)
|
362 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
363 |
+
|
364 |
+
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
365 |
+
if current_layer == 0:
|
366 |
+
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
367 |
+
else:
|
368 |
+
hidden_states = hidden_states[0]
|
369 |
+
|
370 |
+
layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
|
371 |
+
|
372 |
+
# Index of the hidden group
|
373 |
+
group_idx = int(current_layer / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
374 |
+
|
375 |
+
# Index of the layer inside the group
|
376 |
+
layer_idx = int(current_layer - group_idx * layers_per_group)
|
377 |
+
|
378 |
+
layer_group_output = self.albert_layer_groups[group_idx](hidden_states, attention_mask, head_mask[group_idx * layers_per_group:(group_idx + 1) * layers_per_group])
|
379 |
+
hidden_states = layer_group_output[0]
|
380 |
+
|
381 |
+
return (hidden_states,)
|
382 |
+
|
383 |
+
class AlbertPreTrainedModel(PreTrainedModel):
|
384 |
+
""" An abstract class to handle weights initialization and
|
385 |
+
a simple interface for downloading and loading pretrained models.
|
386 |
+
"""
|
387 |
+
|
388 |
+
config_class = AlbertConfig
|
389 |
+
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
390 |
+
base_model_prefix = "albert"
|
391 |
+
|
392 |
+
def _init_weights(self, module):
|
393 |
+
""" Initialize the weights.
|
394 |
+
"""
|
395 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
396 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
397 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
398 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
399 |
+
if isinstance(module, (nn.Linear)) and module.bias is not None:
|
400 |
+
module.bias.data.zero_()
|
401 |
+
elif isinstance(module, nn.LayerNorm):
|
402 |
+
module.bias.data.zero_()
|
403 |
+
module.weight.data.fill_(1.0)
|
404 |
+
|
405 |
+
|
406 |
+
ALBERT_START_DOCSTRING = r"""
|
407 |
+
|
408 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
409 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
410 |
+
usage and behavior.
|
411 |
+
|
412 |
+
Args:
|
413 |
+
config (:class:`~transformers.AlbertConfig`): Model configuration class with all the parameters of the model.
|
414 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
415 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
416 |
+
"""
|
417 |
+
|
418 |
+
ALBERT_INPUTS_DOCSTRING = r"""
|
419 |
+
Args:
|
420 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
421 |
+
Indices of input sequence tokens in the vocabulary.
|
422 |
+
|
423 |
+
Indices can be obtained using :class:`transformers.AlbertTokenizer`.
|
424 |
+
See :func:`transformers.PreTrainedTokenizer.encode` and
|
425 |
+
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
426 |
+
|
427 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
428 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
429 |
+
Mask to avoid performing attention on padding token indices.
|
430 |
+
Mask values selected in ``[0, 1]``:
|
431 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
432 |
+
|
433 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
434 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
435 |
+
Segment token indices to indicate first and second portions of the inputs.
|
436 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
437 |
+
corresponds to a `sentence B` token
|
438 |
+
|
439 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
440 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
441 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
442 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
443 |
+
|
444 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
445 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
446 |
+
Mask to nullify selected heads of the self-attention modules.
|
447 |
+
Mask values selected in ``[0, 1]``:
|
448 |
+
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
449 |
+
input_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
450 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
451 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
452 |
+
than the model's internal embedding lookup matrix.
|
453 |
+
"""
|
454 |
+
|
455 |
+
|
456 |
+
@add_start_docstrings(
|
457 |
+
"The bare ALBERT Model transformer outputting raw hidden-states without any specific head on top.",
|
458 |
+
ALBERT_START_DOCSTRING,
|
459 |
+
)
|
460 |
+
class AlbertModel(AlbertPreTrainedModel):
|
461 |
+
|
462 |
+
config_class = AlbertConfig
|
463 |
+
pretrained_model_archive_map = ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
464 |
+
load_tf_weights = load_tf_weights_in_albert
|
465 |
+
base_model_prefix = "albert"
|
466 |
+
|
467 |
+
def __init__(self, config):
|
468 |
+
super().__init__(config)
|
469 |
+
|
470 |
+
self.config = config
|
471 |
+
self.embeddings = AlbertEmbeddings(config)
|
472 |
+
self.encoder = AlbertTransformer(config)
|
473 |
+
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
|
474 |
+
self.pooler_activation = nn.Tanh()
|
475 |
+
|
476 |
+
self.init_weights()
|
477 |
+
# hyper-param for patience-based adaptive inference
|
478 |
+
self.patience = 0
|
479 |
+
# threshold for confidence-based adaptive inference
|
480 |
+
self.confidence_threshold = 0.8
|
481 |
+
# mode for fast_inference [True for patience-based/ False for confidence-based/ All classifier/ Last Classifier]
|
482 |
+
self.mode = 'patience' # [patience/confi/all/last]
|
483 |
+
|
484 |
+
self.inference_instances_num = 0
|
485 |
+
self.inference_layers_num = 0
|
486 |
+
|
487 |
+
# exits count log
|
488 |
+
self.exits_count_list = [0] * self.config.num_hidden_layers
|
489 |
+
# exits time log
|
490 |
+
self.exits_time_list = [[] for _ in range(self.config.num_hidden_layers)]
|
491 |
+
|
492 |
+
self.regression_threshold = 0
|
493 |
+
|
494 |
+
def set_regression_threshold(self, threshold):
|
495 |
+
self.regression_threshold = threshold
|
496 |
+
|
497 |
+
def set_mode(self, patience='patience'):
|
498 |
+
self.mode = patience # mode for test-time inference
|
499 |
+
|
500 |
+
def set_patience(self, patience):
|
501 |
+
self.patience = patience
|
502 |
+
|
503 |
+
def set_exit_pos(self, exit_pos):
|
504 |
+
self.exit_pos = exit_pos
|
505 |
+
|
506 |
+
def set_confi_threshold(self, confidence_threshold):
|
507 |
+
self.confidence_threshold = confidence_threshold
|
508 |
+
|
509 |
+
def reset_stats(self):
|
510 |
+
self.inference_instances_num = 0
|
511 |
+
self.inference_layers_num = 0
|
512 |
+
self.exits_count_list = [0] * self.config.num_hidden_layers
|
513 |
+
self.exits_time_list = [[] for _ in range(self.config.num_hidden_layers)]
|
514 |
+
|
515 |
+
def log_stats(self):
|
516 |
+
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
517 |
+
message = f'*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***'
|
518 |
+
print(message)
|
519 |
+
|
520 |
+
def get_input_embeddings(self):
|
521 |
+
return self.embeddings.word_embeddings
|
522 |
+
|
523 |
+
def set_input_embeddings(self, value):
|
524 |
+
self.embeddings.word_embeddings = value
|
525 |
+
|
526 |
+
def _resize_token_embeddings(self, new_num_tokens):
|
527 |
+
old_embeddings = self.embeddings.word_embeddings
|
528 |
+
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
529 |
+
self.embeddings.word_embeddings = new_embeddings
|
530 |
+
return self.embeddings.word_embeddings
|
531 |
+
|
532 |
+
def _prune_heads(self, heads_to_prune):
|
533 |
+
""" Prunes heads of the model.
|
534 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
535 |
+
ALBERT has a different architecture in that its layers are shared across groups, which then has inner groups.
|
536 |
+
If an ALBERT model has 12 hidden layers and 2 hidden groups, with two inner groups, there
|
537 |
+
is a total of 4 different layers.
|
538 |
+
|
539 |
+
These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
|
540 |
+
while [2,3] correspond to the two inner groups of the second hidden layer.
|
541 |
+
|
542 |
+
Any layer with in index other than [0,1,2,3] will result in an error.
|
543 |
+
See base class PreTrainedModel for more information about head pruning
|
544 |
+
"""
|
545 |
+
for layer, heads in heads_to_prune.items():
|
546 |
+
group_idx = int(layer / self.config.inner_group_num)
|
547 |
+
inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
|
548 |
+
self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
|
549 |
+
|
550 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
551 |
+
def forward(
|
552 |
+
self,
|
553 |
+
input_ids=None,
|
554 |
+
attention_mask=None,
|
555 |
+
token_type_ids=None,
|
556 |
+
position_ids=None,
|
557 |
+
head_mask=None,
|
558 |
+
inputs_embeds=None,
|
559 |
+
output_dropout=None,
|
560 |
+
output_layers=None,
|
561 |
+
regression=False
|
562 |
+
):
|
563 |
+
r"""
|
564 |
+
Return:
|
565 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
566 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
567 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
568 |
+
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
569 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
570 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
571 |
+
layer weights are trained from the next sentence prediction (classification)
|
572 |
+
objective during pre-training.
|
573 |
+
|
574 |
+
This output is usually *not* a good summary
|
575 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
576 |
+
the sequence of hidden-states for the whole input sequence.
|
577 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
578 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
579 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
580 |
+
|
581 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
582 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
583 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
584 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
585 |
+
|
586 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
587 |
+
heads.
|
588 |
+
|
589 |
+
Example::
|
590 |
+
|
591 |
+
from transformers import AlbertModel, AlbertTokenizer
|
592 |
+
import torch
|
593 |
+
|
594 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
595 |
+
model = AlbertModel.from_pretrained('albert-base-v2')
|
596 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
597 |
+
outputs = model(input_ids)
|
598 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
599 |
+
|
600 |
+
"""
|
601 |
+
|
602 |
+
if input_ids is not None and inputs_embeds is not None:
|
603 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
604 |
+
elif input_ids is not None:
|
605 |
+
input_shape = input_ids.size()
|
606 |
+
elif inputs_embeds is not None:
|
607 |
+
input_shape = inputs_embeds.size()[:-1]
|
608 |
+
else:
|
609 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
610 |
+
|
611 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
612 |
+
|
613 |
+
if attention_mask is None:
|
614 |
+
attention_mask = torch.ones(input_shape, device=device)
|
615 |
+
if token_type_ids is None:
|
616 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
617 |
+
|
618 |
+
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
619 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
620 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
621 |
+
if head_mask is not None:
|
622 |
+
if head_mask.dim() == 1:
|
623 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
624 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
625 |
+
elif head_mask.dim() == 2:
|
626 |
+
head_mask = (
|
627 |
+
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
628 |
+
) # We can specify head_mask for each layer
|
629 |
+
head_mask = head_mask.to(
|
630 |
+
dtype=next(self.parameters()).dtype
|
631 |
+
) # switch to fload if need + fp16 compatibility
|
632 |
+
else:
|
633 |
+
head_mask = [None] * self.config.num_hidden_layers
|
634 |
+
|
635 |
+
embedding_output = self.embeddings(
|
636 |
+
input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
637 |
+
)
|
638 |
+
encoder_outputs = embedding_output
|
639 |
+
|
640 |
+
if self.training:
|
641 |
+
res = []
|
642 |
+
for i in range(self.config.num_hidden_layers):
|
643 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
644 |
+
current_layer=i,
|
645 |
+
attention_mask=extended_attention_mask,
|
646 |
+
head_mask=head_mask
|
647 |
+
)
|
648 |
+
|
649 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
650 |
+
logits = output_layers[i](output_dropout(pooled_output))
|
651 |
+
res.append(logits)
|
652 |
+
elif self.mode == 'last': # Use all layers for inference [last classifier]
|
653 |
+
encoder_outputs = self.encoder(encoder_outputs,
|
654 |
+
extended_attention_mask,
|
655 |
+
head_mask=head_mask)
|
656 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
657 |
+
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
658 |
+
elif self.mode == 'exact':
|
659 |
+
res = []
|
660 |
+
for i in range(self.exit_pos):
|
661 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
662 |
+
current_layer=i,
|
663 |
+
attention_mask=extended_attention_mask,
|
664 |
+
head_mask=head_mask
|
665 |
+
)
|
666 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
667 |
+
logits = output_layers[i](output_dropout(pooled_output))
|
668 |
+
res.append(logits)
|
669 |
+
elif self.mode == 'all':
|
670 |
+
tic = datetime.now()
|
671 |
+
res = []
|
672 |
+
for i in range(self.config.num_hidden_layers):
|
673 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
674 |
+
current_layer=i,
|
675 |
+
attention_mask=extended_attention_mask,
|
676 |
+
head_mask=head_mask
|
677 |
+
)
|
678 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
679 |
+
logits = output_layers[i](output_dropout(pooled_output))
|
680 |
+
toc = datetime.now()
|
681 |
+
exit_time = (toc - tic).total_seconds()
|
682 |
+
res.append(logits)
|
683 |
+
self.exits_time_list[i].append(exit_time)
|
684 |
+
elif self.mode=='patience': # fast inference for patience-based
|
685 |
+
if self.patience <=0:
|
686 |
+
raise ValueError("Patience must be greater than 0")
|
687 |
+
|
688 |
+
patient_counter = 0
|
689 |
+
patient_result = None
|
690 |
+
calculated_layer_num = 0
|
691 |
+
# tic = datetime.now()
|
692 |
+
for i in range(self.config.num_hidden_layers):
|
693 |
+
calculated_layer_num += 1
|
694 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
695 |
+
current_layer=i,
|
696 |
+
attention_mask=extended_attention_mask,
|
697 |
+
head_mask=head_mask
|
698 |
+
)
|
699 |
+
|
700 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
701 |
+
logits = output_layers[i](pooled_output)
|
702 |
+
if regression:
|
703 |
+
labels = logits.detach()
|
704 |
+
if patient_result is not None:
|
705 |
+
patient_labels = patient_result.detach()
|
706 |
+
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
707 |
+
patient_counter += 1
|
708 |
+
else:
|
709 |
+
patient_counter = 0
|
710 |
+
else:
|
711 |
+
labels = logits.detach().argmax(dim=1)
|
712 |
+
if patient_result is not None:
|
713 |
+
patient_labels = patient_result.detach().argmax(dim=1)
|
714 |
+
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
715 |
+
patient_counter += 1
|
716 |
+
else:
|
717 |
+
patient_counter = 0
|
718 |
+
|
719 |
+
patient_result = logits
|
720 |
+
if patient_counter == self.patience:
|
721 |
+
break
|
722 |
+
# toc = datetime.now()
|
723 |
+
# self.exit_time = (toc - tic).total_seconds()
|
724 |
+
res = [patient_result]
|
725 |
+
self.inference_layers_num += calculated_layer_num
|
726 |
+
self.inference_instances_num += 1
|
727 |
+
self.current_exit_layer = calculated_layer_num
|
728 |
+
# LOG EXIT POINTS COUNTS
|
729 |
+
self.exits_count_list[calculated_layer_num-1] += 1
|
730 |
+
elif self.mode == 'confi':
|
731 |
+
if self.confidence_threshold<0 or self.confidence_threshold>1:
|
732 |
+
raise ValueError('Confidence Threshold must be set within the range 0-1')
|
733 |
+
calculated_layer_num = 0
|
734 |
+
tic = datetime.now()
|
735 |
+
for i in range(self.config.num_hidden_layers):
|
736 |
+
calculated_layer_num += 1
|
737 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
738 |
+
current_layer=i,
|
739 |
+
attention_mask=extended_attention_mask,
|
740 |
+
head_mask=head_mask
|
741 |
+
)
|
742 |
+
|
743 |
+
pooled_output = self.pooler_activation(self.pooler(encoder_outputs[0][:, 0]))
|
744 |
+
logits = output_layers[i](pooled_output)
|
745 |
+
labels = logits.detach().argmax(dim=1)
|
746 |
+
logits_max,_ = logits.detach().softmax(dim=1).max(dim=1)
|
747 |
+
|
748 |
+
confi_result = logits
|
749 |
+
if torch.all(logits_max.gt(self.confidence_threshold)):
|
750 |
+
break
|
751 |
+
toc = datetime.now()
|
752 |
+
self.exit_time = (toc - tic).total_seconds()
|
753 |
+
res = [confi_result]
|
754 |
+
self.inference_layers_num += calculated_layer_num
|
755 |
+
self.inference_instances_num += 1
|
756 |
+
self.current_exit_layer = calculated_layer_num
|
757 |
+
# LOG EXIT POINTS COUNTS
|
758 |
+
self.exits_count_list[calculated_layer_num-1] += 1
|
759 |
+
return res
|
760 |
+
|
761 |
+
|
762 |
+
class AlbertMLMHead(nn.Module):
|
763 |
+
def __init__(self, config):
|
764 |
+
super().__init__()
|
765 |
+
|
766 |
+
self.LayerNorm = nn.LayerNorm(config.embedding_size)
|
767 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
768 |
+
self.dense = nn.Linear(config.hidden_size, config.embedding_size)
|
769 |
+
self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
|
770 |
+
self.activation = ACT2FN[config.hidden_act]
|
771 |
+
|
772 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
773 |
+
self.decoder.bias = self.bias
|
774 |
+
|
775 |
+
def forward(self, hidden_states):
|
776 |
+
hidden_states = self.dense(hidden_states)
|
777 |
+
hidden_states = self.activation(hidden_states)
|
778 |
+
hidden_states = self.LayerNorm(hidden_states)
|
779 |
+
hidden_states = self.decoder(hidden_states)
|
780 |
+
|
781 |
+
prediction_scores = hidden_states
|
782 |
+
|
783 |
+
return prediction_scores
|
784 |
+
|
785 |
+
|
786 |
+
@add_start_docstrings(
|
787 |
+
"Albert Model with a `language modeling` head on top.", ALBERT_START_DOCSTRING,
|
788 |
+
)
|
789 |
+
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
790 |
+
def __init__(self, config):
|
791 |
+
super().__init__(config)
|
792 |
+
|
793 |
+
self.albert = AlbertModel(config)
|
794 |
+
self.predictions = AlbertMLMHead(config)
|
795 |
+
|
796 |
+
self.init_weights()
|
797 |
+
self.tie_weights()
|
798 |
+
|
799 |
+
def tie_weights(self):
|
800 |
+
self._tie_or_clone_weights(self.predictions.decoder, self.albert.embeddings.word_embeddings)
|
801 |
+
|
802 |
+
def get_output_embeddings(self):
|
803 |
+
return self.predictions.decoder
|
804 |
+
|
805 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
806 |
+
def forward(
|
807 |
+
self,
|
808 |
+
input_ids=None,
|
809 |
+
attention_mask=None,
|
810 |
+
token_type_ids=None,
|
811 |
+
position_ids=None,
|
812 |
+
head_mask=None,
|
813 |
+
inputs_embeds=None,
|
814 |
+
masked_lm_labels=None,
|
815 |
+
):
|
816 |
+
r"""
|
817 |
+
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
818 |
+
Labels for computing the masked language modeling loss.
|
819 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
820 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with
|
821 |
+
labels in ``[0, ..., config.vocab_size]``
|
822 |
+
|
823 |
+
Returns:
|
824 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
825 |
+
loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
826 |
+
Masked language modeling loss.
|
827 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
828 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
829 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
830 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
831 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
832 |
+
|
833 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
834 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
835 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
836 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
837 |
+
|
838 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
839 |
+
heads.
|
840 |
+
|
841 |
+
Example::
|
842 |
+
|
843 |
+
from transformers import AlbertTokenizer, AlbertForMaskedLM
|
844 |
+
import torch
|
845 |
+
|
846 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
847 |
+
model = AlbertForMaskedLM.from_pretrained('albert-base-v2')
|
848 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
849 |
+
outputs = model(input_ids, masked_lm_labels=input_ids)
|
850 |
+
loss, prediction_scores = outputs[:2]
|
851 |
+
|
852 |
+
"""
|
853 |
+
outputs = self.albert(
|
854 |
+
input_ids=input_ids,
|
855 |
+
attention_mask=attention_mask,
|
856 |
+
token_type_ids=token_type_ids,
|
857 |
+
position_ids=position_ids,
|
858 |
+
head_mask=head_mask,
|
859 |
+
inputs_embeds=inputs_embeds,
|
860 |
+
)
|
861 |
+
sequence_outputs = outputs[0]
|
862 |
+
|
863 |
+
prediction_scores = self.predictions(sequence_outputs)
|
864 |
+
|
865 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
866 |
+
if masked_lm_labels is not None:
|
867 |
+
loss_fct = CrossEntropyLoss()
|
868 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
869 |
+
outputs = (masked_lm_loss,) + outputs
|
870 |
+
|
871 |
+
return outputs
|
872 |
+
|
873 |
+
|
874 |
+
@add_start_docstrings(
|
875 |
+
"""Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
876 |
+
the pooled output) e.g. for GLUE tasks. """,
|
877 |
+
ALBERT_START_DOCSTRING,
|
878 |
+
)
|
879 |
+
class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
880 |
+
def __init__(self, config):
|
881 |
+
super().__init__(config)
|
882 |
+
self.num_labels = config.num_labels
|
883 |
+
|
884 |
+
self.albert = AlbertModel(config)
|
885 |
+
self.dropout = nn.Dropout(config.classifier_dropout_prob)
|
886 |
+
self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)])
|
887 |
+
|
888 |
+
self.init_weights()
|
889 |
+
|
890 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
891 |
+
def forward(
|
892 |
+
self,
|
893 |
+
input_ids=None,
|
894 |
+
attention_mask=None,
|
895 |
+
token_type_ids=None,
|
896 |
+
position_ids=None,
|
897 |
+
head_mask=None,
|
898 |
+
inputs_embeds=None,
|
899 |
+
labels=None,
|
900 |
+
):
|
901 |
+
r"""
|
902 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
903 |
+
Labels for computing the sequence classification/regression loss.
|
904 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
905 |
+
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
|
906 |
+
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
|
907 |
+
|
908 |
+
Returns:
|
909 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
910 |
+
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
911 |
+
Classification (or regression if config.num_labels==1) loss.
|
912 |
+
logits ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
913 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
914 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
915 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
916 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
917 |
+
|
918 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
919 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
920 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
921 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
922 |
+
|
923 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
924 |
+
heads.
|
925 |
+
|
926 |
+
Examples::
|
927 |
+
|
928 |
+
from transformers import AlbertTokenizer, AlbertForSequenceClassification
|
929 |
+
import torch
|
930 |
+
|
931 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
932 |
+
model = AlbertForSequenceClassification.from_pretrained('albert-base-v2')
|
933 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
934 |
+
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
935 |
+
outputs = model(input_ids, labels=labels)
|
936 |
+
loss, logits = outputs[:2]
|
937 |
+
|
938 |
+
"""
|
939 |
+
|
940 |
+
logits = self.albert(
|
941 |
+
input_ids=input_ids,
|
942 |
+
attention_mask=attention_mask,
|
943 |
+
token_type_ids=token_type_ids,
|
944 |
+
position_ids=position_ids,
|
945 |
+
head_mask=head_mask,
|
946 |
+
inputs_embeds=inputs_embeds,
|
947 |
+
output_dropout=self.dropout,
|
948 |
+
output_layers=self.classifiers,
|
949 |
+
regression=self.num_labels == 1
|
950 |
+
)
|
951 |
+
|
952 |
+
if self.albert.mode == 'all':
|
953 |
+
outputs = (logits,)
|
954 |
+
else:
|
955 |
+
outputs = (logits[-1],)
|
956 |
+
|
957 |
+
if labels is not None:
|
958 |
+
total_loss = None
|
959 |
+
total_weights = 0
|
960 |
+
for ix, logits_item in enumerate(logits):
|
961 |
+
if self.num_labels == 1:
|
962 |
+
# We are doing regression
|
963 |
+
loss_fct = MSELoss()
|
964 |
+
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
965 |
+
else:
|
966 |
+
loss_fct = CrossEntropyLoss()
|
967 |
+
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
968 |
+
if total_loss is None:
|
969 |
+
total_loss = loss
|
970 |
+
else:
|
971 |
+
total_loss += loss * (ix + 1)
|
972 |
+
total_weights += ix + 1
|
973 |
+
outputs = (total_loss / total_weights,) + outputs
|
974 |
+
|
975 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
976 |
+
|
977 |
+
|
978 |
+
@add_start_docstrings(
|
979 |
+
"""Albert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
980 |
+
the hidden-states output to compute `span start logits` and `span end logits`). """,
|
981 |
+
ALBERT_START_DOCSTRING,
|
982 |
+
)
|
983 |
+
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
984 |
+
def __init__(self, config):
|
985 |
+
super().__init__(config)
|
986 |
+
self.num_labels = config.num_labels
|
987 |
+
|
988 |
+
self.albert = AlbertModel(config)
|
989 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
990 |
+
|
991 |
+
self.init_weights()
|
992 |
+
|
993 |
+
@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING)
|
994 |
+
def forward(
|
995 |
+
self,
|
996 |
+
input_ids=None,
|
997 |
+
attention_mask=None,
|
998 |
+
token_type_ids=None,
|
999 |
+
position_ids=None,
|
1000 |
+
head_mask=None,
|
1001 |
+
inputs_embeds=None,
|
1002 |
+
start_positions=None,
|
1003 |
+
end_positions=None,
|
1004 |
+
):
|
1005 |
+
r"""
|
1006 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1007 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1008 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1009 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1010 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1011 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1012 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1013 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1014 |
+
|
1015 |
+
Returns:
|
1016 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.AlbertConfig`) and inputs:
|
1017 |
+
loss: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
1018 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
1019 |
+
start_scores ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
1020 |
+
Span-start scores (before SoftMax).
|
1021 |
+
end_scores: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
1022 |
+
Span-end scores (before SoftMax).
|
1023 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1024 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1025 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1026 |
+
|
1027 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1028 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1029 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1030 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1031 |
+
|
1032 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1033 |
+
heads.
|
1034 |
+
|
1035 |
+
Examples::
|
1036 |
+
|
1037 |
+
# The checkpoint albert-base-v2 is not fine-tuned for question answering. Please see the
|
1038 |
+
# examples/run_squad.py example to see how to fine-tune a model to a question answering task.
|
1039 |
+
|
1040 |
+
from transformers import AlbertTokenizer, AlbertForQuestionAnswering
|
1041 |
+
import torch
|
1042 |
+
|
1043 |
+
tokenizer = AlbertTokenizer.from_pretrained('albert-base-v2')
|
1044 |
+
model = AlbertForQuestionAnswering.from_pretrained('albert-base-v2')
|
1045 |
+
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
1046 |
+
input_dict = tokenizer.encode_plus(question, text, return_tensors='pt')
|
1047 |
+
start_scores, end_scores = model(**input_dict)
|
1048 |
+
|
1049 |
+
"""
|
1050 |
+
|
1051 |
+
outputs = self.albert(
|
1052 |
+
input_ids=input_ids,
|
1053 |
+
attention_mask=attention_mask,
|
1054 |
+
token_type_ids=token_type_ids,
|
1055 |
+
position_ids=position_ids,
|
1056 |
+
head_mask=head_mask,
|
1057 |
+
inputs_embeds=inputs_embeds,
|
1058 |
+
)
|
1059 |
+
|
1060 |
+
sequence_output = outputs[0]
|
1061 |
+
|
1062 |
+
logits = self.qa_outputs(sequence_output)
|
1063 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1064 |
+
start_logits = start_logits.squeeze(-1)
|
1065 |
+
end_logits = end_logits.squeeze(-1)
|
1066 |
+
|
1067 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
1068 |
+
if start_positions is not None and end_positions is not None:
|
1069 |
+
# If we are on multi-GPU, split add a dimension
|
1070 |
+
if len(start_positions.size()) > 1:
|
1071 |
+
start_positions = start_positions.squeeze(-1)
|
1072 |
+
if len(end_positions.size()) > 1:
|
1073 |
+
end_positions = end_positions.squeeze(-1)
|
1074 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1075 |
+
ignored_index = start_logits.size(1)
|
1076 |
+
start_positions.clamp_(0, ignored_index)
|
1077 |
+
end_positions.clamp_(0, ignored_index)
|
1078 |
+
|
1079 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1080 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1081 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1082 |
+
total_loss = (start_loss + end_loss) / 2
|
1083 |
+
outputs = (total_loss,) + outputs
|
1084 |
+
|
1085 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
pabee/modeling_bert.py
ADDED
@@ -0,0 +1,1663 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""PyTorch BERT model. """
|
17 |
+
|
18 |
+
|
19 |
+
import logging
|
20 |
+
import math
|
21 |
+
import os
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch import nn
|
25 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
26 |
+
|
27 |
+
from transformers.activations import gelu, gelu_new #, swish
|
28 |
+
from transformers.models.bert.configuration_bert import BertConfig
|
29 |
+
from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward
|
30 |
+
from transformers.modeling_utils import PreTrainedModel, prune_linear_layer
|
31 |
+
from datetime import datetime
|
32 |
+
|
33 |
+
logger = logging.getLogger(__name__)
|
34 |
+
|
35 |
+
BERT_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
36 |
+
"bert-base-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-pytorch_model.bin",
|
37 |
+
"bert-large-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-pytorch_model.bin",
|
38 |
+
"bert-base-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-pytorch_model.bin",
|
39 |
+
"bert-large-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-pytorch_model.bin",
|
40 |
+
"bert-base-multilingual-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-uncased-pytorch_model.bin",
|
41 |
+
"bert-base-multilingual-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-multilingual-cased-pytorch_model.bin",
|
42 |
+
"bert-base-chinese": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-pytorch_model.bin",
|
43 |
+
"bert-base-german-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-cased-pytorch_model.bin",
|
44 |
+
"bert-large-uncased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-pytorch_model.bin",
|
45 |
+
"bert-large-cased-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-pytorch_model.bin",
|
46 |
+
"bert-large-uncased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
47 |
+
"bert-large-cased-whole-word-masking-finetuned-squad": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
|
48 |
+
"bert-base-cased-finetuned-mrpc": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
|
49 |
+
"bert-base-german-dbmdz-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-cased-pytorch_model.bin",
|
50 |
+
"bert-base-german-dbmdz-uncased": "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-german-dbmdz-uncased-pytorch_model.bin",
|
51 |
+
"bert-base-japanese": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-pytorch_model.bin",
|
52 |
+
"bert-base-japanese-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-whole-word-masking-pytorch_model.bin",
|
53 |
+
"bert-base-japanese-char": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-pytorch_model.bin",
|
54 |
+
"bert-base-japanese-char-whole-word-masking": "https://s3.amazonaws.com/models.huggingface.co/bert/cl-tohoku/bert-base-japanese-char-whole-word-masking-pytorch_model.bin",
|
55 |
+
"bert-base-finnish-cased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-cased-v1/pytorch_model.bin",
|
56 |
+
"bert-base-finnish-uncased-v1": "https://s3.amazonaws.com/models.huggingface.co/bert/TurkuNLP/bert-base-finnish-uncased-v1/pytorch_model.bin",
|
57 |
+
"bert-base-dutch-cased": "https://s3.amazonaws.com/models.huggingface.co/bert/wietsedv/bert-base-dutch-cased/pytorch_model.bin",
|
58 |
+
}
|
59 |
+
|
60 |
+
|
61 |
+
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
62 |
+
""" Load tf checkpoints in a pytorch model.
|
63 |
+
"""
|
64 |
+
try:
|
65 |
+
import re
|
66 |
+
import numpy as np
|
67 |
+
import tensorflow as tf
|
68 |
+
except ImportError:
|
69 |
+
logger.error(
|
70 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
71 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
72 |
+
)
|
73 |
+
raise
|
74 |
+
tf_path = os.path.abspath(tf_checkpoint_path)
|
75 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
76 |
+
# Load weights from TF model
|
77 |
+
init_vars = tf.train.list_variables(tf_path)
|
78 |
+
names = []
|
79 |
+
arrays = []
|
80 |
+
for name, shape in init_vars:
|
81 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
82 |
+
array = tf.train.load_variable(tf_path, name)
|
83 |
+
names.append(name)
|
84 |
+
arrays.append(array)
|
85 |
+
|
86 |
+
for name, array in zip(names, arrays):
|
87 |
+
name = name.split("/")
|
88 |
+
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
89 |
+
# which are not required for using pretrained model
|
90 |
+
if any(
|
91 |
+
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
|
92 |
+
for n in name
|
93 |
+
):
|
94 |
+
logger.info("Skipping {}".format("/".join(name)))
|
95 |
+
continue
|
96 |
+
pointer = model
|
97 |
+
for m_name in name:
|
98 |
+
if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
|
99 |
+
scope_names = re.split(r"_(\d+)", m_name)
|
100 |
+
else:
|
101 |
+
scope_names = [m_name]
|
102 |
+
if scope_names[0] == "kernel" or scope_names[0] == "gamma":
|
103 |
+
pointer = getattr(pointer, "weight")
|
104 |
+
elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
|
105 |
+
pointer = getattr(pointer, "bias")
|
106 |
+
elif scope_names[0] == "output_weights":
|
107 |
+
pointer = getattr(pointer, "weight")
|
108 |
+
elif scope_names[0] == "squad":
|
109 |
+
pointer = getattr(pointer, "classifier")
|
110 |
+
else:
|
111 |
+
try:
|
112 |
+
pointer = getattr(pointer, scope_names[0])
|
113 |
+
except AttributeError:
|
114 |
+
logger.info("Skipping {}".format("/".join(name)))
|
115 |
+
continue
|
116 |
+
if len(scope_names) >= 2:
|
117 |
+
num = int(scope_names[1])
|
118 |
+
pointer = pointer[num]
|
119 |
+
if m_name[-11:] == "_embeddings":
|
120 |
+
pointer = getattr(pointer, "weight")
|
121 |
+
elif m_name == "kernel":
|
122 |
+
array = np.transpose(array)
|
123 |
+
try:
|
124 |
+
assert pointer.shape == array.shape
|
125 |
+
except AssertionError as e:
|
126 |
+
e.args += (pointer.shape, array.shape)
|
127 |
+
raise
|
128 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
129 |
+
pointer.data = torch.from_numpy(array)
|
130 |
+
return model
|
131 |
+
|
132 |
+
|
133 |
+
def mish(x):
|
134 |
+
return x * torch.tanh(nn.functional.softplus(x))
|
135 |
+
|
136 |
+
|
137 |
+
# ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish, "gelu_new": gelu_new, "mish": mish}
|
138 |
+
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "gelu_new": gelu_new, "mish": mish}
|
139 |
+
|
140 |
+
|
141 |
+
BertLayerNorm = torch.nn.LayerNorm
|
142 |
+
|
143 |
+
|
144 |
+
class BertEmbeddings(nn.Module):
|
145 |
+
"""Construct the embeddings from word, position and token_type embeddings.
|
146 |
+
"""
|
147 |
+
|
148 |
+
def __init__(self, config):
|
149 |
+
super().__init__()
|
150 |
+
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
|
151 |
+
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
|
152 |
+
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
|
153 |
+
|
154 |
+
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
155 |
+
# any TensorFlow checkpoint file
|
156 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
157 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
158 |
+
|
159 |
+
def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
|
160 |
+
if input_ids is not None:
|
161 |
+
input_shape = input_ids.size()
|
162 |
+
else:
|
163 |
+
input_shape = inputs_embeds.size()[:-1]
|
164 |
+
|
165 |
+
seq_length = input_shape[1]
|
166 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
167 |
+
if position_ids is None:
|
168 |
+
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
169 |
+
position_ids = position_ids.unsqueeze(0).expand(input_shape)
|
170 |
+
if token_type_ids is None:
|
171 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
172 |
+
|
173 |
+
if inputs_embeds is None:
|
174 |
+
inputs_embeds = self.word_embeddings(input_ids)
|
175 |
+
position_embeddings = self.position_embeddings(position_ids)
|
176 |
+
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
177 |
+
|
178 |
+
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
179 |
+
embeddings = self.LayerNorm(embeddings)
|
180 |
+
embeddings = self.dropout(embeddings)
|
181 |
+
return embeddings
|
182 |
+
|
183 |
+
|
184 |
+
class BertSelfAttention(nn.Module):
|
185 |
+
def __init__(self, config):
|
186 |
+
super().__init__()
|
187 |
+
if config.hidden_size % config.num_attention_heads != 0:
|
188 |
+
raise ValueError(
|
189 |
+
"The hidden size (%d) is not a multiple of the number of attention "
|
190 |
+
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
191 |
+
)
|
192 |
+
self.output_attentions = config.output_attentions
|
193 |
+
|
194 |
+
self.num_attention_heads = config.num_attention_heads
|
195 |
+
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
196 |
+
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
197 |
+
|
198 |
+
self.query = nn.Linear(config.hidden_size, self.all_head_size)
|
199 |
+
self.key = nn.Linear(config.hidden_size, self.all_head_size)
|
200 |
+
self.value = nn.Linear(config.hidden_size, self.all_head_size)
|
201 |
+
|
202 |
+
self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
|
203 |
+
|
204 |
+
def transpose_for_scores(self, x):
|
205 |
+
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
|
206 |
+
x = x.view(*new_x_shape)
|
207 |
+
return x.permute(0, 2, 1, 3)
|
208 |
+
|
209 |
+
def forward(
|
210 |
+
self,
|
211 |
+
hidden_states,
|
212 |
+
attention_mask=None,
|
213 |
+
head_mask=None,
|
214 |
+
encoder_hidden_states=None,
|
215 |
+
encoder_attention_mask=None,
|
216 |
+
):
|
217 |
+
mixed_query_layer = self.query(hidden_states)
|
218 |
+
|
219 |
+
# If this is instantiated as a cross-attention module, the keys
|
220 |
+
# and values come from an encoder; the attention mask needs to be
|
221 |
+
# such that the encoder's padding tokens are not attended to.
|
222 |
+
if encoder_hidden_states is not None:
|
223 |
+
mixed_key_layer = self.key(encoder_hidden_states)
|
224 |
+
mixed_value_layer = self.value(encoder_hidden_states)
|
225 |
+
attention_mask = encoder_attention_mask
|
226 |
+
else:
|
227 |
+
mixed_key_layer = self.key(hidden_states)
|
228 |
+
mixed_value_layer = self.value(hidden_states)
|
229 |
+
|
230 |
+
query_layer = self.transpose_for_scores(mixed_query_layer)
|
231 |
+
key_layer = self.transpose_for_scores(mixed_key_layer)
|
232 |
+
value_layer = self.transpose_for_scores(mixed_value_layer)
|
233 |
+
|
234 |
+
# Take the dot product between "query" and "key" to get the raw attention scores.
|
235 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
236 |
+
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
237 |
+
if attention_mask is not None:
|
238 |
+
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
239 |
+
attention_scores = attention_scores + attention_mask
|
240 |
+
|
241 |
+
# Normalize the attention scores to probabilities.
|
242 |
+
attention_probs = nn.Softmax(dim=-1)(attention_scores)
|
243 |
+
|
244 |
+
# This is actually dropping out entire tokens to attend to, which might
|
245 |
+
# seem a bit unusual, but is taken from the original Transformer paper.
|
246 |
+
attention_probs = self.dropout(attention_probs)
|
247 |
+
|
248 |
+
# Mask heads if we want to
|
249 |
+
if head_mask is not None:
|
250 |
+
attention_probs = attention_probs * head_mask
|
251 |
+
|
252 |
+
context_layer = torch.matmul(attention_probs, value_layer)
|
253 |
+
|
254 |
+
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
255 |
+
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
|
256 |
+
context_layer = context_layer.view(*new_context_layer_shape)
|
257 |
+
|
258 |
+
outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
|
259 |
+
return outputs
|
260 |
+
|
261 |
+
|
262 |
+
class BertSelfOutput(nn.Module):
|
263 |
+
def __init__(self, config):
|
264 |
+
super().__init__()
|
265 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
266 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
267 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
268 |
+
|
269 |
+
def forward(self, hidden_states, input_tensor):
|
270 |
+
hidden_states = self.dense(hidden_states)
|
271 |
+
hidden_states = self.dropout(hidden_states)
|
272 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
273 |
+
return hidden_states
|
274 |
+
|
275 |
+
|
276 |
+
class BertAttention(nn.Module):
|
277 |
+
def __init__(self, config):
|
278 |
+
super().__init__()
|
279 |
+
self.self = BertSelfAttention(config)
|
280 |
+
self.output = BertSelfOutput(config)
|
281 |
+
self.pruned_heads = set()
|
282 |
+
|
283 |
+
def prune_heads(self, heads):
|
284 |
+
if len(heads) == 0:
|
285 |
+
return
|
286 |
+
mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size)
|
287 |
+
heads = set(heads) - self.pruned_heads # Convert to set and remove already pruned heads
|
288 |
+
for head in heads:
|
289 |
+
# Compute how many pruned heads are before the head and move the index accordingly
|
290 |
+
head = head - sum(1 if h < head else 0 for h in self.pruned_heads)
|
291 |
+
mask[head] = 0
|
292 |
+
mask = mask.view(-1).contiguous().eq(1)
|
293 |
+
index = torch.arange(len(mask))[mask].long()
|
294 |
+
|
295 |
+
# Prune linear layers
|
296 |
+
self.self.query = prune_linear_layer(self.self.query, index)
|
297 |
+
self.self.key = prune_linear_layer(self.self.key, index)
|
298 |
+
self.self.value = prune_linear_layer(self.self.value, index)
|
299 |
+
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
|
300 |
+
|
301 |
+
# Update hyper params and store pruned heads
|
302 |
+
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
303 |
+
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
|
304 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
305 |
+
|
306 |
+
def forward(
|
307 |
+
self,
|
308 |
+
hidden_states,
|
309 |
+
attention_mask=None,
|
310 |
+
head_mask=None,
|
311 |
+
encoder_hidden_states=None,
|
312 |
+
encoder_attention_mask=None,
|
313 |
+
):
|
314 |
+
self_outputs = self.self(
|
315 |
+
hidden_states, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
316 |
+
)
|
317 |
+
attention_output = self.output(self_outputs[0], hidden_states)
|
318 |
+
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
319 |
+
return outputs
|
320 |
+
|
321 |
+
|
322 |
+
class BertIntermediate(nn.Module):
|
323 |
+
def __init__(self, config):
|
324 |
+
super().__init__()
|
325 |
+
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
326 |
+
if isinstance(config.hidden_act, str):
|
327 |
+
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
328 |
+
else:
|
329 |
+
self.intermediate_act_fn = config.hidden_act
|
330 |
+
|
331 |
+
def forward(self, hidden_states):
|
332 |
+
hidden_states = self.dense(hidden_states)
|
333 |
+
hidden_states = self.intermediate_act_fn(hidden_states)
|
334 |
+
return hidden_states
|
335 |
+
|
336 |
+
|
337 |
+
class BertOutput(nn.Module):
|
338 |
+
def __init__(self, config):
|
339 |
+
super().__init__()
|
340 |
+
self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
341 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
342 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
343 |
+
|
344 |
+
def forward(self, hidden_states, input_tensor):
|
345 |
+
hidden_states = self.dense(hidden_states)
|
346 |
+
hidden_states = self.dropout(hidden_states)
|
347 |
+
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
348 |
+
return hidden_states
|
349 |
+
|
350 |
+
|
351 |
+
class BertLayer(nn.Module):
|
352 |
+
def __init__(self, config):
|
353 |
+
super().__init__()
|
354 |
+
self.attention = BertAttention(config)
|
355 |
+
self.is_decoder = config.is_decoder
|
356 |
+
if self.is_decoder:
|
357 |
+
self.crossattention = BertAttention(config)
|
358 |
+
self.intermediate = BertIntermediate(config)
|
359 |
+
self.output = BertOutput(config)
|
360 |
+
|
361 |
+
def forward(
|
362 |
+
self,
|
363 |
+
hidden_states,
|
364 |
+
attention_mask=None,
|
365 |
+
head_mask=None,
|
366 |
+
encoder_hidden_states=None,
|
367 |
+
encoder_attention_mask=None,
|
368 |
+
):
|
369 |
+
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
370 |
+
attention_output = self_attention_outputs[0]
|
371 |
+
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
372 |
+
|
373 |
+
if self.is_decoder and encoder_hidden_states is not None:
|
374 |
+
cross_attention_outputs = self.crossattention(
|
375 |
+
attention_output, attention_mask, head_mask, encoder_hidden_states, encoder_attention_mask
|
376 |
+
)
|
377 |
+
attention_output = cross_attention_outputs[0]
|
378 |
+
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
379 |
+
|
380 |
+
intermediate_output = self.intermediate(attention_output)
|
381 |
+
layer_output = self.output(intermediate_output, attention_output)
|
382 |
+
outputs = (layer_output,) + outputs
|
383 |
+
return outputs
|
384 |
+
|
385 |
+
|
386 |
+
class BertEncoder(nn.Module):
|
387 |
+
def __init__(self, config):
|
388 |
+
super().__init__()
|
389 |
+
self.config = config
|
390 |
+
self.output_attentions = config.output_attentions
|
391 |
+
self.output_hidden_states = config.output_hidden_states
|
392 |
+
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
393 |
+
|
394 |
+
def forward(
|
395 |
+
self,
|
396 |
+
hidden_states,
|
397 |
+
attention_mask=None,
|
398 |
+
head_mask=None,
|
399 |
+
encoder_hidden_states=None,
|
400 |
+
encoder_attention_mask=None,
|
401 |
+
):
|
402 |
+
all_hidden_states = ()
|
403 |
+
all_attentions = ()
|
404 |
+
for i, layer_module in enumerate(self.layer):
|
405 |
+
if self.output_hidden_states:
|
406 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
407 |
+
|
408 |
+
layer_outputs = layer_module(
|
409 |
+
hidden_states, attention_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask
|
410 |
+
)
|
411 |
+
hidden_states = layer_outputs[0]
|
412 |
+
|
413 |
+
if self.output_attentions:
|
414 |
+
all_attentions = all_attentions + (layer_outputs[1],)
|
415 |
+
|
416 |
+
# Add last layer
|
417 |
+
if self.output_hidden_states:
|
418 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
419 |
+
|
420 |
+
outputs = (hidden_states,)
|
421 |
+
if self.output_hidden_states:
|
422 |
+
outputs = outputs + (all_hidden_states,)
|
423 |
+
if self.output_attentions:
|
424 |
+
outputs = outputs + (all_attentions,)
|
425 |
+
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
426 |
+
|
427 |
+
def adaptive_forward(self, hidden_states, current_layer, attention_mask=None, head_mask=None):
|
428 |
+
layer_outputs = self.layer[current_layer](hidden_states, attention_mask, head_mask[current_layer])
|
429 |
+
|
430 |
+
hidden_states = layer_outputs[0]
|
431 |
+
|
432 |
+
return hidden_states
|
433 |
+
|
434 |
+
|
435 |
+
class BertPooler(nn.Module):
|
436 |
+
def __init__(self, config):
|
437 |
+
super().__init__()
|
438 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
439 |
+
self.activation = nn.Tanh()
|
440 |
+
|
441 |
+
def forward(self, hidden_states):
|
442 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
443 |
+
# to the first token.
|
444 |
+
first_token_tensor = hidden_states[:, 0]
|
445 |
+
pooled_output = self.dense(first_token_tensor)
|
446 |
+
pooled_output = self.activation(pooled_output)
|
447 |
+
return pooled_output
|
448 |
+
|
449 |
+
|
450 |
+
class BertPredictionHeadTransform(nn.Module):
|
451 |
+
def __init__(self, config):
|
452 |
+
super().__init__()
|
453 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
454 |
+
if isinstance(config.hidden_act, str):
|
455 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
456 |
+
else:
|
457 |
+
self.transform_act_fn = config.hidden_act
|
458 |
+
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
459 |
+
|
460 |
+
def forward(self, hidden_states):
|
461 |
+
hidden_states = self.dense(hidden_states)
|
462 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
463 |
+
hidden_states = self.LayerNorm(hidden_states)
|
464 |
+
return hidden_states
|
465 |
+
|
466 |
+
|
467 |
+
class BertLMPredictionHead(nn.Module):
|
468 |
+
def __init__(self, config):
|
469 |
+
super().__init__()
|
470 |
+
self.transform = BertPredictionHeadTransform(config)
|
471 |
+
|
472 |
+
# The output weights are the same as the input embeddings, but there is
|
473 |
+
# an output-only bias for each token.
|
474 |
+
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
|
475 |
+
|
476 |
+
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
|
477 |
+
|
478 |
+
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
479 |
+
self.decoder.bias = self.bias
|
480 |
+
|
481 |
+
def forward(self, hidden_states):
|
482 |
+
hidden_states = self.transform(hidden_states)
|
483 |
+
hidden_states = self.decoder(hidden_states)
|
484 |
+
return hidden_states
|
485 |
+
|
486 |
+
|
487 |
+
class BertOnlyMLMHead(nn.Module):
|
488 |
+
def __init__(self, config):
|
489 |
+
super().__init__()
|
490 |
+
self.predictions = BertLMPredictionHead(config)
|
491 |
+
|
492 |
+
def forward(self, sequence_output):
|
493 |
+
prediction_scores = self.predictions(sequence_output)
|
494 |
+
return prediction_scores
|
495 |
+
|
496 |
+
|
497 |
+
class BertOnlyNSPHead(nn.Module):
|
498 |
+
def __init__(self, config):
|
499 |
+
super().__init__()
|
500 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
501 |
+
|
502 |
+
def forward(self, pooled_output):
|
503 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
504 |
+
return seq_relationship_score
|
505 |
+
|
506 |
+
|
507 |
+
class BertPreTrainingHeads(nn.Module):
|
508 |
+
def __init__(self, config):
|
509 |
+
super().__init__()
|
510 |
+
self.predictions = BertLMPredictionHead(config)
|
511 |
+
self.seq_relationship = nn.Linear(config.hidden_size, 2)
|
512 |
+
|
513 |
+
def forward(self, sequence_output, pooled_output):
|
514 |
+
prediction_scores = self.predictions(sequence_output)
|
515 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
516 |
+
return prediction_scores, seq_relationship_score
|
517 |
+
|
518 |
+
|
519 |
+
class BertPreTrainedModel(PreTrainedModel):
|
520 |
+
""" An abstract class to handle weights initialization and
|
521 |
+
a simple interface for downloading and loading pretrained models.
|
522 |
+
"""
|
523 |
+
|
524 |
+
config_class = BertConfig
|
525 |
+
pretrained_model_archive_map = BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
526 |
+
load_tf_weights = load_tf_weights_in_bert
|
527 |
+
base_model_prefix = "bert"
|
528 |
+
|
529 |
+
def _init_weights(self, module):
|
530 |
+
""" Initialize the weights """
|
531 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
532 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
533 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
534 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
535 |
+
elif isinstance(module, BertLayerNorm):
|
536 |
+
module.bias.data.zero_()
|
537 |
+
module.weight.data.fill_(1.0)
|
538 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
539 |
+
module.bias.data.zero_()
|
540 |
+
|
541 |
+
|
542 |
+
BERT_START_DOCSTRING = r"""
|
543 |
+
This model is a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`_ sub-class.
|
544 |
+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general
|
545 |
+
usage and behavior.
|
546 |
+
|
547 |
+
Parameters:
|
548 |
+
config (:class:`~transformers.BertConfig`): Model configuration class with all the parameters of the model.
|
549 |
+
Initializing with a config file does not load the weights associated with the model, only the configuration.
|
550 |
+
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
|
551 |
+
"""
|
552 |
+
|
553 |
+
BERT_INPUTS_DOCSTRING = r"""
|
554 |
+
Args:
|
555 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
|
556 |
+
Indices of input sequence tokens in the vocabulary.
|
557 |
+
|
558 |
+
Indices can be obtained using :class:`transformers.BertTokenizer`.
|
559 |
+
See :func:`transformers.PreTrainedTokenizer.encode` and
|
560 |
+
:func:`transformers.PreTrainedTokenizer.encode_plus` for details.
|
561 |
+
|
562 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
563 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
564 |
+
Mask to avoid performing attention on padding token indices.
|
565 |
+
Mask values selected in ``[0, 1]``:
|
566 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
567 |
+
|
568 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
569 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
570 |
+
Segment token indices to indicate first and second portions of the inputs.
|
571 |
+
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
572 |
+
corresponds to a `sentence B` token
|
573 |
+
|
574 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
575 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
576 |
+
Indices of positions of each input sequence tokens in the position embeddings.
|
577 |
+
Selected in the range ``[0, config.max_position_embeddings - 1]``.
|
578 |
+
|
579 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
580 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`, defaults to :obj:`None`):
|
581 |
+
Mask to nullify selected heads of the self-attention modules.
|
582 |
+
Mask values selected in ``[0, 1]``:
|
583 |
+
:obj:`1` indicates the head is **not masked**, :obj:`0` indicates the head is **masked**.
|
584 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
585 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
586 |
+
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
|
587 |
+
than the model's internal embedding lookup matrix.
|
588 |
+
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
|
589 |
+
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
|
590 |
+
if the model is configured as a decoder.
|
591 |
+
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
592 |
+
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
|
593 |
+
is used in the cross-attention if the model is configured as a decoder.
|
594 |
+
Mask values selected in ``[0, 1]``:
|
595 |
+
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
596 |
+
"""
|
597 |
+
|
598 |
+
|
599 |
+
@add_start_docstrings(
|
600 |
+
"The bare Bert Model transformer outputting raw hidden-states without any specific head on top.",
|
601 |
+
BERT_START_DOCSTRING,
|
602 |
+
)
|
603 |
+
class BertModel(BertPreTrainedModel):
|
604 |
+
"""
|
605 |
+
|
606 |
+
The model can behave as an encoder (with only self-attention) as well
|
607 |
+
as a decoder, in which case a layer of cross-attention is added between
|
608 |
+
the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
|
609 |
+
Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
|
610 |
+
|
611 |
+
To behave as an decoder the model needs to be initialized with the
|
612 |
+
:obj:`is_decoder` argument of the configuration set to :obj:`True`; an
|
613 |
+
:obj:`encoder_hidden_states` is expected as an input to the forward pass.
|
614 |
+
|
615 |
+
.. _`Attention is all you need`:
|
616 |
+
https://arxiv.org/abs/1706.03762
|
617 |
+
|
618 |
+
"""
|
619 |
+
|
620 |
+
def __init__(self, config):
|
621 |
+
super().__init__(config)
|
622 |
+
self.config = config
|
623 |
+
|
624 |
+
self.embeddings = BertEmbeddings(config)
|
625 |
+
self.encoder = BertEncoder(config)
|
626 |
+
self.pooler = BertPooler(config)
|
627 |
+
|
628 |
+
self.init_weights()
|
629 |
+
# hyper-param for patience-based adaptive inference
|
630 |
+
self.patience = 0
|
631 |
+
# threshold for confidence-based adaptive inference
|
632 |
+
self.confidence_threshold = 0.8
|
633 |
+
# mode for fast_inference [True for patience-based/ False for confidence-based/ All classifier/ Last Classifier]
|
634 |
+
self.mode = 'patience' # [patience/confi/all/last]
|
635 |
+
|
636 |
+
self.inference_instances_num = 0
|
637 |
+
self.inference_layers_num = 0
|
638 |
+
|
639 |
+
# exits count log
|
640 |
+
self.exits_count_list = [0] * self.config.num_hidden_layers
|
641 |
+
# exits time log
|
642 |
+
self.exits_time_list = [[] for _ in range(self.config.num_hidden_layers)]
|
643 |
+
|
644 |
+
self.regression_threshold = 0
|
645 |
+
|
646 |
+
def set_regression_threshold(self, threshold):
|
647 |
+
self.regression_threshold = threshold
|
648 |
+
|
649 |
+
def set_mode(self, patience='patience'):
|
650 |
+
self.mode = patience # mode for test-time inference
|
651 |
+
|
652 |
+
def set_patience(self, patience):
|
653 |
+
self.patience = patience
|
654 |
+
|
655 |
+
def set_confi_threshold(self, confidence_threshold):
|
656 |
+
self.confidence_threshold = confidence_threshold
|
657 |
+
|
658 |
+
def reset_stats(self):
|
659 |
+
self.inference_instances_num = 0
|
660 |
+
self.inference_layers_num = 0
|
661 |
+
self.exits_count_list = [0] * self.config.num_hidden_layers
|
662 |
+
self.exits_time_list = [[] for _ in range(self.config.num_hidden_layers)]
|
663 |
+
|
664 |
+
def log_stats(self):
|
665 |
+
avg_inf_layers = self.inference_layers_num / self.inference_instances_num
|
666 |
+
message = f'*** Patience = {self.patience} Avg. Inference Layers = {avg_inf_layers:.2f} Speed Up = {1 - avg_inf_layers / self.config.num_hidden_layers:.2f} ***'
|
667 |
+
print(message)
|
668 |
+
|
669 |
+
def get_input_embeddings(self):
|
670 |
+
return self.embeddings.word_embeddings
|
671 |
+
|
672 |
+
def set_input_embeddings(self, value):
|
673 |
+
self.embeddings.word_embeddings = value
|
674 |
+
|
675 |
+
def _prune_heads(self, heads_to_prune):
|
676 |
+
""" Prunes heads of the model.
|
677 |
+
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
678 |
+
See base class PreTrainedModel
|
679 |
+
"""
|
680 |
+
for layer, heads in heads_to_prune.items():
|
681 |
+
self.encoder.layer[layer].attention.prune_heads(heads)
|
682 |
+
|
683 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
684 |
+
def forward(
|
685 |
+
self,
|
686 |
+
input_ids=None,
|
687 |
+
attention_mask=None,
|
688 |
+
token_type_ids=None,
|
689 |
+
position_ids=None,
|
690 |
+
head_mask=None,
|
691 |
+
inputs_embeds=None,
|
692 |
+
encoder_hidden_states=None,
|
693 |
+
encoder_attention_mask=None,
|
694 |
+
output_dropout=None,
|
695 |
+
output_layers=None,
|
696 |
+
regression=False
|
697 |
+
):
|
698 |
+
r"""
|
699 |
+
Return:
|
700 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
701 |
+
last_hidden_state (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`):
|
702 |
+
Sequence of hidden-states at the output of the last layer of the model.
|
703 |
+
pooler_output (:obj:`torch.FloatTensor`: of shape :obj:`(batch_size, hidden_size)`):
|
704 |
+
Last layer hidden-state of the first token of the sequence (classification token)
|
705 |
+
further processed by a Linear layer and a Tanh activation function. The Linear
|
706 |
+
layer weights are trained from the next sentence prediction (classification)
|
707 |
+
objective during pre-training.
|
708 |
+
|
709 |
+
This output is usually *not* a good summary
|
710 |
+
of the semantic content of the input, you're often better with averaging or pooling
|
711 |
+
the sequence of hidden-states for the whole input sequence.
|
712 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
713 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
714 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
715 |
+
|
716 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
717 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
718 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
719 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
720 |
+
|
721 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
722 |
+
heads.
|
723 |
+
|
724 |
+
Examples::
|
725 |
+
|
726 |
+
from transformers import BertModel, BertTokenizer
|
727 |
+
import torch
|
728 |
+
|
729 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
730 |
+
model = BertModel.from_pretrained('bert-base-uncased')
|
731 |
+
|
732 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
733 |
+
outputs = model(input_ids)
|
734 |
+
|
735 |
+
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
736 |
+
|
737 |
+
"""
|
738 |
+
|
739 |
+
if input_ids is not None and inputs_embeds is not None:
|
740 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
741 |
+
elif input_ids is not None:
|
742 |
+
input_shape = input_ids.size()
|
743 |
+
elif inputs_embeds is not None:
|
744 |
+
input_shape = inputs_embeds.size()[:-1]
|
745 |
+
else:
|
746 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
747 |
+
|
748 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
749 |
+
|
750 |
+
if attention_mask is None:
|
751 |
+
attention_mask = torch.ones(input_shape, device=device)
|
752 |
+
if token_type_ids is None:
|
753 |
+
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
754 |
+
|
755 |
+
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
756 |
+
# ourselves in which case we just need to make it broadcastable to all heads.
|
757 |
+
if attention_mask.dim() == 3:
|
758 |
+
extended_attention_mask = attention_mask[:, None, :, :]
|
759 |
+
elif attention_mask.dim() == 2:
|
760 |
+
# Provided a padding mask of dimensions [batch_size, seq_length]
|
761 |
+
# - if the model is a decoder, apply a causal mask in addition to the padding mask
|
762 |
+
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
763 |
+
if self.config.is_decoder:
|
764 |
+
batch_size, seq_length = input_shape
|
765 |
+
seq_ids = torch.arange(seq_length, device=device)
|
766 |
+
causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
|
767 |
+
causal_mask = causal_mask.to(
|
768 |
+
attention_mask.dtype
|
769 |
+
) # causal and attention masks must have same type with pytorch version < 1.3
|
770 |
+
extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
|
771 |
+
else:
|
772 |
+
extended_attention_mask = attention_mask[:, None, None, :]
|
773 |
+
else:
|
774 |
+
raise ValueError(
|
775 |
+
"Wrong shape for input_ids (shape {}) or attention_mask (shape {})".format(
|
776 |
+
input_shape, attention_mask.shape
|
777 |
+
)
|
778 |
+
)
|
779 |
+
|
780 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
781 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
782 |
+
# positions we want to attend and -10000.0 for masked positions.
|
783 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
784 |
+
# effectively the same as removing these entirely.
|
785 |
+
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
786 |
+
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
787 |
+
|
788 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
789 |
+
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
790 |
+
if self.config.is_decoder and encoder_hidden_states is not None:
|
791 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
792 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
793 |
+
if encoder_attention_mask is None:
|
794 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
795 |
+
|
796 |
+
if encoder_attention_mask.dim() == 3:
|
797 |
+
encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
|
798 |
+
elif encoder_attention_mask.dim() == 2:
|
799 |
+
encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
|
800 |
+
else:
|
801 |
+
raise ValueError(
|
802 |
+
"Wrong shape for encoder_hidden_shape (shape {}) or encoder_attention_mask (shape {})".format(
|
803 |
+
encoder_hidden_shape, encoder_attention_mask.shape
|
804 |
+
)
|
805 |
+
)
|
806 |
+
|
807 |
+
encoder_extended_attention_mask = encoder_extended_attention_mask.to(
|
808 |
+
dtype=next(self.parameters()).dtype
|
809 |
+
) # fp16 compatibility
|
810 |
+
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -10000.0
|
811 |
+
else:
|
812 |
+
encoder_extended_attention_mask = None
|
813 |
+
|
814 |
+
# Prepare head mask if needed
|
815 |
+
# 1.0 in head_mask indicate we keep the head
|
816 |
+
# attention_probs has shape bsz x n_heads x N x N
|
817 |
+
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
818 |
+
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
819 |
+
if head_mask is not None:
|
820 |
+
if head_mask.dim() == 1:
|
821 |
+
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
822 |
+
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
823 |
+
elif head_mask.dim() == 2:
|
824 |
+
head_mask = (
|
825 |
+
head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1)
|
826 |
+
) # We can specify head_mask for each layer
|
827 |
+
head_mask = head_mask.to(
|
828 |
+
dtype=next(self.parameters()).dtype
|
829 |
+
) # switch to fload if need + fp16 compatibility
|
830 |
+
else:
|
831 |
+
head_mask = [None] * self.config.num_hidden_layers
|
832 |
+
|
833 |
+
embedding_output = self.embeddings(
|
834 |
+
input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
|
835 |
+
)
|
836 |
+
encoder_outputs = embedding_output
|
837 |
+
|
838 |
+
if self.training:
|
839 |
+
res = []
|
840 |
+
for i in range(self.config.num_hidden_layers):
|
841 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
842 |
+
current_layer=i,
|
843 |
+
attention_mask=extended_attention_mask,
|
844 |
+
head_mask=head_mask
|
845 |
+
)
|
846 |
+
|
847 |
+
pooled_output = self.pooler(encoder_outputs)
|
848 |
+
logits = output_layers[i](output_dropout(pooled_output))
|
849 |
+
res.append(logits)
|
850 |
+
elif self.mode == 'last': # Use all layers for inference
|
851 |
+
encoder_outputs = self.encoder(encoder_outputs,
|
852 |
+
extended_attention_mask,
|
853 |
+
head_mask=head_mask)
|
854 |
+
pooled_output = self.pooler(encoder_outputs[0])
|
855 |
+
res = [output_layers[self.config.num_hidden_layers - 1](pooled_output)]
|
856 |
+
elif self.mode == 'all':
|
857 |
+
tic = datetime.now()
|
858 |
+
res = []
|
859 |
+
for i in range(self.config.num_hidden_layers):
|
860 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
861 |
+
current_layer=i,
|
862 |
+
attention_mask=extended_attention_mask,
|
863 |
+
head_mask=head_mask
|
864 |
+
)
|
865 |
+
|
866 |
+
pooled_output = self.pooler(encoder_outputs)
|
867 |
+
logits = output_layers[i](output_dropout(pooled_output))
|
868 |
+
toc = datetime.now()
|
869 |
+
exit_time = (toc - tic).total_seconds()
|
870 |
+
res.append(logits)
|
871 |
+
self.exits_time_list[i].append(exit_time)
|
872 |
+
elif self.mode == 'patience':
|
873 |
+
if self.patience <=0:
|
874 |
+
raise ValueError("Patience must be greater than 0")
|
875 |
+
patient_counter = 0
|
876 |
+
patient_result = None
|
877 |
+
calculated_layer_num = 0
|
878 |
+
tic = datetime.now()
|
879 |
+
for i in range(self.config.num_hidden_layers):
|
880 |
+
calculated_layer_num += 1
|
881 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
882 |
+
current_layer=i,
|
883 |
+
attention_mask=extended_attention_mask,
|
884 |
+
head_mask=head_mask
|
885 |
+
)
|
886 |
+
|
887 |
+
pooled_output = self.pooler(encoder_outputs)
|
888 |
+
logits = output_layers[i](pooled_output)
|
889 |
+
if regression:
|
890 |
+
labels = logits.detach()
|
891 |
+
if patient_result is not None:
|
892 |
+
patient_labels = patient_result.detach()
|
893 |
+
if (patient_result is not None) and torch.abs(patient_result - labels) < self.regression_threshold:
|
894 |
+
patient_counter += 1
|
895 |
+
else:
|
896 |
+
patient_counter = 0
|
897 |
+
else:
|
898 |
+
labels = logits.detach().argmax(dim=1)
|
899 |
+
if patient_result is not None:
|
900 |
+
patient_labels = patient_result.detach().argmax(dim=1)
|
901 |
+
if (patient_result is not None) and torch.all(labels.eq(patient_labels)):
|
902 |
+
patient_counter += 1
|
903 |
+
else:
|
904 |
+
patient_counter = 0
|
905 |
+
|
906 |
+
patient_result = logits
|
907 |
+
if patient_counter == self.patience:
|
908 |
+
break
|
909 |
+
toc = datetime.now()
|
910 |
+
self.exit_time = (toc - tic).total_seconds()
|
911 |
+
res = [patient_result]
|
912 |
+
self.inference_layers_num += calculated_layer_num
|
913 |
+
self.inference_instances_num += 1
|
914 |
+
self.current_exit_layer = calculated_layer_num
|
915 |
+
# LOG EXIT POINTS COUNTS
|
916 |
+
self.exits_count_list[calculated_layer_num-1] += 1
|
917 |
+
elif self.mode == 'confi':
|
918 |
+
if self.confidence_threshold<0 or self.confidence_threshold>1:
|
919 |
+
raise ValueError('Confidence Threshold must be set within the range 0-1')
|
920 |
+
calculated_layer_num = 0
|
921 |
+
tic = datetime.now()
|
922 |
+
for i in range(self.config.num_hidden_layers):
|
923 |
+
calculated_layer_num += 1
|
924 |
+
encoder_outputs = self.encoder.adaptive_forward(encoder_outputs,
|
925 |
+
current_layer=i,
|
926 |
+
attention_mask=extended_attention_mask,
|
927 |
+
head_mask=head_mask
|
928 |
+
)
|
929 |
+
|
930 |
+
pooled_output = self.pooler(encoder_outputs)
|
931 |
+
logits = output_layers[i](pooled_output)
|
932 |
+
labels = logits.detach().argmax(dim=1)
|
933 |
+
logits_max,_ = logits.detach().softmax(dim=1).max(dim=1)
|
934 |
+
|
935 |
+
confi_result = logits
|
936 |
+
if torch.all(logits_max.gt(self.confidence_threshold)):
|
937 |
+
break
|
938 |
+
toc = datetime.now()
|
939 |
+
self.exit_time = (toc - tic).total_seconds()
|
940 |
+
res = [confi_result]
|
941 |
+
self.inference_layers_num += calculated_layer_num
|
942 |
+
self.inference_instances_num += 1
|
943 |
+
self.current_exit_layer = calculated_layer_num
|
944 |
+
# LOG EXIT POINTS COUNTS
|
945 |
+
self.exits_count_list[calculated_layer_num-1] += 1
|
946 |
+
return res
|
947 |
+
|
948 |
+
|
949 |
+
@add_start_docstrings(
|
950 |
+
"""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and
|
951 |
+
a `next sentence prediction (classification)` head. """,
|
952 |
+
BERT_START_DOCSTRING,
|
953 |
+
)
|
954 |
+
class BertForPreTraining(BertPreTrainedModel):
|
955 |
+
def __init__(self, config):
|
956 |
+
super().__init__(config)
|
957 |
+
|
958 |
+
self.bert = BertModel(config)
|
959 |
+
self.cls = BertPreTrainingHeads(config)
|
960 |
+
|
961 |
+
self.init_weights()
|
962 |
+
|
963 |
+
def get_output_embeddings(self):
|
964 |
+
return self.cls.predictions.decoder
|
965 |
+
|
966 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
967 |
+
def forward(
|
968 |
+
self,
|
969 |
+
input_ids=None,
|
970 |
+
attention_mask=None,
|
971 |
+
token_type_ids=None,
|
972 |
+
position_ids=None,
|
973 |
+
head_mask=None,
|
974 |
+
inputs_embeds=None,
|
975 |
+
masked_lm_labels=None,
|
976 |
+
next_sentence_label=None,
|
977 |
+
):
|
978 |
+
r"""
|
979 |
+
masked_lm_labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
980 |
+
Labels for computing the masked language modeling loss.
|
981 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
982 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
983 |
+
in ``[0, ..., config.vocab_size]``
|
984 |
+
next_sentence_label (``torch.LongTensor`` of shape ``(batch_size,)``, `optional`, defaults to :obj:`None`):
|
985 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see :obj:`input_ids` docstring)
|
986 |
+
Indices should be in ``[0, 1]``.
|
987 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
988 |
+
``1`` indicates sequence B is a random sequence.
|
989 |
+
|
990 |
+
Returns:
|
991 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
992 |
+
loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
993 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
|
994 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
995 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
996 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`):
|
997 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False
|
998 |
+
continuation before SoftMax).
|
999 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
1000 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1001 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1002 |
+
|
1003 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1004 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1005 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1006 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1007 |
+
|
1008 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1009 |
+
heads.
|
1010 |
+
|
1011 |
+
|
1012 |
+
Examples::
|
1013 |
+
|
1014 |
+
from transformers import BertTokenizer, BertForPreTraining
|
1015 |
+
import torch
|
1016 |
+
|
1017 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1018 |
+
model = BertForPreTraining.from_pretrained('bert-base-uncased')
|
1019 |
+
|
1020 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
1021 |
+
outputs = model(input_ids)
|
1022 |
+
|
1023 |
+
prediction_scores, seq_relationship_scores = outputs[:2]
|
1024 |
+
|
1025 |
+
"""
|
1026 |
+
|
1027 |
+
outputs = self.bert(
|
1028 |
+
input_ids,
|
1029 |
+
attention_mask=attention_mask,
|
1030 |
+
token_type_ids=token_type_ids,
|
1031 |
+
position_ids=position_ids,
|
1032 |
+
head_mask=head_mask,
|
1033 |
+
inputs_embeds=inputs_embeds,
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
sequence_output, pooled_output = outputs[:2]
|
1037 |
+
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
1038 |
+
|
1039 |
+
outputs = (prediction_scores, seq_relationship_score,) + outputs[
|
1040 |
+
2:
|
1041 |
+
] # add hidden states and attention if they are here
|
1042 |
+
|
1043 |
+
if masked_lm_labels is not None and next_sentence_label is not None:
|
1044 |
+
loss_fct = CrossEntropyLoss()
|
1045 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
1046 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
1047 |
+
total_loss = masked_lm_loss + next_sentence_loss
|
1048 |
+
outputs = (total_loss,) + outputs
|
1049 |
+
|
1050 |
+
return outputs # (loss), prediction_scores, seq_relationship_score, (hidden_states), (attentions)
|
1051 |
+
|
1052 |
+
|
1053 |
+
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
1054 |
+
class BertForMaskedLM(BertPreTrainedModel):
|
1055 |
+
def __init__(self, config):
|
1056 |
+
super().__init__(config)
|
1057 |
+
|
1058 |
+
self.bert = BertModel(config)
|
1059 |
+
self.cls = BertOnlyMLMHead(config)
|
1060 |
+
|
1061 |
+
self.init_weights()
|
1062 |
+
|
1063 |
+
def get_output_embeddings(self):
|
1064 |
+
return self.cls.predictions.decoder
|
1065 |
+
|
1066 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1067 |
+
def forward(
|
1068 |
+
self,
|
1069 |
+
input_ids=None,
|
1070 |
+
attention_mask=None,
|
1071 |
+
token_type_ids=None,
|
1072 |
+
position_ids=None,
|
1073 |
+
head_mask=None,
|
1074 |
+
inputs_embeds=None,
|
1075 |
+
masked_lm_labels=None,
|
1076 |
+
encoder_hidden_states=None,
|
1077 |
+
encoder_attention_mask=None,
|
1078 |
+
lm_labels=None,
|
1079 |
+
):
|
1080 |
+
r"""
|
1081 |
+
masked_lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1082 |
+
Labels for computing the masked language modeling loss.
|
1083 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
1084 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
1085 |
+
in ``[0, ..., config.vocab_size]``
|
1086 |
+
lm_labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1087 |
+
Labels for computing the left-to-right language modeling loss (next word prediction).
|
1088 |
+
Indices should be in ``[-100, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
|
1089 |
+
Tokens with indices set to ``-100`` are ignored (masked), the loss is only computed for the tokens with labels
|
1090 |
+
in ``[0, ..., config.vocab_size]``
|
1091 |
+
|
1092 |
+
Returns:
|
1093 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1094 |
+
masked_lm_loss (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
1095 |
+
Masked language modeling loss.
|
1096 |
+
ltr_lm_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`lm_labels` is provided):
|
1097 |
+
Next token prediction loss.
|
1098 |
+
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
1099 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
1100 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1101 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1102 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1103 |
+
|
1104 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1105 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1106 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1107 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1108 |
+
|
1109 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1110 |
+
heads.
|
1111 |
+
|
1112 |
+
Examples::
|
1113 |
+
|
1114 |
+
from transformers import BertTokenizer, BertForMaskedLM
|
1115 |
+
import torch
|
1116 |
+
|
1117 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1118 |
+
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
|
1119 |
+
|
1120 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
1121 |
+
outputs = model(input_ids, masked_lm_labels=input_ids)
|
1122 |
+
|
1123 |
+
loss, prediction_scores = outputs[:2]
|
1124 |
+
|
1125 |
+
"""
|
1126 |
+
|
1127 |
+
outputs = self.bert(
|
1128 |
+
input_ids,
|
1129 |
+
attention_mask=attention_mask,
|
1130 |
+
token_type_ids=token_type_ids,
|
1131 |
+
position_ids=position_ids,
|
1132 |
+
head_mask=head_mask,
|
1133 |
+
inputs_embeds=inputs_embeds,
|
1134 |
+
encoder_hidden_states=encoder_hidden_states,
|
1135 |
+
encoder_attention_mask=encoder_attention_mask,
|
1136 |
+
)
|
1137 |
+
|
1138 |
+
sequence_output = outputs[0]
|
1139 |
+
prediction_scores = self.cls(sequence_output)
|
1140 |
+
|
1141 |
+
outputs = (prediction_scores,) + outputs[2:] # Add hidden states and attention if they are here
|
1142 |
+
|
1143 |
+
# Although this may seem awkward, BertForMaskedLM supports two scenarios:
|
1144 |
+
# 1. If a tensor that contains the indices of masked labels is provided,
|
1145 |
+
# the cross-entropy is the MLM cross-entropy that measures the likelihood
|
1146 |
+
# of predictions for masked words.
|
1147 |
+
# 2. If `lm_labels` is provided we are in a causal scenario where we
|
1148 |
+
# try to predict the next token for each input in the decoder.
|
1149 |
+
if masked_lm_labels is not None:
|
1150 |
+
loss_fct = CrossEntropyLoss() # -100 index = padding token
|
1151 |
+
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
1152 |
+
outputs = (masked_lm_loss,) + outputs
|
1153 |
+
|
1154 |
+
if lm_labels is not None:
|
1155 |
+
# we are doing next-token prediction; shift prediction scores and input ids by one
|
1156 |
+
prediction_scores = prediction_scores[:, :-1, :].contiguous()
|
1157 |
+
lm_labels = lm_labels[:, 1:].contiguous()
|
1158 |
+
loss_fct = CrossEntropyLoss()
|
1159 |
+
ltr_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), lm_labels.view(-1))
|
1160 |
+
outputs = (ltr_lm_loss,) + outputs
|
1161 |
+
|
1162 |
+
return outputs # (masked_lm_loss), (ltr_lm_loss), prediction_scores, (hidden_states), (attentions)
|
1163 |
+
|
1164 |
+
|
1165 |
+
@add_start_docstrings(
|
1166 |
+
"""Bert Model with a `next sentence prediction (classification)` head on top. """, BERT_START_DOCSTRING,
|
1167 |
+
)
|
1168 |
+
class BertForNextSentencePrediction(BertPreTrainedModel):
|
1169 |
+
def __init__(self, config):
|
1170 |
+
super().__init__(config)
|
1171 |
+
|
1172 |
+
self.bert = BertModel(config)
|
1173 |
+
self.cls = BertOnlyNSPHead(config)
|
1174 |
+
|
1175 |
+
self.init_weights()
|
1176 |
+
|
1177 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1178 |
+
def forward(
|
1179 |
+
self,
|
1180 |
+
input_ids=None,
|
1181 |
+
attention_mask=None,
|
1182 |
+
token_type_ids=None,
|
1183 |
+
position_ids=None,
|
1184 |
+
head_mask=None,
|
1185 |
+
inputs_embeds=None,
|
1186 |
+
next_sentence_label=None,
|
1187 |
+
):
|
1188 |
+
r"""
|
1189 |
+
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1190 |
+
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
|
1191 |
+
Indices should be in ``[0, 1]``.
|
1192 |
+
``0`` indicates sequence B is a continuation of sequence A,
|
1193 |
+
``1`` indicates sequence B is a random sequence.
|
1194 |
+
|
1195 |
+
Returns:
|
1196 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1197 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`next_sentence_label` is provided):
|
1198 |
+
Next sequence prediction (classification) loss.
|
1199 |
+
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, 2)`):
|
1200 |
+
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
1201 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1202 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1203 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1204 |
+
|
1205 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1206 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1207 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1208 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1209 |
+
|
1210 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1211 |
+
heads.
|
1212 |
+
|
1213 |
+
Examples::
|
1214 |
+
|
1215 |
+
from transformers import BertTokenizer, BertForNextSentencePrediction
|
1216 |
+
import torch
|
1217 |
+
|
1218 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1219 |
+
model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
|
1220 |
+
|
1221 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
1222 |
+
outputs = model(input_ids)
|
1223 |
+
|
1224 |
+
seq_relationship_scores = outputs[0]
|
1225 |
+
|
1226 |
+
"""
|
1227 |
+
|
1228 |
+
outputs = self.bert(
|
1229 |
+
input_ids,
|
1230 |
+
attention_mask=attention_mask,
|
1231 |
+
token_type_ids=token_type_ids,
|
1232 |
+
position_ids=position_ids,
|
1233 |
+
head_mask=head_mask,
|
1234 |
+
inputs_embeds=inputs_embeds,
|
1235 |
+
)
|
1236 |
+
|
1237 |
+
pooled_output = outputs[1]
|
1238 |
+
|
1239 |
+
seq_relationship_score = self.cls(pooled_output)
|
1240 |
+
|
1241 |
+
outputs = (seq_relationship_score,) + outputs[2:] # add hidden states and attention if they are here
|
1242 |
+
if next_sentence_label is not None:
|
1243 |
+
loss_fct = CrossEntropyLoss()
|
1244 |
+
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
1245 |
+
outputs = (next_sentence_loss,) + outputs
|
1246 |
+
|
1247 |
+
return outputs # (next_sentence_loss), seq_relationship_score, (hidden_states), (attentions)
|
1248 |
+
|
1249 |
+
|
1250 |
+
@add_start_docstrings(
|
1251 |
+
"""Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
1252 |
+
the pooled output) e.g. for GLUE tasks. """,
|
1253 |
+
BERT_START_DOCSTRING,
|
1254 |
+
)
|
1255 |
+
class BertForSequenceClassification(BertPreTrainedModel):
|
1256 |
+
def __init__(self, config):
|
1257 |
+
super().__init__(config)
|
1258 |
+
self.num_labels = config.num_labels
|
1259 |
+
|
1260 |
+
self.bert = BertModel(config)
|
1261 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1262 |
+
self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, self.config.num_labels) for _ in range(config.num_hidden_layers)])
|
1263 |
+
|
1264 |
+
self.init_weights()
|
1265 |
+
|
1266 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1267 |
+
def forward(
|
1268 |
+
self,
|
1269 |
+
input_ids=None,
|
1270 |
+
attention_mask=None,
|
1271 |
+
token_type_ids=None,
|
1272 |
+
position_ids=None,
|
1273 |
+
head_mask=None,
|
1274 |
+
inputs_embeds=None,
|
1275 |
+
labels=None,
|
1276 |
+
):
|
1277 |
+
r"""
|
1278 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1279 |
+
Labels for computing the sequence classification/regression loss.
|
1280 |
+
Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
|
1281 |
+
If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
|
1282 |
+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
|
1283 |
+
|
1284 |
+
Returns:
|
1285 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1286 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
|
1287 |
+
Classification (or regression if config.num_labels==1) loss.
|
1288 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
1289 |
+
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
1290 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1291 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1292 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1293 |
+
|
1294 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1295 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1296 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1297 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1298 |
+
|
1299 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1300 |
+
heads.
|
1301 |
+
|
1302 |
+
Examples::
|
1303 |
+
|
1304 |
+
from transformers import BertTokenizer, BertForSequenceClassification
|
1305 |
+
import torch
|
1306 |
+
|
1307 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1308 |
+
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')
|
1309 |
+
|
1310 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
1311 |
+
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
|
1312 |
+
outputs = model(input_ids, labels=labels)
|
1313 |
+
|
1314 |
+
loss, logits = outputs[:2]
|
1315 |
+
|
1316 |
+
"""
|
1317 |
+
|
1318 |
+
logits = self.bert(
|
1319 |
+
input_ids=input_ids,
|
1320 |
+
attention_mask=attention_mask,
|
1321 |
+
token_type_ids=token_type_ids,
|
1322 |
+
position_ids=position_ids,
|
1323 |
+
head_mask=head_mask,
|
1324 |
+
inputs_embeds=inputs_embeds,
|
1325 |
+
output_dropout=self.dropout,
|
1326 |
+
output_layers=self.classifiers,
|
1327 |
+
regression=self.num_labels == 1
|
1328 |
+
)
|
1329 |
+
|
1330 |
+
if self.bert.mode == 'all':
|
1331 |
+
outputs = (logits,)
|
1332 |
+
else:
|
1333 |
+
outputs = (logits[-1],)
|
1334 |
+
|
1335 |
+
if labels is not None:
|
1336 |
+
total_loss = None
|
1337 |
+
total_weights = 0
|
1338 |
+
for ix, logits_item in enumerate(logits):
|
1339 |
+
if self.num_labels == 1:
|
1340 |
+
# We are doing regression
|
1341 |
+
loss_fct = MSELoss()
|
1342 |
+
loss = loss_fct(logits_item.view(-1), labels.view(-1))
|
1343 |
+
else:
|
1344 |
+
loss_fct = CrossEntropyLoss()
|
1345 |
+
loss = loss_fct(logits_item.view(-1, self.num_labels), labels.view(-1))
|
1346 |
+
if total_loss is None:
|
1347 |
+
total_loss = loss
|
1348 |
+
else:
|
1349 |
+
total_loss += loss * (ix + 1)
|
1350 |
+
total_weights += ix + 1
|
1351 |
+
outputs = (total_loss / total_weights,) + outputs
|
1352 |
+
|
1353 |
+
return outputs # (loss), logits, (hidden_states), (attentions)
|
1354 |
+
|
1355 |
+
|
1356 |
+
@add_start_docstrings(
|
1357 |
+
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
|
1358 |
+
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """,
|
1359 |
+
BERT_START_DOCSTRING,
|
1360 |
+
)
|
1361 |
+
class BertForMultipleChoice(BertPreTrainedModel):
|
1362 |
+
def __init__(self, config):
|
1363 |
+
super().__init__(config)
|
1364 |
+
|
1365 |
+
self.bert = BertModel(config)
|
1366 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1367 |
+
self.classifier = nn.Linear(config.hidden_size, 1)
|
1368 |
+
|
1369 |
+
self.init_weights()
|
1370 |
+
|
1371 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1372 |
+
def forward(
|
1373 |
+
self,
|
1374 |
+
input_ids=None,
|
1375 |
+
attention_mask=None,
|
1376 |
+
token_type_ids=None,
|
1377 |
+
position_ids=None,
|
1378 |
+
head_mask=None,
|
1379 |
+
inputs_embeds=None,
|
1380 |
+
labels=None,
|
1381 |
+
):
|
1382 |
+
r"""
|
1383 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1384 |
+
Labels for computing the multiple choice classification loss.
|
1385 |
+
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
|
1386 |
+
of the input tensors. (see `input_ids` above)
|
1387 |
+
|
1388 |
+
Returns:
|
1389 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1390 |
+
loss (:obj:`torch.FloatTensor`` of shape ``(1,)`, `optional`, returned when :obj:`labels` is provided):
|
1391 |
+
Classification loss.
|
1392 |
+
classification_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
1393 |
+
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
1394 |
+
|
1395 |
+
Classification scores (before SoftMax).
|
1396 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1397 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1398 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1399 |
+
|
1400 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1401 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1402 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1403 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1404 |
+
|
1405 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1406 |
+
heads.
|
1407 |
+
|
1408 |
+
Examples::
|
1409 |
+
|
1410 |
+
from transformers import BertTokenizer, BertForMultipleChoice
|
1411 |
+
import torch
|
1412 |
+
|
1413 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1414 |
+
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
|
1415 |
+
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
1416 |
+
|
1417 |
+
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
1418 |
+
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
1419 |
+
outputs = model(input_ids, labels=labels)
|
1420 |
+
|
1421 |
+
loss, classification_scores = outputs[:2]
|
1422 |
+
|
1423 |
+
"""
|
1424 |
+
num_choices = input_ids.shape[1]
|
1425 |
+
|
1426 |
+
input_ids = input_ids.view(-1, input_ids.size(-1))
|
1427 |
+
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
1428 |
+
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
1429 |
+
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
1430 |
+
|
1431 |
+
outputs = self.bert(
|
1432 |
+
input_ids,
|
1433 |
+
attention_mask=attention_mask,
|
1434 |
+
token_type_ids=token_type_ids,
|
1435 |
+
position_ids=position_ids,
|
1436 |
+
head_mask=head_mask,
|
1437 |
+
inputs_embeds=inputs_embeds,
|
1438 |
+
)
|
1439 |
+
|
1440 |
+
pooled_output = outputs[1]
|
1441 |
+
|
1442 |
+
pooled_output = self.dropout(pooled_output)
|
1443 |
+
logits = self.classifier(pooled_output)
|
1444 |
+
reshaped_logits = logits.view(-1, num_choices)
|
1445 |
+
|
1446 |
+
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
1447 |
+
|
1448 |
+
if labels is not None:
|
1449 |
+
loss_fct = CrossEntropyLoss()
|
1450 |
+
loss = loss_fct(reshaped_logits, labels)
|
1451 |
+
outputs = (loss,) + outputs
|
1452 |
+
|
1453 |
+
return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
|
1454 |
+
|
1455 |
+
|
1456 |
+
@add_start_docstrings(
|
1457 |
+
"""Bert Model with a token classification head on top (a linear layer on top of
|
1458 |
+
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """,
|
1459 |
+
BERT_START_DOCSTRING,
|
1460 |
+
)
|
1461 |
+
class BertForTokenClassification(BertPreTrainedModel):
|
1462 |
+
def __init__(self, config):
|
1463 |
+
super().__init__(config)
|
1464 |
+
self.num_labels = config.num_labels
|
1465 |
+
|
1466 |
+
self.bert = BertModel(config)
|
1467 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
1468 |
+
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
1469 |
+
|
1470 |
+
self.init_weights()
|
1471 |
+
|
1472 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1473 |
+
def forward(
|
1474 |
+
self,
|
1475 |
+
input_ids=None,
|
1476 |
+
attention_mask=None,
|
1477 |
+
token_type_ids=None,
|
1478 |
+
position_ids=None,
|
1479 |
+
head_mask=None,
|
1480 |
+
inputs_embeds=None,
|
1481 |
+
labels=None,
|
1482 |
+
):
|
1483 |
+
r"""
|
1484 |
+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
1485 |
+
Labels for computing the token classification loss.
|
1486 |
+
Indices should be in ``[0, ..., config.num_labels - 1]``.
|
1487 |
+
|
1488 |
+
Returns:
|
1489 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1490 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided) :
|
1491 |
+
Classification loss.
|
1492 |
+
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
1493 |
+
Classification scores (before SoftMax).
|
1494 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1495 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1496 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1497 |
+
|
1498 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1499 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1500 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1501 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1502 |
+
|
1503 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1504 |
+
heads.
|
1505 |
+
|
1506 |
+
Examples::
|
1507 |
+
|
1508 |
+
from transformers import BertTokenizer, BertForTokenClassification
|
1509 |
+
import torch
|
1510 |
+
|
1511 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1512 |
+
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
|
1513 |
+
|
1514 |
+
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
|
1515 |
+
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
|
1516 |
+
outputs = model(input_ids, labels=labels)
|
1517 |
+
|
1518 |
+
loss, scores = outputs[:2]
|
1519 |
+
|
1520 |
+
"""
|
1521 |
+
|
1522 |
+
outputs = self.bert(
|
1523 |
+
input_ids,
|
1524 |
+
attention_mask=attention_mask,
|
1525 |
+
token_type_ids=token_type_ids,
|
1526 |
+
position_ids=position_ids,
|
1527 |
+
head_mask=head_mask,
|
1528 |
+
inputs_embeds=inputs_embeds,
|
1529 |
+
)
|
1530 |
+
|
1531 |
+
sequence_output = outputs[0]
|
1532 |
+
|
1533 |
+
sequence_output = self.dropout(sequence_output)
|
1534 |
+
logits = self.classifier(sequence_output)
|
1535 |
+
|
1536 |
+
outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
|
1537 |
+
if labels is not None:
|
1538 |
+
loss_fct = CrossEntropyLoss()
|
1539 |
+
# Only keep active parts of the loss
|
1540 |
+
if attention_mask is not None:
|
1541 |
+
active_loss = attention_mask.view(-1) == 1
|
1542 |
+
active_logits = logits.view(-1, self.num_labels)[active_loss]
|
1543 |
+
active_labels = labels.view(-1)[active_loss]
|
1544 |
+
loss = loss_fct(active_logits, active_labels)
|
1545 |
+
else:
|
1546 |
+
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
1547 |
+
outputs = (loss,) + outputs
|
1548 |
+
|
1549 |
+
return outputs # (loss), scores, (hidden_states), (attentions)
|
1550 |
+
|
1551 |
+
|
1552 |
+
@add_start_docstrings(
|
1553 |
+
"""Bert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
|
1554 |
+
layers on top of the hidden-states output to compute `span start logits` and `span end logits`). """,
|
1555 |
+
BERT_START_DOCSTRING,
|
1556 |
+
)
|
1557 |
+
class BertForQuestionAnswering(BertPreTrainedModel):
|
1558 |
+
def __init__(self, config):
|
1559 |
+
super(BertForQuestionAnswering, self).__init__(config)
|
1560 |
+
self.num_labels = config.num_labels
|
1561 |
+
|
1562 |
+
self.bert = BertModel(config)
|
1563 |
+
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
|
1564 |
+
|
1565 |
+
self.init_weights()
|
1566 |
+
|
1567 |
+
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING)
|
1568 |
+
def forward(
|
1569 |
+
self,
|
1570 |
+
input_ids=None,
|
1571 |
+
attention_mask=None,
|
1572 |
+
token_type_ids=None,
|
1573 |
+
position_ids=None,
|
1574 |
+
head_mask=None,
|
1575 |
+
inputs_embeds=None,
|
1576 |
+
start_positions=None,
|
1577 |
+
end_positions=None,
|
1578 |
+
):
|
1579 |
+
r"""
|
1580 |
+
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1581 |
+
Labels for position (index) of the start of the labelled span for computing the token classification loss.
|
1582 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1583 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1584 |
+
end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
1585 |
+
Labels for position (index) of the end of the labelled span for computing the token classification loss.
|
1586 |
+
Positions are clamped to the length of the sequence (`sequence_length`).
|
1587 |
+
Position outside of the sequence are not taken into account for computing the loss.
|
1588 |
+
|
1589 |
+
Returns:
|
1590 |
+
:obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.BertConfig`) and inputs:
|
1591 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided):
|
1592 |
+
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
|
1593 |
+
start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
1594 |
+
Span-start scores (before SoftMax).
|
1595 |
+
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
1596 |
+
Span-end scores (before SoftMax).
|
1597 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
1598 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
1599 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
1600 |
+
|
1601 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
1602 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
|
1603 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
|
1604 |
+
:obj:`(batch_size, num_heads, sequence_length, sequence_length)`.
|
1605 |
+
|
1606 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
1607 |
+
heads.
|
1608 |
+
|
1609 |
+
Examples::
|
1610 |
+
|
1611 |
+
from transformers import BertTokenizer, BertForQuestionAnswering
|
1612 |
+
import torch
|
1613 |
+
|
1614 |
+
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
1615 |
+
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
|
1616 |
+
|
1617 |
+
question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet"
|
1618 |
+
input_ids = tokenizer.encode(question, text)
|
1619 |
+
token_type_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
|
1620 |
+
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([token_type_ids]))
|
1621 |
+
|
1622 |
+
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
|
1623 |
+
answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1])
|
1624 |
+
|
1625 |
+
assert answer == "a nice puppet"
|
1626 |
+
|
1627 |
+
"""
|
1628 |
+
|
1629 |
+
outputs = self.bert(
|
1630 |
+
input_ids,
|
1631 |
+
attention_mask=attention_mask,
|
1632 |
+
token_type_ids=token_type_ids,
|
1633 |
+
position_ids=position_ids,
|
1634 |
+
head_mask=head_mask,
|
1635 |
+
inputs_embeds=inputs_embeds,
|
1636 |
+
)
|
1637 |
+
|
1638 |
+
sequence_output = outputs[0]
|
1639 |
+
|
1640 |
+
logits = self.qa_outputs(sequence_output)
|
1641 |
+
start_logits, end_logits = logits.split(1, dim=-1)
|
1642 |
+
start_logits = start_logits.squeeze(-1)
|
1643 |
+
end_logits = end_logits.squeeze(-1)
|
1644 |
+
|
1645 |
+
outputs = (start_logits, end_logits,) + outputs[2:]
|
1646 |
+
if start_positions is not None and end_positions is not None:
|
1647 |
+
# If we are on multi-GPU, split add a dimension
|
1648 |
+
if len(start_positions.size()) > 1:
|
1649 |
+
start_positions = start_positions.squeeze(-1)
|
1650 |
+
if len(end_positions.size()) > 1:
|
1651 |
+
end_positions = end_positions.squeeze(-1)
|
1652 |
+
# sometimes the start/end positions are outside our model inputs, we ignore these terms
|
1653 |
+
ignored_index = start_logits.size(1)
|
1654 |
+
start_positions.clamp_(0, ignored_index)
|
1655 |
+
end_positions.clamp_(0, ignored_index)
|
1656 |
+
|
1657 |
+
loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
|
1658 |
+
start_loss = loss_fct(start_logits, start_positions)
|
1659 |
+
end_loss = loss_fct(end_logits, end_positions)
|
1660 |
+
total_loss = (start_loss + end_loss) / 2
|
1661 |
+
outputs = (total_loss,) + outputs
|
1662 |
+
|
1663 |
+
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
run_glue.py
ADDED
@@ -0,0 +1,772 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
|
3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
""" Finetuning the library models for sequence classification on GLUE (Bert, XLM, XLNet, RoBERTa, Albert, XLM-RoBERTa)."""
|
17 |
+
|
18 |
+
|
19 |
+
import argparse
|
20 |
+
import glob
|
21 |
+
import json
|
22 |
+
import logging
|
23 |
+
import os
|
24 |
+
import random
|
25 |
+
|
26 |
+
import numpy as np
|
27 |
+
import torch
|
28 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
|
29 |
+
from torch.utils.data.distributed import DistributedSampler
|
30 |
+
from tqdm import tqdm, trange
|
31 |
+
|
32 |
+
from transformers import WEIGHTS_NAME,AdamW,AlbertConfig,AlbertTokenizer,BertConfig,BertTokenizer,DistilBertConfig,DistilBertForSequenceClassification,DistilBertTokenizer,FlaubertConfig, FlaubertForSequenceClassification,FlaubertTokenizer,RobertaConfig,RobertaForSequenceClassification,RobertaTokenizer,XLMConfig,XLMForSequenceClassification,XLMRobertaConfig,XLMRobertaForSequenceClassification,XLMRobertaTokenizer,XLMTokenizer,XLNetConfig,XLNetForSequenceClassification,XLNetTokenizer,get_linear_schedule_with_warmup
|
33 |
+
|
34 |
+
# from transformers import (
|
35 |
+
# WEIGHTS_NAME,
|
36 |
+
# AdamW,
|
37 |
+
# AlbertConfig,
|
38 |
+
# AlbertTokenizer,
|
39 |
+
# BertConfig,
|
40 |
+
# BertTokenizer,
|
41 |
+
# DistilBertConfig,
|
42 |
+
# DistilBertForSequenceClassification,
|
43 |
+
# DistilBertTokenizer,
|
44 |
+
# FlaubertConfig,
|
45 |
+
# FlaubertForSequenceClassification,
|
46 |
+
# FlaubertTokenizer,
|
47 |
+
# RobertaConfig,
|
48 |
+
# RobertaForSequenceClassification,
|
49 |
+
# RobertaTokenizer,
|
50 |
+
# XLMConfig,
|
51 |
+
# XLMForSequenceClassification,
|
52 |
+
# XLMRobertaConfig,
|
53 |
+
# XLMRobertaForSequenceClassification,
|
54 |
+
# XLMRobertaTokenizer,
|
55 |
+
# XLMTokenizer,
|
56 |
+
# XLNetConfig,
|
57 |
+
# XLNetForSequenceClassification,
|
58 |
+
# XLNetTokenizer,
|
59 |
+
# get_linear_schedule_with_warmup,
|
60 |
+
# )
|
61 |
+
|
62 |
+
from pabee.modeling_albert import AlbertForSequenceClassification
|
63 |
+
from pabee.modeling_bert import BertForSequenceClassification
|
64 |
+
from transformers import glue_compute_metrics as compute_metrics
|
65 |
+
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
66 |
+
from transformers import glue_output_modes as output_modes
|
67 |
+
from transformers import glue_processors as processors
|
68 |
+
|
69 |
+
from torch.utils.tensorboard import SummaryWriter
|
70 |
+
|
71 |
+
|
72 |
+
logger = logging.getLogger(__name__)
|
73 |
+
|
74 |
+
# ALL_MODELS = sum(
|
75 |
+
# (
|
76 |
+
# tuple(conf.pretrained_config_archive_map.keys())
|
77 |
+
# for conf in (
|
78 |
+
# BertConfig,
|
79 |
+
# XLNetConfig,
|
80 |
+
# XLMConfig,
|
81 |
+
# RobertaConfig,
|
82 |
+
# DistilBertConfig,
|
83 |
+
# AlbertConfig,
|
84 |
+
# XLMRobertaConfig,
|
85 |
+
# FlaubertConfig,
|
86 |
+
# )
|
87 |
+
# ),
|
88 |
+
# (),
|
89 |
+
# )
|
90 |
+
|
91 |
+
MODEL_CLASSES = {
|
92 |
+
"bert": (BertConfig, BertForSequenceClassification, BertTokenizer),
|
93 |
+
# "xlnet": (XLNetConfig, XLNetForSequenceClassification, XLNetTokenizer),
|
94 |
+
# "xlm": (XLMConfig, XLMForSequenceClassification, XLMTokenizer),
|
95 |
+
# "roberta": (RobertaConfig, RobertaForSequenceClassification, RobertaTokenizer),
|
96 |
+
# "distilbert": (DistilBertConfig, DistilBertForSequenceClassification, DistilBertTokenizer),
|
97 |
+
"albert": (AlbertConfig, AlbertForSequenceClassification, AlbertTokenizer),
|
98 |
+
# "xlmroberta": (XLMRobertaConfig, XLMRobertaForSequenceClassification, XLMRobertaTokenizer),
|
99 |
+
# "flaubert": (FlaubertConfig, FlaubertForSequenceClassification, FlaubertTokenizer),
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def set_seed(args):
|
104 |
+
random.seed(args.seed)
|
105 |
+
np.random.seed(args.seed)
|
106 |
+
torch.manual_seed(args.seed)
|
107 |
+
if args.n_gpu > 0:
|
108 |
+
torch.cuda.manual_seed_all(args.seed)
|
109 |
+
|
110 |
+
|
111 |
+
def train(args, train_dataset, model, tokenizer):
|
112 |
+
""" Train the model """
|
113 |
+
if args.local_rank in [-1, 0]:
|
114 |
+
tb_writer = SummaryWriter()
|
115 |
+
|
116 |
+
args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
|
117 |
+
train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
|
118 |
+
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
119 |
+
|
120 |
+
if args.max_steps > 0:
|
121 |
+
t_total = args.max_steps
|
122 |
+
args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
|
123 |
+
else:
|
124 |
+
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
|
125 |
+
|
126 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
127 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
128 |
+
optimizer_grouped_parameters = [
|
129 |
+
{
|
130 |
+
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
|
131 |
+
"weight_decay": args.weight_decay,
|
132 |
+
},
|
133 |
+
{"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
|
134 |
+
]
|
135 |
+
|
136 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
|
137 |
+
scheduler = get_linear_schedule_with_warmup(
|
138 |
+
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
139 |
+
)
|
140 |
+
|
141 |
+
# Check if saved optimizer or scheduler states exist
|
142 |
+
if os.path.isfile(os.path.join(args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
|
143 |
+
os.path.join(args.model_name_or_path, "scheduler.pt")
|
144 |
+
):
|
145 |
+
# Load in optimizer and scheduler states
|
146 |
+
optimizer.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
|
147 |
+
scheduler.load_state_dict(torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))
|
148 |
+
|
149 |
+
# if args.fp16:
|
150 |
+
# try:
|
151 |
+
# from apex import amp
|
152 |
+
# except ImportError:
|
153 |
+
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
|
154 |
+
# model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)
|
155 |
+
|
156 |
+
# multi-gpu training (should be after apex fp16 initialization)
|
157 |
+
if args.n_gpu > 1:
|
158 |
+
model = torch.nn.DataParallel(model)
|
159 |
+
|
160 |
+
# Distributed training (should be after apex fp16 initialization)
|
161 |
+
if args.local_rank != -1:
|
162 |
+
model = torch.nn.parallel.DistributedDataParallel(
|
163 |
+
model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True,
|
164 |
+
)
|
165 |
+
|
166 |
+
# Train!
|
167 |
+
logger.info("***** Running training *****")
|
168 |
+
logger.info(" Num examples = %d", len(train_dataset))
|
169 |
+
logger.info(" Num Epochs = %d", args.num_train_epochs)
|
170 |
+
logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size)
|
171 |
+
logger.info(
|
172 |
+
" Total train batch size (w. parallel, distributed & accumulation) = %d",
|
173 |
+
args.train_batch_size
|
174 |
+
* args.gradient_accumulation_steps
|
175 |
+
* (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
|
176 |
+
)
|
177 |
+
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
|
178 |
+
logger.info(" Total optimization steps = %d", t_total)
|
179 |
+
|
180 |
+
global_step = 0
|
181 |
+
epochs_trained = 0
|
182 |
+
steps_trained_in_current_epoch = 0
|
183 |
+
# Check if continuing training from a checkpoint
|
184 |
+
if os.path.exists(args.model_name_or_path):
|
185 |
+
# set global_step to gobal_step of last saved checkpoint from model path
|
186 |
+
global_step = int(args.model_name_or_path.split("-")[-1].split("/")[0])
|
187 |
+
epochs_trained = global_step // (len(train_dataloader) // args.gradient_accumulation_steps)
|
188 |
+
steps_trained_in_current_epoch = global_step % (len(train_dataloader) // args.gradient_accumulation_steps)
|
189 |
+
|
190 |
+
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
191 |
+
logger.info(" Continuing training from epoch %d", epochs_trained)
|
192 |
+
logger.info(" Continuing training from global step %d", global_step)
|
193 |
+
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
194 |
+
|
195 |
+
tr_loss, logging_loss = 0.0, 0.0
|
196 |
+
model.zero_grad()
|
197 |
+
train_iterator = trange(
|
198 |
+
epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0],
|
199 |
+
)
|
200 |
+
set_seed(args) # Added here for reproductibility
|
201 |
+
for _ in train_iterator:
|
202 |
+
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
|
203 |
+
for step, batch in enumerate(epoch_iterator):
|
204 |
+
|
205 |
+
# Skip past any already trained steps if resuming training
|
206 |
+
if steps_trained_in_current_epoch > 0:
|
207 |
+
steps_trained_in_current_epoch -= 1
|
208 |
+
continue
|
209 |
+
|
210 |
+
model.train()
|
211 |
+
batch = tuple(t.to(args.device) for t in batch)
|
212 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
213 |
+
if args.model_type != "distilbert":
|
214 |
+
inputs["token_type_ids"] = (
|
215 |
+
batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
|
216 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
217 |
+
outputs = model(**inputs)
|
218 |
+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
|
219 |
+
|
220 |
+
if args.n_gpu > 1:
|
221 |
+
loss = loss.mean() # mean() to average on multi-gpu parallel training
|
222 |
+
if args.gradient_accumulation_steps > 1:
|
223 |
+
loss = loss / args.gradient_accumulation_steps
|
224 |
+
|
225 |
+
# if args.fp16:
|
226 |
+
# with amp.scale_loss(loss, optimizer) as scaled_loss:
|
227 |
+
# scaled_loss.backward()
|
228 |
+
else:
|
229 |
+
loss.backward()
|
230 |
+
|
231 |
+
tr_loss += loss.item()
|
232 |
+
if (step + 1) % args.gradient_accumulation_steps == 0:
|
233 |
+
# if args.fp16:
|
234 |
+
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
|
235 |
+
# else:
|
236 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
237 |
+
|
238 |
+
optimizer.step()
|
239 |
+
scheduler.step() # Update learning rate schedule
|
240 |
+
model.zero_grad()
|
241 |
+
global_step += 1
|
242 |
+
|
243 |
+
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
|
244 |
+
logs = {}
|
245 |
+
if (
|
246 |
+
args.local_rank == -1 and args.evaluate_during_training
|
247 |
+
): # Only evaluate when single GPU otherwise metrics may not average well
|
248 |
+
results = evaluate(args, model, tokenizer)
|
249 |
+
for key, value in results.items():
|
250 |
+
eval_key = "eval_{}".format(key)
|
251 |
+
logs[eval_key] = value
|
252 |
+
|
253 |
+
loss_scalar = (tr_loss - logging_loss) / args.logging_steps
|
254 |
+
learning_rate_scalar = scheduler.get_lr()[0]
|
255 |
+
logs["learning_rate"] = learning_rate_scalar
|
256 |
+
logs["loss"] = loss_scalar
|
257 |
+
logging_loss = tr_loss
|
258 |
+
|
259 |
+
for key, value in logs.items():
|
260 |
+
tb_writer.add_scalar(key, value, global_step)
|
261 |
+
print(json.dumps({**logs, **{"step": global_step}}))
|
262 |
+
|
263 |
+
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
|
264 |
+
# Save model checkpoint
|
265 |
+
output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step))
|
266 |
+
if not os.path.exists(output_dir):
|
267 |
+
os.makedirs(output_dir)
|
268 |
+
model_to_save = (
|
269 |
+
model.module if hasattr(model, "module") else model
|
270 |
+
) # Take care of distributed/parallel training
|
271 |
+
model_to_save.save_pretrained(output_dir)
|
272 |
+
tokenizer.save_pretrained(output_dir)
|
273 |
+
|
274 |
+
torch.save(args, os.path.join(output_dir, "training_args.bin"))
|
275 |
+
logger.info("Saving model checkpoint to %s", output_dir)
|
276 |
+
|
277 |
+
torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
|
278 |
+
torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
|
279 |
+
logger.info("Saving optimizer and scheduler states to %s", output_dir)
|
280 |
+
|
281 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
282 |
+
epoch_iterator.close()
|
283 |
+
break
|
284 |
+
if args.max_steps > 0 and global_step > args.max_steps:
|
285 |
+
train_iterator.close()
|
286 |
+
break
|
287 |
+
|
288 |
+
if args.local_rank in [-1, 0]:
|
289 |
+
tb_writer.close()
|
290 |
+
|
291 |
+
return global_step, tr_loss / global_step
|
292 |
+
|
293 |
+
|
294 |
+
def evaluate(args, model, tokenizer, prefix="", patience=0):
|
295 |
+
|
296 |
+
if args.model_type == 'albert':
|
297 |
+
model.albert.set_regression_threshold(args.regression_threshold)
|
298 |
+
if args.do_train:
|
299 |
+
model.albert.set_mode('last')
|
300 |
+
elif args.eval_mode == 'patience':
|
301 |
+
model.albert.set_mode('patience')
|
302 |
+
model.albert.set_patience(patience)
|
303 |
+
elif args.eval_mode == 'confi':
|
304 |
+
model.albert.set_mode('confi')
|
305 |
+
model.albert.set_confi_threshold(patience)
|
306 |
+
model.albert.reset_stats()
|
307 |
+
elif args.model_type == 'bert':
|
308 |
+
model.bert.set_regression_threshold(args.regression_threshold)
|
309 |
+
if args.do_train:
|
310 |
+
model.bert.set_mode('last')
|
311 |
+
elif args.eval_mode == 'patience':
|
312 |
+
model.bert.set_mode('patience')
|
313 |
+
model.bert.set_patience(patience)
|
314 |
+
elif args.eval_mode == 'confi':
|
315 |
+
model.bert.set_mode('confi')
|
316 |
+
model.bert.set_confi_threshold(patience)
|
317 |
+
model.bert.reset_stats()
|
318 |
+
else:
|
319 |
+
raise NotImplementedError()
|
320 |
+
|
321 |
+
# Loop to handle MNLI double evaluation (matched, mis-matched)
|
322 |
+
eval_task_names = ("mnli", "mnli-mm") if args.task_name == "mnli" else (args.task_name,)
|
323 |
+
eval_outputs_dirs = (args.output_dir, args.output_dir + "-MM") if args.task_name == "mnli" else (args.output_dir,)
|
324 |
+
|
325 |
+
results = {}
|
326 |
+
for eval_task, eval_output_dir in zip(eval_task_names, eval_outputs_dirs):
|
327 |
+
eval_dataset = load_and_cache_examples(args, eval_task, tokenizer, evaluate=True)
|
328 |
+
|
329 |
+
if not os.path.exists(eval_output_dir) and args.local_rank in [-1, 0]:
|
330 |
+
os.makedirs(eval_output_dir)
|
331 |
+
|
332 |
+
args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
|
333 |
+
# Note that DistributedSampler samples randomly
|
334 |
+
eval_sampler = SequentialSampler(eval_dataset)
|
335 |
+
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
|
336 |
+
|
337 |
+
# multi-gpu eval
|
338 |
+
if args.n_gpu > 1 and not isinstance(model, torch.nn.DataParallel):
|
339 |
+
model = torch.nn.DataParallel(model)
|
340 |
+
|
341 |
+
# Eval!
|
342 |
+
logger.info("***** Running evaluation {} *****".format(prefix))
|
343 |
+
logger.info(" Num examples = %d", len(eval_dataset))
|
344 |
+
logger.info(" Batch size = %d", args.eval_batch_size)
|
345 |
+
eval_loss = 0.0
|
346 |
+
nb_eval_steps = 0
|
347 |
+
preds = None
|
348 |
+
out_label_ids = None
|
349 |
+
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
350 |
+
model.eval()
|
351 |
+
batch = tuple(t.to(args.device) for t in batch)
|
352 |
+
|
353 |
+
with torch.no_grad():
|
354 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "labels": batch[3]}
|
355 |
+
if args.model_type != "distilbert":
|
356 |
+
inputs["token_type_ids"] = (
|
357 |
+
batch[2] if args.model_type in ["bert", "xlnet", "albert"] else None
|
358 |
+
) # XLM, DistilBERT, RoBERTa, and XLM-RoBERTa don't use segment_ids
|
359 |
+
outputs = model(**inputs)
|
360 |
+
tmp_eval_loss, logits = outputs[:2]
|
361 |
+
|
362 |
+
eval_loss += tmp_eval_loss.mean().item()
|
363 |
+
nb_eval_steps += 1
|
364 |
+
if preds is None:
|
365 |
+
preds = logits.detach().cpu().numpy()
|
366 |
+
out_label_ids = inputs["labels"].detach().cpu().numpy()
|
367 |
+
else:
|
368 |
+
preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
|
369 |
+
out_label_ids = np.append(out_label_ids, inputs["labels"].detach().cpu().numpy(), axis=0)
|
370 |
+
|
371 |
+
eval_loss = eval_loss / nb_eval_steps
|
372 |
+
if args.output_mode == "classification":
|
373 |
+
preds = np.argmax(preds, axis=1)
|
374 |
+
elif args.output_mode == "regression":
|
375 |
+
preds = np.squeeze(preds)
|
376 |
+
result = compute_metrics(eval_task, preds, out_label_ids)
|
377 |
+
results.update(result)
|
378 |
+
|
379 |
+
output_eval_file = os.path.join(eval_output_dir, prefix, "eval_results.txt")
|
380 |
+
with open(output_eval_file, "w") as writer:
|
381 |
+
logger.info("***** Eval results {} *****".format(prefix))
|
382 |
+
for key in sorted(result.keys()):
|
383 |
+
logger.info(" %s = %s", key, str(result[key]))
|
384 |
+
print(" %s = %s" % (key, str(result[key])))
|
385 |
+
writer.write("%s = %s\n" % (key, str(result[key])))
|
386 |
+
|
387 |
+
if args.eval_all_checkpoints and patience != 0:
|
388 |
+
if args.model_type == 'albert':
|
389 |
+
model.albert.log_stats()
|
390 |
+
elif args.model_type == 'bert':
|
391 |
+
model.bert.log_stats()
|
392 |
+
else:
|
393 |
+
raise NotImplementedError()
|
394 |
+
|
395 |
+
return results
|
396 |
+
|
397 |
+
|
398 |
+
def load_and_cache_examples(args, task, tokenizer, evaluate=False):
|
399 |
+
if args.local_rank not in [-1, 0] and not evaluate:
|
400 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
401 |
+
|
402 |
+
processor = processors[task]()
|
403 |
+
output_mode = output_modes[task]
|
404 |
+
# Load data features from cache or dataset file
|
405 |
+
cached_features_file = os.path.join(
|
406 |
+
args.data_dir,
|
407 |
+
"cached_{}_{}_{}_{}".format(
|
408 |
+
"dev" if evaluate else "train",
|
409 |
+
list(filter(None, args.model_name_or_path.split("/"))).pop(),
|
410 |
+
str(args.max_seq_length),
|
411 |
+
str(task),
|
412 |
+
),
|
413 |
+
)
|
414 |
+
if os.path.exists(cached_features_file) and not args.overwrite_cache:
|
415 |
+
logger.info("Loading features from cached file %s", cached_features_file)
|
416 |
+
features = torch.load(cached_features_file)
|
417 |
+
else:
|
418 |
+
logger.info("Creating features from dataset file at %s", args.data_dir)
|
419 |
+
label_list = processor.get_labels()
|
420 |
+
if task in ["mnli", "mnli-mm"] and args.model_type in ["roberta", "xlmroberta"]:
|
421 |
+
# HACK(label indices are swapped in RoBERTa pretrained model)
|
422 |
+
label_list[1], label_list[2] = label_list[2], label_list[1]
|
423 |
+
examples = (
|
424 |
+
processor.get_dev_examples(args.data_dir) if evaluate else processor.get_train_examples(args.data_dir)
|
425 |
+
)
|
426 |
+
# convert words into features
|
427 |
+
features = convert_examples_to_features(
|
428 |
+
examples,
|
429 |
+
tokenizer,
|
430 |
+
label_list=label_list,
|
431 |
+
max_length=args.max_seq_length,
|
432 |
+
output_mode=output_mode,
|
433 |
+
# pad_on_left=bool(args.model_type in ["xlnet"]), # pad on the left for xlnet
|
434 |
+
# pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
|
435 |
+
# pad_token_segment_id=4 if args.model_type in ["xlnet"] else 0,
|
436 |
+
)
|
437 |
+
if args.local_rank in [-1, 0]:
|
438 |
+
logger.info("Saving features into cached file %s", cached_features_file)
|
439 |
+
torch.save(features, cached_features_file)
|
440 |
+
|
441 |
+
if args.local_rank == 0 and not evaluate:
|
442 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training process the dataset, and the others will use the cache
|
443 |
+
|
444 |
+
# Convert to Tensors and build dataset
|
445 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
446 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
447 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
448 |
+
if output_mode == "classification":
|
449 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.long)
|
450 |
+
elif output_mode == "regression":
|
451 |
+
all_labels = torch.tensor([f.label for f in features], dtype=torch.float)
|
452 |
+
|
453 |
+
dataset = TensorDataset(all_input_ids, all_attention_mask, all_token_type_ids, all_labels)
|
454 |
+
return dataset
|
455 |
+
|
456 |
+
|
457 |
+
def main():
|
458 |
+
parser = argparse.ArgumentParser()
|
459 |
+
|
460 |
+
# Required parameters
|
461 |
+
parser.add_argument(
|
462 |
+
"--data_dir",
|
463 |
+
default=None,
|
464 |
+
type=str,
|
465 |
+
required=True,
|
466 |
+
help="The input data dir. Should contain the .tsv files (or other data files) for the task.",
|
467 |
+
)
|
468 |
+
parser.add_argument(
|
469 |
+
"--model_type",
|
470 |
+
default=None,
|
471 |
+
type=str,
|
472 |
+
required=True,
|
473 |
+
help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()),
|
474 |
+
)
|
475 |
+
parser.add_argument(
|
476 |
+
"--model_name_or_path",
|
477 |
+
default=None,
|
478 |
+
type=str,
|
479 |
+
required=True,
|
480 |
+
help="Path to pre-trained model or shortcut name selected in the list: " # + ", ".join(ALL_MODELS),
|
481 |
+
)
|
482 |
+
parser.add_argument(
|
483 |
+
"--task_name",
|
484 |
+
default=None,
|
485 |
+
type=str,
|
486 |
+
required=True,
|
487 |
+
help="The name of the task to train selected in the list: " + ", ".join(processors.keys())
|
488 |
+
)
|
489 |
+
parser.add_argument(
|
490 |
+
"--output_dir",
|
491 |
+
default=None,
|
492 |
+
type=str,
|
493 |
+
required=True,
|
494 |
+
help="The output directory where the model predictions and checkpoints will be written.",
|
495 |
+
)
|
496 |
+
parser.add_argument(
|
497 |
+
"--patience",
|
498 |
+
default='0',
|
499 |
+
type=str,
|
500 |
+
required=False,
|
501 |
+
)
|
502 |
+
parser.add_argument(
|
503 |
+
"--regression_threshold",
|
504 |
+
default=0,
|
505 |
+
type=float,
|
506 |
+
required=False,
|
507 |
+
)
|
508 |
+
# Other parameters
|
509 |
+
parser.add_argument(
|
510 |
+
"--config_name", default="", type=str, help="Pretrained config name or path if not the same as model_name",
|
511 |
+
)
|
512 |
+
parser.add_argument(
|
513 |
+
"--tokenizer_name",
|
514 |
+
default="",
|
515 |
+
type=str,
|
516 |
+
help="Pretrained tokenizer name or path if not the same as model_name",
|
517 |
+
)
|
518 |
+
parser.add_argument(
|
519 |
+
"--cache_dir",
|
520 |
+
default="",
|
521 |
+
type=str,
|
522 |
+
help="Where do you want to store the pre-trained models downloaded from s3",
|
523 |
+
)
|
524 |
+
parser.add_argument(
|
525 |
+
"--max_seq_length",
|
526 |
+
default=128,
|
527 |
+
type=int,
|
528 |
+
help="The maximum total input sequence length after tokenization. Sequences longer "
|
529 |
+
"than this will be truncated, sequences shorter will be padded.",
|
530 |
+
)
|
531 |
+
parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
|
532 |
+
parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
|
533 |
+
parser.add_argument(
|
534 |
+
"--evaluate_during_training", action="store_true", help="Run evaluation during training at each logging step.",
|
535 |
+
)
|
536 |
+
parser.add_argument(
|
537 |
+
"--do_lower_case", action="store_true", help="Set this flag if you are using an uncased model.",
|
538 |
+
)
|
539 |
+
|
540 |
+
parser.add_argument(
|
541 |
+
"--per_gpu_train_batch_size", default=8, type=int, help="Batch size per GPU/CPU for training.",
|
542 |
+
)
|
543 |
+
parser.add_argument(
|
544 |
+
"--per_gpu_eval_batch_size", default=8, type=int, help="Batch size per GPU/CPU for evaluation.",
|
545 |
+
)
|
546 |
+
parser.add_argument(
|
547 |
+
"--gradient_accumulation_steps",
|
548 |
+
type=int,
|
549 |
+
default=1,
|
550 |
+
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
551 |
+
)
|
552 |
+
parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
|
553 |
+
parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.")
|
554 |
+
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
|
555 |
+
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
556 |
+
parser.add_argument(
|
557 |
+
"--num_train_epochs", default=3.0, type=float, help="Total number of training epochs to perform.",
|
558 |
+
)
|
559 |
+
parser.add_argument(
|
560 |
+
"--max_steps",
|
561 |
+
default=-1,
|
562 |
+
type=int,
|
563 |
+
help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
|
564 |
+
)
|
565 |
+
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
|
566 |
+
|
567 |
+
parser.add_argument("--logging_steps", type=int, default=500, help="Log every X updates steps.")
|
568 |
+
parser.add_argument("--save_steps", type=int, default=500, help="Save checkpoint every X updates steps.")
|
569 |
+
parser.add_argument(
|
570 |
+
"--eval_all_checkpoints",
|
571 |
+
action="store_true",
|
572 |
+
help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number",
|
573 |
+
)
|
574 |
+
parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available")
|
575 |
+
parser.add_argument(
|
576 |
+
"--overwrite_output_dir", action="store_true", help="Overwrite the content of the output directory",
|
577 |
+
)
|
578 |
+
parser.add_argument(
|
579 |
+
"--overwrite_cache", action="store_true", help="Overwrite the cached training and evaluation sets",
|
580 |
+
)
|
581 |
+
parser.add_argument("--seed", type=int, default=42, help="random seed for initialization")
|
582 |
+
|
583 |
+
parser.add_argument(
|
584 |
+
"--fp16",
|
585 |
+
action="store_true",
|
586 |
+
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
|
587 |
+
)
|
588 |
+
parser.add_argument(
|
589 |
+
"--fp16_opt_level",
|
590 |
+
type=str,
|
591 |
+
default="O1",
|
592 |
+
help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
|
593 |
+
"See details at https://nvidia.github.io/apex/amp.html",
|
594 |
+
)
|
595 |
+
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
596 |
+
parser.add_argument("--server_ip", type=str, default="", help="For distant debugging.")
|
597 |
+
parser.add_argument("--server_port", type=str, default="", help="For distant debugging.")
|
598 |
+
|
599 |
+
parser.add_argument("--eval_mode",type=str,default="patience",help='the evaluation mode for the multi-exit BERT patience|confi')
|
600 |
+
args = parser.parse_args()
|
601 |
+
|
602 |
+
if (
|
603 |
+
os.path.exists(args.output_dir)
|
604 |
+
and os.listdir(args.output_dir)
|
605 |
+
and args.do_train
|
606 |
+
and not args.overwrite_output_dir
|
607 |
+
):
|
608 |
+
raise ValueError(
|
609 |
+
"Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(
|
610 |
+
args.output_dir
|
611 |
+
)
|
612 |
+
)
|
613 |
+
|
614 |
+
# Setup distant debugging if needed
|
615 |
+
# if args.server_ip and args.server_port:
|
616 |
+
# # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
|
617 |
+
# import ptvsd
|
618 |
+
|
619 |
+
# print("Waiting for debugger attach")
|
620 |
+
# ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
|
621 |
+
# ptvsd.wait_for_attach()
|
622 |
+
|
623 |
+
# TODO: 这里是不是错了? Distributed
|
624 |
+
# Setup CUDA, GPU & distributed training
|
625 |
+
if args.local_rank == -1 or args.no_cuda:
|
626 |
+
print(f'CUDA STATUS: {torch.cuda.is_available()}')
|
627 |
+
device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
|
628 |
+
args.n_gpu = torch.cuda.device_count()
|
629 |
+
else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
|
630 |
+
torch.cuda.set_device(args.local_rank)
|
631 |
+
device = torch.device("cuda", args.local_rank)
|
632 |
+
torch.distributed.init_process_group(backend="nccl")
|
633 |
+
args.n_gpu = 1
|
634 |
+
args.device = device
|
635 |
+
|
636 |
+
# Setup logging
|
637 |
+
logging.basicConfig(
|
638 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
639 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
640 |
+
level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN,
|
641 |
+
)
|
642 |
+
logger.warning(
|
643 |
+
"Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
|
644 |
+
args.local_rank,
|
645 |
+
device,
|
646 |
+
args.n_gpu,
|
647 |
+
bool(args.local_rank != -1),
|
648 |
+
args.fp16,
|
649 |
+
)
|
650 |
+
|
651 |
+
# Set seed
|
652 |
+
set_seed(args)
|
653 |
+
|
654 |
+
# Prepare GLUE task
|
655 |
+
args.task_name = args.task_name.lower()
|
656 |
+
if args.task_name not in processors:
|
657 |
+
raise ValueError("Task not found: %s" % (args.task_name))
|
658 |
+
processor = processors[args.task_name]() # transformers package-preprocessor
|
659 |
+
args.output_mode = output_modes[args.task_name] # output type
|
660 |
+
label_list = processor.get_labels()
|
661 |
+
num_labels = len(label_list)
|
662 |
+
print(f'num labels: {num_labels}')
|
663 |
+
|
664 |
+
# Load pretrained model and tokenizer
|
665 |
+
if args.local_rank not in [-1, 0]:
|
666 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
667 |
+
|
668 |
+
args.model_type = args.model_type.lower()
|
669 |
+
config_class, model_class, tokenizer_class = MODEL_CLASSES[args.model_type]
|
670 |
+
config = config_class.from_pretrained(
|
671 |
+
args.config_name if args.config_name else args.model_name_or_path,
|
672 |
+
num_labels=num_labels,
|
673 |
+
finetuning_task=args.task_name,
|
674 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
675 |
+
)
|
676 |
+
tokenizer = tokenizer_class.from_pretrained(
|
677 |
+
args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
|
678 |
+
do_lower_case=args.do_lower_case,
|
679 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
680 |
+
)
|
681 |
+
model = model_class.from_pretrained(
|
682 |
+
args.model_name_or_path,
|
683 |
+
from_tf=bool(".ckpt" in args.model_name_or_path),
|
684 |
+
config=config,
|
685 |
+
cache_dir=args.cache_dir if args.cache_dir else None,
|
686 |
+
)
|
687 |
+
|
688 |
+
if args.local_rank == 0:
|
689 |
+
torch.distributed.barrier() # Make sure only the first process in distributed training will download model & vocab
|
690 |
+
|
691 |
+
model.to(args.device)
|
692 |
+
|
693 |
+
print('Total Model Parameters:', sum(param.numel() for param in model.parameters()))
|
694 |
+
output_layers_param_num = sum(param.numel() for param in model.classifiers.parameters())
|
695 |
+
print('Output Layers Parameters:', output_layers_param_num)
|
696 |
+
single_output_layer_param_num = sum(param.numel() for param in model.classifiers[0].parameters())
|
697 |
+
print('Added Output Layers Parameters:', output_layers_param_num - single_output_layer_param_num)
|
698 |
+
|
699 |
+
logger.info("Training/evaluation parameters %s", args)
|
700 |
+
|
701 |
+
# Training
|
702 |
+
if args.do_train:
|
703 |
+
train_dataset = load_and_cache_examples(args, args.task_name, tokenizer, evaluate=False)
|
704 |
+
global_step, tr_loss = train(args, train_dataset, model, tokenizer)
|
705 |
+
logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)
|
706 |
+
|
707 |
+
# Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
|
708 |
+
if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
|
709 |
+
# Create output directory if needed
|
710 |
+
if not os.path.exists(args.output_dir) and args.local_rank in [-1, 0]:
|
711 |
+
os.makedirs(args.output_dir)
|
712 |
+
|
713 |
+
logger.info("Saving model checkpoint to %s", args.output_dir)
|
714 |
+
# Save a trained model, configuration and tokenizer using `save_pretrained()`.
|
715 |
+
# They can then be reloaded using `from_pretrained()`
|
716 |
+
model_to_save = (
|
717 |
+
model.module if hasattr(model, "module") else model
|
718 |
+
) # Take care of distributed/parallel training
|
719 |
+
model_to_save.save_pretrained(args.output_dir)
|
720 |
+
tokenizer.save_pretrained(args.output_dir)
|
721 |
+
|
722 |
+
# Good practice: save your training arguments together with the trained model
|
723 |
+
torch.save(args, os.path.join(args.output_dir, "training_args.bin"))
|
724 |
+
|
725 |
+
# Load a trained model and vocabulary that you have fine-tuned
|
726 |
+
model = model_class.from_pretrained(args.output_dir)
|
727 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir)
|
728 |
+
model.to(args.device)
|
729 |
+
|
730 |
+
# Evaluation
|
731 |
+
results = {}
|
732 |
+
if args.do_eval and args.local_rank in [-1, 0]:
|
733 |
+
if args.eval_mode == 'patience':
|
734 |
+
patience_list = [int(x) for x in args.patience.split(',')]
|
735 |
+
elif args.eval_mode == 'confi':
|
736 |
+
patience_list = [float(x) for x in args.patience.split(',')]
|
737 |
+
tokenizer = tokenizer_class.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
|
738 |
+
checkpoints = [args.output_dir]
|
739 |
+
if args.eval_all_checkpoints:
|
740 |
+
checkpoints = list(
|
741 |
+
os.path.dirname(c) for c in sorted(glob.glob(args.output_dir + "/**/" + WEIGHTS_NAME, recursive=True))
|
742 |
+
)
|
743 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.WARN) # Reduce logging
|
744 |
+
logger.info("Evaluate the following checkpoints: %s", checkpoints)
|
745 |
+
|
746 |
+
for checkpoint in checkpoints:
|
747 |
+
if '600' not in checkpoint:
|
748 |
+
continue
|
749 |
+
|
750 |
+
global_step = checkpoint.split("-")[-1] if len(checkpoints) > 1 else ""
|
751 |
+
prefix = checkpoint.split("/")[-1] if checkpoint.find("checkpoint") != -1 else ""
|
752 |
+
|
753 |
+
model = model_class.from_pretrained(checkpoint)
|
754 |
+
model.to(args.device)
|
755 |
+
|
756 |
+
print(f'Evaluation for checkpoint {prefix}')
|
757 |
+
for patience in patience_list:
|
758 |
+
print(f'------ Patience Threshold: {patience} ------')
|
759 |
+
result = evaluate(args, model, tokenizer, prefix=prefix, patience=patience)
|
760 |
+
result = dict((k + "_{}".format(global_step), v) for k, v in result.items())
|
761 |
+
results.update(result)
|
762 |
+
if args.model_type == 'albert':
|
763 |
+
print(f'Exits Distribution: {model.albert.exits_count_list}')
|
764 |
+
elif args.model_type == 'bert':
|
765 |
+
print(f'Exits Distribution: {model.bert.exits_count_list}')
|
766 |
+
|
767 |
+
|
768 |
+
return results
|
769 |
+
|
770 |
+
|
771 |
+
if __name__ == "__main__":
|
772 |
+
main()
|
whitebox_utils/__pycache__/attack.cpython-37.pyc
ADDED
Binary file (3.95 kB). View file
|
|
whitebox_utils/__pycache__/classifier.cpython-37.pyc
ADDED
Binary file (3.84 kB). View file
|
|
whitebox_utils/__pycache__/metric.cpython-37.pyc
ADDED
Binary file (1.37 kB). View file
|
|
whitebox_utils/classifier.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from transformers import glue_convert_examples_to_features as convert_examples_to_features
|
5 |
+
from transformers import InputExample
|
6 |
+
|
7 |
+
class MyClassifier():
|
8 |
+
def __init__(self,model,tokenizer,label_list,output_mode,exit_type,exit_value,model_type='albert',max_length=128):
|
9 |
+
self.model = model
|
10 |
+
self.model.eval()
|
11 |
+
self.model_type = model_type
|
12 |
+
self.tokenizer = tokenizer
|
13 |
+
self.label_list = label_list
|
14 |
+
self.output_mode = output_mode
|
15 |
+
self.max_length = max_length
|
16 |
+
self.exit_type = exit_type
|
17 |
+
self.exit_value = exit_value
|
18 |
+
self.count = 0
|
19 |
+
self.reset_status(mode='all',stats=True)
|
20 |
+
if exit_type == 'patience':
|
21 |
+
self.set_patience(patience=exit_value)
|
22 |
+
elif exit_type == 'confi':
|
23 |
+
self.set_threshold(confidence_threshold=exit_value)
|
24 |
+
|
25 |
+
def tokenize(self,input_,idx):
|
26 |
+
examples = []
|
27 |
+
guid = f"dev_{idx}"
|
28 |
+
if input_[1] == "<none>":
|
29 |
+
text_a = input_[0]
|
30 |
+
text_b = None
|
31 |
+
else:
|
32 |
+
text_a = input_[0]
|
33 |
+
text_b = input_[1]
|
34 |
+
# print(f'len: {len(input_)}\t text_a: {text_a}\t text_b:{text_b}')
|
35 |
+
label = None
|
36 |
+
examples.append(InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
|
37 |
+
# print(examples)
|
38 |
+
features = convert_examples_to_features(
|
39 |
+
examples,
|
40 |
+
self.tokenizer,
|
41 |
+
label_list=self.label_list,
|
42 |
+
max_length=self.max_length,
|
43 |
+
output_mode=self.output_mode,
|
44 |
+
)
|
45 |
+
# print(features)
|
46 |
+
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
47 |
+
all_attention_mask = torch.tensor([f.attention_mask for f in features], dtype=torch.long)
|
48 |
+
all_token_type_ids = torch.tensor([f.token_type_ids for f in features], dtype=torch.long)
|
49 |
+
return all_input_ids,all_attention_mask,all_token_type_ids
|
50 |
+
|
51 |
+
def set_threshold(self,confidence_threshold):
|
52 |
+
if self.model_type == 'albert':
|
53 |
+
self.model.albert.set_confi_threshold(confidence_threshold)
|
54 |
+
elif self.model_type == 'bert':
|
55 |
+
self.model.bert.set_confi_threshold(confidence_threshold)
|
56 |
+
|
57 |
+
def set_patience(self,patience):
|
58 |
+
if self.model_type == 'albert':
|
59 |
+
self.model.albert.set_patience(patience)
|
60 |
+
elif self.model_type == 'bert':
|
61 |
+
self.model.bert.set_patience(patience)
|
62 |
+
|
63 |
+
def set_exit_position(self,exit_pos):
|
64 |
+
if self.model_type == 'albert':
|
65 |
+
self.model.albert.set_exit_pos = exit_pos
|
66 |
+
|
67 |
+
def reset_status(self,mode,stats=False):
|
68 |
+
if self.model_type == 'albert':
|
69 |
+
self.model.albert.set_mode(mode)
|
70 |
+
if stats:
|
71 |
+
self.model.albert.reset_stats()
|
72 |
+
elif self.model_type == 'bert':
|
73 |
+
self.model.bert.set_mode(mode)
|
74 |
+
if stats:
|
75 |
+
self.model.bert.reset_stats()
|
76 |
+
|
77 |
+
def get_exit_number(self):
|
78 |
+
if self.model_type == 'albert':
|
79 |
+
return self.model.albert.config.num_hidden_layers
|
80 |
+
elif self.model_type == 'bert':
|
81 |
+
return self.model.bert.config.num_hidden_layers
|
82 |
+
|
83 |
+
def get_current_exit(self):
|
84 |
+
if self.model_type == 'albert':
|
85 |
+
return self.model.albert.current_exit_layer
|
86 |
+
elif self.model_type == 'bert':
|
87 |
+
return self.model.bert.current_exit_layer
|
88 |
+
|
89 |
+
# TODO: 改一下预测算法得到预测结果
|
90 |
+
def get_pred(self,input_):
|
91 |
+
# print(self.get_prob(input_).argmax(axis=2).shape)
|
92 |
+
return self.get_prob(input_).argmax(axis=2)
|
93 |
+
|
94 |
+
def get_prob(self,input_):
|
95 |
+
self.reset_status(mode=self.exit_type,stats=False) # set patience
|
96 |
+
ret = []
|
97 |
+
for sent in input_:
|
98 |
+
self.count+=1
|
99 |
+
batch = self.tokenize(sent,idx=self.count)
|
100 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1],"token_type_ids":batch[2]}
|
101 |
+
outputs = self.model(**inputs)[0] # get all logits
|
102 |
+
output_ = torch.softmax(outputs,dim=1)[0].detach().cpu().numpy()
|
103 |
+
ret.append(output_)
|
104 |
+
return np.array(ret)
|
105 |
+
|
106 |
+
def get_prob_time(self,input_,exit_position):
|
107 |
+
self.reset_status(mode='exact',stats=False) # set patience
|
108 |
+
self.set_exit_position(exit_position)
|
109 |
+
ret = []
|
110 |
+
for sent in input_:
|
111 |
+
self.count+=1
|
112 |
+
batch = self.tokenize(sent,idx=self.count)
|
113 |
+
inputs = {"input_ids": batch[0], "attention_mask": batch[1], "token_type_ids":batch[2]}
|
114 |
+
outputs = self.model(**inputs)[0] # get all logits
|
115 |
+
output_ = [torch.softmax(output,dim=1)[0].detach().cpu().numpy() for output in outputs]
|
116 |
+
ret.append(output_)
|
117 |
+
return np.array(ret)
|