File size: 1,378 Bytes
230c4b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
# Copyright (c) Together
# This software is distributed under the terms of the Apache License, Version 2.0
# Author: Michael Poli
from torch import Tensor
from dataclasses import dataclass, field
from typing import Optional
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/utils/generation.py
@dataclass
class InferenceParams:
"""Inference parameters that are passed to the main model in order
to efficienly calculate and store the context during inference."""
max_seqlen: int
max_batch_size: int
seqlen_offset: int = 0
batch_size_offset: int = 0
key_value_memory_dict: dict = field(default_factory=dict)
lengths_per_sample: Optional[Tensor] = None
def reset(self, max_seqlen, max_batch_size):
self.max_seqlen = max_seqlen
self.max_batch_size = max_batch_size
self.seqlen_offset = 0
if self.lengths_per_sample is not None:
self.lengths_per_sample.zero_()
@dataclass
class RecurrentInferenceParams:
"""Inference parameters passed to blocks with recurrent mode."""
fir_filter_length: int = 3
state_dim: int = 16
seqlen_offset: int = 0
fir_state_dict: dict = field(default_factory=dict)
state_dict: dict = field(default_factory=dict)
def reset(self):
self.fir_filter_length = 3
self.state_dim = 16
self.seqlen_offset = 0
|