Spaces:
Runtime error
Runtime error
import gradio as gr | |
from lavis.models import load_model_and_preprocess | |
import torch | |
device = torch.device("cuda") if torch.cuda.is_available() else "cpu" | |
model_name = "blip2_t5_instruct" | |
model_type = "flant5xl" | |
model, vis_processors, _ = load_model_and_preprocess( | |
name=model_name, | |
model_type=model_type, | |
is_eval=True, | |
device=device | |
) | |
def infer(image, prompt, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, decoding_method): | |
use_nucleus_sampling = decoding_method == "Nucleus sampling" | |
image = vis_processors["eval"](image).unsqueeze(0).to(device) | |
samples = { | |
"image": image, | |
"prompt": prompt, | |
} | |
output = model.generate( | |
samples, | |
length_penalty=float(len_penalty), | |
repetition_penalty=float(repetition_penalty), | |
num_beams=beam_size, | |
max_length=max_len, | |
min_length=min_len, | |
top_p=top_p, | |
use_nucleus_sampling=use_nucleus_sampling | |
) | |
return output[0] | |
theme = gr.themes.Monochrome( | |
primary_hue="indigo", | |
secondary_hue="blue", | |
neutral_hue="slate", | |
radius_size=gr.themes.sizes.radius_sm, | |
font=[gr.themes.GoogleFont("Open Sans"), "ui-sans-serif", "system-ui", "sans-serif"], | |
) | |
css = ".generating {visibility: hidden}" | |
examples = [ | |
["banff.jpg", "Can you tell me about this image in detail", 1, 200, 5, 1, 3, 0.9, "Beam search"] | |
] | |
with gr.Blocks(theme=theme, analytics_enabled=False,css=css) as demo: | |
gr.Markdown("## InstructBLIP: Towards General-purpose Vision-Language Models with Instruction Tuning") | |
gr.Markdown( | |
""" | |
Unofficial demo for InstructBLIP. InstructBLIP is a new vision-language instruction-tuning framework by Salesforce that uses BLIP-2 models, achieving state-of-the-art zero-shot generalization performance on a wide range of vision-language tasks. | |
The demo is based on the official <a href="https://github.com/salesforce/LAVIS/tree/main/projects/instructblip" style="text-decoration: underline;" target="_blank"> Github </a> implementation | |
""" | |
) | |
gr.HTML("<p>You can duplicate this Space to run it privately without a queue for shorter queue times : <a style='display:inline-block' href='https://huggingface.co/spaces/RamAnanth1/InstructBLIP?duplicate=true'><img src='https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14' alt='Duplicate Space'></a> </p>") | |
with gr.Row(): | |
with gr.Column(scale=3): | |
image_input = gr.Image(type="pil") | |
prompt_textbox = gr.Textbox(label="Prompt:", placeholder="prompt", lines=2) | |
output = gr.Textbox(label="Output") | |
submit = gr.Button("Run", variant="primary") | |
with gr.Column(scale=1): | |
min_len = gr.Slider( | |
minimum=1, | |
maximum=50, | |
value=1, | |
step=1, | |
interactive=True, | |
label="Min Length", | |
) | |
max_len = gr.Slider( | |
minimum=10, | |
maximum=500, | |
value=250, | |
step=5, | |
interactive=True, | |
label="Max Length", | |
) | |
sampling = gr.Radio( | |
choices=["Beam search", "Nucleus sampling"], | |
value="Beam search", | |
label="Text Decoding Method", | |
interactive=True, | |
) | |
top_p = gr.Slider( | |
minimum=0.5, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
interactive=True, | |
label="Top p", | |
) | |
beam_size = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=5, | |
step=1, | |
interactive=True, | |
label="Beam Size", | |
) | |
len_penalty = gr.Slider( | |
minimum=-1, | |
maximum=2, | |
value=1, | |
step=0.2, | |
interactive=True, | |
label="Length Penalty", | |
) | |
repetition_penalty = gr.Slider( | |
minimum=-1, | |
maximum=3, | |
value=1, | |
step=0.2, | |
interactive=True, | |
label="Repetition Penalty", | |
) | |
gr.Examples( | |
examples=examples, | |
inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], | |
cache_examples=False, | |
fn=infer, | |
outputs=[output], | |
) | |
submit.click(infer, inputs=[image_input, prompt_textbox, min_len, max_len, beam_size, len_penalty, repetition_penalty, top_p, sampling], outputs=[output]) | |
demo.queue(concurrency_count=16).launch(debug=True) | |