Spaces:
Runtime error
Runtime error
Face-Landmark-ControlNet
/
annotator
/uniformer
/mmseg
/models
/decode_heads
/cascade_decode_head.py
from abc import ABCMeta, abstractmethod | |
from .decode_head import BaseDecodeHead | |
class BaseCascadeDecodeHead(BaseDecodeHead, metaclass=ABCMeta): | |
"""Base class for cascade decode head used in | |
:class:`CascadeEncoderDecoder.""" | |
def __init__(self, *args, **kwargs): | |
super(BaseCascadeDecodeHead, self).__init__(*args, **kwargs) | |
def forward(self, inputs, prev_output): | |
"""Placeholder of forward function.""" | |
pass | |
def forward_train(self, inputs, prev_output, img_metas, gt_semantic_seg, | |
train_cfg): | |
"""Forward function for training. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
prev_output (Tensor): The output of previous decode head. | |
img_metas (list[dict]): List of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:Collect`. | |
gt_semantic_seg (Tensor): Semantic segmentation masks | |
used if the architecture supports semantic segmentation task. | |
train_cfg (dict): The training config. | |
Returns: | |
dict[str, Tensor]: a dictionary of loss components | |
""" | |
seg_logits = self.forward(inputs, prev_output) | |
losses = self.losses(seg_logits, gt_semantic_seg) | |
return losses | |
def forward_test(self, inputs, prev_output, img_metas, test_cfg): | |
"""Forward function for testing. | |
Args: | |
inputs (list[Tensor]): List of multi-level img features. | |
prev_output (Tensor): The output of previous decode head. | |
img_metas (list[dict]): List of image info dict where each dict | |
has: 'img_shape', 'scale_factor', 'flip', and may also contain | |
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'. | |
For details on the values of these keys see | |
`mmseg/datasets/pipelines/formatting.py:Collect`. | |
test_cfg (dict): The testing config. | |
Returns: | |
Tensor: Output segmentation map. | |
""" | |
return self.forward(inputs, prev_output) | |