from io import BytesIO import string import gradio as gr import requests from PIL import Image from utils import Endpoint def encode_image(image): buffered = BytesIO() image.save(buffered, format="JPEG") buffered.seek(0) return buffered def query_api(image, prompt, decoding_method, temperature, len_penalty, repetition_penalty): url = endpoint.url headers = {"User-Agent": "BLIP-2 HuggingFace Space"} data = { "prompt": prompt, "use_nucleus_sampling": decoding_method == "Nucleus sampling", "temperature": temperature, "length_penalty": len_penalty, "repetition_penalty": repetition_penalty, } image = encode_image(image) files = {"image": image} response = requests.post(url, data=data, files=files, headers=headers) if response.status_code == 200: return response.json() else: return "Error: " + response.text def postprocess_output(output): # if last character is not a punctuation, add a full stop if not output[0][-1] in string.punctuation: output[0] += "." return output def inference( image, text_input, decoding_method, temperature, length_penalty, repetition_penalty, history=[], ): text_input = text_input history.append(text_input) prompt = " ".join(history) output = query_api(image, prompt, decoding_method, temperature, length_penalty, repetition_penalty) output = postprocess_output(output) history += output chat = [ (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2) ] # convert to tuples of list return chat, history # image source: https://m.facebook.com/112483753737319/photos/112489593736735/ endpoint = Endpoint() examples = [ ["house.png", "How could someone get out of the house?"], [ "sunset.png", "Write a romantic message that goes along this photo.", ], ] # outputs = ["chatbot", "state"] title = """
Disclaimer: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected.
""" article = "BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models" # iface = gr.Interface(inference, inputs, outputs, title=title, description=description, article=article, examples=examples) def reset_all(text_input, image_input, chatbot, history): return "", None, None, [] def reset_chatbot(chatbot, history): return None, [] with gr.Blocks() as iface: state = gr.State([]) gr.Markdown(title) gr.Markdown(description) gr.Markdown(article) with gr.Row(): with gr.Column(): image_input = gr.Image(type="pil") text_input = gr.Textbox(lines=2, label="Text input") sampling = gr.Radio( choices=["Beam search", "Nucleus sampling"], value="Beam search", label="Text Decoding Method", interactive=True, ) with gr.Row(): temperature = gr.Slider( minimum=0.5, maximum=1.0, value=0.8, interactive=True, label="Temperature", ) len_penalty = gr.Slider( minimum=-2.0, maximum=2.0, value=1.0, step=0.5, interactive=True, label="Length Penalty", ) rep_penalty = gr.Slider( minimum=1.0, maximum=10.0, value=1.0, step=0.5, interactive=True, label="Repetition Penalty", ) with gr.Column(): chatbot = gr.Chatbot() with gr.Row(): clear_button = gr.Button(value="Clear", interactive=True) clear_button.click( reset_all, [text_input, image_input, chatbot, state], [text_input, image_input, chatbot, state], ) submit_button = gr.Button(value="Submit", interactive=True, variant="primary") submit_button.click( inference, [ image_input, text_input, sampling, temperature, len_penalty, state, ], [chatbot, state], ) image_input.change(reset_chatbot, [chatbot, state], [chatbot, state]) examples = gr.Examples( examples=examples, inputs=[image_input, text_input], ) iface.queue(concurrency_count=1) iface.launch(enable_queue=True, debug=True)