|
|
|
|
|
|
|
import torch |
|
|
|
def gaussian_kernel(size, sigma=1.0): |
|
""" |
|
Generates a 2D Gaussian kernel using PyTorch. |
|
|
|
Parameters: |
|
- size: The size of the kernel (an integer). It's recommended to use an odd number |
|
to have a central pixel. |
|
- sigma: The standard deviation of the Gaussian distribution. |
|
|
|
Returns: |
|
- A 2D PyTorch tensor representing the Gaussian kernel. |
|
""" |
|
size = int(size) // 2 |
|
x, y = torch.meshgrid(torch.arange(-size, size+1), torch.arange(-size, size+1)) |
|
g = torch.exp(-(x**2 + y**2) / (2 * sigma**2)) |
|
return g / g.sum() |
|
|
|
def laplacian_kernel(size, scale=1.0): |
|
""" |
|
Creates a Laplacian kernel for edge detection with an adjustable size and scale factor. |
|
|
|
Parameters: |
|
- size: The size of the kernel (an integer). It's recommended to use an odd number |
|
to ensure a central pixel. |
|
- scale: A float that adjusts the intensity of the edge detection effect. |
|
|
|
Returns: |
|
- A 2D PyTorch tensor representing the scaled Laplacian kernel. |
|
""" |
|
if size % 2 == 0: |
|
raise ValueError("Size must be odd.") |
|
|
|
|
|
kernel = torch.zeros((size, size), dtype=torch.float32) |
|
|
|
|
|
kernel[size // 2, size // 2] = -4.0 |
|
|
|
|
|
kernel[size // 2, size // 2 - 1] = kernel[size // 2, size // 2 + 1] = 1.0 |
|
kernel[size // 2 - 1, size // 2] = kernel[size // 2 + 1, size // 2] = 1.0 |
|
|
|
|
|
if size > 3: |
|
for i in range(size): |
|
for j in range(size): |
|
if i == 0 or i == size - 1 or j == 0 or j == size - 1: |
|
kernel[i, j] = 1.0 |
|
|
|
|
|
kernel *= scale |
|
|
|
|
|
center = size // 2 |
|
kernel[center, center] = -torch.sum(kernel) + kernel[center, center] |
|
|
|
return kernel |
|
|
|
def fftshift(input): |
|
""" |
|
Reorients the FFT output so the zero-frequency component is at the center. |
|
|
|
Parameters: |
|
- input: A 2D tensor representing the FFT output. |
|
|
|
Returns: |
|
- A 2D tensor with the zero-frequency component shifted to the center. |
|
""" |
|
|
|
for dim in range(2): |
|
n = input.shape[dim] |
|
half = (n + 1) // 2 |
|
input = torch.roll(input, shifts=half, dims=dim) |
|
return input |
|
|