Text Generation
Transformers
PyTorch
Safetensors
longllama
text-generation-inference
custom_code
File size: 2,286 Bytes
b65129a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import namedtuple
from dataclasses import dataclass
import torch
from typing import Tuple, Optional

@dataclass
class LongLlamaMemConfig:
    """
    Class for configuring memory caches for LongLlama model.

    Args:
        positionals (`boolean`)
            Whether to use positional embeddings in memory layer
        cache_dtype (`torch.dtype`)
            Specifies storing type for keys and values
        attention_grouping (`Tuple[int, int]`, *optional*)
            One can trade speed for memory by performing attention
            in memory layers sequentially. 
            When equal to `(4, 128)` the memory layers will process at most 4 heads and 128 queries
            from each head at once. That is at most 512 queries at once.
    """

    positionals: bool = True
    cache_dtype: torch.dtype = torch.bfloat16
    attention_grouping: Optional[Tuple[int, int]] = None


@dataclass
class LongLlamaMemCache:
    """
    Class with LongLlama's memory cache

    Args:
        keys (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`)
        values (`torch.FloatTensor` of shape `(batch_size, num_heads, mem_length, embed_size_per_head)`)
        masks (`torch.FloatTensor` of shape `(batch_size, 1, mem_length, 1)`)
            For masking out parts of memory
    """

    keys: torch.FloatTensor
    values: torch.FloatTensor
    masks: torch.FloatTensor


def mem_apply_update(prev_mem_cache: LongLlamaMemCache, new_mem_content: LongLlamaMemCache, mem_config: LongLlamaMemConfig):
    def update_one(prev, new):
        if len(prev.shape) != 4 or len(new.shape) != 4:
            raise ValueError(f"Memory cache content should be consistent in shape got {prev.shape} {new.shape}")

        return torch.concat([prev, new], dim=-2)

    insert_size = new_mem_content.keys.shape[-2]

    if new_mem_content.values.shape[-2] != insert_size or new_mem_content.masks.shape[-2] != insert_size:
        raise ValueError(f"Inconsistent mem_length in new_mem_content")

    return LongLlamaMemCache(
        keys=update_one(prev_mem_cache.keys, new_mem_content.keys),
        values=update_one(prev_mem_cache.values, new_mem_content.values),
        masks=update_one(prev_mem_cache.masks, new_mem_content.masks),
    )