hvaldez's picture
first commit
c18a21e verified
raw
history blame
1.05 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
import torch
def accuracy(output, target, topk=(1,)):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.reshape(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def get_mean_accuracy(cm):
list_acc = []
for i in range(len(cm)):
acc = 0
if cm[i, :].sum() > 0:
acc = cm[i, i] / cm[i, :].sum()
list_acc.append(acc)
return 100 * np.mean(list_acc), 100 * np.trace(cm) / np.sum(cm)