Spaces:
Runtime error
Runtime error
"""A server that provides OpenAI-compatible RESTful APIs. It supports: | |
- Chat Completions. (Reference: https://platform.openai.com/docs/api-reference/chat) | |
- Completions. (Reference: https://platform.openai.com/docs/api-reference/completions) | |
- Embeddings. (Reference: https://platform.openai.com/docs/api-reference/embeddings) | |
Usage: | |
python3 -m fastchat.serve.openai_api_server | |
""" | |
import asyncio | |
import argparse | |
import json | |
import logging | |
import os | |
from typing import Generator, Optional, Union, Dict, List, Any | |
import aiohttp | |
import fastapi | |
from fastapi import Depends, HTTPException | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse, JSONResponse | |
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBearer | |
import httpx | |
from pydantic import BaseSettings | |
import shortuuid | |
import tiktoken | |
import uvicorn | |
from fastchat.constants import ( | |
WORKER_API_TIMEOUT, | |
WORKER_API_EMBEDDING_BATCH_SIZE, | |
ErrorCode, | |
) | |
from fastchat.conversation import Conversation, SeparatorStyle | |
from fastchat.protocol.openai_api_protocol import ( | |
ChatCompletionRequest, | |
ChatCompletionResponse, | |
ChatCompletionResponseStreamChoice, | |
ChatCompletionStreamResponse, | |
ChatMessage, | |
ChatCompletionResponseChoice, | |
CompletionRequest, | |
CompletionResponse, | |
CompletionResponseChoice, | |
DeltaMessage, | |
CompletionResponseStreamChoice, | |
CompletionStreamResponse, | |
EmbeddingsRequest, | |
EmbeddingsResponse, | |
ErrorResponse, | |
LogProbs, | |
ModelCard, | |
ModelList, | |
ModelPermission, | |
UsageInfo, | |
) | |
from fastchat.protocol.api_protocol import ( | |
APIChatCompletionRequest, | |
APITokenCheckRequest, | |
APITokenCheckResponse, | |
APITokenCheckResponseItem, | |
) | |
logger = logging.getLogger(__name__) | |
conv_template_map = {} | |
fetch_timeout = aiohttp.ClientTimeout(total=3 * 3600) | |
async def fetch_remote(url, pload=None, name=None): | |
async with aiohttp.ClientSession(timeout=fetch_timeout) as session: | |
async with session.post(url, json=pload) as response: | |
chunks = [] | |
if response.status != 200: | |
ret = { | |
"text": f"{response.reason}", | |
"error_code": ErrorCode.INTERNAL_ERROR, | |
} | |
return json.dumps(ret) | |
async for chunk, _ in response.content.iter_chunks(): | |
chunks.append(chunk) | |
output = b"".join(chunks) | |
if name is not None: | |
res = json.loads(output) | |
if name != "": | |
res = res[name] | |
return res | |
return output | |
class AppSettings(BaseSettings): | |
# The address of the model controller. | |
controller_address: str = "http://localhost:21001" | |
api_keys: Optional[List[str]] = None | |
app_settings = AppSettings() | |
app = fastapi.FastAPI() | |
headers = {"User-Agent": "FastChat API Server"} | |
get_bearer_token = HTTPBearer(auto_error=False) | |
async def check_api_key( | |
auth: Optional[HTTPAuthorizationCredentials] = Depends(get_bearer_token), | |
) -> str: | |
if app_settings.api_keys: | |
if auth is None or (token := auth.credentials) not in app_settings.api_keys: | |
raise HTTPException( | |
status_code=401, | |
detail={ | |
"error": { | |
"message": "", | |
"type": "invalid_request_error", | |
"param": None, | |
"code": "invalid_api_key", | |
} | |
}, | |
) | |
return token | |
else: | |
# api_keys not set; allow all | |
return None | |
def create_error_response(code: int, message: str) -> JSONResponse: | |
return JSONResponse( | |
ErrorResponse(message=message, code=code).dict(), status_code=400 | |
) | |
async def validation_exception_handler(request, exc): | |
return create_error_response(ErrorCode.VALIDATION_TYPE_ERROR, str(exc)) | |
async def check_model(request) -> Optional[JSONResponse]: | |
controller_address = app_settings.controller_address | |
ret = None | |
models = await fetch_remote(controller_address + "/list_models", None, "models") | |
if request.model not in models: | |
ret = create_error_response( | |
ErrorCode.INVALID_MODEL, | |
f"Only {'&&'.join(models)} allowed now, your model {request.model}", | |
) | |
return ret | |
async def check_length(request, prompt, max_tokens, worker_addr): | |
if ( | |
not isinstance(max_tokens, int) or max_tokens <= 0 | |
): # model worker not support max_tokens=None | |
max_tokens = 1024 * 1024 | |
context_len = await fetch_remote( | |
worker_addr + "/model_details", {"model": request.model}, "context_length" | |
) | |
token_num = await fetch_remote( | |
worker_addr + "/count_token", | |
{"model": request.model, "prompt": prompt}, | |
"count", | |
) | |
return min(max_tokens, context_len - token_num) | |
def check_requests(request) -> Optional[JSONResponse]: | |
# Check all params | |
if request.max_tokens is not None and request.max_tokens <= 0: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.max_tokens} is less than the minimum of 1 - 'max_tokens'", | |
) | |
if request.n is not None and request.n <= 0: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.n} is less than the minimum of 1 - 'n'", | |
) | |
if request.temperature is not None and request.temperature < 0: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.temperature} is less than the minimum of 0 - 'temperature'", | |
) | |
if request.temperature is not None and request.temperature > 2: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.temperature} is greater than the maximum of 2 - 'temperature'", | |
) | |
if request.top_p is not None and request.top_p < 0: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.top_p} is less than the minimum of 0 - 'top_p'", | |
) | |
if request.top_p is not None and request.top_p > 1: | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.top_p} is greater than the maximum of 1 - 'temperature'", | |
) | |
if request.top_k is not None and (request.top_k > -1 and request.top_k < 1): | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.top_k} is out of Range. Either set top_k to -1 or >=1.", | |
) | |
if request.stop is not None and ( | |
not isinstance(request.stop, str) and not isinstance(request.stop, list) | |
): | |
return create_error_response( | |
ErrorCode.PARAM_OUT_OF_RANGE, | |
f"{request.stop} is not valid under any of the given schemas - 'stop'", | |
) | |
return None | |
def process_input(model_name, inp): | |
if isinstance(inp, str): | |
inp = [inp] | |
elif isinstance(inp, list): | |
if isinstance(inp[0], int): | |
decoding = tiktoken.model.encoding_for_model(model_name) | |
inp = [decoding.decode(inp)] | |
elif isinstance(inp[0], list): | |
decoding = tiktoken.model.encoding_for_model(model_name) | |
inp = [decoding.decode(text) for text in inp] | |
return inp | |
def create_openai_logprobs(logprob_dict): | |
"""Create OpenAI-style logprobs.""" | |
return LogProbs(**logprob_dict) if logprob_dict is not None else None | |
def _add_to_set(s, new_stop): | |
if not s: | |
return | |
if isinstance(s, str): | |
new_stop.add(s) | |
else: | |
new_stop.update(s) | |
async def get_gen_params( | |
model_name: str, | |
worker_addr: str, | |
messages: Union[str, List[Dict[str, str]]], | |
*, | |
temperature: float, | |
top_p: float, | |
top_k: Optional[int], | |
presence_penalty: Optional[float], | |
frequency_penalty: Optional[float], | |
max_tokens: Optional[int], | |
echo: Optional[bool], | |
logprobs: Optional[int] = None, | |
stop: Optional[Union[str, List[str]]], | |
best_of: Optional[int] = None, | |
use_beam_search: Optional[bool] = None, | |
) -> Dict[str, Any]: | |
conv = await get_conv(model_name, worker_addr) | |
conv = Conversation( | |
name=conv["name"], | |
system_template=conv["system_template"], | |
system_message=conv["system_message"], | |
roles=conv["roles"], | |
messages=list(conv["messages"]), # prevent in-place modification | |
offset=conv["offset"], | |
sep_style=SeparatorStyle(conv["sep_style"]), | |
sep=conv["sep"], | |
sep2=conv["sep2"], | |
stop_str=conv["stop_str"], | |
stop_token_ids=conv["stop_token_ids"], | |
) | |
if isinstance(messages, str): | |
prompt = messages | |
else: | |
for message in messages: | |
msg_role = message["role"] | |
if msg_role == "system": | |
conv.set_system_message(message["content"]) | |
elif msg_role == "user": | |
conv.append_message(conv.roles[0], message["content"]) | |
elif msg_role == "assistant": | |
conv.append_message(conv.roles[1], message["content"]) | |
else: | |
raise ValueError(f"Unknown role: {msg_role}") | |
# Add a blank message for the assistant. | |
conv.append_message(conv.roles[1], None) | |
prompt = conv.get_prompt() | |
gen_params = { | |
"model": model_name, | |
"prompt": prompt, | |
"temperature": temperature, | |
"logprobs": logprobs, | |
"top_p": top_p, | |
"top_k": top_k, | |
"presence_penalty": presence_penalty, | |
"frequency_penalty": frequency_penalty, | |
"max_new_tokens": max_tokens, | |
"echo": echo, | |
"stop_token_ids": conv.stop_token_ids, | |
} | |
if best_of is not None: | |
gen_params.update({"best_of": best_of}) | |
if use_beam_search is not None: | |
gen_params.update({"use_beam_search": use_beam_search}) | |
new_stop = set() | |
_add_to_set(stop, new_stop) | |
_add_to_set(conv.stop_str, new_stop) | |
gen_params["stop"] = list(new_stop) | |
logger.debug(f"==== request ====\n{gen_params}") | |
return gen_params | |
async def get_worker_address(model_name: str) -> str: | |
""" | |
Get worker address based on the requested model | |
:param model_name: The worker's model name | |
:return: Worker address from the controller | |
:raises: :class:`ValueError`: No available worker for requested model | |
""" | |
controller_address = app_settings.controller_address | |
worker_addr = await fetch_remote( | |
controller_address + "/get_worker_address", {"model": model_name}, "address" | |
) | |
# No available worker | |
if worker_addr == "": | |
raise ValueError(f"No available worker for {model_name}") | |
logger.debug(f"model_name: {model_name}, worker_addr: {worker_addr}") | |
return worker_addr | |
async def get_conv(model_name: str, worker_addr: str): | |
conv_template = conv_template_map.get((worker_addr, model_name)) | |
if conv_template is None: | |
conv_template = await fetch_remote( | |
worker_addr + "/worker_get_conv_template", {"model": model_name}, "conv" | |
) | |
conv_template_map[(worker_addr, model_name)] = conv_template | |
return conv_template | |
async def show_available_models(): | |
controller_address = app_settings.controller_address | |
ret = await fetch_remote(controller_address + "/refresh_all_workers") | |
models = await fetch_remote(controller_address + "/list_models", None, "models") | |
models.sort() | |
# TODO: return real model permission details | |
model_cards = [] | |
for m in models: | |
model_cards.append(ModelCard(id=m, root=m, permission=[ModelPermission()])) | |
return ModelList(data=model_cards) | |
async def create_chat_completion(request: ChatCompletionRequest): | |
"""Creates a completion for the chat message""" | |
error_check_ret = await check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
error_check_ret = check_requests(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
worker_addr = await get_worker_address(request.model) | |
gen_params = await get_gen_params( | |
request.model, | |
worker_addr, | |
request.messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
presence_penalty=request.presence_penalty, | |
frequency_penalty=request.frequency_penalty, | |
max_tokens=request.max_tokens, | |
echo=False, | |
stop=request.stop, | |
) | |
gen_params["max_new_tokens"] = await check_length( | |
request, | |
gen_params["prompt"], | |
gen_params["max_new_tokens"], | |
worker_addr, | |
) | |
if request.stream: | |
generator = chat_completion_stream_generator( | |
request.model, gen_params, request.n, worker_addr | |
) | |
return StreamingResponse(generator, media_type="text/event-stream") | |
choices = [] | |
chat_completions = [] | |
for i in range(request.n): | |
content = asyncio.create_task(generate_completion(gen_params, worker_addr)) | |
chat_completions.append(content) | |
try: | |
all_tasks = await asyncio.gather(*chat_completions) | |
except Exception as e: | |
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
usage = UsageInfo() | |
for i, content in enumerate(all_tasks): | |
if content["error_code"] != 0: | |
return create_error_response(content["error_code"], content["text"]) | |
choices.append( | |
ChatCompletionResponseChoice( | |
index=i, | |
message=ChatMessage(role="assistant", content=content["text"]), | |
finish_reason=content.get("finish_reason", "stop"), | |
) | |
) | |
if "usage" in content: | |
task_usage = UsageInfo.parse_obj(content["usage"]) | |
for usage_key, usage_value in task_usage.dict().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) | |
async def chat_completion_stream_generator( | |
model_name: str, gen_params: Dict[str, Any], n: int, worker_addr: str | |
) -> Generator[str, Any, None]: | |
""" | |
Event stream format: | |
https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format | |
""" | |
id = f"chatcmpl-{shortuuid.random()}" | |
finish_stream_events = [] | |
for i in range(n): | |
# First chunk with role | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=i, | |
delta=DeltaMessage(role="assistant"), | |
finish_reason=None, | |
) | |
chunk = ChatCompletionStreamResponse( | |
id=id, choices=[choice_data], model=model_name | |
) | |
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
previous_text = "" | |
async for content in generate_completion_stream(gen_params, worker_addr): | |
if content["error_code"] != 0: | |
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
decoded_unicode = content["text"].replace("\ufffd", "") | |
delta_text = decoded_unicode[len(previous_text) :] | |
previous_text = ( | |
decoded_unicode | |
if len(decoded_unicode) > len(previous_text) | |
else previous_text | |
) | |
if len(delta_text) == 0: | |
delta_text = None | |
choice_data = ChatCompletionResponseStreamChoice( | |
index=i, | |
delta=DeltaMessage(content=delta_text), | |
finish_reason=content.get("finish_reason", None), | |
) | |
chunk = ChatCompletionStreamResponse( | |
id=id, choices=[choice_data], model=model_name | |
) | |
if delta_text is None: | |
if content.get("finish_reason", None) is not None: | |
finish_stream_events.append(chunk) | |
continue | |
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
# There is not "content" field in the last delta message, so exclude_none to exclude field "content". | |
for finish_chunk in finish_stream_events: | |
yield f"data: {finish_chunk.json(exclude_none=True, ensure_ascii=False)}\n\n" | |
yield "data: [DONE]\n\n" | |
async def create_completion(request: CompletionRequest): | |
error_check_ret = await check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
error_check_ret = check_requests(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
request.prompt = process_input(request.model, request.prompt) | |
worker_addr = await get_worker_address(request.model) | |
for text in request.prompt: | |
max_tokens = await check_length(request, text, request.max_tokens, worker_addr) | |
if isinstance(max_tokens, int) and max_tokens < request.max_tokens: | |
request.max_tokens = max_tokens | |
if request.stream: | |
generator = generate_completion_stream_generator( | |
request, request.n, worker_addr | |
) | |
return StreamingResponse(generator, media_type="text/event-stream") | |
else: | |
text_completions = [] | |
for text in request.prompt: | |
gen_params = await get_gen_params( | |
request.model, | |
worker_addr, | |
text, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
frequency_penalty=request.frequency_penalty, | |
presence_penalty=request.presence_penalty, | |
max_tokens=request.max_tokens, | |
logprobs=request.logprobs, | |
echo=request.echo, | |
stop=request.stop, | |
best_of=request.best_of, | |
use_beam_search=request.use_beam_search, | |
) | |
for i in range(request.n): | |
content = asyncio.create_task( | |
generate_completion(gen_params, worker_addr) | |
) | |
text_completions.append(content) | |
try: | |
all_tasks = await asyncio.gather(*text_completions) | |
except Exception as e: | |
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
choices = [] | |
usage = UsageInfo() | |
for i, content in enumerate(all_tasks): | |
if content["error_code"] != 0: | |
return create_error_response(content["error_code"], content["text"]) | |
choices.append( | |
CompletionResponseChoice( | |
index=i, | |
text=content["text"], | |
logprobs=create_openai_logprobs(content.get("logprobs", None)), | |
finish_reason=content.get("finish_reason", "stop"), | |
) | |
) | |
task_usage = UsageInfo.parse_obj(content["usage"]) | |
for usage_key, usage_value in task_usage.dict().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return CompletionResponse( | |
model=request.model, choices=choices, usage=UsageInfo.parse_obj(usage) | |
) | |
async def generate_completion_stream_generator( | |
request: CompletionRequest, n: int, worker_addr: str | |
): | |
model_name = request.model | |
id = f"cmpl-{shortuuid.random()}" | |
finish_stream_events = [] | |
for text in request.prompt: | |
for i in range(n): | |
previous_text = "" | |
gen_params = await get_gen_params( | |
request.model, | |
worker_addr, | |
text, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
presence_penalty=request.presence_penalty, | |
frequency_penalty=request.frequency_penalty, | |
max_tokens=request.max_tokens, | |
logprobs=request.logprobs, | |
echo=request.echo, | |
stop=request.stop, | |
) | |
async for content in generate_completion_stream(gen_params, worker_addr): | |
if content["error_code"] != 0: | |
yield f"data: {json.dumps(content, ensure_ascii=False)}\n\n" | |
yield "data: [DONE]\n\n" | |
return | |
decoded_unicode = content["text"].replace("\ufffd", "") | |
delta_text = decoded_unicode[len(previous_text) :] | |
previous_text = ( | |
decoded_unicode | |
if len(decoded_unicode) > len(previous_text) | |
else previous_text | |
) | |
# todo: index is not apparent | |
choice_data = CompletionResponseStreamChoice( | |
index=i, | |
text=delta_text, | |
logprobs=create_openai_logprobs(content.get("logprobs", None)), | |
finish_reason=content.get("finish_reason", None), | |
) | |
chunk = CompletionStreamResponse( | |
id=id, | |
object="text_completion", | |
choices=[choice_data], | |
model=model_name, | |
) | |
if len(delta_text) == 0: | |
if content.get("finish_reason", None) is not None: | |
finish_stream_events.append(chunk) | |
continue | |
yield f"data: {chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
# There is not "content" field in the last delta message, so exclude_none to exclude field "content". | |
for finish_chunk in finish_stream_events: | |
yield f"data: {finish_chunk.json(exclude_unset=True, ensure_ascii=False)}\n\n" | |
yield "data: [DONE]\n\n" | |
async def generate_completion_stream(payload: Dict[str, Any], worker_addr: str): | |
controller_address = app_settings.controller_address | |
async with httpx.AsyncClient() as client: | |
delimiter = b"\0" | |
async with client.stream( | |
"POST", | |
worker_addr + "/worker_generate_stream", | |
headers=headers, | |
json=payload, | |
timeout=WORKER_API_TIMEOUT, | |
) as response: | |
# content = await response.aread() | |
buffer = b"" | |
async for raw_chunk in response.aiter_raw(): | |
buffer += raw_chunk | |
while (chunk_end := buffer.find(delimiter)) >= 0: | |
chunk, buffer = buffer[:chunk_end], buffer[chunk_end + 1 :] | |
if not chunk: | |
continue | |
yield json.loads(chunk.decode()) | |
async def generate_completion(payload: Dict[str, Any], worker_addr: str): | |
return await fetch_remote(worker_addr + "/worker_generate", payload, "") | |
async def create_embeddings(request: EmbeddingsRequest, model_name: str = None): | |
"""Creates embeddings for the text""" | |
if request.model is None: | |
request.model = model_name | |
error_check_ret = await check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
request.input = process_input(request.model, request.input) | |
data = [] | |
token_num = 0 | |
batch_size = WORKER_API_EMBEDDING_BATCH_SIZE | |
batches = [ | |
request.input[i : min(i + batch_size, len(request.input))] | |
for i in range(0, len(request.input), batch_size) | |
] | |
for num_batch, batch in enumerate(batches): | |
payload = { | |
"model": request.model, | |
"input": batch, | |
"encoding_format": request.encoding_format, | |
} | |
embedding = await get_embedding(payload) | |
if "error_code" in embedding and embedding["error_code"] != 0: | |
return create_error_response(embedding["error_code"], embedding["text"]) | |
data += [ | |
{ | |
"object": "embedding", | |
"embedding": emb, | |
"index": num_batch * batch_size + i, | |
} | |
for i, emb in enumerate(embedding["embedding"]) | |
] | |
token_num += embedding["token_num"] | |
return EmbeddingsResponse( | |
data=data, | |
model=request.model, | |
usage=UsageInfo( | |
prompt_tokens=token_num, | |
total_tokens=token_num, | |
completion_tokens=None, | |
), | |
).dict(exclude_none=True) | |
async def get_embedding(payload: Dict[str, Any]): | |
controller_address = app_settings.controller_address | |
model_name = payload["model"] | |
worker_addr = await get_worker_address(model_name) | |
embedding = await fetch_remote(worker_addr + "/worker_get_embeddings", payload) | |
return json.loads(embedding) | |
### GENERAL API - NOT OPENAI COMPATIBLE ### | |
async def count_tokens(request: APITokenCheckRequest): | |
""" | |
Checks the token count for each message in your list | |
This is not part of the OpenAI API spec. | |
""" | |
checkedList = [] | |
for item in request.prompts: | |
worker_addr = await get_worker_address(item.model) | |
context_len = await fetch_remote( | |
worker_addr + "/model_details", | |
{"prompt": item.prompt, "model": item.model}, | |
"context_length", | |
) | |
token_num = await fetch_remote( | |
worker_addr + "/count_token", | |
{"prompt": item.prompt, "model": item.model}, | |
"count", | |
) | |
can_fit = True | |
if token_num + item.max_tokens > context_len: | |
can_fit = False | |
checkedList.append( | |
APITokenCheckResponseItem( | |
fits=can_fit, contextLength=context_len, tokenCount=token_num | |
) | |
) | |
return APITokenCheckResponse(prompts=checkedList) | |
async def create_chat_completion(request: APIChatCompletionRequest): | |
"""Creates a completion for the chat message""" | |
error_check_ret = await check_model(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
error_check_ret = check_requests(request) | |
if error_check_ret is not None: | |
return error_check_ret | |
worker_addr = await get_worker_address(request.model) | |
gen_params = await get_gen_params( | |
request.model, | |
worker_addr, | |
request.messages, | |
temperature=request.temperature, | |
top_p=request.top_p, | |
top_k=request.top_k, | |
presence_penalty=request.presence_penalty, | |
frequency_penalty=request.frequency_penalty, | |
max_tokens=request.max_tokens, | |
echo=False, | |
stop=request.stop, | |
) | |
if request.repetition_penalty is not None: | |
gen_params["repetition_penalty"] = request.repetition_penalty | |
gen_params["max_new_tokens"] = await check_length( | |
request, | |
gen_params["prompt"], | |
gen_params["max_new_tokens"], | |
worker_addr, | |
) | |
if request.stream: | |
generator = chat_completion_stream_generator( | |
request.model, gen_params, request.n, worker_addr | |
) | |
return StreamingResponse(generator, media_type="text/event-stream") | |
choices = [] | |
chat_completions = [] | |
for i in range(request.n): | |
content = asyncio.create_task(generate_completion(gen_params, worker_addr)) | |
chat_completions.append(content) | |
try: | |
all_tasks = await asyncio.gather(*chat_completions) | |
except Exception as e: | |
return create_error_response(ErrorCode.INTERNAL_ERROR, str(e)) | |
usage = UsageInfo() | |
for i, content in enumerate(all_tasks): | |
if content["error_code"] != 0: | |
return create_error_response(content["error_code"], content["text"]) | |
choices.append( | |
ChatCompletionResponseChoice( | |
index=i, | |
message=ChatMessage(role="assistant", content=content["text"]), | |
finish_reason=content.get("finish_reason", "stop"), | |
) | |
) | |
task_usage = UsageInfo.parse_obj(content["usage"]) | |
for usage_key, usage_value in task_usage.dict().items(): | |
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) | |
return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) | |
### END GENERAL API - NOT OPENAI COMPATIBLE ### | |
def create_openai_api_server(): | |
parser = argparse.ArgumentParser( | |
description="FastChat ChatGPT-Compatible RESTful API server." | |
) | |
parser.add_argument("--host", type=str, default="localhost", help="host name") | |
parser.add_argument("--port", type=int, default=8000, help="port number") | |
parser.add_argument( | |
"--controller-address", type=str, default="http://localhost:21001" | |
) | |
parser.add_argument( | |
"--allow-credentials", action="store_true", help="allow credentials" | |
) | |
parser.add_argument( | |
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins" | |
) | |
parser.add_argument( | |
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods" | |
) | |
parser.add_argument( | |
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers" | |
) | |
parser.add_argument( | |
"--api-keys", | |
type=lambda s: s.split(","), | |
help="Optional list of comma separated API keys", | |
) | |
parser.add_argument( | |
"--ssl", | |
action="store_true", | |
required=False, | |
default=False, | |
help="Enable SSL. Requires OS Environment variables 'SSL_KEYFILE' and 'SSL_CERTFILE'.", | |
) | |
args = parser.parse_args() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=args.allowed_origins, | |
allow_credentials=args.allow_credentials, | |
allow_methods=args.allowed_methods, | |
allow_headers=args.allowed_headers, | |
) | |
app_settings.controller_address = args.controller_address | |
app_settings.api_keys = args.api_keys | |
logger.info(f"args: {args}") | |
return args | |
if __name__ == "__main__": | |
args = create_openai_api_server() | |
if args.ssl: | |
uvicorn.run( | |
app, | |
host=args.host, | |
port=args.port, | |
log_level="info", | |
ssl_keyfile=os.environ["SSL_KEYFILE"], | |
ssl_certfile=os.environ["SSL_CERTFILE"], | |
) | |
else: | |
uvicorn.run(app, host=args.host, port=args.port, log_level="info") | |