import logging from collections import OrderedDict from pathlib import Path from typing import Union, List import torch import torchvision def check_is_valid_torchvision_architecture(architecture: str): """Raises an ValueError if architecture is not part of available torchvision models """ available = sorted( name for name in torchvision.models.__dict__ if name.islower() and not name.startswith("__") and callable(torchvision.models.__dict__[name]) ) if architecture not in available: raise ValueError(f"{architecture} not in {available}") def build_base_model(arch: str): model = torchvision.models.__dict__[arch](pretrained=True) # get input dimension before classification layer if arch in ["mobilenet_v2"]: nfeatures = model.classifier[-1].in_features model = torch.nn.Sequential(*list(model.children())[:-1]) elif arch in ["densenet121", "densenet161", "densenet169"]: nfeatures = model.classifier.in_features model = torch.nn.Sequential(*list(model.children())[:-1]) elif "resne" in arch: # usually all ResNet variants nfeatures = model.fc.in_features model = torch.nn.Sequential(*list(model.children())[:-2]) else: raise NotImplementedError model.avgpool = torch.nn.AdaptiveAvgPool2d(1) model.flatten = torch.nn.Flatten(start_dim=1) return model, nfeatures def load_weights_if_available( model: torch.nn.Module, classifier: torch.nn.Module, weights_path: Union[str, Path] ): checkpoint = torch.load(weights_path, map_location=lambda storage, loc: storage) state_dict_features = OrderedDict() state_dict_classifier = OrderedDict() for k, w in checkpoint["state_dict"].items(): if k.startswith("model"): state_dict_features[k.replace("model.", "")] = w elif k.startswith("classifier"): state_dict_classifier[k.replace("classifier.", "")] = w else: logging.warning(f"Unexpected prefix in state_dict: {k}") model.load_state_dict(state_dict_features, strict=True) return model, classifier def vectorized_gc_distance(latitudes, longitudes, latitudes_gt, longitudes_gt): R = 6371 factor_rad = 0.01745329252 longitudes = factor_rad * longitudes longitudes_gt = factor_rad * longitudes_gt latitudes = factor_rad * latitudes latitudes_gt = factor_rad * latitudes_gt delta_long = longitudes_gt - longitudes delta_lat = latitudes_gt - latitudes subterm0 = torch.sin(delta_lat / 2) ** 2 subterm1 = torch.cos(latitudes) * torch.cos(latitudes_gt) subterm2 = torch.sin(delta_long / 2) ** 2 subterm1 = subterm1 * subterm2 a = subterm0 + subterm1 c = 2 * torch.asin(torch.sqrt(a)) gcd = R * c return gcd def gcd_threshold_eval(gc_dists, thresholds=[1, 25, 200, 750, 2500]): # calculate accuracy for given gcd thresolds results = {} for thres in thresholds: results[thres] = torch.true_divide( torch.sum(gc_dists <= thres), len(gc_dists) ).item() return results def accuracy(output, target, partitioning_shortnames: list, topk=(1, 5, 10)): def _accuracy(output, target, topk=(1,)): """Computes the accuracy over the k top predictions for the specified values of k""" with torch.no_grad(): maxk = max(topk) batch_size = target.size(0) _, pred = output.topk(maxk, 1, True, True) pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = {} for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res[k] = correct_k / batch_size return res with torch.no_grad(): out_dict = {} for i, pname in enumerate(partitioning_shortnames): res_dict = _accuracy(output[i], target[i], topk=topk) for k, v in res_dict.items(): out_dict[f"acc{k}_val/{pname}"] = v return out_dict def summarize_gcd_stats(pnames: List[str], outputs, hierarchy=None): gcd_dict = {} metric_names = [f"gcd_{p}_val" for p in pnames] if hierarchy is not None: metric_names.append("gcd_hierarchy_val") for metric_name in metric_names: distances_flat = [output[metric_name] for output in outputs] distances_flat = torch.cat(distances_flat, dim=0) gcd_results = gcd_threshold_eval(distances_flat) for gcd_thres, acc in gcd_results.items(): gcd_dict[f"{metric_name}/{gcd_thres}"] = acc return gcd_dict def summarize_test_gcd(pnames, outputs, hierarchy=None): def _eval(output): # calculate acc@km for a list of given thresholds accuracy_outputs = {} if hierarchy is not None: pnames.append("hierarchy") for pname in pnames: # concat batches of distances distances_flat = torch.cat([x[pname] for x in output], dim=0) # acc for all distances acc_dict = gcd_threshold_eval(distances_flat) accuracy_outputs[f"acc_test/{pname}"] = acc_dict return accuracy_outputs result = {} if isinstance(outputs[0], dict): # only one testset result = _eval(outputs) elif isinstance(outputs[0], list): # multiple testsets for testset_index, output in enumerate(outputs): result[testset_index] = _eval(output) else: raise TypeError return result def summarize_loss_acc_stats(pnames: List[str], outputs, topk=[1, 5, 10]): loss_acc_dict = {} metric_names = [] for k in topk: accuracy_names = [f"acc{k}_val/{p}" for p in pnames] metric_names.extend(accuracy_names) metric_names.extend([f"loss_val/{p}" for p in pnames]) for metric_name in ["loss_val/total", *metric_names]: metric_total = 0 for output in outputs: metric_value = output[metric_name] metric_total += metric_value loss_acc_dict[metric_name] = metric_total / len(outputs) return loss_acc_dict