Hansimov commited on
Commit
1b9f698
1 Parent(s): c95d47e

:zap: [Enhance] Auto calculate max_tokens if not set

Browse files
apis/chat_api.py CHANGED
@@ -56,7 +56,7 @@ class ChatAPIApp:
56
  if api_key.startswith("hf_"):
57
  return api_key
58
  else:
59
- logger.warn(f"Invalid HF Token")
60
  else:
61
  logger.warn("Not provide HF Token!")
62
  return None
@@ -71,11 +71,11 @@ class ChatAPIApp:
71
  description="(list) Messages",
72
  )
73
  temperature: float = Field(
74
- default=0.01,
75
  description="(float) Temperature",
76
  )
77
  max_tokens: int = Field(
78
- default=4096,
79
  description="(int) Max tokens",
80
  )
81
  stream: bool = Field(
 
56
  if api_key.startswith("hf_"):
57
  return api_key
58
  else:
59
+ logger.warn(f"Invalid HF Token!")
60
  else:
61
  logger.warn("Not provide HF Token!")
62
  return None
 
71
  description="(list) Messages",
72
  )
73
  temperature: float = Field(
74
+ default=0,
75
  description="(float) Temperature",
76
  )
77
  max_tokens: int = Field(
78
+ default=-1,
79
  description="(int) Max tokens",
80
  )
81
  stream: bool = Field(
networks/message_streamer.py CHANGED
@@ -1,6 +1,7 @@
1
  import json
2
  import re
3
  import requests
 
4
  from messagers.message_outputer import OpenaiStreamOutputer
5
  from utils.logger import logger
6
  from utils.enver import enver
@@ -22,6 +23,12 @@ class MessageStreamer:
22
  "mistral-7b": "</s>",
23
  "openchat-3.5": "<|end_of_turn|>",
24
  }
 
 
 
 
 
 
25
 
26
  def __init__(self, model: str):
27
  if model in self.MODEL_MAP.keys():
@@ -30,6 +37,7 @@ class MessageStreamer:
30
  self.model = "default"
31
  self.model_fullname = self.MODEL_MAP[self.model]
32
  self.message_outputer = OpenaiStreamOutputer()
 
33
 
34
  def parse_line(self, line):
35
  line = line.decode("utf-8")
@@ -38,11 +46,17 @@ class MessageStreamer:
38
  content = data["token"]["text"]
39
  return content
40
 
 
 
 
 
 
 
41
  def chat_response(
42
  self,
43
  prompt: str = None,
44
  temperature: float = 0.01,
45
- max_new_tokens: int = 8192,
46
  api_key: str = None,
47
  ):
48
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
@@ -60,6 +74,19 @@ class MessageStreamer:
60
  )
61
  self.request_headers["Authorization"] = f"Bearer {api_key}"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  # References:
64
  # huggingface_hub/inference/_client.py:
65
  # class InferenceClient > def text_generation()
 
1
  import json
2
  import re
3
  import requests
4
+ from tiktoken import get_encoding as tiktoken_get_encoding
5
  from messagers.message_outputer import OpenaiStreamOutputer
6
  from utils.logger import logger
7
  from utils.enver import enver
 
23
  "mistral-7b": "</s>",
24
  "openchat-3.5": "<|end_of_turn|>",
25
  }
26
+ TOKEN_LIMIT_MAP = {
27
+ "mixtral-8x7b": 32768,
28
+ "mistral-7b": 32768,
29
+ "openchat-3.5": 8192,
30
+ }
31
+ TOKEN_RESERVED = 32
32
 
33
  def __init__(self, model: str):
34
  if model in self.MODEL_MAP.keys():
 
37
  self.model = "default"
38
  self.model_fullname = self.MODEL_MAP[self.model]
39
  self.message_outputer = OpenaiStreamOutputer()
40
+ self.tokenizer = tiktoken_get_encoding("cl100k_base")
41
 
42
  def parse_line(self, line):
43
  line = line.decode("utf-8")
 
46
  content = data["token"]["text"]
47
  return content
48
 
49
+ def count_tokens(self, text):
50
+ tokens = self.tokenizer.encode(text)
51
+ token_count = len(tokens)
52
+ logger.note(f"Prompt Token Count: {token_count}")
53
+ return token_count
54
+
55
  def chat_response(
56
  self,
57
  prompt: str = None,
58
  temperature: float = 0.01,
59
+ max_new_tokens: int = None,
60
  api_key: str = None,
61
  ):
62
  # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl
 
74
  )
75
  self.request_headers["Authorization"] = f"Bearer {api_key}"
76
 
77
+ token_limit = (
78
+ self.TOKEN_LIMIT_MAP[self.model]
79
+ - self.TOKEN_RESERVED
80
+ - self.count_tokens(prompt)
81
+ )
82
+ if token_limit <= 0:
83
+ raise ValueError("Prompt exceeded token limit!")
84
+
85
+ if max_new_tokens is None or max_new_tokens <= 0:
86
+ max_new_tokens = token_limit
87
+ else:
88
+ max_new_tokens = min(max_new_tokens, token_limit)
89
+
90
  # References:
91
  # huggingface_hub/inference/_client.py:
92
  # class InferenceClient > def text_generation()
requirements.txt CHANGED
@@ -6,5 +6,6 @@ pydantic
6
  requests
7
  sse_starlette
8
  termcolor
 
9
  uvicorn
10
  websockets
 
6
  requests
7
  sse_starlette
8
  termcolor
9
+ tiktoken
10
  uvicorn
11
  websockets