Duplicate from MatrixYao/how_many_data_points_zh
Browse filesCo-authored-by: Matrix Yao <[email protected]>
- .gitattributes +34 -0
- .gitignore +6 -0
- Dockerfile +11 -0
- README.md +12 -0
- naacl_demo/demo_utils.py +514 -0
- naacl_demo/main.py +293 -0
- naacl_demo/text.md +90 -0
- naacl_demo/text.py +171 -0
- requirements.txt +22 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
*$py.class
|
4 |
+
|
5 |
+
|
6 |
+
.env/
|
Dockerfile
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.7
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY ./requirements.txt /code/requirements.txt
|
6 |
+
|
7 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
8 |
+
|
9 |
+
COPY . .
|
10 |
+
|
11 |
+
CMD ["bokeh", "serve", "naacl_demo", "--allow-websocket-origin=*"]
|
README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: How Many Data Points
|
3 |
+
emoji: 🦀
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
app_port: 5006
|
9 |
+
duplicated_from: MatrixYao/how_many_data_points_zh
|
10 |
+
---
|
11 |
+
|
12 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
naacl_demo/demo_utils.py
ADDED
@@ -0,0 +1,514 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
from itertools import product
|
6 |
+
import shapely
|
7 |
+
from bokeh.models import Span, Label, ColumnDataSource, Whisker
|
8 |
+
from bokeh.plotting import figure, show
|
9 |
+
from shapely.geometry import Polygon
|
10 |
+
import matplotlib as mpl
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import seaborn
|
13 |
+
|
14 |
+
task_patterns = {
|
15 |
+
"CB": [0, 3],
|
16 |
+
"RTE": [0, 3],
|
17 |
+
"BoolQ": [0, 3, 5],
|
18 |
+
"MNLI": [0, 3],
|
19 |
+
"COPA": [0, 1],
|
20 |
+
"WSC": [0, 1, 2],
|
21 |
+
"WiC": [0, 1],
|
22 |
+
"MultiRC": [0, 1, 2],
|
23 |
+
}
|
24 |
+
task_reps = {"CB": 4, "RTE": 4, "BoolQ": 4, "MNLI": 4, "COPA": 4, "WSC": 4, "WiC": 4, "MultiRC": 4}
|
25 |
+
task_best_pattern = {"CB": 0, "RTE": 0, "BoolQ": 0, "MNLI": 0, "COPA": 1, "WSC": 0, "WiC": 0, "MultiRC": 1}
|
26 |
+
task_metric_short = {
|
27 |
+
"CB": "f1-macro",
|
28 |
+
"RTE": "acc",
|
29 |
+
"BoolQ": "acc",
|
30 |
+
"MNLI": "acc",
|
31 |
+
"COPA": "acc",
|
32 |
+
"WSC": "acc",
|
33 |
+
"WiC": "acc",
|
34 |
+
"MultiRC": "f1",
|
35 |
+
}
|
36 |
+
task_metrics = {
|
37 |
+
"CB": "F1-macro",
|
38 |
+
"RTE": "accuracy",
|
39 |
+
"BoolQ": "accuracy",
|
40 |
+
"MNLI": "accuracy",
|
41 |
+
"COPA": "accuracy",
|
42 |
+
"WSC": "accuracy",
|
43 |
+
"WiC": "accuracy",
|
44 |
+
"MultiRC": "F1",
|
45 |
+
}
|
46 |
+
task_neutral = {
|
47 |
+
"CB": True,
|
48 |
+
"RTE": True,
|
49 |
+
"BoolQ": True,
|
50 |
+
"MNLI": True,
|
51 |
+
"COPA": False,
|
52 |
+
"WSC": False,
|
53 |
+
"multirc": True,
|
54 |
+
"WiC": True,
|
55 |
+
"MultiRC": True,
|
56 |
+
}
|
57 |
+
neutral_tasks = [
|
58 |
+
"BoolQ",
|
59 |
+
"CB",
|
60 |
+
"MNLI",
|
61 |
+
"MultiRC",
|
62 |
+
"RTE",
|
63 |
+
"WiC",
|
64 |
+
]
|
65 |
+
tasks = sorted(task_patterns.keys())
|
66 |
+
|
67 |
+
pvp_colors = ["goldenrod", "blanchedalmond", "floralwhite"]
|
68 |
+
ctl_colors = ["crimson", "salmon", "mistyrose"]
|
69 |
+
clf_colors = ["indigo", "plum", "thistle"]
|
70 |
+
|
71 |
+
|
72 |
+
def prompt_boolq(passage, question, pattern):
|
73 |
+
if pattern == 0:
|
74 |
+
return f"""<span style="color: #0c593d">{passage}</span> <span style="color: #910713"><b>Based on the previous passage,</b></span> <span style="color: #031154">{question}</span> <span style="color: #ba9004"><b>[YES/NO]</b></span>"""
|
75 |
+
if pattern == 1:
|
76 |
+
return f"""<span style="color: #0c593d">{passage}</span><span style="color: #910713"><b> Question:</b></span> <span style="color: #031154">{question}</span><span style="color: #910713"><b> Answer: </b></span><span style="color: #ba9004"><b>[YES/NO]</b></span>"""
|
77 |
+
if pattern == 2:
|
78 |
+
return f"""<span style="color: #910713"><b>Based on the following passage,</b></span> <span style="color: #031154">{question}</span><span style="color: #ba9004"><b> [YES/NO]</b></span> <span style="color: #0c593d">{passage}</span>"""
|
79 |
+
|
80 |
+
|
81 |
+
def advantage_text(advantage):
|
82 |
+
model_type = (
|
83 |
+
"""<span style="color: #4B0082">分类头法</span>"""
|
84 |
+
if advantage < 0
|
85 |
+
else """<span style="color: #daa520">提示法</span>"""
|
86 |
+
)
|
87 |
+
return f"""<b>{model_type}</b> 优势: <b>{abs(advantage):.2f}</b> 条样本"""
|
88 |
+
|
89 |
+
|
90 |
+
def average_advantage_text(advantage):
|
91 |
+
model_type = (
|
92 |
+
"""<span style="color: #4B0082">分类头法</span>"""
|
93 |
+
if advantage < 0
|
94 |
+
else """<span style="color: #daa520">提示法</span>"""
|
95 |
+
)
|
96 |
+
return f"""<b>Average {model_type}</b> 优势: <b>{abs(advantage):.2f}</b> 条样本"""
|
97 |
+
|
98 |
+
|
99 |
+
def naming_convention(task, seed, pvp_index=None, neutral=False):
|
100 |
+
method = f"PVP {pvp_index}" if pvp_index is not None else "CLF"
|
101 |
+
model = "roberta"
|
102 |
+
if neutral:
|
103 |
+
verbalizer = "neutral"
|
104 |
+
else:
|
105 |
+
verbalizer = None
|
106 |
+
return (
|
107 |
+
f"{method} {model}"
|
108 |
+
+ (f" {verbalizer} verbalizer" if verbalizer is not None else "")
|
109 |
+
+ f" seed {seed} - test-{task_metric_short[task]}-all-p"
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def get_data(task):
|
114 |
+
url = f"https://raw.githubusercontent.com/TevenLeScao/pet/master/exported_results/{task.lower()}/wandb_export.csv"
|
115 |
+
df = pd.read_csv(url)
|
116 |
+
training_points = df["training_points"]
|
117 |
+
|
118 |
+
head_performances = np.transpose(np.array([df[naming_convention(task, i)] for i in range(task_reps[task])]))
|
119 |
+
pattern_performances = {}
|
120 |
+
for pattern in task_patterns[task]:
|
121 |
+
pattern_performances[pattern] = {
|
122 |
+
"normal": np.transpose(np.array([df[naming_convention(task, i, pattern)] for i in range(task_reps[task])]))
|
123 |
+
}
|
124 |
+
if task_neutral[task]:
|
125 |
+
pattern_performances[pattern]["neutral"] = np.transpose(
|
126 |
+
np.array([df[naming_convention(task, i, pattern, True)] for i in range(task_reps[task])])
|
127 |
+
)
|
128 |
+
|
129 |
+
return training_points, head_performances, pattern_performances
|
130 |
+
|
131 |
+
|
132 |
+
def reduct(performances, reduction="accmax", final_pattern=0, verbalizer="normal", exclude=None):
|
133 |
+
# Combining the different runs for each experimental set-up
|
134 |
+
reducted = None
|
135 |
+
|
136 |
+
if isinstance(performances, dict):
|
137 |
+
performances = performances[final_pattern][verbalizer]
|
138 |
+
if exclude is not None:
|
139 |
+
performances = np.delete(performances, exclude, axis=1)
|
140 |
+
|
141 |
+
if reduction == "avg":
|
142 |
+
# Average
|
143 |
+
reducted = np.nanmean(performances, axis=1)
|
144 |
+
|
145 |
+
if reduction == "std":
|
146 |
+
# Standard deviation
|
147 |
+
reducted = np.nanstd(performances, axis=1)
|
148 |
+
|
149 |
+
if reduction == "max":
|
150 |
+
# Maximum
|
151 |
+
reducted = np.nanmax(performances, axis=1)
|
152 |
+
|
153 |
+
if reduction == "accmax":
|
154 |
+
# This makes the maximum curve monotonic
|
155 |
+
max_performance = np.nanmax(performances, axis=1)
|
156 |
+
reducted = np.maximum.accumulate(max_performance)
|
157 |
+
|
158 |
+
assert reducted is not None, "unrecognized reduction method"
|
159 |
+
return reducted
|
160 |
+
|
161 |
+
|
162 |
+
def find_surrounding_points(perf, clf_results, pvp_results):
|
163 |
+
for i, clf_result in enumerate(clf_results):
|
164 |
+
if i - 1 > 0 and clf_result == clf_results[i - 1]:
|
165 |
+
continue
|
166 |
+
if clf_result > perf:
|
167 |
+
if i == 0:
|
168 |
+
raise ValueError(f"value {perf} too small")
|
169 |
+
else:
|
170 |
+
break
|
171 |
+
for j, pvp_result in enumerate(pvp_results):
|
172 |
+
if j - 1 > 0 and pvp_result == pvp_results[j - 1]:
|
173 |
+
continue
|
174 |
+
if pvp_result > perf:
|
175 |
+
if j == 0:
|
176 |
+
raise ValueError(f"value {perf} too small")
|
177 |
+
else:
|
178 |
+
break
|
179 |
+
return i - 1, j - 1
|
180 |
+
|
181 |
+
|
182 |
+
def interpolate(perf, x1, x2, y1, y2):
|
183 |
+
return x1 + (perf - y1) * (x2 - x1) / (y2 - y1)
|
184 |
+
|
185 |
+
|
186 |
+
def interpolate_from_idx(perf, idx, results, training_points):
|
187 |
+
return interpolate(perf, training_points[idx], training_points[idx + 1], results[idx], results[idx + 1])
|
188 |
+
|
189 |
+
|
190 |
+
def interpolate_from_perf(perf, overlapping_range, training_points, clf_results, pvp_results):
|
191 |
+
if not overlapping_range[0] <= perf <= overlapping_range[1]:
|
192 |
+
raise ValueError(f"perf {perf} not in acceptable bounds {overlapping_range}")
|
193 |
+
clf_idx, pvp_idx = find_surrounding_points(perf, clf_results, pvp_results)
|
194 |
+
return interpolate_from_idx(perf, clf_idx, clf_results, training_points), interpolate_from_idx(
|
195 |
+
perf, pvp_idx, pvp_results, training_points
|
196 |
+
)
|
197 |
+
|
198 |
+
|
199 |
+
def data_difference(perf, overlapping_range, training_points, clf_results, pvp_results):
|
200 |
+
x1, x2 = interpolate_from_perf(perf, overlapping_range, training_points, clf_results, pvp_results)
|
201 |
+
return x1 - x2
|
202 |
+
|
203 |
+
|
204 |
+
def calculate_overlap(clf_results, pvp_results, full_range=False):
|
205 |
+
if full_range:
|
206 |
+
return (min(min(clf_results), min(pvp_results)), max(max(clf_results), max(pvp_results)))
|
207 |
+
else:
|
208 |
+
return (max(min(clf_results), min(pvp_results)), min(max(clf_results), max(pvp_results)))
|
209 |
+
|
210 |
+
|
211 |
+
def calculate_range(overlapping_range, number_of_points):
|
212 |
+
integral_range = (
|
213 |
+
overlapping_range[0] + i / (number_of_points + 1) * (overlapping_range[1] - overlapping_range[0])
|
214 |
+
for i in range(1, number_of_points + 1)
|
215 |
+
)
|
216 |
+
return integral_range
|
217 |
+
|
218 |
+
|
219 |
+
def calculate_differences(integral_range, overlapping_range, training_points, clf_results, pvp_results):
|
220 |
+
differences = [
|
221 |
+
data_difference(y, overlapping_range, training_points, clf_results, pvp_results) for y in integral_range
|
222 |
+
]
|
223 |
+
return differences
|
224 |
+
|
225 |
+
|
226 |
+
def calculate_offset(training_points, clf_results, pvp_results, number_of_points=1000):
|
227 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results)
|
228 |
+
integral_range = calculate_range(overlapping_range, number_of_points)
|
229 |
+
differences = calculate_differences(integral_range, overlapping_range, training_points, clf_results, pvp_results)
|
230 |
+
offset = sum(differences) / number_of_points
|
231 |
+
return offset
|
232 |
+
|
233 |
+
|
234 |
+
def intersection_with_range(training_points, results, band):
|
235 |
+
result_polygon = Polygon(
|
236 |
+
[(training_points[i], results[i]) for i in range(len(training_points))]
|
237 |
+
+ [(training_points[-1], 0), (training_points[0], 0)]
|
238 |
+
)
|
239 |
+
return result_polygon.intersection(band)
|
240 |
+
|
241 |
+
|
242 |
+
def fill_polygon(fig, polygon, color, label=None, alpha=1.0):
|
243 |
+
if polygon.is_empty or isinstance(polygon, shapely.geometry.LineString):
|
244 |
+
return
|
245 |
+
if isinstance(polygon, Polygon):
|
246 |
+
xs, ys = polygon.exterior.xy
|
247 |
+
fig.patch(xs, ys, color=color, alpha=alpha)
|
248 |
+
else:
|
249 |
+
for geom in polygon.geoms:
|
250 |
+
if isinstance(geom, shapely.geometry.LineString):
|
251 |
+
continue
|
252 |
+
xs, ys = geom.exterior.xy
|
253 |
+
fig.patch(xs, ys, color=color, alpha=alpha)
|
254 |
+
label = None
|
255 |
+
|
256 |
+
|
257 |
+
label_order = {
|
258 |
+
"head run": 0,
|
259 |
+
"head advantage": 1,
|
260 |
+
"control run": 2,
|
261 |
+
"optimization advantage": 3,
|
262 |
+
"prompting run": 4,
|
263 |
+
"semantics advantage": 5,
|
264 |
+
"region of comparison": 6,
|
265 |
+
}
|
266 |
+
|
267 |
+
|
268 |
+
def metric_tap(
|
269 |
+
event, overlapping_range, training_points, clf_results, pvp_results, advantage_box, advantage_plot
|
270 |
+
):
|
271 |
+
_, metric_value = event.x, event.y
|
272 |
+
try:
|
273 |
+
advantage_value = data_difference(metric_value, overlapping_range, training_points, clf_results, pvp_results)
|
274 |
+
advantage_box.text = advantage_text(advantage_value)
|
275 |
+
if not isinstance(advantage_plot.renderers[-1], Span):
|
276 |
+
metric_line = Span(
|
277 |
+
location=metric_value,
|
278 |
+
line_alpha=0.7,
|
279 |
+
dimension="width",
|
280 |
+
line_color=clf_colors[0] if advantage_value < 0 else pvp_colors[0],
|
281 |
+
line_dash="dashed",
|
282 |
+
line_width=1,
|
283 |
+
)
|
284 |
+
advantage_plot.renderers.extend([metric_line])
|
285 |
+
else:
|
286 |
+
advantage_plot.renderers[-1].location = metric_value
|
287 |
+
advantage_plot.renderers[-1].line_color = clf_colors[0] if advantage_value < 0 else pvp_colors[0]
|
288 |
+
# clicking outside the region
|
289 |
+
except ValueError:
|
290 |
+
pass
|
291 |
+
|
292 |
+
|
293 |
+
def plot_polygons_bokeh(task, training_points, clf_results, pvp_results, clf_colors, pvp_colors, x_log_scale=False):
|
294 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results, False)
|
295 |
+
full_range = calculate_overlap(clf_results, pvp_results, True)
|
296 |
+
middle_y = (full_range[0] + full_range[1]) / 2
|
297 |
+
|
298 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800,
|
299 |
+
x_axis_type="log" if x_log_scale else "linear", title="分类头法及提示法在各规模的训练子集上的性能")
|
300 |
+
|
301 |
+
fig.circle(training_points, clf_results, color=clf_colors[0], legend="分类头法")
|
302 |
+
fig.circle(training_points, pvp_results, color=pvp_colors[0], legend="提示法")
|
303 |
+
fig.line(training_points, clf_results, color=clf_colors[0], alpha=1)
|
304 |
+
fig.line(training_points, pvp_results, color=pvp_colors[0], alpha=1)
|
305 |
+
fig.xaxis.axis_label = "训练子集规模"
|
306 |
+
fig.yaxis.axis_label = task_metrics[task]
|
307 |
+
fig.patch(
|
308 |
+
[training_points[0], training_points[0], training_points[-1], training_points[-1]],
|
309 |
+
[overlapping_range[0], overlapping_range[1], overlapping_range[1], overlapping_range[0]],
|
310 |
+
color="black",
|
311 |
+
fill_alpha=0,
|
312 |
+
line_width=0,
|
313 |
+
legend="比较区域",
|
314 |
+
hatch_alpha=0.14,
|
315 |
+
hatch_scale=40,
|
316 |
+
hatch_pattern="/",
|
317 |
+
)
|
318 |
+
|
319 |
+
band = Polygon(
|
320 |
+
[
|
321 |
+
(training_points[0], overlapping_range[0]),
|
322 |
+
(training_points[0], overlapping_range[1]),
|
323 |
+
(training_points[-1], overlapping_range[1]),
|
324 |
+
(training_points[-1], overlapping_range[0]),
|
325 |
+
]
|
326 |
+
)
|
327 |
+
full_band = Polygon(
|
328 |
+
[
|
329 |
+
(training_points[0], full_range[0]),
|
330 |
+
(training_points[0], full_range[1]),
|
331 |
+
(training_points[-1], full_range[1]),
|
332 |
+
(training_points[-1], full_range[0]),
|
333 |
+
]
|
334 |
+
)
|
335 |
+
clf_polygon = intersection_with_range(training_points, clf_results, band)
|
336 |
+
pvp_polygon = intersection_with_range(training_points, pvp_results, band)
|
337 |
+
full_clf_polygon = intersection_with_range(training_points, clf_results, full_band)
|
338 |
+
full_pvp_polygon = intersection_with_range(training_points, pvp_results, full_band)
|
339 |
+
|
340 |
+
clf_inside_area = clf_polygon.difference(pvp_polygon)
|
341 |
+
pvp_inside_area = pvp_polygon.difference(clf_polygon)
|
342 |
+
clf_outside_area = (full_clf_polygon.difference(full_pvp_polygon)).difference(clf_inside_area)
|
343 |
+
pvp_outside_area = (full_pvp_polygon.difference(full_clf_polygon)).difference(pvp_inside_area)
|
344 |
+
|
345 |
+
fill_polygon(fig, clf_outside_area, clf_colors[1], alpha=0.13)
|
346 |
+
fill_polygon(fig, pvp_outside_area, pvp_colors[1], alpha=0.18)
|
347 |
+
fill_polygon(
|
348 |
+
fig, clf_inside_area, clf_colors[1], alpha=0.4, label="head advantage" if task == "WiC" else None
|
349 |
+
)
|
350 |
+
fill_polygon(fig, pvp_inside_area, pvp_colors[1], alpha=0.4, label="prompting advantage")
|
351 |
+
|
352 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[0], overlapping_range[0]], color="dimgrey")
|
353 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[1], overlapping_range[1]], color="dimgrey")
|
354 |
+
|
355 |
+
vline = Span(
|
356 |
+
location=training_points[-1], dimension="height", line_color="black", line_width=2.5, line_dash="dashed"
|
357 |
+
)
|
358 |
+
end_label = Label(
|
359 |
+
x=training_points[-1], y=middle_y, text="数据集总大小", angle=90, angle_units="deg", text_align="center"
|
360 |
+
)
|
361 |
+
fig.renderers.extend([vline, end_label])
|
362 |
+
|
363 |
+
fig.legend.location = "bottom_right"
|
364 |
+
|
365 |
+
return fig
|
366 |
+
|
367 |
+
|
368 |
+
def plot_three_polygons_bokeh(
|
369 |
+
task, training_points, clf_results, pvp_results, ctl_results, clf_colors, pvp_colors, ctl_colors,
|
370 |
+
x_log_scale=False
|
371 |
+
):
|
372 |
+
overlapping_range = calculate_overlap(clf_results, pvp_results, False)
|
373 |
+
full_range = calculate_overlap(clf_results, pvp_results, True)
|
374 |
+
middle_y = (full_range[0] + full_range[1]) / 2
|
375 |
+
|
376 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800,
|
377 |
+
x_axis_type="log" if x_log_scale else "linear", title="分类头法、提示法以及空言语器提示法在各规模的训练子集上的性能")
|
378 |
+
fig.xaxis.axis_label = "训练子集规模"
|
379 |
+
fig.yaxis.axis_label = task_metrics[task]
|
380 |
+
fig.circle(training_points, clf_results, color=clf_colors[0], legend="分类头法")
|
381 |
+
fig.circle(training_points, pvp_results, color=pvp_colors[0], legend="提示法")
|
382 |
+
fig.circle(training_points, ctl_results, color=ctl_colors[0], legend="空言语器提示法")
|
383 |
+
fig.line(training_points, clf_results, color=clf_colors[0], alpha=1)
|
384 |
+
fig.line(training_points, pvp_results, color=pvp_colors[0], alpha=1)
|
385 |
+
fig.line(training_points, ctl_results, color=ctl_colors[0], alpha=1)
|
386 |
+
|
387 |
+
fig.patch(
|
388 |
+
[training_points[0], training_points[0], training_points[-1], training_points[-1]],
|
389 |
+
[overlapping_range[0], overlapping_range[1], overlapping_range[1], overlapping_range[0]],
|
390 |
+
color="black",
|
391 |
+
fill_alpha=0,
|
392 |
+
line_width=0,
|
393 |
+
legend="比较区域",
|
394 |
+
hatch_alpha=0.14,
|
395 |
+
hatch_scale=40,
|
396 |
+
hatch_pattern="/",
|
397 |
+
)
|
398 |
+
|
399 |
+
band = Polygon(
|
400 |
+
[
|
401 |
+
(training_points[0], overlapping_range[0]),
|
402 |
+
(training_points[0], overlapping_range[1]),
|
403 |
+
(training_points[-1], overlapping_range[1]),
|
404 |
+
(training_points[-1], overlapping_range[0]),
|
405 |
+
]
|
406 |
+
)
|
407 |
+
full_band = Polygon(
|
408 |
+
[
|
409 |
+
(training_points[0], full_range[0]),
|
410 |
+
(training_points[0], full_range[1]),
|
411 |
+
(training_points[-1], full_range[1]),
|
412 |
+
(training_points[-1], full_range[0]),
|
413 |
+
]
|
414 |
+
)
|
415 |
+
|
416 |
+
clf_polygon = intersection_with_range(training_points, clf_results, band)
|
417 |
+
pvp_polygon = intersection_with_range(training_points, pvp_results, band)
|
418 |
+
ctl_polygon = intersection_with_range(training_points, ctl_results, band)
|
419 |
+
|
420 |
+
full_clf_polygon = intersection_with_range(training_points, clf_results, full_band)
|
421 |
+
full_pvp_polygon = intersection_with_range(training_points, pvp_results, full_band)
|
422 |
+
full_ctl_polygon = intersection_with_range(training_points, ctl_results, full_band)
|
423 |
+
|
424 |
+
clf_inside_area = clf_polygon.difference(ctl_polygon)
|
425 |
+
pvp_inside_area = pvp_polygon.difference(clf_polygon).difference(ctl_polygon)
|
426 |
+
ctl_inside_area = ctl_polygon.difference(clf_polygon)
|
427 |
+
|
428 |
+
clf_outside_area = (full_clf_polygon.difference(full_ctl_polygon)).difference(clf_inside_area)
|
429 |
+
pvp_outside_area = (full_pvp_polygon.difference(full_clf_polygon).difference(ctl_polygon)).difference(
|
430 |
+
pvp_inside_area
|
431 |
+
)
|
432 |
+
ctl_outside_area = (full_ctl_polygon.difference(full_clf_polygon)).difference(pvp_inside_area)
|
433 |
+
|
434 |
+
fill_polygon(
|
435 |
+
fig, clf_inside_area, clf_colors[1], alpha=0.4, label="head advantage" if task == "WiC" else None
|
436 |
+
)
|
437 |
+
fill_polygon(fig, pvp_inside_area, pvp_colors[1], alpha=0.4, label="prompting advantage")
|
438 |
+
fill_polygon(fig, ctl_inside_area, ctl_colors[1], alpha=0.4, label="null verbalizer advantage")
|
439 |
+
fill_polygon(fig, clf_outside_area, clf_colors[1], alpha=0.13)
|
440 |
+
fill_polygon(fig, pvp_outside_area, pvp_colors[1], alpha=0.18)
|
441 |
+
fill_polygon(fig, ctl_outside_area, ctl_colors[1], alpha=0.13)
|
442 |
+
|
443 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[0], overlapping_range[0]], color="dimgrey")
|
444 |
+
fig.line([training_points[0], training_points[-1]], [overlapping_range[1], overlapping_range[1]], color="dimgrey")
|
445 |
+
|
446 |
+
vline = Span(
|
447 |
+
location=training_points[-1], dimension="height", line_color="black", line_width=2.5, line_dash="dashed"
|
448 |
+
)
|
449 |
+
end_label = Label(
|
450 |
+
x=training_points[-1], y=middle_y, text="数据集总大小", angle=90, angle_units="deg", text_align="center"
|
451 |
+
)
|
452 |
+
fig.renderers.extend([vline, end_label])
|
453 |
+
|
454 |
+
fig.legend.location = "bottom_right"
|
455 |
+
|
456 |
+
return fig
|
457 |
+
|
458 |
+
|
459 |
+
def pattern_graph(task):
|
460 |
+
fig = figure(plot_height=400, plot_width=800, max_height=400, max_width=800, x_axis_type="log", title="Performance over training subset sizes of different prompt patterns")
|
461 |
+
fig.xaxis.axis_label = "训练子集规模"
|
462 |
+
fig.yaxis.axis_label = task_metrics[task]
|
463 |
+
url = f"https://raw.githubusercontent.com/TevenLeScao/pet/master/exported_results/{task.lower()}/wandb_export.csv"
|
464 |
+
df = pd.read_csv(url)
|
465 |
+
expanded_training_points = np.array(list(df["training_points"]) * task_reps[task] * len(task_patterns[task]))
|
466 |
+
data = np.array(df[[naming_convention(task, seed, pattern) for pattern in task_patterns[task] for seed in
|
467 |
+
range(task_reps[task])]])
|
468 |
+
data = data.reshape(-1, task_reps[task])
|
469 |
+
col_med = np.nanmean(data, axis=1)
|
470 |
+
# Find indices that you need to replace
|
471 |
+
inds = np.where(np.isnan(data))
|
472 |
+
# Place column means in the indices. Align the arrays using take
|
473 |
+
data[inds] = np.take(col_med, inds[0])
|
474 |
+
data = data.reshape(len(df["training_points"]), -1)
|
475 |
+
data = data.transpose().reshape(-1)
|
476 |
+
data = data + np.random.normal(0, 0.01, len(data))
|
477 |
+
pattern = np.array([i // (len(data) // len(task_patterns[task])) for i in range(len(data))])
|
478 |
+
seed = np.array([0, 1, 2, 3] * (len(data) // task_reps[task]))
|
479 |
+
long_df = pd.DataFrame(np.stack((expanded_training_points, pattern, seed, data), axis=1),
|
480 |
+
columns=["training_points", "pattern", "seed", task_metrics[task]])
|
481 |
+
long_df['pattern'] = long_df['pattern'].astype(int).astype(str)
|
482 |
+
gby_pattern = long_df.groupby('pattern')
|
483 |
+
pattern_colors = ["royalblue", "darkturquoise", "darkviolet"]
|
484 |
+
|
485 |
+
for i, (pattern, pattern_df) in enumerate(gby_pattern):
|
486 |
+
gby_training_points = pattern_df.groupby('training_points')
|
487 |
+
x = [training_point for training_point, training_point_df in gby_training_points]
|
488 |
+
y_max = list([np.max(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
489 |
+
y_min = list([np.min(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
490 |
+
y = list([np.median(training_point_df[task_metrics[task]]) for training_point, training_point_df in gby_training_points])
|
491 |
+
fig.circle(x, y, color=pattern_colors[i], alpha=1, legend=f"模式 {i}")
|
492 |
+
fig.line(x, y, color=pattern_colors[i], alpha=1)
|
493 |
+
fig.varea(x=x, y1=y_max, y2=y_min, color=pattern_colors[i], alpha=0.11)
|
494 |
+
# source = ColumnDataSource(data=dict(base=x, lower=y_min, upper=y_max))
|
495 |
+
# w = Whisker(source=source, base="base", upper="upper", lower="lower", line_color=pattern_colors[i], line_alpha=0.3)
|
496 |
+
# w.upper_head.line_color = pattern_colors[i]
|
497 |
+
# w.lower_head.line_color = pattern_colors[i]
|
498 |
+
# fig.add_layout(w)
|
499 |
+
|
500 |
+
return fig
|
501 |
+
|
502 |
+
|
503 |
+
|
504 |
+
def cubic_easing(t):
|
505 |
+
if t < 0.5:
|
506 |
+
return 4 * t * t * t
|
507 |
+
p = 2 * t - 2
|
508 |
+
return 0.5 * p * p * p + 1
|
509 |
+
|
510 |
+
|
511 |
+
def circ_easing(t):
|
512 |
+
if t < 0.5:
|
513 |
+
return 0.5 * (1 - math.sqrt(1 - 4 * (t * t)))
|
514 |
+
return 0.5 * (math.sqrt(-((2 * t) - 3) * ((2 * t) - 1)) + 1)
|
naacl_demo/main.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from bokeh.events import Tap
|
2 |
+
from bokeh.io import curdoc
|
3 |
+
from bokeh.layouts import column
|
4 |
+
from bokeh.models import Div, TextInput, RadioButtonGroup, TextAreaInput, Span, Button, Panel, Tabs
|
5 |
+
from bokeh.models.tools import CrosshairTool
|
6 |
+
|
7 |
+
from demo_utils import (
|
8 |
+
get_data,
|
9 |
+
prompt_boolq,
|
10 |
+
pvp_colors,
|
11 |
+
ctl_colors,
|
12 |
+
clf_colors,
|
13 |
+
reduct,
|
14 |
+
task_best_pattern,
|
15 |
+
plot_polygons_bokeh,
|
16 |
+
advantage_text,
|
17 |
+
data_difference,
|
18 |
+
calculate_overlap,
|
19 |
+
circ_easing,
|
20 |
+
average_advantage_text,
|
21 |
+
plot_three_polygons_bokeh,
|
22 |
+
tasks,
|
23 |
+
metric_tap,
|
24 |
+
neutral_tasks, pattern_graph,
|
25 |
+
)
|
26 |
+
from text import text1, text2, text3, text4, initial_passage, initial_question, text5
|
27 |
+
|
28 |
+
########################################################################################################################
|
29 |
+
# Basic dimensions
|
30 |
+
########################################################################################################################
|
31 |
+
|
32 |
+
plot_width = 1200
|
33 |
+
plot_height = 400
|
34 |
+
sidebar_width = 400
|
35 |
+
in_text_plot_height = 300
|
36 |
+
text_width = 800
|
37 |
+
widget_size = 400
|
38 |
+
|
39 |
+
########################################################################################################################
|
40 |
+
# Patternification widget
|
41 |
+
########################################################################################################################
|
42 |
+
|
43 |
+
passage = TextAreaInput(title="篇章", rows=3, value=initial_passage, max_width=text_width)
|
44 |
+
passage.align = "center"
|
45 |
+
question = TextInput(title="问题", value=initial_question, max_width=text_width)
|
46 |
+
question.align = "center"
|
47 |
+
radio_button_group = RadioButtonGroup(labels=["模板 1", "模板 2", "模板 3"], active=0, max_width=text_width)
|
48 |
+
radio_button_group.align = "center"
|
49 |
+
|
50 |
+
box_style = {
|
51 |
+
"display": "block",
|
52 |
+
"margin": "0 auto",
|
53 |
+
"width": f"{text_width}px",
|
54 |
+
"text-align": "center",
|
55 |
+
"white-space": "pre-wrap",
|
56 |
+
"background": "#f4f4f4",
|
57 |
+
"border": "1px solid #ddd",
|
58 |
+
# "border-left": "3px solid #4d4945",
|
59 |
+
"color": "#666",
|
60 |
+
"page-break-inside": "avoid",
|
61 |
+
# "font-family": "monospace",
|
62 |
+
"font-size": "15px",
|
63 |
+
"line-height": "1.6",
|
64 |
+
"max-width": "100%",
|
65 |
+
"overflow": "hidden",
|
66 |
+
"min-height": "30px",
|
67 |
+
"word-wrap": "break-word",
|
68 |
+
}
|
69 |
+
|
70 |
+
prompt_box = Div(
|
71 |
+
text=prompt_boolq(passage.value, question.value, radio_button_group.active),
|
72 |
+
width=text_width,
|
73 |
+
style=box_style,
|
74 |
+
sizing_mode="scale_width",
|
75 |
+
)
|
76 |
+
prompt_box.align = "center"
|
77 |
+
|
78 |
+
|
79 |
+
def update_prompt(attrname, old, new):
|
80 |
+
prompt_box.text = prompt_boolq(passage.value, question.value, radio_button_group.active)
|
81 |
+
|
82 |
+
|
83 |
+
passage.on_change("value", update_prompt)
|
84 |
+
question.on_change("value", update_prompt)
|
85 |
+
radio_button_group.on_change("active", update_prompt)
|
86 |
+
|
87 |
+
patternification = column(passage, question, radio_button_group, prompt_box, sizing_mode="scale_width")
|
88 |
+
patternification.align = "center"
|
89 |
+
|
90 |
+
########################################################################################################################
|
91 |
+
# Advantage diagram
|
92 |
+
########################################################################################################################
|
93 |
+
|
94 |
+
advantage_plots_per_task = []
|
95 |
+
overlapping_range_per_task = []
|
96 |
+
training_points_per_task = []
|
97 |
+
clf_results_per_task = []
|
98 |
+
pvp_results_per_task = []
|
99 |
+
advantage_tabs = []
|
100 |
+
advantage_all_figures = Tabs(tabs=advantage_tabs)
|
101 |
+
|
102 |
+
advantage_box = Div(
|
103 |
+
text="在比较区域内点击某点以计算该点对应的性能点上的数据优势",
|
104 |
+
width=text_width,
|
105 |
+
style=box_style,
|
106 |
+
sizing_mode="scale_width",
|
107 |
+
)
|
108 |
+
advantage_box.align = "center"
|
109 |
+
|
110 |
+
for task in tasks:
|
111 |
+
training_points, classifier_performances, pattern_performances = get_data(task)
|
112 |
+
training_points_per_task.append(list(training_points))
|
113 |
+
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
|
114 |
+
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
|
115 |
+
advantage_plots_per_task.append(plot_polygons_bokeh(
|
116 |
+
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
|
117 |
+
pvp_colors
|
118 |
+
))
|
119 |
+
advantage_plots_per_task[-1].align = "center"
|
120 |
+
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
121 |
+
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
|
122 |
+
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title=task))
|
123 |
+
|
124 |
+
advantage_plots_per_task[-1].on_event(
|
125 |
+
Tap,
|
126 |
+
lambda event: metric_tap(
|
127 |
+
event,
|
128 |
+
overlapping_range_per_task[advantage_all_figures.active],
|
129 |
+
training_points_per_task[advantage_all_figures.active],
|
130 |
+
clf_results_per_task[advantage_all_figures.active],
|
131 |
+
pvp_results_per_task[advantage_all_figures.active],
|
132 |
+
advantage_box,
|
133 |
+
advantage_plots_per_task[advantage_all_figures.active],
|
134 |
+
),
|
135 |
+
)
|
136 |
+
|
137 |
+
if task == "MNLI":
|
138 |
+
training_points_per_task.append(list(training_points))
|
139 |
+
clf_results_per_task.append(reduct(classifier_performances, "accmax"))
|
140 |
+
pvp_results_per_task.append(reduct(pattern_performances, "accmax", task_best_pattern[task], "normal"))
|
141 |
+
advantage_plots_per_task.append(plot_polygons_bokeh(
|
142 |
+
task, training_points_per_task[-1], clf_results_per_task[-1], pvp_results_per_task[-1], clf_colors,
|
143 |
+
pvp_colors, x_log_scale=True
|
144 |
+
))
|
145 |
+
advantage_plots_per_task[-1].align = "center"
|
146 |
+
advantage_plots_per_task[-1].add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
147 |
+
overlapping_range_per_task.append(calculate_overlap(clf_results_per_task[-1], pvp_results_per_task[-1]))
|
148 |
+
advantage_tabs.append(Panel(child=advantage_plots_per_task[-1], title="MNLI (log scale)"))
|
149 |
+
|
150 |
+
advantage_plots_per_task[-1].on_event(
|
151 |
+
Tap,
|
152 |
+
lambda event: metric_tap(
|
153 |
+
event,
|
154 |
+
overlapping_range_per_task[advantage_all_figures.active],
|
155 |
+
training_points_per_task[advantage_all_figures.active],
|
156 |
+
clf_results_per_task[advantage_all_figures.active],
|
157 |
+
pvp_results_per_task[advantage_all_figures.active],
|
158 |
+
advantage_box,
|
159 |
+
advantage_plots_per_task[advantage_all_figures.active],
|
160 |
+
),
|
161 |
+
)
|
162 |
+
|
163 |
+
advantage_all_figures = Tabs(tabs=advantage_tabs)
|
164 |
+
advantage_all_figures.align = "center"
|
165 |
+
|
166 |
+
|
167 |
+
def on_integrate_click():
|
168 |
+
frames = 200
|
169 |
+
initial_placement = overlapping_range_per_task[advantage_all_figures.active][0]
|
170 |
+
|
171 |
+
if not isinstance(advantage_plots_per_task[advantage_all_figures.active].renderers[-1], Span):
|
172 |
+
metric_line = Span(
|
173 |
+
location=initial_placement,
|
174 |
+
line_alpha=0.7,
|
175 |
+
dimension="width",
|
176 |
+
line_color=clf_colors[0] if initial_placement < 0 else pvp_colors[0],
|
177 |
+
line_dash="dashed",
|
178 |
+
line_width=1,
|
179 |
+
)
|
180 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers.extend([metric_line])
|
181 |
+
else:
|
182 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = initial_placement
|
183 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
|
184 |
+
0] if initial_placement < 0 else pvp_colors[0]
|
185 |
+
|
186 |
+
average_advantage = 0
|
187 |
+
for i in range(1, frames):
|
188 |
+
metric_value = overlapping_range_per_task[advantage_all_figures.active][0] + (
|
189 |
+
overlapping_range_per_task[advantage_all_figures.active][1] -
|
190 |
+
overlapping_range_per_task[advantage_all_figures.active][0]) * (i / frames)
|
191 |
+
advantage_value = data_difference(metric_value, overlapping_range_per_task[advantage_all_figures.active],
|
192 |
+
training_points_per_task[advantage_all_figures.active],
|
193 |
+
clf_results_per_task[advantage_all_figures.active],
|
194 |
+
pvp_results_per_task[advantage_all_figures.active])
|
195 |
+
average_advantage = ((i - 1) * average_advantage + advantage_value) / i
|
196 |
+
|
197 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].location = metric_value
|
198 |
+
advantage_plots_per_task[advantage_all_figures.active].renderers[-1].line_color = clf_colors[
|
199 |
+
0] if advantage_value < 0 else pvp_colors[0]
|
200 |
+
advantage_box.text = average_advantage_text(average_advantage)
|
201 |
+
|
202 |
+
|
203 |
+
integrate = Button(width=175, max_width=175, label="对整个区域进行积分!")
|
204 |
+
integrate.align = "center"
|
205 |
+
integrate.on_click(on_integrate_click)
|
206 |
+
|
207 |
+
|
208 |
+
def on_tab_change(attr, old, new):
|
209 |
+
advantage_box.text = "在比较区域内点击某点以计算该点对应的性能点上的数据优势"
|
210 |
+
|
211 |
+
|
212 |
+
advantage_all_figures.on_change('active', on_tab_change)
|
213 |
+
|
214 |
+
advantage_column = column(advantage_all_figures, advantage_box, integrate, sizing_mode="scale_width")
|
215 |
+
|
216 |
+
########################################################################################################################
|
217 |
+
# Null verbalizer diagram
|
218 |
+
########################################################################################################################
|
219 |
+
|
220 |
+
null_tabs = []
|
221 |
+
null_all_figures = Tabs(tabs=null_tabs)
|
222 |
+
|
223 |
+
for task in neutral_tasks:
|
224 |
+
training_points, classifier_performances, pattern_performances = get_data(task)
|
225 |
+
training_points = list(training_points)
|
226 |
+
clf_results = reduct(classifier_performances, "accmax")
|
227 |
+
pvp_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "normal")
|
228 |
+
ctl_results = reduct(pattern_performances, "accmax", task_best_pattern[task], "neutral")
|
229 |
+
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
|
230 |
+
pvp_colors, ctl_colors)
|
231 |
+
null_plot.align = "center"
|
232 |
+
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
233 |
+
null_tabs.append(Panel(child=null_plot, title=task))
|
234 |
+
|
235 |
+
if task == "MNLI":
|
236 |
+
null_plot = plot_three_polygons_bokeh(task, training_points, clf_results, pvp_results, ctl_results, clf_colors,
|
237 |
+
pvp_colors, ctl_colors, x_log_scale=True)
|
238 |
+
null_plot.align = "center"
|
239 |
+
null_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
240 |
+
null_tabs.append(Panel(child=null_plot, title="MNLI (log scale)"))
|
241 |
+
|
242 |
+
null_all_figures = Tabs(tabs=null_tabs)
|
243 |
+
null_all_figures.align = "center"
|
244 |
+
|
245 |
+
########################################################################################################################
|
246 |
+
# Patterns diagram
|
247 |
+
########################################################################################################################
|
248 |
+
|
249 |
+
pattern_tabs = []
|
250 |
+
pattern_all_figures = Tabs(tabs=pattern_tabs)
|
251 |
+
|
252 |
+
for task in tasks:
|
253 |
+
pattern_plot = pattern_graph(task)
|
254 |
+
pattern_plot.align = "center"
|
255 |
+
pattern_plot.add_tools(CrosshairTool(dimensions="width", line_alpha=0.2))
|
256 |
+
pattern_tabs.append(Panel(child=pattern_plot, title=task))
|
257 |
+
|
258 |
+
pattern_all_figures = Tabs(tabs=pattern_tabs)
|
259 |
+
pattern_all_figures.align = "center"
|
260 |
+
|
261 |
+
########################################################################################################################
|
262 |
+
# Add write-up text
|
263 |
+
########################################################################################################################
|
264 |
+
|
265 |
+
main_text_style = {
|
266 |
+
"min-height": "100px",
|
267 |
+
"overflow": "hidden",
|
268 |
+
"display": "block",
|
269 |
+
"margin": "auto",
|
270 |
+
"width": f"{text_width}px",
|
271 |
+
"font-size": "18px",
|
272 |
+
}
|
273 |
+
|
274 |
+
textbox1 = Div(text=text1, style=main_text_style)
|
275 |
+
textbox2 = Div(text=text2, style=main_text_style)
|
276 |
+
textbox3 = Div(text=text3, style=main_text_style)
|
277 |
+
textbox4 = Div(text=text4, style=main_text_style)
|
278 |
+
textbox5 = Div(text=text5, style=main_text_style)
|
279 |
+
textbox1.align = "center"
|
280 |
+
textbox2.align = "center"
|
281 |
+
textbox3.align = "center"
|
282 |
+
textbox4.align = "center"
|
283 |
+
textbox5.align = "center"
|
284 |
+
|
285 |
+
########################################################################################################################
|
286 |
+
# Set up layouts and add to document
|
287 |
+
########################################################################################################################
|
288 |
+
|
289 |
+
main_body = column(textbox1, patternification, textbox2, advantage_column, textbox3, null_all_figures, textbox4, pattern_all_figures, textbox5, sizing_mode="scale_width")
|
290 |
+
main_body.align = "center"
|
291 |
+
|
292 |
+
curdoc().add_root(main_body)
|
293 |
+
curdoc().title = "一条提示抵得上多少样本数据?"
|
naacl_demo/text.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 一条提示抵得上多少条数据?
|
2 |
+
|
3 |
+
当前 NLP 应用的主流方法是针对各式各样的特定任务,分别对预训练语言模型的分类头进行微调。随着语言模型变得越来越大,各种替代方法相继涌现,开始叫板在 [BERT](https://arxiv.org/abs/1810.04805)、[UniLM](https://arxiv.org/abs/1905.03197) 以及 [GPT](https://openai.com/research/language-unsupervised) 中广泛使用的分类头法。特别地,GPT-3 向大家普及了提示法,该方法通过自然语言输入来引导预训练语言模型依靠自身解决任务,而不再需要额外的分类头。
|
4 |
+
|
5 |
+
提示很有意思,用户可以通过它向模型提供信息,这与传统的 ML 监督学习有很大不同。在与 [Alexader Rush](http://rush-nlp.com/) 合作的 [NAACL 2021 论文](https://arxiv.org/abs/2103.08493)中,我们研究了基于提示的微调,该方法可用于替代当前标准的有监督微调法,我们的实验表明提示通常比标准方法更有优势,因此该方法很有前途。在分析这种优势时,我们认为提示为模型带来了额外的信息,这促使我们想用数据点这个指标来量化这种优势,也就是说:**提示可以抵多少个数据点?**
|
6 |
+
|
7 |
+
## 提示法
|
8 |
+
|
9 |
+
为了使预训练语言模型能够完成特定任务,当前的主流方法是用随机初始化的线性分类头替换原模型的最后一层:词预测层。然后使用有监督的任务数据通过反向传播来训练修改后的模型,主要学习这个新分类头的权重,同时也可以更新模型其他层的权重。我们将这种方法称为*分类头*法。
|
10 |
+
|
11 |
+
一种与之相竞争的方法是*提示法*:这类方法主要尝试使用原语言模型来预测目标类相应的单词来“回答”分类问题,而不是像传统方法那样“预测”类标签。这使得我们可以直接使用语言模型本身来执行分类任务。在这里,*提示*就是精心设计的、用于生成所需的答案文本的输入文本。
|
12 |
+
|
13 |
+
这听起来可能很抽象,但这其实恰恰就是人类在实际生活中进行文本推理时所使用的非常自然的方法:例如,学校练习往往以一个文本输入(例如,一篇关于火星的文章)加上一个问题(“火星上有生命吗?”)的形式呈现,并期望你提供一个自然语言的答案(“否”<sup>1</sup>),该答案其实就可以映射到分类任务的某个类别(这里,“否”对应假,“是”对应真,本例就是个二分类问题)。在这种范式中,就像做语法练习一样,我们把特定于任务的数据输入给模型,而模型就像学生一样,需要以固定的方式进行填空。提示法希望能显式利用语言模型中包含的预训练信息,而不是仅以将其隐含表征馈送给线性分类头的方式隐式利用这些信息。
|
14 |
+
|
15 |
+
以下是 [SuperGLUE](https://arxiv.org/abs/1905.00537) 中的 [BoolQ](https://arxiv.org/abs/1905.10044) 任务的示例,其题型为判断题,每条数据包括一个文本 <span style="color: #0c593d">passage</span> 及其对应的问题 <span style="color: #031154">question</span> ,其答案为布尔值,要么为真,要么为假。每条数据可以和 <span style="color: #910713">**模板(pattern)**</span> 一起组装成一个文本序列,该序列只有一个需预测的 <span style="color: #ba9004">**掩码词**</span>。预测出该掩码词后,预测词会被一个预设的 *言语器(verbalizer)* 转换为类,也就是说*言语器*负责输出词与类别之间的映射:比较该词被映射为*真*和*假*的概率,如果*真*的概率高,则最终预测为真,反之则为假。
|
16 |
+
|
17 |
+
![image](mockups/boolqpatterns.png)
|
18 |
+
|
19 |
+
## 微调
|
20 |
+
|
21 |
+
这样,我们就把通用语言模型转变成了针对特定任务的分类器。这种基于提示的语言模型分类器的用法很多:
|
22 |
+
|
23 |
+
- 预训练模型中保留的语言建模功能允许它们在没有额外数据的情况下执行,这与**以随机初始化开始因此初始性能也随机**的*线性分类头*模型相反。因此,许多论文将其用于[零样本分类](https://arxiv.org/abs/1912.10165)。
|
24 |
+
|
25 |
+
- 为了将有监督的任务数据引入模型,我们可以使用反向传播及语言建模中的交叉熵损失目标来微调:将与正确类别相关联的言语器词作为正确预测。 [PET](https://arxiv.org/abs/2001.07676) 使用了这个方法,[T5](https://arxiv.org/abs/1910.10683) 也使用了这个目标函数 - 尽管 T5 使用任务前缀来指示任务,而未使用自然语言提示来描述它。
|
26 |
+
|
27 |
+
- 还有一种方法是使用*潜觉(priming)*,此时,我们需要为当前问题找到若干正确的示例,将其作为原输入文本的前缀一起输入给模型。它没有反向传播,所以永远不会修改语言模型的权重永远;相反,它主要靠在推理时使用注意力机制去利用正确的示例。[GPT3](https://arxiv.org/abs/2005.14165) 使用了该方法。
|
28 |
+
|
29 |
+
- 最后,PET 的方法是使用提示模型���测未标注数据的软标签(或称为伪标签),然后将其作为标签去训练线性分类头。
|
30 |
+
|
31 |
+
在本文中,我们想在提示法和分类头法之间进行一个公平的比较,因此我们统一采用基于反向传播的微调方法。
|
32 |
+
|
33 |
+
## 一个提示可以抵多少条数据?
|
34 |
+
|
35 |
+
正如我们所看到的,分类头法和提示法都可以用于针对下游任务进行有监督微调。二者的核心区别在于,除了带标注的原始样本外,提示法还给了模型一个用于对特定任务进行粗略描述的句子。从某种意义上说,这句话也是一种监督数据,因为它告诉模型任务的信息,但它在本质上与机器学习中标准的监督数据又截然不同。我们应该如何看待这种新的监督数据?又该如何量化这种方法的“零样本”程度?
|
36 |
+
|
37 |
+
我们通过在 SuperGLUE 任务和 MNLI 上比较*分类头法*和*提示法*来尝试回答上面的问题。我们使用的具体方法是:对每个任务,我们通过从数据集中选取样本数不断增加的子集,然后在每个子集上使用这两种方法对 [`RoBERTa-large`](https://arxiv.org/abs/1907.11692) 进行微调,同时其他所有配置保持不变,最后对评估各自的微调模型的性能。为了公平起见,我们先调整基线分类头模型的超参,并使它们的性能达到 SuperGLUE 排行榜中 BERT++ 的性能水平,然后在对应的*提示法*模型中采用相同的超参。
|
38 |
+
|
39 |
+
下图绘制了每个任务 <sup>2</sup> 的最终性能(指标随任务而不同)随数据集大小的变化曲线。有了这个图,我们就能够对两种方法在给定任务上达到一定性能水平所需的数据量进行对比。我们将这种差异称为在该性能水平上其中一个方法相对于其他方法的*数据优势*。我们将两种方法都能达到的性能的范围称为*比较窗口*。通过在该范围内进行积分,我们可以获得在某任务上一种方法相对于另一种方法的“平均数据优势”。从图上看,这即是两条曲线所夹区域的的面积除以比较窗口的高度。<sup>3</sup>
|
40 |
+
|
41 |
+
![image](mockups/advantage.png)
|
42 |
+
|
43 |
+
下表总结了在每个任务上提示法相对于分类头法的平均数据优势,其误差范围由自助采样法(bootstrapping)获得,具体做法是对每个数据规模,我们运行 4 次分类头法和 4 次提示法(即每个数据规模共 16 种组合),然后计算这些结果的标准差。不同任务的结果有很大不同;甚至对于同一任务的不同数据集,结果也会有所不同,例如 MNLI 和 RTE,这俩数据集虽然同属蕴涵任务,但结果就很不同。然而,总的趋势也很明显,即:除 WiC <sup>4</sup> 之外,提示方法在其他任务中都具有显著的优势。 **提示提供的附加信息始终大致相当于数百个数据点**。
|
44 |
+
|
45 |
+
| | MNLI | BoolQ | CB | COPA | MultiRC<sup>5</sup> | RTE | WiC | WSC |
|
46 |
+
|----------------|----------|--------|------|---------|----------|--------|---------|---------|
|
47 |
+
| 提示法 vs 分类头法 | 3506±536 | 752±46 | 90±2 | 288±242 | 384±378 | 282±34 | -424±74 | 281±137 |
|
48 |
+
|
49 |
+
|
50 |
+
## 模板与言语器
|
51 |
+
|
52 |
+
#### 对言语器进行控制
|
53 |
+
|
54 |
+
当前,提示主要被用作零样本分类的工具,这是一个很自然的用法。然而,真正操作起来,零样本一般会很棘手,因为需要对提示和言语器进行完美对齐。在上文中,我们已经表明,提示可以应用到更广泛的场景中,包括在全数据场景中。为了对比提示法的零样本性和自适应性,我们考虑一个*空言语器(null verbalizer)*,该言语器与任务完全无关。对于只需要填写一个词的任务(因此 COPA 和 WSC 数据集不在此列),我们把其言语映射(例如“是”、“否”、“也许”、“对”或“错”)替换成随机的。这样的话,提示模型就会像分类头模型一样,在没有训练数据的情况下无法使用。我们对空言语器配置进行与上文相同的优势分析,并绘制出相应的曲线,如下:
|
55 |
+
|
56 |
+
![image](mockups/nullverbalizer.png)
|
57 |
+
|
58 |
+
| | MNLI | BoolQ | CB | MultiRC<sup>5</sup> | RTE | WiC |
|
59 |
+
|----------------|----------|--------|------|----------|--------|---------|
|
60 |
+
| 提示法 vs 分类头法 | 3506±536 | 752±46 | 90±2 | 384±378 | 282±34 | -424±74 |
|
61 |
+
| 提示法 vs 空言语器 | 150±252 | 299±81 | 78±2 | 74±56 | 404±68 | -354±166 |
|
62 |
+
| 空言语器 vs 分类头法 | 3355±612 | 453±90 | 12±1 | 309±320 | -122±62 | -70±160 |
|
63 |
+
|
64 |
+
从结果来看,其数据优势噪声比直接提示法与分类头法的数据优势的噪声更大。然而,我们也发现,即使仅使用空言语器,语言模型也能够适应任务,即使只有几个数据点,其也能凭借恰当的提示取得与分类头模型相当或更好的性能。因此,我们可以认为,即使没有信息丰富的言语器,提示法带来的归纳偏差也是有益的。
|
65 |
+
|
66 |
+
#### ��板选择带来的影响
|
67 |
+
|
68 |
+
另一个可能影响提示法在零样本分类场景下的成败的因素是:提示模板的选择。这里,我们看一下该因素对我们的影响。我们复用了 PET 中的模板(每个任务有两到三个完全不同的模板),并对每个任务的每种模板进行了实验,结果如下。我们可以看到提示的选择对结果没有显著影响,其方差小于随机种子带来的方差。
|
69 |
+
|
70 |
+
![image](mockups/prompts.png)
|
71 |
+
|
72 |
+
## 最后的话
|
73 |
+
|
74 |
+
通过这项工作,我们研究了一种新的、基于自然语言提示的微调方法,其目的是通过单词预测显式地利用预训练模型的语言建模能力,而不是通过基于模型中间表征的线性分类器隐式地利用它。为了公平比较,我们把问题建模为用反向传播来微调基于提示的分类器语言模型,我们发现提示法通常优于标准的微调线性分类头的方法。我们用数据点来估计这种优势,以衡量人类通过提示提供的附加信息,并发现**编写提示始终抵得上数百个数据点**。此外,即使没有言语器带来的信息量(即使用空言语器),这种优势仍然存在,并且这种方法对于提示的选择相当鲁棒。
|
75 |
+
|
76 |
+
对于从业人员而言,我们相信基于提示的微调应当而且会成为一种标准工具:特别是对于中小型的特定任务数据集,自己设计提示只需付出很小的努力即可获得相当大的数据优势。而对于研究人员而言,我们认为这个领域还有很多问题尚待探索:为什么相同的提示在 MNLI 数据集上抵得上 3500 个样本,而在 RTE 数据集上却只抵得上 282 个样本?提示与标准 ML 监督有何关系?由于它们具有一些零样本特性,因此它们对对抗样本或领域外样本的反应是否有所不同?
|
77 |
+
|
78 |
+
<sup>1</sup>:或者严格点说,至少据我们所知为否。
|
79 |
+
|
80 |
+
<sup>2</sup>:眼尖的读者会注意到所有这些曲线都是单调的。我们为每个实验执行了 4 次运行(即对每个任务的每个数据规模,分别各运行分类头法和提示法 4 次,并用得到的模型测试)。为了清楚起见,并且由于两种方法的微调有时都会失败,从而导致负异常值,因此针对每个数据规模我们报告在此数据规模或更小的数据规模下获得的最大性能,我们将其称为*累积最大*聚合。除了减少方差之外,这不会对报告的数据优势产生太大影响,且即使对于非单调曲线,对图形的解读仍然成立。
|
81 |
+
|
82 |
+
<sup>3</sup>:在计算每个指标的数据优势时,我们为每个数据赋予相同的权重;我们还可以针对每个任务重新参数化 y 轴。这种做法到底是会对提示法相对于分类头法的数据优势起促进作用还是阻碍作用不好说,与数据集相关。举个例子,强调接近收敛时的增益会增加 CB 和 MNLI 上的提示法的数据优势,但会降低 COPA 或 BoolQ 上的优势。
|
83 |
+
|
84 |
+
<sup>4</sup>:有趣的是,PET 已经发现提示对该数据集无效。
|
85 |
+
|
86 |
+
<sup>5</sup>:MultiRC 的比较窗口太小,因为分类头基线模型无法学习到多数类之外知识;我们使用整个区域来获得实际结果的下界。
|
87 |
+
|
88 |
+
> 英文原文: <url> https://huggingface.co/spaces/teven-projects/how_many_data_points </url>
|
89 |
+
> 原文作者:Teven Le Scao
|
90 |
+
> 译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。
|
naacl_demo/text.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
text1 = """<h1 id="how-many-data-points-is-a-prompt-worth">一条提示抵得上多少条数据?</h1>
|
2 |
+
<img class='center' style='height: 5em; float: right;' src='https://raw.githubusercontent.com/TevenLeScao/transformer-xl/master/pytorch/assets/avatar_logo_joint.png' alt='avatar'>
|
3 |
+
<h4>发表于 2021 年 4 月 6 日</h4>
|
4 |
+
<h4>Teven Le Scao, Hugging Face 研究员 • <a href="https://twitter.com/Fluke_Ellington">@Fluke_Ellington</a> </h4>
|
5 |
+
<p>当前 NLP 应用的主流方法是针对各式各样的特定任务,分别对预训练语言模型的分类头进行微调。随着语言模型变得越来越大,各种替代方法相继涌现,开始叫板在 <a href="https://arxiv.org/abs/1810.04805">BERT</a>、<a href="https://arxiv.org/abs/1905.03197">UniLM</a> 以及 <a href="https://openai.com/research/language-unsupervised">GPT</a> 中广泛使用的分类头法。特别地,GPT-3 向大家普及了提示法,该方法通过自然语言输入来引导预训练语言模型依靠自身解决任务,而不再需要额外的分类头。</p>
|
6 |
+
<p>提示很有意思,用户可以通过它向模型提供信息,这与传统的 ML 监督学习有很大不同。在与 <a href="http://rush-nlp.com/">Sasha Rush</a> 合作的 NAACL 2021 <a href="https://arxiv.org/abs/2103.08493">论文</a>中,我们研究了基于提示的微调,该方法可用于替代当前标准的有监督微调法,我们的实验表明提示通常比标准方法更有优势,因此该方法很有前途。在分析这种优势时,我们认为提示为模型带来了额外的信息,这促使我们想用数据点这个指标来量化这种优势,也就是说:<strong>一个提示可以抵多少条数据?</strong> </p>
|
7 |
+
|
8 |
+
<h2 id="prompting">提示法</h2>
|
9 |
+
<p>为了使预训练语言模型能够完成特定任务,当前的主流方法是用随机初始化的线性分类头替换原模型的最后一层:词预测层。然后使用有监督的任务数据通过反向传播来训练修改后的模型,主要学习这个新分类头的权重,同时也可以更新模型其他层的权重。我们将这种方法称为<em>分类头</em>法。</p>
|
10 |
+
<p>一种与之相竞争的方法是<em>提示法</em>:这类方法主要尝试使用原语言模型来预测目标类相应的单词来“回答”分类问题,而不是像传统方法那样“预测”类标签。这使得我们可以直接使用语言模型本身来执行分类任务。在这里,<em>提示</em>就是精心设计的、用于生成所需的答案文本的输入文本。</p>
|
11 |
+
<p id="footnote1back">这听起来可能很抽象,但这其实恰恰就是人类在实际生活中进行文本推理时所使用的非常自然的方法:例如,学校练习往往以一个文本输入(例如,一篇关于火星的文章)加上一个问题("火星上有生命吗?")的形式呈现,并期望你提供一个自然语言的答案("否"<a href="#footnote1"><sup>1</sup></a>),该答案其实就可以映射到分类任务的某个类别(这里,"否"对应<code>假</code>,"是"对应<code>真</code>,本例就是个二分类问题)。在这种范式中,就像做语法练习一样,我们把特定于任务的数据输入给模型,而模型就像学生一样,需要以固定的方式进行填空。提示法希望能显式利用语言模型中包含的预训练信息,而不是仅以将其隐含表征馈送给线性分类头的方式隐式利用这些信息。</p>
|
12 |
+
<p>以下是 <a href="https://arxiv.org/abs/1905.00537">SuperGLUE</a> 中的 <a href="https://arxiv.org/abs/1905.10044">BoolQ</a> 任务的示例,其题型为判断题,每条数据包括一个文本 <span style="color: #0c593d">passage</span> 及其对应的问题 <span style="color: #031154">question</span> ,其答案为布尔值,要么为真,要么为假。每条数据可以和 <span style="color: #910713"><strong>模板(pattern)</strong></span> 一起组装成一个文本序列,该序列只有一个需预测的 <span style="color: #ba9004"><strong>掩码词</strong></span>。预测出该掩码词后,预测词会被一个预设的 <em>言语器(verbalizer)</em> 转换为类,也就是说<em>言语器</em>负责输出词与类别之间的映射:比较该词被映射为<em>是</em>和<em>否</em>的概率,如果<em>是</em>的概率高,则最终预测为<code>真</code>,反之则为<code>假</code>。
|
13 |
+
</p>
|
14 |
+
"""
|
15 |
+
|
16 |
+
text2 = """<h2 id="fine-tuning">微调</h2>
|
17 |
+
<p>这样,我们就把通用语言模型转变成了针对特定任务的分类器。这种基于提示的语言模型分类器的用法很多: </p>
|
18 |
+
<ul>
|
19 |
+
<li>预训练模型中保留的语言建模功能允许它们在没有额外数据的情况下执行,这与<strong>以随机初始化开始因此初始性能也随机</strong>的<em>线性分类头</em>模型相反。因此,许多论文将其用于 <a href="https://arxiv.org/abs/1912.10165">零样本分类</a>。</li>
|
20 |
+
<li>为了将有监督的任务数据引入模型,我们可以使用反向传��及语言建模中的交叉熵损失目标来微调:将与正确类别相关联的言语器词作为正确预测。 <a href="https://arxiv.org/abs/2001.07676">PET</a> 使用了这个方法,<a href="https://arxiv.org/abs/1910.10683">T5</a> 也使用了这个目标函数 - 尽管 T5 使用任务前缀来指示任务,而未使用自然语言提示来描述它。</li>
|
21 |
+
<li>还有一种方法是使用<em>潜觉(priming)</em>,此时,我们需要为当前问题找到若干正确的示例,将其作为原输入文本的前缀一起输入给模型。它没有反向传播,所以永远不会修改语言模型的权重永远;相反,它主要靠在推理时使用注意力机制去利用正确的示例。<a href="https://arxiv.org/abs/2005.14165">GPT-3</a> 使用了该方法。</li>
|
22 |
+
<li>最后,PET 的方法是使用提示模型预测未标注数据的软标签(或称为伪标签),然后将其作为标签去训练线性分类头。</li>
|
23 |
+
</ul>
|
24 |
+
<p>在本文中,我们想在提示法和分类头法之间进行一个公平的比较,因此我们统一采用基于反向传播的微调方法。</p>
|
25 |
+
<h2 id="how-many-data-points-is-a-prompt-worth-">一个提示可以抵多少条数据?</h2>
|
26 |
+
<p>正如我们所看到的,分类头法和提示法都可以用于针对下游任务进行有监督微调。二者的核心区别在于,除了带标注的原始样本外,提示法还给了模型一个用于对特定任务进行粗略描述的句子。从某种意义上说,这句话也是一种监督数据,因为它告诉模型任务的信息,但它在本质上与机器学习中标准的监督数据又截然不同。我们应该如何看待这种新的监督数据?又该如何量化这种方法的“零样本”程度?</p>
|
27 |
+
<p>我们通过在 SuperGLUE 任务和 MNLI 上比较<em>分类头法</em>和<em>提示法</em>来尝试回答上面的问题。我们使用的具体方法是:对每个任务,我们通过从数据集中选取样本数不断增加的子集,然后在每个子集上使用这两种方法对 <a href="https://arxiv.org/abs/1907.11692"><code>RoBERTa-large</code></a> 进行微调,同时其他所有配置保持不变,最后对评估各自的微调模型的性能。为了公平起见,我们先调整基线分类头模型的超参,并使它们的性能达到 SuperGLUE 排行榜中 BERT++ 的性能水平,然后在对应的<em>提示法</em>模型中采用相同的超参。</p>
|
28 |
+
<p id="footnote2back">下图绘制了每个任务 <a href="#footnote2"><sup>2</sup></a> 的最终性能(指标随任务而不同)随数据集大小的变化曲线。有了这个图,我们就能够对两种方法在给定任务上达到一定性能水平所需的数据量进行对比。我们将这种差异称为在该性能水平上其中一个方法相对于其他方法的<em>数据优势</em>。我们将两种方法都能达到的性能的范围称为<em>比较窗口</em>。通过在该范围内进行积分,我们可以获得在某任务上一种方法相对于另一种方法的<em>平均数据优势</em>。从图上看,这即是两条曲线所夹区域的的面积除以比较窗口的高度。<a href="#footnote3"><sup>3</sup></a> </p>
|
29 |
+
"""
|
30 |
+
|
31 |
+
text3 = """<html>
|
32 |
+
<head>
|
33 |
+
<style>
|
34 |
+
table, th, td {
|
35 |
+
border: 1px solid black;
|
36 |
+
border-collapse: collapse;
|
37 |
+
}
|
38 |
+
.styled-table {
|
39 |
+
margin-left: auto;
|
40 |
+
margin-right: auto;
|
41 |
+
}
|
42 |
+
.styled-table {
|
43 |
+
border-collapse: collapse;
|
44 |
+
font-size: 1em;
|
45 |
+
font-family: sans-serif;
|
46 |
+
min-width: 400px;
|
47 |
+
box-shadow: 0 0 20px rgba(0, 0, 0, 0.15);
|
48 |
+
}
|
49 |
+
.styled-table thead tr {
|
50 |
+
background-color: #ffebcd;
|
51 |
+
color: #000000;
|
52 |
+
text-align: left;
|
53 |
+
}
|
54 |
+
.styled-table th,
|
55 |
+
.styled-table td {
|
56 |
+
padding: 6px 8px;
|
57 |
+
font-size: 13px;
|
58 |
+
}
|
59 |
+
.styled-table tbody tr {
|
60 |
+
border-bottom: 1px solid #dddddd;
|
61 |
+
}
|
62 |
+
|
63 |
+
.styled-table tbody tr:nth-of-type(even) {
|
64 |
+
background-color: #f3f3f3;
|
65 |
+
}
|
66 |
+
|
67 |
+
.styled-table tbody tr:last-of-type {
|
68 |
+
border-bottom: 2px solid #29004a;
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
|
73 |
+
</style>
|
74 |
+
</head>
|
75 |
+
<body>
|
76 |
+
<p id="footnote4back">下表总结了在每个任务上提示法相对于分类头法的平均数据优势,其误差范围由自助采样法(bootstrapping)获得,具体做法是对每个数据规模,我们运行 4 次分类头法和 4 次提示法(即每个数据规模共 16 种组合),然后计算这些结果的标准差。不同任务的结果有很大不同;甚至对于同一任务的不同数据集,结果也会有所不同,例如 MNLI 和 RTE,这俩数据集虽然同属蕴涵任务,但结果就很不同。然而,总的趋势也很明显,即:除 WiC <a href="#footnote4"><sup>4</sup></a> 之外,提示方法在其他任务中都具有显著的优势。 <strong>提示提供的附加信息始终大致相当于数百个数据点</strong>。</p>
|
77 |
+
<table id="footnote5back" class="styled-table">
|
78 |
+
<thead>
|
79 |
+
<tr>
|
80 |
+
<th></th>
|
81 |
+
<th><a href="https://arxiv.org/abs/1704.05426">MNLI</a></th>
|
82 |
+
<th><a href="https://arxiv.org/abs/1905.10044">BoolQ</a></th>
|
83 |
+
<th><a href="https://ojs.ub.uni-konstanz.de/sub/index.php/sub/article/view/601">CB</a></th>
|
84 |
+
<th><a href="https://people.ict.usc.edu/~gordon/publications/AAAI-SPRING11A.PDF">COPA</a></th>
|
85 |
+
<th><a href="https://www.aclweb.org/anthology/N18-1023/">MultiRC</a><sup><a href="#footnote5">5</a></sup></th>
|
86 |
+
<th><a href="https://link.springer.com/chapter/10.1007/978-94-024-0881-2_42">RTE</a></th>
|
87 |
+
<th><a href="https://arxiv.org/abs/1808.09121">WiC</a></th>
|
88 |
+
<th><a href="https://arxiv.org/abs/1808.09121">WSC</a></th>
|
89 |
+
</tr>
|
90 |
+
</thead>
|
91 |
+
<tbody>
|
92 |
+
<tr>
|
93 |
+
<td>提示法 vs 分类头法</td>
|
94 |
+
<td>3506±536</td>
|
95 |
+
<td>752±46</td>
|
96 |
+
<td>90±2</td>
|
97 |
+
<td>288±242</td>
|
98 |
+
<td>384±378</td>
|
99 |
+
<td>282±34</td>
|
100 |
+
<td>-424±74</td>
|
101 |
+
<td>281±137</td>
|
102 |
+
</tr>
|
103 |
+
</tbody>
|
104 |
+
</table>
|
105 |
+
<h2 id="patterns-and-verbalizers">模板与言语器</h2>
|
106 |
+
<h4 id="control-verbalizers">对言语器进行控制</h4>
|
107 |
+
<p>当前,提示主要被用作零样本分类的工具,这是一个很自然的用法。然而,真正操作起来,零样本一般会很棘手,因为需要对提示和言语器进行完美对齐。在上文中,我们已经表明,提示可以应用到更广泛的场景中,包括在全数据场景中。为了对比提示法的零样本性和自适应性,我们考虑一个<em>空言语器(null verbalizer)</em>,该言语器与任务完全无关。对于只需要填写一个词的任务(因此 COPA 和 WSC 数据集不在此列),我们把其言语映射(例如"是"、"否"、"也许"、"对"或"错")替换成随机的。这样的话,提示模型就会像分类头模型一样,在没有训练数据的情况下无法使用。我们对空言语器配置进行与上文相同的优势分析,并绘制出相应的曲线,如下:</p>
|
108 |
+
</body>
|
109 |
+
</html>
|
110 |
+
"""
|
111 |
+
|
112 |
+
text4 = """<table id="footnote6back" class="styled-table">
|
113 |
+
<thead>
|
114 |
+
<tr>
|
115 |
+
<th></th>
|
116 |
+
<th>MNLI</th>
|
117 |
+
<th>BoolQ</th>
|
118 |
+
<th>CB</th>
|
119 |
+
<th>MultiRC<a href="#footnote5"><sup>6</sup></a></th>
|
120 |
+
<th>RTE</th>
|
121 |
+
<th>WiC</th>
|
122 |
+
</tr>
|
123 |
+
</thead>
|
124 |
+
<tbody>
|
125 |
+
<tr>
|
126 |
+
<td>提示法 vs 分类头法</td>
|
127 |
+
<td>3506±536</td>
|
128 |
+
<td>752±46</td>
|
129 |
+
<td>90±2</td>
|
130 |
+
<td>384±378</td>
|
131 |
+
<td>282±34</td>
|
132 |
+
<td>-424±74</td>
|
133 |
+
</tr>
|
134 |
+
<tr>
|
135 |
+
<td>提示法 vs 空言语器</td>
|
136 |
+
<td>150±252</td>
|
137 |
+
<td>299±81</td>
|
138 |
+
<td>78±2</td>
|
139 |
+
<td>74±56</td>
|
140 |
+
<td>404±68</td>
|
141 |
+
<td>-354±166</td>
|
142 |
+
</tr>
|
143 |
+
<tr>
|
144 |
+
<td>空言语器 vs 分类头法</td>
|
145 |
+
<td>3355±612</td>
|
146 |
+
<td>453±90</td>
|
147 |
+
<td>12±1</td>
|
148 |
+
<td>309±320</td>
|
149 |
+
<td>-122±62</td>
|
150 |
+
<td>-70±160</td>
|
151 |
+
</tr>
|
152 |
+
</tbody>
|
153 |
+
</table>
|
154 |
+
<p>从结果来看,其数据优势噪声比直接提示法与分类头法的数据优势的噪声更大。然而,我们也发现,即使仅使用空言语器,语言模型也能够适应任务,即使只有几个数据点,其也能凭借恰当的提示取得与分类头模型相当或更好的性能。因此,我们可以认为,即使没有信息丰富的言语器,提示法带来的归纳偏差也是有益的。</p>
|
155 |
+
<h4 id="influence-of-the-pattern-choice">模板选择带来的影响</h4>
|
156 |
+
<p>另一个可能影响提示法在零样本分类场景下的成败的因素是:提示模板的选择。这里,我们看一下该因素对我们的影响。我们复用了 PET 中的模板(每个任务有两到三个完全不同的模板),并对每个任务的每种模板进行了实验,结果如下。我们可以看到提示的选择对结果没有显著影响,其方差小于随机种子带来的方差。</p>
|
157 |
+
"""
|
158 |
+
|
159 |
+
text5 = """<h2 id="mot-de-la-fin">最后的话</h2>
|
160 |
+
<p>通过这项工作,我们研究了一种新的、基于自然语言提示的微调方法,其目的是通过单词预测显式地利用预训练模型的语言建模能力,而不是通过基于模型中间表征的线性分类器隐式地利用它。为了公平比较,我们把问题建模为用反向传播来微调基于提示的分类器语言模型,我们发现提示法通常优于标准的微调线性分类头的方法。我们用数据点来估计这种优势,以衡量人类通过提示提供的附加信息,并发现<strong>编写提示始终抵得上数百个数据点</strong>。此外,即使没有言语器带来的信息量(即使用空言语器),这种优势仍然存在,并且这种方法对于提示的选择相当鲁棒。</p>
|
161 |
+
<p>对于从业人员而言,我们相信基于提示的微调应当而且会成为一种标准工具:特别是对于中小型的特定任务数据集,自己设计提示只需付出很小的努力即可获得相当大的数据优势。而对于研究人员而言,我们认为这个领域还有很多问题尚待探索:为什么相同的提示在 MNLI 数据集上抵得上 3500 个样本,而在 RTE 数据集上却只抵得上 282 个样本?提示与标准 ML 监督有何关系?由于它们具有一些零样本特性,因此它们对对抗样本或领域外样本的反应是否有所不同?</p>
|
162 |
+
<p id="footnote1"><sup><a href="#footnote1back">1</a></sup>: 或者严格点说,至少据我们所知为否。</p>
|
163 |
+
<p id="footnote2"><sup><a href="#footnote2back">2</a></sup>: 眼尖的读者会注意到所有这些曲线都是单调的。我们为每个实验执行了 4 次运行(即对每个任务的每个数据规模,分别各运行分类头法和提示法 4 次,并用得到的模型测试)。为了清楚起见,并且由于两种方法的微调有时都会失败,从而导致负异常值,因此针对每个数据规模我们报告在此数据规模或更小的数据规模下获得的最大性能,我们将其称为*累积最大*聚合。除了减少方差之外,这不会对报告的数据优势产生太大影响,且即使对于非单调曲线,对图形的解读仍然成立。 </p>
|
164 |
+
<p id="footnote3"><sup><a href="#footnote2back">3</a></sup>: 在计算每个指标的数据优势时,我们为每个数据赋予相同的权重;我们还可以针对每个任务重新参数化 y 轴。这种做法到底是会对提示法相对于分类头法的数据优势起促进作用还是阻碍作用不好说,与数据集相关。举个例子,强调接近收敛时的增益会增加 CB 和 MNLI 上的提示法的数据优势,但会降低 COPA 或 BoolQ 上的优势。 </p>
|
165 |
+
<p id="footnote4"><sup><a href="#footnote4back">4</a></sup>: 有趣的是,PET 已经发现提示对该数据集无效。</p>
|
166 |
+
<p id="footnote5"><sup><a href="#footnote5back">5</a> <a href="#footnote6back">6</a></sup>: MultiRC 的比较窗口太小,因为分类头基线模型无法学习到多数类之外知识;我们使用整个区域来获得实际结果的下界。</p>
|
167 |
+
"""
|
168 |
+
|
169 |
+
initial_passage = "In informal games, it is customary to announce 'check' when making a move that puts the opponent's king in check. In formal competitions, however, check is rarely announced."
|
170 |
+
|
171 |
+
initial_question = "do you always have to say check in chess?"
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
bokeh==2.3.0
|
2 |
+
cycler==0.10.0
|
3 |
+
Jinja2==2.11.2
|
4 |
+
kiwisolver==1.3.1
|
5 |
+
MarkupSafe==1.1.1
|
6 |
+
matplotlib==3.4.1
|
7 |
+
numpy==1.18.4
|
8 |
+
packaging==20.4
|
9 |
+
pandas==1.0.3
|
10 |
+
Pillow==7.1.2
|
11 |
+
pyparsing==2.4.7
|
12 |
+
python-dateutil==2.8.1
|
13 |
+
pytz==2020.1
|
14 |
+
PyYAML==5.3.1
|
15 |
+
randomcolor==0.4.4.5
|
16 |
+
scipy==1.4.1
|
17 |
+
seaborn==0.11.1
|
18 |
+
Shapely==1.7.1
|
19 |
+
six==1.15.0
|
20 |
+
tornado==6.0.4
|
21 |
+
typing-extensions==3.7.4.2
|
22 |
+
virtualenv-clone==0.5.4
|