sczhou's picture
init code
320e465
raw
history blame
14.5 kB
from typing import Dict, List, Optional, Literal
from collections import defaultdict
import torch
def _add_last_dim(dictionary, key, new_value, prepend=False):
# append/prepend a new value to the last dimension of a tensor in a dictionary
# if the key does not exist, put the new value in
# append by default
if key in dictionary:
dictionary[key] = torch.cat([dictionary[key], new_value], -1)
else:
dictionary[key] = new_value
class KeyValueMemoryStore:
"""
Works for key/value pairs type storage
e.g., working and long-term memory
"""
def __init__(self, save_selection: bool = False, save_usage: bool = False):
"""
We store keys and values of objects that first appear in the same frame in a bucket.
Each bucket contains a set of object ids.
Each bucket is associated with a single key tensor
and a dictionary of value tensors indexed by object id.
The keys and values are stored as the concatenation of a permanent part and a temporary part.
"""
self.save_selection = save_selection
self.save_usage = save_usage
self.global_bucket_id = 0 # does not reduce even if buckets are removed
self.buckets: Dict[int, List[int]] = {} # indexed by bucket id
self.k: Dict[int, torch.Tensor] = {} # indexed by bucket id
self.v: Dict[int, torch.Tensor] = {} # indexed by object id
# indexed by bucket id; the end point of permanent memory
self.perm_end_pt: Dict[int, int] = defaultdict(int)
# shrinkage and selection are just like the keys
self.s = {}
if self.save_selection:
self.e = {} # does not contain the permanent memory part
# usage
if self.save_usage:
self.use_cnt = {} # indexed by bucket id, does not contain the permanent memory part
self.life_cnt = {} # indexed by bucket id, does not contain the permanent memory part
def add(self,
key: torch.Tensor,
values: Dict[int, torch.Tensor],
shrinkage: torch.Tensor,
selection: torch.Tensor,
supposed_bucket_id: int = -1,
as_permanent: Literal['no', 'first', 'all'] = 'no') -> None:
"""
key: (1/2)*C*N
values: dict of values ((1/2)*C*N), object ids are used as keys
shrinkage: (1/2)*1*N
selection: (1/2)*C*N
supposed_bucket_id: used to sync the bucket id between working and long-term memory
if provided, the input should all be in a single bucket indexed by this id
as_permanent: whether to store the input as permanent memory
'no': don't
'first': only store it as permanent memory if the bucket is empty
'all': always store it as permanent memory
"""
bs = key.shape[0]
ne = key.shape[-1]
assert len(key.shape) == 3
assert len(shrinkage.shape) == 3
assert not self.save_selection or len(selection.shape) == 3
assert as_permanent in ['no', 'first', 'all']
# add the value and create new buckets if necessary
if supposed_bucket_id >= 0:
enabled_buckets = [supposed_bucket_id]
bucket_exist = supposed_bucket_id in self.buckets
for obj, value in values.items():
if bucket_exist:
assert obj in self.v
assert obj in self.buckets[supposed_bucket_id]
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
else:
assert obj not in self.v
self.v[obj] = value
self.buckets[supposed_bucket_id] = list(values.keys())
else:
new_bucket_id = None
enabled_buckets = set()
for obj, value in values.items():
assert len(value.shape) == 3
if obj in self.v:
_add_last_dim(self.v, obj, value, prepend=(as_permanent == 'all'))
bucket_used = [
bucket_id for bucket_id, object_ids in self.buckets.items()
if obj in object_ids
]
assert len(bucket_used) == 1 # each object should only be in one bucket
enabled_buckets.add(bucket_used[0])
else:
self.v[obj] = value
if new_bucket_id is None:
# create new bucket
new_bucket_id = self.global_bucket_id
self.global_bucket_id += 1
self.buckets[new_bucket_id] = []
# put the new object into the corresponding bucket
self.buckets[new_bucket_id].append(obj)
enabled_buckets.add(new_bucket_id)
# increment the permanent size if necessary
add_as_permanent = {} # indexed by bucket id
for bucket_id in enabled_buckets:
add_as_permanent[bucket_id] = False
if as_permanent == 'all':
self.perm_end_pt[bucket_id] += ne
add_as_permanent[bucket_id] = True
elif as_permanent == 'first':
if self.perm_end_pt[bucket_id] == 0:
self.perm_end_pt[bucket_id] = ne
add_as_permanent[bucket_id] = True
# create new counters for usage if necessary
if self.save_usage and as_permanent != 'all':
new_count = torch.zeros((bs, ne), device=key.device, dtype=torch.float32)
new_life = torch.zeros((bs, ne), device=key.device, dtype=torch.float32) + 1e-7
# add the key to every bucket
for bucket_id in self.buckets:
if bucket_id not in enabled_buckets:
# if we are not adding new values to a bucket, we should skip it
continue
_add_last_dim(self.k, bucket_id, key, prepend=add_as_permanent[bucket_id])
_add_last_dim(self.s, bucket_id, shrinkage, prepend=add_as_permanent[bucket_id])
if not add_as_permanent[bucket_id]:
if self.save_selection:
_add_last_dim(self.e, bucket_id, selection)
if self.save_usage:
_add_last_dim(self.use_cnt, bucket_id, new_count)
_add_last_dim(self.life_cnt, bucket_id, new_life)
def update_bucket_usage(self, bucket_id: int, usage: torch.Tensor) -> None:
# increase all life count by 1
# increase use of indexed elements
if not self.save_usage:
return
usage = usage[:, self.perm_end_pt[bucket_id]:]
if usage.shape[-1] == 0:
# if there is no temporary memory, we don't need to update
return
self.use_cnt[bucket_id] += usage.view_as(self.use_cnt[bucket_id])
self.life_cnt[bucket_id] += 1
def sieve_by_range(self, bucket_id: int, start: int, end: int, min_size: int) -> None:
# keep only the temporary elements *outside* of this range (with some boundary conditions)
# the permanent elements are ignored in this computation
# i.e., concat (a[:start], a[end:])
# bucket with size <= min_size are not modified
assert start >= 0
assert end <= 0
object_ids = self.buckets[bucket_id]
bucket_num_elements = self.k[bucket_id].shape[-1] - self.perm_end_pt[bucket_id]
if bucket_num_elements <= min_size:
return
if end == 0:
# negative 0 would not work as the end index!
# effectively make the second part an empty slice
end = self.k[bucket_id].shape[-1] + 1
p_size = self.perm_end_pt[bucket_id]
start = start + p_size
k = self.k[bucket_id]
s = self.s[bucket_id]
if self.save_selection:
e = self.e[bucket_id]
if self.save_usage:
use_cnt = self.use_cnt[bucket_id]
life_cnt = self.life_cnt[bucket_id]
self.k[bucket_id] = torch.cat([k[:, :, :start], k[:, :, end:]], -1)
self.s[bucket_id] = torch.cat([s[:, :, :start], s[:, :, end:]], -1)
if self.save_selection:
self.e[bucket_id] = torch.cat([e[:, :, :start - p_size], e[:, :, end:]], -1)
if self.save_usage:
self.use_cnt[bucket_id] = torch.cat([use_cnt[:, :start - p_size], use_cnt[:, end:]], -1)
self.life_cnt[bucket_id] = torch.cat([life_cnt[:, :start - p_size], life_cnt[:, end:]],
-1)
for obj_id in object_ids:
v = self.v[obj_id]
self.v[obj_id] = torch.cat([v[:, :, :start], v[:, :, end:]], -1)
def remove_old_memory(self, bucket_id: int, max_len: int) -> None:
self.sieve_by_range(bucket_id, 0, -max_len, max_len)
def remove_obsolete_features(self, bucket_id: int, max_size: int) -> None:
# for long-term memory only
object_ids = self.buckets[bucket_id]
assert self.perm_end_pt[bucket_id] == 0 # permanent memory should be empty in LT memory
# normalize with life duration
usage = self.get_usage(bucket_id)
bs = usage.shape[0]
survivals = []
for bi in range(bs):
_, survived = torch.topk(usage[bi], k=max_size)
survivals.append(survived.flatten())
assert survived.shape[-1] == survivals[0].shape[-1]
self.k[bucket_id] = torch.stack(
[self.k[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
self.s[bucket_id] = torch.stack(
[self.s[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
if self.save_selection:
# Long-term memory does not store selection so this should not be needed
self.e[bucket_id] = torch.stack(
[self.e[bucket_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
for obj_id in object_ids:
self.v[obj_id] = torch.stack(
[self.v[obj_id][bi, :, survived] for bi, survived in enumerate(survivals)], 0)
self.use_cnt[bucket_id] = torch.stack(
[self.use_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
self.life_cnt[bucket_id] = torch.stack(
[self.life_cnt[bucket_id][bi, survived] for bi, survived in enumerate(survivals)], 0)
def get_usage(self, bucket_id: int) -> torch.Tensor:
# return normalized usage
if not self.save_usage:
raise RuntimeError('I did not count usage!')
else:
usage = self.use_cnt[bucket_id] / self.life_cnt[bucket_id]
return usage
def get_all_sliced(
self, bucket_id: int, start: int, end: int
) -> (torch.Tensor, torch.Tensor, torch.Tensor, Dict[int, torch.Tensor], torch.Tensor):
# return k, sk, ek, value, normalized usage in order, sliced by start and end
# this only queries the temporary memory
assert start >= 0
assert end <= 0
p_size = self.perm_end_pt[bucket_id]
start = start + p_size
if end == 0:
# negative 0 would not work as the end index!
k = self.k[bucket_id][:, :, start:]
sk = self.s[bucket_id][:, :, start:]
ek = self.e[bucket_id][:, :, start - p_size:] if self.save_selection else None
value = {obj_id: self.v[obj_id][:, :, start:] for obj_id in self.buckets[bucket_id]}
usage = self.get_usage(bucket_id)[:, start - p_size:] if self.save_usage else None
else:
k = self.k[bucket_id][:, :, start:end]
sk = self.s[bucket_id][:, :, start:end]
ek = self.e[bucket_id][:, :, start - p_size:end] if self.save_selection else None
value = {obj_id: self.v[obj_id][:, :, start:end] for obj_id in self.buckets[bucket_id]}
usage = self.get_usage(bucket_id)[:, start - p_size:end] if self.save_usage else None
return k, sk, ek, value, usage
def purge_except(self, obj_keep_idx: List[int]):
# purge certain objects from the memory except the one listed
obj_keep_idx = set(obj_keep_idx)
# remove objects that are not in the keep list from the buckets
buckets_to_remove = []
for bucket_id, object_ids in self.buckets.items():
self.buckets[bucket_id] = [obj_id for obj_id in object_ids if obj_id in obj_keep_idx]
if len(self.buckets[bucket_id]) == 0:
buckets_to_remove.append(bucket_id)
# remove object values that are not in the keep list
self.v = {k: v for k, v in self.v.items() if k in obj_keep_idx}
# remove buckets that are empty
for bucket_id in buckets_to_remove:
del self.buckets[bucket_id]
del self.k[bucket_id]
del self.s[bucket_id]
if self.save_selection:
del self.e[bucket_id]
if self.save_usage:
del self.use_cnt[bucket_id]
del self.life_cnt[bucket_id]
def clear_non_permanent_memory(self):
# clear all non-permanent memory
for bucket_id in self.buckets:
self.sieve_by_range(bucket_id, 0, 0, 0)
def get_v_size(self, obj_id: int) -> int:
return self.v[obj_id].shape[-1]
def size(self, bucket_id: int) -> int:
if bucket_id not in self.k:
return 0
else:
return self.k[bucket_id].shape[-1]
def perm_size(self, bucket_id: int) -> int:
return self.perm_end_pt[bucket_id]
def non_perm_size(self, bucket_id: int) -> int:
return self.size(bucket_id) - self.perm_size(bucket_id)
def engaged(self, bucket_id: Optional[int] = None) -> bool:
if bucket_id is None:
return len(self.buckets) > 0
else:
return bucket_id in self.buckets
@property
def num_objects(self) -> int:
return len(self.v)
@property
def key(self) -> Dict[int, torch.Tensor]:
return self.k
@property
def value(self) -> Dict[int, torch.Tensor]:
return self.v
@property
def shrinkage(self) -> Dict[int, torch.Tensor]:
return self.s
@property
def selection(self) -> Dict[int, torch.Tensor]:
return self.e
def __contains__(self, key):
return key in self.v