# -*- coding: utf-8 -*- # Dense Block as defined in: # Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. # "Densely connected convolutional networks." In Proceedings of the IEEE conference # on computer vision and pattern recognition, pp. 4700-4708. 2017. # # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net) # # @ Fabian Hörst, fabian.hoerst@uk-essen.de # Institute for Artifical Intelligence in Medicine, # University Medicine Essen import torch import torch.nn as nn from collections import OrderedDict class DenseBlock(nn.Module): """Dense Block as defined in: Huang, Gao, Zhuang Liu, Laurens Van Der Maaten, and Kilian Q. Weinberger. "Densely connected convolutional networks." In Proceedings of the IEEE conference on computer vision and pattern recognition, pp. 4700-4708. 2017. Only performs `valid` convolution. """ def __init__(self, in_ch, unit_ksize, unit_ch, unit_count, split=1): super(DenseBlock, self).__init__() assert len(unit_ksize) == len(unit_ch), "Unbalance Unit Info" self.nr_unit = unit_count self.in_ch = in_ch self.unit_ch = unit_ch # ! For inference only so init values for batchnorm may not match tensorflow unit_in_ch = in_ch self.units = nn.ModuleList() for idx in range(unit_count): self.units.append( nn.Sequential( OrderedDict( [ ("preact_bna/bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("preact_bna/relu", nn.ReLU(inplace=True)), ( "conv1", nn.Conv2d( unit_in_ch, unit_ch[0], unit_ksize[0], stride=1, padding=0, bias=False, ), ), ("conv1/bn", nn.BatchNorm2d(unit_ch[0], eps=1e-5)), ("conv1/relu", nn.ReLU(inplace=True)), # ('conv2/pool', TFSamepaddingLayer(ksize=unit_ksize[1], stride=1)), ( "conv2", nn.Conv2d( unit_ch[0], unit_ch[1], unit_ksize[1], groups=split, stride=1, padding=0, bias=False, ), ), ] ) ) ) unit_in_ch += unit_ch[1] self.blk_bna = nn.Sequential( OrderedDict( [ ("bn", nn.BatchNorm2d(unit_in_ch, eps=1e-5)), ("relu", nn.ReLU(inplace=True)), ] ) ) def out_ch(self): return self.in_ch + self.nr_unit * self.unit_ch[-1] def init_weights(self): """Kaiming (HE) initialization for convolutional layers and constant initialization for normalization and linear layers""" for m in self.modules(): classname = m.__class__.__name__ if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") if "norm" in classname.lower(): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) if "linear" in classname.lower(): if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, prev_feat): for idx in range(self.nr_unit): new_feat = self.units[idx](prev_feat) prev_feat = crop_to_shape(prev_feat, new_feat) prev_feat = torch.cat([prev_feat, new_feat], dim=1) prev_feat = self.blk_bna(prev_feat) return prev_feat # helper functions for cropping def crop_op(x, cropping, data_format="NCHW"): """Center crop image. Args: x: input image cropping: the substracted amount data_format: choose either `NCHW` or `NHWC` """ crop_t = cropping[0] // 2 crop_b = cropping[0] - crop_t crop_l = cropping[1] // 2 crop_r = cropping[1] - crop_l if data_format == "NCHW": x = x[:, :, crop_t:-crop_b, crop_l:-crop_r] else: x = x[:, crop_t:-crop_b, crop_l:-crop_r, :] return x def crop_to_shape(x, y, data_format="NCHW"): """Centre crop x so that x has shape of y. y dims must be smaller than x dims. Args: x: input array y: array with desired shape. """ assert ( y.shape[0] <= x.shape[0] and y.shape[1] <= x.shape[1] ), "Ensure that y dimensions are smaller than x dimensions!" x_shape = x.size() y_shape = y.size() if data_format == "NCHW": crop_shape = (x_shape[2] - y_shape[2], x_shape[3] - y_shape[3]) else: crop_shape = (x_shape[1] - y_shape[1], x_shape[2] - y_shape[2]) return crop_op(x, crop_shape, data_format)