Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) OpenMMLab. All rights reserved. | |
"""LLM client.""" | |
import argparse | |
import json | |
import pytoml | |
import requests | |
from loguru import logger | |
class ChatClient: | |
"""A class to handle client-side interactions with a chat service. | |
This class is responsible for loading configurations from a given path, | |
building prompts, and generating responses by interacting with the chat | |
service. | |
""" | |
def __init__(self, config_path: str) -> None: | |
"""Initialize the ChatClient with the path of the configuration | |
file.""" | |
self.config_path = config_path | |
def load_config(self): | |
"""Load the 'llm' section of the configuration from the provided | |
path.""" | |
with open(self.config_path, encoding='utf8') as f: | |
config = pytoml.load(f) | |
return config['llm'] | |
def load_llm_config(self): | |
"""Load the 'server' section of the 'llm' configuration from the | |
provided path.""" | |
with open(self.config_path, encoding='utf8') as f: | |
config = pytoml.load(f) | |
return config['llm']['server'] | |
def build_prompt(self, | |
history_pair, | |
instruction: str, | |
template: str, | |
context: str = '', | |
reject: str = '<reject>'): | |
"""Build a prompt for interaction. | |
Args: | |
history_pair (list): List of previous interactions. | |
instruction (str): Instruction for the current interaction. | |
template (str): Template for constructing the interaction. | |
context (str, optional): Context of the interaction. Defaults to ''. # noqa E501 | |
reject (str, optional): Text that indicates a rejected interaction. Defaults to '<reject>'. # noqa E501 | |
Returns: | |
tuple: A tuple containing the constructed instruction and real history. | |
""" | |
if context is not None and len(context) > 0: | |
instruction = template.format(context, instruction) | |
real_history = [] | |
for pair in history_pair: | |
if pair[1] == reject: | |
continue | |
if pair[0] is None or pair[1] is None: | |
continue | |
if len(pair[0]) < 1 or len(pair[1]) < 1: | |
continue | |
real_history.append(pair) | |
return instruction, real_history | |
def generate_response(self, prompt, backend, history=[]): | |
"""Generate a response from the chat service. | |
Args: | |
prompt (str): The prompt to send to the chat service. | |
history (list, optional): List of previous interactions. Defaults to []. | |
backend (str, optional): Determine which LLM should be called. Default to `local` | |
Returns: | |
str: Generated response from the chat service. | |
""" | |
llm_config = self.load_config() | |
url, enable_local, enable_remote = (llm_config['client_url'], | |
llm_config['enable_local'], | |
llm_config['enable_remote']) | |
type_given = llm_config['server']['remote_type'] != "" # yyj | |
api_given = llm_config['server']['remote_api_key'] != "" # yyj | |
llm_given = llm_config['server']['remote_llm_model'] != "" # yyj | |
if backend == 'local' and enable_local: # yyj | |
max_length = llm_config['server']['local_llm_max_text_length'] # yyj | |
elif backend == 'remote' and enable_remote and type_given and api_given and llm_given: # yyj | |
max_length = llm_config['server']['remote_llm_max_text_length'] # yyj | |
backend = llm_config['server']['remote_type'] # yyj | |
else: | |
raise ValueError('Invalid backend or backend is not enabled') | |
# remote = False | |
# if backend != 'local': | |
# remote = True | |
# if remote and not enable_remote: | |
# # if use remote LLM (for example, kimi) and disable enable_remote | |
# # auto fixed to local LLM | |
# remote = False | |
# logger.warning( | |
# 'disable remote LLM while choose remote LLM, auto fixed') | |
# elif not enable_local and not remote: | |
# remote = True | |
# backend = 'remote' # yyj | |
# logger.warning( | |
# 'diable local LLM while using local LLM, auto fixed') | |
# if remote: | |
# if backend == 'remote': | |
# backend = llm_config['server']['remote_type'] | |
# max_length = llm_config['server']['remote_llm_max_text_length'] | |
# else: | |
# backend = 'local' | |
# max_length = llm_config['server']['local_llm_max_text_length'] | |
if len(prompt) > max_length: | |
logger.warning( | |
f'prompt length {len(prompt)} > max_length {max_length}, truncated' # noqa E501 | |
) | |
prompt = prompt[0:max_length] | |
try: | |
header = {'Content-Type': 'application/json'} | |
data_history = [] | |
for item in history: | |
data_history.append([item[0], item[1]]) | |
data = { | |
'prompt': prompt, | |
'history': data_history, | |
'backend': backend | |
} | |
resp = requests.post(url, | |
headers=header, | |
data=json.dumps(data), | |
timeout=300) | |
if resp.status_code != 200: | |
raise Exception(str((resp.status_code, resp.reason))) | |
return resp.json()['text'] | |
except Exception as e: | |
logger.error(str(e)) | |
logger.error( | |
'Do you forget `--standalone` when `python3 -m huixiangdou.main` ?' # noqa E501 | |
) | |
return '' | |
def parse_args(): | |
"""Parse command-line arguments.""" | |
parser = argparse.ArgumentParser( | |
description='Client for hybrid llm service.') | |
parser.add_argument( | |
'--config_path', | |
default='config.ini', | |
help='Configuration path. Default value is config.ini' # noqa E501 | |
) | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
args = parse_args() | |
client = ChatClient(config_path=args.config_path) | |
question = '“{}”\n请仔细阅读以上问题,提取其中的实体词,结果直接用 list 表示,不要解释。'.format( | |
'请问triviaqa 5shot结果怎么在summarizer里输出呢') | |
print(client.generate_response(prompt=question, backend='local')) | |
print( | |
client.generate_response(prompt='请问 ncnn 的全称是什么', | |
history=[('ncnn 是什么', | |
'ncnn中的n代表nihui,cnn代表卷积神经网络。')], | |
backend='remote')) | |