<|im_start|>
, <|im_end|>
special tokens."
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
mbart_mmt_template = gr.Button(
"mBART Multilingual MT", variant="secondary"
)
gr.Markdown(
"Preset for the mBART Many-to-Many multilingual MT model using language tags (default: English to French)."
)
with gr.Column(scale=1):
nllb_mmt_template = gr.Button(
"NLLB Multilingual MT", variant="secondary"
)
gr.Markdown(
"Preset for the NLLB 600M multilingual MT model using language tags (default: English to French)."
)
with gr.Column(scale=1):
towerinstruct_template = gr.Button(
"Unbabel TowerInstruct", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for models using the Unbabel TowerInstruct conversational template.\nUses <|im_start|>
, <|im_end|>
special tokens."
)
with gr.Row(equal_height=True):
with gr.Column():
zephyr_preset = gr.Button("Zephyr Template", variant="secondary", interactive=False)
gr.Markdown(
"Preset for models using the StableLM 2 Zephyr conversational template.\nUses <|system|>
, <|user|>
and <|assistant|>
special tokens."
)
with gr.Column(scale=1):
gemma_template = gr.Button(
"Gemma Chat Template", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for Gemma instruction-tuned models."
)
with gr.Column(scale=1):
mistral_instruct_template = gr.Button(
"Mistral Instruct", variant="secondary", interactive=False
)
gr.Markdown(
"Preset for models using the Mistral Instruct template.\nUses [INST]...[/INST]
special tokens."
)
gr.Markdown("## ⚙️ PECoRe Parameters")
with gr.Row(equal_height=True):
with gr.Column():
model_name_or_path = gr.Textbox(
value="gsarti/cora_mgen",
label="Model",
info="Hugging Face Hub identifier of the model to analyze with PECoRe.",
interactive=True,
)
load_model_button = gr.Button(
"Load model",
variant="secondary",
)
context_sensitivity_metric = gr.Dropdown(
value="kl_divergence",
label="Context sensitivity metric",
info="Metric to use to measure context sensitivity of generated tokens.",
choices=[
"probability",
"logit",
"kl_divergence",
"contrast_logits_diff",
"contrast_prob_diff",
"pcxmi"
],
interactive=True,
)
attribution_method = gr.Dropdown(
value="saliency",
label="Attribution method",
info="Attribution method identifier to identify relevant context tokens.",
choices=[
"saliency",
"input_x_gradient",
"value_zeroing",
],
interactive=True,
)
attributed_fn = gr.Dropdown(
value="contrast_prob_diff",
label="Attributed function",
info="Function of model logits to use as target for the attribution method.",
choices=[
"probability",
"logit",
"contrast_logits_diff",
"contrast_prob_diff",
],
interactive=True,
)
gr.Markdown("#### Results Selection Parameters")
with gr.Row(equal_height=True):
context_sensitivity_std_threshold = gr.Number(
value=0.0,
label="Context sensitivity threshold",
info="Select N to keep context sensitive tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
context_sensitivity_topk = gr.Number(
value=0,
label="Context sensitivity top-k",
info="Select N to keep top N context sensitive tokens. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=10,
)
attribution_std_threshold = gr.Number(
value=2.0,
label="Attribution threshold",
info="Select N to keep attributed tokens with scores above N * std. 0 = above mean.",
precision=1,
minimum=0.0,
maximum=5.0,
step=0.5,
interactive=True,
)
attribution_topk = gr.Number(
value=5,
label="Attribution top-k",
info="Select N to keep top N attributed tokens in the context. 0 = keep all.",
interactive=True,
precision=0,
minimum=0,
maximum=100,
)
gr.Markdown("#### Text Format Parameters")
with gr.Row(equal_height=True):
input_template = gr.Textbox(
value=": {current}: {context}", label="Contextual input template", info="Template to format the input for the model. Use {current} and {context} placeholders for Input Query and Input Context, respectively.", interactive=True, ) output_template = gr.Textbox( value="{current}", label="Contextual output template", info="Template to format the output from the model. Use {current} and {context} placeholders for Generation Output and Generation Context, respectively.", interactive=True, ) contextless_input_template = gr.Textbox( value="
: {current}", label="Contextless input template", info="Template to format the input query in the non-contextual setting. Use {current} placeholder for Input Query.", interactive=True, ) contextless_output_template = gr.Textbox( value="{current}", label="Contextless output template", info="Template to format the output from the model. Use {current} placeholder for Generation Output.", interactive=True, ) with gr.Row(equal_height=True): special_tokens_to_keep = gr.Dropdown( label="Special tokens to keep", info="Special tokens to keep in the attribution. If empty, all special tokens are ignored.", value=None, multiselect=True, allow_custom_value=True, ) decoder_input_output_separator = gr.Textbox( label="Decoder input/output separator", info="Separator to use between input and output in the decoder input.", value="", interactive=True, lines=1, ) gr.Markdown("## ⚙️ Generation Parameters") with gr.Row(equal_height=True): with gr.Column(scale=0.5): gr.Markdown( "The following arguments can be used to control generation parameters and force specific model outputs." ) with gr.Column(scale=1): generation_kwargs = gr.Code( value="{}", language="json", label="Generation kwargs (JSON)", interactive=True, lines=1, ) with gr.Row(equal_height=True): output_current_text = gr.Textbox( label="Generation output", info="Specifies an output to force-decoded during generation. If blank, the model will generate freely.", interactive=True, ) output_context_text = gr.Textbox( label="Generation context", info="If specified, this context is used as starting point for generation. Useful for e.g. chain-of-thought reasoning.", interactive=True, ) gr.Markdown("## ⚙️ Other Parameters") with gr.Row(equal_height=True): with gr.Column(): gr.Markdown( "The following arguments will be passed to initialize the Hugging Face model and tokenizer, and to the `inseq_model.attribute` method." ) with gr.Column(): model_kwargs = gr.Code( value="{}", language="json", label="Model kwargs (JSON)", interactive=True, lines=1, min_width=160, ) with gr.Column(): tokenizer_kwargs = gr.Code( value="{}", language="json", label="Tokenizer kwargs (JSON)", interactive=True, lines=1, ) with gr.Column(): attribution_kwargs = gr.Code( value='{\n\t"logprob": true\n}', language="json", label="Attribution kwargs (JSON)", interactive=True, lines=1, ) with gr.Tab("🔍 How Does It Work?"): gr.Markdown(how_it_works_intro) with gr.Row(equal_height=True): with gr.Column(scale=0.60): gr.Markdown(cti_explanation) with gr.Column(scale=0.30): gr.HTML('') with gr.Row(equal_height=True): with gr.Column(scale=0.35): gr.HTML('') with gr.Column(scale=0.65): gr.Markdown(cci_explanation) with gr.Tab("🔧 Usage Guide"): gr.Markdown(how_to_use) gr.Markdown(example_explanation) with gr.Tab("❓ FAQ"): gr.Markdown(faq) with gr.Tab("📚 Citing PECoRe"): gr.Markdown("To refer to the PECoRe framework for context usage detection, cite:") gr.Code(pecore_citation, interactive=False, label="PECoRe (Sarti et al., 2024)") gr.Markdown("If you use the Inseq implementation of PECoRe (inseq attribute-context
, including this demo), please also cite:") gr.Code(inseq_citation, interactive=False, label="Inseq (Sarti et al., 2023)") with gr.Row(elem_classes="footer-container"): with gr.Column(): gr.Markdown(powered_by) with gr.Column(): with gr.Row(elem_classes="footer-custom-block"): with gr.Column(scale=0.25, min_width=150): gr.Markdown("""Built by Gabriele Sarti
with the support of""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with gr.Column(scale=0.25, min_width=120): gr.Markdown("""""") with Modal(visible=False) as code_modal: gr.Markdown(show_code_modal) with gr.Row(equal_height=True): with gr.Column(scale=0.5): python_code_snippet = gr.Code( value="""Generate Python code snippet by pressing the button.""", language="python", label="Python", interactive=False, show_label=True, ) with gr.Column(scale=0.5): shell_code_snippet = gr.Code( value="""Generate Shell code snippet by pressing the button.""", language="shell", label="Shell", interactive=False, show_label=True, ) # Main logic load_model_args = [ model_name_or_path, attribution_method, model_kwargs, tokenizer_kwargs, ] pecore_args = [ input_current_text, input_context_text, output_current_text, output_context_text, model_name_or_path, attribution_method, attributed_fn, context_sensitivity_metric, context_sensitivity_std_threshold, context_sensitivity_topk, attribution_std_threshold, attribution_topk, input_template, output_template, contextless_input_template, contextless_output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ] attribute_input_button.click( lambda *args: [gr.DownloadButton(visible=False), gr.DownloadButton(visible=False)], inputs=[], outputs=[download_output_file_button, download_output_html_button], ).then( pecore, inputs=pecore_args, outputs=[ pecore_output_highlights, download_output_file_button, download_output_html_button, ], ) load_model_event = load_model_button.click( preload_model, inputs=load_model_args, outputs=[], ) # Preset params check_enable_large_models.input( lambda checkbox, *buttons: [gr.Button(interactive=checkbox) for _ in buttons], inputs=[check_enable_large_models, zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template], outputs=[zephyr_preset, towerinstruct_template, gemma_template, mistral_instruct_template], ) outputs_to_reset = [ model_name_or_path, input_template, output_template, contextless_input_template, contextless_output_template, special_tokens_to_keep, decoder_input_output_separator, model_kwargs, tokenizer_kwargs, generation_kwargs, attribution_kwargs, ] reset_kwargs = { "fn": set_default_preset, "inputs": None, "outputs": outputs_to_reset, } # Presets default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args, cancels=load_model_event) cora_preset.click(**reset_kwargs).then( set_cora_preset, outputs=[model_name_or_path, input_template, contextless_input_template], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) zephyr_preset.click(**reset_kwargs).then( set_zephyr_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) mbart_mmt_template.click(**reset_kwargs).then( set_mbart_mmt_preset, outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) nllb_mmt_template.click(**reset_kwargs).then( set_nllb_mmt_preset, outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) chatml_template.click(**reset_kwargs).then( set_chatml_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) towerinstruct_template.click(**reset_kwargs).then( set_towerinstruct_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) gemma_template.click(**reset_kwargs).then( set_gemma_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, special_tokens_to_keep, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) mistral_instruct_template.click(**reset_kwargs).then( set_mistral_instruct_preset, outputs=[ model_name_or_path, input_template, contextless_input_template, generation_kwargs, ], ).success(preload_model, inputs=load_model_args, cancels=load_model_event) show_code_btn.click( update_code_snippets_fn, inputs=pecore_args, outputs=[python_code_snippet, shell_code_snippet], ).then(lambda: Modal(visible=True), None, code_modal) demo.queue(api_open=False, max_size=20).launch(allowed_paths=["outputs/", "img/"], show_api=False)