ProPainter / model /canny /filter.py
sczhou's picture
init code
320e465
raw
history blame
11 kB
from typing import List
import torch
import torch.nn.functional as F
from .kernels import normalize_kernel2d
def _compute_padding(kernel_size: List[int]) -> List[int]:
"""Compute padding tuple."""
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
if len(kernel_size) < 2:
raise AssertionError(kernel_size)
computed = [k - 1 for k in kernel_size]
# for even kernels we need to do asymmetric padding :(
out_padding = 2 * len(kernel_size) * [0]
for i in range(len(kernel_size)):
computed_tmp = computed[-(i + 1)]
pad_front = computed_tmp // 2
pad_rear = computed_tmp - pad_front
out_padding[2 * i + 0] = pad_front
out_padding[2 * i + 1] = pad_rear
return out_padding
def filter2d(
input: torch.Tensor,
kernel: torch.Tensor,
border_type: str = 'reflect',
normalized: bool = False,
padding: str = 'same',
) -> torch.Tensor:
r"""Convolve a tensor with a 2d kernel.
The function applies a given kernel to a tensor. The kernel is applied
independently at each depth channel of the tensor. Before applying the
kernel, the function applies padding according to the specified mode so
that the output remains in the same shape.
Args:
input: the input tensor with shape of
:math:`(B, C, H, W)`.
kernel: the kernel to be convolved with the input
tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
normalized: If True, kernel will be L1 normalized.
padding: This defines the type of padding.
2 modes available ``'same'`` or ``'valid'``.
Return:
torch.Tensor: the convolved tensor of same size and numbers of channels
as the input with shape :math:`(B, C, H, W)`.
Example:
>>> input = torch.tensor([[[
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 5., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],]]])
>>> kernel = torch.ones(1, 3, 3)
>>> filter2d(input, kernel, padding='same')
tensor([[[[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.]]]])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
if not isinstance(kernel, torch.Tensor):
raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
if not isinstance(border_type, str):
raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
if border_type not in ['constant', 'reflect', 'replicate', 'circular']:
raise ValueError(
f"Invalid border type, we expect 'constant', \
'reflect', 'replicate', 'circular'. Got:{border_type}"
)
if not isinstance(padding, str):
raise TypeError(f"Input padding is not string. Got {type(padding)}")
if padding not in ['valid', 'same']:
raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}")
if not len(input.shape) == 4:
raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])):
raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}")
# prepare kernel
b, c, h, w = input.shape
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
if normalized:
tmp_kernel = normalize_kernel2d(tmp_kernel)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
height, width = tmp_kernel.shape[-2:]
# pad the input tensor
if padding == 'same':
padding_shape: List[int] = _compute_padding([height, width])
input = F.pad(input, padding_shape, mode=border_type)
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
# convolve the tensor with the kernel.
output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
if padding == 'same':
out = output.view(b, c, h, w)
else:
out = output.view(b, c, h - height + 1, w - width + 1)
return out
def filter2d_separable(
input: torch.Tensor,
kernel_x: torch.Tensor,
kernel_y: torch.Tensor,
border_type: str = 'reflect',
normalized: bool = False,
padding: str = 'same',
) -> torch.Tensor:
r"""Convolve a tensor with two 1d kernels, in x and y directions.
The function applies a given kernel to a tensor. The kernel is applied
independently at each depth channel of the tensor. Before applying the
kernel, the function applies padding according to the specified mode so
that the output remains in the same shape.
Args:
input: the input tensor with shape of
:math:`(B, C, H, W)`.
kernel_x: the kernel to be convolved with the input
tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
kernel_y: the kernel to be convolved with the input
tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``, ``'reflect'``,
``'replicate'`` or ``'circular'``.
normalized: If True, kernel will be L1 normalized.
padding: This defines the type of padding.
2 modes available ``'same'`` or ``'valid'``.
Return:
torch.Tensor: the convolved tensor of same size and numbers of channels
as the input with shape :math:`(B, C, H, W)`.
Example:
>>> input = torch.tensor([[[
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 5., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],]]])
>>> kernel = torch.ones(1, 3)
>>> filter2d_separable(input, kernel, kernel, padding='same')
tensor([[[[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.]]]])
"""
out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
return out
def filter3d(
input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False
) -> torch.Tensor:
r"""Convolve a tensor with a 3d kernel.
The function applies a given kernel to a tensor. The kernel is applied
independently at each depth channel of the tensor. Before applying the
kernel, the function applies padding according to the specified mode so
that the output remains in the same shape.
Args:
input: the input tensor with shape of
:math:`(B, C, D, H, W)`.
kernel: the kernel to be convolved with the input
tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
border_type: the padding mode to be applied before convolving.
The expected modes are: ``'constant'``,
``'replicate'`` or ``'circular'``.
normalized: If True, kernel will be L1 normalized.
Return:
the convolved tensor of same size and numbers of channels
as the input with shape :math:`(B, C, D, H, W)`.
Example:
>>> input = torch.tensor([[[
... [[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]],
... [[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 5., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]],
... [[0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.],
... [0., 0., 0., 0., 0.]]
... ]]])
>>> kernel = torch.ones(1, 3, 3, 3)
>>> filter3d(input, kernel)
tensor([[[[[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.]],
<BLANKLINE>
[[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.]]]]])
"""
if not isinstance(input, torch.Tensor):
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
if not isinstance(kernel, torch.Tensor):
raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
if not isinstance(border_type, str):
raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
if not len(input.shape) == 5:
raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}")
# prepare kernel
b, c, d, h, w = input.shape
tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
if normalized:
bk, dk, hk, wk = kernel.shape
tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel)
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
# pad the input tensor
depth, height, width = tmp_kernel.shape[-3:]
padding_shape: List[int] = _compute_padding([depth, height, width])
input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
# kernel and input tensor reshape to align element-wise or batch-wise params
tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1))
# convolve the tensor with the kernel.
output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
return output.view(b, c, d, h, w)