|
""" |
|
Chat with a model with command line interface. |
|
|
|
Usage: |
|
python3 -m fastchat.serve.cli --model lmsys/vicuna-7b-v1.5 |
|
python3 -m fastchat.serve.cli --model lmsys/fastchat-t5-3b-v1.0 |
|
|
|
Other commands: |
|
- Type "!!exit" or an empty line to exit. |
|
- Type "!!reset" to start a new conversation. |
|
- Type "!!remove" to remove the last prompt. |
|
- Type "!!regen" to regenerate the last message. |
|
- Type "!!save <filename>" to save the conversation history to a json file. |
|
- Type "!!load <filename>" to load a conversation history from a json file. |
|
""" |
|
import argparse |
|
import os |
|
import re |
|
import sys |
|
|
|
from prompt_toolkit import PromptSession |
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory |
|
from prompt_toolkit.completion import WordCompleter |
|
from prompt_toolkit.history import InMemoryHistory |
|
from prompt_toolkit.key_binding import KeyBindings |
|
from rich.console import Console |
|
from rich.live import Live |
|
from rich.markdown import Markdown |
|
import torch |
|
|
|
from src.model.model_adapter import add_model_args |
|
from src.modules.awq import AWQConfig |
|
from src.modules.exllama import ExllamaConfig |
|
from src.modules.xfastertransformer import XftConfig |
|
from src.modules.gptq import GptqConfig |
|
from src.serve.inference import ChatIO, chat_loop |
|
from src.utils import str_to_torch_dtype |
|
|
|
|
|
class SimpleChatIO(ChatIO): |
|
def __init__(self, multiline: bool = False): |
|
self._multiline = multiline |
|
|
|
def prompt_for_input(self, role) -> str: |
|
if not self._multiline: |
|
return input(f"{role}: ") |
|
|
|
prompt_data = [] |
|
line = input(f"{role} [ctrl-d/z on empty line to end]: ") |
|
while True: |
|
prompt_data.append(line.strip()) |
|
try: |
|
line = input() |
|
except EOFError as e: |
|
break |
|
return "\n".join(prompt_data) |
|
|
|
def prompt_for_output(self, role: str): |
|
print(f"{role}: ", end="", flush=True) |
|
|
|
def stream_output(self, output_stream): |
|
pre = 0 |
|
for outputs in output_stream: |
|
output_text = outputs["text"] |
|
output_text = output_text.strip().split(" ") |
|
now = len(output_text) - 1 |
|
if now > pre: |
|
print(" ".join(output_text[pre:now]), end=" ", flush=True) |
|
pre = now |
|
print(" ".join(output_text[pre:]), flush=True) |
|
return " ".join(output_text) |
|
|
|
def print_output(self, text: str): |
|
print(text) |
|
|
|
|
|
class RichChatIO(ChatIO): |
|
bindings = KeyBindings() |
|
|
|
@bindings.add("escape", "enter") |
|
def _(event): |
|
event.app.current_buffer.newline() |
|
|
|
def __init__(self, multiline: bool = False, mouse: bool = False): |
|
self._prompt_session = PromptSession(history=InMemoryHistory()) |
|
self._completer = WordCompleter( |
|
words=["!!exit", "!!reset", "!!remove", "!!regen", "!!save", "!!load"], |
|
pattern=re.compile("$"), |
|
) |
|
self._console = Console() |
|
self._multiline = multiline |
|
self._mouse = mouse |
|
|
|
def prompt_for_input(self, role) -> str: |
|
self._console.print(f"[bold]{role}:") |
|
|
|
prompt_input = self._prompt_session.prompt( |
|
completer=self._completer, |
|
multiline=False, |
|
mouse_support=self._mouse, |
|
auto_suggest=AutoSuggestFromHistory(), |
|
key_bindings=self.bindings if self._multiline else None, |
|
) |
|
self._console.print() |
|
return prompt_input |
|
|
|
def prompt_for_output(self, role: str): |
|
self._console.print(f"[bold]{role.replace('/', '|')}:") |
|
|
|
def stream_output(self, output_stream): |
|
"""Stream output from a role.""" |
|
|
|
|
|
|
|
|
|
with Live(console=self._console, refresh_per_second=4) as live: |
|
|
|
for outputs in output_stream: |
|
if not outputs: |
|
continue |
|
text = outputs["text"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lines = [] |
|
for line in text.splitlines(): |
|
lines.append(line) |
|
if line.startswith("```"): |
|
|
|
|
|
lines.append("\n") |
|
else: |
|
lines.append(" \n") |
|
markdown = Markdown("".join(lines)) |
|
|
|
live.update(markdown) |
|
self._console.print() |
|
return text |
|
|
|
def print_output(self, text: str): |
|
self.stream_output([{"text": text}]) |
|
|
|
|
|
class ProgrammaticChatIO(ChatIO): |
|
def prompt_for_input(self, role) -> str: |
|
contents = "" |
|
|
|
|
|
end_sequence = " __END_OF_A_MESSAGE_47582648__\n" |
|
len_end = len(end_sequence) |
|
while True: |
|
if len(contents) >= len_end: |
|
last_chars = contents[-len_end:] |
|
if last_chars == end_sequence: |
|
break |
|
try: |
|
char = sys.stdin.read(1) |
|
contents = contents + char |
|
except EOFError: |
|
continue |
|
contents = contents[:-len_end] |
|
print(f"[!OP:{role}]: {contents}", flush=True) |
|
return contents |
|
|
|
def prompt_for_output(self, role: str): |
|
print(f"[!OP:{role}]: ", end="", flush=True) |
|
|
|
def stream_output(self, output_stream): |
|
pre = 0 |
|
for outputs in output_stream: |
|
output_text = outputs["text"] |
|
output_text = output_text.strip().split(" ") |
|
now = len(output_text) - 1 |
|
if now > pre: |
|
print(" ".join(output_text[pre:now]), end=" ", flush=True) |
|
pre = now |
|
print(" ".join(output_text[pre:]), flush=True) |
|
return " ".join(output_text) |
|
|
|
def print_output(self, text: str): |
|
print(text) |
|
|
|
|
|
def main(args): |
|
if args.gpus: |
|
if len(args.gpus.split(",")) < args.num_gpus: |
|
raise ValueError( |
|
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!" |
|
) |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus |
|
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus |
|
if args.enable_exllama: |
|
exllama_config = ExllamaConfig( |
|
max_seq_len=args.exllama_max_seq_len, |
|
gpu_split=args.exllama_gpu_split, |
|
cache_8bit=args.exllama_cache_8bit, |
|
) |
|
else: |
|
exllama_config = None |
|
if args.enable_xft: |
|
xft_config = XftConfig( |
|
max_seq_len=args.xft_max_seq_len, |
|
data_type=args.xft_dtype, |
|
) |
|
if args.device != "cpu": |
|
print("xFasterTransformer now is only support CPUs. Reset device to CPU") |
|
args.device = "cpu" |
|
else: |
|
xft_config = None |
|
if args.style == "simple": |
|
chatio = SimpleChatIO(args.multiline) |
|
elif args.style == "rich": |
|
chatio = RichChatIO(args.multiline, args.mouse) |
|
elif args.style == "programmatic": |
|
chatio = ProgrammaticChatIO() |
|
else: |
|
raise ValueError(f"Invalid style for console: {args.style}") |
|
try: |
|
chat_loop( |
|
args.model_path, |
|
args.device, |
|
args.num_gpus, |
|
args.max_gpu_memory, |
|
str_to_torch_dtype(args.dtype), |
|
args.load_8bit, |
|
args.cpu_offloading, |
|
args.conv_template, |
|
args.conv_system_msg, |
|
args.temperature, |
|
args.repetition_penalty, |
|
args.max_new_tokens, |
|
chatio, |
|
gptq_config=GptqConfig( |
|
ckpt=args.gptq_ckpt or args.model_path, |
|
wbits=args.gptq_wbits, |
|
groupsize=args.gptq_groupsize, |
|
act_order=args.gptq_act_order, |
|
), |
|
awq_config=AWQConfig( |
|
ckpt=args.awq_ckpt or args.model_path, |
|
wbits=args.awq_wbits, |
|
groupsize=args.awq_groupsize, |
|
), |
|
exllama_config=exllama_config, |
|
xft_config=xft_config, |
|
revision=args.revision, |
|
judge_sent_end=args.judge_sent_end, |
|
debug=args.debug, |
|
history=not args.no_history, |
|
) |
|
except KeyboardInterrupt: |
|
print("exit...") |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
add_model_args(parser) |
|
parser.add_argument( |
|
"--conv-template", type=str, default=None, help="Conversation prompt template." |
|
) |
|
parser.add_argument( |
|
"--conv-system-msg", type=str, default=None, help="Conversation system message." |
|
) |
|
parser.add_argument("--temperature", type=float, default=0.7) |
|
parser.add_argument("--repetition_penalty", type=float, default=1.0) |
|
parser.add_argument("--max-new-tokens", type=int, default=512) |
|
parser.add_argument("--no-history", action="store_true") |
|
parser.add_argument( |
|
"--style", |
|
type=str, |
|
default="simple", |
|
choices=["simple", "rich", "programmatic"], |
|
help="Display style.", |
|
) |
|
parser.add_argument( |
|
"--multiline", |
|
action="store_true", |
|
help="Enable multiline input. Use ESC+Enter for newline.", |
|
) |
|
parser.add_argument( |
|
"--mouse", |
|
action="store_true", |
|
help="[Rich Style]: Enable mouse support for cursor positioning.", |
|
) |
|
parser.add_argument( |
|
"--judge-sent-end", |
|
action="store_true", |
|
help="Whether enable the correction logic that interrupts the output of sentences due to EOS.", |
|
) |
|
parser.add_argument( |
|
"--debug", |
|
action="store_true", |
|
help="Print useful debug information (e.g., prompts)", |
|
) |
|
args = parser.parse_args() |
|
main(args) |
|
|