Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Union, List, Dict | |
| import torch | |
| from tracker.inference.object_info import ObjectInfo | |
| class ObjectManager: | |
| """ | |
| Object IDs are immutable. The same ID always represent the same object. | |
| Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. | |
| Temporary IDs start from 1. | |
| """ | |
| def __init__(self): | |
| self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} | |
| self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} | |
| self.obj_id_to_obj: Dict[int, ObjectInfo] = {} | |
| self.all_historical_object_ids: List[int] = [] | |
| def _recompute_obj_id_to_obj_mapping(self) -> None: | |
| self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} | |
| def add_new_objects( | |
| self, objects: Union[List[ObjectInfo], ObjectInfo, | |
| List[int]]) -> (List[int], List[int]): | |
| if not isinstance(objects, list): | |
| objects = [objects] | |
| corresponding_tmp_ids = [] | |
| corresponding_obj_ids = [] | |
| for obj in objects: | |
| if isinstance(obj, int): | |
| obj = ObjectInfo(id=obj) | |
| if obj in self.obj_to_tmp_id: | |
| # old object | |
| corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) | |
| corresponding_obj_ids.append(obj.id) | |
| else: | |
| # new object | |
| new_obj = ObjectInfo(id=obj) | |
| # new object | |
| new_tmp_id = len(self.obj_to_tmp_id) + 1 | |
| self.obj_to_tmp_id[new_obj] = new_tmp_id | |
| self.tmp_id_to_obj[new_tmp_id] = new_obj | |
| self.all_historical_object_ids.append(new_obj.id) | |
| corresponding_tmp_ids.append(new_tmp_id) | |
| corresponding_obj_ids.append(new_obj.id) | |
| self._recompute_obj_id_to_obj_mapping() | |
| assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) | |
| return corresponding_tmp_ids, corresponding_obj_ids | |
| def delete_object(self, obj_ids_to_remove: Union[int, List[int]]) -> None: | |
| # delete an object or a list of objects | |
| # re-sort the tmp ids | |
| if isinstance(obj_ids_to_remove, int): | |
| obj_ids_to_remove = [obj_ids_to_remove] | |
| new_tmp_id = 1 | |
| total_num_id = len(self.obj_to_tmp_id) | |
| local_obj_to_tmp_id = {} | |
| local_tmp_to_obj_id = {} | |
| for tmp_iter in range(1, total_num_id + 1): | |
| obj = self.tmp_id_to_obj[tmp_iter] | |
| if obj.id not in obj_ids_to_remove: | |
| local_obj_to_tmp_id[obj] = new_tmp_id | |
| local_tmp_to_obj_id[new_tmp_id] = obj | |
| new_tmp_id += 1 | |
| self.obj_to_tmp_id = local_obj_to_tmp_id | |
| self.tmp_id_to_obj = local_tmp_to_obj_id | |
| self._recompute_obj_id_to_obj_mapping() | |
| def purge_inactive_objects(self, | |
| max_missed_detection_count: int) -> (bool, List[int], List[int]): | |
| # remove tmp ids of objects that are removed | |
| obj_id_to_be_deleted = [] | |
| tmp_id_to_be_deleted = [] | |
| tmp_id_to_keep = [] | |
| obj_id_to_keep = [] | |
| for obj in self.obj_to_tmp_id: | |
| if obj.poke_count > max_missed_detection_count: | |
| obj_id_to_be_deleted.append(obj.id) | |
| tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) | |
| else: | |
| tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) | |
| obj_id_to_keep.append(obj.id) | |
| purge_activated = len(obj_id_to_be_deleted) > 0 | |
| if purge_activated: | |
| self.delete_object(obj_id_to_be_deleted) | |
| return purge_activated, tmp_id_to_keep, obj_id_to_keep | |
| def tmp_to_obj_cls(self, mask) -> torch.Tensor: | |
| # remap tmp id cls representation to the true object id representation | |
| new_mask = torch.zeros_like(mask) | |
| for tmp_id, obj in self.tmp_id_to_obj.items(): | |
| new_mask[mask == tmp_id] = obj.id | |
| return new_mask | |
| def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: | |
| # returns the mapping in a dict format for saving it with pickle | |
| return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} | |
| def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: | |
| # turns a dict indexed by obj id into a tensor, ordered by tmp IDs | |
| output = [] | |
| for _, obj in self.tmp_id_to_obj.items(): | |
| if obj.id not in obj_dict: | |
| raise NotImplementedError | |
| output.append(obj_dict[obj.id]) | |
| output = torch.stack(output, dim=dim) | |
| return output | |
| def make_one_hot(self, cls_mask) -> torch.Tensor: | |
| output = [] | |
| for _, obj in self.tmp_id_to_obj.items(): | |
| output.append(cls_mask == obj.id) | |
| if len(output) == 0: | |
| output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) | |
| else: | |
| output = torch.stack(output, dim=0) | |
| return output | |
| def all_obj_ids(self) -> List[int]: | |
| return [k.id for k in self.obj_to_tmp_id] | |
| def num_obj(self) -> int: | |
| return len(self.obj_to_tmp_id) | |
| def has_all(self, objects: List[int]) -> bool: | |
| for obj in objects: | |
| if obj not in self.obj_to_tmp_id: | |
| return False | |
| return True | |
| def find_object_by_id(self, obj_id) -> ObjectInfo: | |
| return self.obj_id_to_obj[obj_id] | |
| def find_tmp_by_id(self, obj_id) -> int: | |
| return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] | |