Spaces:
Runtime error
Runtime error
from typing import Dict, Tuple | |
import torch | |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator | |
from sam2.build_sam import build_sam2 | |
from sam2.sam2_image_predictor import SAM2ImagePredictor | |
BOX_PROMPT_MODE = "box prompt" | |
MASK_GENERATION_MODE = "mask generation" | |
MODE_NAMES = [BOX_PROMPT_MODE, MASK_GENERATION_MODE] | |
CHECKPOINT_NAMES = ["tiny", "small", "base_plus", "large"] | |
CHECKPOINTS = { | |
"tiny": ["sam2_hiera_t.yaml", "checkpoints/sam2_hiera_tiny.pt"], | |
"small": ["sam2_hiera_s.yaml", "checkpoints/sam2_hiera_small.pt"], | |
"base_plus": ["sam2_hiera_b+.yaml", "checkpoints/sam2_hiera_base_plus.pt"], | |
"large": ["sam2_hiera_l.yaml", "checkpoints/sam2_hiera_large.pt"], | |
} | |
def load_models( | |
device: torch.device | |
) -> Tuple[Dict[str, SAM2ImagePredictor], Dict[str, SAM2AutomaticMaskGenerator]]: | |
image_predictors = {} | |
mask_generators = {} | |
for key, (config, checkpoint) in CHECKPOINTS.items(): | |
model = build_sam2(config, checkpoint, device=device) | |
image_predictors[key] = SAM2ImagePredictor(sam_model=model) | |
mask_generators[key] = SAM2AutomaticMaskGenerator( | |
model=model, | |
points_per_side=32, | |
points_per_batch=64, | |
pred_iou_thresh=0.7, | |
stability_score_thresh=0.92, | |
stability_score_offset=0.7, | |
crop_n_layers=1, | |
box_nms_thresh=0.7, | |
) | |
return image_predictors, mask_generators | |