Updated
Browse files- app.py +309 -29
- contents.py +53 -0
- requirements.txt +3 -1
- style.py +19 -0
- utils.py +110 -0
app.py
CHANGED
@@ -1,39 +1,319 @@
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
|
|
3 |
from inseq.commands.attribute_context.attribute_context import (
|
4 |
AttributeContextArgs,
|
5 |
attribute_context,
|
6 |
-
visualize_attribute_context,
|
7 |
)
|
8 |
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
show_viz=False,
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
)
|
21 |
-
out = attribute_context(
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
)
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
import gradio as gr
|
5 |
+
import spaces
|
6 |
+
from contents import (
|
7 |
+
citation,
|
8 |
+
description,
|
9 |
+
examples,
|
10 |
+
how_it_works,
|
11 |
+
how_to_use,
|
12 |
+
subtitle,
|
13 |
+
title,
|
14 |
+
)
|
15 |
+
from gradio_highlightedtextbox import HighlightedTextbox
|
16 |
+
from style import custom_css
|
17 |
+
from utils import get_tuples_from_output
|
18 |
|
19 |
+
from inseq import list_feature_attribution_methods, list_step_functions
|
20 |
from inseq.commands.attribute_context.attribute_context import (
|
21 |
AttributeContextArgs,
|
22 |
attribute_context,
|
|
|
23 |
)
|
24 |
|
25 |
|
26 |
+
@spaces.GPU()
|
27 |
+
def pecore(
|
28 |
+
input_current_text: str,
|
29 |
+
input_context_text: str,
|
30 |
+
output_current_text: str,
|
31 |
+
output_context_text: str,
|
32 |
+
model_name_or_path: str,
|
33 |
+
attribution_method: str,
|
34 |
+
attributed_fn: str | None,
|
35 |
+
context_sensitivity_metric: str,
|
36 |
+
context_sensitivity_std_threshold: float,
|
37 |
+
context_sensitivity_topk: int,
|
38 |
+
attribution_std_threshold: float,
|
39 |
+
attribution_topk: int,
|
40 |
+
input_template: str,
|
41 |
+
input_current_text_template: str,
|
42 |
+
output_template: str,
|
43 |
+
special_tokens_to_keep: str | list[str] | None,
|
44 |
+
model_kwargs: str,
|
45 |
+
tokenizer_kwargs: str,
|
46 |
+
generation_kwargs: str,
|
47 |
+
attribution_kwargs: str,
|
48 |
+
):
|
49 |
+
formatted_input_current_text = input_current_text_template.format(
|
50 |
+
current=input_current_text
|
51 |
+
)
|
52 |
+
pecore_args = AttributeContextArgs(
|
53 |
+
show_intermediate_outputs=False,
|
54 |
+
save_path=os.path.join(os.path.dirname(__file__), "outputs/output.json"),
|
55 |
+
add_output_info=True,
|
56 |
+
viz_path=os.path.join(os.path.dirname(__file__), "outputs/output.html"),
|
57 |
show_viz=False,
|
58 |
+
model_name_or_path=model_name_or_path,
|
59 |
+
attribution_method=attribution_method,
|
60 |
+
attributed_fn=attributed_fn,
|
61 |
+
attribution_selectors=None,
|
62 |
+
attribution_aggregators=None,
|
63 |
+
normalize_attributions=True,
|
64 |
+
model_kwargs=json.loads(model_kwargs),
|
65 |
+
tokenizer_kwargs=json.loads(tokenizer_kwargs),
|
66 |
+
generation_kwargs=json.loads(generation_kwargs),
|
67 |
+
attribution_kwargs=json.loads(attribution_kwargs),
|
68 |
+
context_sensitivity_metric=context_sensitivity_metric,
|
69 |
+
align_output_context_auto=False,
|
70 |
+
prompt_user_for_contextless_output_next_tokens=False,
|
71 |
+
special_tokens_to_keep=special_tokens_to_keep,
|
72 |
+
context_sensitivity_std_threshold=context_sensitivity_std_threshold,
|
73 |
+
context_sensitivity_topk=context_sensitivity_topk
|
74 |
+
if context_sensitivity_topk > 0
|
75 |
+
else None,
|
76 |
+
attribution_std_threshold=attribution_std_threshold,
|
77 |
+
attribution_topk=attribution_topk if attribution_topk > 0 else None,
|
78 |
+
input_current_text=formatted_input_current_text,
|
79 |
+
input_context_text=input_context_text if input_context_text else None,
|
80 |
+
input_template=input_template,
|
81 |
+
output_current_text=output_current_text if output_current_text else None,
|
82 |
+
output_context_text=output_context_text if output_context_text else None,
|
83 |
+
output_template=output_template,
|
84 |
)
|
85 |
+
out = attribute_context(pecore_args)
|
86 |
+
return get_tuples_from_output(out), gr.Button(visible=True), gr.Button(visible=True)
|
87 |
+
|
88 |
+
|
89 |
+
with gr.Blocks(css=custom_css) as demo:
|
90 |
+
gr.Markdown(title)
|
91 |
+
gr.Markdown(subtitle)
|
92 |
+
gr.Markdown(description)
|
93 |
+
with gr.Tab("π Attributing Context"):
|
94 |
+
with gr.Row():
|
95 |
+
with gr.Column():
|
96 |
+
input_current_text = gr.Textbox(
|
97 |
+
label="Input query", placeholder="Your input query..."
|
98 |
+
)
|
99 |
+
input_context_text = gr.Textbox(
|
100 |
+
label="Input context", lines=4, placeholder="Your input context..."
|
101 |
+
)
|
102 |
+
attribute_input_button = gr.Button("Submit", variant="primary")
|
103 |
+
with gr.Column():
|
104 |
+
pecore_output_highlights = HighlightedTextbox(
|
105 |
+
value=[
|
106 |
+
("This output will contain ", None),
|
107 |
+
("context sensitive", "Context sensitive"),
|
108 |
+
(" generated tokens and ", None),
|
109 |
+
("influential context", "Influential context"),
|
110 |
+
(" tokens.", None),
|
111 |
+
],
|
112 |
+
color_map={
|
113 |
+
"Context sensitive": "green",
|
114 |
+
"Influential context": "blue",
|
115 |
+
},
|
116 |
+
show_legend=True,
|
117 |
+
label="PECoRe Output",
|
118 |
+
combine_adjacent=True,
|
119 |
+
interactive=False,
|
120 |
+
)
|
121 |
+
with gr.Row(equal_height=True):
|
122 |
+
download_output_file_button = gr.Button(
|
123 |
+
"β Download output",
|
124 |
+
visible=False,
|
125 |
+
link=os.path.join(
|
126 |
+
os.path.dirname(__file__), "/file=outputs/output.json"
|
127 |
+
),
|
128 |
+
)
|
129 |
+
download_output_html_button = gr.Button(
|
130 |
+
"π Download HTML",
|
131 |
+
visible=False,
|
132 |
+
link=os.path.join(
|
133 |
+
os.path.dirname(__file__), "/file=outputs/output.html"
|
134 |
+
),
|
135 |
+
)
|
136 |
+
|
137 |
+
attribute_input_examples = gr.Examples(
|
138 |
+
examples,
|
139 |
+
inputs=[input_current_text, input_context_text],
|
140 |
+
outputs=pecore_output_highlights,
|
141 |
+
)
|
142 |
+
with gr.Tab("βοΈ Parameters"):
|
143 |
+
gr.Markdown("## βοΈ PECoRe Parameters")
|
144 |
+
with gr.Row(equal_height=True):
|
145 |
+
model_name_or_path = gr.Textbox(
|
146 |
+
value="gsarti/cora_mgen",
|
147 |
+
label="Model",
|
148 |
+
info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
|
149 |
+
interactive=True,
|
150 |
+
)
|
151 |
+
context_sensitivity_metric = gr.Dropdown(
|
152 |
+
value="kl_divergence",
|
153 |
+
label="Context sensitivity metric",
|
154 |
+
info="Metric to use to measure context sensitivity of generated tokens.",
|
155 |
+
choices=list_step_functions(),
|
156 |
+
interactive=True,
|
157 |
+
)
|
158 |
+
attribution_method = gr.Dropdown(
|
159 |
+
value="saliency",
|
160 |
+
label="Attribution method",
|
161 |
+
info="Attribution method identifier to identify relevant context tokens.",
|
162 |
+
choices=list_feature_attribution_methods(),
|
163 |
+
interactive=True,
|
164 |
+
)
|
165 |
+
attributed_fn = gr.Dropdown(
|
166 |
+
value="contrast_prob_diff",
|
167 |
+
label="Attributed function",
|
168 |
+
info="Function of model logits to use as target for the attribution method.",
|
169 |
+
choices=list_step_functions(),
|
170 |
+
interactive=True,
|
171 |
+
)
|
172 |
+
gr.Markdown("#### Results Selection Parameters")
|
173 |
+
with gr.Row(equal_height=True):
|
174 |
+
context_sensitivity_std_threshold = gr.Number(
|
175 |
+
value=1.0,
|
176 |
+
label="Context sensitivity threshold",
|
177 |
+
info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.",
|
178 |
+
precision=1,
|
179 |
+
minimum=0.0,
|
180 |
+
maximum=5.0,
|
181 |
+
step=0.5,
|
182 |
+
interactive=True,
|
183 |
+
)
|
184 |
+
context_sensitivity_topk = gr.Number(
|
185 |
+
value=0,
|
186 |
+
label="Context sensitivity top-k",
|
187 |
+
info="Select N to keep top N context sensitive tokens. 0 = keep all.",
|
188 |
+
interactive=True,
|
189 |
+
precision=0,
|
190 |
+
minimum=0,
|
191 |
+
maximum=10,
|
192 |
+
)
|
193 |
+
attribution_std_threshold = gr.Number(
|
194 |
+
value=1.0,
|
195 |
+
label="Attribution threshold",
|
196 |
+
info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.",
|
197 |
+
precision=1,
|
198 |
+
minimum=0.0,
|
199 |
+
maximum=5.0,
|
200 |
+
step=0.5,
|
201 |
+
interactive=True,
|
202 |
+
)
|
203 |
+
attribution_topk = gr.Number(
|
204 |
+
value=0,
|
205 |
+
label="Attribution top-k",
|
206 |
+
info="Select N to keep top N attributed tokens in the context. 0 = keep all.",
|
207 |
+
interactive=True,
|
208 |
+
precision=0,
|
209 |
+
minimum=0,
|
210 |
+
maximum=50,
|
211 |
+
)
|
212 |
+
|
213 |
+
gr.Markdown("#### Text Format Parameters")
|
214 |
+
with gr.Row(equal_height=True):
|
215 |
+
input_template = gr.Textbox(
|
216 |
+
value="{current} <P>:{context}",
|
217 |
+
label="Input template",
|
218 |
+
info="Template to format the input for the model. Use {current} and {context} placeholders.",
|
219 |
+
interactive=True,
|
220 |
+
)
|
221 |
+
output_template = gr.Textbox(
|
222 |
+
value="{current}",
|
223 |
+
label="Output template",
|
224 |
+
info="Template to format the output from the model. Use {current} and {context} placeholders.",
|
225 |
+
interactive=True,
|
226 |
+
)
|
227 |
+
input_current_text_template = gr.Textbox(
|
228 |
+
value="<Q>:{current}",
|
229 |
+
label="Input current text template",
|
230 |
+
info="Template to format the input query for the model. Use {current} placeholder.",
|
231 |
+
interactive=True,
|
232 |
+
)
|
233 |
+
special_tokens_to_keep = gr.Dropdown(
|
234 |
+
label="Special tokens to keep",
|
235 |
+
info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.",
|
236 |
+
value=None,
|
237 |
+
multiselect=True,
|
238 |
+
allow_custom_value=True,
|
239 |
+
)
|
240 |
+
|
241 |
+
gr.Markdown("## βοΈ Generation Parameters")
|
242 |
+
with gr.Row(equal_height=True):
|
243 |
+
output_current_text = gr.Textbox(
|
244 |
+
label="Generation output",
|
245 |
+
info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.",
|
246 |
+
interactive=True,
|
247 |
+
)
|
248 |
+
output_context_text = gr.Textbox(
|
249 |
+
label="Generation context",
|
250 |
+
info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.",
|
251 |
+
interactive=True,
|
252 |
+
)
|
253 |
+
generation_kwargs = gr.Code(
|
254 |
+
value="{}",
|
255 |
+
language="json",
|
256 |
+
label="Generation kwargs",
|
257 |
+
interactive=True,
|
258 |
+
lines=1,
|
259 |
+
)
|
260 |
+
gr.Markdown("## βοΈ Other Parameters")
|
261 |
+
with gr.Row(equal_height=True):
|
262 |
+
model_kwargs = gr.Code(
|
263 |
+
value="{}",
|
264 |
+
language="json",
|
265 |
+
label="Model kwargs",
|
266 |
+
interactive=True,
|
267 |
+
lines=1,
|
268 |
+
)
|
269 |
+
tokenizer_kwargs = gr.Code(
|
270 |
+
value="{}",
|
271 |
+
language="json",
|
272 |
+
label="Tokenizer kwargs",
|
273 |
+
interactive=True,
|
274 |
+
lines=1,
|
275 |
+
)
|
276 |
+
attribution_kwargs = gr.Code(
|
277 |
+
value="{}",
|
278 |
+
language="json",
|
279 |
+
label="Attribution kwargs",
|
280 |
+
interactive=True,
|
281 |
+
lines=1,
|
282 |
+
)
|
283 |
+
|
284 |
+
gr.Markdown(how_it_works)
|
285 |
+
gr.Markdown(how_to_use)
|
286 |
+
gr.Markdown(citation)
|
287 |
+
|
288 |
+
attribute_input_button.click(
|
289 |
+
pecore,
|
290 |
+
inputs=[
|
291 |
+
input_current_text,
|
292 |
+
input_context_text,
|
293 |
+
output_current_text,
|
294 |
+
output_context_text,
|
295 |
+
model_name_or_path,
|
296 |
+
attribution_method,
|
297 |
+
attributed_fn,
|
298 |
+
context_sensitivity_metric,
|
299 |
+
context_sensitivity_std_threshold,
|
300 |
+
context_sensitivity_topk,
|
301 |
+
attribution_std_threshold,
|
302 |
+
attribution_topk,
|
303 |
+
input_template,
|
304 |
+
input_current_text_template,
|
305 |
+
output_template,
|
306 |
+
special_tokens_to_keep,
|
307 |
+
model_kwargs,
|
308 |
+
tokenizer_kwargs,
|
309 |
+
generation_kwargs,
|
310 |
+
attribution_kwargs,
|
311 |
+
],
|
312 |
+
outputs=[
|
313 |
+
pecore_output_highlights,
|
314 |
+
download_output_file_button,
|
315 |
+
download_output_html_button,
|
316 |
+
],
|
317 |
+
)
|
318 |
+
|
319 |
+
demo.launch(allowed_paths=["outputs/"])
|
contents.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
title = "<h1 class='demo-title'>π Plausibility Evaluation of Context Reliance (PECoRe) π</h1>"
|
2 |
+
|
3 |
+
subtitle = "<h2 class='demo-subtitle'>An Interpretability Framework to Detect and Attribute Context Reliance in Language Models</h2>"
|
4 |
+
|
5 |
+
description = """
|
6 |
+
Given a query and a context passed as inputs to a LM, PECoRe will identify which tokens in the generated
|
7 |
+
response were dependant on context, and match them with context tokens contributing to their prediction.
|
8 |
+
For more information, check out our <a href="https://openreview.net/forum?id=XTHfNGI3zT" target='_blank'>ICLR 2024 paper</a>.
|
9 |
+
"""
|
10 |
+
|
11 |
+
how_it_works = r"""
|
12 |
+
<details>
|
13 |
+
<summary><h3 class="summary-label">βοΈ How Does It Work?</h3></summary>
|
14 |
+
<br/>
|
15 |
+
PECoRe uses a contrastive approach to attribute context reliance in language models.
|
16 |
+
It compares the model's predictions when the context is present and when it is absent, and attributes the difference in predictions to the context tokens.
|
17 |
+
</details>
|
18 |
+
"""
|
19 |
+
|
20 |
+
how_to_use = r"""
|
21 |
+
<details>
|
22 |
+
<summary><h3 class="summary-label">π§ How to Use PECoRe</h3></summary>
|
23 |
+
|
24 |
+
</details>
|
25 |
+
"""
|
26 |
+
|
27 |
+
citation = r"""
|
28 |
+
<details>
|
29 |
+
<summary><h3 class="summary-label">π Citing PECoRe</h3></summary>
|
30 |
+
|
31 |
+
@inproceedings{sarti-etal-2023-quantifying,
|
32 |
+
title = "Quantifying the Plausibility of Context Reliance in Neural Machine Translation",
|
33 |
+
author = "Sarti, Gabriele and
|
34 |
+
Chrupa{\l}a, Grzegorz and
|
35 |
+
Nissim, Malvina and
|
36 |
+
Bisazza, Arianna",
|
37 |
+
booktitle = "The Twelfth International Conference on Learning Representations (ICLR 2024)",
|
38 |
+
month = may,
|
39 |
+
year = "2024",
|
40 |
+
address = "Vienna, Austria",
|
41 |
+
publisher = "OpenReview",
|
42 |
+
url = "https://openreview.net/forum?id=XTHfNGI3zT"
|
43 |
+
}
|
44 |
+
|
45 |
+
</details>
|
46 |
+
"""
|
47 |
+
|
48 |
+
examples = [
|
49 |
+
[
|
50 |
+
"When was Banff National Park established?",
|
51 |
+
"Banff National Park is Canada's oldest national park, established in 1885 as Rocky Mountains Park. Located in Alberta's Rocky Mountains, 110β180 kilometres (68β112 mi) west of Calgary, Banff encompasses 6,641 square kilometres (2,564 sq mi) of mountainous terrain.",
|
52 |
+
]
|
53 |
+
]
|
requirements.txt
CHANGED
@@ -1 +1,3 @@
|
|
1 |
-
|
|
|
|
|
|
1 |
+
spaces
|
2 |
+
git+https://github.com/inseq-team/inseq.git@main
|
3 |
+
gradio_highlightedtextbox
|
style.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
custom_css = """
|
2 |
+
.demo-title {
|
3 |
+
text-align: center;
|
4 |
+
display: block;
|
5 |
+
margin-bottom: 0;
|
6 |
+
font-size: 2em;
|
7 |
+
}
|
8 |
+
|
9 |
+
.demo-subtitle {
|
10 |
+
text-align: center;
|
11 |
+
display: block;
|
12 |
+
margin-top: 0;
|
13 |
+
font-size: 1.5em;
|
14 |
+
}
|
15 |
+
|
16 |
+
.summary-label {
|
17 |
+
display: inline;
|
18 |
+
}
|
19 |
+
"""
|
utils.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
from typing import Optional
|
3 |
+
|
4 |
+
from inseq import load_model
|
5 |
+
from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
|
6 |
+
from inseq.commands.attribute_context.attribute_context_helpers import (
|
7 |
+
AttributeContextOutput,
|
8 |
+
filter_rank_tokens,
|
9 |
+
get_filtered_tokens,
|
10 |
+
)
|
11 |
+
from inseq.models import HuggingfaceModel
|
12 |
+
|
13 |
+
|
14 |
+
def get_formatted_attribute_context_results(
|
15 |
+
model: HuggingfaceModel,
|
16 |
+
args: AttributeContextArgs,
|
17 |
+
output: AttributeContextOutput,
|
18 |
+
) -> str:
|
19 |
+
"""Format the results of the context attribution process."""
|
20 |
+
|
21 |
+
def format_context_comment(
|
22 |
+
model: HuggingfaceModel,
|
23 |
+
has_other_context: bool,
|
24 |
+
special_tokens_to_keep: list[str],
|
25 |
+
context: str,
|
26 |
+
context_scores: list[float],
|
27 |
+
other_context_scores: Optional[list[float]] = None,
|
28 |
+
is_target: bool = False,
|
29 |
+
) -> str:
|
30 |
+
context_tokens = get_filtered_tokens(
|
31 |
+
context,
|
32 |
+
model,
|
33 |
+
special_tokens_to_keep,
|
34 |
+
replace_special_characters=True,
|
35 |
+
is_target=is_target,
|
36 |
+
)
|
37 |
+
context_token_tuples = [(t, None) for t in context_tokens]
|
38 |
+
scores = context_scores
|
39 |
+
if has_other_context:
|
40 |
+
scores += other_context_scores
|
41 |
+
context_ranked_tokens, _ = filter_rank_tokens(
|
42 |
+
tokens=context_tokens,
|
43 |
+
scores=scores,
|
44 |
+
std_threshold=args.attribution_std_threshold,
|
45 |
+
topk=args.attribution_topk,
|
46 |
+
)
|
47 |
+
for idx, _, tok in context_ranked_tokens:
|
48 |
+
context_token_tuples[idx] = (tok, "Influential context")
|
49 |
+
return context_token_tuples
|
50 |
+
|
51 |
+
out = []
|
52 |
+
output_current_tokens = get_filtered_tokens(
|
53 |
+
output.output_current,
|
54 |
+
model,
|
55 |
+
args.special_tokens_to_keep,
|
56 |
+
replace_special_characters=True,
|
57 |
+
is_target=True,
|
58 |
+
)
|
59 |
+
for example_idx, cci_out in enumerate(output.cci_scores, start=1):
|
60 |
+
curr_output_tokens = [(t, None) for t in output_current_tokens]
|
61 |
+
cti_idx = cci_out.cti_idx
|
62 |
+
curr_output_tokens[cti_idx] = (
|
63 |
+
curr_output_tokens[cti_idx][0],
|
64 |
+
"Context sensitive",
|
65 |
+
)
|
66 |
+
if args.has_input_context:
|
67 |
+
input_context_tokens = format_context_comment(
|
68 |
+
model,
|
69 |
+
args.has_output_context,
|
70 |
+
args.special_tokens_to_keep,
|
71 |
+
output.input_context,
|
72 |
+
cci_out.input_context_scores,
|
73 |
+
cci_out.output_context_scores,
|
74 |
+
)
|
75 |
+
if args.has_output_context:
|
76 |
+
output_context_tokens = format_context_comment(
|
77 |
+
model,
|
78 |
+
args.has_input_context,
|
79 |
+
args.special_tokens_to_keep,
|
80 |
+
output.output_context,
|
81 |
+
cci_out.output_context_scores,
|
82 |
+
cci_out.input_context_scores,
|
83 |
+
is_target=True,
|
84 |
+
context_type="Output",
|
85 |
+
)
|
86 |
+
out += [
|
87 |
+
("\n\n" if example_idx > 1 else "", None),
|
88 |
+
(
|
89 |
+
f"#{example_idx}.\nGenerated output:\t",
|
90 |
+
None,
|
91 |
+
),
|
92 |
+
]
|
93 |
+
out += curr_output_tokens
|
94 |
+
if args.has_input_context:
|
95 |
+
out += [("\nInput context:\t", None)]
|
96 |
+
out += input_context_tokens
|
97 |
+
if args.has_output_context:
|
98 |
+
out += [("\\Output context:\t", None)]
|
99 |
+
out += output_context_tokens
|
100 |
+
return out
|
101 |
+
|
102 |
+
|
103 |
+
def get_tuples_from_output(output: AttributeContextOutput):
|
104 |
+
model = load_model(
|
105 |
+
output.info.model_name_or_path,
|
106 |
+
output.info.attribution_method,
|
107 |
+
model_kwargs=deepcopy(output.info.model_kwargs),
|
108 |
+
tokenizer_kwargs=deepcopy(output.info.tokenizer_kwargs),
|
109 |
+
)
|
110 |
+
return get_formatted_attribute_context_results(model, output.info, output)
|