|
import os |
|
import gradio as gr |
|
from together import Together |
|
from PIL import Image |
|
import io |
|
import base64 |
|
|
|
|
|
client = Together(api_key=os.environ.get('TOGETHER_API_KEY')) |
|
|
|
def encode_image(image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode('utf-8') |
|
|
|
def chat_with_image(message, image, history): |
|
|
|
if image is not None: |
|
encoded_image = encode_image(Image.open(image)) |
|
image_message = { |
|
"role": "user", |
|
"content": [ |
|
{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{encoded_image}"}}, |
|
{"type": "text", "text": message} |
|
] |
|
} |
|
else: |
|
image_message = {"role": "user", "content": message} |
|
|
|
|
|
messages = [{"role": "system", "content": "You are a helpful assistant."}] |
|
for human, assistant in history: |
|
messages.append({"role": "user", "content": human}) |
|
messages.append({"role": "assistant", "content": assistant}) |
|
messages.append(image_message) |
|
|
|
|
|
response = client.chat.completions.create( |
|
model="meta-llama/Llama-Vision-Free", |
|
messages=messages, |
|
max_tokens=512, |
|
temperature=0.7, |
|
top_p=0.7, |
|
top_k=50, |
|
repetition_penalty=1, |
|
stop=["<|eot_id|>", "<|eom_id|>"], |
|
stream=True |
|
) |
|
|
|
|
|
full_response = "" |
|
for chunk in response: |
|
if chunk.choices[0].delta.content is not None: |
|
full_response += chunk.choices[0].delta.content |
|
yield full_response |
|
|
|
return full_response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
chatbot = gr.Chatbot() |
|
msg = gr.Textbox() |
|
image = gr.Image(type="filepath") |
|
clear = gr.Button("Clear") |
|
|
|
def user(user_message, image, history): |
|
return "", image, history + [[user_message, None]] |
|
|
|
def bot(history): |
|
user_message, image = history[-1][0], None |
|
if len(history) > 1 and isinstance(history[-2][0], dict): |
|
image = history[-2][0]['image'] |
|
bot_message = chat_with_image(user_message, image, history[:-1]) |
|
history[-1][1] = "" |
|
for character in bot_message: |
|
history[-1][1] += character |
|
yield history |
|
|
|
msg.submit(user, [msg, image, chatbot], [msg, image, chatbot], queue=False).then( |
|
bot, chatbot, chatbot |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
demo.queue() |
|
demo.launch() |