|
import io |
|
import os |
|
|
|
import json |
|
import base64 |
|
import random |
|
import numpy as np |
|
import pandas as pd |
|
import gradio as gr |
|
from pathlib import Path |
|
from PIL import Image |
|
|
|
from plots import get_pre_define_colors |
|
from utils.load_model import load_xclip |
|
from utils.predict import xclip_pred |
|
|
|
|
|
DEVICE = "cpu" |
|
XCLIP, OWLVIT_PRECESSOR = load_xclip(DEVICE) |
|
XCLIP_DESC_PATH = "data/jsons/bs_cub_desc.json" |
|
XCLIP_DESC = json.load(open(XCLIP_DESC_PATH, "r")) |
|
PREPROCESS = lambda x: OWLVIT_PRECESSOR(images=x, return_tensors='pt') |
|
IMAGES_FOLDER = "data/images" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
IMAGE_FILE_LIST = json.load(open("data/jsons/file_list.json", "r")) |
|
|
|
|
|
IMAGE_GALLERY = [Image.open(os.path.join(IMAGES_FOLDER, 'org', file_name)).convert('RGB') for file_name in IMAGE_FILE_LIST] |
|
|
|
ORG_PART_ORDER = ['back', 'beak', 'belly', 'breast', 'crown', 'forehead', 'eyes', 'legs', 'wings', 'nape', 'tail', 'throat'] |
|
ORDERED_PARTS = ['crown', 'forehead', 'nape', 'eyes', 'beak', 'throat', 'breast', 'belly', 'back', 'wings', 'legs', 'tail'] |
|
COLORS = get_pre_define_colors(12, cmap_set=['Set2', 'tab10']) |
|
SACHIT_COLOR = "#ADD8E6" |
|
|
|
VISIBILITY_DICT = json.load(open("data/jsons/cub_vis_dict_binary.json", 'r')) |
|
VISIBILITY_DICT['Eastern_Bluebird.jpg'] = dict(zip(ORDERED_PARTS, [True]*12)) |
|
|
|
|
|
def img_to_base64(img): |
|
img_pil = Image.fromarray(img) if isinstance(img, np.ndarray) else img |
|
buffered = io.BytesIO() |
|
img_pil.save(buffered, format="JPEG") |
|
img_str = base64.b64encode(buffered.getvalue()) |
|
return img_str.decode() |
|
|
|
def create_blank_image(width=500, height=500, color=(255, 255, 255)): |
|
"""Create a blank image of the given size and color.""" |
|
return np.array(Image.new("RGB", (width, height), color)) |
|
|
|
|
|
def rgb_to_hex(rgb): |
|
return f"#{''.join(f'{x:02x}' for x in rgb)}" |
|
|
|
def load_part_images(file_name: str) -> dict: |
|
part_images = {} |
|
|
|
for part_name in ORDERED_PARTS: |
|
base_name = Path(file_name).stem |
|
part_image_path = os.path.join(IMAGES_FOLDER, "boxes", f"{base_name}_{part_name}.jpg") |
|
if not Path(part_image_path).exists(): |
|
continue |
|
image = np.array(Image.open(part_image_path)) |
|
part_images[part_name] = img_to_base64(image) |
|
|
|
|
|
return part_images |
|
|
|
def generate_xclip_explanations(result_dict:dict, visibility: dict, part_mask: dict = dict(zip(ORDERED_PARTS, [1]*12))): |
|
""" |
|
The result_dict needs three keys: 'descriptions', 'pred_scores', 'file_name' |
|
descriptions: {part_name1: desc_1, part_name2: desc_2, ...} |
|
pred_scores: {part_name1: score_1, part_name2: score_2, ...} |
|
file_name: str |
|
""" |
|
|
|
descriptions = result_dict['descriptions'] |
|
image_name = result_dict['file_name'] |
|
part_images = PART_IMAGES_DICT[image_name] |
|
MAX_LENGTH = 50 |
|
exp_length = 400 |
|
fontsize = 15 |
|
|
|
|
|
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', |
|
"<svg width=\"100%\" height=\"100%\">"] |
|
|
|
|
|
y_offset = 0 |
|
for part in ORDERED_PARTS: |
|
if visibility[part] and part_mask[part]: |
|
|
|
part_score = max(result_dict['pred_scores'][part], 0) |
|
bar_length = part_score * exp_length |
|
|
|
|
|
mouseover_action1 = f"document.getElementById('overlayImage').src = 'data:image/jpeg;base64,{part_images[part]}'; document.getElementById('overlayImage').style.opacity = 1;" |
|
mouseout_action1 = "document.getElementById('overlayImage').style.opacity = 0;" |
|
|
|
combined_mouseover = f"javascript: {mouseover_action1};" |
|
combined_mouseout = f"javascript: {mouseout_action1};" |
|
|
|
|
|
num_lines = len(descriptions[part]) // MAX_LENGTH + 1 |
|
for line in range(num_lines): |
|
desc_line = descriptions[part][line*MAX_LENGTH:(line+1)*MAX_LENGTH] |
|
y_offset += fontsize |
|
svg_parts.append(f""" |
|
<text x="0" y="{y_offset}" font-size="{fontsize}" |
|
onmouseover="{combined_mouseover}" |
|
onmouseout="{combined_mouseout}"> |
|
{desc_line} |
|
</text> |
|
""") |
|
|
|
|
|
svg_parts.append(f""" |
|
<rect x="0" y="{y_offset +3}" width="{bar_length}" height="{fontsize*0.7}" fill="{PART_COLORS[part]}" |
|
onmouseover="{combined_mouseover}" |
|
onmouseout="{combined_mouseout}"> |
|
</rect> |
|
""") |
|
|
|
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="{fontsize}" fill="{PART_COLORS[part]}">{part_score:.2f}</text>') |
|
|
|
y_offset += fontsize + 3 |
|
svg_parts.extend(("</svg>", "</div>")) |
|
|
|
html = "".join(svg_parts) |
|
|
|
|
|
return html |
|
|
|
|
|
|
|
def generate_sachit_explanations(result_dict:dict): |
|
descriptions = result_dict['descriptions'] |
|
scores = result_dict['scores'] |
|
MAX_LENGTH = 50 |
|
exp_length = 400 |
|
fontsize = 15 |
|
|
|
descriptions = zip(scores, descriptions) |
|
descriptions = sorted(descriptions, key=lambda x: x[0], reverse=True) |
|
|
|
|
|
svg_parts = [f'<div style="width: {exp_length}px; height: 450px; background-color: white;">', |
|
"<svg width=\"100%\" height=\"100%\">"] |
|
|
|
|
|
y_offset = 0 |
|
for score, desc in descriptions: |
|
|
|
|
|
part_score = max(score, 0) |
|
bar_length = part_score * exp_length |
|
|
|
|
|
num_lines = len(desc) // MAX_LENGTH + 1 |
|
for line in range(num_lines): |
|
desc_line = desc[line*MAX_LENGTH:(line+1)*MAX_LENGTH] |
|
y_offset += fontsize |
|
svg_parts.append(f""" |
|
<text x="0" y="{y_offset}" font-size="{fontsize}" fill="black"> |
|
{desc_line} |
|
</text> |
|
""") |
|
|
|
|
|
svg_parts.append(f""" |
|
<rect x="0" y="{y_offset+3}" width="{bar_length}" height="{fontsize*0.7}" fill="{SACHIT_COLOR}"> |
|
</rect> |
|
""") |
|
|
|
|
|
svg_parts.append(f'<text x="{exp_length - 50}" y="{y_offset+fontsize+3}" font-size="fontsize" fill="{SACHIT_COLOR}">{part_score:.2f}</text>') |
|
|
|
y_offset += fontsize + 3 |
|
|
|
|
|
svg_parts.extend(("</svg>", "</div>")) |
|
|
|
html = "".join(svg_parts) |
|
|
|
|
|
return html |
|
|
|
|
|
BLANK_OVERLAY = img_to_base64(create_blank_image()) |
|
PART_COLORS = {part: rgb_to_hex(COLORS[i]) for i, part in enumerate(ORDERED_PARTS)} |
|
blank_image = np.array(Image.open('data/images/final.png').convert('RGB')) |
|
PART_IMAGES_DICT = {file_name: load_part_images(file_name) for file_name in IMAGE_FILE_LIST} |
|
|
|
|
|
def update_selected_image(event: gr.SelectData): |
|
image_height = 400 |
|
index = event.index |
|
|
|
image_name = IMAGE_FILE_LIST[index] |
|
current_image.state = image_name |
|
org_image = Image.open(os.path.join(IMAGES_FOLDER, 'org', image_name)).convert('RGB') |
|
img_base64 = f""" |
|
<div style="position: relative; height: {image_height}px; display: inline-block;"> |
|
<img id="birdImage" src="data:image/jpeg;base64,{img_to_base64(org_image)}" style="height: {image_height}px; width: auto;"> |
|
<img id="overlayImage" src="data:image/jpeg;base64,{BLANK_OVERLAY}" style="position:absolute; top:0; left:0; width:auto; height: {image_height}px; opacity: 0;"> |
|
</div> |
|
""" |
|
gt_label = XCLIP_RESULTS[image_name]['ground_truth'] |
|
gt_class.state = gt_label |
|
|
|
|
|
out_dict = xclip_pred(new_desc=None, new_part_mask=None, new_class=None, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state) |
|
xclip_label = out_dict['pred_class'] |
|
clip_pred_scores = out_dict['pred_score'] |
|
xclip_part_scores = out_dict['pred_desc_scores'] |
|
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} |
|
xclip_exp = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask=dict(zip(ORDERED_PARTS, [1]*12))) |
|
|
|
|
|
xclip_color = "green" if xclip_label.strip() == gt_label.strip() else "red" |
|
xclip_pred_markdown = f""" |
|
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {clip_pred_scores:.4f}</span> |
|
""" |
|
|
|
gt_label = f""" |
|
## {gt_label} |
|
""" |
|
current_predicted_class.state = xclip_label |
|
|
|
|
|
custom_class_name = "class name: custom" |
|
descs = XCLIP_DESC[xclip_label] |
|
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} |
|
descs = {k: descs[k] for k in ORDERED_PARTS} |
|
custom_text = [custom_class_name] + list(descs.values()) |
|
descriptions = ";\n".join(custom_text) |
|
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) |
|
|
|
return gt_label, img_base64, xclip_pred_markdown, xclip_exp, current_image, textbox |
|
|
|
def on_edit_button_click_xclip(): |
|
empty_exp = gr.HTML.update(visible=False) |
|
|
|
|
|
descs = XCLIP_DESC[current_predicted_class.state] |
|
descs = {k: descs[i] for i, k in enumerate(ORG_PART_ORDER)} |
|
descs = {k: descs[k] for k in ORDERED_PARTS} |
|
custom_text = ["class name: custom"] + list(descs.values()) |
|
descriptions = ";\n".join(custom_text) |
|
textbox = gr.Textbox.update(value=descriptions, lines=12, visible=True, label="XCLIP descriptions", interactive=True, info='Please use ";" to separate the descriptions for each part, and keep the format of {part name}: {descriptions}', show_label=False) |
|
|
|
return textbox, empty_exp |
|
|
|
def convert_input_text_to_xclip_format(textbox_input: str): |
|
|
|
|
|
descriptions_list = textbox_input.split(";\n") |
|
|
|
class_name_line = descriptions_list[0] |
|
new_class_name = class_name_line.split(":")[1].strip() |
|
|
|
descriptions_list = descriptions_list[1:] |
|
|
|
|
|
descriptions_dict = {} |
|
for desc in descriptions_list: |
|
if desc.strip() == "": |
|
continue |
|
part_name, _ = desc.split(":") |
|
descriptions_dict[part_name.strip()] = desc |
|
|
|
part_mask = {} |
|
for part in ORDERED_PARTS: |
|
if part not in descriptions_dict: |
|
descriptions_dict[part] = "" |
|
part_mask[part] = 0 |
|
else: |
|
part_mask[part] = 1 |
|
return descriptions_dict, part_mask, new_class_name |
|
|
|
def on_predict_button_click_xclip(textbox_input: str): |
|
descriptions_dict, part_mask, new_class_name = convert_input_text_to_xclip_format(textbox_input) |
|
|
|
|
|
out_dict = xclip_pred(new_desc=descriptions_dict, new_part_mask=part_mask, new_class=new_class_name, org_desc=XCLIP_DESC_PATH, image=Image.open(os.path.join(IMAGES_FOLDER, 'org', current_image.state)).convert('RGB'), model=XCLIP, owlvit_processor=OWLVIT_PRECESSOR, device=DEVICE, image_name=current_image.state) |
|
xclip_label = out_dict['pred_class'] |
|
xclip_pred_score = out_dict['pred_score'] |
|
xclip_part_scores = out_dict['pred_desc_scores'] |
|
custom_label = out_dict['modified_class'] |
|
custom_pred_score = out_dict['modified_score'] |
|
custom_part_scores = out_dict['modified_desc_scores'] |
|
|
|
|
|
result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["descriptions"])), 'pred_scores': xclip_part_scores, 'file_name': current_image.state} |
|
xclip_explanation = generate_xclip_explanations(result_dict, VISIBILITY_DICT[current_image.state], part_mask) |
|
modified_result_dict = {'descriptions': dict(zip(ORG_PART_ORDER, out_dict["modified_descriptions"])), 'pred_scores': custom_part_scores, 'file_name': current_image.state} |
|
modified_explanation = generate_xclip_explanations(modified_result_dict, VISIBILITY_DICT[current_image.state], part_mask) |
|
|
|
xclip_color = "green" if xclip_label.strip() == gt_class.state.strip() else "red" |
|
xclip_pred_markdown = f""" |
|
### <span style='color:{xclip_color}'>XCLIP: {xclip_label} {xclip_pred_score:.4f}</span> |
|
""" |
|
custom_color = "green" if custom_label.strip() == gt_class.state.strip() else "red" |
|
custom_pred_markdown = f""" |
|
### <span style='color:{custom_color}'>XCLIP: {custom_label} {custom_pred_score:.4f}</span> |
|
""" |
|
textbox = gr.Textbox.update(visible=False) |
|
|
|
|
|
modified_exp = gr.HTML().update(value=modified_explanation, visible=True) |
|
return textbox, xclip_pred_markdown, xclip_explanation, custom_pred_markdown, modified_exp |
|
|
|
|
|
custom_css = """ |
|
html, body { |
|
margin: 0; |
|
padding: 0; |
|
} |
|
|
|
#container { |
|
position: relative; |
|
width: 400px; |
|
height: 400px; |
|
border: 1px solid #000; |
|
margin: 0 auto; /* This will center the container horizontally */ |
|
} |
|
|
|
#canvas { |
|
position: absolute; |
|
top: 0; |
|
left: 0; |
|
width: 100%; |
|
height: 100%; |
|
object-fit: cover; |
|
} |
|
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css, title="PEEB") as demo: |
|
current_image = gr.State("") |
|
current_predicted_class = gr.State("") |
|
gt_class = gr.State("") |
|
|
|
with gr.Column(): |
|
title_text = gr.Markdown("# PEEB - demo") |
|
gr.Markdown( |
|
"- In this demo, you can edit the descriptions of a class and see how to model react to it." |
|
) |
|
|
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("## Select an image to start!") |
|
image_gallery = gr.Gallery(value=IMAGE_GALLERY, label=None, preview=False, allow_preview=False, columns=10, height=250) |
|
gr.Markdown("### Custom descritions: \n The first row should be **class name: {some name};**, where you can name your descriptions. \n For the remianing descriptions, please use **;** to separate the descriptions for each part, and use the format **{part name}: {descriptions}**. \n Note that you can delete a part completely, in such cases, all descriptions will remove the corresponding part.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_label = gr.Markdown("### Class Name") |
|
org_image = gr.HTML() |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
|
|
xclip_predict_button = gr.Button(value="Predict") |
|
xclip_pred_label = gr.Markdown("### XCLIP:") |
|
xclip_explanation = gr.HTML() |
|
|
|
with gr.Column(): |
|
|
|
xclip_edit_button = gr.Button(value="Reset Descriptions") |
|
custom_pred_label = gr.Markdown( |
|
"### Custom Descritpions:" |
|
) |
|
xclip_textbox = gr.Textbox(lines=12, placeholder="Edit the descriptions here", visible=False) |
|
|
|
custom_explanation = gr.HTML() |
|
|
|
gr.HTML("<br>") |
|
|
|
image_gallery.select(update_selected_image, inputs=None, outputs=[image_label, org_image, xclip_pred_label, xclip_explanation, current_image, xclip_textbox]) |
|
xclip_edit_button.click(on_edit_button_click_xclip, inputs=[], outputs=[xclip_textbox, custom_explanation]) |
|
xclip_predict_button.click(on_predict_button_click_xclip, inputs=[xclip_textbox], outputs=[xclip_textbox, xclip_pred_label, xclip_explanation, custom_pred_label, custom_explanation]) |
|
|
|
demo.launch(server_port=5000, share=True) |