""" Integrate numerical values for some iterations Typically used for loss computation / logging to tensorboard Call finalize and create a new Integrator when you want to display/log """ from typing import Dict, Callable, Tuple import torch from tracker.utils.logger import TensorboardLogger class Integrator: def __init__(self, logger: TensorboardLogger, distributed: bool = True): self.values = {} self.counts = {} self.hooks = [] # List is used here to maintain insertion order self.logger = logger self.distributed = distributed self.local_rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() def add_tensor(self, key: str, tensor: torch.Tensor): if key not in self.values: self.counts[key] = 1 if type(tensor) == float or type(tensor) == int: self.values[key] = tensor else: self.values[key] = tensor.mean().item() else: self.counts[key] += 1 if type(tensor) == float or type(tensor) == int: self.values[key] += tensor else: self.values[key] += tensor.mean().item() def add_dict(self, tensor_dict: Dict[str, torch.Tensor]): for k, v in tensor_dict.items(): self.add_tensor(k, v) def add_hook(self, hook: Callable[[torch.Tensor], Tuple[str, torch.Tensor]]): """ Adds a custom hook, i.e. compute new metrics using values in the dict The hook takes the dict as argument, and returns a (k, v) tuple e.g. for computing IoU """ if type(hook) == list: self.hooks.extend(hook) else: self.hooks.append(hook) def reset_except_hooks(self): self.values = {} self.counts = {} # Average and output the metrics def finalize(self, exp_id: str, prefix: str, it: int) -> None: for hook in self.hooks: k, v = hook(self.values) self.add_tensor(k, v) outputs = {} for k, v in self.values.items(): if k[:4] == 'hide': continue avg = v / self.counts[k] if self.distributed: # Inplace operation avg = torch.tensor(avg).cuda() torch.distributed.reduce(avg, dst=0) if self.local_rank == 0: avg = (avg / self.world_size).cpu().item() outputs[k] = avg else: # Simple does it outputs[k] = avg if (not self.distributed) or (self.local_rank == 0): self.logger.log_metrics(exp_id, prefix, outputs, it)