import torch import torch.nn as nn from ..modeling import Sam from .amg import calculate_stability_score class SamCoreMLModel(nn.Module): """ This model should not be called directly, but is used in CoreML export. """ def __init__( self, model: Sam, use_stability_score: bool = False ) -> None: super().__init__() self.mask_decoder = model.mask_decoder self.model = model self.img_size = model.image_encoder.img_size self.use_stability_score = use_stability_score self.stability_score_offset = 1.0 def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: point_coords = point_coords + 0.5 point_coords = point_coords / self.img_size point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) point_embedding = point_embedding * (point_labels != -1) point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( point_labels == -1 ) for i in range(self.model.prompt_encoder.num_point_embeddings): point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ i ].weight * (point_labels == i) return point_embedding @torch.no_grad() def forward( self, image_embeddings: torch.Tensor, point_coords: torch.Tensor, point_labels: torch.Tensor, ): sparse_embedding = self._embed_points(point_coords, point_labels) dense_embedding = self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) masks, scores = self.model.mask_decoder.predict_masks( image_embeddings=image_embeddings, image_pe=self.model.prompt_encoder.get_dense_pe(), sparse_prompt_embeddings=sparse_embedding, dense_prompt_embeddings=dense_embedding, ) if self.use_stability_score: scores = calculate_stability_score( masks, self.model.mask_threshold, self.stability_score_offset ) return scores, masks