|
import math |
|
from typing import Optional, Callable |
|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
|
|
|
|
def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
|
""" |
|
This function is taken from the original tf repo. |
|
It ensures that all layers have a channel number that is divisible by 8 |
|
It can be seen here: |
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
|
""" |
|
if min_value is None: |
|
min_value = divisor |
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
|
|
if new_v < 0.9 * v: |
|
new_v += divisor |
|
return new_v |
|
|
|
|
|
def cnn_out_size(in_size, padding, dilation, kernel, stride): |
|
s = in_size + 2 * padding - dilation * (kernel - 1) - 1 |
|
return math.floor(s / stride + 1) |
|
|
|
|
|
def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, |
|
combine_dim: int = None): |
|
""" |
|
Collapses dimension of multi-dimensional tensor by pooling or combining dimensions |
|
:param x: input Tensor |
|
:param dim: dimension to collapse |
|
:param mode: 'pool' or 'combine' |
|
:param pool_fn: function to be applied in case of pooling |
|
:param combine_dim: dimension to join 'dim' to |
|
:return: collapsed tensor |
|
""" |
|
if mode == "pool": |
|
return pool_fn(x, dim) |
|
elif mode == "combine": |
|
s = list(x.size()) |
|
s[combine_dim] *= dim |
|
s[dim] //= dim |
|
return x.view(s) |
|
|
|
|
|
class CollapseDim(nn.Module): |
|
def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, |
|
combine_dim: int = None): |
|
super(CollapseDim, self).__init__() |
|
self.dim = dim |
|
self.mode = mode |
|
self.pool_fn = pool_fn |
|
self.combine_dim = combine_dim |
|
|
|
def forward(self, x): |
|
return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim) |
|
|