Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from transformers import MllamaForConditionalGeneration, AutoProcessor | |
from huggingface_hub import login | |
import spaces | |
import json | |
import matplotlib.pyplot as plt | |
import io | |
import base64 | |
def check_environment(): | |
required_vars = ["HF_TOKEN"] | |
missing_vars = [var for var in required_vars if var not in os.environ] | |
if missing_vars: | |
raise ValueError( | |
f"Missing required environment variables: {', '.join(missing_vars)}\n" | |
"Please set the HF_TOKEN environment variable with your Hugging Face token" | |
) | |
# Login to Hugging Face | |
check_environment() | |
login(token=os.environ["HF_TOKEN"], add_to_git_credential=True) | |
# Load model and processor (do this outside the inference function to avoid reloading) | |
base_model_path = ( | |
"taesiri/BugsBunny-LLama-3.2-11B-Vision-BaseCaptioner-Medium-FullModel" | |
) | |
processor = AutoProcessor.from_pretrained(base_model_path) | |
model = MllamaForConditionalGeneration.from_pretrained( | |
base_model_path, | |
torch_dtype=torch.bfloat16, | |
device_map="cuda", | |
) | |
# model = PeftModel.from_pretrained(model, lora_weights_path) | |
model.tie_weights() | |
def describe_image_in_JSON(json_string): | |
try: | |
# First JSON decode | |
first_decode = json.loads(json_string) | |
# Second JSON decode - parse the actual data | |
final_data = json.loads(first_decode) | |
return final_data | |
except json.JSONDecodeError as e: | |
return f"Error parsing JSON: {str(e)}" | |
def create_color_palette_image(colors): | |
if not colors or not isinstance(colors, list): | |
return None | |
try: | |
# Validate color format | |
for color in colors: | |
if not isinstance(color, str) or not color.startswith("#"): | |
return None | |
# Create figure and axis | |
fig, ax = plt.subplots(figsize=(10, 2)) | |
# Create rectangles for each color | |
for i, color in enumerate(colors): | |
ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color)) | |
# Set the view limits and aspect ratio | |
ax.set_xlim(0, len(colors)) | |
ax.set_ylim(0, 1) | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
return fig # Return the matplotlib figure directly | |
except Exception as e: | |
print(f"Error creating color palette: {e}") | |
return None | |
def inference(image): | |
if image is None: | |
return ["Please provide an image"] * 8 | |
if not isinstance(image, Image.Image): | |
try: | |
image = Image.fromarray(image) | |
except Exception as e: | |
print(f"Image conversion error: {e}") | |
return ["Invalid image format"] * 8 | |
# Prepare input | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image"}, | |
{"type": "text", "text": "Describe the image in JSON"}, | |
], | |
} | |
] | |
input_text = processor.apply_chat_template(messages, add_generation_prompt=True) | |
try: | |
# Move inputs to the correct device | |
inputs = processor( | |
image, input_text, add_special_tokens=False, return_tensors="pt" | |
).to(model.device) | |
# Clear CUDA cache after inference | |
with torch.no_grad(): | |
output = model.generate(**inputs, max_new_tokens=2048) | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
except Exception as e: | |
print(f"Inference error: {e}") | |
return ["Error during inference"] * 8 | |
# Decode output | |
result = processor.decode(output[0], skip_special_tokens=True) | |
print("DEBUG: Full decoded output:", result) | |
try: | |
json_str = result.strip().split("assistant\n")[1].strip() | |
print("DEBUG: Extracted JSON string after split:", json_str) | |
except Exception as e: | |
print("DEBUG: Error splitting response:", e) | |
return ["Error extracting JSON from response"] * 8 + [ | |
"Failed to extract JSON", | |
"Error", | |
] | |
parsed_json = describe_image_in_JSON(json_str) | |
if parsed_json: | |
# Create color palette visualization | |
colors = parsed_json.get("color_palette", []) | |
color_image = create_color_palette_image(colors) | |
# Convert lists to proper format for Gradio JSON components | |
character_list = json.dumps(parsed_json.get("character_list", [])) | |
object_list = json.dumps(parsed_json.get("object_list", [])) | |
texture_details = json.dumps(parsed_json.get("texture_details", [])) | |
return ( | |
parsed_json.get("description", "Not available"), | |
parsed_json.get("scene_description", "Not available"), | |
character_list, | |
object_list, | |
texture_details, | |
parsed_json.get("lighting_details", "Not available"), | |
color_image, | |
json_str, | |
"", # Error box | |
"Analysis complete", # Status | |
) | |
return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"] | |
# Update Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Image", | |
elem_id="large-image", | |
) | |
submit_btn = gr.Button("Analyze Image", variant="primary") | |
with gr.Tabs(): | |
with gr.Tab("Structured Results"): | |
with gr.Column(scale=1): | |
description_output = gr.Textbox( | |
label="Description", | |
lines=4, | |
) | |
scene_output = gr.Textbox( | |
label="Scene Description", | |
lines=2, | |
) | |
characters_output = gr.JSON( | |
label="Characters", | |
) | |
objects_output = gr.JSON( | |
label="Objects", | |
) | |
textures_output = gr.JSON( | |
label="Texture Details", | |
) | |
lighting_output = gr.Textbox( | |
label="Lighting Details", | |
lines=2, | |
) | |
color_palette_output = gr.Plot( | |
label="Color Palette", | |
) | |
with gr.Tab("Raw Output"): | |
raw_output = gr.Textbox( | |
label="Raw JSON Response", | |
lines=25, | |
max_lines=30, | |
) | |
error_box = gr.Textbox(label="Error Messages", visible=False) | |
with gr.Row(): | |
status_text = gr.Textbox(label="Status", value="Ready", interactive=False) | |
submit_btn.click( | |
fn=inference, | |
inputs=[image_input], | |
outputs=[ | |
description_output, | |
scene_output, | |
characters_output, | |
objects_output, | |
textures_output, | |
lighting_output, | |
color_palette_output, | |
raw_output, | |
error_box, | |
status_text, | |
], | |
api_name="analyze", | |
) | |
demo.launch(share=True) | |