"""Send a test message.""" import argparse import json import requests from fastchat.model.model_adapter import get_conversation_template from fastchat.conversation import get_conv_template def main(): model_name = args.model_name conv_template = args.conv_template if args.worker_address: worker_addr = args.worker_address else: controller_addr = args.controller_address ret = requests.post(controller_addr + "/refresh_all_workers") ret = requests.post(controller_addr + "/list_models") models = ret.json()["models"] models.sort() print(f"Models: {models}") ret = requests.post( controller_addr + "/get_worker_address", json={"model": model_name} ) worker_addr = ret.json()["address"] print(f"worker_addr: {worker_addr}") if worker_addr == "": print(f"No available workers for {model_name}") return # conv = get_conversation_template(model_name) conv = get_conv_template(conv_template) conv.append_message(conv.roles[0], args.message) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() headers = {"User-Agent": "FastChat Client"} gen_params = { "model": model_name, "prompt": prompt, "temperature": args.temperature, "max_new_tokens": args.max_new_tokens, "stop": conv.stop_str, "stop_token_ids": conv.stop_token_ids, "echo": False, } response = requests.post( worker_addr + "/worker_generate_stream", headers=headers, json=gen_params, stream=True, ) print(f"{conv.roles[0]}: {args.message}") print(f"{conv.roles[1]}: ", end="") prev = 0 for chunk in response.iter_lines(decode_unicode=False, delimiter=b"\0"): if chunk: data = json.loads(chunk.decode()) output = data["text"].strip() print(output[prev:], end="", flush=True) prev = len(output) print("") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--controller-address", type=str, default="http://localhost:21001" ) parser.add_argument("--worker-address", type=str) parser.add_argument("--model-name", type=str, required=True) parser.add_argument("--conv-template", type=str, required=True) parser.add_argument("--temperature", type=float, default=0.0) parser.add_argument("--max-new-tokens", type=int, default=32) parser.add_argument( "--message", type=str, default="Tell me a story with more than 1000 words." ) args = parser.parse_args() main()