|
from transformers import pipeline |
|
import gradio as gr |
|
|
|
|
|
MODELS = { |
|
"gsarti": pipeline("summarization", model="gsarti/it5-base-wiki-summarization"), |
|
"facebook": pipeline("summarization", model="facebook/bart-large-cnn"), |
|
"lincoln": pipeline( |
|
"summarization", model="lincoln/mbart-mlsum-automatic-summarization" |
|
), |
|
"t5-small": pipeline("summarization", model="t5-small"), |
|
} |
|
|
|
|
|
def predict(prompt, model_name, max_length): |
|
if model_name is None: |
|
model = MODELS["t5-small"] |
|
else: |
|
model = MODELS[model_name] |
|
prompt = prompt.replace("\n", " ") |
|
summary = model(prompt, max_length)[0]["summary_text"] |
|
return summary |
|
|
|
|
|
options_1 = list(MODELS.keys()) |
|
with gr.Blocks() as demo: |
|
drop_down = gr.Dropdown(choices=options_1, label="model") |
|
textbox = gr.Textbox(placeholder="Enter text block to summarize", lines=4) |
|
length = gr.Number(value=100, label="the max number of characher for summerized") |
|
gr.Interface(fn=predict, inputs=[textbox, drop_down, length], outputs="text") |
|
|
|
demo.launch() |
|
|