from typing import Any, Dict import cv2 import gradio as gr import numpy as np from gradio_image_annotation import image_annotator from sam2 import load_model from sam2.utils.visualization import show_masks from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator from sam2.sam2_image_predictor import SAM2ImagePredictor # @spaces.GPU() def predict(model_choice, annotations: Dict[str, Any]): sam2_model = load_model( variant=model_choice, ckpt_path=f"assets/checkpoints/sam2_hiera_{model_choice}.pt", device="cpu", ) if annotations["boxes"]: predictor = SAM2ImagePredictor(sam2_model) # type:ignore predictor.set_image(annotations["image"]) coordinates = [] for i in range(len(annotations["boxes"])): coordinate = [ int(annotations["boxes"][i]["xmin"]), int(annotations["boxes"][i]["ymin"]), int(annotations["boxes"][i]["xmax"]), int(annotations["boxes"][i]["ymax"]), ] coordinates.append(coordinate) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=np.array(coordinates), multimask_output=False, ) multi_box = len(scores) > 1 return show_masks( image=annotations["image"], masks=masks, scores=scores if len(scores) == 1 else None, only_best=not multi_box, ) else: mask_generator = SAM2AutomaticMaskGenerator(sam2_model) # type: ignore masks = mask_generator.generate(annotations["image"]) return show_masks( image=annotations["image"], masks=masks, # type: ignore scores=None, only_best=False, autogenerated_mask=True ) with gr.Blocks(delete_cache=(30, 30)) as demo: gr.Markdown( """ ## To read more about the Segment Anything Project please refer to the [Lightly AI blogpost](https://www.lightly.ai/post/segment-anything-model-and-friends) """ ) gr.Markdown( """ # 1. Choose Model Checkpoint """ ) with gr.Row(): model = gr.Dropdown( choices=["tiny", "small", "base_plus", "large"], value="tiny", label="Model Checkpoint", info="Which model checkpoint to load?", ) gr.Markdown( """ # 2. Upload your Image and draw bounding box(es) """ ) annotator = image_annotator( value={"image": cv2.imread("assets/example.png")}, disable_edit_boxes=True, label="Draw a bounding box", ) btn = gr.Button("Get Segmentation Mask(s)") btn.click( fn=predict, inputs=[model, annotator], outputs=[gr.Image(label="Mask(s)")] ) demo.launch()