import gradio as gr import numpy as np import cv2 from PIL import Image from src.plot_utils import show_masks from gradio_image_annotation import image_annotator from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor choice_mapping = { "tiny": ["sam2_hiera_t.yaml", "assets/checkpoints/sam2_hiera_tiny.pt"], "small": ["sam2_hiera_s.yaml", "assets/checkpoints/sam2_hiera_small.pt"], "base_plus": ["sam2_hiera_b+.yaml", "assets/checkpoints/sam2_hiera_base_plus.pt"], "large": ["sam2_hiera_l.yaml", "assets/checkpoints/sam2_hiera_large.pt"], } def predict(model_choice: str, annotations, image): config_file, ckpt_path = choice_mapping[str(model_choice)] sam2_model = build_sam2(config_file, ckpt_path, device="cpu") predictor = SAM2ImagePredictor(sam2_model) predictor.set_image(image) coordinates = np.array( [ int(annotations["boxes"][0]["xmin"]), int(annotations["boxes"][0]["ymin"]), int(annotations["boxes"][0]["xmax"]), int(annotations["boxes"][0]["ymax"]), ] ) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=coordinates[None, :], multimask_output=False, ) mask = masks.transpose(1, 2, 0) mask_image = (mask * 255).astype(np.uint8) # Convert to uint8 format cv2.imwrite("mask.png", mask_image) return [ show_masks(image, masks, scores, box_coords=coordinates), gr.DownloadButton("Download Mask", value="mask.png", visible=True), ] with gr.Blocks(delete_cache=(30, 30)) as demo: gr.Markdown( """ # 1. Choose Model Checkpoint """ ) with gr.Row(): model = gr.Dropdown( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="Model Checkpoint", info="Which model checkpoint to load?", ) gr.Markdown( """ # 2. Upload an Image """ ) with gr.Row(): img = gr.Image(value="./assets/img.png", type="numpy", label="Input Image") gr.Markdown( """ # 3. Draw Bounding Box """ ) annotator = image_annotator( value={"image": img.value["path"]}, disable_edit_boxes=True, single_box=True, label="Draw a bounding box", ) btn = gr.Button("Get Segmentation Mask") download_btn = gr.DownloadButton("Download Mask", value="mask.png", visible=False) btn.click( fn=predict, inputs=[model, annotator, img], outputs=[gr.Plot(), download_btn] ) demo.launch()