import gradio as gr import matplotlib.pyplot as plt from inference import RelationsInference from attention_viz import AttentionVisualizer from utils import KGType, Model_Type, Data_Type #prep import nltk nltk.download('popular') ############################# # Constants ############################# #examples = [["What's the meaning of life?", "eli5", "constraint"], # ["boat, water, bird", "commongen", "constraint"], # ["What flows under a bridge?", "commonsense_qa", "constraint"]] commongen_bart = RelationsInference( model_path='MrVicente/commonsense_bart_commongen', kg_type=KGType.CONCEPTNET, model_type=Model_Type.RELATIONS, max_length=32 ) qa_bart = RelationsInference( model_path='MrVicente/commonsense_bart_absqa', kg_type=KGType.CONCEPTNET, model_type=Model_Type.RELATIONS, max_length=128 ) att_viz = AttentionVisualizer(device='cpu') ############################# # Helper ############################# def infer_bart(context, task_type, decoding_type_str): if Data_Type(task_type) == Data_Type.COMMONGEN: if decoding_type_str =='default': response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False) else: response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2) elif Data_Type(task_type) == Data_Type.ELI5: response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False) else: raise NotImplementedError() return response[0] def plot_attention(context, task_type, layer, head): if Data_Type(task_type) == Data_Type.COMMONGEN: model = commongen_bart elif Data_Type(task_type) == Data_Type.ELI5: model = qa_bart else: raise NotImplementedError() response, examples, relations = model.prepare_context_for_visualization(context) fig = att_viz.plot_attn_lines_concepts_ids('Input text importance visualized', examples, layer, head, relations) return fig ############################# # Interface ############################# app = gr.Blocks() with app: gr.Markdown( """ # Demo ### Test Commonsense Relation-Aware BART (BART-RA) model Tutorial:
1) Select the possible model variations and tasks;
2) Change the inputs and Click the buttons to produce results;
3) See attention visualisations, by choosing a specific layer and head;
""") with gr.Row(): context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:') model_result_output = gr.Textbox(lines=2, label='Model result:') with gr.Column(): task_type_choice = gr.Radio( ["eli5", "commongen"], value="eli5", label="What task do you want to try?" ) decoding_type_choice = gr.Radio( ["default", "constraint"], value="default", label="What decoding strategy do you want to use?" ) with gr.Row(): model_btn = gr.Button(value="See Model Results") gr.Markdown( """ --- Observe Attention """ ) with gr.Row(): with gr.Column(): layer = gr.Slider(0, 11, 0, step=1, label="Layer") head = gr.Slider(0, 15, 0, step=1, label="Head") with gr.Column(): plot_output = gr.Plot() with gr.Row(): vis_btn = gr.Button(value="See Attention Scores") model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice], outputs=[model_result_output]) vis_btn.click(fn=plot_attention, inputs=[context_input, task_type_choice, layer, head], outputs=[plot_output]) if __name__ == '__main__': app.launch()