RA-BART / app.py
MrVicente's picture
fixed gradio plot issue
c56dde4
raw
history blame contribute delete
No virus
3.9 kB
import gradio as gr
import matplotlib.pyplot as plt
from inference import RelationsInference
from attention_viz import AttentionVisualizer
from utils import KGType, Model_Type, Data_Type
#prep
import nltk
nltk.download('popular')
#############################
# Constants
#############################
#examples = [["What's the meaning of life?", "eli5", "constraint"],
# ["boat, water, bird", "commongen", "constraint"],
# ["What flows under a bridge?", "commonsense_qa", "constraint"]]
commongen_bart = RelationsInference(
model_path='MrVicente/commonsense_bart_commongen',
kg_type=KGType.CONCEPTNET,
model_type=Model_Type.RELATIONS,
max_length=32
)
qa_bart = RelationsInference(
model_path='MrVicente/commonsense_bart_absqa',
kg_type=KGType.CONCEPTNET,
model_type=Model_Type.RELATIONS,
max_length=128
)
att_viz = AttentionVisualizer(device='cpu')
#############################
# Helper
#############################
def infer_bart(context, task_type, decoding_type_str):
if Data_Type(task_type) == Data_Type.COMMONGEN:
if decoding_type_str =='default':
response, _, _ = commongen_bart.generate_based_on_context(context, use_kg=False)
else:
response, _, _ = commongen_bart.generate_contrained_based_on_context([context], use_kg=True, max_concepts=2)
elif Data_Type(task_type) == Data_Type.ELI5:
response, _, _ = qa_bart.generate_based_on_context(context, use_kg=False)
else:
raise NotImplementedError()
return response[0]
def plot_attention(context, task_type, layer, head):
if Data_Type(task_type) == Data_Type.COMMONGEN:
model = commongen_bart
elif Data_Type(task_type) == Data_Type.ELI5:
model = qa_bart
else:
raise NotImplementedError()
response, examples, relations = model.prepare_context_for_visualization(context)
fig = att_viz.plot_attn_lines_concepts_ids('Input text importance visualized',
examples,
layer, head,
relations)
return fig
#############################
# Interface
#############################
app = gr.Blocks()
with app:
gr.Markdown(
"""
# Demo
### Test Commonsense Relation-Aware BART (BART-RA) model
Tutorial: <br>
1) Select the possible model variations and tasks;<br>
2) Change the inputs and Click the buttons to produce results;<br>
3) See attention visualisations, by choosing a specific layer and head;<br>
""")
with gr.Row():
context_input = gr.Textbox(lines=2, value="What's the meaning of life?", label='Input:')
model_result_output = gr.Textbox(lines=2, label='Model result:')
with gr.Column():
task_type_choice = gr.Radio(
["eli5", "commongen"], value="eli5", label="What task do you want to try?"
)
decoding_type_choice = gr.Radio(
["default", "constraint"], value="default", label="What decoding strategy do you want to use?"
)
with gr.Row():
model_btn = gr.Button(value="See Model Results")
gr.Markdown(
"""
---
Observe Attention
"""
)
with gr.Row():
with gr.Column():
layer = gr.Slider(0, 11, 0, step=1, label="Layer")
head = gr.Slider(0, 15, 0, step=1, label="Head")
with gr.Column():
plot_output = gr.Plot()
with gr.Row():
vis_btn = gr.Button(value="See Attention Scores")
model_btn.click(fn=infer_bart, inputs=[context_input, task_type_choice, decoding_type_choice],
outputs=[model_result_output])
vis_btn.click(fn=plot_attention, inputs=[context_input, task_type_choice, layer, head], outputs=[plot_output])
if __name__ == '__main__':
app.launch()