Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
import annotator.uniformer.mmcv as mmcv | |
import numpy as np | |
import torch | |
def f_score(precision, recall, beta=1): | |
"""calcuate the f-score value. | |
Args: | |
precision (float | torch.Tensor): The precision value. | |
recall (float | torch.Tensor): The recall value. | |
beta (int): Determines the weight of recall in the combined score. | |
Default: False. | |
Returns: | |
[torch.tensor]: The f-score value. | |
""" | |
score = (1 + beta**2) * (precision * recall) / ( | |
(beta**2 * precision) + recall) | |
return score | |
def intersect_and_union(pred_label, | |
label, | |
num_classes, | |
ignore_index, | |
label_map=dict(), | |
reduce_zero_label=False): | |
"""Calculate intersection and Union. | |
Args: | |
pred_label (ndarray | str): Prediction segmentation map | |
or predict result filename. | |
label (ndarray | str): Ground truth segmentation map | |
or label filename. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
label_map (dict): Mapping old labels to new labels. The parameter will | |
work only when label is str. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. The parameter will | |
work only when label is str. Default: False. | |
Returns: | |
torch.Tensor: The intersection of prediction and ground truth | |
histogram on all classes. | |
torch.Tensor: The union of prediction and ground truth histogram on | |
all classes. | |
torch.Tensor: The prediction histogram on all classes. | |
torch.Tensor: The ground truth histogram on all classes. | |
""" | |
if isinstance(pred_label, str): | |
pred_label = torch.from_numpy(np.load(pred_label)) | |
else: | |
pred_label = torch.from_numpy((pred_label)) | |
if isinstance(label, str): | |
label = torch.from_numpy( | |
mmcv.imread(label, flag='unchanged', backend='pillow')) | |
else: | |
label = torch.from_numpy(label) | |
if label_map is not None: | |
for old_id, new_id in label_map.items(): | |
label[label == old_id] = new_id | |
if reduce_zero_label: | |
label[label == 0] = 255 | |
label = label - 1 | |
label[label == 254] = 255 | |
mask = (label != ignore_index) | |
pred_label = pred_label[mask] | |
label = label[mask] | |
intersect = pred_label[pred_label == label] | |
area_intersect = torch.histc( | |
intersect.float(), bins=(num_classes), min=0, max=num_classes - 1) | |
area_pred_label = torch.histc( | |
pred_label.float(), bins=(num_classes), min=0, max=num_classes - 1) | |
area_label = torch.histc( | |
label.float(), bins=(num_classes), min=0, max=num_classes - 1) | |
area_union = area_pred_label + area_label - area_intersect | |
return area_intersect, area_union, area_pred_label, area_label | |
def total_intersect_and_union(results, | |
gt_seg_maps, | |
num_classes, | |
ignore_index, | |
label_map=dict(), | |
reduce_zero_label=False): | |
"""Calculate Total Intersection and Union. | |
Args: | |
results (list[ndarray] | list[str]): List of prediction segmentation | |
maps or list of prediction result filenames. | |
gt_seg_maps (list[ndarray] | list[str]): list of ground truth | |
segmentation maps or list of label filenames. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
label_map (dict): Mapping old labels to new labels. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. Default: False. | |
Returns: | |
ndarray: The intersection of prediction and ground truth histogram | |
on all classes. | |
ndarray: The union of prediction and ground truth histogram on all | |
classes. | |
ndarray: The prediction histogram on all classes. | |
ndarray: The ground truth histogram on all classes. | |
""" | |
num_imgs = len(results) | |
assert len(gt_seg_maps) == num_imgs | |
total_area_intersect = torch.zeros((num_classes, ), dtype=torch.float64) | |
total_area_union = torch.zeros((num_classes, ), dtype=torch.float64) | |
total_area_pred_label = torch.zeros((num_classes, ), dtype=torch.float64) | |
total_area_label = torch.zeros((num_classes, ), dtype=torch.float64) | |
for i in range(num_imgs): | |
area_intersect, area_union, area_pred_label, area_label = \ | |
intersect_and_union( | |
results[i], gt_seg_maps[i], num_classes, ignore_index, | |
label_map, reduce_zero_label) | |
total_area_intersect += area_intersect | |
total_area_union += area_union | |
total_area_pred_label += area_pred_label | |
total_area_label += area_label | |
return total_area_intersect, total_area_union, total_area_pred_label, \ | |
total_area_label | |
def mean_iou(results, | |
gt_seg_maps, | |
num_classes, | |
ignore_index, | |
nan_to_num=None, | |
label_map=dict(), | |
reduce_zero_label=False): | |
"""Calculate Mean Intersection and Union (mIoU) | |
Args: | |
results (list[ndarray] | list[str]): List of prediction segmentation | |
maps or list of prediction result filenames. | |
gt_seg_maps (list[ndarray] | list[str]): list of ground truth | |
segmentation maps or list of label filenames. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
nan_to_num (int, optional): If specified, NaN values will be replaced | |
by the numbers defined by the user. Default: None. | |
label_map (dict): Mapping old labels to new labels. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. Default: False. | |
Returns: | |
dict[str, float | ndarray]: | |
<aAcc> float: Overall accuracy on all images. | |
<Acc> ndarray: Per category accuracy, shape (num_classes, ). | |
<IoU> ndarray: Per category IoU, shape (num_classes, ). | |
""" | |
iou_result = eval_metrics( | |
results=results, | |
gt_seg_maps=gt_seg_maps, | |
num_classes=num_classes, | |
ignore_index=ignore_index, | |
metrics=['mIoU'], | |
nan_to_num=nan_to_num, | |
label_map=label_map, | |
reduce_zero_label=reduce_zero_label) | |
return iou_result | |
def mean_dice(results, | |
gt_seg_maps, | |
num_classes, | |
ignore_index, | |
nan_to_num=None, | |
label_map=dict(), | |
reduce_zero_label=False): | |
"""Calculate Mean Dice (mDice) | |
Args: | |
results (list[ndarray] | list[str]): List of prediction segmentation | |
maps or list of prediction result filenames. | |
gt_seg_maps (list[ndarray] | list[str]): list of ground truth | |
segmentation maps or list of label filenames. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
nan_to_num (int, optional): If specified, NaN values will be replaced | |
by the numbers defined by the user. Default: None. | |
label_map (dict): Mapping old labels to new labels. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. Default: False. | |
Returns: | |
dict[str, float | ndarray]: Default metrics. | |
<aAcc> float: Overall accuracy on all images. | |
<Acc> ndarray: Per category accuracy, shape (num_classes, ). | |
<Dice> ndarray: Per category dice, shape (num_classes, ). | |
""" | |
dice_result = eval_metrics( | |
results=results, | |
gt_seg_maps=gt_seg_maps, | |
num_classes=num_classes, | |
ignore_index=ignore_index, | |
metrics=['mDice'], | |
nan_to_num=nan_to_num, | |
label_map=label_map, | |
reduce_zero_label=reduce_zero_label) | |
return dice_result | |
def mean_fscore(results, | |
gt_seg_maps, | |
num_classes, | |
ignore_index, | |
nan_to_num=None, | |
label_map=dict(), | |
reduce_zero_label=False, | |
beta=1): | |
"""Calculate Mean Intersection and Union (mIoU) | |
Args: | |
results (list[ndarray] | list[str]): List of prediction segmentation | |
maps or list of prediction result filenames. | |
gt_seg_maps (list[ndarray] | list[str]): list of ground truth | |
segmentation maps or list of label filenames. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
nan_to_num (int, optional): If specified, NaN values will be replaced | |
by the numbers defined by the user. Default: None. | |
label_map (dict): Mapping old labels to new labels. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. Default: False. | |
beta (int): Determines the weight of recall in the combined score. | |
Default: False. | |
Returns: | |
dict[str, float | ndarray]: Default metrics. | |
<aAcc> float: Overall accuracy on all images. | |
<Fscore> ndarray: Per category recall, shape (num_classes, ). | |
<Precision> ndarray: Per category precision, shape (num_classes, ). | |
<Recall> ndarray: Per category f-score, shape (num_classes, ). | |
""" | |
fscore_result = eval_metrics( | |
results=results, | |
gt_seg_maps=gt_seg_maps, | |
num_classes=num_classes, | |
ignore_index=ignore_index, | |
metrics=['mFscore'], | |
nan_to_num=nan_to_num, | |
label_map=label_map, | |
reduce_zero_label=reduce_zero_label, | |
beta=beta) | |
return fscore_result | |
def eval_metrics(results, | |
gt_seg_maps, | |
num_classes, | |
ignore_index, | |
metrics=['mIoU'], | |
nan_to_num=None, | |
label_map=dict(), | |
reduce_zero_label=False, | |
beta=1): | |
"""Calculate evaluation metrics | |
Args: | |
results (list[ndarray] | list[str]): List of prediction segmentation | |
maps or list of prediction result filenames. | |
gt_seg_maps (list[ndarray] | list[str]): list of ground truth | |
segmentation maps or list of label filenames. | |
num_classes (int): Number of categories. | |
ignore_index (int): Index that will be ignored in evaluation. | |
metrics (list[str] | str): Metrics to be evaluated, 'mIoU' and 'mDice'. | |
nan_to_num (int, optional): If specified, NaN values will be replaced | |
by the numbers defined by the user. Default: None. | |
label_map (dict): Mapping old labels to new labels. Default: dict(). | |
reduce_zero_label (bool): Wether ignore zero label. Default: False. | |
Returns: | |
float: Overall accuracy on all images. | |
ndarray: Per category accuracy, shape (num_classes, ). | |
ndarray: Per category evaluation metrics, shape (num_classes, ). | |
""" | |
if isinstance(metrics, str): | |
metrics = [metrics] | |
allowed_metrics = ['mIoU', 'mDice', 'mFscore'] | |
if not set(metrics).issubset(set(allowed_metrics)): | |
raise KeyError('metrics {} is not supported'.format(metrics)) | |
total_area_intersect, total_area_union, total_area_pred_label, \ | |
total_area_label = total_intersect_and_union( | |
results, gt_seg_maps, num_classes, ignore_index, label_map, | |
reduce_zero_label) | |
all_acc = total_area_intersect.sum() / total_area_label.sum() | |
ret_metrics = OrderedDict({'aAcc': all_acc}) | |
for metric in metrics: | |
if metric == 'mIoU': | |
iou = total_area_intersect / total_area_union | |
acc = total_area_intersect / total_area_label | |
ret_metrics['IoU'] = iou | |
ret_metrics['Acc'] = acc | |
elif metric == 'mDice': | |
dice = 2 * total_area_intersect / ( | |
total_area_pred_label + total_area_label) | |
acc = total_area_intersect / total_area_label | |
ret_metrics['Dice'] = dice | |
ret_metrics['Acc'] = acc | |
elif metric == 'mFscore': | |
precision = total_area_intersect / total_area_pred_label | |
recall = total_area_intersect / total_area_label | |
f_value = torch.tensor( | |
[f_score(x[0], x[1], beta) for x in zip(precision, recall)]) | |
ret_metrics['Fscore'] = f_value | |
ret_metrics['Precision'] = precision | |
ret_metrics['Recall'] = recall | |
ret_metrics = { | |
metric: value.numpy() | |
for metric, value in ret_metrics.items() | |
} | |
if nan_to_num is not None: | |
ret_metrics = OrderedDict({ | |
metric: np.nan_to_num(metric_value, nan=nan_to_num) | |
for metric, metric_value in ret_metrics.items() | |
}) | |
return ret_metrics | |