Spaces:
Paused
Paused
add api server and openapi
Browse files- Dockerfile +1 -0
- README.md +1 -1
- api_server.py +188 -0
- entrypoint.sh +7 -1
- main.py +0 -76
- protocol.py +232 -0
- serving_chat.py +265 -0
- serving_completion.py +290 -0
- serving_engine.py +133 -0
Dockerfile
CHANGED
@@ -14,6 +14,7 @@ RUN pip3 install "torch==2.1.1"
|
|
14 |
# This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
|
15 |
# RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
|
16 |
RUN pip3 install vllm
|
|
|
17 |
RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
|
18 |
|
19 |
RUN git clone https://github.com/NVIDIA/apex && \
|
|
|
14 |
# This build is slow but NVIDIA does not provide binaries. Increase MAX_JOBS as needed.
|
15 |
# RUN pip3 install "git+https://github.com/stanford-futuredata/megablocks.git"
|
16 |
RUN pip3 install vllm
|
17 |
+
RUN pip3 install openai
|
18 |
RUN pip3 install "xformers==0.0.23" "transformers==4.36.0" "fschat[model_worker]==0.2.34"
|
19 |
|
20 |
RUN git clone https://github.com/NVIDIA/apex && \
|
README.md
CHANGED
@@ -2,7 +2,7 @@
|
|
2 |
title: Test Docker
|
3 |
emoji: 🔥
|
4 |
colorFrom: purple
|
5 |
-
colorTo:
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
|
|
2 |
title: Test Docker
|
3 |
emoji: 🔥
|
4 |
colorFrom: purple
|
5 |
+
colorTo: white
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
license: mit
|
api_server.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import asyncio
|
3 |
+
import json
|
4 |
+
from contextlib import asynccontextmanager
|
5 |
+
from aioprometheus import MetricsMiddleware
|
6 |
+
from aioprometheus.asgi.starlette import metrics
|
7 |
+
import fastapi
|
8 |
+
import uvicorn
|
9 |
+
from http import HTTPStatus
|
10 |
+
from fastapi import Request
|
11 |
+
from fastapi.exceptions import RequestValidationError
|
12 |
+
from fastapi.middleware.cors import CORSMiddleware
|
13 |
+
from fastapi.responses import JSONResponse, StreamingResponse, Response
|
14 |
+
|
15 |
+
from vllm.engine.arg_utils import AsyncEngineArgs
|
16 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
17 |
+
from vllm.engine.metrics import add_global_metrics_labels
|
18 |
+
from protocol import CompletionRequest, ChatCompletionRequest, ErrorResponse
|
19 |
+
from vllm.logger import init_logger
|
20 |
+
from serving_chat import OpenAIServingChat
|
21 |
+
from serving_completion import OpenAIServingCompletion
|
22 |
+
|
23 |
+
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
24 |
+
|
25 |
+
openai_serving_chat: OpenAIServingChat = None
|
26 |
+
openai_serving_completion: OpenAIServingCompletion = None
|
27 |
+
logger = init_logger(__name__)
|
28 |
+
|
29 |
+
|
30 |
+
@asynccontextmanager
|
31 |
+
async def lifespan(app: fastapi.FastAPI):
|
32 |
+
|
33 |
+
async def _force_log():
|
34 |
+
while True:
|
35 |
+
await asyncio.sleep(10)
|
36 |
+
await engine.do_log_stats()
|
37 |
+
|
38 |
+
if not engine_args.disable_log_stats:
|
39 |
+
asyncio.create_task(_force_log())
|
40 |
+
|
41 |
+
yield
|
42 |
+
|
43 |
+
|
44 |
+
app = fastapi.FastAPI(lifespan=lifespan)
|
45 |
+
|
46 |
+
|
47 |
+
def parse_args():
|
48 |
+
parser = argparse.ArgumentParser(
|
49 |
+
description="vLLM OpenAI-Compatible RESTful API server.")
|
50 |
+
parser.add_argument("--host", type=str, default=None, help="host name")
|
51 |
+
parser.add_argument("--port", type=int, default=8000, help="port number")
|
52 |
+
parser.add_argument("--allow-credentials",
|
53 |
+
action="store_true",
|
54 |
+
help="allow credentials")
|
55 |
+
parser.add_argument("--allowed-origins",
|
56 |
+
type=json.loads,
|
57 |
+
default=["*"],
|
58 |
+
help="allowed origins")
|
59 |
+
parser.add_argument("--allowed-methods",
|
60 |
+
type=json.loads,
|
61 |
+
default=["*"],
|
62 |
+
help="allowed methods")
|
63 |
+
parser.add_argument("--allowed-headers",
|
64 |
+
type=json.loads,
|
65 |
+
default=["*"],
|
66 |
+
help="allowed headers")
|
67 |
+
parser.add_argument("--served-model-name",
|
68 |
+
type=str,
|
69 |
+
default=None,
|
70 |
+
help="The model name used in the API. If not "
|
71 |
+
"specified, the model name will be the same as "
|
72 |
+
"the huggingface name.")
|
73 |
+
parser.add_argument("--chat-template",
|
74 |
+
type=str,
|
75 |
+
default=None,
|
76 |
+
help="The file path to the chat template, "
|
77 |
+
"or the template in single-line form "
|
78 |
+
"for the specified model")
|
79 |
+
parser.add_argument("--response-role",
|
80 |
+
type=str,
|
81 |
+
default="assistant",
|
82 |
+
help="The role name to return if "
|
83 |
+
"`request.add_generation_prompt=true`.")
|
84 |
+
parser.add_argument("--ssl-keyfile",
|
85 |
+
type=str,
|
86 |
+
default=None,
|
87 |
+
help="The file path to the SSL key file")
|
88 |
+
parser.add_argument("--ssl-certfile",
|
89 |
+
type=str,
|
90 |
+
default=None,
|
91 |
+
help="The file path to the SSL cert file")
|
92 |
+
parser.add_argument(
|
93 |
+
"--root-path",
|
94 |
+
type=str,
|
95 |
+
default=None,
|
96 |
+
help="FastAPI root_path when app is behind a path based routing proxy")
|
97 |
+
|
98 |
+
parser = AsyncEngineArgs.add_cli_args(parser)
|
99 |
+
return parser.parse_args()
|
100 |
+
|
101 |
+
|
102 |
+
app.add_middleware(MetricsMiddleware) # Trace HTTP server metrics
|
103 |
+
app.add_route("/metrics", metrics) # Exposes HTTP metrics
|
104 |
+
|
105 |
+
|
106 |
+
@app.exception_handler(RequestValidationError)
|
107 |
+
async def validation_exception_handler(_, exc):
|
108 |
+
err = openai_serving_chat.create_error_response(message=str(exc))
|
109 |
+
return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST)
|
110 |
+
|
111 |
+
|
112 |
+
@app.get("/health")
|
113 |
+
async def health() -> Response:
|
114 |
+
"""Health check."""
|
115 |
+
return Response(status_code=200)
|
116 |
+
|
117 |
+
|
118 |
+
@app.get("/api/v1/models")
|
119 |
+
async def show_available_models():
|
120 |
+
models = await openai_serving_chat.show_available_models()
|
121 |
+
return JSONResponse(content=models.model_dump())
|
122 |
+
|
123 |
+
|
124 |
+
@app.post("/api/v1/chat/completions")
|
125 |
+
async def create_chat_completion(request: ChatCompletionRequest,
|
126 |
+
raw_request: Request):
|
127 |
+
generator = await openai_serving_chat.create_chat_completion(
|
128 |
+
request, raw_request)
|
129 |
+
if isinstance(generator, ErrorResponse):
|
130 |
+
return JSONResponse(content=generator.model_dump(),
|
131 |
+
status_code=generator.code)
|
132 |
+
if request.stream:
|
133 |
+
return StreamingResponse(content=generator,
|
134 |
+
media_type="text/event-stream")
|
135 |
+
else:
|
136 |
+
return JSONResponse(content=generator.model_dump())
|
137 |
+
|
138 |
+
|
139 |
+
@app.post("/api/v1/completions")
|
140 |
+
async def create_completion(request: CompletionRequest, raw_request: Request):
|
141 |
+
generator = await openai_serving_completion.create_completion(
|
142 |
+
request, raw_request)
|
143 |
+
if isinstance(generator, ErrorResponse):
|
144 |
+
return JSONResponse(content=generator.model_dump(),
|
145 |
+
status_code=generator.code)
|
146 |
+
if request.stream:
|
147 |
+
return StreamingResponse(content=generator,
|
148 |
+
media_type="text/event-stream")
|
149 |
+
else:
|
150 |
+
return JSONResponse(content=generator.model_dump())
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
args = parse_args()
|
155 |
+
|
156 |
+
app.add_middleware(
|
157 |
+
CORSMiddleware,
|
158 |
+
allow_origins=args.allowed_origins,
|
159 |
+
allow_credentials=args.allow_credentials,
|
160 |
+
allow_methods=args.allowed_methods,
|
161 |
+
allow_headers=args.allowed_headers,
|
162 |
+
)
|
163 |
+
|
164 |
+
logger.info(f"args: {args}")
|
165 |
+
|
166 |
+
if args.served_model_name is not None:
|
167 |
+
served_model = args.served_model_name
|
168 |
+
else:
|
169 |
+
served_model = args.model
|
170 |
+
|
171 |
+
engine_args = AsyncEngineArgs.from_cli_args(args)
|
172 |
+
engine = AsyncLLMEngine.from_engine_args(engine_args)
|
173 |
+
openai_serving_chat = OpenAIServingChat(engine, served_model,
|
174 |
+
args.response_role,
|
175 |
+
args.chat_template)
|
176 |
+
openai_serving_completion = OpenAIServingCompletion(engine, served_model)
|
177 |
+
|
178 |
+
# Register labels for metrics
|
179 |
+
add_global_metrics_labels(model_name=engine_args.model)
|
180 |
+
|
181 |
+
app.root_path = args.root_path
|
182 |
+
uvicorn.run(app,
|
183 |
+
host=args.host,
|
184 |
+
port=args.port,
|
185 |
+
log_level="info",
|
186 |
+
timeout_keep_alive=TIMEOUT_KEEP_ALIVE,
|
187 |
+
ssl_keyfile=args.ssl_keyfile,
|
188 |
+
ssl_certfile=args.ssl_certfile)
|
entrypoint.sh
CHANGED
@@ -30,7 +30,13 @@ if [[ ! -z "${ROOT_PATH}" ]]; then
|
|
30 |
fi
|
31 |
|
32 |
# Run the provided command
|
33 |
-
exec python3 -u -m vllm.entrypoints.openai.api_server \
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
--model "${HF_MODEL}" \
|
35 |
--host 0.0.0.0 \
|
36 |
--port 7860 \
|
|
|
30 |
fi
|
31 |
|
32 |
# Run the provided command
|
33 |
+
# exec python3 -u -m vllm.entrypoints.openai.api_server \
|
34 |
+
# --model "${HF_MODEL}" \
|
35 |
+
# --host 0.0.0.0 \
|
36 |
+
# --port 7860 \
|
37 |
+
# ${additional_args}
|
38 |
+
|
39 |
+
exec python3 -u api_server.py \
|
40 |
--model "${HF_MODEL}" \
|
41 |
--host 0.0.0.0 \
|
42 |
--port 7860 \
|
main.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import copy
|
3 |
-
import time
|
4 |
-
import llama_cpp
|
5 |
-
from llama_cpp import Llama
|
6 |
-
from huggingface_hub import hf_hub_download
|
7 |
-
|
8 |
-
import uvicorn
|
9 |
-
from fastapi import FastAPI, Request
|
10 |
-
|
11 |
-
|
12 |
-
llm = Llama(
|
13 |
-
model_path=hf_hub_download(
|
14 |
-
repo_id=os.environ.get("REPO_ID", "TheBloke/Llama-2-7b-Chat-GGUF"),
|
15 |
-
filename=os.environ.get("MODEL_FILE", "llama-2-7b-chat.Q5_0.gguf"),
|
16 |
-
),
|
17 |
-
n_ctx=2048,
|
18 |
-
n_gpu_layers=50, # change n_gpu_layers if you have more or less VRAM
|
19 |
-
)
|
20 |
-
|
21 |
-
history = []
|
22 |
-
|
23 |
-
system_message = """
|
24 |
-
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
25 |
-
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.
|
26 |
-
"""
|
27 |
-
|
28 |
-
|
29 |
-
def generate_text(message, history):
|
30 |
-
temp = ""
|
31 |
-
input_prompt = f"[INST] <<SYS>>\n{system_message}\n<</SYS>>\n\n "
|
32 |
-
for interaction in history:
|
33 |
-
input_prompt = input_prompt + str(interaction[0]) + " [/INST] " + str(interaction[1]) + " </s><s> [INST] "
|
34 |
-
|
35 |
-
input_prompt = input_prompt + str(message) + " [/INST] "
|
36 |
-
|
37 |
-
output = llm(
|
38 |
-
input_prompt,
|
39 |
-
temperature=0.15,
|
40 |
-
top_p=0.1,
|
41 |
-
top_k=40,
|
42 |
-
repeat_penalty=1.1,
|
43 |
-
max_tokens=1024,
|
44 |
-
stop=[
|
45 |
-
"<|prompter|>",
|
46 |
-
"<|endoftext|>",
|
47 |
-
"<|endoftext|> \n",
|
48 |
-
"ASSISTANT:",
|
49 |
-
"USER:",
|
50 |
-
"SYSTEM:",
|
51 |
-
],
|
52 |
-
)
|
53 |
-
# for out in output:
|
54 |
-
# stream = copy.deepcopy(out)
|
55 |
-
# temp += stream["choices"][0]["text"]
|
56 |
-
# yield temp
|
57 |
-
|
58 |
-
history = ["init", input_prompt]
|
59 |
-
|
60 |
-
print(history)
|
61 |
-
print(output)
|
62 |
-
return output
|
63 |
-
|
64 |
-
app = FastAPI()
|
65 |
-
|
66 |
-
@app.post("/api/generate")
|
67 |
-
async def generate(request: Request):
|
68 |
-
# Receive the request as JSON
|
69 |
-
data = await request.json()
|
70 |
-
# Check if the event is a completed order
|
71 |
-
if data['message']:
|
72 |
-
response = generate_text(data['message'], history)
|
73 |
-
return {"status": "success", "data":response}
|
74 |
-
else:
|
75 |
-
# If the event is not what we're looking for, ignore it
|
76 |
-
return {"status": "ignored"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
protocol.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from
|
2 |
+
# https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/protocol/openai_api_protocol.py
|
3 |
+
import time
|
4 |
+
from typing import Dict, List, Literal, Optional, Union
|
5 |
+
|
6 |
+
from pydantic import BaseModel, Field
|
7 |
+
|
8 |
+
from vllm.utils import random_uuid
|
9 |
+
from vllm.sampling_params import SamplingParams
|
10 |
+
|
11 |
+
|
12 |
+
class ErrorResponse(BaseModel):
|
13 |
+
object: str = "error"
|
14 |
+
message: str
|
15 |
+
type: str
|
16 |
+
param: Optional[str] = None
|
17 |
+
code: int
|
18 |
+
|
19 |
+
|
20 |
+
class ModelPermission(BaseModel):
|
21 |
+
id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}")
|
22 |
+
object: str = "model_permission"
|
23 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
24 |
+
allow_create_engine: bool = False
|
25 |
+
allow_sampling: bool = True
|
26 |
+
allow_logprobs: bool = True
|
27 |
+
allow_search_indices: bool = False
|
28 |
+
allow_view: bool = True
|
29 |
+
allow_fine_tuning: bool = False
|
30 |
+
organization: str = "*"
|
31 |
+
group: Optional[str] = None
|
32 |
+
is_blocking: str = False
|
33 |
+
|
34 |
+
|
35 |
+
class ModelCard(BaseModel):
|
36 |
+
id: str
|
37 |
+
object: str = "model"
|
38 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
39 |
+
owned_by: str = "vllm"
|
40 |
+
root: Optional[str] = None
|
41 |
+
parent: Optional[str] = None
|
42 |
+
permission: List[ModelPermission] = Field(default_factory=list)
|
43 |
+
|
44 |
+
|
45 |
+
class ModelList(BaseModel):
|
46 |
+
object: str = "list"
|
47 |
+
data: List[ModelCard] = Field(default_factory=list)
|
48 |
+
|
49 |
+
|
50 |
+
class UsageInfo(BaseModel):
|
51 |
+
prompt_tokens: int = 0
|
52 |
+
total_tokens: int = 0
|
53 |
+
completion_tokens: Optional[int] = 0
|
54 |
+
|
55 |
+
|
56 |
+
class ChatCompletionRequest(BaseModel):
|
57 |
+
model: str
|
58 |
+
messages: Union[str, List[Dict[str, str]]]
|
59 |
+
temperature: Optional[float] = 0.7
|
60 |
+
top_p: Optional[float] = 1.0
|
61 |
+
n: Optional[int] = 1
|
62 |
+
max_tokens: Optional[int] = None
|
63 |
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
64 |
+
stream: Optional[bool] = False
|
65 |
+
presence_penalty: Optional[float] = 0.0
|
66 |
+
frequency_penalty: Optional[float] = 0.0
|
67 |
+
logit_bias: Optional[Dict[str, float]] = None
|
68 |
+
user: Optional[str] = None
|
69 |
+
# Additional parameters supported by vLLM
|
70 |
+
best_of: Optional[int] = None
|
71 |
+
top_k: Optional[int] = -1
|
72 |
+
ignore_eos: Optional[bool] = False
|
73 |
+
use_beam_search: Optional[bool] = False
|
74 |
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
75 |
+
skip_special_tokens: Optional[bool] = True
|
76 |
+
spaces_between_special_tokens: Optional[bool] = True
|
77 |
+
add_generation_prompt: Optional[bool] = True
|
78 |
+
echo: Optional[bool] = False
|
79 |
+
repetition_penalty: Optional[float] = 1.0
|
80 |
+
min_p: Optional[float] = 0.0
|
81 |
+
|
82 |
+
def to_sampling_params(self) -> SamplingParams:
|
83 |
+
return SamplingParams(
|
84 |
+
n=self.n,
|
85 |
+
presence_penalty=self.presence_penalty,
|
86 |
+
frequency_penalty=self.frequency_penalty,
|
87 |
+
repetition_penalty=self.repetition_penalty,
|
88 |
+
temperature=self.temperature,
|
89 |
+
top_p=self.top_p,
|
90 |
+
min_p=self.min_p,
|
91 |
+
stop=self.stop,
|
92 |
+
stop_token_ids=self.stop_token_ids,
|
93 |
+
max_tokens=self.max_tokens,
|
94 |
+
best_of=self.best_of,
|
95 |
+
top_k=self.top_k,
|
96 |
+
ignore_eos=self.ignore_eos,
|
97 |
+
use_beam_search=self.use_beam_search,
|
98 |
+
skip_special_tokens=self.skip_special_tokens,
|
99 |
+
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
100 |
+
)
|
101 |
+
|
102 |
+
|
103 |
+
class CompletionRequest(BaseModel):
|
104 |
+
model: str
|
105 |
+
# a string, array of strings, array of tokens, or array of token arrays
|
106 |
+
prompt: Union[List[int], List[List[int]], str, List[str]]
|
107 |
+
suffix: Optional[str] = None
|
108 |
+
max_tokens: Optional[int] = 16
|
109 |
+
temperature: Optional[float] = 1.0
|
110 |
+
top_p: Optional[float] = 1.0
|
111 |
+
n: Optional[int] = 1
|
112 |
+
stream: Optional[bool] = False
|
113 |
+
logprobs: Optional[int] = None
|
114 |
+
echo: Optional[bool] = False
|
115 |
+
stop: Optional[Union[str, List[str]]] = Field(default_factory=list)
|
116 |
+
presence_penalty: Optional[float] = 0.0
|
117 |
+
frequency_penalty: Optional[float] = 0.0
|
118 |
+
best_of: Optional[int] = None
|
119 |
+
logit_bias: Optional[Dict[str, float]] = None
|
120 |
+
user: Optional[str] = None
|
121 |
+
# Additional parameters supported by vLLM
|
122 |
+
top_k: Optional[int] = -1
|
123 |
+
ignore_eos: Optional[bool] = False
|
124 |
+
use_beam_search: Optional[bool] = False
|
125 |
+
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
126 |
+
skip_special_tokens: Optional[bool] = True
|
127 |
+
spaces_between_special_tokens: Optional[bool] = True
|
128 |
+
repetition_penalty: Optional[float] = 1.0
|
129 |
+
min_p: Optional[float] = 0.0
|
130 |
+
|
131 |
+
def to_sampling_params(self):
|
132 |
+
echo_without_generation = self.echo and self.max_tokens == 0
|
133 |
+
|
134 |
+
return SamplingParams(
|
135 |
+
n=self.n,
|
136 |
+
best_of=self.best_of,
|
137 |
+
presence_penalty=self.presence_penalty,
|
138 |
+
frequency_penalty=self.frequency_penalty,
|
139 |
+
repetition_penalty=self.repetition_penalty,
|
140 |
+
temperature=self.temperature,
|
141 |
+
top_p=self.top_p,
|
142 |
+
top_k=self.top_k,
|
143 |
+
min_p=self.min_p,
|
144 |
+
stop=self.stop,
|
145 |
+
stop_token_ids=self.stop_token_ids,
|
146 |
+
ignore_eos=self.ignore_eos,
|
147 |
+
max_tokens=self.max_tokens if not echo_without_generation else 1,
|
148 |
+
logprobs=self.logprobs,
|
149 |
+
use_beam_search=self.use_beam_search,
|
150 |
+
prompt_logprobs=self.logprobs if self.echo else None,
|
151 |
+
skip_special_tokens=self.skip_special_tokens,
|
152 |
+
spaces_between_special_tokens=(self.spaces_between_special_tokens),
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
class LogProbs(BaseModel):
|
157 |
+
text_offset: List[int] = Field(default_factory=list)
|
158 |
+
token_logprobs: List[Optional[float]] = Field(default_factory=list)
|
159 |
+
tokens: List[str] = Field(default_factory=list)
|
160 |
+
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None
|
161 |
+
|
162 |
+
|
163 |
+
class CompletionResponseChoice(BaseModel):
|
164 |
+
index: int
|
165 |
+
text: str
|
166 |
+
logprobs: Optional[LogProbs] = None
|
167 |
+
finish_reason: Optional[Literal["stop", "length"]] = None
|
168 |
+
|
169 |
+
|
170 |
+
class CompletionResponse(BaseModel):
|
171 |
+
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
172 |
+
object: str = "text_completion"
|
173 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
174 |
+
model: str
|
175 |
+
choices: List[CompletionResponseChoice]
|
176 |
+
usage: UsageInfo
|
177 |
+
|
178 |
+
|
179 |
+
class CompletionResponseStreamChoice(BaseModel):
|
180 |
+
index: int
|
181 |
+
text: str
|
182 |
+
logprobs: Optional[LogProbs] = None
|
183 |
+
finish_reason: Optional[Literal["stop", "length"]] = None
|
184 |
+
|
185 |
+
|
186 |
+
class CompletionStreamResponse(BaseModel):
|
187 |
+
id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}")
|
188 |
+
object: str = "text_completion"
|
189 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
190 |
+
model: str
|
191 |
+
choices: List[CompletionResponseStreamChoice]
|
192 |
+
usage: Optional[UsageInfo] = Field(default=None)
|
193 |
+
|
194 |
+
|
195 |
+
class ChatMessage(BaseModel):
|
196 |
+
role: str
|
197 |
+
content: str
|
198 |
+
|
199 |
+
|
200 |
+
class ChatCompletionResponseChoice(BaseModel):
|
201 |
+
index: int
|
202 |
+
message: ChatMessage
|
203 |
+
finish_reason: Optional[Literal["stop", "length"]] = None
|
204 |
+
|
205 |
+
|
206 |
+
class ChatCompletionResponse(BaseModel):
|
207 |
+
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
208 |
+
object: str = "chat.completion"
|
209 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
210 |
+
model: str
|
211 |
+
choices: List[ChatCompletionResponseChoice]
|
212 |
+
usage: UsageInfo
|
213 |
+
|
214 |
+
|
215 |
+
class DeltaMessage(BaseModel):
|
216 |
+
role: Optional[str] = None
|
217 |
+
content: Optional[str] = None
|
218 |
+
|
219 |
+
|
220 |
+
class ChatCompletionResponseStreamChoice(BaseModel):
|
221 |
+
index: int
|
222 |
+
delta: DeltaMessage
|
223 |
+
finish_reason: Optional[Literal["stop", "length"]] = None
|
224 |
+
|
225 |
+
|
226 |
+
class ChatCompletionStreamResponse(BaseModel):
|
227 |
+
id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}")
|
228 |
+
object: str = "chat.completion.chunk"
|
229 |
+
created: int = Field(default_factory=lambda: int(time.time()))
|
230 |
+
model: str
|
231 |
+
choices: List[ChatCompletionResponseStreamChoice]
|
232 |
+
usage: Optional[UsageInfo] = Field(default=None)
|
serving_chat.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import codecs
|
3 |
+
from fastapi import Request
|
4 |
+
from typing import AsyncGenerator, AsyncIterator, Union
|
5 |
+
from vllm.logger import init_logger
|
6 |
+
from vllm.utils import random_uuid
|
7 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
8 |
+
from protocol import (
|
9 |
+
ChatCompletionRequest, ChatCompletionResponse,
|
10 |
+
ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
|
11 |
+
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse,
|
12 |
+
UsageInfo)
|
13 |
+
from vllm.outputs import RequestOutput
|
14 |
+
from serving_engine import OpenAIServing
|
15 |
+
|
16 |
+
logger = init_logger(__name__)
|
17 |
+
|
18 |
+
|
19 |
+
class OpenAIServingChat(OpenAIServing):
|
20 |
+
|
21 |
+
def __init__(self,
|
22 |
+
engine: AsyncLLMEngine,
|
23 |
+
served_model: str,
|
24 |
+
response_role: str,
|
25 |
+
chat_template=None):
|
26 |
+
super().__init__(engine=engine, served_model=served_model)
|
27 |
+
self.response_role = response_role
|
28 |
+
self._load_chat_template(chat_template)
|
29 |
+
|
30 |
+
async def create_chat_completion(
|
31 |
+
self, request: ChatCompletionRequest, raw_request: Request
|
32 |
+
) -> Union[ErrorResponse, AsyncGenerator[str, None],
|
33 |
+
ChatCompletionResponse]:
|
34 |
+
"""Completion API similar to OpenAI's API.
|
35 |
+
|
36 |
+
See https://platform.openai.com/docs/api-reference/chat/create
|
37 |
+
for the API specification. This API mimics the OpenAI ChatCompletion API.
|
38 |
+
|
39 |
+
NOTE: Currently we do not support the following features:
|
40 |
+
- function_call (Users should implement this by themselves)
|
41 |
+
- logit_bias (to be supported by vLLM engine)
|
42 |
+
"""
|
43 |
+
error_check_ret = await self._check_model(request)
|
44 |
+
if error_check_ret is not None:
|
45 |
+
return error_check_ret
|
46 |
+
|
47 |
+
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
48 |
+
# TODO: support logit_bias in vLLM engine.
|
49 |
+
return self.create_error_response(
|
50 |
+
"logit_bias is not currently supported")
|
51 |
+
|
52 |
+
try:
|
53 |
+
prompt = self.tokenizer.apply_chat_template(
|
54 |
+
conversation=request.messages,
|
55 |
+
tokenize=False,
|
56 |
+
add_generation_prompt=request.add_generation_prompt)
|
57 |
+
except Exception as e:
|
58 |
+
logger.error(
|
59 |
+
f"Error in applying chat template from request: {str(e)}")
|
60 |
+
return self.create_error_response(str(e))
|
61 |
+
|
62 |
+
request_id = f"cmpl-{random_uuid()}"
|
63 |
+
try:
|
64 |
+
token_ids = self._validate_prompt_and_tokenize(request,
|
65 |
+
prompt=prompt)
|
66 |
+
sampling_params = request.to_sampling_params()
|
67 |
+
except ValueError as e:
|
68 |
+
return self.create_error_response(str(e))
|
69 |
+
|
70 |
+
result_generator = self.engine.generate(prompt, sampling_params,
|
71 |
+
request_id, token_ids)
|
72 |
+
# Streaming response
|
73 |
+
if request.stream:
|
74 |
+
return self.chat_completion_stream_generator(
|
75 |
+
request, result_generator, request_id)
|
76 |
+
else:
|
77 |
+
return await self.chat_completion_full_generator(
|
78 |
+
request, raw_request, result_generator, request_id)
|
79 |
+
|
80 |
+
def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
|
81 |
+
if request.add_generation_prompt:
|
82 |
+
return self.response_role
|
83 |
+
else:
|
84 |
+
return request.messages[-1].role
|
85 |
+
|
86 |
+
async def chat_completion_stream_generator(
|
87 |
+
self, request: ChatCompletionRequest,
|
88 |
+
result_generator: AsyncIterator[RequestOutput], request_id: str
|
89 |
+
) -> Union[ErrorResponse, AsyncGenerator[str, None]]:
|
90 |
+
|
91 |
+
model_name = request.model
|
92 |
+
created_time = int(time.monotonic())
|
93 |
+
chunk_object_type = "chat.completion.chunk"
|
94 |
+
|
95 |
+
# Send first response for each request.n (index) with the role
|
96 |
+
role = self.get_chat_request_role(request)
|
97 |
+
for i in range(request.n):
|
98 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
99 |
+
index=i, delta=DeltaMessage(role=role), finish_reason=None)
|
100 |
+
chunk = ChatCompletionStreamResponse(id=request_id,
|
101 |
+
object=chunk_object_type,
|
102 |
+
created=created_time,
|
103 |
+
choices=[choice_data],
|
104 |
+
model=model_name)
|
105 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
106 |
+
yield f"data: {data}\n\n"
|
107 |
+
|
108 |
+
# Send response to echo the input portion of the last message
|
109 |
+
if request.echo:
|
110 |
+
last_msg_content = ""
|
111 |
+
if request.messages and isinstance(
|
112 |
+
request.messages, list) and request.messages[-1].get(
|
113 |
+
"content") and request.messages[-1].get(
|
114 |
+
"role") == role:
|
115 |
+
last_msg_content = request.messages[-1]["content"]
|
116 |
+
if last_msg_content:
|
117 |
+
for i in range(request.n):
|
118 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
119 |
+
index=i,
|
120 |
+
delta=DeltaMessage(content=last_msg_content),
|
121 |
+
finish_reason=None)
|
122 |
+
chunk = ChatCompletionStreamResponse(
|
123 |
+
id=request_id,
|
124 |
+
object=chunk_object_type,
|
125 |
+
created=created_time,
|
126 |
+
choices=[choice_data],
|
127 |
+
model=model_name)
|
128 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
129 |
+
yield f"data: {data}\n\n"
|
130 |
+
|
131 |
+
# Send response for each token for each request.n (index)
|
132 |
+
previous_texts = [""] * request.n
|
133 |
+
previous_num_tokens = [0] * request.n
|
134 |
+
finish_reason_sent = [False] * request.n
|
135 |
+
async for res in result_generator:
|
136 |
+
res: RequestOutput
|
137 |
+
for output in res.outputs:
|
138 |
+
i = output.index
|
139 |
+
|
140 |
+
if finish_reason_sent[i]:
|
141 |
+
continue
|
142 |
+
|
143 |
+
delta_text = output.text[len(previous_texts[i]):]
|
144 |
+
previous_texts[i] = output.text
|
145 |
+
previous_num_tokens[i] = len(output.token_ids)
|
146 |
+
|
147 |
+
if output.finish_reason is None:
|
148 |
+
# Send token-by-token response for each request.n
|
149 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
150 |
+
index=i,
|
151 |
+
delta=DeltaMessage(content=delta_text),
|
152 |
+
finish_reason=None)
|
153 |
+
chunk = ChatCompletionStreamResponse(
|
154 |
+
id=request_id,
|
155 |
+
object=chunk_object_type,
|
156 |
+
created=created_time,
|
157 |
+
choices=[choice_data],
|
158 |
+
model=model_name)
|
159 |
+
data = chunk.model_dump_json(exclude_unset=True)
|
160 |
+
yield f"data: {data}\n\n"
|
161 |
+
else:
|
162 |
+
# Send the finish response for each request.n only once
|
163 |
+
prompt_tokens = len(res.prompt_token_ids)
|
164 |
+
final_usage = UsageInfo(
|
165 |
+
prompt_tokens=prompt_tokens,
|
166 |
+
completion_tokens=previous_num_tokens[i],
|
167 |
+
total_tokens=prompt_tokens + previous_num_tokens[i],
|
168 |
+
)
|
169 |
+
choice_data = ChatCompletionResponseStreamChoice(
|
170 |
+
index=i,
|
171 |
+
delta=DeltaMessage(content=delta_text),
|
172 |
+
finish_reason=output.finish_reason)
|
173 |
+
chunk = ChatCompletionStreamResponse(
|
174 |
+
id=request_id,
|
175 |
+
object=chunk_object_type,
|
176 |
+
created=created_time,
|
177 |
+
choices=[choice_data],
|
178 |
+
model=model_name)
|
179 |
+
if final_usage is not None:
|
180 |
+
chunk.usage = final_usage
|
181 |
+
data = chunk.model_dump_json(exclude_unset=True,
|
182 |
+
exclude_none=True)
|
183 |
+
yield f"data: {data}\n\n"
|
184 |
+
finish_reason_sent[i] = True
|
185 |
+
# Send the final done message after all response.n are finished
|
186 |
+
yield "data: [DONE]\n\n"
|
187 |
+
|
188 |
+
async def chat_completion_full_generator(
|
189 |
+
self, request: ChatCompletionRequest, raw_request: Request,
|
190 |
+
result_generator: AsyncIterator[RequestOutput],
|
191 |
+
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
|
192 |
+
|
193 |
+
model_name = request.model
|
194 |
+
created_time = int(time.monotonic())
|
195 |
+
final_res: RequestOutput = None
|
196 |
+
|
197 |
+
async for res in result_generator:
|
198 |
+
if await raw_request.is_disconnected():
|
199 |
+
# Abort the request if the client disconnects.
|
200 |
+
await self.engine.abort(request_id)
|
201 |
+
return self.create_error_response("Client disconnected")
|
202 |
+
final_res = res
|
203 |
+
assert final_res is not None
|
204 |
+
|
205 |
+
choices = []
|
206 |
+
role = self.get_chat_request_role(request)
|
207 |
+
for output in final_res.outputs:
|
208 |
+
choice_data = ChatCompletionResponseChoice(
|
209 |
+
index=output.index,
|
210 |
+
message=ChatMessage(role=role, content=output.text),
|
211 |
+
finish_reason=output.finish_reason,
|
212 |
+
)
|
213 |
+
choices.append(choice_data)
|
214 |
+
|
215 |
+
if request.echo:
|
216 |
+
last_msg_content = ""
|
217 |
+
if request.messages and isinstance(
|
218 |
+
request.messages, list) and request.messages[-1].get(
|
219 |
+
"content") and request.messages[-1].get(
|
220 |
+
"role") == role:
|
221 |
+
last_msg_content = request.messages[-1]["content"]
|
222 |
+
|
223 |
+
for choice in choices:
|
224 |
+
full_message = last_msg_content + choice.message.content
|
225 |
+
choice.message.content = full_message
|
226 |
+
|
227 |
+
num_prompt_tokens = len(final_res.prompt_token_ids)
|
228 |
+
num_generated_tokens = sum(
|
229 |
+
len(output.token_ids) for output in final_res.outputs)
|
230 |
+
usage = UsageInfo(
|
231 |
+
prompt_tokens=num_prompt_tokens,
|
232 |
+
completion_tokens=num_generated_tokens,
|
233 |
+
total_tokens=num_prompt_tokens + num_generated_tokens,
|
234 |
+
)
|
235 |
+
response = ChatCompletionResponse(
|
236 |
+
id=request_id,
|
237 |
+
created=created_time,
|
238 |
+
model=model_name,
|
239 |
+
choices=choices,
|
240 |
+
usage=usage,
|
241 |
+
)
|
242 |
+
|
243 |
+
return response
|
244 |
+
|
245 |
+
def _load_chat_template(self, chat_template):
|
246 |
+
if chat_template is not None:
|
247 |
+
try:
|
248 |
+
with open(chat_template, "r") as f:
|
249 |
+
self.tokenizer.chat_template = f.read()
|
250 |
+
except OSError:
|
251 |
+
# If opening a file fails, set chat template to be args to
|
252 |
+
# ensure we decode so our escape are interpreted correctly
|
253 |
+
self.tokenizer.chat_template = codecs.decode(
|
254 |
+
chat_template, "unicode_escape")
|
255 |
+
|
256 |
+
logger.info(
|
257 |
+
f"Using supplied chat template:\n{self.tokenizer.chat_template}"
|
258 |
+
)
|
259 |
+
elif self.tokenizer.chat_template is not None:
|
260 |
+
logger.info(
|
261 |
+
f"Using default chat template:\n{self.tokenizer.chat_template}"
|
262 |
+
)
|
263 |
+
else:
|
264 |
+
logger.warning(
|
265 |
+
"No chat template provided. Chat API will not work.")
|
serving_completion.py
ADDED
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
from fastapi import Request
|
3 |
+
from typing import AsyncGenerator, AsyncIterator
|
4 |
+
from vllm.logger import init_logger
|
5 |
+
from vllm.utils import random_uuid
|
6 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
7 |
+
from .protocol import (
|
8 |
+
CompletionRequest,
|
9 |
+
CompletionResponse,
|
10 |
+
CompletionResponseChoice,
|
11 |
+
CompletionResponseStreamChoice,
|
12 |
+
CompletionStreamResponse,
|
13 |
+
LogProbs,
|
14 |
+
UsageInfo,
|
15 |
+
)
|
16 |
+
from vllm.outputs import RequestOutput
|
17 |
+
from serving_engine import OpenAIServing
|
18 |
+
|
19 |
+
logger = init_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
async def completion_stream_generator(
|
23 |
+
request: CompletionRequest,
|
24 |
+
result_generator: AsyncIterator[RequestOutput],
|
25 |
+
echo_without_generation, create_logprobs_fn, request_id, created_time,
|
26 |
+
model_name) -> AsyncGenerator[str, None]:
|
27 |
+
previous_texts = [""] * request.n
|
28 |
+
previous_num_tokens = [0] * request.n
|
29 |
+
has_echoed = [False] * request.n
|
30 |
+
|
31 |
+
async for res in result_generator:
|
32 |
+
# TODO: handle client disconnect for streaming
|
33 |
+
for output in res.outputs:
|
34 |
+
i = output.index
|
35 |
+
delta_text = output.text[len(previous_texts[i]):]
|
36 |
+
token_ids = output.token_ids[previous_num_tokens[i]:]
|
37 |
+
if request.logprobs is not None:
|
38 |
+
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
39 |
+
else:
|
40 |
+
top_logprobs = None
|
41 |
+
offsets = len(previous_texts[i])
|
42 |
+
if request.echo and not has_echoed[i]:
|
43 |
+
if not echo_without_generation:
|
44 |
+
delta_text = res.prompt + delta_text
|
45 |
+
token_ids = res.prompt_token_ids + token_ids
|
46 |
+
if top_logprobs:
|
47 |
+
top_logprobs = res.prompt_logprobs + top_logprobs
|
48 |
+
else: # only just return the prompt
|
49 |
+
delta_text = res.prompt
|
50 |
+
token_ids = res.prompt_token_ids
|
51 |
+
if top_logprobs:
|
52 |
+
top_logprobs = res.prompt_logprobs
|
53 |
+
has_echoed[i] = True
|
54 |
+
if request.logprobs is not None:
|
55 |
+
logprobs = create_logprobs_fn(
|
56 |
+
token_ids=token_ids,
|
57 |
+
top_logprobs=top_logprobs,
|
58 |
+
num_output_top_logprobs=request.logprobs,
|
59 |
+
initial_text_offset=offsets,
|
60 |
+
)
|
61 |
+
else:
|
62 |
+
logprobs = None
|
63 |
+
previous_texts[i] = output.text
|
64 |
+
previous_num_tokens[i] = len(output.token_ids)
|
65 |
+
finish_reason = output.finish_reason
|
66 |
+
response_json = CompletionStreamResponse(
|
67 |
+
id=request_id,
|
68 |
+
created=created_time,
|
69 |
+
model=model_name,
|
70 |
+
choices=[
|
71 |
+
CompletionResponseStreamChoice(
|
72 |
+
index=i,
|
73 |
+
text=delta_text,
|
74 |
+
logprobs=logprobs,
|
75 |
+
finish_reason=finish_reason,
|
76 |
+
)
|
77 |
+
]).model_dump_json(exclude_unset=True)
|
78 |
+
yield f"data: {response_json}\n\n"
|
79 |
+
|
80 |
+
if output.finish_reason is not None:
|
81 |
+
logprobs = LogProbs() if request.logprobs is not None else None
|
82 |
+
prompt_tokens = len(res.prompt_token_ids)
|
83 |
+
completion_tokens = len(output.token_ids)
|
84 |
+
final_usage = UsageInfo(
|
85 |
+
prompt_tokens=prompt_tokens,
|
86 |
+
completion_tokens=completion_tokens,
|
87 |
+
total_tokens=prompt_tokens + completion_tokens,
|
88 |
+
)
|
89 |
+
response_json = CompletionStreamResponse(
|
90 |
+
id=request_id,
|
91 |
+
created=created_time,
|
92 |
+
model=model_name,
|
93 |
+
choices=[
|
94 |
+
CompletionResponseStreamChoice(
|
95 |
+
index=i,
|
96 |
+
text="",
|
97 |
+
logprobs=logprobs,
|
98 |
+
finish_reason=output.finish_reason,
|
99 |
+
)
|
100 |
+
],
|
101 |
+
usage=final_usage,
|
102 |
+
).model_dump_json(exclude_unset=True)
|
103 |
+
yield f"data: {response_json}\n\n"
|
104 |
+
|
105 |
+
yield "data: [DONE]\n\n"
|
106 |
+
|
107 |
+
|
108 |
+
def parse_prompt_format(prompt) -> tuple[bool, list]:
|
109 |
+
# get the prompt, openai supports the following
|
110 |
+
# "a string, array of strings, array of tokens, or array of token arrays."
|
111 |
+
prompt_is_tokens = False
|
112 |
+
prompts = [prompt] # case 1: a string
|
113 |
+
if isinstance(prompt, list):
|
114 |
+
if len(prompt) == 0:
|
115 |
+
raise ValueError("please provide at least one prompt")
|
116 |
+
elif isinstance(prompt[0], str):
|
117 |
+
prompt_is_tokens = False
|
118 |
+
prompts = prompt # case 2: array of strings
|
119 |
+
elif isinstance(prompt[0], int):
|
120 |
+
prompt_is_tokens = True
|
121 |
+
prompts = [prompt] # case 3: array of tokens
|
122 |
+
elif isinstance(prompt[0], list) and isinstance(prompt[0][0], int):
|
123 |
+
prompt_is_tokens = True
|
124 |
+
prompts = prompt # case 4: array of token arrays
|
125 |
+
else:
|
126 |
+
raise ValueError(
|
127 |
+
"prompt must be a string, array of strings, array of tokens, or array of token arrays"
|
128 |
+
)
|
129 |
+
return prompt_is_tokens, prompts
|
130 |
+
|
131 |
+
|
132 |
+
def request_output_to_completion_response(final_res: RequestOutput, request,
|
133 |
+
echo_without_generation,
|
134 |
+
create_logprobs_fn, request_id,
|
135 |
+
created_time,
|
136 |
+
model_name) -> CompletionResponse:
|
137 |
+
assert final_res is not None
|
138 |
+
choices = []
|
139 |
+
prompt_token_ids = final_res.prompt_token_ids
|
140 |
+
prompt_logprobs = final_res.prompt_logprobs
|
141 |
+
prompt_text = final_res.prompt
|
142 |
+
for output in final_res.outputs:
|
143 |
+
if request.logprobs is not None:
|
144 |
+
if not echo_without_generation:
|
145 |
+
token_ids = output.token_ids
|
146 |
+
top_logprobs = output.logprobs
|
147 |
+
if request.echo:
|
148 |
+
token_ids = prompt_token_ids + token_ids
|
149 |
+
top_logprobs = prompt_logprobs + top_logprobs
|
150 |
+
else:
|
151 |
+
token_ids = prompt_token_ids
|
152 |
+
top_logprobs = prompt_logprobs
|
153 |
+
logprobs = create_logprobs_fn(
|
154 |
+
token_ids=token_ids,
|
155 |
+
top_logprobs=top_logprobs,
|
156 |
+
num_output_top_logprobs=request.logprobs,
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
logprobs = None
|
160 |
+
if not echo_without_generation:
|
161 |
+
output_text = output.text
|
162 |
+
if request.echo:
|
163 |
+
output_text = prompt_text + output_text
|
164 |
+
else:
|
165 |
+
output_text = prompt_text
|
166 |
+
choice_data = CompletionResponseChoice(
|
167 |
+
index=output.index,
|
168 |
+
text=output_text,
|
169 |
+
logprobs=logprobs,
|
170 |
+
finish_reason=output.finish_reason,
|
171 |
+
)
|
172 |
+
choices.append(choice_data)
|
173 |
+
|
174 |
+
num_prompt_tokens = len(final_res.prompt_token_ids)
|
175 |
+
num_generated_tokens = sum(
|
176 |
+
len(output.token_ids) for output in final_res.outputs)
|
177 |
+
usage = UsageInfo(
|
178 |
+
prompt_tokens=num_prompt_tokens,
|
179 |
+
completion_tokens=num_generated_tokens,
|
180 |
+
total_tokens=num_prompt_tokens + num_generated_tokens,
|
181 |
+
)
|
182 |
+
|
183 |
+
return CompletionResponse(
|
184 |
+
id=request_id,
|
185 |
+
created=created_time,
|
186 |
+
model=model_name,
|
187 |
+
choices=choices,
|
188 |
+
usage=usage,
|
189 |
+
)
|
190 |
+
|
191 |
+
|
192 |
+
class OpenAIServingCompletion(OpenAIServing):
|
193 |
+
|
194 |
+
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
195 |
+
super().__init__(engine=engine, served_model=served_model)
|
196 |
+
|
197 |
+
async def create_completion(self, request: CompletionRequest,
|
198 |
+
raw_request: Request):
|
199 |
+
"""Completion API similar to OpenAI's API.
|
200 |
+
|
201 |
+
See https://platform.openai.com/docs/api-reference/completions/create
|
202 |
+
for the API specification. This API mimics the OpenAI Completion API.
|
203 |
+
|
204 |
+
NOTE: Currently we do not support the following features:
|
205 |
+
- suffix (the language models we currently support do not support
|
206 |
+
suffix)
|
207 |
+
- logit_bias (to be supported by vLLM engine)
|
208 |
+
"""
|
209 |
+
error_check_ret = await self._check_model(request)
|
210 |
+
if error_check_ret is not None:
|
211 |
+
return error_check_ret
|
212 |
+
|
213 |
+
# OpenAI API supports echoing the prompt when max_tokens is 0.
|
214 |
+
echo_without_generation = request.echo and request.max_tokens == 0
|
215 |
+
|
216 |
+
# Return error for unsupported features.
|
217 |
+
if request.suffix is not None:
|
218 |
+
return self.create_error_response(
|
219 |
+
"suffix is not currently supported")
|
220 |
+
if request.logit_bias is not None and len(request.logit_bias) > 0:
|
221 |
+
return self.create_error_response(
|
222 |
+
"logit_bias is not currently supported")
|
223 |
+
|
224 |
+
model_name = request.model
|
225 |
+
request_id = f"cmpl-{random_uuid()}"
|
226 |
+
created_time = int(time.monotonic())
|
227 |
+
|
228 |
+
# Schedule the request and get the result generator.
|
229 |
+
try:
|
230 |
+
sampling_params = request.to_sampling_params()
|
231 |
+
|
232 |
+
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)
|
233 |
+
|
234 |
+
if len(prompts) > 1:
|
235 |
+
raise ValueError(
|
236 |
+
"Batching in completion API is not supported.")
|
237 |
+
prompt = prompts[0]
|
238 |
+
|
239 |
+
if prompt_is_tokens:
|
240 |
+
input_ids = self._validate_prompt_and_tokenize(
|
241 |
+
request, prompt_ids=prompt)
|
242 |
+
else:
|
243 |
+
input_ids = self._validate_prompt_and_tokenize(request,
|
244 |
+
prompt=prompt)
|
245 |
+
|
246 |
+
result_generator = self.engine.generate(None,
|
247 |
+
sampling_params,
|
248 |
+
request_id,
|
249 |
+
prompt_token_ids=input_ids)
|
250 |
+
except ValueError as e:
|
251 |
+
return self.create_error_response(str(e))
|
252 |
+
|
253 |
+
# Similar to the OpenAI API, when n != best_of, we do not stream the
|
254 |
+
# results. In addition, we do not stream the results when use beam search.
|
255 |
+
stream = (request.stream
|
256 |
+
and (request.best_of is None or request.n == request.best_of)
|
257 |
+
and not request.use_beam_search)
|
258 |
+
|
259 |
+
# Streaming response
|
260 |
+
if stream:
|
261 |
+
return completion_stream_generator(request, result_generator,
|
262 |
+
echo_without_generation,
|
263 |
+
self._create_logprobs,
|
264 |
+
request_id, created_time,
|
265 |
+
model_name)
|
266 |
+
|
267 |
+
# Non-streaming response
|
268 |
+
final_res: RequestOutput = None
|
269 |
+
async for res in result_generator:
|
270 |
+
if await raw_request.is_disconnected():
|
271 |
+
# Abort the request if the client disconnects.
|
272 |
+
await self.engine.abort(request_id)
|
273 |
+
return self.create_error_response("Client disconnected")
|
274 |
+
final_res = res
|
275 |
+
response = request_output_to_completion_response(
|
276 |
+
final_res, request, echo_without_generation, self._create_logprobs,
|
277 |
+
request_id, created_time, model_name)
|
278 |
+
|
279 |
+
# When user requests streaming but we don't stream, we still need to
|
280 |
+
# return a streaming response with a single event.
|
281 |
+
if request.stream:
|
282 |
+
response_json = response.model_dump_json()
|
283 |
+
|
284 |
+
async def fake_stream_generator() -> AsyncGenerator[str, None]:
|
285 |
+
yield f"data: {response_json}\n\n"
|
286 |
+
yield "data: [DONE]\n\n"
|
287 |
+
|
288 |
+
return fake_stream_generator()
|
289 |
+
|
290 |
+
return response
|
serving_engine.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from http import HTTPStatus
|
3 |
+
from typing import Dict, List, Optional, Union
|
4 |
+
from vllm.logger import init_logger
|
5 |
+
from vllm.transformers_utils.tokenizer import get_tokenizer
|
6 |
+
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
7 |
+
from protocol import (CompletionRequest,
|
8 |
+
ChatCompletionRequest,
|
9 |
+
ErrorResponse, LogProbs,
|
10 |
+
ModelCard, ModelList,
|
11 |
+
ModelPermission)
|
12 |
+
|
13 |
+
logger = init_logger(__name__)
|
14 |
+
|
15 |
+
|
16 |
+
class OpenAIServing:
|
17 |
+
|
18 |
+
def __init__(self, engine: AsyncLLMEngine, served_model: str):
|
19 |
+
self.engine = engine
|
20 |
+
self.served_model = served_model
|
21 |
+
|
22 |
+
self.max_model_len = 0
|
23 |
+
self.tokenizer = None
|
24 |
+
|
25 |
+
try:
|
26 |
+
event_loop = asyncio.get_running_loop()
|
27 |
+
except RuntimeError:
|
28 |
+
event_loop = None
|
29 |
+
|
30 |
+
if event_loop is not None and event_loop.is_running(
|
31 |
+
): # If the current is instanced by Ray Serve, there is already a running event loop
|
32 |
+
event_loop.create_task(self._post_init())
|
33 |
+
else: # When using single vLLM without engine_use_ray
|
34 |
+
asyncio.run(self._post_init())
|
35 |
+
|
36 |
+
async def _post_init(self):
|
37 |
+
engine_model_config = await self.engine.get_model_config()
|
38 |
+
self.max_model_len = engine_model_config.max_model_len
|
39 |
+
|
40 |
+
# A separate tokenizer to map token IDs to strings.
|
41 |
+
self.tokenizer = get_tokenizer(
|
42 |
+
engine_model_config.tokenizer,
|
43 |
+
tokenizer_mode=engine_model_config.tokenizer_mode,
|
44 |
+
trust_remote_code=engine_model_config.trust_remote_code)
|
45 |
+
|
46 |
+
async def show_available_models(self) -> ModelList:
|
47 |
+
"""Show available models. Right now we only have one model."""
|
48 |
+
model_cards = [
|
49 |
+
ModelCard(id=self.served_model,
|
50 |
+
root=self.served_model,
|
51 |
+
permission=[ModelPermission()])
|
52 |
+
]
|
53 |
+
return ModelList(data=model_cards)
|
54 |
+
|
55 |
+
def _create_logprobs(
|
56 |
+
self,
|
57 |
+
token_ids: List[int],
|
58 |
+
top_logprobs: Optional[List[Optional[Dict[int, float]]]] = None,
|
59 |
+
num_output_top_logprobs: Optional[int] = None,
|
60 |
+
initial_text_offset: int = 0,
|
61 |
+
) -> LogProbs:
|
62 |
+
"""Create OpenAI-style logprobs."""
|
63 |
+
logprobs = LogProbs()
|
64 |
+
last_token_len = 0
|
65 |
+
if num_output_top_logprobs:
|
66 |
+
logprobs.top_logprobs = []
|
67 |
+
for i, token_id in enumerate(token_ids):
|
68 |
+
step_top_logprobs = top_logprobs[i]
|
69 |
+
if step_top_logprobs is not None:
|
70 |
+
token_logprob = step_top_logprobs[token_id]
|
71 |
+
else:
|
72 |
+
token_logprob = None
|
73 |
+
token = self.tokenizer.convert_ids_to_tokens(token_id)
|
74 |
+
logprobs.tokens.append(token)
|
75 |
+
logprobs.token_logprobs.append(token_logprob)
|
76 |
+
if len(logprobs.text_offset) == 0:
|
77 |
+
logprobs.text_offset.append(initial_text_offset)
|
78 |
+
else:
|
79 |
+
logprobs.text_offset.append(logprobs.text_offset[-1] +
|
80 |
+
last_token_len)
|
81 |
+
last_token_len = len(token)
|
82 |
+
|
83 |
+
if num_output_top_logprobs:
|
84 |
+
logprobs.top_logprobs.append({
|
85 |
+
self.tokenizer.convert_ids_to_tokens(i): p
|
86 |
+
for i, p in step_top_logprobs.items()
|
87 |
+
} if step_top_logprobs else None)
|
88 |
+
return logprobs
|
89 |
+
|
90 |
+
def create_error_response(
|
91 |
+
self,
|
92 |
+
message: str,
|
93 |
+
err_type: str = "BadRequestError",
|
94 |
+
status_code: HTTPStatus = HTTPStatus.BAD_REQUEST) -> ErrorResponse:
|
95 |
+
return ErrorResponse(message=message,
|
96 |
+
type=err_type,
|
97 |
+
code=status_code.value)
|
98 |
+
|
99 |
+
async def _check_model(self, request) -> Optional[ErrorResponse]:
|
100 |
+
if request.model == self.served_model:
|
101 |
+
return
|
102 |
+
return self.create_error_response(
|
103 |
+
message=f"The model `{request.model}` does not exist.",
|
104 |
+
err_type="NotFoundError",
|
105 |
+
status_code=HTTPStatus.NOT_FOUND)
|
106 |
+
|
107 |
+
def _validate_prompt_and_tokenize(
|
108 |
+
self,
|
109 |
+
request: Union[ChatCompletionRequest, CompletionRequest],
|
110 |
+
prompt: Optional[str] = None,
|
111 |
+
prompt_ids: Optional[List[int]] = None) -> List[int]:
|
112 |
+
if not (prompt or prompt_ids):
|
113 |
+
raise ValueError("Either prompt or prompt_ids should be provided.")
|
114 |
+
if (prompt and prompt_ids):
|
115 |
+
raise ValueError(
|
116 |
+
"Only one of prompt or prompt_ids should be provided.")
|
117 |
+
|
118 |
+
input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
|
119 |
+
prompt).input_ids
|
120 |
+
token_num = len(input_ids)
|
121 |
+
|
122 |
+
if request.max_tokens is None:
|
123 |
+
request.max_tokens = self.max_model_len - token_num
|
124 |
+
|
125 |
+
if token_num + request.max_tokens > self.max_model_len:
|
126 |
+
raise ValueError(
|
127 |
+
f"This model's maximum context length is {self.max_model_len} tokens. "
|
128 |
+
f"However, you requested {request.max_tokens + token_num} tokens "
|
129 |
+
f"({token_num} in the messages, "
|
130 |
+
f"{request.max_tokens} in the completion). "
|
131 |
+
f"Please reduce the length of the messages or completion.", )
|
132 |
+
else:
|
133 |
+
return input_ids
|