Spaces:
Runtime error
Runtime error
import math | |
from typing import Optional, Callable | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
def make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: | |
""" | |
This function is taken from the original tf repo. | |
It ensures that all layers have a channel number that is divisible by 8 | |
It can be seen here: | |
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py | |
""" | |
if min_value is None: | |
min_value = divisor | |
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) | |
# Make sure that round down does not go down by more than 10%. | |
if new_v < 0.9 * v: | |
new_v += divisor | |
return new_v | |
def cnn_out_size(in_size, padding, dilation, kernel, stride): | |
s = in_size + 2 * padding - dilation * (kernel - 1) - 1 | |
return math.floor(s / stride + 1) | |
def collapse_dim(x: Tensor, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, | |
combine_dim: int = None): | |
""" | |
Collapses dimension of multi-dimensional tensor by pooling or combining dimensions | |
:param x: input Tensor | |
:param dim: dimension to collapse | |
:param mode: 'pool' or 'combine' | |
:param pool_fn: function to be applied in case of pooling | |
:param combine_dim: dimension to join 'dim' to | |
:return: collapsed tensor | |
""" | |
if mode == "pool": | |
return pool_fn(x, dim) | |
elif mode == "combine": | |
s = list(x.size()) | |
s[combine_dim] *= dim | |
s[dim] //= dim | |
return x.view(s) | |
class CollapseDim(nn.Module): | |
def __init__(self, dim: int, mode: str = "pool", pool_fn: Callable[[Tensor, int], Tensor] = torch.mean, | |
combine_dim: int = None): | |
super(CollapseDim, self).__init__() | |
self.dim = dim | |
self.mode = mode | |
self.pool_fn = pool_fn | |
self.combine_dim = combine_dim | |
def forward(self, x): | |
return collapse_dim(x, dim=self.dim, mode=self.mode, pool_fn=self.pool_fn, combine_dim=self.combine_dim) | |