File size: 1,622 Bytes
8df3985
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0f710a2
 
 
8df3985
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from tclogger import logger
from transformers import AutoTokenizer

from constants.models import MODEL_MAP, TOKEN_LIMIT_MAP, TOKEN_RESERVED


class TokenChecker:
    def __init__(self, input_str: str, model: str):
        self.input_str = input_str

        if model in MODEL_MAP.keys():
            self.model = model
        else:
            self.model = "mixtral-8x7b"

        self.model_fullname = MODEL_MAP[self.model]

        # As some models are gated, we need to fetch tokenizers from alternatives
        GATED_MODEL_MAP = {
            "llama3-70b": "NousResearch/Meta-Llama-3-70B",
            "gemma-7b": "unsloth/gemma-7b",
            "mistral-7b": "dfurman/Mistral-7B-Instruct-v0.2",
            "mixtral-8x7b": "dfurman/Mixtral-8x7B-Instruct-v0.1",
        }
        if self.model in GATED_MODEL_MAP.keys():
            self.tokenizer = AutoTokenizer.from_pretrained(GATED_MODEL_MAP[self.model])
        else:
            self.tokenizer = AutoTokenizer.from_pretrained(self.model_fullname)

    def count_tokens(self):
        token_count = len(self.tokenizer.encode(self.input_str))
        logger.note(f"Prompt Token Count: {token_count}")
        return token_count

    def get_token_limit(self):
        return TOKEN_LIMIT_MAP[self.model]

    def get_token_redundancy(self):
        return int(self.get_token_limit() - TOKEN_RESERVED - self.count_tokens())

    def check_token_limit(self):
        if self.get_token_redundancy() <= 0:
            raise ValueError(
                f"Prompt exceeded token limit: {self.count_tokens()} > {self.get_token_limit()}"
            )
        return True