FIRE / src /serve /cli.py
zhangbofei
fix: src
2238fe2
"""
Chat with a model with command line interface.
Usage:
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0
Other commands:
- Type "!!exit" or an empty line to exit.
- Type "!!reset" to start a new conversation.
- Type "!!remove" to remove the last prompt.
- Type "!!regen" to regenerate the last message.
- Type "!!save <filename>" to save the conversation history to a json file.
- Type "!!load <filename>" to load a conversation history from a json file.
"""
import argparse
import os
import re
import sys
from prompt_toolkit import PromptSession
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.history import InMemoryHistory
from prompt_toolkit.key_binding import KeyBindings
from rich.console import Console
from rich.live import Live
from rich.markdown import Markdown
import torch
from src.model.model_adapter import add_model_args
from src.modules.awq import AWQConfig
from src.modules.exllama import ExllamaConfig
from src.modules.xfastertransformer import XftConfig
from src.modules.gptq import GptqConfig
from src.serve.inference import ChatIO, chat_loop
from src.utils import str_to_torch_dtype
class SimpleChatIO(ChatIO):
def __init__(self, multiline: bool = False):
self._multiline = multiline
def prompt_for_input(self, role) -> str:
if not self._multiline:
return input(f"{role}: ")
prompt_data = []
line = input(f"{role} [ctrl-d/z on empty line to end]: ")
while True:
prompt_data.append(line.strip())
try:
line = input()
except EOFError as e:
break
return "\n".join(prompt_data)
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
return " ".join(output_text)
def print_output(self, text: str):
print(text)
class RichChatIO(ChatIO):
bindings = KeyBindings()
@bindings.add("escape", "enter")
def _(event):
event.app.current_buffer.newline()
def __init__(self, multiline: bool = False, mouse: bool = False):
self._prompt_session = PromptSession(history=InMemoryHistory())
self._completer = WordCompleter(
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"],
pattern=re.compile("$"),
)
self._console = Console()
self._multiline = multiline
self._mouse = mouse
def prompt_for_input(self, role) -> str:
self._console.print(f"[bold]{role}:")
# TODO(suquark): multiline input has some issues. fix it later.
prompt_input = self._prompt_session.prompt(
completer=self._completer,
multiline=False,
mouse_support=self._mouse,
auto_suggest=AutoSuggestFromHistory(),
key_bindings=self.bindings if self._multiline else None,
)
self._console.print()
return prompt_input
def prompt_for_output(self, role: str):
self._console.print(f"[bold]{role.replace('/', '|')}:")
def stream_output(self, output_stream):
"""Stream output from a role."""
# TODO(suquark): the console flickers when there is a code block
# above it. We need to cut off "live" when a code block is done.
# Create a Live context for updating the console output
with Live(console=self._console, refresh_per_second=4) as live:
# Read lines from the stream
for outputs in output_stream:
if not outputs:
continue
text = outputs["text"]
# Render the accumulated text as Markdown
# NOTE: this is a workaround for the rendering "unstandard markdown"
# in rich. The chatbots output treat "\n" as a new line for
# better compatibility with real-world text. However, rendering
# in markdown would break the format. It is because standard markdown
# treat a single "\n" in normal text as a space.
# Our workaround is adding two spaces at the end of each line.
# This is not a perfect solution, as it would
# introduce trailing spaces (only) in code block, but it works well
# especially for console output, because in general the console does not
# care about trailing spaces.
lines = []
for line in text.splitlines():
lines.append(line)
if line.startswith("```"):
# Code block marker - do not add trailing spaces, as it would
# break the syntax highlighting
lines.append("\n")
else:
lines.append(" \n")
markdown = Markdown("".join(lines))
# Update the Live console output
live.update(markdown)
self._console.print()
return text
def print_output(self, text: str):
self.stream_output([{"text": text}])
class ProgrammaticChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
contents = ""
# `end_sequence` signals the end of a message. It is unlikely to occur in
# message content.
end_sequence = " __END_OF_A_MESSAGE_47582648__\n"
len_end = len(end_sequence)
while True:
if len(contents) >= len_end:
last_chars = contents[-len_end:]
if last_chars == end_sequence:
break
try:
char = sys.stdin.read(1)
contents = contents + char
except EOFError:
continue
contents = contents[:-len_end]
print(f"[!OP:{role}]: {contents}", flush=True)
return contents
def prompt_for_output(self, role: str):
print(f"[!OP:{role}]: ", end="", flush=True)
def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
return " ".join(output_text)
def print_output(self, text: str):
print(text)
def main(args):
if args.gpus:
if len(args.gpus.split(",")) < args.num_gpus:
raise ValueError(
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
gpu_split=args.exllama_gpu_split,
cache_8bit=args.exllama_cache_8bit,
)
else:
exllama_config = None
if args.enable_xft:
xft_config = XftConfig(
max_seq_len=args.xft_max_seq_len,
data_type=args.xft_dtype,
)
if args.device != "cpu":
print("xFasterTransformer now is only support CPUs. Reset device to CPU")
args.device = "cpu"
else:
xft_config = None
if args.style == "simple":
chatio = SimpleChatIO(args.multiline)
elif args.style == "rich":
chatio = RichChatIO(args.multiline, args.mouse)
elif args.style == "programmatic":
chatio = ProgrammaticChatIO()
else:
raise ValueError(f"Invalid style for console: {args.style}")
try:
chat_loop(
args.model_path,
args.device,
args.num_gpus,
args.max_gpu_memory,
str_to_torch_dtype(args.dtype),
args.load_8bit,
args.cpu_offloading,
args.conv_template,
args.conv_system_msg,
args.temperature,
args.repetition_penalty,
args.max_new_tokens,
chatio,
gptq_config=GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
wbits=args.gptq_wbits,
groupsize=args.gptq_groupsize,
act_order=args.gptq_act_order,
),
awq_config=AWQConfig(
ckpt=args.awq_ckpt or args.model_path,
wbits=args.awq_wbits,
groupsize=args.awq_groupsize,
),
exllama_config=exllama_config,
xft_config=xft_config,
revision=args.revision,
judge_sent_end=args.judge_sent_end,
debug=args.debug,
history=not args.no_history,
)
except KeyboardInterrupt:
print("exit...")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
add_model_args(parser)
parser.add_argument(
"--conv-template", type=str, default=None, help="Conversation prompt template."
)
parser.add_argument(
"--conv-system-msg", type=str, default=None, help="Conversation system message."
)
parser.add_argument("--temperature", type=float, default=0.7)
parser.add_argument("--repetition_penalty", type=float, default=1.0)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--no-history", action="store_true")
parser.add_argument(
"--style",
type=str,
default="simple",
choices=["simple", "rich", "programmatic"],
help="Display style.",
)
parser.add_argument(
"--multiline",
action="store_true",
help="Enable multiline input. Use ESC+Enter for newline.",
)
parser.add_argument(
"--mouse",
action="store_true",
help="[Rich Style]: Enable mouse support for cursor positioning.",
)
parser.add_argument(
"--judge-sent-end",
action="store_true",
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.",
)
parser.add_argument(
"--debug",
action="store_true",
help="Print useful debug information (e.g., prompts)",
)
args = parser.parse_args()
main(args)