SauravMaheshkar's picture
feat: initial commit
bf29adc unverified
raw
history blame
2.66 kB
import gradio as gr
import numpy as np
import cv2
from PIL import Image
from src.plot_utils import show_masks
from gradio_image_annotation import image_annotator
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
choice_mapping = {
"tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"],
"small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"],
"base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"],
"large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"],
}
def predict(model_choice: str, annotations, image):
config_file, ckpt_path = choice_mapping[str(model_choice)]
sam2_model = build_sam2(config_file, ckpt_path, device="cpu")
predictor = SAM2ImagePredictor(sam2_model)
predictor.set_image(image)
coordinates = np.array(
[
int(annotations["boxes"][0]["xmin"]),
int(annotations["boxes"][0]["ymin"]),
int(annotations["boxes"][0]["xmax"]),
int(annotations["boxes"][0]["ymax"]),
]
)
masks, scores, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=coordinates[None, :],
multimask_output=False,
)
mask = masks.transpose(1, 2, 0)
mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format
cv2.imwrite("mask.png", mask_image)
return [
show_masks(image, masks, scores, box_coords=coordinates),
gr.DownloadButton("Download Mask", value="mask.png", visible=True),
]
with gr.Blocks(delete_cache=(30, 30)) as demo:
gr.Markdown(
"""
# 1. Choose Model Checkpoint
"""
)
with gr.Row():
model = gr.Dropdown(
choices=["tiny", "small", "base_plus", "large"],
value="tiny",
label="Model Checkpoint",
info="Which model checkpoint to load?",
)
gr.Markdown(
"""
# 2. Upload an Image
"""
)
with gr.Row():
img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image")
gr.Markdown(
"""
# 3. Draw Bounding Box
"""
)
annotator = image_annotator(
value={"image": img.value["path"]},
disable_edit_boxes=True,
single_box=True,
label="Draw a bounding box",
)
btn = gr.Button("Get Segmentation Mask")
download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False)
btn.click(
fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn]
)
demo.launch()