from copy import deepcopy from typing import Optional, Tuple import torch from flash_attn import flash_attn_func from transformers.modeling_outputs import CausalLMOutput from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention class CudaCache: def __init__(self, num_units, unit_size, dtype): self.num_units = num_units self.unit_size = unit_size self.dtype = dtype self.data = torch.empty((num_units, unit_size), device="cuda", dtype=dtype) self.idle_set = set(list(range(num_units))) def alloc(self): assert len(self.idle_set) > 0 idx = self.idle_set.pop() return self.data[idx], idx def delete(self, idx): assert idx not in self.idle_set self.idle_set.add(idx) class MemoryUnit: def __init__( self, kv: Tuple[torch.Tensor, torch.Tensor], cache: CudaCache, load_to_cache: bool = False, pin_memory: bool = False, ): self.cache = cache if kv[0].is_cuda: cpu_data = tuple(_t.contiguous().to("cpu", non_blocking=True) for _t in kv) else: cpu_data = tuple(_t.contiguous() for _t in kv) if pin_memory: cpu_data = tuple(_t.pin_memory() for _t in cpu_data) if load_to_cache: gpu_data, gpu_data_id = cache.alloc() gpu_data = gpu_data.view((2,) + kv[0].shape) gpu_data[0].copy_(kv[0], non_blocking=True) gpu_data[1].copy_(kv[1], non_blocking=True) event = torch.cuda.Event() event.record(torch.cuda.current_stream()) else: gpu_data, gpu_data_id = None, None event = None self.cpu_data = cpu_data self.gpu_data = gpu_data self.gpu_data_id = gpu_data_id self.event = event def load(self, target: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> bool: if self.gpu_data is not None: if target is not None: target[0].copy_(self.gpu_data[0], non_blocking=True) target[1].copy_(self.gpu_data[1], non_blocking=True) target_event = torch.cuda.Event() target_event.record(torch.cuda.current_stream()) else: target_event = None return False, target_event gpu_data, gpu_data_id = self.cache.alloc() gpu_data = gpu_data.view((2,) + self.cpu_data[0].shape) if target is not None: target[0].copy_(self.cpu_data[0], non_blocking=True) target[1].copy_(self.cpu_data[1], non_blocking=True) target_event = torch.cuda.Event() target_event.record(torch.cuda.current_stream()) gpu_data[0].copy_(target[0], non_blocking=True) gpu_data[1].copy_(target[1], non_blocking=True) else: gpu_data[0].copy_(self.cpu_data[0], non_blocking=True) gpu_data[1].copy_(self.cpu_data[1], non_blocking=True) event = torch.cuda.Event() event.record(torch.cuda.current_stream()) self.event = event self.gpu_data = gpu_data self.gpu_data_id = gpu_data_id return True, target_event def get(self): assert self.gpu_data is not None self.event.wait() return self.gpu_data def offload(self): assert self.gpu_data is not None self.event.wait() self.gpu_data = None self.cache.delete(self.gpu_data_id) self.gpu_data_id = None class VectorTensor: def __init__(self, hidden_size, element_dtype): init_cached_size = 16 self.data = torch.empty( (init_cached_size, hidden_size), dtype=element_dtype, device="cuda" ) self.length = 0 self.cache_size = init_cached_size self.hidden_size = hidden_size def append_cache(self): new_cache_size = self.cache_size * 2 data_shape = self.data.shape new_data = torch.empty( (new_cache_size,) + data_shape[1:], device="cuda", dtype=self.data.dtype ) new_data[: self.cache_size, ...].copy_(self.data) self.data = new_data self.cache_size = new_cache_size def append(self, tensor: torch.Tensor): assert tensor.dtype == self.data.dtype assert tensor.size(1) == self.hidden_size assert tensor.is_contiguous() append_l = tensor.size(0) while self.length + append_l > self.cache_size: self.append_cache() self.data[self.length : self.length + append_l, ...].copy_(tensor) self.length += append_l def get_data(self): return self.data[: self.length, ...] def get_topk(self, tensor: torch.Tensor, topk): # inner product assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size logits = torch.matmul(self.data[: self.length], tensor[:, None]).squeeze(dim=-1) assert logits.dim() == 1 and logits.size(0) == self.length return logits.topk(topk, dim=0).indices.cpu().tolist() def __len__(self): return self.length class Faiss: def __init__(self, hidden_size, element_dtype): import faiss # We use the CPU index here because the GPU index requires a long initialization time self.index = faiss.IndexFlatIP(hidden_size) self.hidden_size = hidden_size def append(self, tensor: torch.Tensor): assert tensor.dim() == 2 and tensor.size(1) == self.hidden_size self.index.add(tensor.cpu().float().numpy().astype("float32")) def get_data(self): raise ValueError def get_topk(self, tensor: torch.Tensor, topk): assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size xq = tensor[None, :].cpu().float().numpy().astype("float32") topk_index = self.index.search(xq, topk)[1][0].tolist() return topk_index def __len__(self): return self.index.ntotal GLOBAL_STREAM = None class ContextManager: def __init__( self, position_embedding, n_init, n_local, block_size, max_cached_block, topk, exc_block_size, score_decay: Optional[float] = None, repr_topk: int = 1, cache_strategy="lru", chunk_topk_calc: Optional[int] = None, async_global_stream: bool = False, pin_memory: bool = False, faiss: bool = False, perhead: bool = False, dense_decoding: bool = False, ): self.length = 0 self.position_embedding = position_embedding self.n_init = n_init self.n_local = n_local self.block_size = block_size self.max_cached_block = max_cached_block self.exc_block_size = exc_block_size self.score_decay = score_decay assert exc_block_size <= n_local # no global token in input self.topk = topk self.Attn = TritonMultiStageDotProductionAttention self.initialized = False self.repr_topk = repr_topk self.cache_strategy = cache_strategy self.load_count = 0 self.chunk_topk_calc = chunk_topk_calc self.async_global_stream = async_global_stream self.pin_memory = pin_memory self.faiss = faiss self.perhead = perhead self.dense_decoding = dense_decoding global GLOBAL_STREAM if self.async_global_stream and GLOBAL_STREAM is None: GLOBAL_STREAM = torch.cuda.Stream() assert cache_strategy in ["lru", "lru-s"] if cache_strategy == "lru-s": self.calc_block_score = True else: self.calc_block_score = False def remove_lru_blocks( self, u, num_remove: Optional[int] = None, ignore_blocks=None ): if num_remove is None: num_remove = len(self.cached_blocks[u]) - self.max_cached_block if num_remove <= 0: return lst = list(self.cached_blocks[u].items()) lst.sort(key=lambda x: x[1]) removed = 0 for i in range(len(lst)): idx = lst[i][0] if ignore_blocks is None or (idx not in ignore_blocks): self.global_blocks[u][idx].offload() self.cached_blocks[u].pop(idx) removed += 1 if removed >= num_remove: return def get_block_k(self, k, score): assert isinstance(score, torch.Tensor) assert k.dim() >= 2 k = self.from_group_kv(k) assert k.shape[:-1] == score.shape assert k.shape[-2] == self.block_size score_topk = score.topk(self.repr_topk, dim=-1).indices assert score_topk.shape == (self.num_units, self.unit_size, self.repr_topk) ret = torch.gather( k, -2, score_topk[:, :, :, None].expand( self.num_units, self.unit_size, self.repr_topk, self.dim_head ), ) return ret def from_group_kv(self, tensor): assert tensor.dim() == 4 assert tensor.size(1) == self.num_heads_kv if self.num_heads == self.num_heads_kv: return tensor _, _, length, dim_head = tensor.shape num_group = self.num_heads // self.num_heads_kv tensor = tensor.view((self.num_units, self.unit_size_kv, 1, length, dim_head)) tensor = tensor.expand( (self.num_units, self.unit_size_kv, num_group, length, dim_head) ).reshape((self.num_units, self.num_heads, length, dim_head)) return tensor def init(self, local_q, local_k, local_v, global_q, global_k, global_v): assert local_q.dim() == 4 batch_size, num_heads, len_q, dim_head = local_q.shape num_heads_kv = local_k.size(1) for _t in [local_q, local_k, local_v, global_q, global_k, global_v]: assert _t.size(0) == batch_size assert _t.size(1) == num_heads or _t.size(1) == num_heads_kv assert _t.size(2) == len_q assert _t.size(3) == dim_head assert _t.is_cuda self.batch_size = batch_size self.num_heads = num_heads self.num_heads_kv = num_heads_kv self.dim_head = dim_head self.num_units = batch_size self.unit_size = num_heads self.unit_size_kv = num_heads_kv self.global_blocks = [[] for _ in range(self.num_units)] # [[memory_unit]] self.cached_blocks = [ {} for _ in range(self.num_units) ] # [[block_id: block_score] self.num_global_block = 0 if self.faiss: self.block_k = [ Faiss(dim_head * self.unit_size, global_k.dtype) for _ in range(self.num_units) ] else: self.block_k = [ VectorTensor(dim_head * self.unit_size, global_k.dtype) for _ in range(self.num_units) ] self.local_k = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=local_k.dtype, device=local_k.device, ) self.local_v = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=local_v.dtype, device=local_v.device, ) if self.dense_decoding: self.dense_k = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=local_k.dtype, device=local_k.device, ) self.dense_v = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=local_v.dtype, device=local_v.device, ) self.global_remainder = ( torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=global_k.dtype, device=global_k.device, ), torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=global_v.dtype, device=global_v.device, ), ) self.global_remainder_local_score = torch.empty( (self.num_units, self.unit_size, 0), dtype=global_k.dtype, device=global_k.device, ) self.init_k = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=global_k.dtype, device=global_k.device, ) self.init_v = torch.empty( (self.num_units, self.unit_size_kv, 0, dim_head), dtype=global_k.dtype, device=global_k.device, ) self.init_exc = False self.dtype = local_q.dtype self.position_embedding._update_cos_sin_tables_len( self.n_local + self.exc_block_size + 1, local_k.device, local_k.dim() ) buffer_len = ( self.topk * self.block_size + self.exc_block_size + self.block_size + self.n_init ) self.global_buffer = torch.zeros( (2, self.num_units, self.unit_size_kv, buffer_len, dim_head), dtype=global_k.dtype, device=global_k.device, ) self.global_buffer_block_id_list = [ [-1] * self.topk for _ in range(self.num_units) ] self.global_buffer_init_st = 0 self.global_buffer_init_ed = 0 self.cuda_cache = CudaCache( self.max_cached_block * self.num_units, self.unit_size_kv * self.block_size * dim_head * 2, local_k.dtype, ) self.initialized = True def calc_block_topk(self, global_h_q): if not self._use_chunk_topk: if self.num_global_block <= self.topk: return [ list(range(len(self.global_blocks[0]))) for _ in range(self.num_units) ] global_h_q = global_h_q.mean(dim=2, keepdim=False) assert global_h_q.shape == (self.num_units, self.unit_size, self.dim_head) global_h_q = global_h_q.reshape( self.num_units, self.dim_head * self.unit_size ) ret = [] for u in range(self.num_units): ret.append(self.block_k[u].get_topk(global_h_q[u], self.topk)) else: return self._cached_topk[self._topk_cur] return ret def get_global_hidden_and_mask(self, len_q, block_topk): assert len(block_topk) == self.num_units global_block_map = [[] for _ in range(self.num_units)] global_remainder_len = max( self._global_remainder_ed - self._global_remainder_st + len_q - self.n_local, 0, ) init_len = self.init_k.size(-2) sliding_window = None global_h_k = self.global_buffer[0] global_h_v = self.global_buffer[1] block_num = len(block_topk[0]) for u in range(self.num_units): assert len(block_topk[u]) == block_num block_topk[u].sort() global_block_map[u] = deepcopy(self.global_buffer_block_id_list[u]) for b_idx in block_topk[u]: if b_idx in global_block_map[u]: continue st = -1 ed = -1 for j in range(self.topk): if ( global_block_map[u][j] == -1 or global_block_map[u][j] not in block_topk[u] ): st = j * self.block_size ed = st + self.block_size global_block_map[u][j] = b_idx break assert b_idx in self.cached_blocks[u] self.global_blocks[u][b_idx].load( (global_h_k[u, :, st:ed, :], global_h_v[u, :, st:ed, :]) ) init_st = block_num * self.block_size init_ed = init_st + init_len if ( self.global_buffer_init_st != init_st or self.global_buffer_init_ed != init_ed ): global_h_k[:, :, init_st:init_ed, :].copy_(self.init_k, non_blocking=True) global_h_v[:, :, init_st:init_ed, :].copy_(self.init_v, non_blocking=True) ed = init_ed rmd_st = init_ed rmd_ed = rmd_st + global_remainder_len ed = rmd_ed global_h_k[:, :, rmd_st:rmd_ed, :].copy_( self.global_remainder[0][ :, :, self._global_remainder_st : self._global_remainder_st + global_remainder_len, :, ], non_blocking=True, ) global_h_v[:, :, rmd_st:rmd_ed, :].copy_( self.global_remainder[1][ :, :, self._global_remainder_st : self._global_remainder_st + global_remainder_len, :, ], non_blocking=True, ) sliding_window = (self.global_remainder[0].size(-2) + rmd_st, self.n_local) self.global_buffer_block_id_list = deepcopy(global_block_map) self.global_buffer_init_st = init_st self.global_buffer_init_ed = init_ed for u in range(self.num_units): assert max(global_block_map[u][block_num:] + [-1]) == -1 assert min(global_block_map[u][:block_num] + [0]) > -1 global_block_map[u] = list(global_block_map[u][:block_num]) global_h_k = global_h_k[:, :, :ed, :] global_h_v = global_h_v[:, :, :ed, :] return global_h_k, global_h_v, sliding_window, global_block_map, block_num def update_block_score( self, global_score: torch.FloatTensor, global_block_map, global_block_num ): if global_score is not None: global_score = global_score[:, :, : global_block_num * self.block_size] assert global_score.shape == ( self.num_units, self.unit_size, global_block_num * self.block_size, ) global_score = global_score.view( self.num_units, self.unit_size, global_block_num, self.block_size ) global_score = global_score.sum(dim=-1).sum(dim=1) assert global_score.shape == (self.num_units, global_block_num) global_score = global_score.to( device="cpu", non_blocking=False ) # (num_units, global_block_num) for u in range(self.num_units): for k, v in self.cached_blocks[u].items(): self.cached_blocks[u][k] = v * self.score_decay score = global_score[u].tolist() assert len(score) >= len(global_block_map[u]) for s, i in zip(score, global_block_map[u]): self.cached_blocks[u][i] += s def _append(self, local_q, local_k, local_v, global_q): # get local_h_q, local_h_k, local_h_v local_h_q, local_h_k = self.position_embedding(local_q, local_k) local_h_v = local_v # calc local result first to overlap host-device communication attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device) attn.append( local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local ) # calc topk global repr k and load cache with torch.cuda.stream(GLOBAL_STREAM): block_topk = self.calc_block_topk(global_q) for u in range(self.num_units): num_remove = len(self.cached_blocks[u]) - self.max_cached_block for bidx in block_topk[u]: if bidx not in self.cached_blocks[u]: num_remove += 1 # update cache self.remove_lru_blocks(u, num_remove, block_topk[u]) if self.cache_strategy == "lru": self.load_count += 1 for u in range(self.num_units): for bidx in block_topk[u]: self.cached_blocks[u][bidx] = self.load_count elif self.cache_strategy == "lru-s": for u in range(self.num_units): for bidx in block_topk[u]: self.cached_blocks[u][bidx] = 0 else: raise ValueError # get global_h_k, global_h_v, global_mask # Beacuse exc_block_size <= n_local, no global_k, global_v used in global part global_h_q = global_q ( global_h_k, global_h_v, global_sliding_window, global_block_map, global_block_num, ) = self.get_global_hidden_and_mask(local_h_q.size(-2), block_topk) if self.async_global_stream: torch.cuda.current_stream().wait_stream(GLOBAL_STREAM) # calc global result attn.append( global_h_q, global_h_k, global_h_v, end=True, get_score=self.calc_block_score, sliding_window=global_sliding_window, complement_sliding_window=True, ) o, score_list = attn.get_result() loc_score = score_list[0] glb_score = score_list[1] if self.async_global_stream: GLOBAL_STREAM.wait_stream(torch.cuda.current_stream()) # update global score with torch.cuda.stream(GLOBAL_STREAM): self.update_block_score(glb_score, global_block_map, global_block_num) return o.view((self.batch_size, self.num_heads, -1, self.dim_head)), loc_score def get_batched_topk(self, global_q): length = global_q.shape[2] exc_num = (length + self.exc_block_size - 1) // self.exc_block_size exc_block_num = length // self.exc_block_size ret = [] if self.num_global_block <= self.topk: for _ in range(exc_num): ret.append( [ list(range(len(self.global_blocks[0]))) for _ in range(self.num_units) ] ) return ret global_h_q = global_q assert global_h_q.dim() == 4 assert global_h_q.shape[:2] == (self.num_units, self.unit_size) assert global_h_q.shape[3] == self.dim_head block_k = torch.cat( [self.block_k[u].get_data()[None, :, :] for u in range(self.num_units)], dim=0, ) assert block_k.shape == ( self.num_units, self.num_global_block, self.dim_head * self.unit_size, ) block_k = ( block_k.reshape( self.num_units, self.num_global_block, self.unit_size, self.dim_head ) .permute(0, 2, 1, 3) .contiguous() ) if exc_block_num > 0: tmp_global_h_q = ( global_h_q[:, :, : exc_block_num * self.exc_block_size, :] .reshape( self.num_units, self.unit_size, exc_block_num, self.exc_block_size, self.dim_head, ) .mean(dim=-2) ) assert tmp_global_h_q.shape == ( self.num_units, self.unit_size, exc_block_num, self.dim_head, ) block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)).mean( dim=1 ) # (num_units, exc_block_num, num_global_block) assert block_score.shape == ( self.num_units, exc_block_num, self.num_global_block, ) indices = block_score.topk(self.topk, dim=-1).indices.cpu() for b in range(exc_block_num): tmp = [] for u in range(self.num_units): tmp.append(indices[u, b].tolist()) assert len(tmp[-1]) == self.topk ret.append(tmp) if exc_block_num != exc_num: tmp_global_h_q = ( global_h_q[:, :, exc_block_num * self.exc_block_size :, :] .reshape( self.num_units, self.unit_size, length - exc_block_num * self.exc_block_size, self.dim_head, ) .mean(dim=-2, keepdim=True) ) assert tmp_global_h_q.shape == ( self.num_units, self.unit_size, 1, self.dim_head, ) block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)) assert block_score.shape == ( self.num_units, self.unit_size, 1, self.num_global_block, ) block_score = block_score.squeeze(dim=2).mean(dim=1) assert block_score.shape == (self.num_units, self.num_global_block) indices = block_score.topk(self.topk, dim=-1).indices.cpu() tmp = [] for u in range(self.num_units): tmp.append(indices[u].tolist()) assert len(tmp[-1]) == self.topk ret.append(tmp) return ret def append_global(self, exc_length, kv_length, local_score): global_remainder_ed = self._global_remainder_ed + exc_length global_remainder_st = self._global_remainder_st global_remainder_len = global_remainder_ed - global_remainder_st assert local_score.shape[:3] == (self.num_units, self.unit_size, kv_length) local_score = local_score[:, :, -exc_length - self.n_local :] self.global_remainder_local_score[ :, :, global_remainder_ed - local_score.size(-1) : global_remainder_ed ].add_(local_score) if not self.init_exc and global_remainder_len > self.n_local: global_k = self.global_remainder[0] global_v = self.global_remainder[1] append_init_len = min( self.n_init - self.init_k.size(-2), global_remainder_len - self.n_local ) self.init_k = torch.cat( ( self.init_k, global_k[ :, :, global_remainder_st : global_remainder_st + append_init_len, :, ], ), dim=-2, ) self.init_v = torch.cat( ( self.init_v, global_v[ :, :, global_remainder_st : global_remainder_st + append_init_len, :, ], ), dim=-2, ) global_remainder_st += append_init_len global_remainder_len -= append_init_len if self.init_k.size(-2) == self.n_init: self.init_exc = True while global_remainder_len - self.block_size >= self.n_local: global_remainder_len -= self.block_size for u in range(self.num_units): self.global_blocks[u].append( ( MemoryUnit( ( self.global_remainder[0][ u, :, global_remainder_st : global_remainder_st + self.block_size, :, ], self.global_remainder[1][ u, :, global_remainder_st : global_remainder_st + self.block_size, :, ], ), self.cuda_cache, False, self.pin_memory, ) ) ) global_block_k = self.get_block_k( self.global_remainder[0][ :, :, global_remainder_st : global_remainder_st + self.block_size, : ], self.global_remainder_local_score[ :, :, global_remainder_st : global_remainder_st + self.block_size ], ) assert global_block_k.shape == ( self.num_units, self.unit_size, self.repr_topk, self.dim_head, ) global_block_k = global_block_k.mean(dim=-2, keepdim=False) global_block_k = global_block_k.reshape( self.num_units, self.unit_size * self.dim_head ) global_block_k = global_block_k[:, None, :] self.num_global_block += 1 for u in range(self.num_units): self.block_k[u].append(global_block_k[u]) global_remainder_st += self.block_size self._global_remainder_ed = global_remainder_ed self._global_remainder_st = global_remainder_st def append( self, local_q, local_k, local_v, global_q, global_k, global_v, ): batch_size = local_q.size(0) input_length = local_q.size(-2) if self.perhead: num_heads = local_q.size(1) num_heads_kv = local_v.size(1) def repeat_kv(t): t = t.view(batch_size, num_heads_kv, 1, input_length, -1) t = t.expand( batch_size, num_heads_kv, num_heads // num_heads_kv, input_length, -1, ) t = t.reshape(batch_size * num_heads, 1, input_length, -1) return t local_q = local_q.view(batch_size * num_heads, 1, input_length, -1) local_k = repeat_kv(local_k) local_v = repeat_kv(local_v) global_q = global_q.view(batch_size * num_heads, 1, input_length, -1) global_k = repeat_kv(global_k) global_v = repeat_kv(global_v) if not self.initialized: self.init(local_q, local_k, local_v, global_q, global_k, global_v) input_length = local_q.size(-2) if self.async_global_stream: GLOBAL_STREAM.wait_stream(torch.cuda.current_stream()) # append local and global tensor self.local_k = torch.cat((self.local_k, local_k), dim=-2) self.local_v = torch.cat((self.local_v, local_v), dim=-2) kv_length = self.local_k.size(-2) if self.dense_decoding: self.dense_k = torch.cat((self.dense_k, local_k), dim=-2) self.dense_v = torch.cat((self.dense_v, local_v), dim=-2) # append global remainder with torch.cuda.stream(GLOBAL_STREAM): self._global_remainder_st = 0 self._global_remainder_ed = self.global_remainder[0].size(-2) self.global_remainder = ( torch.cat((self.global_remainder[0], global_k), dim=-2), torch.cat((self.global_remainder[1], global_v), dim=-2), ) self.global_remainder_local_score = torch.cat( ( self.global_remainder_local_score, torch.zeros( (self.num_units, self.unit_size, global_k.size(-2)), dtype=global_k.dtype, device=global_k.device, ), ), dim=-1, ) with torch.cuda.stream(GLOBAL_STREAM): global_q = self.position_embedding.apply_rotary_pos_emb_one_angle( global_q, self.n_local ) use_chunk_topk = self.chunk_topk_calc is not None and input_length > 1 self._use_chunk_topk = use_chunk_topk if use_chunk_topk: exc_block_num = input_length // self.exc_block_size exc_block_per_topk_chunk = self.chunk_topk_calc // self.exc_block_size calc_cur_list = [ i * self.exc_block_size for i in range(0, exc_block_num + 1, exc_block_per_topk_chunk) ] if calc_cur_list[-1] < input_length: calc_cur_list.append(input_length) self._topk_cur = 0 self._topk_calc_cur = -1 o_list = [] for st in range(0, input_length, self.exc_block_size): ed = min(st + self.exc_block_size, input_length) if use_chunk_topk and calc_cur_list[self._topk_calc_cur + 1] < ed: # calculate topk and sync with host here assert ed <= calc_cur_list[self._topk_calc_cur + 2] self._topk_calc_cur += 1 with torch.cuda.stream(GLOBAL_STREAM): self._cached_topk = self.get_batched_topk( global_q[ :, :, calc_cur_list[self._topk_calc_cur] : calc_cur_list[ self._topk_calc_cur + 1 ], :, ] ) self._topk_cur = 0 kv_st = max(kv_length + st - input_length - self.n_local, 0) kv_ed = kv_length + ed - input_length chunk_o, local_score = self._append( local_q[:, :, st:ed, :], self.local_k[:, :, kv_st:kv_ed, :], self.local_v[:, :, kv_st:kv_ed, :], global_q[:, :, st:ed, :], ) o_list.append(chunk_o) # append global with torch.cuda.stream(GLOBAL_STREAM): self.append_global(ed - st, kv_ed - kv_st, local_score) if self.async_global_stream: torch.cuda.current_stream().wait_stream(GLOBAL_STREAM) if use_chunk_topk: self._topk_cur += 1 self.length += input_length # update local and global tensor if self.local_k.size(-2) >= self.n_local: self.local_k = self.local_k[:, :, -self.n_local :, :] self.local_v = self.local_v[:, :, -self.n_local :, :] assert self._global_remainder_ed == self.global_remainder[0].size(-2) with torch.cuda.stream(GLOBAL_STREAM): self.global_remainder = ( self.global_remainder[0][:, :, self._global_remainder_st :, :], self.global_remainder[1][:, :, self._global_remainder_st :, :], ) self.global_remainder_local_score = self.global_remainder_local_score[ :, :, self._global_remainder_st : ] ret = torch.cat(o_list, dim=-2) if self.perhead: ret = ret.view(batch_size, num_heads, input_length, -1) return ret def size(self, *args, **kwargs): return self.length def inf_llm_forward( n_local, n_init, topk, block_size, max_cached_block, exc_block_size, repr_topk: int = 1, cache_strategy="lru", score_decay=None, chunk_topk_calc=None, async_global_stream=True, pin_memory=False, faiss=False, perhead=False, dense_decoding=False, *args, **kwargs ): def forward( self, query: torch.Tensor, key_value: torch.Tensor, position_bias: Optional[torch.Tensor], use_cache: bool, past_key_value, project_q, project_k, project_v, attention_out, dim_head, num_heads, num_heads_kv, ): batch_size = query.size(0) len_q = query.size(1) len_k = key_value.size(1) # assert use_cache h_q = project_q(query) # (batch, len_q, num_heads * dim_head) h_k = project_k(key_value) # (batch, len_k, num_heads * dim_head) h_v = project_v(key_value) # (batch, len_k, num_heads * dim_head) h_q = ( h_q.view(batch_size, len_q, num_heads, dim_head) .permute(0, 2, 1, 3) .contiguous() ) # (batch, num_heads, len_q, dim_head) h_k = ( h_k.view(batch_size, len_k, num_heads_kv, dim_head) .permute(0, 2, 1, 3) .contiguous() ) # (batch, num_heads_kv, len_k, dim_head) h_v = ( h_v.view(batch_size, len_k, num_heads_kv, dim_head) .permute(0, 2, 1, 3) .contiguous() ) # (batch, num_heads_kv, len_k, dim_head) if len_q == 1 and dense_decoding: past_k = past_key_value.dense_k past_v = past_key_value.dense_v h_k = torch.cat((past_k, h_k), dim=-2) h_v = torch.cat((past_v, h_v), dim=-2) past_key_value.dense_k = h_k past_key_value.dense_v = h_v h_q, h_k = position_bias(h_q, h_k) # (batch_size, seqlen, nheads, headdim) h_q = h_q.transpose(1, 2) h_k = h_k.transpose(1, 2) h_v = h_v.transpose(1, 2) # (batch_size, seqlen, nheads, headdim) o = flash_attn_func(h_q, h_k, h_v, causal=True) o = o.reshape(batch_size, len_q, dim_head * num_heads) o = attention_out(o) if use_cache: return o, past_key_value else: return o if past_key_value is None: past_key_value = ContextManager( position_bias, n_init, n_local, block_size, max_cached_block, topk, exc_block_size, score_decay, repr_topk, cache_strategy, chunk_topk_calc, async_global_stream, pin_memory, faiss, perhead, dense_decoding=dense_decoding, ) local_q, local_k, local_v = h_q, h_k, h_v global_q, global_k, global_v = h_q, h_k, h_v o = past_key_value.append( local_q, local_k, local_v, global_q, global_k, global_v, ) o = o.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3) o = o.reshape(batch_size, len_q, dim_head * num_heads) o = attention_out(o) if use_cache: return o, past_key_value else: return o return forward class GreedySearch: def __init__(self, model, tokenizer): model.eval() self.device = model.device self.model = model self.tokenizer = tokenizer self.past_kv = None def clear(self): self.past_kv = None def _process_texts(self, input_text): model_inputs = {} input_ids = self.tokenizer.encode(input_text) model_inputs["input_ids"] = input_ids model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"]) for key in model_inputs: model_inputs[key] = ( torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda() ) return model_inputs def generate(self, text=None, input_ids=None, **kwargs): if input_ids is None: model_inputs = self._process_texts(text) input_ids = model_inputs["input_ids"] with torch.inference_mode(): result = self._decode(input_ids, **kwargs) self.clear() return result def _decode( self, input_ids, max_length=100, extra_end_token_ids=[], chunk_size: int = 4096, output=False, ): if input_ids.dim() == 1: input_ids = input_ids[None, :] input_ids = input_ids.cuda() attention_mask = torch.ones_like(input_ids) assert input_ids.size(0) == 1 length = input_ids.size(1) end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id] logits = None past_key_values = self.past_kv if output: output_text = "" for i in range(max_length + 1): if i == 0: if chunk_size is None: chunk_size = input_ids.size(1) for st in range(0, input_ids.size(1) - 1, chunk_size): ed = min(input_ids.size(1) - 1, st + chunk_size) out = self.model( input_ids=input_ids[:, st:ed], attention_mask=attention_mask[:, :ed], use_cache=True, return_dict=True, past_key_values=past_key_values, ) logits, past_key_values = out.logits, out.past_key_values out = self.model( input_ids=input_ids[:, -1:], attention_mask=attention_mask, use_cache=True, return_dict=True, past_key_values=past_key_values, ) logits, past_key_values = out.logits, out.past_key_values else: out = self.model( input_ids=input_ids[:, -1:], attention_mask=attention_mask, past_key_values=past_key_values, use_cache=True, return_dict=True, ) logits, past_key_values = out.logits, out.past_key_values logits = logits[:, -1, :] word = logits.argmax(dim=-1) if word.item() in end_token_ids or i == max_length: break input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1) attention_mask = torch.cat( ( attention_mask, torch.ones( (attention_mask.size(0), 1), dtype=torch.int, device=attention_mask.device, ), ), dim=-1, ) if output: tmp = self.tokenizer.decode(input_ids.squeeze(0)[length:]) if len(tmp) > len(output_text): import sys sys.stdout.write(tmp[len(output_text) :]) sys.stdout.flush() output_text = tmp self.past_kv = past_key_values if output: sys.stdout.write("\n") sys.stdout.flush() # return [self.tokenizer.decode(input_ids.squeeze(0)[length:])] return input_ids class InfLLMGenerator(GreedySearch): def generate( self, input_ids=None, generation_config=None, pad_token_id=None, max_new_tokens=None, ): if max_new_tokens is not None: max_new_tokens = max_new_tokens else: max_new_tokens = generation_config.max_new_tokens return super().generate( text=None, input_ids=input_ids, max_length=max_new_tokens, chunk_size=8192, extra_end_token_ids=[pad_token_id] if pad_token_id is not None else [], ) @torch.no_grad() def __call__(self, input_ids=None, *args, **kwargs): # chunked forward chunk_size = 8192 all_logits = torch.empty(0, dtype=torch.bfloat16).to(input_ids.device) for st in range(0, input_ids.size(1), chunk_size): torch.cuda.empty_cache() ed = min(input_ids.size(1), st + chunk_size) out = self.model( input_ids=input_ids[:, st:ed], ) logits = out.logits.to(torch.bfloat16) all_logits = torch.cat((all_logits, logits), dim=1) return CausalLMOutput(logits=all_logits)