import gradio as gr import os import base64 from io import BytesIO from mistralai import Mistral from pydantic import BaseModel, Field from datasets import load_dataset from PIL import Image import json import sqlite3 from datetime import datetime from pymongo import MongoClient # Load the dataset ds = load_dataset("svjack/pokemon-blip-captions-en-zh") ds = ds["train"] # load MongoDB client MONGO_URI = os.environ.get('MONGO_URI') if not MONGO_URI: raise ValueError("MONGO_URI is not set in the environment variables.") client = MongoClient(MONGO_URI) db = client['capimg'] # Choose a database name collection = db['feedback'] # Choose a collection name # Load environment variables api_key = os.environ.get('MISTRAL_API_KEY') if not api_key: raise ValueError("MISTRAL_API_KEY is not set in the environment variables.") # Create sample history hist = [str({"en": ds[i]["en_text"], "zh": ds[i]["zh_text"]}) for i in range(8)] hist_str = "\n".join(hist) # Define the Caption model class Caption(BaseModel): en: str = Field(..., description="English caption of image", max_length=84) zh: str = Field(..., description="Chinese caption of image", max_length=64) # Initialize the Mistral client client = Mistral(api_key=api_key) def generate_caption(image): # Convert image to base64 buffered = BytesIO() image.save(buffered, format="JPEG") base64_image = base64.b64encode(buffered.getvalue()).decode('utf-8') messages = [ { "role": "system", "content": f''' You are a highly accurate image to caption transformer. Describe the image content in English and Chinese respectively. Make sure to FOCUS on item CATEGORY and COLOR! Do NOT provide NAMES! KEEP it SHORT! While adhering to the following JSON schema: {Caption.model_json_schema()} Following are some samples you should adhere to for style and tone: {hist_str} ''' }, { "role": "user", "content": [ { "type": "text", "text": "Describe the image in English and Chinese" }, { "type": "image_url", "image_url": f"data:image/jpeg;base64,{base64_image}" } ] } ] chat_response = client.chat.complete( model="pixtral-12b-2409", messages=messages, response_format = { "type": "json_object", } ) response_content = chat_response.choices[0].message.content try: caption_dict = json.loads(response_content) return Caption(**caption_dict) except json.JSONDecodeError as e: print(f"Error decoding JSON: {e}") return None # Initialize SQLite database def save_feedback(image, caption): # Convert image to base64 string for storage buffered = BytesIO() image.save(buffered, format="JPEG") img_str = base64.b64encode(buffered.getvalue()).decode() feedback_entry = { "timestamp": datetime.now(), "input_data": img_str, "output_data": caption } result = collection.insert_one(feedback_entry) print(f"Feedback saved with id: {result.inserted_id}") return gr.Info("Thanks for your feedback!") def process_image(image): if image is None: return "Please upload an image first." result = generate_caption(image) if result: return f"English caption: {result.en}\nChinese caption: {result.zh}" else: return "Failed to generate caption. Please check the API call or network connectivity." def thumbs_up(image, caption): return save_feedback(image, caption) with gr.Blocks() as iface: gr.Markdown("# Image Captioner") gr.Markdown("Upload an image to generate captions in English and Chinese.") gr.Markdown("Use the 'Thumbs Up' button if you like the result!!") with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(type="pil") with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit") with gr.Column(scale=1): output_text = gr.Textbox() thumbs_up_btn = gr.Button("Thumbs Up") clear_btn.click(fn=lambda: None, inputs=None, outputs=input_image) submit_btn.click(fn=process_image, inputs=input_image, outputs=output_text) thumbs_up_btn.click(fn=thumbs_up, inputs=[input_image, output_text], outputs=None) # Launch the interface iface.launch()