# -*- coding: utf-8 -*- # Residual block as defined in: # He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning # for image recognition." In Proceedings of the IEEE conference on computer vision # and pattern recognition, pp. 770-778. 2016. # # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) # # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen import torch import torch.nn as nn from collections import OrderedDict from models.utils.tf_utils import TFSamepaddingLayer class ResidualBlock(nn.Module): """Residual block as defined in: He, Kaiming, Xiangyu Zhang, Shaoqing Ren, and Jian Sun. "Deep residual learning for image recognition." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 770-778. 2016. """ def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, stride=1): super(ResidualBlock, self).__init__() assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" self.nr_unit = unit_count self.in_ch = in_ch self.unit_ch = unit_ch # ! For inference only so init values for batchnorm may not match tensorflow unit_in_ch = in_ch self.units = nn.ModuleList() for idx in range(unit_count): unit_layer = [ ("preact/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("preact/relu", nn.ReLU(inplace=True)), ( "conv1", nn.Conv2d( unit_in_ch, unit_ch[0], unit_ksize[0], stride=1, padding=0, bias=False, ), ), ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), ("conv1/relu", nn.ReLU(inplace=True)), ( "conv2/pad", TFSamepaddingLayer( ksize=unit_ksize[1], stride=stride if idx == 0 else 1 ), ), ( "conv2", nn.Conv2d( unit_ch[0], unit_ch[1], unit_ksize[1], stride=stride if idx == 0 else 1, padding=0, bias=False, ), ), ("conv2/bn", nn.BatchNorm2d(unit_ch[1], eps=1e-5)), ("conv2/relu", nn.ReLU(inplace=True)), ( "conv3", nn.Conv2d( unit_ch[1], unit_ch[2], unit_ksize[2], stride=1, padding=0, bias=False, ), ), ] # * has bna to conclude each previous block so # * must not put preact for the first unit of this block unit_layer = unit_layer if idx != 0 else unit_layer[2:] self.units.append(nn.Sequential(OrderedDict(unit_layer))) unit_in_ch = unit_ch[-1] if in_ch != unit_ch[-1] or stride != 1: self.shortcut = nn.Conv2d(in_ch, unit_ch[-1], 1, stride=stride, bias=False) else: self.shortcut = None self.blk_bna = nn.Sequential( OrderedDict( [ ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ] ) ) def out_ch(self): return self.unit_ch[-1] def init_weights(self): """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" for m in self.modules(): classname = m.__class__.__name__ if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if "norm" in classname.lower(): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if "linear" in classname.lower(): if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, prev_feat, freeze=False): if self.shortcut is None: shortcut = prev_feat else: shortcut = self.shortcut(prev_feat) for idx in range(0, len(self.units)): new_feat = prev_feat if self.training: with torch.set_grad_enabled(not freeze): new_feat = self.units[idx](new_feat) else: new_feat = self.units[idx](new_feat) prev_feat = new_feat + shortcut shortcut = prev_feat feat = self.blk_bna(prev_feat) return feat