handsomeguy001 commited on
Commit
8723eb9
1 Parent(s): 83f5c2c

support litellm

Browse files
mindsearch/agent/__init__.py CHANGED
@@ -17,13 +17,14 @@ from mindsearch.agent.mindsearch_prompt import (
17
  LLM = {}
18
 
19
 
20
- def init_agent(lang='cn', model_format='internlm_server',search_engine='DuckDuckGoSearch'):
21
  llm = LLM.get(model_format, None)
22
  if llm is None:
23
  llm_cfg = getattr(llm_factory, model_format)
24
  if llm_cfg is None:
25
  raise NotImplementedError
26
  llm_cfg = llm_cfg.copy()
 
27
  llm = llm_cfg.pop('type')(**llm_cfg)
28
  LLM[model_format] = llm
29
 
@@ -37,9 +38,9 @@ def init_agent(lang='cn', model_format='internlm_server',search_engine='DuckDuck
37
  llm=llm,
38
  protocol=MindSearchProtocol(meta_prompt=datetime.now().strftime(
39
  'The current date is %Y-%m-%d.'),
40
- interpreter_prompt=interpreter_prompt,
41
- response_prompt=FINAL_RESPONSE_CN
42
- if lang == 'cn' else FINAL_RESPONSE_EN),
43
  searcher_cfg=dict(
44
  llm=llm,
45
  plugin_executor=ActionExecutor(
 
17
  LLM = {}
18
 
19
 
20
+ def init_agent(lang='cn', model_format='internlm_server', search_engine='DuckDuckGoSearch', **kwargs):
21
  llm = LLM.get(model_format, None)
22
  if llm is None:
23
  llm_cfg = getattr(llm_factory, model_format)
24
  if llm_cfg is None:
25
  raise NotImplementedError
26
  llm_cfg = llm_cfg.copy()
27
+ llm_cfg.update(kwargs)
28
  llm = llm_cfg.pop('type')(**llm_cfg)
29
  LLM[model_format] = llm
30
 
 
38
  llm=llm,
39
  protocol=MindSearchProtocol(meta_prompt=datetime.now().strftime(
40
  'The current date is %Y-%m-%d.'),
41
+ interpreter_prompt=interpreter_prompt,
42
+ response_prompt=FINAL_RESPONSE_CN
43
+ if lang == 'cn' else FINAL_RESPONSE_EN),
44
  searcher_cfg=dict(
45
  llm=llm,
46
  plugin_executor=ActionExecutor(
mindsearch/agent/models.py CHANGED
@@ -1,3 +1,9 @@
 
 
 
 
 
 
1
  import os
2
 
3
  from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM,
@@ -38,7 +44,8 @@ internlm_hf = dict(type=HFTransformerCasualLM,
38
  gpt4 = dict(type=GPTAPI,
39
  model_type='gpt-4-turbo',
40
  key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'),
41
- openai_api_base=os.environ.get('OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'),
 
42
  )
43
 
44
  url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
@@ -61,7 +68,8 @@ qwen = dict(type=GPTAPI,
61
 
62
  internlm_silicon = dict(type=GPTAPI,
63
  model_type='internlm/internlm2_5-7b-chat',
64
- key=os.environ.get('SILICON_API_KEY', 'YOUR SILICON API KEY'),
 
65
  openai_api_base='https://api.siliconflow.cn/v1/chat/completions',
66
  meta_template=[
67
  dict(role='system', api_role='system'),
@@ -75,3 +83,133 @@ internlm_silicon = dict(type=GPTAPI,
75
  max_new_tokens=8192,
76
  repetition_penalty=1.02,
77
  stop_words=['<|im_end|>'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lagent.llms import BaseAPIModel
2
+ from typing import List, Optional, Union
3
+ from litellm import completion
4
+
5
+ from lagent.schema import ModelStatusCode
6
+ from lagent.utils.util import filter_suffix
7
  import os
8
 
9
  from lagent.llms import (GPTAPI, INTERNLM2_META, HFTransformerCasualLM,
 
44
  gpt4 = dict(type=GPTAPI,
45
  model_type='gpt-4-turbo',
46
  key=os.environ.get('OPENAI_API_KEY', 'YOUR OPENAI API KEY'),
47
+ openai_api_base=os.environ.get(
48
+ 'OPENAI_API_BASE', 'https://api.openai.com/v1/chat/completions'),
49
  )
50
 
51
  url = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation'
 
68
 
69
  internlm_silicon = dict(type=GPTAPI,
70
  model_type='internlm/internlm2_5-7b-chat',
71
+ key=os.environ.get(
72
+ 'SILICON_API_KEY', 'YOUR SILICON API KEY'),
73
  openai_api_base='https://api.siliconflow.cn/v1/chat/completions',
74
  meta_template=[
75
  dict(role='system', api_role='system'),
 
83
  max_new_tokens=8192,
84
  repetition_penalty=1.02,
85
  stop_words=['<|im_end|>'])
86
+
87
+
88
+ class litellmCompletion(BaseAPIModel):
89
+ """
90
+
91
+ Args:
92
+ path (str): The path to the model.
93
+ It could be one of the following options:
94
+ - i) A local directory path of a turbomind model which is
95
+ converted by `lmdeploy convert` command or download
96
+ from ii) and iii).
97
+ - ii) The model_id of a lmdeploy-quantized model hosted
98
+ inside a model repo on huggingface.co, such as
99
+ "InternLM/internlm-chat-20b-4bit",
100
+ "lmdeploy/llama2-chat-70b-4bit", etc.
101
+ - iii) The model_id of a model hosted inside a model repo
102
+ on huggingface.co, such as "internlm/internlm-chat-7b",
103
+ "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
104
+ and so on.
105
+ model_name (str): needed when model_path is a pytorch model on
106
+ huggingface.co, such as "internlm-chat-7b",
107
+ "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on.
108
+ tp (int): tensor parallel
109
+ pipeline_cfg (dict): config of pipeline
110
+ """
111
+
112
+ def __init__(self,
113
+ path='',
114
+ model_name="command-r",
115
+ **kwargs):
116
+ self.model_name = model_name
117
+ super().__init__(path, **kwargs)
118
+
119
+ def generate(self,
120
+ inputs: Union[str, List[str]],
121
+ do_preprocess: bool = None,
122
+ skip_special_tokens: bool = False,
123
+ **kwargs):
124
+ """Return the chat completions in non-stream mode.
125
+
126
+ Args:
127
+ inputs (Union[str, List[str]]): input texts to be completed.
128
+ do_preprocess (bool): whether pre-process the messages. Default to
129
+ True, which means chat_template will be applied.
130
+ skip_special_tokens (bool): Whether or not to remove special tokens
131
+ in the decoding. Default to be False.
132
+ Returns:
133
+ (a list of/batched) text/chat completion
134
+ """
135
+
136
+ batched = True
137
+ if isinstance(inputs, str):
138
+ inputs = [inputs]
139
+ prompts = inputs
140
+ messages = [{"role": "user", "content": prompt}for prompt in prompts]
141
+ gen_params = self.update_gen_params(**kwargs)
142
+ response = completion(model=self.model_name, messages=messages)
143
+ response = [resp.message.content for resp in response.choices]
144
+ # remove stop_words
145
+ response = filter_suffix(response, self.gen_params.get('stop_words'))
146
+ if batched:
147
+ return response
148
+ return response[0]
149
+
150
+ def stream_chat(self,
151
+ inputs: List[dict],
152
+ stream: bool = True,
153
+ ignore_eos: bool = False,
154
+ skip_special_tokens: Optional[bool] = False,
155
+ timeout: int = 30,
156
+ **kwargs):
157
+ """Start a new round conversation of a session. Return the chat
158
+ completions in stream mode.
159
+
160
+ Args:
161
+ session_id (int): the identical id of a session
162
+ inputs (List[dict]): user's inputs in this round conversation
163
+ sequence_start (bool): start flag of a session
164
+ sequence_end (bool): end flag of a session
165
+ stream (bool): return in a streaming format if enabled
166
+ ignore_eos (bool): indicator for ignoring eos
167
+ skip_special_tokens (bool): Whether or not to remove special tokens
168
+ in the decoding. Default to be False.
169
+ timeout (int): max time to wait for response
170
+ Returns:
171
+ tuple(Status, str, int): status, text/chat completion,
172
+ generated token number
173
+ """
174
+ gen_params = self.update_gen_params(**kwargs)
175
+ max_new_tokens = gen_params.pop('max_new_tokens')
176
+ gen_params.update(max_tokens=max_new_tokens)
177
+
178
+ resp = ''
179
+ finished = False
180
+ stop_words = gen_params.get('stop_words')
181
+ if stop_words is None:
182
+ stop_words = []
183
+ messages = self.template_parser._prompt2api(inputs)
184
+
185
+ for text in completion(
186
+ self.model_name,
187
+ messages,
188
+ stream=stream,
189
+ **gen_params):
190
+ if not text.choices[0].delta.content:
191
+ continue
192
+ resp += text.choices[0].delta.content
193
+ if not resp:
194
+ continue
195
+ # remove stop_words
196
+ for sw in stop_words:
197
+ if sw in resp:
198
+ resp = filter_suffix(resp, stop_words)
199
+ finished = True
200
+ break
201
+ yield ModelStatusCode.STREAM_ING, resp, None
202
+ if finished:
203
+ break
204
+ yield ModelStatusCode.END, resp, None
205
+
206
+
207
+ litellm_completion = dict(type=litellmCompletion,
208
+ # model_name="deepseek/deepseek-chat",
209
+ meta_template=[
210
+ dict(role='system', api_role='system'),
211
+ dict(role='user', api_role='user'),
212
+ dict(role='assistant', api_role='assistant'),
213
+ dict(role='environment', api_role='system')
214
+ ]
215
+ )
mindsearch/app.py CHANGED
@@ -24,9 +24,11 @@ def parse_arguments():
24
  type=str,
25
  help='Model format')
26
  parser.add_argument('--search_engine',
27
- default='DuckDuckGoSearch',
28
- type=str,
29
- help='Search engine')
 
 
30
  return parser.parse_args()
31
 
32
 
@@ -127,7 +129,14 @@ async def run(request: GenerationParams):
127
  await queue.wait_closed()
128
 
129
  inputs = request.inputs
130
- agent = init_agent(lang=args.lang, model_format=args.model_format,search_engine=args.search_engine)
 
 
 
 
 
 
 
131
  return EventSourceResponse(generate())
132
 
133
 
 
24
  type=str,
25
  help='Model format')
26
  parser.add_argument('--search_engine',
27
+ default='DuckDuckGoSearch',
28
+ type=str,
29
+ help='Search engine')
30
+ parser.add_argument('--model_name', default='deepseek/deepseek-chat',
31
+ type=str, help='litellm model name')
32
  return parser.parse_args()
33
 
34
 
 
129
  await queue.wait_closed()
130
 
131
  inputs = request.inputs
132
+ if args.model_format == 'litellm_completion':
133
+ agent = init_agent(lang=args.lang, model_format=args.model_format,
134
+ search_engine=args.search_engine,
135
+ model_name=args.model_name)
136
+ else:
137
+ agent = init_agent(
138
+ lang=args.lang, model_format=args.model_format,
139
+ search_engine=args.search_engine)
140
  return EventSourceResponse(generate())
141
 
142
 
requirements.txt CHANGED
@@ -1,12 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  duckduckgo_search==5.3.1b1
2
- einops
3
- fastapi
4
- git+https://github.com/InternLM/lagent.git
5
- gradio
6
- janus
7
- lmdeploy
8
- pyvis
9
- sse-starlette
10
- termcolor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  transformers==4.41.0
12
- uvicorn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ addict==2.4.0
3
+ aiofiles==23.2.1
4
+ aiohappyeyeballs==2.4.0
5
+ aiohttp==3.10.5
6
+ aiosignal==1.3.1
7
+ altair==5.4.1
8
+ annotated-types==0.7.0
9
+ anyio==4.4.0
10
+ argon2-cffi==23.1.0
11
+ argon2-cffi-bindings==21.2.0
12
+ arrow==1.3.0
13
+ arxiv==2.1.3
14
+ asttokens==2.4.1
15
+ async-lru==2.0.4
16
+ attrs==24.2.0
17
+ babel==2.16.0
18
+ beautifulsoup4==4.12.3
19
+ bleach==6.1.0
20
+ blinker==1.8.2
21
+ Brotli==1.1.0
22
+ cachetools==5.5.0
23
+ certifi==2024.8.30
24
+ cffi==1.17.0
25
+ charset-normalizer==3.3.2
26
+ click==8.1.7
27
+ colorama==0.4.6
28
+ comm==0.2.2
29
+ contourpy==1.3.0
30
+ cycler==0.12.1
31
+ debugpy==1.8.5
32
+ decorator==5.1.1
33
+ defusedxml==0.7.1
34
+ distro==1.9.0
35
  duckduckgo_search==5.3.1b1
36
+ einops==0.8.0
37
+ executing==2.0.1
38
+ fastapi==0.112.2
39
+ fastjsonschema==2.20.0
40
+ feedparser==6.0.11
41
+ ffmpy==0.4.0
42
+ filelock==3.15.4
43
+ fire==0.6.0
44
+ fonttools==4.53.1
45
+ fqdn==1.5.1
46
+ frozenlist==1.4.1
47
+ fsspec==2024.6.1
48
+ func_timeout==4.3.5
49
+ gitdb==4.0.11
50
+ GitPython==3.1.43
51
+ gradio==4.42.0
52
+ gradio_client==1.3.0
53
+ griffe==1.2.0
54
+ h11==0.14.0
55
+ h2==4.1.0
56
+ hpack==4.0.0
57
+ httpcore==1.0.5
58
+ httpx==0.27.2
59
+ huggingface-hub==0.24.6
60
+ hyperframe==6.0.1
61
+ idna==3.8
62
+ importlib_metadata==8.4.0
63
+ importlib_resources==6.4.4
64
+ ipykernel==6.29.5
65
+ ipython==8.27.0
66
+ ipywidgets==8.1.5
67
+ isoduration==20.11.0
68
+ janus==1.0.0
69
+ jedi==0.19.1
70
+ Jinja2==3.1.4
71
+ jiter==0.5.0
72
+ json5==0.9.25
73
+ jsonpickle==3.2.2
74
+ jsonpointer==3.0.0
75
+ jsonschema==4.23.0
76
+ jsonschema-specifications==2023.12.1
77
+ jupyter==1.1.1
78
+ jupyter-console==6.6.3
79
+ jupyter-events==0.10.0
80
+ jupyter-lsp==2.2.5
81
+ jupyter_client==8.6.2
82
+ jupyter_core==5.7.2
83
+ jupyter_server==2.14.2
84
+ jupyter_server_terminals==0.5.3
85
+ jupyterlab==4.2.5
86
+ jupyterlab_pygments==0.3.0
87
+ jupyterlab_server==2.27.3
88
+ jupyterlab_widgets==3.0.13
89
+ kiwisolver==1.4.5
90
+ -e git+https://github.com/InternLM/lagent.git@906845f1af47fcb7d81c5c32ec44b0cc22204f8a#egg=lagent
91
+ litellm==1.44.13
92
+ lmdeploy==0.5.3
93
+ markdown-it-py==3.0.0
94
+ MarkupSafe==2.1.5
95
+ matplotlib==3.9.2
96
+ matplotlib-inline==0.1.7
97
+ mdurl==0.1.2
98
+ mistune==3.0.2
99
+ mmengine-lite==0.10.4
100
+ mpmath==1.3.0
101
+ multidict==6.0.5
102
+ narwhals==1.6.0
103
+ nbclient==0.10.0
104
+ nbconvert==7.16.4
105
+ nbformat==5.10.4
106
+ nest-asyncio==1.6.0
107
+ networkx==3.3
108
+ notebook==7.2.2
109
+ notebook_shim==0.2.4
110
+ numpy==1.26.4
111
+ nvidia-cublas-cu12==12.1.3.1
112
+ nvidia-cuda-cupti-cu12==12.1.105
113
+ nvidia-cuda-nvrtc-cu12==12.1.105
114
+ nvidia-cuda-runtime-cu12==12.1.105
115
+ nvidia-cudnn-cu12==8.9.2.26
116
+ nvidia-cufft-cu12==11.0.2.54
117
+ nvidia-curand-cu12==10.3.2.106
118
+ nvidia-cusolver-cu12==11.4.5.107
119
+ nvidia-cusparse-cu12==12.1.0.106
120
+ nvidia-nccl-cu12==2.20.5
121
+ nvidia-nvjitlink-cu12==12.6.68
122
+ nvidia-nvtx-cu12==12.1.105
123
+ openai==1.43.0
124
+ orjson==3.10.7
125
+ overrides==7.7.0
126
+ packaging==24.1
127
+ pandas==2.2.2
128
+ pandocfilters==1.5.1
129
+ parso==0.8.4
130
+ peft==0.11.1
131
+ pexpect==4.9.0
132
+ phx-class-registry==4.1.0
133
+ pillow==10.4.0
134
+ platformdirs==4.2.2
135
+ prometheus_client==0.20.0
136
+ prompt_toolkit==3.0.47
137
+ protobuf==5.28.0
138
+ psutil==6.0.0
139
+ ptyprocess==0.7.0
140
+ pure_eval==0.2.3
141
+ pyarrow==17.0.0
142
+ pycparser==2.22
143
+ pydantic==2.8.2
144
+ pydantic_core==2.20.1
145
+ pydeck==0.9.1
146
+ pydub==0.25.1
147
+ Pygments==2.18.0
148
+ pynvml==11.5.3
149
+ pyparsing==3.1.4
150
+ python-dateutil==2.9.0.post0
151
+ python-dotenv==1.0.1
152
+ python-json-logger==2.0.7
153
+ python-multipart==0.0.9
154
+ pytz==2024.1
155
+ pyvis==0.3.2
156
+ PyYAML==6.0.2
157
+ pyzmq==26.2.0
158
+ referencing==0.35.1
159
+ regex==2024.7.24
160
+ requests==2.32.3
161
+ rfc3339-validator==0.1.4
162
+ rfc3986-validator==0.1.1
163
+ rich==13.8.0
164
+ rpds-py==0.20.0
165
+ ruff==0.6.3
166
+ safetensors==0.4.4
167
+ semantic-version==2.10.0
168
+ Send2Trash==1.8.3
169
+ sentencepiece==0.2.0
170
+ setuptools==72.2.0
171
+ sgmllib3k==1.0.0
172
+ shellingham==1.5.4
173
+ shortuuid==1.0.13
174
+ six==1.16.0
175
+ smmap==5.0.1
176
+ sniffio==1.3.1
177
+ socksio==1.0.0
178
+ soupsieve==2.6
179
+ sse-starlette==2.1.3
180
+ stack-data==0.6.3
181
+ starlette==0.38.2
182
+ streamlit==1.38.0
183
+ sympy==1.13.2
184
+ tenacity==8.5.0
185
+ termcolor==2.4.0
186
+ terminado==0.18.1
187
+ tiktoken==0.7.0
188
+ timeout-decorator==0.5.0
189
+ tinycss2==1.3.0
190
+ tokenizers==0.19.1
191
+ toml==0.10.2
192
+ tomli==2.0.1
193
+ tomlkit==0.12.0
194
+ torch==2.3.1
195
+ torchvision==0.18.1
196
+ tornado==6.4.1
197
+ tqdm==4.66.5
198
+ traitlets==5.14.3
199
  transformers==4.41.0
200
+ triton==2.3.1
201
+ typer==0.12.5
202
+ types-python-dateutil==2.9.0.20240821
203
+ typing_extensions==4.12.2
204
+ tzdata==2024.1
205
+ uri-template==1.3.0
206
+ urllib3==2.2.2
207
+ uvicorn==0.30.6
208
+ watchdog==4.0.2
209
+ wcwidth==0.2.13
210
+ webcolors==24.8.0
211
+ webencodings==0.5.1
212
+ websocket-client==1.8.0
213
+ websockets==12.0
214
+ wheel==0.44.0
215
+ widgetsnbextension==4.0.13
216
+ yapf==0.40.2
217
+ yarl==1.9.6
218
+ zipp==3.20.1