|
|
|
import numpy as np |
|
from skimage.metrics import peak_signal_noise_ratio, structural_similarity |
|
from typing import Optional |
|
|
|
def ssim( |
|
gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None |
|
) -> np.ndarray: |
|
"""Compute Structural Similarity Index Metric (SSIM)""" |
|
if not gt.ndim == 3: |
|
raise ValueError("Unexpected number of dimensions in ground truth.") |
|
if not gt.ndim == pred.ndim: |
|
raise ValueError("Ground truth dimensions does not match pred.") |
|
|
|
data_range = gt.max() if data_range is None else data_range |
|
|
|
ssim = np.array([0]) |
|
for slice_num in range(gt.shape[0]): |
|
ssim = ssim + structural_similarity( |
|
gt[slice_num], pred[slice_num], data_range=data_range |
|
) |
|
|
|
return ssim / gt.shape[0] |
|
|
|
def psnr( |
|
gt: np.ndarray, pred: np.ndarray, data_range: Optional[float] = None |
|
) -> np.ndarray: |
|
"""Compute Peak Signal to Noise Ratio metric (PSNR)""" |
|
data_range = gt.max() if data_range is None else data_range |
|
return peak_signal_noise_ratio(gt, pred, data_range=data_range) |
|
|