|
import os |
|
import cv2 |
|
import spaces |
|
import gradio as gr |
|
from PIL import Image |
|
from omegaconf import OmegaConf |
|
|
|
|
|
from utils.env_utils import set_random_seed, use_lower_vram |
|
from utils.timer_utils import Timer |
|
|
|
set_random_seed(1024) |
|
timer = Timer() |
|
timer.start() |
|
|
|
|
|
|
|
from utils.labels_utils import Labels |
|
from utils.ram_utils import ram_inference |
|
from utils.blip2_utils import blip2_caption |
|
from utils.llms_utils import pre_refinement, make_prompt, init_model |
|
from utils.grounded_sam_utils import run_grounded_sam |
|
|
|
|
|
|
|
box_threshold = 0.18 |
|
text_threshold = 0.15 |
|
iou_threshold = 0.8 |
|
|
|
global current_config, L, llm, system_prompt |
|
|
|
|
|
llm = init_model("Meta-Llama-3-8B-Instruct") |
|
current_config = "" |
|
L = None |
|
system_prompt = None |
|
|
|
def load_config(config_type): |
|
config = OmegaConf.load(os.path.join(os.path.dirname(__file__), f"configs/{config_type}.yaml")) |
|
L = Labels(config=config) |
|
|
|
system_prompt = make_prompt(", ".join(L.LABELS)) |
|
return L, system_prompt |
|
|
|
@spaces.GPU(duration=120) |
|
def process(image_ori, config_type): |
|
global current_config, L, llm, system_prompt |
|
if current_config != config_type: |
|
L, system_prompt = load_config(config_type) |
|
current_config = config_type |
|
else: |
|
pass |
|
image_ori = cv2.cvtColor(image_ori, cv2.COLOR_BGR2RGB) |
|
image_pil = Image.fromarray(image_ori) |
|
labels_ram = ram_inference(image_pil) + ": " + blip2_caption(image_pil) |
|
converted_labels, llm_output = pre_refinement([labels_ram], system_prompt, llm=llm) |
|
labels_llm = L.check_labels(converted_labels)[0] |
|
print("labels_ram: ", labels_ram) |
|
print("llm_output: ", llm_output) |
|
print("labels_llm: ", labels_llm) |
|
|
|
|
|
label_res, bboxes, output_labels, output_prob_maps, output_points = run_grounded_sam( |
|
input_image = {"image": image_pil, "mask": None}, |
|
text_prompt = labels_llm, |
|
box_threshold = box_threshold, |
|
text_threshold = text_threshold, |
|
iou_threshold = iou_threshold, |
|
LABELS = L.LABELS, |
|
IDS = L.IDS, |
|
llm = llm, |
|
timer = timer, |
|
) |
|
|
|
|
|
ours = L.draw_mask(label_res, image_ori, print_label=True, tag="Ours") |
|
return cv2.cvtColor(ours, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
dropdown_options = ["COCO-81", "Cityscapes", "DRAM", "VOC2012"] |
|
default_option = "COCO-81" |
|
|
|
with gr.Blocks() as demo: |
|
gr.HTML( |
|
""" |
|
<h1 style="text-align: center; font-size: 32px; font-family: 'Times New Roman', Times, serif;"> |
|
Training-Free Zero-Shot Semantic Segmentation with LLM Refinement |
|
</h1> |
|
<p style="text-align: center; font-size: 20px; font-family: 'Times New Roman', Times, serif;"> |
|
<a style="text-align: center; display:inline-block" |
|
href="https://sky24h.github.io/websites/bmvc2024_training-free-semseg-with-LLM/"> |
|
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/paper-page-sm.svg#center" |
|
alt="Paper Page"> |
|
</a> |
|
<a style="text-align: center; display:inline-block" href="https://huggingface.co/spaces/sky24h/Training-Free_Zero-Shot_Semantic_Segmentation_with_LLM_Refinement?duplicate=true"> |
|
<img src="https://huggingface.co/datasets/huggingface/badges/raw/main/duplicate-this-space-sm.svg#center" alt="Duplicate Space"> |
|
</a> |
|
</p> |
|
""" |
|
) |
|
gr.Interface( |
|
fn=process, |
|
inputs=[gr.Image(type="numpy", height="384"), gr.Dropdown(choices=dropdown_options, label="Refinement Type", value=default_option)], |
|
outputs="image", |
|
description="""<html> |
|
<p style="text-align:center;"> This is an online demo for the paper "Training-Free Zero-Shot Semantic Segmentation with LLM Refinement" (BMVC 2024). </p> |
|
<p style="text-align:center;"> Uasge: Please select or upload an image and choose a dataset setting for semantic segmentation refinement.</p> |
|
</html>""", |
|
allow_flagging='never', |
|
examples=[ |
|
["examples/Cityscapes_eg.jpg", "Cityscapes"], |
|
["examples/DRAM_eg.jpg", "DRAM"], |
|
["examples/COCO-81_eg.jpg", "COCO-81"], |
|
["examples/VOC2012_eg.jpg", "VOC2012"], |
|
], |
|
cache_examples=True, |
|
) |
|
|
|
demo.queue(max_size=10).launch() |
|
|