SegGPT
Overview
The SegGPT model was proposed in SegGPT: Segmenting Everything In Context by Xinlong Wang, Xiaosong Zhang, Yue Cao, Wen Wang, Chunhua Shen, Tiejun Huang. SegGPT employs a decoder-only Transformer that can generate a segmentation mask given an input image, a prompt image and its corresponding prompt mask. The model achieves remarkable one-shot results with 56.1 mIoU on COCO-20 and 85.6 mIoU on FSS-1000.
The abstract from the paper is the following:
We present SegGPT, a generalist model for segmenting everything in context. We unify various segmentation tasks into a generalist in-context learning framework that accommodates different kinds of segmentation data by transforming them into the same format of images. The training of SegGPT is formulated as an in-context coloring problem with random color mapping for each data sample. The objective is to accomplish diverse tasks according to the context, rather than relying on specific colors. After training, SegGPT can perform arbitrary segmentation tasks in images or videos via in-context inference, such as object instance, stuff, part, contour, and text. SegGPT is evaluated on a broad range of tasks, including few-shot semantic segmentation, video object segmentation, semantic segmentation, and panoptic segmentation. Our results show strong capabilities in segmenting in-domain and out-of
Tips:
- One can use SegGptImageProcessor to prepare image input, prompt and mask to the model.
- One can either use segmentation maps or RGB images as prompt masks. If using the latter make sure to set
do_convert_rgb=False
in thepreprocess
method. - Itβs highly advisable to pass
num_labels
when usingsegmentation_maps
(not considering background) during preprocessing and postprocessing with SegGptImageProcessor for your use case. - When doing inference with SegGptForImageSegmentation if your
batch_size
is greater than 1 you can use feature ensemble across your images by passingfeature_ensemble=True
in the forward method.
Hereβs how to use the model for one-shot semantic segmentation:
import torch
from datasets import load_dataset
from transformers import SegGptImageProcessor, SegGptForImageSegmentation
checkpoint = "BAAI/seggpt-vit-large"
image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
model = SegGptForImageSegmentation.from_pretrained(checkpoint)
dataset_id = "EduardoPacheco/FoodSeg103"
ds = load_dataset(dataset_id, split="train")
# Number of labels in FoodSeg103 (not including background)
num_labels = 103
image_input = ds[4]["image"]
ground_truth = ds[4]["label"]
image_prompt = ds[29]["image"]
mask_prompt = ds[29]["label"]
inputs = image_processor(
images=image_input,
prompt_images=image_prompt,
segmentation_maps=mask_prompt,
num_labels=num_labels,
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = [image_input.size[::-1]]
mask = image_processor.post_process_semantic_segmentation(outputs, target_sizes, num_labels=num_labels)[0]
This model was contributed by EduardoPacheco. The original code can be found here.
SegGptConfig
class transformers.SegGptConfig
< source >( hidden_size = 1024 num_hidden_layers = 24 num_attention_heads = 16 hidden_act = 'gelu' hidden_dropout_prob = 0.0 initializer_range = 0.02 layer_norm_eps = 1e-06 image_size = [896, 448] patch_size = 16 num_channels = 3 qkv_bias = True mlp_dim = None drop_path_rate = 0.1 pretrain_image_size = 224 decoder_hidden_size = 64 use_relative_position_embeddings = True merge_index = 2 intermediate_hidden_state_indices = [5, 11, 17, 23] beta = 0.01 **kwargs )
Parameters
- hidden_size (
int
, optional, defaults to 1024) — Dimensionality of the encoder layers and the pooler layer. - num_hidden_layers (
int
, optional, defaults to 24) — Number of hidden layers in the Transformer encoder. - num_attention_heads (
int
, optional, defaults to 16) — Number of attention heads for each attention layer in the Transformer encoder. - hidden_act (
str
orfunction
, optional, defaults to"gelu"
) — The non-linear activation function (function or string) in the encoder and pooler. If string,"gelu"
,"relu"
,"selu"
and"gelu_new"
are supported. - hidden_dropout_prob (
float
, optional, defaults to 0.0) — The dropout probability for all fully connected layers in the embeddings, encoder, and pooler. - initializer_range (
float
, optional, defaults to 0.02) — The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - layer_norm_eps (
float
, optional, defaults to 1e-06) — The epsilon used by the layer normalization layers. - image_size (
List[int]
, optional, defaults to[896, 448]
) — The size (resolution) of each image. - patch_size (
int
, optional, defaults to 16) — The size (resolution) of each patch. - num_channels (
int
, optional, defaults to 3) — The number of input channels. - qkv_bias (
bool
, optional, defaults toTrue
) — Whether to add a bias to the queries, keys and values. - mlp_dim (
int
, optional) — The dimensionality of the MLP layer in the Transformer encoder. If unset, defaults tohidden_size
* 4. - drop_path_rate (
float
, optional, defaults to 0.1) — The drop path rate for the dropout layers. - pretrain_image_size (
int
, optional, defaults to 224) — The pretrained size of the absolute position embeddings. - decoder_hidden_size (
int
, optional, defaults to 64) — Hidden size for decoder. - use_relative_position_embeddings (
bool
, optional, defaults toTrue
) — Whether to use relative position embeddings in the attention layers. - merge_index (
int
, optional, defaults to 2) — The index of the encoder layer to merge the embeddings. - intermediate_hidden_state_indices (
List[int]
, optional, defaults to[5, 11, 17, 23]
) — The indices of the encoder layers which we store as features for the decoder. - beta (
float
, optional, defaults to 0.01) — Regularization factor for SegGptLoss (smooth-l1 loss).
This is the configuration class to store the configuration of a SegGptModel. It is used to instantiate a SegGPT model according to the specified arguments, defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the SegGPT BAAI/seggpt-vit-large architecture.
Configuration objects inherit from PretrainedConfig and can be used to control the model outputs. Read the documentation from PretrainedConfig for more information.
Example:
>>> from transformers import SegGptConfig, SegGptModel
>>> # Initializing a SegGPT seggpt-vit-large style configuration
>>> configuration = SegGptConfig()
>>> # Initializing a model (with random weights) from the seggpt-vit-large style configuration
>>> model = SegGptModel(configuration)
>>> # Accessing the model configuration
>>> configuration = model.config
SegGptImageProcessor
class transformers.SegGptImageProcessor
< source >( do_resize: bool = True size: typing.Optional[typing.Dict[str, int]] = None resample: Resampling = <Resampling.BICUBIC: 3> do_rescale: bool = True rescale_factor: typing.Union[int, float] = 0.00392156862745098 do_normalize: bool = True image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None do_convert_rgb: bool = True **kwargs )
Parameters
- do_resize (
bool
, optional, defaults toTrue
) — Whether to resize the image’s (height, width) dimensions to the specified(size["height"], size["width"])
. Can be overridden by thedo_resize
parameter in thepreprocess
method. - size (
dict
, optional, defaults to{"height" -- 448, "width": 448}
): Size of the output image after resizing. Can be overridden by thesize
parameter in thepreprocess
method. - resample (
PILImageResampling
, optional, defaults toResampling.BICUBIC
) — Resampling filter to use if resizing the image. Can be overridden by theresample
parameter in thepreprocess
method. - do_rescale (
bool
, optional, defaults toTrue
) — Whether to rescale the image by the specified scalerescale_factor
. Can be overridden by thedo_rescale
parameter in thepreprocess
method. - rescale_factor (
int
orfloat
, optional, defaults to1/255
) — Scale factor to use if rescaling the image. Can be overridden by therescale_factor
parameter in thepreprocess
method. - do_normalize (
bool
, optional, defaults toTrue
) — Whether to normalize the image. Can be overridden by thedo_normalize
parameter in thepreprocess
method. - image_mean (
float
orList[float]
, optional, defaults toIMAGENET_DEFAULT_MEAN
) — Mean to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by theimage_mean
parameter in thepreprocess
method. - image_std (
float
orList[float]
, optional, defaults toIMAGENET_DEFAULT_STD
) — Standard deviation to use if normalizing the image. This is a float or list of floats the length of the number of channels in the image. Can be overridden by theimage_std
parameter in thepreprocess
method. - do_convert_rgb (
bool
, optional, defaults toTrue
) — Whether to convert the prompt mask to RGB format. Can be overridden by thedo_convert_rgb
parameter in thepreprocess
method.
Constructs a SegGpt image processor.
preprocess
< source >( images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), typing.List[ForwardRef('PIL.Image.Image')], typing.List[numpy.ndarray], typing.List[ForwardRef('torch.Tensor')], NoneType] = None prompt_images: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), typing.List[ForwardRef('PIL.Image.Image')], typing.List[numpy.ndarray], typing.List[ForwardRef('torch.Tensor')], NoneType] = None prompt_masks: typing.Union[ForwardRef('PIL.Image.Image'), numpy.ndarray, ForwardRef('torch.Tensor'), typing.List[ForwardRef('PIL.Image.Image')], typing.List[numpy.ndarray], typing.List[ForwardRef('torch.Tensor')], NoneType] = None do_resize: typing.Optional[bool] = None size: typing.Dict[str, int] = None resample: Resampling = None do_rescale: typing.Optional[bool] = None rescale_factor: typing.Optional[float] = None do_normalize: typing.Optional[bool] = None image_mean: typing.Union[float, typing.List[float], NoneType] = None image_std: typing.Union[float, typing.List[float], NoneType] = None do_convert_rgb: typing.Optional[bool] = None num_labels: typing.Optional[int] = None return_tensors: typing.Union[str, transformers.utils.generic.TensorType, NoneType] = None data_format: typing.Union[str, transformers.image_utils.ChannelDimension] = <ChannelDimension.FIRST: 'channels_first'> input_data_format: typing.Union[str, transformers.image_utils.ChannelDimension, NoneType] = None **kwargs )
Parameters
- images (
ImageInput
) — Image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, setdo_rescale=False
. - prompt_images (
ImageInput
) — Prompt image to _preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If passing in images with pixel values between 0 and 1, setdo_rescale=False
. - prompt_masks (
ImageInput
) — Prompt mask from prompt image to _preprocess that specify prompt_masks value in the preprocessed output. Can either be in the format of segmentation maps (no channels) or RGB images. If in the format of RGB images,do_convert_rgb
should be set toFalse
. If in the format of segmentation maps,num_labels
specifyingnum_labels
is recommended to build a palette to map the prompt mask from a single channel to a 3 channel RGB. Ifnum_labels
is not specified, the prompt mask will be duplicated across the channel dimension. - do_resize (
bool
, optional, defaults toself.do_resize
) — Whether to resize the image. - size (
Dict[str, int]
, optional, defaults toself.size
) — Dictionary in the format{"height": h, "width": w}
specifying the size of the output image after resizing. - resample (
PILImageResampling
filter, optional, defaults toself.resample
) —PILImageResampling
filter to use if resizing the image e.g.PILImageResampling.BICUBIC
. Only has an effect ifdo_resize
is set toTrue
. Doesn’t apply to prompt mask as it is resized using nearest. - do_rescale (
bool
, optional, defaults toself.do_rescale
) — Whether to rescale the image values between [0 - 1]. - rescale_factor (
float
, optional, defaults toself.rescale_factor
) — Rescale factor to rescale the image by ifdo_rescale
is set toTrue
. - do_normalize (
bool
, optional, defaults toself.do_normalize
) — Whether to normalize the image. - image_mean (
float
orList[float]
, optional, defaults toself.image_mean
) — Image mean to use ifdo_normalize
is set toTrue
. - image_std (
float
orList[float]
, optional, defaults toself.image_std
) — Image standard deviation to use ifdo_normalize
is set toTrue
. - do_convert_rgb (
bool
, optional, defaults toself.do_convert_rgb
) — Whether to convert the prompt mask to RGB format. Ifnum_labels
is specified, a palette will be built to map the prompt mask from a single channel to a 3 channel RGB. If unset, the prompt mask is duplicated across the channel dimension. Must be set toFalse
if the prompt mask is already in RGB format. - num_labels — (
int
, optional): Number of classes in the segmentation task (excluding the background). If specified, a palette will be built, assuming that class_idx 0 is the background, to map the prompt mask from a plain segmentation map with no channels to a 3 channel RGB. Not specifying this will result in the prompt mask either being passed through as is if it is already in RGB format (ifdo_convert_rgb
is false) or being duplicated across the channel dimension. - return_tensors (
str
orTensorType
, optional) — The type of tensors to return. Can be one of:- Unset: Return a list of
np.ndarray
. TensorType.TENSORFLOW
or'tf'
: Return a batch of typetf.Tensor
.TensorType.PYTORCH
or'pt'
: Return a batch of typetorch.Tensor
.TensorType.NUMPY
or'np'
: Return a batch of typenp.ndarray
.TensorType.JAX
or'jax'
: Return a batch of typejax.numpy.ndarray
.
- Unset: Return a list of
- data_format (
ChannelDimension
orstr
, optional, defaults toChannelDimension.FIRST
) — The channel dimension format for the output image. Can be one of:"channels_first"
orChannelDimension.FIRST
: image in (num_channels, height, width) format."channels_last"
orChannelDimension.LAST
: image in (height, width, num_channels) format.- Unset: Use the channel dimension format of the input image.
- input_data_format (
ChannelDimension
orstr
, optional) — The channel dimension format for the input image. If unset, the channel dimension format is inferred from the input image. Can be one of:"channels_first"
orChannelDimension.FIRST
: image in (num_channels, height, width) format."channels_last"
orChannelDimension.LAST
: image in (height, width, num_channels) format."none"
orChannelDimension.NONE
: image in (height, width) format.
Preprocess an image or batch of images.
post_process_semantic_segmentation
< source >( outputs target_sizes: typing.Optional[typing.List[typing.Tuple[int, int]]] = None num_labels: typing.Optional[int] = None ) β semantic_segmentation
Parameters
- outputs (
SegGptImageSegmentationOutput
) — Raw outputs of the model. - target_sizes (
List[Tuple[int, int]]
, optional) — List of length (batch_size), where each list item (Tuple[int, int]
) corresponds to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. - num_labels (
int
, optional) — Number of classes in the segmentation task (excluding the background). If specified, a palette will be built, assuming that class_idx 0 is the background, to map prediction masks from RGB values to class indices. This value should be the same used when preprocessing inputs.
Returns
semantic_segmentation
List[torch.Tensor]
of length batch_size
, where each item is a semantic
segmentation map of shape (height, width) corresponding to the target_sizes entry (if target_sizes
is
specified). Each entry of each torch.Tensor
correspond to a semantic class id.
Converts the output of SegGptImageSegmentationOutput
into segmentation maps. Only supports
PyTorch.
SegGptModel
class transformers.SegGptModel
< source >( config: SegGptConfig )
Parameters
- config (SegGptConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
The bare SegGpt Model transformer outputting raw hidden-states without any specific head on top. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( pixel_values: Tensor prompt_pixel_values: Tensor prompt_masks: Tensor bool_masked_pos: typing.Optional[torch.BoolTensor] = None feature_ensemble: typing.Optional[bool] = None embedding_type: typing.Optional[str] = None labels: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β transformers.models.seggpt.modeling_seggpt.SegGptEncoderOutput
or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - prompt_pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Prompt pixel values. Prompt pixel values can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - prompt_masks (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Prompt mask. Prompt mask can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
, optional) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0). - feature_ensemble (
bool
, optional) — Boolean indicating whether to use feature ensemble or not. IfTrue
, the model will use feature ensemble if we have at least two prompts. IfFalse
, the model will not use feature ensemble. This argument should be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image. - embedding_type (
str
, optional) — Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either instance or semantic. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - labels (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
,optional
) — Ground truth mask for input images.
Returns
transformers.models.seggpt.modeling_seggpt.SegGptEncoderOutput
or tuple(torch.FloatTensor)
A transformers.models.seggpt.modeling_seggpt.SegGptEncoderOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (SegGptConfig) and inputs.
- last_hidden_state (
torch.FloatTensor
of shape(batch_size, patch_height, patch_width, hidden_size)
) β Sequence of hidden-states at the output of the last layer of the model. - hidden_states (
Tuple[torch.FloatTensor]
,optional
, returned whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, patch_height, patch_width, hidden_size)
. - attentions (
Tuple[torch.FloatTensor]
,optional
, returned whenconfig.output_attentions=True
) β Tuple of torch.FloatTensor (one for each layer) of shape(batch_size, num_heads, seq_len, seq_len)
. - intermediate_hidden_states (
Tuple[torch.FloatTensor]
, optional, returned whenconfig.intermediate_hidden_state_indices
is set) β Tuple oftorch.FloatTensor
of shape(batch_size, patch_height, patch_width, hidden_size)
. Each element in the Tuple corresponds to the output of the layer specified inconfig.intermediate_hidden_state_indices
. Additionaly, each feature passes through a LayerNorm.
The SegGptModel forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Examples:
>>> from transformers import SegGptImageProcessor, SegGptModel
>>> from PIL import Image
>>> import requests
>>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
>>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
>>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
>>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
>>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
>>> checkpoint = "BAAI/seggpt-vit-large"
>>> model = SegGptModel.from_pretrained(checkpoint)
>>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
>>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> list(outputs.last_hidden_state.shape)
[1, 56, 28, 1024]
SegGptForImageSegmentation
class transformers.SegGptForImageSegmentation
< source >( config: SegGptConfig )
Parameters
- config (SegGptConfig) — Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the from_pretrained() method to load the model weights.
SegGpt model with a decoder on top for one-shot image segmentation. This model is a PyTorch torch.nn.Module subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior.
forward
< source >( pixel_values: Tensor prompt_pixel_values: Tensor prompt_masks: Tensor bool_masked_pos: typing.Optional[torch.BoolTensor] = None feature_ensemble: typing.Optional[bool] = None embedding_type: typing.Optional[str] = None labels: typing.Optional[torch.FloatTensor] = None output_attentions: typing.Optional[bool] = None output_hidden_states: typing.Optional[bool] = None return_dict: typing.Optional[bool] = None ) β transformers.models.seggpt.modeling_seggpt.SegGptImageSegmentationOutput
or tuple(torch.FloatTensor)
Parameters
- pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Pixel values. Pixel values can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - prompt_pixel_values (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Prompt pixel values. Prompt pixel values can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - prompt_masks (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) — Prompt mask. Prompt mask can be obtained using AutoImageProcessor. See SegGptImageProcessor.call() for details. - bool_masked_pos (
torch.BoolTensor
of shape(batch_size, num_patches)
, optional) — Boolean masked positions. Indicates which patches are masked (1) and which aren’t (0). - feature_ensemble (
bool
, optional) — Boolean indicating whether to use feature ensemble or not. IfTrue
, the model will use feature ensemble if we have at least two prompts. IfFalse
, the model will not use feature ensemble. This argument should be considered when doing few-shot inference on an input image i.e. more than one prompt for the same image. - embedding_type (
str
, optional) — Embedding type. Indicates whether the prompt is a semantic or instance embedding. Can be either instance or semantic. - output_attentions (
bool
, optional) — Whether or not to return the attentions tensors of all attention layers. Seeattentions
under returned tensors for more detail. - output_hidden_states (
bool
, optional) — Whether or not to return the hidden states of all layers. Seehidden_states
under returned tensors for more detail. - return_dict (
bool
, optional) — Whether or not to return a ModelOutput instead of a plain tuple. - labels (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
,optional
) — Ground truth mask for input images.
Returns
transformers.models.seggpt.modeling_seggpt.SegGptImageSegmentationOutput
or tuple(torch.FloatTensor)
A transformers.models.seggpt.modeling_seggpt.SegGptImageSegmentationOutput
or a tuple of
torch.FloatTensor
(if return_dict=False
is passed or when config.return_dict=False
) comprising various
elements depending on the configuration (SegGptConfig) and inputs.
- loss (
torch.FloatTensor
, optional, returned whenlabels
is provided) β The loss value. - pred_masks (
torch.FloatTensor
of shape(batch_size, num_channels, height, width)
) β The predicted masks. - hidden_states (
Tuple[torch.FloatTensor]
,optional
, returned whenconfig.output_hidden_states=True
) β Tuple oftorch.FloatTensor
(one for the output of the embeddings + one for the output of each layer) of shape(batch_size, patch_height, patch_width, hidden_size)
. - attentions (
Tuple[torch.FloatTensor]
,optional
, returned whenconfig.output_attentions=True
) β Tuple oftorch.FloatTensor
(one for each layer) of shape(batch_size, num_heads, seq_len, seq_len)
.
The SegGptForImageSegmentation forward method, overrides the __call__
special method.
Although the recipe for forward pass needs to be defined within this function, one should call the Module
instance afterwards instead of this since the former takes care of running the pre and post processing steps while
the latter silently ignores them.
Examples:
>>> from transformers import SegGptImageProcessor, SegGptForImageSegmentation
>>> from PIL import Image
>>> import requests
>>> image_input_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_2.jpg"
>>> image_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1.jpg"
>>> mask_prompt_url = "https://raw.githubusercontent.com/baaivision/Painter/main/SegGPT/SegGPT_inference/examples/hmbb_1_target.png"
>>> image_input = Image.open(requests.get(image_input_url, stream=True).raw)
>>> image_prompt = Image.open(requests.get(image_prompt_url, stream=True).raw)
>>> mask_prompt = Image.open(requests.get(mask_prompt_url, stream=True).raw).convert("L")
>>> checkpoint = "BAAI/seggpt-vit-large"
>>> model = SegGptForImageSegmentation.from_pretrained(checkpoint)
>>> image_processor = SegGptImageProcessor.from_pretrained(checkpoint)
>>> inputs = image_processor(images=image_input, prompt_images=image_prompt, prompt_masks=mask_prompt, return_tensors="pt")
>>> outputs = model(**inputs)
>>> result = image_processor.post_process_semantic_segmentation(outputs, target_sizes=[(image_input.height, image_input.width)])[0]
>>> print(list(result.shape))
[170, 297]