|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
|
|
|
|
def gaussian_kernel(kernel_size, sigma): |
|
kernel = np.fromfunction( |
|
lambda x, y: (1 / (2 * np.pi * sigma ** 2)) * |
|
np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)), |
|
(kernel_size, kernel_size) |
|
) |
|
return kernel / np.sum(kernel) |
|
|
|
|
|
class GaussianBlur(nn.Module): |
|
def __init__(self, channels, kernel_size, sigma): |
|
super(GaussianBlur, self).__init__() |
|
self.channels = channels |
|
self.kernel_size = kernel_size |
|
self.sigma = sigma |
|
self.padding = kernel_size // 2 |
|
self.register_buffer('kernel', torch.tensor(gaussian_kernel(kernel_size, sigma), dtype=torch.float32)) |
|
self.kernel = self.kernel.view(1, 1, kernel_size, kernel_size) |
|
self.kernel = self.kernel.expand(self.channels, -1, -1, -1) |
|
|
|
def forward(self, x): |
|
x = F.conv2d(x, self.kernel.to(x), padding=self.padding, groups=self.channels) |
|
return x |
|
|
|
|
|
gaussian_filter_2d = GaussianBlur(4, 7, 0.8) |
|
|