sczhou's picture
init code
320e465
raw
history blame
5.54 kB
from typing import Union, List, Dict
import torch
from tracker.inference.object_info import ObjectInfo
class ObjectManager:
"""
Object IDs are immutable. The same ID always represent the same object.
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed.
Temporary IDs start from 1.
"""
def __init__(self):
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {}
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {}
self.obj_id_to_obj: Dict[int, ObjectInfo] = {}
self.all_historical_object_ids: List[int] = []
def _recompute_obj_id_to_obj_mapping(self) -> None:
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id}
def add_new_objects(
self, objects: Union[List[ObjectInfo], ObjectInfo,
List[int]]) -> (List[int], List[int]):
if not isinstance(objects, list):
objects = [objects]
corresponding_tmp_ids = []
corresponding_obj_ids = []
for obj in objects:
if isinstance(obj, int):
obj = ObjectInfo(id=obj)
if obj in self.obj_to_tmp_id:
# old object
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj])
corresponding_obj_ids.append(obj.id)
else:
# new object
new_obj = ObjectInfo(id=obj)
# new object
new_tmp_id = len(self.obj_to_tmp_id) + 1
self.obj_to_tmp_id[new_obj] = new_tmp_id
self.tmp_id_to_obj[new_tmp_id] = new_obj
self.all_historical_object_ids.append(new_obj.id)
corresponding_tmp_ids.append(new_tmp_id)
corresponding_obj_ids.append(new_obj.id)
self._recompute_obj_id_to_obj_mapping()
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids)
return corresponding_tmp_ids, corresponding_obj_ids
def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None:
# delete an object or a list of objects
# re-sort the tmp ids
if isinstance(obj_ids_to_remove, int):
obj_ids_to_remove = [obj_ids_to_remove]
new_tmp_id = 1
total_num_id = len(self.obj_to_tmp_id)
local_obj_to_tmp_id = {}
local_tmp_to_obj_id = {}
for tmp_iter in range(1, total_num_id + 1):
obj = self.tmp_id_to_obj[tmp_iter]
if obj.id not in obj_ids_to_remove:
local_obj_to_tmp_id[obj] = new_tmp_id
local_tmp_to_obj_id[new_tmp_id] = obj
new_tmp_id += 1
self.obj_to_tmp_id = local_obj_to_tmp_id
self.tmp_id_to_obj = local_tmp_to_obj_id
self._recompute_obj_id_to_obj_mapping()
def purge_inactive_objects(self,
max_missed_detection_count: int) -> (bool, List[int], List[int]):
# remove tmp ids of objects that are removed
obj_id_to_be_deleted = []
tmp_id_to_be_deleted = []
tmp_id_to_keep = []
obj_id_to_keep = []
for obj in self.obj_to_tmp_id:
if obj.poke_count > max_missed_detection_count:
obj_id_to_be_deleted.append(obj.id)
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj])
else:
tmp_id_to_keep.append(self.obj_to_tmp_id[obj])
obj_id_to_keep.append(obj.id)
purge_activated = len(obj_id_to_be_deleted) > 0
if purge_activated:
self.delete_object(obj_id_to_be_deleted)
return purge_activated, tmp_id_to_keep, obj_id_to_keep
def tmp_to_obj_cls(self, mask) -> torch.Tensor:
# remap tmp id cls representation to the true object id representation
new_mask = torch.zeros_like(mask)
for tmp_id, obj in self.tmp_id_to_obj.items():
new_mask[mask == tmp_id] = obj.id
return new_mask
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]:
# returns the mapping in a dict format for saving it with pickle
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()}
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor:
# turns a dict indexed by obj id into a tensor, ordered by tmp IDs
output = []
for _, obj in self.tmp_id_to_obj.items():
if obj.id not in obj_dict:
raise NotImplementedError
output.append(obj_dict[obj.id])
output = torch.stack(output, dim=dim)
return output
def make_one_hot(self, cls_mask) -> torch.Tensor:
output = []
for _, obj in self.tmp_id_to_obj.items():
output.append(cls_mask == obj.id)
if len(output) == 0:
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device)
else:
output = torch.stack(output, dim=0)
return output
@property
def all_obj_ids(self) -> List[int]:
return [k.id for k in self.obj_to_tmp_id]
@property
def num_obj(self) -> int:
return len(self.obj_to_tmp_id)
def has_all(self, objects: List[int]) -> bool:
for obj in objects:
if obj not in self.obj_to_tmp_id:
return False
return True
def find_object_by_id(self, obj_id) -> ObjectInfo:
return self.obj_id_to_obj[obj_id]
def find_tmp_by_id(self, obj_id) -> int:
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]]