File size: 3,344 Bytes
cf3d1b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b66ced
cf3d1b1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from typing import Optional

from inseq.commands.attribute_context.attribute_context_args import AttributeContextArgs
from inseq.commands.attribute_context.attribute_context_helpers import (
    AttributeContextOutput,
    filter_rank_tokens,
    get_filtered_tokens,
)
from inseq.models import HuggingfaceModel


def get_formatted_attribute_context_results(
    model: HuggingfaceModel,
    args: AttributeContextArgs,
    output: AttributeContextOutput,
) -> str:
    """Format the results of the context attribution process."""

    def format_context_comment(
        model: HuggingfaceModel,
        has_other_context: bool,
        special_tokens_to_keep: list[str],
        context: str,
        context_scores: list[float],
        other_context_scores: Optional[list[float]] = None,
        is_target: bool = False,
    ) -> str:
        context_tokens = get_filtered_tokens(
            context,
            model,
            special_tokens_to_keep,
            replace_special_characters=True,
            is_target=is_target,
        )
        context_token_tuples = [(t, None) for t in context_tokens]
        scores = context_scores
        if has_other_context:
            scores += other_context_scores
        context_ranked_tokens, _ = filter_rank_tokens(
            tokens=context_tokens,
            scores=scores,
            std_threshold=args.attribution_std_threshold,
            topk=args.attribution_topk,
        )
        for idx, _, tok in context_ranked_tokens:
            context_token_tuples[idx] = (tok, "Influential context")
        return context_token_tuples

    out = []
    output_current_tokens = get_filtered_tokens(
        output.output_current,
        model,
        args.special_tokens_to_keep,
        replace_special_characters=True,
        is_target=True,
    )
    for example_idx, cci_out in enumerate(output.cci_scores, start=1):
        curr_output_tokens = [(t, None) for t in output_current_tokens]
        cti_idx = cci_out.cti_idx
        curr_output_tokens[cti_idx] = (
            curr_output_tokens[cti_idx][0],
            "Context sensitive",
        )
        if args.has_input_context:
            input_context_tokens = format_context_comment(
                model,
                args.has_output_context,
                args.special_tokens_to_keep,
                output.input_context,
                cci_out.input_context_scores,
                cci_out.output_context_scores,
            )
        if args.has_output_context:
            output_context_tokens = format_context_comment(
                model,
                args.has_input_context,
                args.special_tokens_to_keep,
                output.output_context,
                cci_out.output_context_scores,
                cci_out.input_context_scores,
                is_target=True,
            )
        out += [
            ("\n\n" if example_idx > 1 else "", None),
            (
                f"#{example_idx}.\nGenerated output:\t",
                None,
            ),
        ]
        out += curr_output_tokens
        if args.has_input_context:
            out += [("\nInput context:\t", None)]
            out += input_context_tokens
        if args.has_output_context:
            out += [("\nOutput context:\t", None)]
            out += output_context_tokens
    return out