Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Dict, List, Optional, Literal | |
| from collections import defaultdict | |
| import torch | |
| def _add_last_dim(dictionary, key, new_value, prepend=False): | |
| # append/prepend a new value to the last dimension of a tensor in a dictionary | |
| # if the key does not exist, put the new value in | |
| # append by default | |
| if key in dictionary: | |
| dictionary[key] = torch.cat([dictionary[key], new_value], -1) | |
| else: | |
| dictionary[key] = new_value | |
| class KeyValueMemoryStore: | |
| """ | |
| Works for key/value pairs type storage | |
| e.g., working and long-term memory | |
| """ | |
| def __init__(self, save_selection: bool = False, save_usage: bool = False): | |
| """ | |
| We store keys and values of objects that first appear in the same frame in a bucket. | |
| Each bucket contains a set of object ids. | |
| Each bucket is associated with a single key tensor | |
| and a dictionary of value tensors indexed by object id. | |
| The keys and values are stored as the concatenation of a permanent part and a temporary part. | |
| """ | |
| self.save_selection = save_selection | |
| self.save_usage = save_usage | |
| self.global_bucket_id = 0 # does not reduce even if buckets are removed | |
| self.buckets: Dict[int, List[int]] = {} # indexed by bucket id | |
| self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id | |
| self.v: Dict[int, torch.Tensor] = {} # indexed by object id | |
| # indexed by bucket id; the end point of permanent memory | |
| self.perm_end_pt: Dict[int, int] = defaultdict(int) | |
| # shrinkage and selection are just like the keys | |
| self.s = {} | |
| if self.save_selection: | |
| self.e = {} # does not contain the permanent memory part | |
| # usage | |
| if self.save_usage: | |
| self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part | |
| self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part | |
| def add(self, | |
| key: torch.Tensor, | |
| values: Dict[int, torch.Tensor], | |
| shrinkage: torch.Tensor, | |
| selection: torch.Tensor, | |
| supposed_bucket_id: int = -1, | |
| as_permanent: Literal['no', 'first', 'all'] = 'no') -> None: | |
| """ | |
| key: (1/2)*C*N | |
| values: dict of values ((1/2)*C*N), object ids are used as keys | |
| shrinkage: (1/2)*1*N | |
| selection: (1/2)*C*N | |
| supposed_bucket_id: used to sync the bucket id between working and long-term memory | |
| if provided, the input should all be in a single bucket indexed by this id | |
| as_permanent: whether to store the input as permanent memory | |
| 'no': don't | |
| 'first': only store it as permanent memory if the bucket is empty | |
| 'all': always store it as permanent memory | |
| """ | |
| bs = key.shape[0] | |
| ne = key.shape[-1] | |
| assert len(key.shape) == 3 | |
| assert len(shrinkage.shape) == 3 | |
| assert not self.save_selection or len(selection.shape) == 3 | |
| assert as_permanent in ['no', 'first', 'all'] | |
| # add the value and create new buckets if necessary | |
| if supposed_bucket_id >= 0: | |
| enabled_buckets = [supposed_bucket_id] | |
| bucket_exist = supposed_bucket_id in self.buckets | |
| for obj, value in values.items(): | |
| if bucket_exist: | |
| assert obj in self.v | |
| assert obj in self.buckets[supposed_bucket_id] | |
| _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) | |
| else: | |
| assert obj not in self.v | |
| self.v[obj] = value | |
| self.buckets[supposed_bucket_id] = list(values.keys()) | |
| else: | |
| new_bucket_id = None | |
| enabled_buckets = set() | |
| for obj, value in values.items(): | |
| assert len(value.shape) == 3 | |
| if obj in self.v: | |
| _add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all')) | |
| bucket_used = [ | |
| bucket_id for bucket_id, object_ids in self.buckets.items() | |
| if obj in object_ids | |
| ] | |
| assert len(bucket_used) == 1 # each object should only be in one bucket | |
| enabled_buckets.add(bucket_used[0]) | |
| else: | |
| self.v[obj] = value | |
| if new_bucket_id is None: | |
| # create new bucket | |
| new_bucket_id = self.global_bucket_id | |
| self.global_bucket_id += 1 | |
| self.buckets[new_bucket_id] = [] | |
| # put the new object into the corresponding bucket | |
| self.buckets[new_bucket_id].append(obj) | |
| enabled_buckets.add(new_bucket_id) | |
| # increment the permanent size if necessary | |
| add_as_permanent = {} # indexed by bucket id | |
| for bucket_id in enabled_buckets: | |
| add_as_permanent[bucket_id] = False | |
| if as_permanent == 'all': | |
| self.perm_end_pt[bucket_id] += ne | |
| add_as_permanent[bucket_id] = True | |
| elif as_permanent == 'first': | |
| if self.perm_end_pt[bucket_id] == 0: | |
| self.perm_end_pt[bucket_id] = ne | |
| add_as_permanent[bucket_id] = True | |
| # create new counters for usage if necessary | |
| if self.save_usage and as_permanent != 'all': | |
| new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) | |
| new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7 | |
| # add the key to every bucket | |
| for bucket_id in self.buckets: | |
| if bucket_id not in enabled_buckets: | |
| # if we are not adding new values to a bucket, we should skip it | |
| continue | |
| _add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id]) | |
| _add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id]) | |
| if not add_as_permanent[bucket_id]: | |
| if self.save_selection: | |
| _add_last_dim(self.e, bucket_id, selection) | |
| if self.save_usage: | |
| _add_last_dim(self.use_cnt, bucket_id, new_count) | |
| _add_last_dim(self.life_cnt, bucket_id, new_life) | |
| def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None: | |
| # increase all life count by 1 | |
| # increase use of indexed elements | |
| if not self.save_usage: | |
| return | |
| usage = usage[:, self.perm_end_pt[bucket_id]:] | |
| if usage.shape[-1] == 0: | |
| # if there is no temporary memory, we don't need to update | |
| return | |
| self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id]) | |
| self.life_cnt[bucket_id] += 1 | |
| def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None: | |
| # keep only the temporary elements *outside* of this range (with some boundary conditions) | |
| # the permanent elements are ignored in this computation | |
| # i.e., concat (a[:start], a[end:]) | |
| # bucket with size <= min_size are not modified | |
| assert start >= 0 | |
| assert end <= 0 | |
| object_ids = self.buckets[bucket_id] | |
| bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id] | |
| if bucket_num_elements <= min_size: | |
| return | |
| if end == 0: | |
| # negative 0 would not work as the end index! | |
| # effectively make the second part an empty slice | |
| end = self.k[bucket_id].shape[-1] + 1 | |
| p_size = self.perm_end_pt[bucket_id] | |
| start = start + p_size | |
| k = self.k[bucket_id] | |
| s = self.s[bucket_id] | |
| if self.save_selection: | |
| e = self.e[bucket_id] | |
| if self.save_usage: | |
| use_cnt = self.use_cnt[bucket_id] | |
| life_cnt = self.life_cnt[bucket_id] | |
| self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1) | |
| self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1) | |
| if self.save_selection: | |
| self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1) | |
| if self.save_usage: | |
| self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1) | |
| self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]], | |
| -1) | |
| for obj_id in object_ids: | |
| v = self.v[obj_id] | |
| self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1) | |
| def remove_old_memory(self, bucket_id: int, max_len: int) -> None: | |
| self.sieve_by_range(bucket_id, 0, -max_len, max_len) | |
| def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None: | |
| # for long-term memory only | |
| object_ids = self.buckets[bucket_id] | |
| assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory | |
| # normalize with life duration | |
| usage = self.get_usage(bucket_id) | |
| bs = usage.shape[0] | |
| survivals = [] | |
| for bi in range(bs): | |
| _, survived = torch.topk(usage[bi], k=max_size) | |
| survivals.append(survived.flatten()) | |
| assert survived.shape[-1] == survivals[0].shape[-1] | |
| self.k[bucket_id] = torch.stack( | |
| [self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
| self.s[bucket_id] = torch.stack( | |
| [self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
| if self.save_selection: | |
| # Long-term memory does not store selection so this should not be needed | |
| self.e[bucket_id] = torch.stack( | |
| [self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
| for obj_id in object_ids: | |
| self.v[obj_id] = torch.stack( | |
| [self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0) | |
| self.use_cnt[bucket_id] = torch.stack( | |
| [self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) | |
| self.life_cnt[bucket_id] = torch.stack( | |
| [self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0) | |
| def get_usage(self, bucket_id: int) -> torch.Tensor: | |
| # return normalized usage | |
| if not self.save_usage: | |
| raise RuntimeError('I did not count usage!') | |
| else: | |
| usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id] | |
| return usage | |
| def get_all_sliced( | |
| self, bucket_id: int, start: int, end: int | |
| ) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor): | |
| # return k, sk, ek, value, normalized usage in order, sliced by start and end | |
| # this only queries the temporary memory | |
| assert start >= 0 | |
| assert end <= 0 | |
| p_size = self.perm_end_pt[bucket_id] | |
| start = start + p_size | |
| if end == 0: | |
| # negative 0 would not work as the end index! | |
| k = self.k[bucket_id][:, :, start:] | |
| sk = self.s[bucket_id][:, :, start:] | |
| ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None | |
| value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]} | |
| usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None | |
| else: | |
| k = self.k[bucket_id][:, :, start:end] | |
| sk = self.s[bucket_id][:, :, start:end] | |
| ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None | |
| value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]} | |
| usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None | |
| return k, sk, ek, value, usage | |
| def purge_except(self, obj_keep_idx: List[int]): | |
| # purge certain objects from the memory except the one listed | |
| obj_keep_idx = set(obj_keep_idx) | |
| # remove objects that are not in the keep list from the buckets | |
| buckets_to_remove = [] | |
| for bucket_id, object_ids in self.buckets.items(): | |
| self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx] | |
| if len(self.buckets[bucket_id]) == 0: | |
| buckets_to_remove.append(bucket_id) | |
| # remove object values that are not in the keep list | |
| self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx} | |
| # remove buckets that are empty | |
| for bucket_id in buckets_to_remove: | |
| del self.buckets[bucket_id] | |
| del self.k[bucket_id] | |
| del self.s[bucket_id] | |
| if self.save_selection: | |
| del self.e[bucket_id] | |
| if self.save_usage: | |
| del self.use_cnt[bucket_id] | |
| del self.life_cnt[bucket_id] | |
| def clear_non_permanent_memory(self): | |
| # clear all non-permanent memory | |
| for bucket_id in self.buckets: | |
| self.sieve_by_range(bucket_id, 0, 0, 0) | |
| def get_v_size(self, obj_id: int) -> int: | |
| return self.v[obj_id].shape[-1] | |
| def size(self, bucket_id: int) -> int: | |
| if bucket_id not in self.k: | |
| return 0 | |
| else: | |
| return self.k[bucket_id].shape[-1] | |
| def perm_size(self, bucket_id: int) -> int: | |
| return self.perm_end_pt[bucket_id] | |
| def non_perm_size(self, bucket_id: int) -> int: | |
| return self.size(bucket_id) - self.perm_size(bucket_id) | |
| def engaged(self, bucket_id: Optional[int] = None) -> bool: | |
| if bucket_id is None: | |
| return len(self.buckets) > 0 | |
| else: | |
| return bucket_id in self.buckets | |
| def num_objects(self) -> int: | |
| return len(self.v) | |
| def key(self) -> Dict[int, torch.Tensor]: | |
| return self.k | |
| def value(self) -> Dict[int, torch.Tensor]: | |
| return self.v | |
| def shrinkage(self) -> Dict[int, torch.Tensor]: | |
| return self.s | |
| def selection(self) -> Dict[int, torch.Tensor]: | |
| return self.e | |
| def __contains__(self, key): | |
| return key in self.v | |