Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ import socketio
|
|
7 |
import requests
|
8 |
import logging
|
9 |
from threading import Event
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
logging.basicConfig(level=logging.INFO)
|
@@ -74,21 +75,14 @@ def normalize_content(content):
|
|
74 |
# 如果是其他类型,返回空字符串
|
75 |
return ""
|
76 |
|
77 |
-
def
|
78 |
"""
|
79 |
-
|
80 |
-
|
81 |
-
对于英文,我们可以继续使用空格分词。
|
82 |
"""
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
if '\u4e00' <= char <= '\u9fff':
|
87 |
-
tokens += 1
|
88 |
-
else:
|
89 |
-
# 对于非中文字符,简单按空格分词
|
90 |
-
tokens += len(char.split())
|
91 |
-
return tokens
|
92 |
|
93 |
@app.route('/')
|
94 |
def root():
|
@@ -126,8 +120,8 @@ def messages():
|
|
126 |
# 使用 normalize_content 递归处理 msg['content']
|
127 |
previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
|
128 |
|
129 |
-
# 动态计算输入的 token
|
130 |
-
input_tokens =
|
131 |
|
132 |
msg_id = str(uuid.uuid4())
|
133 |
response_event = Event()
|
@@ -229,8 +223,8 @@ def messages():
|
|
229 |
if sio.connected:
|
230 |
sio.disconnect()
|
231 |
|
232 |
-
# 动态计算输出的 token
|
233 |
-
output_tokens =
|
234 |
|
235 |
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
|
236 |
yield create_event("message_delta", {
|
@@ -304,8 +298,8 @@ def handle_non_stream(previous_messages, msg_id, model, input_tokens):
|
|
304 |
# 等待响应完成
|
305 |
response_event.wait(timeout=30)
|
306 |
|
307 |
-
# 动态计算输出的 token
|
308 |
-
output_tokens =
|
309 |
|
310 |
# 生成完整的响应
|
311 |
full_response = {
|
|
|
7 |
import requests
|
8 |
import logging
|
9 |
from threading import Event
|
10 |
+
import tiktoken # 引入 tiktoken 库
|
11 |
|
12 |
app = Flask(__name__)
|
13 |
logging.basicConfig(level=logging.INFO)
|
|
|
75 |
# 如果是其他类型,返回空字符串
|
76 |
return ""
|
77 |
|
78 |
+
def calculate_tokens_via_tiktoken(text, model="gpt-3.5-turbo"):
|
79 |
"""
|
80 |
+
使用 tiktoken 库根据 GPT 模型计算 token 数量。
|
81 |
+
Claude 模型与 GPT 模型的 token 计算机制类似,因此可以使用 tiktoken。
|
|
|
82 |
"""
|
83 |
+
encoding = tiktoken.encoding_for_model(model) # 获取模型的编码器
|
84 |
+
tokens = encoding.encode(text) # 对文本进行 tokenization
|
85 |
+
return len(tokens)
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
@app.route('/')
|
88 |
def root():
|
|
|
120 |
# 使用 normalize_content 递归处理 msg['content']
|
121 |
previous_messages = "\n\n".join([normalize_content(msg['content']) for msg in json_body['messages']])
|
122 |
|
123 |
+
# 动态计算输入的 token 数量,使用 tiktoken 进行 tokenization
|
124 |
+
input_tokens = calculate_tokens_via_tiktoken(previous_messages, model="gpt-3.5-turbo")
|
125 |
|
126 |
msg_id = str(uuid.uuid4())
|
127 |
response_event = Event()
|
|
|
223 |
if sio.connected:
|
224 |
sio.disconnect()
|
225 |
|
226 |
+
# 动态计算输出的 token 数量,使用 tiktoken 进行 tokenization
|
227 |
+
output_tokens = calculate_tokens_via_tiktoken(''.join(response_text), model="gpt-3.5-turbo")
|
228 |
|
229 |
yield create_event("content_block_stop", {"type": "content_block_stop", "index": 0})
|
230 |
yield create_event("message_delta", {
|
|
|
298 |
# 等待响应完成
|
299 |
response_event.wait(timeout=30)
|
300 |
|
301 |
+
# 动态计算输出的 token 数量,使用 tiktoken 进行 tokenization
|
302 |
+
output_tokens = calculate_tokens_via_tiktoken(''.join(response_text), model="gpt-3.5-turbo")
|
303 |
|
304 |
# 生成完整的响应
|
305 |
full_response = {
|