from typing import Union, List, Dict import torch from tracker.inference.object_info import ObjectInfo class ObjectManager: """ Object IDs are immutable. The same ID always represent the same object. Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. Temporary IDs start from 1. """ def __init__(self): self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} self.obj_id_to_obj: Dict[int, ObjectInfo] = {} self.all_historical_object_ids: List[int] = [] def _recompute_obj_id_to_obj_mapping(self) -> None: self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} def add_new_objects( self, objects: Union[List[ObjectInfo], ObjectInfo, List[int]]) -> (List[int], List[int]): if not isinstance(objects, list): objects = [objects] corresponding_tmp_ids = [] corresponding_obj_ids = [] for obj in objects: if isinstance(obj, int): obj = ObjectInfo(id=obj) if obj in self.obj_to_tmp_id: # old object corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) corresponding_obj_ids.append(obj.id) else: # new object new_obj = ObjectInfo(id=obj) # new object new_tmp_id = len(self.obj_to_tmp_id) + 1 self.obj_to_tmp_id[new_obj] = new_tmp_id self.tmp_id_to_obj[new_tmp_id] = new_obj self.all_historical_object_ids.append(new_obj.id) corresponding_tmp_ids.append(new_tmp_id) corresponding_obj_ids.append(new_obj.id) self._recompute_obj_id_to_obj_mapping() assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) return corresponding_tmp_ids, corresponding_obj_ids def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: # delete an object or a list of objects # re-sort the tmp ids if isinstance(obj_ids_to_remove, int): obj_ids_to_remove = [obj_ids_to_remove] new_tmp_id = 1 total_num_id = len(self.obj_to_tmp_id) local_obj_to_tmp_id = {} local_tmp_to_obj_id = {} for tmp_iter in range(1, total_num_id + 1): obj = self.tmp_id_to_obj[tmp_iter] if obj.id not in obj_ids_to_remove: local_obj_to_tmp_id[obj] = new_tmp_id local_tmp_to_obj_id[new_tmp_id] = obj new_tmp_id += 1 self.obj_to_tmp_id = local_obj_to_tmp_id self.tmp_id_to_obj = local_tmp_to_obj_id self._recompute_obj_id_to_obj_mapping() def purge_inactive_objects(self, max_missed_detection_count: int) -> (bool, List[int], List[int]): # remove tmp ids of objects that are removed obj_id_to_be_deleted = [] tmp_id_to_be_deleted = [] tmp_id_to_keep = [] obj_id_to_keep = [] for obj in self.obj_to_tmp_id: if obj.poke_count > max_missed_detection_count: obj_id_to_be_deleted.append(obj.id) tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) else: tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) obj_id_to_keep.append(obj.id) purge_activated = len(obj_id_to_be_deleted) > 0 if purge_activated: self.delete_object(obj_id_to_be_deleted) return purge_activated, tmp_id_to_keep, obj_id_to_keep def tmp_to_obj_cls(self, mask) -> torch.Tensor: # remap tmp id cls representation to the true object id representation new_mask = torch.zeros_like(mask) for tmp_id, obj in self.tmp_id_to_obj.items(): new_mask[mask == tmp_id] = obj.id return new_mask def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: # returns the mapping in a dict format for saving it with pickle return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: # turns a dict indexed by obj id into a tensor, ordered by tmp IDs output = [] for _, obj in self.tmp_id_to_obj.items(): if obj.id not in obj_dict: raise NotImplementedError output.append(obj_dict[obj.id]) output = torch.stack(output, dim=dim) return output def make_one_hot(self, cls_mask) -> torch.Tensor: output = [] for _, obj in self.tmp_id_to_obj.items(): output.append(cls_mask == obj.id) if len(output) == 0: output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) else: output = torch.stack(output, dim=0) return output @property def all_obj_ids(self) -> List[int]: return [k.id for k in self.obj_to_tmp_id] @property def num_obj(self) -> int: return len(self.obj_to_tmp_id) def has_all(self, objects: List[int]) -> bool: for obj in objects: if obj not in self.obj_to_tmp_id: return False return True def find_object_by_id(self, obj_id) -> ObjectInfo: return self.obj_id_to_obj[obj_id] def find_tmp_by_id(self, obj_id) -> int: return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]