File size: 3,344 Bytes
cf3d1b1 2b66ced cf3d1b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
from typing import Optional
from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
from inseq.commands.attribute_context.attribute_context_helpers import (
AttributeContextOutput,
filter_rank_tokens,
get_filtered_tokens,
)
from inseq.models import HuggingfaceModel
def get_formatted_attribute_context_results(
model: HuggingfaceModel,
args: AttributeContextArgs,
output: AttributeContextOutput,
) -> str:
"""Format the results of the context attribution process."""
def format_context_comment(
model: HuggingfaceModel,
has_other_context: bool,
special_tokens_to_keep: list[str],
context: str,
context_scores: list[float],
other_context_scores: Optional[list[float]] = None,
is_target: bool = False,
) -> str:
context_tokens = get_filtered_tokens(
context,
model,
special_tokens_to_keep,
replace_special_characters=True,
is_target=is_target,
)
context_token_tuples = [(t, None) for t in context_tokens]
scores = context_scores
if has_other_context:
scores += other_context_scores
context_ranked_tokens, _ = filter_rank_tokens(
tokens=context_tokens,
scores=scores,
std_threshold=args.attribution_std_threshold,
topk=args.attribution_topk,
)
for idx, _, tok in context_ranked_tokens:
context_token_tuples[idx] = (tok, "Influential context")
return context_token_tuples
out = []
output_current_tokens = get_filtered_tokens(
output.output_current,
model,
args.special_tokens_to_keep,
replace_special_characters=True,
is_target=True,
)
for example_idx, cci_out in enumerate(output.cci_scores, start=1):
curr_output_tokens = [(t, None) for t in output_current_tokens]
cti_idx = cci_out.cti_idx
curr_output_tokens[cti_idx] = (
curr_output_tokens[cti_idx][0],
"Context sensitive",
)
if args.has_input_context:
input_context_tokens = format_context_comment(
model,
args.has_output_context,
args.special_tokens_to_keep,
output.input_context,
cci_out.input_context_scores,
cci_out.output_context_scores,
)
if args.has_output_context:
output_context_tokens = format_context_comment(
model,
args.has_input_context,
args.special_tokens_to_keep,
output.output_context,
cci_out.output_context_scores,
cci_out.input_context_scores,
is_target=True,
)
out += [
("\n\n" if example_idx > 1 else "", None),
(
f"#{example_idx}.\nGenerated output:\t",
None,
),
]
out += curr_output_tokens
if args.has_input_context:
out += [("\nInput context:\t", None)]
out += input_context_tokens
if args.has_output_context:
out += [("\nOutput context:\t", None)]
out += output_context_tokens
return out
|