pseudotensor commited on
Commit
1e8c453
1 Parent(s): 663f03d

Update with h2oGPT hash 236c95819e80ab122193bfb843b55618ae285c39

Browse files
client_test.py CHANGED
@@ -97,6 +97,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
97
  chunk_size=512,
98
  document_choice=[DocumentChoices.All_Relevant.name],
99
  )
 
 
100
  if chat:
101
  # add chatbot output on end. Assumes serialize=False
102
  kwargs.update(dict(chatbot=[]))
@@ -105,8 +107,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
105
 
106
 
107
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
108
- def test_client_basic():
109
- return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
110
 
111
 
112
  def run_client_nochat(prompt, prompt_type, max_new_tokens):
@@ -122,12 +124,12 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
122
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
123
  response=md_to_text(res))
124
  print(res_dict)
125
- return res_dict
126
 
127
 
128
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
129
- def test_client_basic_api():
130
- return run_client_nochat_api(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
131
 
132
 
133
  def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
@@ -144,12 +146,12 @@ def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
144
  response=md_to_text(ast.literal_eval(res)['response']),
145
  sources=ast.literal_eval(res)['sources'])
146
  print(res_dict)
147
- return res_dict
148
 
149
 
150
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
151
- def test_client_basic_api_lean():
152
- return run_client_nochat_api_lean(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
153
 
154
 
155
  def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
@@ -166,21 +168,21 @@ def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
166
  response=md_to_text(ast.literal_eval(res)['response']),
167
  sources=ast.literal_eval(res)['sources'])
168
  print(res_dict)
169
- return res_dict
170
 
171
 
172
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
173
- def test_client_basic_api_lean_morestuff():
174
- return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
175
 
176
 
177
- def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
178
  kwargs = dict(
179
  instruction='',
180
  iinput='',
181
  context='',
182
  stream_output=False,
183
- prompt_type='human_bot',
184
  temperature=0.1,
185
  top_p=0.75,
186
  top_k=40,
@@ -211,12 +213,19 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
211
  response=md_to_text(ast.literal_eval(res)['response']),
212
  sources=ast.literal_eval(res)['sources'])
213
  print(res_dict)
214
- return res_dict
215
 
216
 
217
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
218
- def test_client_chat():
219
- return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50,
 
 
 
 
 
 
 
220
  langchain_mode='Disabled')
221
 
222
 
@@ -229,6 +238,7 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchai
229
 
230
 
231
  def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
 
232
  res = client.predict(*tuple(args), api_name='/instruction')
233
  args[-1] += [res[-1]]
234
 
@@ -262,6 +272,46 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
262
  return res_dict, client
263
 
264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  def md_to_text(md, do_md_to_text=True):
266
  if not do_md_to_text:
267
  return md
@@ -271,8 +321,16 @@ def md_to_text(md, do_md_to_text=True):
271
  return soup.get_text()
272
 
273
 
 
 
 
 
 
 
 
 
 
 
 
274
  if __name__ == '__main__':
275
- test_client_basic()
276
- test_client_basic_api()
277
- test_client_basic_api_lean()
278
- test_client_basic_api_lean_morestuff()
 
97
  chunk_size=512,
98
  document_choice=[DocumentChoices.All_Relevant.name],
99
  )
100
+ from generate import eval_func_param_names
101
+ assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
102
  if chat:
103
  # add chatbot output on end. Assumes serialize=False
104
  kwargs.update(dict(chatbot=[]))
 
107
 
108
 
109
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
110
+ def test_client_basic(prompt_type='human_bot'):
111
+ return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
112
 
113
 
114
  def run_client_nochat(prompt, prompt_type, max_new_tokens):
 
124
  res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
125
  response=md_to_text(res))
126
  print(res_dict)
127
+ return res_dict, client
128
 
129
 
130
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
131
+ def test_client_basic_api(prompt_type='human_bot'):
132
+ return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
133
 
134
 
135
  def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
 
146
  response=md_to_text(ast.literal_eval(res)['response']),
147
  sources=ast.literal_eval(res)['sources'])
148
  print(res_dict)
149
+ return res_dict, client
150
 
151
 
152
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
153
+ def test_client_basic_api_lean(prompt_type='human_bot'):
154
+ return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
155
 
156
 
157
  def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
 
168
  response=md_to_text(ast.literal_eval(res)['response']),
169
  sources=ast.literal_eval(res)['sources'])
170
  print(res_dict)
171
+ return res_dict, client
172
 
173
 
174
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
175
+ def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
176
+ return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
177
 
178
 
179
+ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
180
  kwargs = dict(
181
  instruction='',
182
  iinput='',
183
  context='',
184
  stream_output=False,
185
+ prompt_type=prompt_type,
186
  temperature=0.1,
187
  top_p=0.75,
188
  top_k=40,
 
213
  response=md_to_text(ast.literal_eval(res)['response']),
214
  sources=ast.literal_eval(res)['sources'])
215
  print(res_dict)
216
+ return res_dict, client
217
 
218
 
219
  @pytest.mark.skip(reason="For manual use against some server, no server launched")
220
+ def test_client_chat(prompt_type='human_bot'):
221
+ return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
222
+ langchain_mode='Disabled')
223
+
224
+
225
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
226
+ def test_client_chat_stream(prompt_type='human_bot'):
227
+ return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
228
+ stream_output=True, max_new_tokens=512,
229
  langchain_mode='Disabled')
230
 
231
 
 
238
 
239
 
240
  def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
241
+ assert kwargs['chat'], "Chat mode only"
242
  res = client.predict(*tuple(args), api_name='/instruction')
243
  args[-1] += [res[-1]]
244
 
 
272
  return res_dict, client
273
 
274
 
275
+ @pytest.mark.skip(reason="For manual use against some server, no server launched")
276
+ def test_client_nochat_stream(prompt_type='human_bot'):
277
+ return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
278
+ stream_output=True, max_new_tokens=512,
279
+ langchain_mode='Disabled')
280
+
281
+
282
+ def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
283
+ client = get_client(serialize=False)
284
+
285
+ kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
286
+ max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
287
+ return run_client_gen(client, prompt, args, kwargs)
288
+
289
+
290
+ def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
291
+ res_dict = kwargs
292
+ res_dict['prompt'] = prompt
293
+ if not kwargs['stream_output']:
294
+ res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
295
+ res_dict['response'] = res[0]
296
+ print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
297
+ return res_dict, client
298
+ else:
299
+ job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
300
+ while not job.done():
301
+ outputs_list = job.communicator.job.outputs
302
+ if outputs_list:
303
+ res = job.communicator.job.outputs[-1]
304
+ res_dict = ast.literal_eval(res)
305
+ print('Stream: %s' % res_dict['response'])
306
+ time.sleep(0.1)
307
+ res_list = job.outputs()
308
+ assert len(res_list) > 0, "No response, check server"
309
+ res = res_list[-1]
310
+ res_dict = ast.literal_eval(res)
311
+ print('Final: %s' % res_dict['response'])
312
+ return res_dict, client
313
+
314
+
315
  def md_to_text(md, do_md_to_text=True):
316
  if not do_md_to_text:
317
  return md
 
321
  return soup.get_text()
322
 
323
 
324
+ def run_client_many(prompt_type='human_bot'):
325
+ ret1, _ = test_client_chat(prompt_type=prompt_type)
326
+ ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
327
+ ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
328
+ ret4, _ = test_client_basic(prompt_type=prompt_type)
329
+ ret5, _ = test_client_basic_api(prompt_type=prompt_type)
330
+ ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
331
+ ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
332
+ return ret1, ret2, ret3, ret4, ret5, ret6, ret7
333
+
334
+
335
  if __name__ == '__main__':
336
+ run_client_many()
 
 
 
create_data.py CHANGED
@@ -567,7 +567,7 @@ def test_show_prompts():
567
  from prompter import generate_prompt
568
  for data_points in file_points:
569
  for data_point in data_points:
570
- print(generate_prompt(data_point, 'plain', '', False, False)[0])
571
 
572
 
573
  def test_get_open_datasets():
@@ -1571,7 +1571,7 @@ def test_check_stats_data():
1571
 
1572
  llama_type = False
1573
  tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1575
  local_files_only = False
1576
  resume_download = True
1577
  use_auth_token = False
 
567
  from prompter import generate_prompt
568
  for data_points in file_points:
569
  for data_point in data_points:
570
+ print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
571
 
572
 
573
  def test_get_open_datasets():
 
1571
 
1572
  llama_type = False
1573
  tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1574
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
1575
  local_files_only = False
1576
  resume_download = True
1577
  use_auth_token = False
enums.py CHANGED
@@ -22,6 +22,12 @@ class PromptType(Enum):
22
  wizard2 = 16
23
  wizard3 = 17
24
  instruct_simple = 18
 
 
 
 
 
 
25
 
26
 
27
  class DocumentChoices(Enum):
@@ -44,3 +50,35 @@ class LangChainMode(Enum):
44
  MY_DATA = "MyData"
45
  GITHUB_H2OGPT = "github h2oGPT"
46
  H2O_DAI_DOCS = "DriverlessAI docs"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  wizard2 = 16
23
  wizard3 = 17
24
  instruct_simple = 18
25
+ wizard_vicuna = 19
26
+ openai = 20
27
+ openai_chat = 21
28
+ gptj = 22
29
+ prompt_answer_openllama = 23
30
+ vicuna11 = 24
31
 
32
 
33
  class DocumentChoices(Enum):
 
50
  MY_DATA = "MyData"
51
  GITHUB_H2OGPT = "github h2oGPT"
52
  H2O_DAI_DOCS = "DriverlessAI docs"
53
+
54
+
55
+ no_server_str = no_lora_str = no_model_str = '[None/Remove]'
56
+
57
+
58
+ # from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
59
+ model_token_mapping = {
60
+ "gpt-4": 8192,
61
+ "gpt-4-0314": 8192,
62
+ "gpt-4-32k": 32768,
63
+ "gpt-4-32k-0314": 32768,
64
+ "gpt-3.5-turbo": 4096,
65
+ "gpt-3.5-turbo-16k": 16*1024,
66
+ "gpt-3.5-turbo-0301": 4096,
67
+ "text-ada-001": 2049,
68
+ "ada": 2049,
69
+ "text-babbage-001": 2040,
70
+ "babbage": 2049,
71
+ "text-curie-001": 2049,
72
+ "curie": 2049,
73
+ "davinci": 2049,
74
+ "text-davinci-003": 4097,
75
+ "text-davinci-002": 4097,
76
+ "code-davinci-002": 8001,
77
+ "code-davinci-001": 8001,
78
+ "code-cushman-002": 2048,
79
+ "code-cushman-001": 2048,
80
+ }
81
+
82
+
83
+ source_prefix = "Sources [Score | Link]:"
84
+ source_postfix = "End Sources<p>"
finetune.py CHANGED
@@ -5,6 +5,9 @@ from typing import List, Union
5
  import fire
6
  import numpy as np
7
 
 
 
 
8
  from loaders import get_loaders, get_tokenizer
9
  from prompter import generate_prompt, prompt_types, PromptType
10
  from utils import get_githash, copy_code
@@ -182,7 +185,7 @@ def train(
182
  log("num_gpus: %d" % gpus)
183
  log("max mem: %s" % max_memory)
184
 
185
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
186
 
187
  model = model_loader.from_pretrained(
188
  base_model,
@@ -556,13 +559,6 @@ def train(
556
  )
557
  model.config.use_cache = False
558
 
559
- old_state_dict = model.state_dict
560
- from peft import get_peft_model_state_dict
561
-
562
- model.state_dict = (
563
- lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
564
- ).__get__(model, type(model))
565
-
566
  if torch.__version__ >= "2" and sys.platform != "win32":
567
  model = torch.compile(model)
568
  # WIP (not generally replacing layers until pytorch 2.1)
@@ -621,10 +617,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
621
  assert tokenizer is not None
622
  prompt_dict = '' # only for custom prompt_type
623
  assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
624
- full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False)
625
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
626
  if not train_on_inputs:
627
- user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False)
628
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
629
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
630
  if add_eos_token:
@@ -643,7 +639,7 @@ def test_debug():
643
  fire.Fire(train)
644
 
645
 
646
- if __name__ == "__main__":
647
  CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
648
  CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
649
  log(f"""
@@ -674,3 +670,7 @@ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank
674
  "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
675
 
676
  fire.Fire(train)
 
 
 
 
 
5
  import fire
6
  import numpy as np
7
 
8
+ if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
9
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
10
+
11
  from loaders import get_loaders, get_tokenizer
12
  from prompter import generate_prompt, prompt_types, PromptType
13
  from utils import get_githash, copy_code
 
185
  log("num_gpus: %d" % gpus)
186
  log("max mem: %s" % max_memory)
187
 
188
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
189
 
190
  model = model_loader.from_pretrained(
191
  base_model,
 
559
  )
560
  model.config.use_cache = False
561
 
 
 
 
 
 
 
 
562
  if torch.__version__ >= "2" and sys.platform != "win32":
563
  model = torch.compile(model)
564
  # WIP (not generally replacing layers until pytorch 2.1)
 
617
  assert tokenizer is not None
618
  prompt_dict = '' # only for custom prompt_type
619
  assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
620
+ full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
621
  tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
622
  if not train_on_inputs:
623
+ user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
624
  tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
625
  user_prompt_len = len(tokenized_user_prompt["input_ids"])
626
  if add_eos_token:
 
639
  fire.Fire(train)
640
 
641
 
642
+ def entrypoint_main():
643
  CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
644
  CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
645
  log(f"""
 
670
  "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
671
 
672
  fire.Fire(train)
673
+
674
+
675
+ if __name__ == "__main__":
676
+ entrypoint_main()
generate.py CHANGED
@@ -1,27 +1,37 @@
1
  import ast
 
2
  import functools
3
  import glob
4
  import inspect
5
  import queue
6
- import shutil
7
  import sys
8
  import os
9
  import time
10
  import traceback
 
11
  import typing
12
  import warnings
13
  from datetime import datetime
14
  import filelock
 
15
  import psutil
 
 
 
 
 
 
 
16
 
17
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
18
  os.environ['BITSANDBYTES_NOWELCOME'] = '1'
19
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
20
 
21
- from enums import DocumentChoices, LangChainMode
 
22
  from loaders import get_loaders
23
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
24
- import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler
25
 
26
  start_faulthandler()
27
  import_matplotlib()
@@ -34,9 +44,8 @@ from typing import Union
34
  import fire
35
  import torch
36
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
37
- from accelerate import init_empty_weights, infer_auto_device_map
38
 
39
- from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt
40
  from stopping import get_stopping
41
 
42
  eval_extra_columns = ['prompt', 'response', 'score']
@@ -56,9 +65,15 @@ def main(
56
  lora_weights: str = "",
57
  gpu_id: int = 0,
58
  compile_model: bool = True,
59
-
 
60
  prompt_type: Union[int, str] = None,
61
  prompt_dict: typing.Dict = None,
 
 
 
 
 
62
  # input to generation
63
  temperature: float = None,
64
  top_p: float = None,
@@ -88,14 +103,13 @@ def main(
88
  cli: bool = False,
89
  cli_loop: bool = True,
90
  gradio: bool = True,
91
- gradio_avoid_processing_markdown: bool = False,
92
  gradio_offline_level: int = 0,
93
  chat: bool = True,
94
  chat_context: bool = False,
95
  stream_output: bool = True,
96
  show_examples: bool = None,
97
  verbose: bool = False,
98
- h2ocolors: bool = False,
99
  height: int = 600,
100
  show_lora: bool = True,
101
  login_mode_if_model0: bool = False,
@@ -104,16 +118,19 @@ def main(
104
  api_open: bool = False,
105
  allow_api: bool = True,
106
  input_lines: int = 1,
 
107
  auth: typing.List[typing.Tuple[str, str]] = None,
 
 
108
 
109
- sanitize_user_prompt: bool = True,
110
- sanitize_bot_response: bool = True,
111
 
112
  extra_model_options: typing.List[str] = [],
113
  extra_lora_options: typing.List[str] = [],
 
114
 
115
  score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
116
- auto_score: bool = True,
117
 
118
  eval_filename: str = None,
119
  eval_prompts_only_num: int = 0,
@@ -121,6 +138,7 @@ def main(
121
  eval_as_output: bool = False,
122
 
123
  langchain_mode: str = 'Disabled',
 
124
  visible_langchain_modes: list = ['UserData', 'MyData'],
125
  document_choice: list = [DocumentChoices.All_Relevant.name],
126
  user_path: str = None,
@@ -138,7 +156,10 @@ def main(
138
  enable_sources_list: bool = True,
139
  chunk: bool = True,
140
  chunk_size: int = 512,
141
- top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed
 
 
 
142
  n_jobs: int = -1,
143
  enable_captions: bool = True,
144
  captions_model: str = "Salesforce/blip-image-captioning-base",
@@ -157,8 +178,31 @@ def main(
157
  :param lora_weights: LORA weights path/HF link
158
  :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
159
  :param compile_model Whether to compile the model
 
 
 
 
 
 
 
160
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
161
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  :param temperature: generation temperature
163
  :param top_p: generation top_p
164
  :param top_k: generation top_k
@@ -184,13 +228,13 @@ def main(
184
  :param cli: whether to use CLI (non-gradio) interface.
185
  :param cli_loop: whether to loop for CLI (False usually only for testing)
186
  :param gradio: whether to enable gradio, or to enable benchmark mode
187
- :param gradio_avoid_processing_markdown:
188
  :param gradio_offline_level: > 0, then change fonts so full offline
189
  == 1 means backend won't need internet for fonts, but front-end UI might if font not cached
190
  == 2 means backend and frontend don't need internet to download any fonts.
191
  Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
192
  This option further disables google fonts for downloading, which is less intrusive than uploading,
193
  but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
 
194
  :param chat: whether to enable chat mode with chat history
195
  :param chat_context: whether to use extra helpful context if human_bot
196
  :param stream_output: whether to stream output from generate
@@ -205,20 +249,25 @@ def main(
205
  :param api_open: If False, don't let API calls skip gradio queue
206
  :param allow_api: whether to allow API calls at all to gradio server
207
  :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
 
 
208
  :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
209
  e.g. --auth=[('jon','password')] with no spaces
210
- :param sanitize_user_prompt: whether to remove profanity from user input
211
- :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output
 
 
212
  :param extra_model_options: extra models to show in list in gradio
213
  :param extra_lora_options: extra LORA to show in list in gradio
 
214
  :param score_model: which model to score responses (None means no scoring)
215
- :param auto_score: whether to automatically score responses
216
  :param eval_filename: json file to use for evaluation, if None is sharegpt
217
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
218
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
219
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
220
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
221
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
 
222
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
223
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
224
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
@@ -248,12 +297,17 @@ def main(
248
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
249
  :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
250
  :param top_k_docs: number of chunks to give LLM
 
 
 
 
 
251
  :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
252
  :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
253
  :param captions_model: Which model to use for captions.
254
- captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable
255
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
256
- captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
257
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
258
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
259
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
@@ -262,6 +316,32 @@ def main(
262
  :param enable_ocr: Whether to support OCR on images
263
  :return:
264
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0')))
266
  is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0')))
267
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
@@ -276,7 +356,8 @@ def main(
276
 
277
  # allow set token directly
278
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
279
- allow_upload_to_user_data = bool(int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
 
280
  allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
281
  height = int(os.environ.get("HEIGHT", height))
282
  h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
@@ -289,6 +370,12 @@ def main(
289
  if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
290
  visible_langchain_modes += [langchain_mode]
291
 
 
 
 
 
 
 
292
  if is_public:
293
  allow_upload_to_user_data = False
294
  input_lines = 1 # ensure set, for ease of use
@@ -297,26 +384,50 @@ def main(
297
  top_k = 70 if top_k is None else top_k
298
  if is_hf:
299
  do_sample = True if do_sample is None else do_sample
 
300
  else:
301
  # by default don't sample, too chatty
302
  do_sample = False if do_sample is None else do_sample
 
303
 
304
  if memory_restriction_level == 2:
305
- if not base_model:
306
  base_model = 'h2oai/h2ogpt-oasst1-512-12b'
307
  # don't set load_8bit if passed base_model, doesn't always work so can't just override
308
  load_8bit = True
309
  load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
310
- else:
311
- base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
312
  if memory_restriction_level >= 2:
313
  load_8bit = True
314
  load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
315
  if hf_embedding_model is None:
316
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  if is_hf:
318
  # must override share if in spaces
319
  share = False
 
 
 
 
 
320
  save_dir = os.getenv('SAVE_DIR', save_dir)
321
  score_model = os.getenv('SCORE_MODEL', score_model)
322
  if score_model == 'None' or score_model is None:
@@ -335,7 +446,7 @@ def main(
335
  torch.backends.cudnn.benchmark = True
336
  torch.backends.cudnn.enabled = False
337
  torch.set_default_dtype(torch.float32)
338
- if psutil.virtual_memory().available < 94 * 1024 ** 3:
339
  # 12B uses ~94GB
340
  # 6.9B uses ~47GB
341
  base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
@@ -360,8 +471,8 @@ def main(
360
 
361
  if offload_folder:
362
  makedirs(offload_folder)
363
-
364
- user_set_max_new_tokens = max_new_tokens is not None
365
 
366
  placeholder_instruction, placeholder_input, \
367
  stream_output, show_examples, \
@@ -386,11 +497,12 @@ def main(
386
  verbose,
387
  )
388
 
 
389
  locals_dict = locals()
390
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
391
  if verbose:
392
  print(f"Generating model with params:\n{locals_print}", flush=True)
393
- print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
394
 
395
  if langchain_mode != "Disabled":
396
  # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
@@ -404,7 +516,7 @@ def main(
404
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
405
  if os.path.isdir(gpath1):
406
  print("Removing old MyData: %s" % gpath1, flush=True)
407
- shutil.rmtree(gpath1)
408
  continue
409
  if langchain_mode1 in ['All']:
410
  # FIXME: All should be avoided until scans over each db, shouldn't be separate db
@@ -430,6 +542,10 @@ def main(
430
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
431
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
432
 
 
 
 
 
433
  if cli:
434
  from cli import run_cli
435
  return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
@@ -441,20 +557,68 @@ def main(
441
  from gradio_runner import go_gradio
442
 
443
  # get default model
444
- all_kwargs = locals().copy()
445
- if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']:
446
- model0, tokenizer0, device = get_model(reward_type=False,
447
- **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs))
448
- else:
449
- # if empty model, then don't load anything, just get gradio up
450
- model0, tokenizer0, device = None, None, None
451
- model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
  # get score model
 
454
  smodel, stokenizer, sdevice = get_score_model(reward_type=True,
455
  **get_kwargs(get_score_model, exclude_names=['reward_type'],
456
  **all_kwargs))
457
- score_model_state0 = [smodel, stokenizer, sdevice, score_model]
 
 
458
 
459
  if enable_captions:
460
  if pre_load_caption_model:
@@ -469,34 +633,33 @@ def main(
469
  go_gradio(**locals())
470
 
471
 
472
- def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
473
- gpu_id=0,
474
- use_auth_token=False,
475
- trust_remote_code=True,
476
- offload_folder=None,
477
- triton_attn=False,
478
- long_sequence=True,
479
- ):
480
- """
481
- Ensure model gets on correct device
482
- :param base_model:
483
- :param model_loader:
484
- :param load_half:
485
- :param model_kwargs:
486
- :param reward_type:
487
- :param gpu_id:
488
- :param use_auth_token:
489
- :param trust_remote_code:
490
- :param offload_folder:
491
- :param triton_attn:
492
- :param long_sequence:
493
- :return:
494
- """
495
  with init_empty_weights():
496
  from transformers import AutoConfig
497
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
498
- trust_remote_code=trust_remote_code,
499
- offload_folder=offload_folder)
 
 
 
 
 
 
 
 
 
 
 
 
500
  if triton_attn and 'mpt-' in base_model.lower():
501
  config.attn_config['attn_impl'] = 'triton'
502
  if long_sequence:
@@ -504,18 +667,36 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
504
  config.update({"max_seq_len": 83968})
505
  if 'mosaicml/mpt-7b-chat' in base_model.lower():
506
  config.update({"max_seq_len": 4096})
507
- if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
 
 
 
508
  model = AutoModel.from_config(
509
  config,
 
510
  )
511
  else:
512
  # can't infer
513
  model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
514
 
515
  if model is not None:
516
  # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
517
  # NOTE: Some models require avoiding sharding some layers,
518
  # then would pass no_split_module_classes and give list of those layers.
 
519
  device_map = infer_auto_device_map(
520
  model,
521
  dtype=torch.float16 if load_half else torch.float32,
@@ -567,12 +748,59 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
567
  return model
568
 
569
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
  def get_model(
571
  load_8bit: bool = False,
572
  load_4bit: bool = False,
573
  load_half: bool = True,
574
  infer_devices: bool = True,
575
  base_model: str = '',
 
576
  tokenizer_base_model: str = '',
577
  lora_weights: str = "",
578
  gpu_id: int = 0,
@@ -596,6 +824,7 @@ def get_model(
596
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
597
  So it is not the default
598
  :param base_model: name/path of base model
 
599
  :param tokenizer_base_model: name/path of tokenizer
600
  :param lora_weights: name/path
601
  :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
@@ -611,11 +840,120 @@ def get_model(
611
  """
612
  if verbose:
613
  print("Get %s model" % base_model, flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  if base_model in non_hf_types:
615
  from gpt4all_llm import get_model_tokenizer_gpt4all
616
  model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
617
  return model, tokenizer, device
618
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
619
  if lora_weights is not None and lora_weights.strip():
620
  if verbose:
621
  print("Get %s lora weights" % lora_weights, flush=True)
@@ -630,30 +968,13 @@ def get_model(
630
  "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
631
  )
632
 
633
- from transformers import AutoConfig
634
- config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
635
- trust_remote_code=trust_remote_code,
636
- offload_folder=offload_folder)
637
- llama_type_from_config = 'llama' in str(config).lower()
638
- llama_type_from_name = "llama" in base_model.lower()
639
- llama_type = llama_type_from_config or llama_type_from_name
640
- if llama_type:
641
- if verbose:
642
- print("Detected as llama type from"
643
- " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
644
 
645
- model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type)
646
- if not tokenizer_base_model:
647
- tokenizer_base_model = base_model
648
 
649
  if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
650
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
651
- local_files_only=local_files_only,
652
- resume_download=resume_download,
653
- use_auth_token=use_auth_token,
654
- trust_remote_code=trust_remote_code,
655
- offload_folder=offload_folder,
656
- )
657
  else:
658
  tokenizer = tokenizer_loader
659
 
@@ -677,7 +998,7 @@ def get_model(
677
  load_in_4bit=load_4bit,
678
  device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
679
  ))
680
- if 'mpt-' in base_model.lower() and gpu_id >= 0:
681
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
682
 
683
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
@@ -688,25 +1009,30 @@ def get_model(
688
 
689
  if not lora_weights:
690
  with torch.device(device):
 
691
  if infer_devices:
 
692
  model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
 
693
  gpu_id=gpu_id,
694
- use_auth_token=use_auth_token,
695
- trust_remote_code=trust_remote_code,
696
- offload_folder=offload_folder,
697
  )
698
  else:
 
699
  if load_half and not (load_8bit or load_4bit):
700
  model = model_loader.from_pretrained(
701
  base_model,
 
702
  **model_kwargs).half()
703
  else:
704
  model = model_loader.from_pretrained(
705
  base_model,
 
706
  **model_kwargs)
707
  elif load_8bit or load_4bit:
 
708
  model = model_loader.from_pretrained(
709
  base_model,
 
710
  **model_kwargs
711
  )
712
  from peft import PeftModel # loads cuda, so avoid in global scope
@@ -723,8 +1049,10 @@ def get_model(
723
  )
724
  else:
725
  with torch.device(device):
 
726
  model = model_loader.from_pretrained(
727
  base_model,
 
728
  **model_kwargs
729
  )
730
  from peft import PeftModel # loads cuda, so avoid in global scope
@@ -758,6 +1086,15 @@ def get_model(
758
  if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
759
  model = torch.compile(model)
760
 
 
 
 
 
 
 
 
 
 
761
  if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int):
762
  tokenizer.model_max_length = config.max_seq_len
763
  elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
@@ -767,8 +1104,9 @@ def get_model(
767
  if verbose:
768
  print("Could not determine model_max_length, setting to 2048", flush=True)
769
  tokenizer.model_max_length = 2048
770
-
771
- return model, tokenizer, device
 
772
 
773
 
774
  def pop_unused_model_kwargs(model_kwargs):
@@ -790,6 +1128,7 @@ def get_score_model(score_model: str = None,
790
  load_half: bool = True,
791
  infer_devices: bool = True,
792
  base_model: str = '',
 
793
  tokenizer_base_model: str = '',
794
  lora_weights: str = "",
795
  gpu_id: int = 0,
@@ -811,6 +1150,7 @@ def get_score_model(score_model: str = None,
811
  base_model = score_model.strip()
812
  tokenizer_base_model = ''
813
  lora_weights = ''
 
814
  llama_type = False
815
  compile_model = False
816
  smodel, stokenizer, sdevice = get_model(reward_type=True,
@@ -877,9 +1217,12 @@ def evaluate_from_str(
877
  debug=False,
878
  concurrency_count=None,
879
  save_dir=None,
880
- sanitize_bot_response=True,
881
  model_state0=None,
882
  memory_restriction_level=None,
 
 
 
883
  raise_generate_gpu_exceptions=None,
884
  chat_context=None,
885
  lora_weights=None,
@@ -890,20 +1233,26 @@ def evaluate_from_str(
890
  use_openai_embedding=None,
891
  use_openai_model=None,
892
  hf_embedding_model=None,
893
- chunk=None,
894
- chunk_size=None,
895
  db_type=None,
896
  n_jobs=None,
897
  first_para=None,
898
  text_limit=None,
899
  verbose=False,
900
  cli=False,
 
 
 
 
 
 
 
901
  ):
902
  if isinstance(user_kwargs, str):
903
  user_kwargs = ast.literal_eval(user_kwargs)
904
  # only used for submit_nochat_api
905
  user_kwargs['chat'] = False
906
- user_kwargs['stream_output'] = False
 
907
  if 'langchain_mode' not in user_kwargs:
908
  # if user doesn't specify, then assume disabled, not use default
909
  user_kwargs['langchain_mode'] = 'Disabled'
@@ -926,6 +1275,9 @@ def evaluate_from_str(
926
  sanitize_bot_response=sanitize_bot_response,
927
  model_state0=model_state0,
928
  memory_restriction_level=memory_restriction_level,
 
 
 
929
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
930
  chat_context=chat_context,
931
  lora_weights=lora_weights,
@@ -942,6 +1294,13 @@ def evaluate_from_str(
942
  text_limit=text_limit,
943
  verbose=verbose,
944
  cli=cli,
 
 
 
 
 
 
 
945
  )
946
  try:
947
  for ret1 in ret:
@@ -986,9 +1345,12 @@ def evaluate(
986
  debug=False,
987
  concurrency_count=None,
988
  save_dir=None,
989
- sanitize_bot_response=True,
990
  model_state0=None,
991
  memory_restriction_level=None,
 
 
 
992
  raise_generate_gpu_exceptions=None,
993
  chat_context=None,
994
  lora_weights=None,
@@ -1005,6 +1367,13 @@ def evaluate(
1005
  text_limit=None,
1006
  verbose=False,
1007
  cli=False,
 
 
 
 
 
 
 
1008
  ):
1009
  # ensure passed these
1010
  assert concurrency_count is not None
@@ -1025,29 +1394,58 @@ def evaluate(
1025
  locals_dict = locals().copy()
1026
  locals_dict.pop('model_state', None)
1027
  locals_dict.pop('model_state0', None)
 
1028
  print(locals_dict)
1029
 
1030
- no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation"
 
1031
 
 
 
1032
  if model_state0 is None:
1033
  # e.g. for no gradio case, set dummy value, else should be set
1034
- model_state0 = [None, None, None, None]
1035
-
1036
- if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str):
1037
- # try to free-up original model (i.e. list was passed as reference)
1038
- if model_state0 is not None and model_state0[0] is not None:
1039
- model_state0[0].cpu()
1040
- model_state0[0] = None
1041
- # try to free-up original tokenizer (i.e. list was passed as reference)
1042
- if model_state0 is not None and model_state0[1] is not None:
1043
- model_state0[1] = None
1044
- clear_torch_cache()
1045
- model, tokenizer, device, base_model = model_state
1046
- elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None:
1047
- assert isinstance(model_state[0], str)
1048
- model, tokenizer, device, base_model = model_state0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1049
  else:
1050
  raise AssertionError(no_model_msg)
 
 
 
 
 
 
 
 
 
 
 
1051
 
1052
  if base_model is None:
1053
  raise AssertionError(no_model_msg)
@@ -1061,10 +1459,48 @@ def evaluate(
1061
  instruction = instruction_nochat
1062
  iinput = iinput_nochat
1063
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1064
  if not context:
1065
  # get hidden context if have one
1066
  context = get_context(chat_context, prompt_type)
1067
 
 
 
 
 
 
 
 
 
1068
  prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output)
1069
  data_point = dict(context=context, instruction=instruction, input=iinput)
1070
  prompt = prompter.generate_prompt(data_point)
@@ -1077,13 +1513,30 @@ def evaluate(
1077
  db1 = dbs[langchain_mode]
1078
  else:
1079
  db1 = None
1080
- if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types:
 
 
 
 
1081
  query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
1082
  outr = ""
1083
  # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
1084
  from gpt_langchain import run_qa_db
 
 
 
 
 
 
 
 
 
 
 
 
1085
  for r in run_qa_db(query=query,
1086
  model_name=base_model, model=model, tokenizer=tokenizer,
 
1087
  stream_output=stream_output,
1088
  prompter=prompter,
1089
  load_db_if_exists=load_db_if_exists,
@@ -1103,29 +1556,31 @@ def evaluate(
1103
  db_type=db_type,
1104
  top_k_docs=top_k_docs,
1105
 
1106
- # gen_hyper:
1107
- do_sample=do_sample,
1108
- temperature=temperature,
1109
- repetition_penalty=repetition_penalty,
1110
- top_k=top_k,
1111
- top_p=top_p,
1112
- num_beams=num_beams,
1113
- min_new_tokens=min_new_tokens,
1114
- max_new_tokens=max_new_tokens,
1115
- early_stopping=early_stopping,
1116
- max_time=max_time,
1117
- num_return_sequences=num_return_sequences,
1118
 
1119
  prompt_type=prompt_type,
1120
  prompt_dict=prompt_dict,
1121
  n_jobs=n_jobs,
1122
  verbose=verbose,
1123
  cli=cli,
 
 
 
 
 
 
 
1124
  ):
1125
  outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
1126
  yield dict(response=outr, sources=extra)
1127
  if save_dir:
1128
- save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
 
 
 
 
 
 
1129
  if verbose:
1130
  print(
1131
  'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
@@ -1138,6 +1593,266 @@ def evaluate(
1138
  clear_torch_cache()
1139
  return
1140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1141
  if isinstance(tokenizer, str):
1142
  # pipeline
1143
  if tokenizer == "summarization":
@@ -1151,37 +1866,37 @@ def evaluate(
1151
  assert src_lang is not None
1152
  tokenizer.src_lang = languages_covered()[src_lang]
1153
 
1154
- if chat:
1155
- # override, ignore user change
1156
- num_return_sequences = 1
1157
- stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device)
1158
- _, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level,
1159
- model_max_length=tokenizer.model_max_length)
1160
- prompt = prompt[-max_prompt_length:]
1161
- inputs = tokenizer(prompt,
1162
- return_tensors="pt",
1163
- truncation=True,
1164
- max_length=max_length_tokenize)
1165
- if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
1166
- print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
1167
  if debug and len(inputs["input_ids"]) > 0:
1168
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1169
  input_ids = inputs["input_ids"].to(device)
1170
  # CRITICAL LIMIT else will fail
1171
  max_max_tokens = tokenizer.model_max_length
1172
- max_input_tokens = max_max_tokens - max_new_tokens
 
1173
  input_ids = input_ids[:, -max_input_tokens:]
1174
- generation_config = GenerationConfig(
1175
- temperature=float(temperature),
1176
- top_p=float(top_p),
1177
- top_k=top_k,
1178
- num_beams=num_beams,
1179
- do_sample=do_sample,
1180
- repetition_penalty=float(repetition_penalty),
1181
- num_return_sequences=num_return_sequences,
1182
- renormalize_logits=True,
1183
- remove_invalid_values=True,
1184
- )
 
 
 
 
 
 
 
 
1185
 
1186
  gen_kwargs = dict(input_ids=input_ids,
1187
  generation_config=generation_config,
@@ -1200,7 +1915,10 @@ def evaluate(
1200
  tgt_lang = languages_covered()[tgt_lang]
1201
  gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1202
  else:
1203
- gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id))
 
 
 
1204
 
1205
  decoder_kwargs = dict(skip_special_tokens=True,
1206
  clean_up_tokenization_spaces=True)
@@ -1216,7 +1934,8 @@ def evaluate(
1216
  )
1217
 
1218
  with torch.no_grad():
1219
- context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast
 
1220
  with context_class_cast(device):
1221
  # protection for gradio not keeping track of closed users,
1222
  # else hit bitsandbytes lack of thread safety:
@@ -1299,7 +2018,11 @@ def evaluate(
1299
  if outputs and len(outputs) >= 1:
1300
  decoded_output = prompt + outputs[0]
1301
  if save_dir and decoded_output:
1302
- save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
 
 
 
 
1303
  if verbose:
1304
  print('Post-Generate: %s decoded_output: %s' % (
1305
  str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
@@ -1318,6 +2041,7 @@ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=20
1318
  if memory_restriction_level > 0:
1319
  max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
1320
  else:
 
1321
  max_length_tokenize = model_max_length - 256
1322
  cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1323
  output_smallest = 30 * 4
@@ -1422,7 +2146,8 @@ def get_generate_params(model_lower, chat,
1422
  if model_lower:
1423
  print(f"Using Model {model_lower}", flush=True)
1424
  else:
1425
- print("No model defined yet", flush=True)
 
1426
 
1427
  min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
1428
  early_stopping = early_stopping if early_stopping is not None else False
@@ -1478,12 +2203,14 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
1478
  use_defaults = True
1479
  else:
1480
  if chat:
1481
- placeholder_instruction = "Enter a question or imperative."
1482
  else:
1483
  placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
1484
  placeholder_input = ""
1485
- if model_lower:
1486
- # default is plain, because might relly upon trust_remote_code to handle prompting
 
 
1487
  prompt_type = prompt_type or 'plain'
1488
  else:
1489
  prompt_type = ''
@@ -1591,7 +2318,7 @@ y = np.random.randint(0, 1, 100)
1591
 
1592
  # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format
1593
  prompt_dict, error0 = get_prompt(prompt_type, prompt_dict,
1594
- chat=False, context='', reduced=False, return_dict=True)
1595
  if error0:
1596
  raise RuntimeError("Prompt wrong: %s" % error0)
1597
 
@@ -1644,7 +2371,8 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
1644
  if 'Expected all tensors to be on the same device' in str(e) or \
1645
  'expected scalar type Half but found Float' in str(e) or \
1646
  'probability tensor contains either' in str(e) or \
1647
- 'cublasLt ran into an error!' in str(e):
 
1648
  print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
1649
  flush=True)
1650
  traceback.print_exc()
@@ -1677,25 +2405,101 @@ def check_locals(**kwargs):
1677
  assert k in kwargs, "Missing %s" % k
1678
 
1679
 
 
 
 
 
 
 
 
1680
  def get_max_max_new_tokens(model_state, **kwargs):
1681
- if kwargs['max_new_tokens'] and kwargs['user_set_max_new_tokens']:
1682
- max_max_new_tokens = kwargs['max_new_tokens']
 
 
 
 
 
 
 
1683
  elif kwargs['memory_restriction_level'] == 1:
1684
- max_max_new_tokens = 768
1685
  elif kwargs['memory_restriction_level'] == 2:
1686
- max_max_new_tokens = 512
1687
  elif kwargs['memory_restriction_level'] >= 3:
1688
- max_max_new_tokens = 256
1689
  else:
1690
- if not isinstance(model_state[1], str):
1691
- max_max_new_tokens = model_state[1].model_max_length
1692
- else:
1693
- # FIXME: Need to update after new model loaded, so user can control with slider
1694
- max_max_new_tokens = 2048
1695
- return max_max_new_tokens
1696
 
1697
 
1698
- if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1699
  """
1700
  Examples:
1701
 
@@ -1726,3 +2530,7 @@ if __name__ == "__main__":
1726
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
1727
  """
1728
  fire.Fire(main)
 
 
 
 
 
1
  import ast
2
+ import copy
3
  import functools
4
  import glob
5
  import inspect
6
  import queue
 
7
  import sys
8
  import os
9
  import time
10
  import traceback
11
+ import types
12
  import typing
13
  import warnings
14
  from datetime import datetime
15
  import filelock
16
+ import requests
17
  import psutil
18
+ from requests import ConnectTimeout, JSONDecodeError
19
+ from urllib3.exceptions import ConnectTimeoutError, MaxRetryError, ConnectionError
20
+ from requests.exceptions import ConnectionError as ConnectionError2
21
+ from requests.exceptions import ReadTimeout as ReadTimeout2
22
+
23
+ if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
24
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
25
 
26
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
27
  os.environ['BITSANDBYTES_NOWELCOME'] = '1'
28
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
29
 
30
+ from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
31
+ source_postfix
32
  from loaders import get_loaders
33
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
34
+ import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove
35
 
36
  start_faulthandler()
37
  import_matplotlib()
 
44
  import fire
45
  import torch
46
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
 
47
 
48
+ from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
49
  from stopping import get_stopping
50
 
51
  eval_extra_columns = ['prompt', 'response', 'score']
 
65
  lora_weights: str = "",
66
  gpu_id: int = 0,
67
  compile_model: bool = True,
68
+ use_cache: bool = None,
69
+ inference_server: str = "",
70
  prompt_type: Union[int, str] = None,
71
  prompt_dict: typing.Dict = None,
72
+
73
+ model_lock: typing.List[typing.Dict[str, str]] = None,
74
+ model_lock_columns: int = None,
75
+ fail_if_cannot_connect: bool = False,
76
+
77
  # input to generation
78
  temperature: float = None,
79
  top_p: float = None,
 
103
  cli: bool = False,
104
  cli_loop: bool = True,
105
  gradio: bool = True,
 
106
  gradio_offline_level: int = 0,
107
  chat: bool = True,
108
  chat_context: bool = False,
109
  stream_output: bool = True,
110
  show_examples: bool = None,
111
  verbose: bool = False,
112
+ h2ocolors: bool = True,
113
  height: int = 600,
114
  show_lora: bool = True,
115
  login_mode_if_model0: bool = False,
 
118
  api_open: bool = False,
119
  allow_api: bool = True,
120
  input_lines: int = 1,
121
+ gradio_size: str = None,
122
  auth: typing.List[typing.Tuple[str, str]] = None,
123
+ max_max_time=None,
124
+ max_max_new_tokens=None,
125
 
126
+ sanitize_user_prompt: bool = False,
127
+ sanitize_bot_response: bool = False,
128
 
129
  extra_model_options: typing.List[str] = [],
130
  extra_lora_options: typing.List[str] = [],
131
+ extra_server_options: typing.List[str] = [],
132
 
133
  score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
 
134
 
135
  eval_filename: str = None,
136
  eval_prompts_only_num: int = 0,
 
138
  eval_as_output: bool = False,
139
 
140
  langchain_mode: str = 'Disabled',
141
+ force_langchain_evaluate: bool = False,
142
  visible_langchain_modes: list = ['UserData', 'MyData'],
143
  document_choice: list = [DocumentChoices.All_Relevant.name],
144
  user_path: str = None,
 
156
  enable_sources_list: bool = True,
157
  chunk: bool = True,
158
  chunk_size: int = 512,
159
+ top_k_docs: int = None,
160
+ reverse_docs: bool = True,
161
+ auto_reduce_chunks: bool = True,
162
+ max_chunks: int = 100,
163
  n_jobs: int = -1,
164
  enable_captions: bool = True,
165
  captions_model: str = "Salesforce/blip-image-captioning-base",
 
178
  :param lora_weights: LORA weights path/HF link
179
  :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
180
  :param compile_model Whether to compile the model
181
+ :param use_cache: Whether to use caching in model (some models fail when multiple threads use)
182
+ :param inference_server: Consume base_model as type of model at this address
183
+ Address can be text-generation-server hosting that base_model
184
+ e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b
185
+ Or Address can be "openai_chat" or "openai" for OpenAI API
186
+ e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
187
+ e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
188
  :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
189
  :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
190
+ :param model_lock: Lock models to specific combinations, for ease of use and extending to many models
191
+ Only used if gradio = True
192
+ List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict
193
+ If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict
194
+ Can specify model_lock instead of those items on CLI
195
+ As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py.
196
+ Also, tokenizer_base_model and lora_weights are optional.
197
+ Also, inference_server is optional if loading model from local system.
198
+ All models provided will automatically appear in compare model mode
199
+ Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled
200
+ :param model_lock_columns: How many columns to show if locking models (and so showing all at once)
201
+ If None, then defaults to up to 3
202
+ if -1, then all goes into 1 row
203
+ Maximum value is 4 due to non-dynamic gradio rendering elements
204
+ :param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore.
205
+ Useful when many endpoints and want to just see what works, but still have to wait for timeout.
206
  :param temperature: generation temperature
207
  :param top_p: generation top_p
208
  :param top_k: generation top_k
 
228
  :param cli: whether to use CLI (non-gradio) interface.
229
  :param cli_loop: whether to loop for CLI (False usually only for testing)
230
  :param gradio: whether to enable gradio, or to enable benchmark mode
 
231
  :param gradio_offline_level: > 0, then change fonts so full offline
232
  == 1 means backend won't need internet for fonts, but front-end UI might if font not cached
233
  == 2 means backend and frontend don't need internet to download any fonts.
234
  Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
235
  This option further disables google fonts for downloading, which is less intrusive than uploading,
236
  but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
237
+ Also set --share=False to avoid sharing a gradio live link.
238
  :param chat: whether to enable chat mode with chat history
239
  :param chat_context: whether to use extra helpful context if human_bot
240
  :param stream_output: whether to stream output from generate
 
249
  :param api_open: If False, don't let API calls skip gradio queue
250
  :param allow_api: whether to allow API calls at all to gradio server
251
  :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
252
+ :param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large".
253
+ Small useful for many chatbots in model_lock mode
254
  :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
255
  e.g. --auth=[('jon','password')] with no spaces
256
+ :param max_max_time: Maximum max_time for gradio slider
257
+ :param max_max_new_tokens: Maximum max_new_tokens for gradio slider
258
+ :param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing)
259
+ :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow)
260
  :param extra_model_options: extra models to show in list in gradio
261
  :param extra_lora_options: extra LORA to show in list in gradio
262
+ :param extra_server_options: extra servers to show in list in gradio
263
  :param score_model: which model to score responses (None means no scoring)
 
264
  :param eval_filename: json file to use for evaluation, if None is sharegpt
265
  :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
266
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
267
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
268
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
269
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
270
+ :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
271
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
272
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
273
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
 
297
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
298
  :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
299
  :param top_k_docs: number of chunks to give LLM
300
+ :param reverse_docs: whether to reverse docs order so most relevant is closest to question.
301
+ Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
302
+ But smaller 6_9 models fail to use newest context and can get stuck on old information.
303
+ :param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt
304
+ :param max_chunks: If top_k_docs=-1, maximum number of chunks to allow
305
  :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
306
  :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
307
  :param captions_model: Which model to use for captions.
308
+ captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable
309
  captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
310
+ captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
311
  Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
312
  :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
313
  parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
 
316
  :param enable_ocr: Whether to support OCR on images
317
  :return:
318
  """
319
+ if base_model is None:
320
+ base_model = ''
321
+ if tokenizer_base_model is None:
322
+ tokenizer_base_model = ''
323
+ if lora_weights is None:
324
+ lora_weights = ''
325
+ if inference_server is None:
326
+ inference_server = ''
327
+
328
+ # listen to env if set
329
+ model_lock = os.getenv('model_lock', str(model_lock))
330
+ model_lock = ast.literal_eval(model_lock)
331
+
332
+ if model_lock:
333
+ assert gradio, "model_lock only supported for gradio=True"
334
+ if len(model_lock) > 1:
335
+ assert chat, "model_lock only works for multiple models for chat=True"
336
+ assert not cli, "model_lock only supported for cli=False"
337
+ assert not (not cli and not gradio), "model_lock only supported for eval (cli=gradio=False)"
338
+ assert not base_model, "Don't specify model_lock and base_model"
339
+ assert not tokenizer_base_model, "Don't specify model_lock and tokenizer_base_model"
340
+ assert not lora_weights, "Don't specify model_lock and lora_weights"
341
+ assert not inference_server, "Don't specify model_lock and inference_server"
342
+ # assert not prompt_type, "Don't specify model_lock and prompt_type"
343
+ # assert not prompt_dict, "Don't specify model_lock and prompt_dict"
344
+
345
  is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0')))
346
  is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0')))
347
  is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
 
356
 
357
  # allow set token directly
358
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
359
+ allow_upload_to_user_data = bool(
360
+ int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
361
  allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
362
  height = int(os.environ.get("HEIGHT", height))
363
  h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
 
370
  if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
371
  visible_langchain_modes += [langchain_mode]
372
 
373
+ # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
374
+ if LangChainMode.MY_DATA.value not in visible_langchain_modes:
375
+ allow_upload_to_my_data = False
376
+ if LangChainMode.USER_DATA.value not in visible_langchain_modes:
377
+ allow_upload_to_user_data = False
378
+
379
  if is_public:
380
  allow_upload_to_user_data = False
381
  input_lines = 1 # ensure set, for ease of use
 
384
  top_k = 70 if top_k is None else top_k
385
  if is_hf:
386
  do_sample = True if do_sample is None else do_sample
387
+ top_k_docs = 3 if top_k_docs is None else top_k_docs
388
  else:
389
  # by default don't sample, too chatty
390
  do_sample = False if do_sample is None else do_sample
391
+ top_k_docs = 4 if top_k_docs is None else top_k_docs
392
 
393
  if memory_restriction_level == 2:
394
+ if not base_model and not inference_server:
395
  base_model = 'h2oai/h2ogpt-oasst1-512-12b'
396
  # don't set load_8bit if passed base_model, doesn't always work so can't just override
397
  load_8bit = True
398
  load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
399
+ elif not inference_server:
400
+ top_k_docs = 10 if top_k_docs is None else top_k_docs
401
  if memory_restriction_level >= 2:
402
  load_8bit = True
403
  load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
404
  if hf_embedding_model is None:
405
  hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
406
+ top_k_docs = 3 if top_k_docs is None else top_k_docs
407
+ if top_k_docs is None:
408
+ top_k_docs = 3
409
+ if is_public:
410
+ if not max_time:
411
+ max_time = 60 * 2
412
+ if not max_max_time:
413
+ max_max_time = max_time
414
+ if not max_new_tokens:
415
+ max_new_tokens = 256
416
+ if not max_max_new_tokens:
417
+ max_max_new_tokens = 256
418
+ else:
419
+ if not max_max_time:
420
+ max_max_time = 60 * 20
421
+ if not max_max_new_tokens:
422
+ max_max_new_tokens = 512
423
  if is_hf:
424
  # must override share if in spaces
425
  share = False
426
+ if not max_time:
427
+ max_time = 60 * 1
428
+ if not max_max_time:
429
+ max_max_time = max_time
430
+ # HF accounted for later in get_max_max_new_tokens()
431
  save_dir = os.getenv('SAVE_DIR', save_dir)
432
  score_model = os.getenv('SCORE_MODEL', score_model)
433
  if score_model == 'None' or score_model is None:
 
446
  torch.backends.cudnn.benchmark = True
447
  torch.backends.cudnn.enabled = False
448
  torch.set_default_dtype(torch.float32)
449
+ if psutil.virtual_memory().available < 94 * 1024 ** 3 and not inference_server:
450
  # 12B uses ~94GB
451
  # 6.9B uses ~47GB
452
  base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
 
471
 
472
  if offload_folder:
473
  makedirs(offload_folder)
474
+ if user_path:
475
+ makedirs(user_path)
476
 
477
  placeholder_instruction, placeholder_input, \
478
  stream_output, show_examples, \
 
497
  verbose,
498
  )
499
 
500
+ git_hash = get_githash()
501
  locals_dict = locals()
502
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
503
  if verbose:
504
  print(f"Generating model with params:\n{locals_print}", flush=True)
505
+ print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), git_hash), flush=True)
506
 
507
  if langchain_mode != "Disabled":
508
  # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
 
516
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
517
  if os.path.isdir(gpath1):
518
  print("Removing old MyData: %s" % gpath1, flush=True)
519
+ remove(gpath1)
520
  continue
521
  if langchain_mode1 in ['All']:
522
  # FIXME: All should be avoided until scans over each db, shouldn't be separate db
 
542
  assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
543
  assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
544
 
545
+ model_state_none = dict(model=None, tokenizer=None, device=None,
546
+ base_model=None, tokenizer_base_model=None, lora_weights=None,
547
+ inference_server=None, prompt_type=None, prompt_dict=None)
548
+
549
  if cli:
550
  from cli import run_cli
551
  return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
 
557
  from gradio_runner import go_gradio
558
 
559
  # get default model
560
+ model_states = []
561
+ model_list = [dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights,
562
+ inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict)]
563
+ model_list0 = copy.deepcopy(model_list) # just strings, safe to deepcopy
564
+ model_state0 = model_state_none.copy()
565
+ assert len(model_state_none) == len(model_state0)
566
+ if model_lock:
567
+ model_list = model_lock
568
+ for model_dict in reversed(model_list):
569
+ # do reverse, so first is default base_model etc., so some logic works in go_gradio() more easily
570
+ # handles defaults user didn't have to pass
571
+ model_dict['base_model'] = base_model = model_dict.get('base_model', '')
572
+ model_dict['tokenizer_base_model'] = tokenizer_base_model = model_dict.get('tokenizer_base_model', '')
573
+ model_dict['lora_weights'] = lora_weights = model_dict.get('lora_weights', '')
574
+ model_dict['inference_server'] = inference_server = model_dict.get('inference_server', '')
575
+ prompt_type = model_dict.get('prompt_type', model_list0[0]['prompt_type']) # don't use mutated value
576
+ # try to infer, ignore empty initial state leading to get_generate_params -> 'plain'
577
+ if model_dict.get('prompt_type') is None:
578
+ model_lower = base_model.lower()
579
+ if model_lower in inv_prompt_type_to_model_lower:
580
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
581
+ prompt_dict, error0 = get_prompt(prompt_type, '',
582
+ chat=False, context='', reduced=False, making_context=False,
583
+ return_dict=True)
584
+ model_dict['prompt_type'] = prompt_type
585
+ model_dict['prompt_dict'] = prompt_dict = model_dict.get('prompt_dict', prompt_dict)
586
+ all_kwargs = locals().copy()
587
+ if base_model and not login_mode_if_model0:
588
+ model0, tokenizer0, device = get_model(reward_type=False,
589
+ **get_kwargs(get_model, exclude_names=['reward_type'],
590
+ **all_kwargs))
591
+ else:
592
+ # if empty model, then don't load anything, just get gradio up
593
+ model0, tokenizer0, device = None, None, None
594
+ if model0 is None:
595
+ if fail_if_cannot_connect:
596
+ raise RuntimeError("Could not connect, see logs")
597
+ # skip
598
+ if isinstance(model_lock, list):
599
+ model_lock.remove(model_dict)
600
+ continue
601
+ model_state_trial = dict(model=model0, tokenizer=tokenizer0, device=device)
602
+ model_state_trial.update(model_dict)
603
+ assert len(model_state_none) == len(model_state_trial)
604
+ print("Model %s" % model_dict, flush=True)
605
+ if model_lock:
606
+ # last in iteration will be first
607
+ model_states.insert(0, model_state_trial)
608
+ # fill model_state0 so go_gradio() easier, manage model_states separately
609
+ model_state0 = model_state_trial.copy()
610
+ else:
611
+ model_state0 = model_state_trial.copy()
612
+ assert len(model_state_none) == len(model_state0)
613
 
614
  # get score model
615
+ all_kwargs = locals().copy()
616
  smodel, stokenizer, sdevice = get_score_model(reward_type=True,
617
  **get_kwargs(get_score_model, exclude_names=['reward_type'],
618
  **all_kwargs))
619
+ score_model_state0 = dict(model=smodel, tokenizer=stokenizer, device=sdevice,
620
+ base_model=score_model, tokenizer_base_model='', lora_weights='',
621
+ inference_server='', prompt_type='', prompt_dict='')
622
 
623
  if enable_captions:
624
  if pre_load_caption_model:
 
633
  go_gradio(**locals())
634
 
635
 
636
+ def get_config(base_model,
637
+ use_auth_token=False,
638
+ trust_remote_code=True,
639
+ offload_folder=None,
640
+ triton_attn=False,
641
+ long_sequence=True,
642
+ return_model=False,
643
+ raise_exception=False,
644
+ ):
645
+ from accelerate import init_empty_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
646
  with init_empty_weights():
647
  from transformers import AutoConfig
648
+ try:
649
+ config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
650
+ trust_remote_code=trust_remote_code,
651
+ offload_folder=offload_folder)
652
+ except OSError as e:
653
+ if raise_exception:
654
+ raise
655
+ if 'not a local folder and is not a valid model identifier listed on' in str(
656
+ e) or '404 Client Error' in str(e):
657
+ # e.g. llama, gpjt, etc.
658
+ # e.g. HF TGI but not model on HF or private etc.
659
+ # HF TGI server only should really require prompt_type, not HF model state
660
+ return None, None
661
+ else:
662
+ raise
663
  if triton_attn and 'mpt-' in base_model.lower():
664
  config.attn_config['attn_impl'] = 'triton'
665
  if long_sequence:
 
667
  config.update({"max_seq_len": 83968})
668
  if 'mosaicml/mpt-7b-chat' in base_model.lower():
669
  config.update({"max_seq_len": 4096})
670
+ if 'mpt-30b' in base_model.lower():
671
+ config.update({"max_seq_len": 2 * 8192})
672
+ if return_model and \
673
+ issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
674
  model = AutoModel.from_config(
675
  config,
676
+ trust_remote_code=trust_remote_code,
677
  )
678
  else:
679
  # can't infer
680
  model = None
681
+ if 'falcon' in base_model.lower():
682
+ config.use_cache = False
683
+
684
+ return config, model
685
+
686
+
687
+ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
688
+ config, model,
689
+ gpu_id=0,
690
+ ):
691
+ """
692
+ Ensure model gets on correct device
693
+ """
694
 
695
  if model is not None:
696
  # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
697
  # NOTE: Some models require avoiding sharding some layers,
698
  # then would pass no_split_module_classes and give list of those layers.
699
+ from accelerate import infer_auto_device_map
700
  device_map = infer_auto_device_map(
701
  model,
702
  dtype=torch.float16 if load_half else torch.float32,
 
748
  return model
749
 
750
 
751
+ def get_client_from_inference_server(inference_server, raise_connection_exception=False):
752
+ inference_server, headers = get_hf_server(inference_server)
753
+ # preload client since slow for gradio case especially
754
+ from gradio_utils.grclient import GradioClient
755
+ gr_client = None
756
+ hf_client = None
757
+ if headers is None:
758
+ try:
759
+ print("GR Client Begin: %s" % inference_server, flush=True)
760
+ # first do sanity check if alive, else gradio client takes too long by default
761
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
762
+ gr_client = GradioClient(inference_server)
763
+ print("GR Client End: %s" % inference_server, flush=True)
764
+ except (OSError, ValueError) as e:
765
+ # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
766
+ gr_client = None
767
+ print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True)
768
+ except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
769
+ JSONDecodeError, ReadTimeout2, KeyError) as e:
770
+ t, v, tb = sys.exc_info()
771
+ ex = ''.join(traceback.format_exception(t, v, tb))
772
+ print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True)
773
+ if raise_connection_exception:
774
+ raise
775
+
776
+ if gr_client is None:
777
+ res = None
778
+ from text_generation import Client as HFClient
779
+ print("HF Client Begin: %s" % inference_server)
780
+ try:
781
+ hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
782
+ # quick check valid TGI endpoint
783
+ res = hf_client.generate('What?', max_new_tokens=1)
784
+ hf_client = HFClient(inference_server, headers=headers, timeout=300)
785
+ except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
786
+ JSONDecodeError, ReadTimeout2, KeyError) as e:
787
+ hf_client = None
788
+ t, v, tb = sys.exc_info()
789
+ ex = ''.join(traceback.format_exception(t, v, tb))
790
+ print("HF Client Failed %s: %s" % (inference_server, str(ex)))
791
+ if raise_connection_exception:
792
+ raise
793
+ print("HF Client End: %s %s" % (inference_server, res))
794
+ return inference_server, gr_client, hf_client
795
+
796
+
797
  def get_model(
798
  load_8bit: bool = False,
799
  load_4bit: bool = False,
800
  load_half: bool = True,
801
  infer_devices: bool = True,
802
  base_model: str = '',
803
+ inference_server: str = "",
804
  tokenizer_base_model: str = '',
805
  lora_weights: str = "",
806
  gpu_id: int = 0,
 
824
  For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
825
  So it is not the default
826
  :param base_model: name/path of base model
827
+ :param inference_server: whether base_model is hosted locally ('') or via http (url)
828
  :param tokenizer_base_model: name/path of tokenizer
829
  :param lora_weights: name/path
830
  :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
 
840
  """
841
  if verbose:
842
  print("Get %s model" % base_model, flush=True)
843
+
844
+ triton_attn = False
845
+ long_sequence = True
846
+ config_kwargs = dict(use_auth_token=use_auth_token,
847
+ trust_remote_code=trust_remote_code,
848
+ offload_folder=offload_folder,
849
+ triton_attn=triton_attn,
850
+ long_sequence=long_sequence)
851
+ config, _ = get_config(base_model, **config_kwargs, raise_exception=False)
852
+
853
+ if base_model in non_hf_types:
854
+ assert config is None, "Expected config None for %s" % base_model
855
+
856
+ llama_type_from_config = 'llama' in str(config).lower()
857
+ llama_type_from_name = "llama" in base_model.lower()
858
+ llama_type = llama_type_from_config or llama_type_from_name
859
+ if "xgen" in base_model.lower():
860
+ llama_type = False
861
+ if llama_type:
862
+ if verbose:
863
+ print("Detected as llama type from"
864
+ " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
865
+
866
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
867
+
868
+ tokenizer_kwargs = dict(local_files_only=local_files_only,
869
+ resume_download=resume_download,
870
+ use_auth_token=use_auth_token,
871
+ trust_remote_code=trust_remote_code,
872
+ offload_folder=offload_folder,
873
+ padding_side='left',
874
+ config=config,
875
+ )
876
+ if not tokenizer_base_model:
877
+ tokenizer_base_model = base_model
878
+
879
+ if config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
880
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs)
881
+ # sets raw (no cushion) limit
882
+ set_model_max_len(config, tokenizer, verbose=False)
883
+ # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
884
+ # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
885
+ tokenizer.model_max_length = tokenizer.model_max_length - 50
886
+ else:
887
+ tokenizer = FakeTokenizer()
888
+
889
+ if isinstance(inference_server, str) and inference_server.startswith("http"):
890
+ inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
891
+ client = gr_client or hf_client
892
+ # Don't return None, None for model, tokenizer so triggers
893
+ return client, tokenizer, 'http'
894
+ if isinstance(inference_server, str) and inference_server.startswith('openai'):
895
+ assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
896
+ # Don't return None, None for model, tokenizer so triggers
897
+ # include small token cushion
898
+ tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
899
+ return inference_server, tokenizer, inference_server
900
+ assert not inference_server, "Malformed inference_server=%s" % inference_server
901
  if base_model in non_hf_types:
902
  from gpt4all_llm import get_model_tokenizer_gpt4all
903
  model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
904
  return model, tokenizer, device
905
 
906
+ # get local torch-HF model
907
+ return get_hf_model(load_8bit=load_8bit,
908
+ load_4bit=load_4bit,
909
+ load_half=load_half,
910
+ infer_devices=infer_devices,
911
+ base_model=base_model,
912
+ tokenizer_base_model=tokenizer_base_model,
913
+ lora_weights=lora_weights,
914
+ gpu_id=gpu_id,
915
+
916
+ reward_type=reward_type,
917
+ local_files_only=local_files_only,
918
+ resume_download=resume_download,
919
+ use_auth_token=use_auth_token,
920
+ trust_remote_code=trust_remote_code,
921
+ offload_folder=offload_folder,
922
+ compile_model=compile_model,
923
+
924
+ llama_type=llama_type,
925
+ config_kwargs=config_kwargs,
926
+ tokenizer_kwargs=tokenizer_kwargs,
927
+
928
+ verbose=verbose)
929
+
930
+
931
+ def get_hf_model(load_8bit: bool = False,
932
+ load_4bit: bool = False,
933
+ load_half: bool = True,
934
+ infer_devices: bool = True,
935
+ base_model: str = '',
936
+ tokenizer_base_model: str = '',
937
+ lora_weights: str = "",
938
+ gpu_id: int = 0,
939
+
940
+ reward_type: bool = None,
941
+ local_files_only: bool = False,
942
+ resume_download: bool = True,
943
+ use_auth_token: Union[str, bool] = False,
944
+ trust_remote_code: bool = True,
945
+ offload_folder: str = None,
946
+ compile_model: bool = True,
947
+
948
+ llama_type: bool = False,
949
+ config_kwargs=None,
950
+ tokenizer_kwargs=None,
951
+
952
+ verbose: bool = False,
953
+ ):
954
+ assert config_kwargs is not None
955
+ assert tokenizer_kwargs is not None
956
+
957
  if lora_weights is not None and lora_weights.strip():
958
  if verbose:
959
  print("Get %s lora weights" % lora_weights, flush=True)
 
968
  "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
969
  )
970
 
971
+ model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
 
 
 
 
 
 
 
 
 
 
972
 
973
+ config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
 
 
974
 
975
  if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
976
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
977
+ **tokenizer_kwargs)
 
 
 
 
 
978
  else:
979
  tokenizer = tokenizer_loader
980
 
 
998
  load_in_4bit=load_4bit,
999
  device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
1000
  ))
1001
+ if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
1002
  model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
1003
 
1004
  if 'OpenAssistant/reward-model'.lower() in base_model.lower():
 
1009
 
1010
  if not lora_weights:
1011
  with torch.device(device):
1012
+
1013
  if infer_devices:
1014
+ config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
1015
  model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
1016
+ config, model,
1017
  gpu_id=gpu_id,
 
 
 
1018
  )
1019
  else:
1020
+ config, _ = get_config(base_model, **config_kwargs)
1021
  if load_half and not (load_8bit or load_4bit):
1022
  model = model_loader.from_pretrained(
1023
  base_model,
1024
+ config=config,
1025
  **model_kwargs).half()
1026
  else:
1027
  model = model_loader.from_pretrained(
1028
  base_model,
1029
+ config=config,
1030
  **model_kwargs)
1031
  elif load_8bit or load_4bit:
1032
+ config, _ = get_config(base_model, **config_kwargs)
1033
  model = model_loader.from_pretrained(
1034
  base_model,
1035
+ config=config,
1036
  **model_kwargs
1037
  )
1038
  from peft import PeftModel # loads cuda, so avoid in global scope
 
1049
  )
1050
  else:
1051
  with torch.device(device):
1052
+ config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
1053
  model = model_loader.from_pretrained(
1054
  base_model,
1055
+ config=config,
1056
  **model_kwargs
1057
  )
1058
  from peft import PeftModel # loads cuda, so avoid in global scope
 
1086
  if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
1087
  model = torch.compile(model)
1088
 
1089
+ set_model_max_len(config, tokenizer, verbose=False, reward_type=reward_type)
1090
+
1091
+ return model, tokenizer, device
1092
+
1093
+
1094
+ def set_model_max_len(config, tokenizer, verbose=False, reward_type=False):
1095
+ if reward_type:
1096
+ # limit deberta, else uses too much memory and not worth response score
1097
+ tokenizer.model_max_length = 512
1098
  if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int):
1099
  tokenizer.model_max_length = config.max_seq_len
1100
  elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
 
1104
  if verbose:
1105
  print("Could not determine model_max_length, setting to 2048", flush=True)
1106
  tokenizer.model_max_length = 2048
1107
+ # for bug in HF transformers
1108
+ if tokenizer.model_max_length > 100000000:
1109
+ tokenizer.model_max_length = 2048
1110
 
1111
 
1112
  def pop_unused_model_kwargs(model_kwargs):
 
1128
  load_half: bool = True,
1129
  infer_devices: bool = True,
1130
  base_model: str = '',
1131
+ inference_server: str = '',
1132
  tokenizer_base_model: str = '',
1133
  lora_weights: str = "",
1134
  gpu_id: int = 0,
 
1150
  base_model = score_model.strip()
1151
  tokenizer_base_model = ''
1152
  lora_weights = ''
1153
+ inference_server = ''
1154
  llama_type = False
1155
  compile_model = False
1156
  smodel, stokenizer, sdevice = get_model(reward_type=True,
 
1217
  debug=False,
1218
  concurrency_count=None,
1219
  save_dir=None,
1220
+ sanitize_bot_response=False,
1221
  model_state0=None,
1222
  memory_restriction_level=None,
1223
+ max_max_new_tokens=None,
1224
+ is_public=None,
1225
+ max_max_time=None,
1226
  raise_generate_gpu_exceptions=None,
1227
  chat_context=None,
1228
  lora_weights=None,
 
1233
  use_openai_embedding=None,
1234
  use_openai_model=None,
1235
  hf_embedding_model=None,
 
 
1236
  db_type=None,
1237
  n_jobs=None,
1238
  first_para=None,
1239
  text_limit=None,
1240
  verbose=False,
1241
  cli=False,
1242
+ reverse_docs=True,
1243
+ use_cache=None,
1244
+ auto_reduce_chunks=None,
1245
+ max_chunks=None,
1246
+ model_lock=None,
1247
+ force_langchain_evaluate=None,
1248
+ model_state_none=None,
1249
  ):
1250
  if isinstance(user_kwargs, str):
1251
  user_kwargs = ast.literal_eval(user_kwargs)
1252
  # only used for submit_nochat_api
1253
  user_kwargs['chat'] = False
1254
+ if 'stream_output' not in user_kwargs:
1255
+ user_kwargs['stream_output'] = False
1256
  if 'langchain_mode' not in user_kwargs:
1257
  # if user doesn't specify, then assume disabled, not use default
1258
  user_kwargs['langchain_mode'] = 'Disabled'
 
1275
  sanitize_bot_response=sanitize_bot_response,
1276
  model_state0=model_state0,
1277
  memory_restriction_level=memory_restriction_level,
1278
+ max_max_new_tokens=max_max_new_tokens,
1279
+ is_public=is_public,
1280
+ max_max_time=max_max_time,
1281
  raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
1282
  chat_context=chat_context,
1283
  lora_weights=lora_weights,
 
1294
  text_limit=text_limit,
1295
  verbose=verbose,
1296
  cli=cli,
1297
+ reverse_docs=reverse_docs,
1298
+ use_cache=use_cache,
1299
+ auto_reduce_chunks=auto_reduce_chunks,
1300
+ max_chunks=max_chunks,
1301
+ model_lock=model_lock,
1302
+ force_langchain_evaluate=force_langchain_evaluate,
1303
+ model_state_none=model_state_none,
1304
  )
1305
  try:
1306
  for ret1 in ret:
 
1345
  debug=False,
1346
  concurrency_count=None,
1347
  save_dir=None,
1348
+ sanitize_bot_response=False,
1349
  model_state0=None,
1350
  memory_restriction_level=None,
1351
+ max_max_new_tokens=None,
1352
+ is_public=None,
1353
+ max_max_time=None,
1354
  raise_generate_gpu_exceptions=None,
1355
  chat_context=None,
1356
  lora_weights=None,
 
1367
  text_limit=None,
1368
  verbose=False,
1369
  cli=False,
1370
+ reverse_docs=True,
1371
+ use_cache=None,
1372
+ auto_reduce_chunks=None,
1373
+ max_chunks=None,
1374
+ model_lock=None,
1375
+ force_langchain_evaluate=None,
1376
+ model_state_none=None,
1377
  ):
1378
  # ensure passed these
1379
  assert concurrency_count is not None
 
1394
  locals_dict = locals().copy()
1395
  locals_dict.pop('model_state', None)
1396
  locals_dict.pop('model_state0', None)
1397
+ locals_dict.pop('model_states', None)
1398
  print(locals_dict)
1399
 
1400
+ no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" \
1401
+ "Then start New Conversation"
1402
 
1403
+ if model_state is None:
1404
+ model_state = model_state_none.copy()
1405
  if model_state0 is None:
1406
  # e.g. for no gradio case, set dummy value, else should be set
1407
+ model_state0 = model_state_none.copy()
1408
+
1409
+ # model_state['model] is only 'model' if should use model_state0
1410
+ # model could also be None
1411
+ have_model_lock = model_lock is not None
1412
+ have_fresh_model = model_state['model'] not in [None, 'model', no_model_str]
1413
+ # for gradio UI control, expect model_state and model_state0 to match, so if have_model_lock=True, then should have_fresh_model=True
1414
+ # but gradio API control will only use nochat api etc. and won't use fresh model, so can't assert in general
1415
+ # if have_model_lock:
1416
+ # assert have_fresh_model, "Expected model_state and model_state0 to match if have_model_lock"
1417
+ have_cli_model = model_state0['model'] not in [None, 'model', no_model_str]
1418
+
1419
+ if have_fresh_model:
1420
+ # USE FRESH MODEL
1421
+ if not have_model_lock:
1422
+ # model_state0 is just one of model_state if model_lock, so don't nuke
1423
+ # try to free-up original model (i.e. list was passed as reference)
1424
+ if model_state0['model'] and hasattr(model_state0['model'], 'cpu'):
1425
+ model_state0['model'].cpu()
1426
+ model_state0['model'] = None
1427
+ # try to free-up original tokenizer (i.e. list was passed as reference)
1428
+ if model_state0['tokenizer']:
1429
+ model_state0['tokenizer'] = None
1430
+ clear_torch_cache()
1431
+ chosen_model_state = model_state
1432
+ elif have_cli_model:
1433
+ # USE MODEL SETUP AT CLI
1434
+ assert isinstance(model_state['model'], str) # expect no fresh model
1435
+ chosen_model_state = model_state0
1436
  else:
1437
  raise AssertionError(no_model_msg)
1438
+ # get variables
1439
+ model = chosen_model_state['model']
1440
+ tokenizer = chosen_model_state['tokenizer']
1441
+ device = chosen_model_state['device']
1442
+ base_model = chosen_model_state['base_model']
1443
+ tokenizer_base_model = chosen_model_state['tokenizer_base_model']
1444
+ lora_weights = chosen_model_state['lora_weights']
1445
+ inference_server = chosen_model_state['inference_server']
1446
+ # prefer use input from API over model state
1447
+ prompt_type = prompt_type or chosen_model_state['prompt_type']
1448
+ prompt_dict = prompt_dict or chosen_model_state['prompt_dict']
1449
 
1450
  if base_model is None:
1451
  raise AssertionError(no_model_msg)
 
1459
  instruction = instruction_nochat
1460
  iinput = iinput_nochat
1461
 
1462
+ # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
1463
+ model_lower = base_model.lower()
1464
+ if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
1465
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
1466
+ if verbose:
1467
+ print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
1468
+ assert prompt_type is not None, "prompt_type was None"
1469
+
1470
+ # Control generation hyperparameters
1471
+ # adjust for bad inputs, e.g. in case also come from API that doesn't get constrained by gradio sliders
1472
+ # below is for TGI server, not required for HF transformers
1473
+ # limits are chosen similar to gradio_runner.py sliders/numbers
1474
+ top_p = min(max(1e-3, top_p), 1.0 - 1e-3)
1475
+ top_k = min(max(1, int(top_k)), 100)
1476
+ temperature = min(max(0.01, temperature), 2.0)
1477
+ # FIXME: https://github.com/h2oai/h2ogpt/issues/106
1478
+ num_beams = 1 if stream_output else num_beams # See max_beams in gradio_runner
1479
+ max_max_new_tokens = get_max_max_new_tokens(chosen_model_state,
1480
+ memory_restriction_level=memory_restriction_level,
1481
+ max_new_tokens=max_new_tokens,
1482
+ max_max_new_tokens=max_max_new_tokens)
1483
+ model_max_length = get_model_max_length(chosen_model_state)
1484
+ max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens)
1485
+ min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens)
1486
+ max_time = min(max(0, max_time), max_max_time)
1487
+ repetition_penalty = min(max(0.01, repetition_penalty), 3.0)
1488
+ num_return_sequences = 1 if chat else min(max(1, int(num_return_sequences)), 10)
1489
+ min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
1490
+ top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs)
1491
+ chunk_size = min(max(128, int(chunk_size)), 2048)
1492
  if not context:
1493
  # get hidden context if have one
1494
  context = get_context(chat_context, prompt_type)
1495
 
1496
+ # restrict instruction, typically what has large input
1497
+ from h2oai_pipeline import H2OTextGenerationPipeline
1498
+ instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer)
1499
+ context, num_prompt_tokens2 = H2OTextGenerationPipeline.limit_prompt(context, tokenizer)
1500
+ iinput, num_prompt_tokens3 = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer)
1501
+ num_prompt_tokens = (num_prompt_tokens1 or 0) + (num_prompt_tokens2 or 0) + (num_prompt_tokens3 or 0)
1502
+
1503
+ # get prompt
1504
  prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output)
1505
  data_point = dict(context=context, instruction=instruction, input=iinput)
1506
  prompt = prompter.generate_prompt(data_point)
 
1513
  db1 = dbs[langchain_mode]
1514
  else:
1515
  db1 = None
1516
+ do_langchain_path = langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and \
1517
+ db1 is not None or \
1518
+ base_model in non_hf_types or \
1519
+ force_langchain_evaluate
1520
+ if do_langchain_path:
1521
  query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
1522
  outr = ""
1523
  # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
1524
  from gpt_langchain import run_qa_db
1525
+ gen_hyper_langchain = dict(do_sample=do_sample,
1526
+ temperature=temperature,
1527
+ repetition_penalty=repetition_penalty,
1528
+ top_k=top_k,
1529
+ top_p=top_p,
1530
+ num_beams=num_beams,
1531
+ min_new_tokens=min_new_tokens,
1532
+ max_new_tokens=max_new_tokens,
1533
+ early_stopping=early_stopping,
1534
+ max_time=max_time,
1535
+ num_return_sequences=num_return_sequences,
1536
+ )
1537
  for r in run_qa_db(query=query,
1538
  model_name=base_model, model=model, tokenizer=tokenizer,
1539
+ inference_server=inference_server,
1540
  stream_output=stream_output,
1541
  prompter=prompter,
1542
  load_db_if_exists=load_db_if_exists,
 
1556
  db_type=db_type,
1557
  top_k_docs=top_k_docs,
1558
 
1559
+ **gen_hyper_langchain,
 
 
 
 
 
 
 
 
 
 
 
1560
 
1561
  prompt_type=prompt_type,
1562
  prompt_dict=prompt_dict,
1563
  n_jobs=n_jobs,
1564
  verbose=verbose,
1565
  cli=cli,
1566
+ sanitize_bot_response=sanitize_bot_response,
1567
+ reverse_docs=reverse_docs,
1568
+
1569
+ lora_weights=lora_weights,
1570
+
1571
+ auto_reduce_chunks=auto_reduce_chunks,
1572
+ max_chunks=max_chunks,
1573
  ):
1574
  outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
1575
  yield dict(response=outr, sources=extra)
1576
  if save_dir:
1577
+ extra_dict = gen_hyper_langchain.copy()
1578
+ extra_dict.update(prompt_type=prompt_type, inference_server=inference_server,
1579
+ langchain_mode=langchain_mode, document_choice=document_choice,
1580
+ num_prompt_tokens=num_prompt_tokens)
1581
+ save_generate_output(prompt=query, output=outr, base_model=base_model, save_dir=save_dir,
1582
+ where_from='run_qa_db',
1583
+ extra_dict=extra_dict)
1584
  if verbose:
1585
  print(
1586
  'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
 
1593
  clear_torch_cache()
1594
  return
1595
 
1596
+ if inference_server.startswith('openai') or inference_server.startswith('http'):
1597
+ if inference_server.startswith('openai'):
1598
+ import openai
1599
+ where_from = "openai_client"
1600
+
1601
+ openai.api_key = os.getenv("OPENAI_API_KEY")
1602
+ stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
1603
+ # OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
1604
+ max_new_tokens_openai = min(max_new_tokens, model_max_length - num_prompt_tokens)
1605
+ gen_server_kwargs = dict(temperature=temperature if do_sample else 0,
1606
+ max_tokens=max_new_tokens_openai,
1607
+ top_p=top_p if do_sample else 1,
1608
+ frequency_penalty=0,
1609
+ n=num_return_sequences,
1610
+ presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
1611
+ )
1612
+ if inference_server == 'openai':
1613
+ response = openai.Completion.create(
1614
+ model=base_model,
1615
+ prompt=prompt,
1616
+ **gen_server_kwargs,
1617
+ stop=stop_sequences,
1618
+ stream=stream_output,
1619
+ )
1620
+ if not stream_output:
1621
+ text = response['choices'][0]['text']
1622
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1623
+ sanitize_bot_response=sanitize_bot_response),
1624
+ sources='')
1625
+ else:
1626
+ collected_events = []
1627
+ text = ''
1628
+ for event in response:
1629
+ collected_events.append(event) # save the event response
1630
+ event_text = event['choices'][0]['text'] # extract the text
1631
+ text += event_text # append the text
1632
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1633
+ sanitize_bot_response=sanitize_bot_response),
1634
+ sources='')
1635
+ elif inference_server == 'openai_chat':
1636
+ response = openai.ChatCompletion.create(
1637
+ model=base_model,
1638
+ messages=[
1639
+ {"role": "system", "content": "You are a helpful assistant."},
1640
+ {'role': 'user',
1641
+ 'content': prompt,
1642
+ }
1643
+ ],
1644
+ stream=stream_output,
1645
+ **gen_server_kwargs,
1646
+ )
1647
+ if not stream_output:
1648
+ text = response["choices"][0]["message"]["content"]
1649
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1650
+ sanitize_bot_response=sanitize_bot_response),
1651
+ sources='')
1652
+ else:
1653
+ text = ""
1654
+ for chunk in response:
1655
+ delta = chunk["choices"][0]["delta"]
1656
+ if 'content' in delta:
1657
+ text += delta['content']
1658
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1659
+ sanitize_bot_response=sanitize_bot_response),
1660
+ sources='')
1661
+ else:
1662
+ raise RuntimeError("No such OpenAI mode: %s" % inference_server)
1663
+ elif inference_server.startswith('http'):
1664
+ inference_server, headers = get_hf_server(inference_server)
1665
+ from gradio_utils.grclient import GradioClient
1666
+ from text_generation import Client as HFClient
1667
+ if isinstance(model, GradioClient):
1668
+ gr_client = model
1669
+ hf_client = None
1670
+ elif isinstance(model, HFClient):
1671
+ gr_client = None
1672
+ hf_client = model
1673
+ else:
1674
+ inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
1675
+
1676
+ # quick sanity check to avoid long timeouts, just see if can reach server
1677
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
1678
+
1679
+ if gr_client is not None:
1680
+ # Note: h2oGPT gradio server could handle input token size issues for prompt,
1681
+ # but best to handle here so send less data to server
1682
+
1683
+ chat_client = False
1684
+ where_from = "gr_client"
1685
+ client_langchain_mode = 'Disabled'
1686
+ gen_server_kwargs = dict(temperature=temperature,
1687
+ top_p=top_p,
1688
+ top_k=top_k,
1689
+ num_beams=num_beams,
1690
+ max_new_tokens=max_new_tokens,
1691
+ min_new_tokens=min_new_tokens,
1692
+ early_stopping=early_stopping,
1693
+ max_time=max_time,
1694
+ repetition_penalty=repetition_penalty,
1695
+ num_return_sequences=num_return_sequences,
1696
+ do_sample=do_sample,
1697
+ chat=chat_client,
1698
+ )
1699
+ # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
1700
+ if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value,
1701
+ str(PromptType.plain.value)]:
1702
+ # if our prompt is plain, assume either correct or gradio server knows different prompt type,
1703
+ # so pass empty prompt_Type
1704
+ gr_prompt_type = ''
1705
+ gr_prompt_dict = ''
1706
+ gr_prompt = prompt # already prepared prompt
1707
+ gr_context = ''
1708
+ gr_iinput = ''
1709
+ else:
1710
+ # if already have prompt_type that is not plain, None, or '', then already applied some prompting
1711
+ # But assume server can handle prompting, and need to avoid double-up.
1712
+ # Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
1713
+ # So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
1714
+ # Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
1715
+ # because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
1716
+ # since those won't appear
1717
+ gr_context = context
1718
+ gr_prompt = instruction
1719
+ gr_iinput = iinput
1720
+ gr_prompt_type = prompt_type
1721
+ gr_prompt_dict = prompt_dict
1722
+ client_kwargs = dict(instruction=gr_prompt if chat_client else '', # only for chat=True
1723
+ iinput=gr_iinput, # only for chat=True
1724
+ context=gr_context,
1725
+ # streaming output is supported, loops over and outputs each generation in streaming mode
1726
+ # but leave stream_output=False for simple input/output mode
1727
+ stream_output=stream_output,
1728
+
1729
+ **gen_server_kwargs,
1730
+
1731
+ prompt_type=gr_prompt_type,
1732
+ prompt_dict=gr_prompt_dict,
1733
+
1734
+ instruction_nochat=gr_prompt if not chat_client else '',
1735
+ iinput_nochat=gr_iinput, # only for chat=False
1736
+ langchain_mode=client_langchain_mode,
1737
+ top_k_docs=top_k_docs,
1738
+ chunk=chunk,
1739
+ chunk_size=chunk_size,
1740
+ document_choice=[DocumentChoices.All_Relevant.name],
1741
+ )
1742
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1743
+ if not stream_output:
1744
+ res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
1745
+ res_dict = ast.literal_eval(res)
1746
+ text = res_dict['response']
1747
+ sources = res_dict['sources']
1748
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1749
+ sanitize_bot_response=sanitize_bot_response),
1750
+ sources=sources)
1751
+ else:
1752
+ job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
1753
+ text = ''
1754
+ sources = ''
1755
+ res_dict = dict(response=text, sources=sources)
1756
+ while not job.done():
1757
+ outputs_list = job.communicator.job.outputs
1758
+ if outputs_list:
1759
+ res = job.communicator.job.outputs[-1]
1760
+ res_dict = ast.literal_eval(res)
1761
+ text = res_dict['response']
1762
+ sources = res_dict['sources']
1763
+ if gr_prompt_type == 'plain':
1764
+ # then gradio server passes back full prompt + text
1765
+ prompt_and_text = text
1766
+ else:
1767
+ prompt_and_text = prompt + text
1768
+ yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt,
1769
+ sanitize_bot_response=sanitize_bot_response),
1770
+ sources=sources)
1771
+ time.sleep(0.01)
1772
+ # ensure get last output to avoid race
1773
+ res_all = job.outputs()
1774
+ if len(res_all) > 0:
1775
+ res = res_all[-1]
1776
+ res_dict = ast.literal_eval(res)
1777
+ text = res_dict['response']
1778
+ sources = res_dict['sources']
1779
+ else:
1780
+ # go with old text if last call didn't work
1781
+ e = job.future._exception
1782
+ if e is not None:
1783
+ stre = str(e)
1784
+ strex = ''.join(traceback.format_tb(e.__traceback__))
1785
+ else:
1786
+ stre = ''
1787
+ strex = ''
1788
+
1789
+ print("Bad final response: %s %s %s %s %s: %s %s" % (base_model, inference_server,
1790
+ res_all, prompt, text, stre, strex),
1791
+ flush=True)
1792
+ if gr_prompt_type == 'plain':
1793
+ # then gradio server passes back full prompt + text
1794
+ prompt_and_text = text
1795
+ else:
1796
+ prompt_and_text = prompt + text
1797
+ yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt,
1798
+ sanitize_bot_response=sanitize_bot_response),
1799
+ sources=sources)
1800
+ elif hf_client:
1801
+ # HF inference server needs control over input tokens
1802
+ where_from = "hf_client"
1803
+
1804
+ # prompt must include all human-bot like tokens, already added by prompt
1805
+ # https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
1806
+ stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
1807
+ gen_server_kwargs = dict(do_sample=do_sample,
1808
+ max_new_tokens=max_new_tokens,
1809
+ # best_of=None,
1810
+ repetition_penalty=repetition_penalty,
1811
+ return_full_text=True,
1812
+ seed=SEED,
1813
+ stop_sequences=stop_sequences,
1814
+ temperature=temperature,
1815
+ top_k=top_k,
1816
+ top_p=top_p,
1817
+ # truncate=False, # behaves oddly
1818
+ # typical_p=top_p,
1819
+ # watermark=False,
1820
+ # decoder_input_details=False,
1821
+ )
1822
+ # work-around for timeout at constructor time, will be issue if multi-threading,
1823
+ # so just do something reasonable or max_time if larger
1824
+ # lower bound because client is re-used if multi-threading
1825
+ hf_client.timeout = max(300, max_time)
1826
+ if not stream_output:
1827
+ text = hf_client.generate(prompt, **gen_server_kwargs).generated_text
1828
+ yield dict(response=prompter.get_response(text, prompt=prompt,
1829
+ sanitize_bot_response=sanitize_bot_response),
1830
+ sources='')
1831
+ else:
1832
+ text = ""
1833
+ for response in hf_client.generate_stream(prompt, **gen_server_kwargs):
1834
+ if not response.token.special:
1835
+ # stop_sequences
1836
+ text_chunk = response.token.text
1837
+ text += text_chunk
1838
+ yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
1839
+ sanitize_bot_response=sanitize_bot_response),
1840
+ sources='')
1841
+ else:
1842
+ raise RuntimeError("Failed to get client: %s" % inference_server)
1843
+ else:
1844
+ raise RuntimeError("No such inference_server %s" % inference_server)
1845
+
1846
+ if save_dir and text:
1847
+ # save prompt + new text
1848
+ extra_dict = gen_server_kwargs.copy()
1849
+ extra_dict.update(dict(inference_server=inference_server, num_prompt_tokens=num_prompt_tokens))
1850
+ save_generate_output(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir,
1851
+ where_from=where_from, extra_dict=extra_dict)
1852
+ return
1853
+ else:
1854
+ assert not inference_server, "inferene_server=%s not supported" % inference_server
1855
+
1856
  if isinstance(tokenizer, str):
1857
  # pipeline
1858
  if tokenizer == "summarization":
 
1866
  assert src_lang is not None
1867
  tokenizer.src_lang = languages_covered()[src_lang]
1868
 
1869
+ stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device,
1870
+ model_max_length=tokenizer.model_max_length)
1871
+
1872
+ inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
1873
  if debug and len(inputs["input_ids"]) > 0:
1874
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1875
  input_ids = inputs["input_ids"].to(device)
1876
  # CRITICAL LIMIT else will fail
1877
  max_max_tokens = tokenizer.model_max_length
1878
+ max_input_tokens = max_max_tokens - min_new_tokens
1879
+ # NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
1880
  input_ids = input_ids[:, -max_input_tokens:]
1881
+ # required for falcon if multiple threads or asyncio accesses to model during generation
1882
+ if use_cache is None:
1883
+ use_cache = False if 'falcon' in base_model else True
1884
+ gen_config_kwargs = dict(temperature=float(temperature),
1885
+ top_p=float(top_p),
1886
+ top_k=top_k,
1887
+ num_beams=num_beams,
1888
+ do_sample=do_sample,
1889
+ repetition_penalty=float(repetition_penalty),
1890
+ num_return_sequences=num_return_sequences,
1891
+ renormalize_logits=True,
1892
+ remove_invalid_values=True,
1893
+ use_cache=use_cache,
1894
+ )
1895
+ token_ids = ['eos_token_id', 'pad_token_id', 'bos_token_id', 'cls_token_id', 'sep_token_id']
1896
+ for token_id in token_ids:
1897
+ if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None:
1898
+ gen_config_kwargs.update({token_id: getattr(tokenizer, token_id)})
1899
+ generation_config = GenerationConfig(**gen_config_kwargs)
1900
 
1901
  gen_kwargs = dict(input_ids=input_ids,
1902
  generation_config=generation_config,
 
1915
  tgt_lang = languages_covered()[tgt_lang]
1916
  gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
1917
  else:
1918
+ token_ids = ['eos_token_id', 'bos_token_id', 'pad_token_id']
1919
+ for token_id in token_ids:
1920
+ if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None:
1921
+ gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
1922
 
1923
  decoder_kwargs = dict(skip_special_tokens=True,
1924
  clean_up_tokenization_spaces=True)
 
1934
  )
1935
 
1936
  with torch.no_grad():
1937
+ have_lora_weights = lora_weights not in [no_lora_str, '', None]
1938
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
1939
  with context_class_cast(device):
1940
  # protection for gradio not keeping track of closed users,
1941
  # else hit bitsandbytes lack of thread safety:
 
2018
  if outputs and len(outputs) >= 1:
2019
  decoded_output = prompt + outputs[0]
2020
  if save_dir and decoded_output:
2021
+ extra_dict = gen_config_kwargs.copy()
2022
+ extra_dict.update(dict(num_prompt_tokens=num_prompt_tokens))
2023
+ save_generate_output(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir,
2024
+ where_from="evaluate_%s" % str(stream_output),
2025
+ extra_dict=gen_config_kwargs)
2026
  if verbose:
2027
  print('Post-Generate: %s decoded_output: %s' % (
2028
  str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
 
2041
  if memory_restriction_level > 0:
2042
  max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
2043
  else:
2044
+ # at least give room for 1 paragraph output
2045
  max_length_tokenize = model_max_length - 256
2046
  cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
2047
  output_smallest = 30 * 4
 
2146
  if model_lower:
2147
  print(f"Using Model {model_lower}", flush=True)
2148
  else:
2149
+ if verbose:
2150
+ print("No model defined yet", flush=True)
2151
 
2152
  min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
2153
  early_stopping = early_stopping if early_stopping is not None else False
 
2203
  use_defaults = True
2204
  else:
2205
  if chat:
2206
+ placeholder_instruction = ""
2207
  else:
2208
  placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
2209
  placeholder_input = ""
2210
+ if model_lower in inv_prompt_type_to_model_lower:
2211
+ prompt_type = inv_prompt_type_to_model_lower[model_lower]
2212
+ elif model_lower:
2213
+ # default is plain, because might rely upon trust_remote_code to handle prompting
2214
  prompt_type = prompt_type or 'plain'
2215
  else:
2216
  prompt_type = ''
 
2318
 
2319
  # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format
2320
  prompt_dict, error0 = get_prompt(prompt_type, prompt_dict,
2321
+ chat=False, context='', reduced=False, making_context=False, return_dict=True)
2322
  if error0:
2323
  raise RuntimeError("Prompt wrong: %s" % error0)
2324
 
 
2371
  if 'Expected all tensors to be on the same device' in str(e) or \
2372
  'expected scalar type Half but found Float' in str(e) or \
2373
  'probability tensor contains either' in str(e) or \
2374
+ 'cublasLt ran into an error!' in str(e) or \
2375
+ 'device-side assert triggered' in str(e):
2376
  print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
2377
  flush=True)
2378
  traceback.print_exc()
 
2405
  assert k in kwargs, "Missing %s" % k
2406
 
2407
 
2408
+ def get_model_max_length(model_state):
2409
+ if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2410
+ return model_state['tokenizer'].model_max_length
2411
+ else:
2412
+ return 2048
2413
+
2414
+
2415
  def get_max_max_new_tokens(model_state, **kwargs):
2416
+ if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
2417
+ max_max_new_tokens = model_state['tokenizer'].model_max_length
2418
+ else:
2419
+ max_max_new_tokens = None
2420
+
2421
+ if kwargs['max_max_new_tokens'] is not None and max_max_new_tokens is not None:
2422
+ return min(max_max_new_tokens, kwargs['max_max_new_tokens'])
2423
+ elif kwargs['max_max_new_tokens'] is not None:
2424
+ return kwargs['max_max_new_tokens']
2425
  elif kwargs['memory_restriction_level'] == 1:
2426
+ return 768
2427
  elif kwargs['memory_restriction_level'] == 2:
2428
+ return 512
2429
  elif kwargs['memory_restriction_level'] >= 3:
2430
+ return 256
2431
  else:
2432
+ # FIXME: Need to update after new model loaded, so user can control with slider
2433
+ return 2048
 
 
 
 
2434
 
2435
 
2436
+ def get_minmax_top_k_docs(is_public):
2437
+ if is_public:
2438
+ min_top_k_docs = 1
2439
+ max_top_k_docs = 3
2440
+ label_top_k_docs = "Number of document chunks"
2441
+ else:
2442
+ min_top_k_docs = -1
2443
+ max_top_k_docs = 100
2444
+ label_top_k_docs = "Number of document chunks (-1 = auto fill model context)"
2445
+ return min_top_k_docs, max_top_k_docs, label_top_k_docs
2446
+
2447
+
2448
+ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1,
2449
+ memory_restriction_level1, keep_sources_in_context1):
2450
+ """
2451
+ consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
2452
+ :param history:
2453
+ :param langchain_mode1:
2454
+ :param prompt_type1:
2455
+ :param prompt_dict1:
2456
+ :param chat1:
2457
+ :param model_max_length1:
2458
+ :param memory_restriction_level1:
2459
+ :param keep_sources_in_context1:
2460
+ :return:
2461
+ """
2462
+ # ensure output will be unique to models
2463
+ _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
2464
+ for_context=True, model_max_length=model_max_length1)
2465
+ context1 = ''
2466
+ if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
2467
+ context1 = ''
2468
+ # - 1 below because current instruction already in history from user()
2469
+ for histi in range(0, len(history) - 1):
2470
+ data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
2471
+ prompt, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt(data_point,
2472
+ prompt_type1,
2473
+ prompt_dict1,
2474
+ chat1,
2475
+ reduced=True,
2476
+ making_context=True)
2477
+ # md -> back to text, maybe not super important if model trained enough
2478
+ if not keep_sources_in_context1 and langchain_mode1 != 'Disabled' and prompt.find(source_prefix) >= 0:
2479
+ # FIXME: This is relatively slow even for small amount of text, like 0.3s each history item
2480
+ import re
2481
+ prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
2482
+ flags=re.DOTALL)
2483
+ if prompt.endswith('\n<p>'):
2484
+ prompt = prompt[:-4]
2485
+ prompt = prompt.replace('<br>', chat_turn_sep)
2486
+ if not prompt.endswith(chat_turn_sep):
2487
+ prompt += chat_turn_sep
2488
+ # most recent first, add older if can
2489
+ # only include desired chat history
2490
+ if len(prompt + context1) > max_prompt_length:
2491
+ break
2492
+ context1 += prompt
2493
+
2494
+ _, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt({}, prompt_type1, prompt_dict1,
2495
+ chat1, reduced=True,
2496
+ making_context=True)
2497
+ if context1 and not context1.endswith(chat_turn_sep):
2498
+ context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line
2499
+ return context1
2500
+
2501
+
2502
+ def entrypoint_main():
2503
  """
2504
  Examples:
2505
 
 
2530
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
2531
  """
2532
  fire.Fire(main)
2533
+
2534
+
2535
+ if __name__ == "__main__":
2536
+ entrypoint_main()
gpt4all_llm.py CHANGED
@@ -1,24 +1,13 @@
1
  import inspect
2
  import os
3
- import sys
4
  from typing import Dict, Any, Optional, List
5
  from langchain.callbacks.manager import CallbackManagerForLLMRun
6
  from pydantic import root_validator
7
  from langchain.llms import gpt4all
8
  from dotenv import dotenv_values
9
 
10
-
11
- class FakeTokenizer:
12
- model_max_length = 2048
13
-
14
- def encode(self, x, *args, **kwargs):
15
- return dict(input_ids=[x])
16
-
17
- def decode(self, x, *args, **kwargs):
18
- return x
19
-
20
- def __call__(self, x, *args, **kwargs):
21
- return self.encode(x, *args, **kwargs)
22
 
23
 
24
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
@@ -74,9 +63,9 @@ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
74
  pass
75
 
76
 
77
- def get_model_kwargs(env_kwargs, default_kwargs, cls):
78
  # default from class
79
- model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
80
  # from our defaults
81
  model_kwargs.update(default_kwargs)
82
  # from user defaults
@@ -94,10 +83,14 @@ def get_llm_gpt4all(model_name,
94
  repetition_penalty=1.0,
95
  top_k=40,
96
  top_p=0.7,
97
- verbose=False):
 
 
 
 
 
98
  env_gpt4all_file = ".env_gpt4all"
99
  env_kwargs = dotenv_values(env_gpt4all_file)
100
- callbacks = [H2OStreamingStdOutCallbackHandler()]
101
  n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
102
  default_kwargs = dict(context_erase=0.5,
103
  n_batch=1,
@@ -114,21 +107,23 @@ def get_llm_gpt4all(model_name,
114
  if model_name == 'llama':
115
  cls = H2OLlamaCpp
116
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
117
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
118
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
119
  llm = cls(**model_kwargs)
120
  llm.client.verbose = verbose
121
  elif model_name == 'gpt4all_llama':
122
  cls = H2OGPT4All
123
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
124
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
125
- model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
 
126
  llm = cls(**model_kwargs)
127
  elif model_name == 'gptj':
128
  cls = H2OGPT4All
129
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
130
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
131
- model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
 
132
  llm = cls(**model_kwargs)
133
  else:
134
  raise RuntimeError("No such model_name %s" % model_name)
@@ -137,6 +132,7 @@ def get_llm_gpt4all(model_name,
137
 
138
  class H2OGPT4All(gpt4all.GPT4All):
139
  model: Any
 
140
  """Path to the pre-trained GPT4All model file."""
141
 
142
  @root_validator()
@@ -156,9 +152,16 @@ class H2OGPT4All(gpt4all.GPT4All):
156
  model_type=values["backend"],
157
  allow_download=False,
158
  )
 
 
 
159
  else:
160
  values["client"] = values["model"]
161
- values["backend"] = values["client"].model.model_type
 
 
 
 
162
 
163
  except ImportError:
164
  raise ValueError(
@@ -172,12 +175,19 @@ class H2OGPT4All(gpt4all.GPT4All):
172
  prompt: str,
173
  stop: Optional[List[str]] = None,
174
  run_manager: Optional[CallbackManagerForLLMRun] = None,
 
175
  ) -> str:
176
  # Roughly 4 chars per token if natural language
177
  prompt = prompt[-self.n_ctx * 4:]
 
 
 
 
 
178
  verbose = False
179
  if verbose:
180
  print("_call prompt: %s" % prompt, flush=True)
 
181
  return super()._call(prompt, stop=stop, run_manager=run_manager)
182
 
183
 
@@ -186,6 +196,7 @@ from langchain.llms import LlamaCpp
186
 
187
  class H2OLlamaCpp(LlamaCpp):
188
  model_path: Any
 
189
  """Path to the pre-trained GPT4All model file."""
190
 
191
  @root_validator()
@@ -237,6 +248,7 @@ class H2OLlamaCpp(LlamaCpp):
237
  prompt: str,
238
  stop: Optional[List[str]] = None,
239
  run_manager: Optional[CallbackManagerForLLMRun] = None,
 
240
  ) -> str:
241
  verbose = False
242
  # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
@@ -253,6 +265,33 @@ class H2OLlamaCpp(LlamaCpp):
253
  prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
254
  num_prompt_tokens2 = len(prompt_tokens2)
255
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
 
 
 
 
 
256
  if verbose:
257
  print("_call prompt: %s" % prompt, flush=True)
258
- return super()._call(prompt, stop=stop, run_manager=run_manager)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import inspect
2
  import os
3
+ from functools import partial
4
  from typing import Dict, Any, Optional, List
5
  from langchain.callbacks.manager import CallbackManagerForLLMRun
6
  from pydantic import root_validator
7
  from langchain.llms import gpt4all
8
  from dotenv import dotenv_values
9
 
10
+ from utils import FakeTokenizer
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  def get_model_tokenizer_gpt4all(base_model, **kwargs):
 
63
  pass
64
 
65
 
66
+ def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
67
  # default from class
68
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
69
  # from our defaults
70
  model_kwargs.update(default_kwargs)
71
  # from user defaults
 
83
  repetition_penalty=1.0,
84
  top_k=40,
85
  top_p=0.7,
86
+ streaming=False,
87
+ callbacks=None,
88
+ prompter=None,
89
+ verbose=False,
90
+ ):
91
+ assert prompter is not None
92
  env_gpt4all_file = ".env_gpt4all"
93
  env_kwargs = dotenv_values(env_gpt4all_file)
 
94
  n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
95
  default_kwargs = dict(context_erase=0.5,
96
  n_batch=1,
 
107
  if model_name == 'llama':
108
  cls = H2OLlamaCpp
109
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
110
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
111
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
112
  llm = cls(**model_kwargs)
113
  llm.client.verbose = verbose
114
  elif model_name == 'gpt4all_llama':
115
  cls = H2OGPT4All
116
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
117
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
118
+ model_kwargs.update(
119
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
120
  llm = cls(**model_kwargs)
121
  elif model_name == 'gptj':
122
  cls = H2OGPT4All
123
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
124
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
125
+ model_kwargs.update(
126
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
127
  llm = cls(**model_kwargs)
128
  else:
129
  raise RuntimeError("No such model_name %s" % model_name)
 
132
 
133
  class H2OGPT4All(gpt4all.GPT4All):
134
  model: Any
135
+ prompter: Any
136
  """Path to the pre-trained GPT4All model file."""
137
 
138
  @root_validator()
 
152
  model_type=values["backend"],
153
  allow_download=False,
154
  )
155
+ if values["n_threads"] is not None:
156
+ # set n_threads
157
+ values["client"].model.set_thread_count(values["n_threads"])
158
  else:
159
  values["client"] = values["model"]
160
+ try:
161
+ values["backend"] = values["client"].model_type
162
+ except AttributeError:
163
+ # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
164
+ values["backend"] = values["client"].model.model_type
165
 
166
  except ImportError:
167
  raise ValueError(
 
175
  prompt: str,
176
  stop: Optional[List[str]] = None,
177
  run_manager: Optional[CallbackManagerForLLMRun] = None,
178
+ **kwargs,
179
  ) -> str:
180
  # Roughly 4 chars per token if natural language
181
  prompt = prompt[-self.n_ctx * 4:]
182
+
183
+ # use instruct prompting
184
+ data_point = dict(context='', instruction=prompt, input='')
185
+ prompt = self.prompter.generate_prompt(data_point)
186
+
187
  verbose = False
188
  if verbose:
189
  print("_call prompt: %s" % prompt, flush=True)
190
+ # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
191
  return super()._call(prompt, stop=stop, run_manager=run_manager)
192
 
193
 
 
196
 
197
  class H2OLlamaCpp(LlamaCpp):
198
  model_path: Any
199
+ prompter: Any
200
  """Path to the pre-trained GPT4All model file."""
201
 
202
  @root_validator()
 
248
  prompt: str,
249
  stop: Optional[List[str]] = None,
250
  run_manager: Optional[CallbackManagerForLLMRun] = None,
251
+ **kwargs,
252
  ) -> str:
253
  verbose = False
254
  # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
 
265
  prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
266
  num_prompt_tokens2 = len(prompt_tokens2)
267
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
268
+
269
+ # use instruct prompting
270
+ data_point = dict(context='', instruction=prompt, input='')
271
+ prompt = self.prompter.generate_prompt(data_point)
272
+
273
  if verbose:
274
  print("_call prompt: %s" % prompt, flush=True)
275
+
276
+ if self.streaming:
277
+ text_callback = None
278
+ if run_manager:
279
+ text_callback = partial(
280
+ run_manager.on_llm_new_token, verbose=self.verbose
281
+ )
282
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
283
+ if text_callback:
284
+ text_callback(prompt)
285
+ text = ""
286
+ for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
287
+ text_chunk = token["choices"][0]["text"]
288
+ # self.stream already calls text_callback
289
+ # if text_callback:
290
+ # text_callback(text_chunk)
291
+ text += text_chunk
292
+ return text
293
+ else:
294
+ params = self._get_parameters(stop)
295
+ params = {**params, **kwargs}
296
+ result = self.client(prompt=prompt, **params)
297
+ return result["choices"][0]["text"]
gpt_langchain.py CHANGED
@@ -1,31 +1,34 @@
 
1
  import glob
2
  import inspect
3
  import os
4
  import pathlib
5
  import pickle
6
- import queue
7
- import random
8
  import shutil
9
  import subprocess
10
- import sys
11
  import tempfile
 
12
  import traceback
 
13
  import uuid
14
  import zipfile
15
  from collections import defaultdict
16
  from datetime import datetime
17
  from functools import reduce
18
  from operator import concat
 
19
 
20
- from joblib import Parallel, delayed
 
21
  from langchain.embeddings import HuggingFaceInstructEmbeddings
22
  from tqdm import tqdm
23
 
24
- from enums import DocumentChoices
25
- from generate import gen_hyper
26
- from prompter import non_hf_types, PromptType
27
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
28
- get_device, ProgressParallel, remove, hash_file, clear_torch_cache
 
29
 
30
  import_matplotlib()
31
 
@@ -40,11 +43,11 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
40
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
41
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
42
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
43
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
44
- from langchain.text_splitter import RecursiveCharacterTextSplitter
45
  from langchain.chains.question_answering import load_qa_chain
46
  from langchain.docstore.document import Document
47
- from langchain import PromptTemplate
48
  from langchain.vectorstores import Chroma
49
 
50
 
@@ -142,7 +145,7 @@ def add_to_db(db, sources, db_type='faiss',
142
  return db, num_new_sources, []
143
  db.add_documents(documents=sources)
144
  elif db_type == 'chroma':
145
- collection = db.get()
146
  # files we already have:
147
  metadata_files = set([x['source'] for x in collection['metadatas']])
148
  if avoid_dup_by_file:
@@ -157,6 +160,9 @@ def add_to_db(db, sources, db_type='faiss',
157
  [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
158
  # avoid sources with same hash
159
  sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
 
 
 
160
  # get new file names that match existing file names. delete existing files we are overridding
161
  dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
162
  print("Removing %s duplicate files from db because ingesting those as new documents" % len(
@@ -233,7 +239,7 @@ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformer
233
  if use_openai_embedding:
234
  assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
235
  from langchain.embeddings import OpenAIEmbeddings
236
- embedding = OpenAIEmbeddings()
237
  else:
238
  # to ensure can fork without deadlock
239
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -260,8 +266,315 @@ def get_answer_from_sources(chain, sources, question):
260
  )["output_text"]
261
 
262
 
263
- def get_llm(use_openai_model=False, model_name=None, model=None,
264
- tokenizer=None, stream_output=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265
  do_sample=False,
266
  temperature=0.1,
267
  top_k=40,
@@ -276,47 +589,142 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
276
  prompt_type=None,
277
  prompt_dict=None,
278
  prompter=None,
 
279
  verbose=False,
280
  ):
281
- if use_openai_model:
282
- from langchain.llms import OpenAI
283
- llm = OpenAI(temperature=0)
284
- model_name = 'openai'
285
- streamer = None
286
- prompt_type = 'plain'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
287
  elif model_name in non_hf_types:
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  from gpt4all_llm import get_llm_gpt4all
289
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
290
  temperature=temperature,
291
  repetition_penalty=repetition_penalty,
292
  top_k=top_k,
293
  top_p=top_p,
 
294
  verbose=verbose,
 
 
295
  )
296
- streamer = None
297
- prompt_type = 'plain'
298
  else:
299
- from transformers import AutoTokenizer, AutoModelForCausalLM
300
-
301
  if model is None:
302
  # only used if didn't pass model in
303
  assert tokenizer is None
304
  prompt_type = 'human_bot'
305
- model_name = 'h2oai/h2ogpt-oasst1-512-12b'
306
- # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
307
- # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
308
- tokenizer = AutoTokenizer.from_pretrained(model_name)
309
- device, torch_dtype, context_class = get_device_dtype()
310
-
311
- with context_class(device):
312
- load_8bit = True
313
- # FIXME: for now not to spread across hetero GPUs
314
- # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
315
- device_map = {"": 0} if device == 'cuda' else "auto"
316
- model = AutoModelForCausalLM.from_pretrained(model_name,
317
- device_map=device_map,
318
- torch_dtype=torch_dtype,
319
- load_in_8bit=load_8bit)
320
 
321
  max_max_tokens = tokenizer.model_max_length
322
  gen_kwargs = dict(do_sample=do_sample,
@@ -331,7 +739,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
331
  repetition_penalty=repetition_penalty,
332
  num_return_sequences=num_return_sequences,
333
  return_full_text=True,
334
- handle_long_generation='hole')
335
  assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
336
 
337
  if stream_output:
@@ -348,10 +756,11 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
348
  prompter=prompter,
349
  prompt_type=prompt_type,
350
  prompt_dict=prompt_dict,
351
- sanitize_bot_response=True,
352
  chat=False, stream_output=stream_output,
353
  tokenizer=tokenizer,
354
- max_input_tokens=max_max_tokens - max_new_tokens,
 
355
  **gen_kwargs)
356
  # pipe.task = "text-generation"
357
  # below makes it listen only to our prompt removal,
@@ -396,7 +805,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
396
  data = json.load(open(filename, "rt"))
397
  page_content = list(data["query"]["pages"].values())[0]["extract"]
398
  if take_head is not None and text_limit is not None:
399
- page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
400
  title_url = str(title).replace(' ', '_')
401
  return Document(
402
  page_content=page_content,
@@ -518,6 +927,21 @@ try:
518
  except (pkg_resources.DistributionNotFound, AssertionError):
519
  have_pymupdf = False
520
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
521
  image_types = ["png", "jpg", "jpeg"]
522
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
523
  "md", "html",
@@ -535,7 +959,7 @@ file_types = non_image_types + image_types
535
  def add_meta(docs1, file):
536
  file_extension = pathlib.Path(file).suffix
537
  hashid = hash_file(file)
538
- if not isinstance(docs1, list):
539
  docs1 = [docs1]
540
  [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
541
 
@@ -577,8 +1001,24 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
577
  else:
578
  docs1 = []
579
  else:
 
 
580
  docs1 = UnstructuredURLLoader(urls=[file]).load()
 
 
 
 
 
 
 
 
 
 
 
 
 
581
  [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
 
582
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
583
  elif is_txt:
584
  base_path = "user_paste"
@@ -588,10 +1028,12 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
588
  f.write(file)
589
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
590
  doc1 = Document(page_content=file, metadata=metadata)
 
591
  elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
592
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
593
  add_meta(docs1, file)
594
- doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
595
  elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
596
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
597
  add_meta(docs1, file)
@@ -603,12 +1045,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
603
  elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
604
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
605
  add_meta(docs1, file)
 
606
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
607
  elif file.lower().endswith('.txt'):
608
  # use UnstructuredFileLoader ?
609
  docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
610
  # makes just one, but big one
611
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
612
  add_meta(doc1, file)
613
  elif file.lower().endswith('.rtf'):
614
  docs1 = UnstructuredRTFLoader(file).load()
@@ -617,7 +1061,8 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
617
  elif file.lower().endswith('.md'):
618
  docs1 = UnstructuredMarkdownLoader(file).load()
619
  add_meta(docs1, file)
620
- doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
 
621
  elif file.lower().endswith('.enex'):
622
  docs1 = EverNoteLoader(file).load()
623
  add_meta(doc1, file)
@@ -682,6 +1127,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
682
  with open(file, "r") as f:
683
  doc1 = Document(page_content=f.read(), metadata={"source": file})
684
  add_meta(doc1, file)
 
685
  elif file.lower().endswith('.pdf'):
686
  env_gpt4all_file = ".env_gpt4all"
687
  from dotenv import dotenv_values
@@ -692,11 +1138,17 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
692
  from langchain.document_loaders import PyMuPDFLoader
693
  # load() still chunks by pages, but every page has title at start to help
694
  doc1 = PyMuPDFLoader(file).load()
 
 
 
 
695
  else:
696
  # open-source fallback
697
  # load() still chunks by pages, but every page has title at start to help
698
  doc1 = PyPDFLoader(file).load()
 
699
  # Some PDFs return nothing or junk from PDFMinerLoader
 
700
  add_meta(doc1, file)
701
  elif file.lower().endswith('.csv'):
702
  doc1 = CSVLoader(file).load()
@@ -704,6 +1156,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
704
  elif file.lower().endswith('.py'):
705
  doc1 = PythonLoader(file).load()
706
  add_meta(doc1, file)
 
707
  elif file.lower().endswith('.toml'):
708
  doc1 = TomlLoader(file).load()
709
  add_meta(doc1, file)
@@ -794,15 +1247,16 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
794
  existing_files=[],
795
  existing_hash_ids={},
796
  ):
 
797
  globs_image_types = []
798
  globs_non_image_types = []
799
  if not path_or_paths and not url and not text:
800
  return []
801
  elif url:
802
- globs_non_image_types = [url]
803
  elif text:
804
- globs_non_image_types = [text]
805
- elif isinstance(path_or_paths, str):
806
  # single path, only consume allowed files
807
  path = path_or_paths
808
  # Below globs should match patterns in file_to_doc()
@@ -811,8 +1265,11 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
811
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
812
  for ftype in non_image_types]
813
  else:
 
 
814
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
815
- assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
 
816
  # reform out of allowed types
817
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
818
  # could do below:
@@ -972,7 +1429,7 @@ def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
972
  if load_embed(db) != (use_openai_embedding, hf_embedding_model):
973
  print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
974
  # handle embedding changes
975
- db_get = db.get()
976
  sources = [Document(page_content=result[0], metadata=result[1] or {})
977
  for result in zip(db_get['documents'], db_get['metadatas'])]
978
  # delete index, has to be redone
@@ -1023,14 +1480,17 @@ def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_opena
1023
  if changed_db:
1024
  db = db_trial
1025
  # only call persist if really changed db, else takes too long for large db
1026
- db.persist()
1027
- clear_embedding(db)
 
1028
  save_embed(db, use_openai_embedding, hf_embedding_model)
1029
  return db
1030
  return None
1031
 
1032
 
1033
  def clear_embedding(db):
 
 
1034
  # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
1035
  db._embedding_function.client.cpu()
1036
  clear_torch_cache()
@@ -1052,9 +1512,10 @@ def make_db(**langchain_kwargs):
1052
 
1053
 
1054
  def save_embed(db, use_openai_embedding, hf_embedding_model):
1055
- embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1056
- with open(embed_info_file, 'wb') as f:
1057
- pickle.dump((use_openai_embedding, hf_embedding_model), f)
 
1058
  return use_openai_embedding, hf_embedding_model
1059
 
1060
 
@@ -1201,28 +1662,90 @@ def _make_db(use_openai_embedding=False,
1201
  return db, len(new_sources_metadata), new_sources_metadata
1202
 
1203
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1204
  def get_existing_files(db):
1205
- collection = db.get()
1206
- metadata_sources = set([x['source'] for x in collection['metadatas']])
1207
  return metadata_sources
1208
 
1209
 
1210
  def get_existing_hash_ids(db):
1211
- collection = db.get()
1212
  # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1213
- metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
1214
  return metadata_hash_ids
1215
 
1216
 
1217
- source_prefix = "Sources [Score | Link]:"
1218
- source_postfix = "End Sources<p>"
1219
-
1220
-
1221
  def run_qa_db(**kwargs):
1222
  func_names = list(inspect.signature(_run_qa_db).parameters)
1223
  # hard-coded defaults
1224
  kwargs['answer_with_sources'] = True
1225
- kwargs['sanitize_bot_response'] = True
1226
  kwargs['show_rank'] = False
1227
  missing_kwargs = [x for x in func_names if x not in kwargs]
1228
  assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
@@ -1240,7 +1763,7 @@ def _run_qa_db(query=None,
1240
  user_path=None,
1241
  detect_user_path_changes_every_query=False,
1242
  db_type='faiss',
1243
- model_name=None, model=None, tokenizer=None,
1244
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1245
  stream_output=False,
1246
  prompter=None,
@@ -1248,7 +1771,7 @@ def _run_qa_db(query=None,
1248
  prompt_dict=None,
1249
  answer_with_sources=True,
1250
  cut_distanct=1.1,
1251
- sanitize_bot_response=True,
1252
  show_rank=False,
1253
  load_db_if_exists=False,
1254
  db=None,
@@ -1267,7 +1790,12 @@ def _run_qa_db(query=None,
1267
  document_choice=[DocumentChoices.All_Relevant.name],
1268
  n_jobs=-1,
1269
  verbose=False,
1270
- cli=False):
 
 
 
 
 
1271
  """
1272
 
1273
  :param query:
@@ -1286,6 +1814,8 @@ def _run_qa_db(query=None,
1286
  :param answer_with_sources
1287
  :return:
1288
  """
 
 
1289
  assert query is not None
1290
  assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1291
  if prompter is not None:
@@ -1299,7 +1829,9 @@ def _run_qa_db(query=None,
1299
  prompt_dict = ''
1300
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1301
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1302
- model=model, tokenizer=tokenizer,
 
 
1303
  stream_output=stream_output,
1304
  do_sample=do_sample,
1305
  temperature=temperature,
@@ -1315,13 +1847,10 @@ def _run_qa_db(query=None,
1315
  prompt_type=prompt_type,
1316
  prompt_dict=prompt_dict,
1317
  prompter=prompter,
 
1318
  verbose=verbose,
1319
  )
1320
 
1321
- if model_name in non_hf_types:
1322
- # FIXME: for now, streams to stdout/stderr currently
1323
- stream_output = False
1324
-
1325
  use_context = False
1326
  scores = []
1327
  chain = None
@@ -1349,43 +1878,49 @@ def _run_qa_db(query=None,
1349
  # can only return if HF type
1350
  return
1351
 
1352
- if stream_output:
1353
- answer = None
1354
- assert streamer is not None
1355
- import queue
1356
- bucket = queue.Queue()
1357
- thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1358
- thread.start()
1359
- outputs = ""
1360
- prompt = None # FIXME
1361
- try:
1362
- for new_text in streamer:
1363
- # print("new_text: %s" % new_text, flush=True)
1364
- if bucket.qsize() > 0 or thread.exc:
1365
- thread.join()
1366
- outputs += new_text
1367
- if prompter: # and False: # FIXME: pipeline can already use prompter
1368
- output1 = prompter.get_response(outputs, prompt=prompt,
1369
- sanitize_bot_response=sanitize_bot_response)
1370
- yield output1, ''
1371
- else:
1372
- yield outputs, ''
1373
- except BaseException:
1374
- # if any exception, raise that exception if was from thread, first
1375
- if thread.exc:
1376
- raise thread.exc
1377
- raise
1378
- finally:
1379
- # in case no exception and didn't join with thread yet, then join
1380
- if not thread.exc:
1381
- answer = thread.join()
1382
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
1383
- if thread.exc:
1384
- raise thread.exc
1385
- # FIXME: answer is not string outputs from streamer. How to get actual final output?
1386
- # answer = outputs
1387
- else:
1388
- answer = chain()
 
 
 
 
 
 
1389
 
1390
  if not use_context:
1391
  ret = answer['output_text']
@@ -1404,6 +1939,7 @@ def get_similarity_chain(query=None,
1404
  detect_user_path_changes_every_query=False,
1405
  db_type='faiss',
1406
  model_name=None,
 
1407
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1408
  prompt_type=None,
1409
  prompt_dict=None,
@@ -1415,8 +1951,14 @@ def get_similarity_chain(query=None,
1415
  n_jobs=-1,
1416
  # beyond run_db_query:
1417
  llm=None,
 
1418
  verbose=False,
1419
  cmd=None,
 
 
 
 
 
1420
  ):
1421
  # determine whether use of context out of docs is planned
1422
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
@@ -1431,7 +1973,11 @@ def get_similarity_chain(query=None,
1431
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
1432
  # Chroma collection MyData contains fewer than 4 elements.
1433
  # type logger error
1434
- k_db = 1000 if db_type == 'chroma' else top_k_docs # top_k_docs=100 works ok too for
 
 
 
 
1435
 
1436
  # FIXME: For All just go over all dbs instead of a separate db for All
1437
  if not detect_user_path_changes_every_query and db is not None:
@@ -1452,6 +1998,29 @@ def get_similarity_chain(query=None,
1452
  n_jobs=n_jobs,
1453
  verbose=verbose)
1454
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1455
  if db and use_context:
1456
  if not isinstance(db, Chroma):
1457
  # only chroma supports filtering
@@ -1472,17 +2041,84 @@ def get_similarity_chain(query=None,
1472
  docs = []
1473
  scores = []
1474
  elif cmd == DocumentChoices.Only_All_Sources.name:
1475
- if isinstance(db, Chroma):
1476
- db_get = db._collection.get(where=filter_kwargs.get('filter'))
1477
- else:
1478
- db_get = db.get()
1479
  # similar to langchain's chroma's _results_to_docs_and_scores
1480
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
1481
- for result in zip(db_get['documents'], db_get['metadatas'])][:top_k_docs]
1482
  docs = [x[0] for x in docs_with_score]
1483
  scores = [x[1] for x in docs_with_score]
1484
  else:
1485
- docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1486
  # cut off so no high distance docs/sources considered
1487
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
1488
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
@@ -1517,19 +2153,11 @@ def get_similarity_chain(query=None,
1517
  if len(docs) == 0:
1518
  # avoid context == in prompt then
1519
  use_context = False
 
1520
 
1521
- if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
1522
  # instruct-like, rather than few-shot prompt_type='plain' as default
1523
  # but then sources confuse the model with how inserted among rest of text, so avoid
1524
- prefix = ""
1525
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
1526
- template = """%s{context}{question}""" % prefix
1527
- else:
1528
- template = """%s
1529
- ==
1530
- {context}
1531
- ==
1532
- {question}""" % prefix
1533
  prompt = PromptTemplate(
1534
  # input_variables=["summaries", "question"],
1535
  input_variables=["context", "question"],
@@ -1589,17 +2217,32 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
1589
  return ret, extra
1590
 
1591
 
1592
- def chunk_sources(sources, chunk=True, chunk_size=512):
 
 
 
 
 
 
 
 
1593
  if not chunk:
1594
  return sources
1595
- source_chunks = []
1596
- # Below for known separator
1597
- # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
1598
- splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
1599
- for source in sources:
1600
- # print(source.metadata['source'], flush=True)
1601
- for chunky in splitter.split_text(source.page_content):
1602
- source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
 
 
 
 
 
 
 
1603
  return source_chunks
1604
 
1605
 
@@ -1647,15 +2290,17 @@ def _create_local_weaviate_client():
1647
  WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
1648
 
1649
  resource_owner_config = None
1650
- if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
1651
- resource_owner_config = weaviate.AuthClientPassword(
1652
- username=WEAVIATE_USERNAME,
1653
- password=WEAVIATE_PASSWORD,
1654
- scope=WEAVIATE_SCOPE
1655
- )
1656
-
1657
  try:
 
 
 
 
 
 
 
 
1658
  client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
 
1659
  except Exception as e:
1660
  print(f"Failed to create Weaviate client: {e}")
1661
  return None
 
1
+ import ast
2
  import glob
3
  import inspect
4
  import os
5
  import pathlib
6
  import pickle
 
 
7
  import shutil
8
  import subprocess
 
9
  import tempfile
10
+ import time
11
  import traceback
12
+ import types
13
  import uuid
14
  import zipfile
15
  from collections import defaultdict
16
  from datetime import datetime
17
  from functools import reduce
18
  from operator import concat
19
+ import filelock
20
 
21
+ from joblib import delayed
22
+ from langchain.callbacks import streaming_stdout
23
  from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from tqdm import tqdm
25
 
26
+ from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
27
+ from generate import gen_hyper, get_model, SEED
28
+ from prompter import non_hf_types, PromptType, Prompter
29
  from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
30
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
31
+ from utils_langchain import StreamingGradioCallbackHandler
32
 
33
  import_matplotlib()
34
 
 
43
  from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
44
  UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
45
  EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
46
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
47
+ from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
48
  from langchain.chains.question_answering import load_qa_chain
49
  from langchain.docstore.document import Document
50
+ from langchain import PromptTemplate, HuggingFaceTextGenInference
51
  from langchain.vectorstores import Chroma
52
 
53
 
 
145
  return db, num_new_sources, []
146
  db.add_documents(documents=sources)
147
  elif db_type == 'chroma':
148
+ collection = get_documents(db)
149
  # files we already have:
150
  metadata_files = set([x['source'] for x in collection['metadatas']])
151
  if avoid_dup_by_file:
 
160
  [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
161
  # avoid sources with same hash
162
  sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
163
+ num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
164
+ print("Found %s new sources (%d have no hash in original source,"
165
+ " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
166
  # get new file names that match existing file names. delete existing files we are overridding
167
  dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
168
  print("Removing %s duplicate files from db because ingesting those as new documents" % len(
 
239
  if use_openai_embedding:
240
  assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
241
  from langchain.embeddings import OpenAIEmbeddings
242
+ embedding = OpenAIEmbeddings(disallowed_special=())
243
  else:
244
  # to ensure can fork without deadlock
245
  from langchain.embeddings import HuggingFaceEmbeddings
 
266
  )["output_text"]
267
 
268
 
269
+ """Wrapper around Huggingface text generation inference API."""
270
+ from functools import partial
271
+ from typing import Any, Dict, List, Optional, Set
272
+
273
+ from pydantic import Extra, Field, root_validator
274
+
275
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
276
+
277
+ """Wrapper around Huggingface text generation inference API."""
278
+ from functools import partial
279
+ from typing import Any, Dict, List, Optional
280
+
281
+ from pydantic import Extra, Field, root_validator
282
+
283
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
284
+ from langchain.llms.base import LLM
285
+
286
+
287
+ class GradioInference(LLM):
288
+ """
289
+ Gradio generation inference API.
290
+ """
291
+ inference_server_url: str = ""
292
+
293
+ temperature: float = 0.8
294
+ top_p: Optional[float] = 0.95
295
+ top_k: Optional[int] = None
296
+ num_beams: Optional[int] = 1
297
+ max_new_tokens: int = 512
298
+ min_new_tokens: int = 1
299
+ early_stopping: bool = False
300
+ max_time: int = 180
301
+ repetition_penalty: Optional[float] = None
302
+ num_return_sequences: Optional[int] = 1
303
+ do_sample: bool = False
304
+ chat_client: bool = False
305
+
306
+ return_full_text: bool = True
307
+ stream: bool = False
308
+ sanitize_bot_response: bool = False
309
+
310
+ prompter: Any = None
311
+ client: Any = None
312
+
313
+ class Config:
314
+ """Configuration for this pydantic object."""
315
+
316
+ extra = Extra.forbid
317
+
318
+ @root_validator()
319
+ def validate_environment(cls, values: Dict) -> Dict:
320
+ """Validate that python package exists in environment."""
321
+
322
+ try:
323
+ if values['client'] is None:
324
+ import gradio_client
325
+ values["client"] = gradio_client.Client(
326
+ values["inference_server_url"]
327
+ )
328
+ except ImportError:
329
+ raise ImportError(
330
+ "Could not import gradio_client python package. "
331
+ "Please install it with `pip install gradio_client`."
332
+ )
333
+ return values
334
+
335
+ @property
336
+ def _llm_type(self) -> str:
337
+ """Return type of llm."""
338
+ return "gradio_inference"
339
+
340
+ def _call(
341
+ self,
342
+ prompt: str,
343
+ stop: Optional[List[str]] = None,
344
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
345
+ **kwargs: Any,
346
+ ) -> str:
347
+ # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
348
+ # so server should get prompt_type or '', not plain
349
+ # This is good, so gradio server can also handle stopping.py conditions
350
+ # this is different than TGI server that uses prompter to inject prompt_type prompting
351
+ stream_output = self.stream
352
+ gr_client = self.client
353
+ client_langchain_mode = 'Disabled'
354
+ top_k_docs = 1
355
+ chunk = True
356
+ chunk_size = 512
357
+ client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
358
+ iinput='', # only for chat=True
359
+ context='',
360
+ # streaming output is supported, loops over and outputs each generation in streaming mode
361
+ # but leave stream_output=False for simple input/output mode
362
+ stream_output=stream_output,
363
+ prompt_type=self.prompter.prompt_type,
364
+ prompt_dict='',
365
+
366
+ temperature=self.temperature,
367
+ top_p=self.top_p,
368
+ top_k=self.top_k,
369
+ num_beams=self.num_beams,
370
+ max_new_tokens=self.max_new_tokens,
371
+ min_new_tokens=self.min_new_tokens,
372
+ early_stopping=self.early_stopping,
373
+ max_time=self.max_time,
374
+ repetition_penalty=self.repetition_penalty,
375
+ num_return_sequences=self.num_return_sequences,
376
+ do_sample=self.do_sample,
377
+ chat=self.chat_client,
378
+
379
+ instruction_nochat=prompt if not self.chat_client else '',
380
+ iinput_nochat='', # only for chat=False
381
+ langchain_mode=client_langchain_mode,
382
+ top_k_docs=top_k_docs,
383
+ chunk=chunk,
384
+ chunk_size=chunk_size,
385
+ document_choice=[DocumentChoices.All_Relevant.name],
386
+ )
387
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
388
+ if not stream_output:
389
+ res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
390
+ res_dict = ast.literal_eval(res)
391
+ text = res_dict['response']
392
+ return self.prompter.get_response(prompt + text, prompt=prompt,
393
+ sanitize_bot_response=self.sanitize_bot_response)
394
+ else:
395
+ text_callback = None
396
+ if run_manager:
397
+ text_callback = partial(
398
+ run_manager.on_llm_new_token, verbose=self.verbose
399
+ )
400
+
401
+ job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
402
+ text0 = ''
403
+ while not job.done():
404
+ outputs_list = job.communicator.job.outputs
405
+ if outputs_list:
406
+ res = job.communicator.job.outputs[-1]
407
+ res_dict = ast.literal_eval(res)
408
+ text = res_dict['response']
409
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
410
+ sanitize_bot_response=self.sanitize_bot_response)
411
+ # FIXME: derive chunk from full for now
412
+ text_chunk = text[len(text0):]
413
+ # save old
414
+ text0 = text
415
+
416
+ if text_callback:
417
+ text_callback(text_chunk)
418
+
419
+ time.sleep(0.01)
420
+
421
+ # ensure get last output to avoid race
422
+ res_all = job.outputs()
423
+ if len(res_all) > 0:
424
+ res = res_all[-1]
425
+ res_dict = ast.literal_eval(res)
426
+ text = res_dict['response']
427
+ # FIXME: derive chunk from full for now
428
+ else:
429
+ # go with old if failure
430
+ text = text0
431
+ text_chunk = text[len(text0):]
432
+ if text_callback:
433
+ text_callback(text_chunk)
434
+ return self.prompter.get_response(prompt + text, prompt=prompt,
435
+ sanitize_bot_response=self.sanitize_bot_response)
436
+
437
+
438
+ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
439
+ max_new_tokens: int = 512
440
+ do_sample: bool = False
441
+ top_k: Optional[int] = None
442
+ top_p: Optional[float] = 0.95
443
+ typical_p: Optional[float] = 0.95
444
+ temperature: float = 0.8
445
+ repetition_penalty: Optional[float] = None
446
+ return_full_text: bool = False
447
+ stop_sequences: List[str] = Field(default_factory=list)
448
+ seed: Optional[int] = None
449
+ inference_server_url: str = ""
450
+ timeout: int = 300
451
+ headers: dict = None
452
+ stream: bool = False
453
+ sanitize_bot_response: bool = False
454
+ prompter: Any = None
455
+ tokenizer: Any = None
456
+ client: Any = None
457
+
458
+ @root_validator()
459
+ def validate_environment(cls, values: Dict) -> Dict:
460
+ """Validate that python package exists in environment."""
461
+
462
+ try:
463
+ if values['client'] is None:
464
+ import text_generation
465
+
466
+ values["client"] = text_generation.Client(
467
+ values["inference_server_url"],
468
+ timeout=values["timeout"],
469
+ headers=values["headers"],
470
+ )
471
+ except ImportError:
472
+ raise ImportError(
473
+ "Could not import text_generation python package. "
474
+ "Please install it with `pip install text_generation`."
475
+ )
476
+ return values
477
+
478
+ def _call(
479
+ self,
480
+ prompt: str,
481
+ stop: Optional[List[str]] = None,
482
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
483
+ **kwargs: Any,
484
+ ) -> str:
485
+ if stop is None:
486
+ stop = self.stop_sequences
487
+ else:
488
+ stop += self.stop_sequences
489
+
490
+ # HF inference server needs control over input tokens
491
+ assert self.tokenizer is not None
492
+ from h2oai_pipeline import H2OTextGenerationPipeline
493
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
494
+
495
+ # NOTE: TGI server does not add prompting, so must do here
496
+ data_point = dict(context='', instruction=prompt, input='')
497
+ prompt = self.prompter.generate_prompt(data_point)
498
+
499
+ gen_server_kwargs = dict(do_sample=self.do_sample,
500
+ stop_sequences=stop,
501
+ max_new_tokens=self.max_new_tokens,
502
+ top_k=self.top_k,
503
+ top_p=self.top_p,
504
+ typical_p=self.typical_p,
505
+ temperature=self.temperature,
506
+ repetition_penalty=self.repetition_penalty,
507
+ return_full_text=self.return_full_text,
508
+ seed=self.seed,
509
+ )
510
+ gen_server_kwargs.update(kwargs)
511
+
512
+ # lower bound because client is re-used if multi-threading
513
+ self.client.timeout = max(300, self.timeout)
514
+
515
+ if not self.stream:
516
+ res = self.client.generate(
517
+ prompt,
518
+ **gen_server_kwargs,
519
+ )
520
+ if self.return_full_text:
521
+ gen_text = res.generated_text[len(prompt):]
522
+ else:
523
+ gen_text = res.generated_text
524
+ # remove stop sequences from the end of the generated text
525
+ for stop_seq in stop:
526
+ if stop_seq in gen_text:
527
+ gen_text = gen_text[:gen_text.index(stop_seq)]
528
+ text = prompt + gen_text
529
+ text = self.prompter.get_response(text, prompt=prompt,
530
+ sanitize_bot_response=self.sanitize_bot_response)
531
+ else:
532
+ text_callback = None
533
+ if run_manager:
534
+ text_callback = partial(
535
+ run_manager.on_llm_new_token, verbose=self.verbose
536
+ )
537
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
538
+ if text_callback:
539
+ text_callback(prompt)
540
+ text = ""
541
+ # Note: Streaming ignores return_full_text=True
542
+ for response in self.client.generate_stream(prompt, **gen_server_kwargs):
543
+ text_chunk = response.token.text
544
+ text += text_chunk
545
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
546
+ sanitize_bot_response=self.sanitize_bot_response)
547
+ # stream part
548
+ is_stop = False
549
+ for stop_seq in stop:
550
+ if stop_seq in response.token.text:
551
+ is_stop = True
552
+ break
553
+ if is_stop:
554
+ break
555
+ if not response.token.special:
556
+ if text_callback:
557
+ text_callback(response.token.text)
558
+ return text
559
+
560
+
561
+ from langchain.chat_models import ChatOpenAI
562
+
563
+
564
+ class H2OChatOpenAI(ChatOpenAI):
565
+ @classmethod
566
+ def all_required_field_names(cls) -> Set:
567
+ all_required_field_names = super(ChatOpenAI, cls).all_required_field_names()
568
+ all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'})
569
+ return all_required_field_names
570
+
571
+
572
+ def get_llm(use_openai_model=False,
573
+ model_name=None,
574
+ model=None,
575
+ tokenizer=None,
576
+ inference_server=None,
577
+ stream_output=False,
578
  do_sample=False,
579
  temperature=0.1,
580
  top_k=40,
 
589
  prompt_type=None,
590
  prompt_dict=None,
591
  prompter=None,
592
+ sanitize_bot_response=False,
593
  verbose=False,
594
  ):
595
+ if use_openai_model or inference_server in ['openai', 'openai_chat']:
596
+ if use_openai_model and model_name is None:
597
+ model_name = "gpt-3.5-turbo"
598
+ if inference_server == 'openai':
599
+ from langchain.llms import OpenAI
600
+ cls = OpenAI
601
+ else:
602
+ cls = H2OChatOpenAI
603
+ callbacks = [StreamingGradioCallbackHandler()]
604
+ llm = cls(model_name=model_name,
605
+ temperature=temperature if do_sample else 0,
606
+ # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
607
+ max_tokens=max_new_tokens,
608
+ top_p=top_p if do_sample else 1,
609
+ frequency_penalty=0,
610
+ presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
611
+ callbacks=callbacks if stream_output else None,
612
+ )
613
+ streamer = callbacks[0] if stream_output else None
614
+ if inference_server in ['openai', 'openai_chat']:
615
+ prompt_type = inference_server
616
+ else:
617
+ prompt_type = prompt_type or 'plain'
618
+ elif inference_server:
619
+ assert inference_server.startswith(
620
+ 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
621
+
622
+ from gradio_utils.grclient import GradioClient
623
+ from text_generation import Client as HFClient
624
+ if isinstance(model, GradioClient):
625
+ gr_client = model
626
+ hf_client = None
627
+ else:
628
+ gr_client = None
629
+ hf_client = model
630
+ assert isinstance(hf_client, HFClient)
631
+
632
+ inference_server, headers = get_hf_server(inference_server)
633
+
634
+ # quick sanity check to avoid long timeouts, just see if can reach server
635
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
636
+
637
+ callbacks = [StreamingGradioCallbackHandler()]
638
+ assert prompter is not None
639
+ stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
640
+
641
+ if gr_client:
642
+ chat_client = False
643
+ llm = GradioInference(
644
+ inference_server_url=inference_server,
645
+ return_full_text=True,
646
+
647
+ temperature=temperature,
648
+ top_p=top_p,
649
+ top_k=top_k,
650
+ num_beams=num_beams,
651
+ max_new_tokens=max_new_tokens,
652
+ min_new_tokens=min_new_tokens,
653
+ early_stopping=early_stopping,
654
+ max_time=max_time,
655
+ repetition_penalty=repetition_penalty,
656
+ num_return_sequences=num_return_sequences,
657
+ do_sample=do_sample,
658
+ chat_client=chat_client,
659
+
660
+ callbacks=callbacks if stream_output else None,
661
+ stream=stream_output,
662
+ prompter=prompter,
663
+ client=gr_client,
664
+ sanitize_bot_response=sanitize_bot_response,
665
+ )
666
+ elif hf_client:
667
+ llm = H2OHuggingFaceTextGenInference(
668
+ inference_server_url=inference_server,
669
+ do_sample=do_sample,
670
+ max_new_tokens=max_new_tokens,
671
+ repetition_penalty=repetition_penalty,
672
+ return_full_text=True,
673
+ seed=SEED,
674
+
675
+ stop_sequences=stop_sequences,
676
+ temperature=temperature,
677
+ top_k=top_k,
678
+ top_p=top_p,
679
+ # typical_p=top_p,
680
+ callbacks=callbacks if stream_output else None,
681
+ stream=stream_output,
682
+ prompter=prompter,
683
+ tokenizer=tokenizer,
684
+ client=hf_client,
685
+ timeout=max_time,
686
+ sanitize_bot_response=sanitize_bot_response,
687
+ )
688
+ else:
689
+ raise RuntimeError("No defined client")
690
+ streamer = callbacks[0] if stream_output else None
691
  elif model_name in non_hf_types:
692
+ if model_name == 'llama':
693
+ callbacks = [StreamingGradioCallbackHandler()]
694
+ streamer = callbacks[0] if stream_output else None
695
+ else:
696
+ # stream_output = False
697
+ # doesn't stream properly as generator, but at least
698
+ callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
699
+ streamer = None
700
+ if prompter:
701
+ prompt_type = prompter.prompt_type
702
+ else:
703
+ prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
704
+ pass # assume inputted prompt_type is correct
705
  from gpt4all_llm import get_llm_gpt4all
706
  llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
707
  temperature=temperature,
708
  repetition_penalty=repetition_penalty,
709
  top_k=top_k,
710
  top_p=top_p,
711
+ callbacks=callbacks,
712
  verbose=verbose,
713
+ streaming=stream_output,
714
+ prompter=prompter,
715
  )
 
 
716
  else:
 
 
717
  if model is None:
718
  # only used if didn't pass model in
719
  assert tokenizer is None
720
  prompt_type = 'human_bot'
721
+ if model_name is None:
722
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
723
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
724
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
725
+ inference_server = ''
726
+ model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
727
+ inference_server=inference_server, gpu_id=0)
 
 
 
 
 
 
 
 
728
 
729
  max_max_tokens = tokenizer.model_max_length
730
  gen_kwargs = dict(do_sample=do_sample,
 
739
  repetition_penalty=repetition_penalty,
740
  num_return_sequences=num_return_sequences,
741
  return_full_text=True,
742
+ handle_long_generation=None)
743
  assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
744
 
745
  if stream_output:
 
756
  prompter=prompter,
757
  prompt_type=prompt_type,
758
  prompt_dict=prompt_dict,
759
+ sanitize_bot_response=sanitize_bot_response,
760
  chat=False, stream_output=stream_output,
761
  tokenizer=tokenizer,
762
+ # leave some room for 1 paragraph, even if min_new_tokens=0
763
+ max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
764
  **gen_kwargs)
765
  # pipe.task = "text-generation"
766
  # below makes it listen only to our prompt removal,
 
805
  data = json.load(open(filename, "rt"))
806
  page_content = list(data["query"]["pages"].values())[0]["extract"]
807
  if take_head is not None and text_limit is not None:
808
+ page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
809
  title_url = str(title).replace(' ', '_')
810
  return Document(
811
  page_content=page_content,
 
927
  except (pkg_resources.DistributionNotFound, AssertionError):
928
  have_pymupdf = False
929
 
930
+ try:
931
+ assert pkg_resources.get_distribution('selenium') is not None
932
+ have_selenium = True
933
+ except (pkg_resources.DistributionNotFound, AssertionError):
934
+ have_selenium = False
935
+
936
+ try:
937
+ assert pkg_resources.get_distribution('playwright') is not None
938
+ have_playwright = True
939
+ except (pkg_resources.DistributionNotFound, AssertionError):
940
+ have_playwright = False
941
+
942
+ # disable, hangs too often
943
+ have_playwright = False
944
+
945
  image_types = ["png", "jpg", "jpeg"]
946
  non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
947
  "md", "html",
 
959
  def add_meta(docs1, file):
960
  file_extension = pathlib.Path(file).suffix
961
  hashid = hash_file(file)
962
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
963
  docs1 = [docs1]
964
  [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
965
 
 
1001
  else:
1002
  docs1 = []
1003
  else:
1004
+ if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
1005
+ file = 'http://' + file
1006
  docs1 = UnstructuredURLLoader(urls=[file]).load()
1007
+ if len(docs1) == 0 and have_playwright:
1008
+ # then something went wrong, try another loader:
1009
+ from langchain.document_loaders import PlaywrightURLLoader
1010
+ docs1 = PlaywrightURLLoader(urls=[file]).load()
1011
+ if len(docs1) == 0 and have_selenium:
1012
+ # then something went wrong, try another loader:
1013
+ # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary
1014
+ from langchain.document_loaders import SeleniumURLLoader
1015
+ from selenium.common.exceptions import WebDriverException
1016
+ try:
1017
+ docs1 = SeleniumURLLoader(urls=[file]).load()
1018
+ except WebDriverException as e:
1019
+ print("No web driver: %s" % str(e), flush=True)
1020
  [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
1021
+ docs1 = clean_doc(docs1)
1022
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1023
  elif is_txt:
1024
  base_path = "user_paste"
 
1028
  f.write(file)
1029
  metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
1030
  doc1 = Document(page_content=file, metadata=metadata)
1031
+ doc1 = clean_doc(doc1)
1032
  elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
1033
  docs1 = UnstructuredHTMLLoader(file_path=file).load()
1034
  add_meta(docs1, file)
1035
+ docs1 = clean_doc(docs1)
1036
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
1037
  elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
1038
  docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
1039
  add_meta(docs1, file)
 
1045
  elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
1046
  docs1 = UnstructuredPowerPointLoader(file_path=file).load()
1047
  add_meta(docs1, file)
1048
+ docs1 = clean_doc(docs1)
1049
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1050
  elif file.lower().endswith('.txt'):
1051
  # use UnstructuredFileLoader ?
1052
  docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
1053
  # makes just one, but big one
1054
  doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
1055
+ doc1 = clean_doc(doc1)
1056
  add_meta(doc1, file)
1057
  elif file.lower().endswith('.rtf'):
1058
  docs1 = UnstructuredRTFLoader(file).load()
 
1061
  elif file.lower().endswith('.md'):
1062
  docs1 = UnstructuredMarkdownLoader(file).load()
1063
  add_meta(docs1, file)
1064
+ docs1 = clean_doc(docs1)
1065
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
1066
  elif file.lower().endswith('.enex'):
1067
  docs1 = EverNoteLoader(file).load()
1068
  add_meta(doc1, file)
 
1127
  with open(file, "r") as f:
1128
  doc1 = Document(page_content=f.read(), metadata={"source": file})
1129
  add_meta(doc1, file)
1130
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
1131
  elif file.lower().endswith('.pdf'):
1132
  env_gpt4all_file = ".env_gpt4all"
1133
  from dotenv import dotenv_values
 
1138
  from langchain.document_loaders import PyMuPDFLoader
1139
  # load() still chunks by pages, but every page has title at start to help
1140
  doc1 = PyMuPDFLoader(file).load()
1141
+ doc1 = clean_doc(doc1)
1142
+ elif pdf_class_name == 'UnstructuredPDFLoader':
1143
+ doc1 = UnstructuredPDFLoader(file).load()
1144
+ # seems to not need cleaning in most cases
1145
  else:
1146
  # open-source fallback
1147
  # load() still chunks by pages, but every page has title at start to help
1148
  doc1 = PyPDFLoader(file).load()
1149
+ doc1 = clean_doc(doc1)
1150
  # Some PDFs return nothing or junk from PDFMinerLoader
1151
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1152
  add_meta(doc1, file)
1153
  elif file.lower().endswith('.csv'):
1154
  doc1 = CSVLoader(file).load()
 
1156
  elif file.lower().endswith('.py'):
1157
  doc1 = PythonLoader(file).load()
1158
  add_meta(doc1, file)
1159
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
1160
  elif file.lower().endswith('.toml'):
1161
  doc1 = TomlLoader(file).load()
1162
  add_meta(doc1, file)
 
1247
  existing_files=[],
1248
  existing_hash_ids={},
1249
  ):
1250
+ # path_or_paths could be str, list, tuple, generator
1251
  globs_image_types = []
1252
  globs_non_image_types = []
1253
  if not path_or_paths and not url and not text:
1254
  return []
1255
  elif url:
1256
+ globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
1257
  elif text:
1258
+ globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
1259
+ elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
1260
  # single path, only consume allowed files
1261
  path = path_or_paths
1262
  # Below globs should match patterns in file_to_doc()
 
1265
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1266
  for ftype in non_image_types]
1267
  else:
1268
+ if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
1269
+ path_or_paths = [path_or_paths]
1270
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1271
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
1272
+ path_or_paths)
1273
  # reform out of allowed types
1274
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1275
  # could do below:
 
1429
  if load_embed(db) != (use_openai_embedding, hf_embedding_model):
1430
  print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
1431
  # handle embedding changes
1432
+ db_get = get_documents(db)
1433
  sources = [Document(page_content=result[0], metadata=result[1] or {})
1434
  for result in zip(db_get['documents'], db_get['metadatas'])]
1435
  # delete index, has to be redone
 
1480
  if changed_db:
1481
  db = db_trial
1482
  # only call persist if really changed db, else takes too long for large db
1483
+ if db is not None:
1484
+ db.persist()
1485
+ clear_embedding(db)
1486
  save_embed(db, use_openai_embedding, hf_embedding_model)
1487
  return db
1488
  return None
1489
 
1490
 
1491
  def clear_embedding(db):
1492
+ if db is None:
1493
+ return
1494
  # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
1495
  db._embedding_function.client.cpu()
1496
  clear_torch_cache()
 
1512
 
1513
 
1514
  def save_embed(db, use_openai_embedding, hf_embedding_model):
1515
+ if db is not None:
1516
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
1517
+ with open(embed_info_file, 'wb') as f:
1518
+ pickle.dump((use_openai_embedding, hf_embedding_model), f)
1519
  return use_openai_embedding, hf_embedding_model
1520
 
1521
 
 
1662
  return db, len(new_sources_metadata), new_sources_metadata
1663
 
1664
 
1665
+ def get_metadatas(db):
1666
+ from langchain.vectorstores import FAISS
1667
+ if isinstance(db, FAISS):
1668
+ metadatas = [v.metadata for k, v in db.docstore._dict.items()]
1669
+ elif isinstance(db, Chroma):
1670
+ metadatas = get_documents(db)['metadatas']
1671
+ else:
1672
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
1673
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
1674
+ metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
1675
+ return metadatas
1676
+
1677
+
1678
+ def get_documents(db):
1679
+ if hasattr(db, '_persist_directory'):
1680
+ name_path = os.path.basename(db._persist_directory)
1681
+ base_path = 'locks'
1682
+ makedirs(base_path)
1683
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
1684
+ # get segfaults and other errors when multiple threads access this
1685
+ return _get_documents(db)
1686
+ else:
1687
+ return _get_documents(db)
1688
+
1689
+
1690
+ def _get_documents(db):
1691
+ from langchain.vectorstores import FAISS
1692
+ if isinstance(db, FAISS):
1693
+ documents = [v for k, v in db.docstore._dict.items()]
1694
+ elif isinstance(db, Chroma):
1695
+ documents = db.get()
1696
+ else:
1697
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
1698
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
1699
+ documents = [x for x in db.similarity_search("", k=10000)]
1700
+ return documents
1701
+
1702
+
1703
+ def get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
1704
+ if hasattr(db, '_persist_directory'):
1705
+ name_path = os.path.basename(db._persist_directory)
1706
+ base_path = 'locks'
1707
+ makedirs(base_path)
1708
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
1709
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
1710
+ else:
1711
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
1712
+
1713
+
1714
+ def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
1715
+ from langchain.vectorstores import FAISS
1716
+ if isinstance(db, Chroma):
1717
+ db_get = db._collection.get(where=filter_kwargs.get('filter'))
1718
+ db_metadatas = db_get['metadatas']
1719
+ db_documents = db_get['documents']
1720
+ elif isinstance(db, FAISS):
1721
+ import itertools
1722
+ db_metadatas = get_metadatas(db)
1723
+ # FIXME: FAISS has no filter
1724
+ # slice dict first
1725
+ db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
1726
+ else:
1727
+ db_metadatas = get_metadatas(db)
1728
+ db_documents = get_documents(db)
1729
+ return db_documents, db_metadatas
1730
+
1731
+
1732
  def get_existing_files(db):
1733
+ metadatas = get_metadatas(db)
1734
+ metadata_sources = set([x['source'] for x in metadatas])
1735
  return metadata_sources
1736
 
1737
 
1738
  def get_existing_hash_ids(db):
1739
+ metadatas = get_metadatas(db)
1740
  # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
1741
+ metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas}
1742
  return metadata_hash_ids
1743
 
1744
 
 
 
 
 
1745
  def run_qa_db(**kwargs):
1746
  func_names = list(inspect.signature(_run_qa_db).parameters)
1747
  # hard-coded defaults
1748
  kwargs['answer_with_sources'] = True
 
1749
  kwargs['show_rank'] = False
1750
  missing_kwargs = [x for x in func_names if x not in kwargs]
1751
  assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
 
1763
  user_path=None,
1764
  detect_user_path_changes_every_query=False,
1765
  db_type='faiss',
1766
+ model_name=None, model=None, tokenizer=None, inference_server=None,
1767
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1768
  stream_output=False,
1769
  prompter=None,
 
1771
  prompt_dict=None,
1772
  answer_with_sources=True,
1773
  cut_distanct=1.1,
1774
+ sanitize_bot_response=False,
1775
  show_rank=False,
1776
  load_db_if_exists=False,
1777
  db=None,
 
1790
  document_choice=[DocumentChoices.All_Relevant.name],
1791
  n_jobs=-1,
1792
  verbose=False,
1793
+ cli=False,
1794
+ reverse_docs=True,
1795
+ lora_weights='',
1796
+ auto_reduce_chunks=True,
1797
+ max_chunks=100,
1798
+ ):
1799
  """
1800
 
1801
  :param query:
 
1814
  :param answer_with_sources
1815
  :return:
1816
  """
1817
+ if model is not None:
1818
+ assert model_name is not None # require so can make decisions
1819
  assert query is not None
1820
  assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
1821
  if prompter is not None:
 
1829
  prompt_dict = ''
1830
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1831
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1832
+ model=model,
1833
+ tokenizer=tokenizer,
1834
+ inference_server=inference_server,
1835
  stream_output=stream_output,
1836
  do_sample=do_sample,
1837
  temperature=temperature,
 
1847
  prompt_type=prompt_type,
1848
  prompt_dict=prompt_dict,
1849
  prompter=prompter,
1850
+ sanitize_bot_response=sanitize_bot_response,
1851
  verbose=verbose,
1852
  )
1853
 
 
 
 
 
1854
  use_context = False
1855
  scores = []
1856
  chain = None
 
1878
  # can only return if HF type
1879
  return
1880
 
1881
+ # context stuff similar to used in evaluate()
1882
+ import torch
1883
+ device, torch_dtype, context_class = get_device_dtype()
1884
+ with torch.no_grad():
1885
+ have_lora_weights = lora_weights not in [no_lora_str, '', None]
1886
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
1887
+ with context_class_cast(device):
1888
+ if stream_output and streamer:
1889
+ answer = None
1890
+ import queue
1891
+ bucket = queue.Queue()
1892
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
1893
+ thread.start()
1894
+ outputs = ""
1895
+ prompt = None # FIXME
1896
+ try:
1897
+ for new_text in streamer:
1898
+ # print("new_text: %s" % new_text, flush=True)
1899
+ if bucket.qsize() > 0 or thread.exc:
1900
+ thread.join()
1901
+ outputs += new_text
1902
+ if prompter: # and False: # FIXME: pipeline can already use prompter
1903
+ output1 = prompter.get_response(outputs, prompt=prompt,
1904
+ sanitize_bot_response=sanitize_bot_response)
1905
+ yield output1, ''
1906
+ else:
1907
+ yield outputs, ''
1908
+ except BaseException:
1909
+ # if any exception, raise that exception if was from thread, first
1910
+ if thread.exc:
1911
+ raise thread.exc
1912
+ raise
1913
+ finally:
1914
+ # in case no exception and didn't join with thread yet, then join
1915
+ if not thread.exc:
1916
+ answer = thread.join()
1917
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
1918
+ if thread.exc:
1919
+ raise thread.exc
1920
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
1921
+ # answer = outputs
1922
+ else:
1923
+ answer = chain()
1924
 
1925
  if not use_context:
1926
  ret = answer['output_text']
 
1939
  detect_user_path_changes_every_query=False,
1940
  db_type='faiss',
1941
  model_name=None,
1942
+ inference_server='',
1943
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1944
  prompt_type=None,
1945
  prompt_dict=None,
 
1951
  n_jobs=-1,
1952
  # beyond run_db_query:
1953
  llm=None,
1954
+ tokenizer=None,
1955
  verbose=False,
1956
  cmd=None,
1957
+ reverse_docs=True,
1958
+
1959
+ # local
1960
+ auto_reduce_chunks=True,
1961
+ max_chunks=100,
1962
  ):
1963
  # determine whether use of context out of docs is planned
1964
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
 
1973
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
1974
  # Chroma collection MyData contains fewer than 4 elements.
1975
  # type logger error
1976
+ if top_k_docs == -1:
1977
+ k_db = 1000 if db_type == 'chroma' else 100
1978
+ else:
1979
+ # top_k_docs=100 works ok too
1980
+ k_db = 1000 if db_type == 'chroma' else top_k_docs
1981
 
1982
  # FIXME: For All just go over all dbs instead of a separate db for All
1983
  if not detect_user_path_changes_every_query and db is not None:
 
1998
  n_jobs=n_jobs,
1999
  verbose=verbose)
2000
 
2001
+ if 'falcon' in model_name:
2002
+ extra = "According to only the information in the document sources provided within the context above, "
2003
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
2004
+ elif inference_server in ['openai', 'openai_chat']:
2005
+ extra = "According to (primarily) the information in the document sources provided within context above, "
2006
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
2007
+ else:
2008
+ extra = ""
2009
+ prefix = ""
2010
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2011
+ template_if_no_docs = template = """%s{context}{question}""" % prefix
2012
+ else:
2013
+ template = """%s
2014
+ \"\"\"
2015
+ {context}
2016
+ \"\"\"
2017
+ %s{question}""" % (prefix, extra)
2018
+ template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
2019
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2020
+ use_template = True
2021
+ else:
2022
+ use_template = False
2023
+
2024
  if db and use_context:
2025
  if not isinstance(db, Chroma):
2026
  # only chroma supports filtering
 
2041
  docs = []
2042
  scores = []
2043
  elif cmd == DocumentChoices.Only_All_Sources.name:
2044
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
 
 
 
2045
  # similar to langchain's chroma's _results_to_docs_and_scores
2046
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
2047
+ for result in zip(db_documents, db_metadatas)][:top_k_docs]
2048
  docs = [x[0] for x in docs_with_score]
2049
  scores = [x[1] for x in docs_with_score]
2050
  else:
2051
+ if top_k_docs == -1 or auto_reduce_chunks:
2052
+ # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2053
+ top_k_docs_tokenize = 100
2054
+ base_path = 'locks'
2055
+ makedirs(base_path)
2056
+ if hasattr(db, '_persist_directory'):
2057
+ name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
2058
+ else:
2059
+ name_path = "sim.lock"
2060
+ with filelock.FileLock(os.path.join(base_path, name_path)):
2061
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
2062
+ :top_k_docs_tokenize]
2063
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
2064
+ # more accurate
2065
+ tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
2066
+ template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
2067
+ elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
2068
+ 'weaviate']:
2069
+ # use ticktoken for faiss since embedding called differently
2070
+ tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
2071
+ template_tokens = llm.get_num_tokens(template)
2072
+ elif isinstance(tokenizer, FakeTokenizer):
2073
+ tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score]
2074
+ template_tokens = tokenizer.num_tokens_from_string(template)
2075
+ else:
2076
+ # in case model is not our pipeline with HF tokenizer
2077
+ tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in
2078
+ docs_with_score]
2079
+ template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1]
2080
+ tokens_cumsum = np.cumsum(tokens)
2081
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'):
2082
+ max_input_tokens = llm.pipeline.max_input_tokens
2083
+ elif inference_server in ['openai']:
2084
+ max_tokens = llm.modelname_to_contextsize(model_name)
2085
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2086
+ max_input_tokens = max_tokens - 256
2087
+ elif inference_server in ['openai_chat']:
2088
+ max_tokens = model_token_mapping[model_name]
2089
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2090
+ max_input_tokens = max_tokens - 256
2091
+ elif isinstance(tokenizer, FakeTokenizer):
2092
+ max_input_tokens = tokenizer.model_max_length - 256
2093
+ else:
2094
+ # leave some room for 1 paragraph, even if min_new_tokens=0
2095
+ max_input_tokens = 2048 - 256
2096
+ max_input_tokens -= template_tokens
2097
+ # FIXME: Doesn't account for query, == context, or new lines between contexts
2098
+ where_res = np.where(tokens_cumsum < max_input_tokens)[0]
2099
+ if where_res.shape[0] == 0:
2100
+ # then no chunk can fit, still do first one
2101
+ top_k_docs_trial = 1
2102
+ else:
2103
+ top_k_docs_trial = 1 + where_res[-1]
2104
+ if 0 < top_k_docs_trial < max_chunks:
2105
+ # avoid craziness
2106
+ if top_k_docs == -1:
2107
+ top_k_docs = top_k_docs_trial
2108
+ else:
2109
+ top_k_docs = min(top_k_docs, top_k_docs_trial)
2110
+ if top_k_docs == -1:
2111
+ # if here, means 0 and just do best with 1 doc
2112
+ print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True)
2113
+ top_k_docs = 1
2114
+ docs_with_score = docs_with_score[:top_k_docs]
2115
+ else:
2116
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
2117
+ # put most relevant chunks closest to question,
2118
+ # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
2119
+ # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
2120
+ if reverse_docs:
2121
+ docs_with_score.reverse()
2122
  # cut off so no high distance docs/sources considered
2123
  docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2124
  scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
 
2153
  if len(docs) == 0:
2154
  # avoid context == in prompt then
2155
  use_context = False
2156
+ template = template_if_no_docs
2157
 
2158
+ if use_template:
2159
  # instruct-like, rather than few-shot prompt_type='plain' as default
2160
  # but then sources confuse the model with how inserted among rest of text, so avoid
 
 
 
 
 
 
 
 
 
2161
  prompt = PromptTemplate(
2162
  # input_variables=["summaries", "question"],
2163
  input_variables=["context", "question"],
 
2217
  return ret, extra
2218
 
2219
 
2220
+ def clean_doc(docs1):
2221
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
2222
+ docs1 = [docs1]
2223
+ for doci, doc in enumerate(docs1):
2224
+ docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()])
2225
+ return docs1
2226
+
2227
+
2228
+ def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
2229
  if not chunk:
2230
  return sources
2231
+ if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
2232
+ # if just one document
2233
+ sources = [sources]
2234
+ if language and False:
2235
+ # Bug in langchain, keep separator=True not working
2236
+ # https://github.com/hwchase17/langchain/issues/2836
2237
+ # so avoid this for now
2238
+ keep_separator = True
2239
+ separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
2240
+ else:
2241
+ separators = ["\n\n", "\n", " ", ""]
2242
+ keep_separator = False
2243
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
2244
+ separators=separators)
2245
+ source_chunks = splitter.split_documents(sources)
2246
  return source_chunks
2247
 
2248
 
 
2290
  WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
2291
 
2292
  resource_owner_config = None
 
 
 
 
 
 
 
2293
  try:
2294
+ import weaviate
2295
+ if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
2296
+ resource_owner_config = weaviate.AuthClientPassword(
2297
+ username=WEAVIATE_USERNAME,
2298
+ password=WEAVIATE_PASSWORD,
2299
+ scope=WEAVIATE_SCOPE
2300
+ )
2301
+
2302
  client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
2303
+ return client
2304
  except Exception as e:
2305
  print(f"Failed to create Weaviate client: {e}")
2306
  return None
gradio_runner.py CHANGED
The diff for this file is too large to render. See raw diff
 
gradio_themes.py CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
3
  from typing import Iterable
4
 
5
  from gradio.themes.soft import Soft
6
- from gradio.themes import Color
7
  from gradio.themes.utils import colors, sizes, fonts
8
 
9
  h2o_yellow = Color(
@@ -36,6 +36,42 @@ h2o_gray = Color(
36
  )
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  class H2oTheme(Soft):
40
  def __init__(
41
  self,
@@ -158,19 +194,23 @@ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/
158
  '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
159
 
160
 
161
- def get_h2o_title(title):
162
- return f"""<div style="display:flex; justify-content:center; margin-bottom:30px;">
 
 
 
 
163
  <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
164
  <h1 style="line-height:60px">{title}</h1>
165
  </div>
166
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
167
- <img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png></img>
168
  </div>
169
  """
170
 
171
 
172
- def get_simple_title(title):
173
- return f"""<h1 align="center"> {title}</h1>"""
174
 
175
 
176
  def get_dark_js():
 
3
  from typing import Iterable
4
 
5
  from gradio.themes.soft import Soft
6
+ from gradio.themes import Color, Size
7
  from gradio.themes.utils import colors, sizes, fonts
8
 
9
  h2o_yellow = Color(
 
36
  )
37
 
38
 
39
+ text_xsm = Size(
40
+ name="text_xsm",
41
+ xxs="4px",
42
+ xs="5px",
43
+ sm="6px",
44
+ md="7px",
45
+ lg="8px",
46
+ xl="10px",
47
+ xxl="12px",
48
+ )
49
+
50
+
51
+ spacing_xsm = Size(
52
+ name="spacing_xsm",
53
+ xxs="1px",
54
+ xs="1px",
55
+ sm="1px",
56
+ md="2px",
57
+ lg="3px",
58
+ xl="5px",
59
+ xxl="7px",
60
+ )
61
+
62
+
63
+ radius_xsm = Size(
64
+ name="radius_xsm",
65
+ xxs="1px",
66
+ xs="1px",
67
+ sm="1px",
68
+ md="2px",
69
+ lg="3px",
70
+ xl="5px",
71
+ xxl="7px",
72
+ )
73
+
74
+
75
  class H2oTheme(Soft):
76
  def __init__(
77
  self,
 
194
  '11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
195
 
196
 
197
+ def get_h2o_title(title, description):
198
+ # NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
199
+ return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
200
+ {description}
201
+ </div>
202
+ <div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
203
  <div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
204
  <h1 style="line-height:60px">{title}</h1>
205
  </div>
206
  <div style="float:right; height: 80px; width: 80px; margin-top:-100px">
207
+ <img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
208
  </div>
209
  """
210
 
211
 
212
+ def get_simple_title(title, description):
213
+ return f"""{description}<h1 align="center"> {title}</h1>"""
214
 
215
 
216
  def get_dark_js():
gradio_utils/__pycache__/css.cpython-310.pyc ADDED
Binary file (1.53 kB). View file
 
gradio_utils/__pycache__/grclient.cpython-310.pyc ADDED
Binary file (2.69 kB). View file
 
gradio_utils/__pycache__/prompt_form.cpython-310.pyc ADDED
Binary file (3.59 kB). View file
 
gradio_utils/css.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_css(kwargs) -> str:
2
+ if kwargs['h2ocolors']:
3
+ css_code = """footer {visibility: hidden;}
4
+ body{background:linear-gradient(#f5f5f5,#e5e5e5);}
5
+ body.dark{background:linear-gradient(#000000,#0d0d0d);}
6
+ """
7
+ else:
8
+ css_code = """footer {visibility: hidden}"""
9
+
10
+ css_code += make_css_base()
11
+ return css_code
12
+
13
+
14
+ def make_css_base() -> str:
15
+ return """
16
+ @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
17
+
18
+ body.dark{#warning {background-color: #555555};}
19
+
20
+ #small_btn {
21
+ margin: 0.6em 0em 0.55em 0;
22
+ max-width: 20em;
23
+ min-width: 5em !important;
24
+ height: 5em;
25
+ font-size: 14px !important;
26
+ }
27
+
28
+ #prompt-form {
29
+ border: 1px solid var(--primary-500) !important;
30
+ }
31
+
32
+ #prompt-form.block {
33
+ border-radius: var(--block-radius) !important;
34
+ }
35
+
36
+ #prompt-form textarea {
37
+ border: 1px solid rgb(209, 213, 219);
38
+ }
39
+
40
+ #prompt-form label > div {
41
+ margin-top: 4px;
42
+ }
43
+
44
+ button.primary:hover {
45
+ background-color: var(--primary-600) !important;
46
+ transition: .2s;
47
+ }
48
+
49
+ #prompt-form-area {
50
+ margin-bottom: 2.5rem;
51
+ }
52
+ .chatsmall chatbot {font-size: 10px !important}
53
+ """
gradio_utils/grclient.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import traceback
2
+ from typing import Callable
3
+ import os
4
+
5
+ from gradio_client.client import Job
6
+
7
+ os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
8
+
9
+ from gradio_client import Client
10
+
11
+
12
+ class GradioClient(Client):
13
+ """
14
+ Parent class of gradio client
15
+ To handle automatically refreshing client if detect gradio server changed
16
+ """
17
+
18
+ def __init__(self, *args, **kwargs):
19
+ self.args = args
20
+ self.kwargs = kwargs
21
+ super().__init__(*args, **kwargs)
22
+ self.server_hash = self.get_server_hash()
23
+
24
+ def get_server_hash(self):
25
+ """
26
+ Get server hash using super without any refresh action triggered
27
+ Returns: git hash of gradio server
28
+ """
29
+ return super().submit(api_name='/system_hash').result()
30
+
31
+ def refresh_client_if_should(self):
32
+ # get current hash in order to update api_name -> fn_index map in case gradio server changed
33
+ # FIXME: Could add cli api as hash
34
+ server_hash = self.get_server_hash()
35
+ if self.server_hash != server_hash:
36
+ self.refresh_client()
37
+ self.server_hash = server_hash
38
+ else:
39
+ self.reset_session()
40
+
41
+ def refresh_client(self):
42
+ """
43
+ Ensure every client call is independent
44
+ Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
45
+ Returns:
46
+ """
47
+ # need session hash to be new every time, to avoid "generator already executing"
48
+ self.reset_session()
49
+
50
+ client = Client(*self.args, **self.kwargs)
51
+ for k, v in client.__dict__.items():
52
+ setattr(self, k, v)
53
+
54
+ def submit(
55
+ self,
56
+ *args,
57
+ api_name: str | None = None,
58
+ fn_index: int | None = None,
59
+ result_callbacks: Callable | list[Callable] | None = None,
60
+ ) -> Job:
61
+ # Note predict calls submit
62
+ try:
63
+ self.refresh_client_if_should()
64
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
65
+ except Exception as e:
66
+ print("Hit e=%s" % str(e), flush=True)
67
+ # force reconfig in case only that
68
+ self.refresh_client()
69
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
70
+
71
+ # see if immediately failed
72
+ e = job.future._exception
73
+ if e is not None:
74
+ print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
75
+ # force reconfig in case only that
76
+ self.refresh_client()
77
+ job = super().submit(*args, api_name=api_name, fn_index=fn_index)
78
+ e2 = job.future._exception
79
+ if e2 is not None:
80
+ print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
81
+
82
+ return job
gradio_utils/prompt_form.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+
4
+ import gradio as gr
5
+
6
+
7
+ def make_chatbots(output_label0, output_label0_model2, **kwargs):
8
+ text_outputs = []
9
+ chat_kwargs = []
10
+ for model_state_lock in kwargs['model_states']:
11
+ if os.environ.get('DEBUG_MODEL_LOCK'):
12
+ model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
13
+ else:
14
+ model_name = model_state_lock["base_model"]
15
+ output_label = f'h2oGPT [{model_name}]'
16
+ min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
17
+ chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
18
+ height=kwargs['height'] or 400, min_width=min_width))
19
+
20
+ if kwargs['model_lock_columns'] == -1:
21
+ kwargs['model_lock_columns'] = len(kwargs['model_states'])
22
+ if kwargs['model_lock_columns'] is None:
23
+ kwargs['model_lock_columns'] = 3
24
+
25
+ ncols = kwargs['model_lock_columns']
26
+ if kwargs['model_states'] == 0:
27
+ nrows = 0
28
+ else:
29
+ nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
30
+
31
+ if kwargs['model_lock_columns'] == 0:
32
+ # not using model_lock
33
+ pass
34
+ elif nrows <= 1:
35
+ with gr.Row():
36
+ for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
37
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
38
+ elif nrows == kwargs['model_states']:
39
+ with gr.Row():
40
+ for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
41
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
42
+ elif nrows == 2:
43
+ with gr.Row():
44
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
45
+ if mii >= len(kwargs['model_states']) / 2:
46
+ continue
47
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
48
+ with gr.Row():
49
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
50
+ if mii < len(kwargs['model_states']) / 2:
51
+ continue
52
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
53
+ elif nrows == 3:
54
+ with gr.Row():
55
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
56
+ if mii >= 1 * len(kwargs['model_states']) / 3:
57
+ continue
58
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
59
+ with gr.Row():
60
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
61
+ if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
62
+ continue
63
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
64
+ with gr.Row():
65
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
66
+ if mii < 2 * len(kwargs['model_states']) / 3:
67
+ continue
68
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
69
+ elif nrows >= 4:
70
+ with gr.Row():
71
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
72
+ if mii >= 1 * len(kwargs['model_states']) / 4:
73
+ continue
74
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
75
+ with gr.Row():
76
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
77
+ if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
78
+ continue
79
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
80
+ with gr.Row():
81
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
82
+ if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
83
+ continue
84
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
85
+ with gr.Row():
86
+ for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
87
+ if mii < 3 * len(kwargs['model_states']) / 4:
88
+ continue
89
+ text_outputs.append(gr.Chatbot(**chat_kwargs1))
90
+
91
+ with gr.Row():
92
+ text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
93
+ text_output2 = gr.Chatbot(label=output_label0_model2,
94
+ visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
95
+ return text_output, text_output2, text_outputs
96
+
97
+
98
+ def make_prompt_form(kwargs):
99
+ if kwargs['input_lines'] > 1:
100
+ instruction_label = "Shift-Enter to Submit, Enter for more lines"
101
+ else:
102
+ instruction_label = "Enter to Submit, Shift-Enter for more lines"
103
+
104
+ with gr.Row():#elem_id='prompt-form-area'):
105
+ with gr.Column(scale=50):
106
+ instruction = gr.Textbox(
107
+ lines=kwargs['input_lines'],
108
+ label='Ask anything',
109
+ placeholder=instruction_label,
110
+ info=None,
111
+ elem_id='prompt-form',
112
+ container=True,
113
+ )
114
+ with gr.Row():
115
+ submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
116
+ stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
117
+
118
+ return instruction, submit, stop_btn
h2oai_pipeline.py CHANGED
@@ -9,7 +9,7 @@ from prompter import Prompter, PromptType
9
 
10
  class H2OTextGenerationPipeline(TextGenerationPipeline):
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
- sanitize_bot_response=True,
13
  use_prompter=True, prompter=None,
14
  prompt_type=None, prompt_dict=None,
15
  max_input_tokens=2048 - 256, **kwargs):
@@ -51,25 +51,37 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
51
  self.sanitize_bot_response = sanitize_bot_response
52
  self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
53
 
54
- def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
55
- if hasattr(self.tokenizer, 'model_max_length'):
 
 
 
56
  # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
57
- model_max_length = self.tokenizer.model_max_length
 
 
 
 
 
 
 
 
 
58
  else:
59
  # unknown
60
  model_max_length = None
61
 
62
- verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
63
  if model_max_length is not None:
64
- num_prompt_tokens = None
65
  # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
66
  # For https://github.com/h2oai/h2ogpt/issues/192
67
  for trial in range(0, 3):
68
- prompt_tokens = self.tokenizer(prompt_text)['input_ids']
69
  num_prompt_tokens = len(prompt_tokens)
70
  if num_prompt_tokens > model_max_length:
71
  # conservative by using int()
72
  chars_per_token = int(len(prompt_text) / num_prompt_tokens)
 
73
  prompt_text = prompt_text[-model_max_length * chars_per_token:]
74
  if verbose:
75
  print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
@@ -79,20 +91,27 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
79
  print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
80
  break
81
 
82
- # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
83
- #
84
- if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
85
- # then give room for prompt
86
- fudge = 20
87
- else:
88
- fudge = 0
89
- assert num_prompt_tokens is not None
90
- max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
91
- model_max_length - (num_prompt_tokens + fudge)))
92
- if max_new_tokens < generate_kwargs['max_new_tokens']:
93
- if verbose:
94
- print("Reduced max_new_tokens from %s -> %s" % (generate_kwargs['max_new_tokens'], max_new_tokens))
95
- generate_kwargs['max_new_tokens'] = max_new_tokens
 
 
 
 
 
 
 
96
 
97
  data_point = dict(context='', instruction=prompt_text, input='')
98
  if self.prompter is not None:
@@ -100,7 +119,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
100
  self.prompt_text = prompt_text
101
  if handle_long_generation is None:
102
  # forces truncation of inputs to avoid critical failure
103
- handle_long_generation = 'hole'
104
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
105
  **generate_kwargs)
106
 
@@ -113,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
113
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
114
  sanitize_bot_response=self.sanitize_bot_response)
115
  elif self.bot and self.human:
116
- outputs = rec['generated_text'].split(self.bot)[1].strip().split(self.human)[0].strip()
117
  else:
118
  outputs = rec['generated_text']
119
  rec['generated_text'] = outputs
@@ -123,7 +142,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
123
  if self.can_stop:
124
  stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
125
  self.tokenizer, self.device,
126
- human=self.human, bot=self.bot)
 
127
  generate_kwargs['stopping_criteria'] = stopping_criteria
128
  # return super()._forward(model_inputs, **generate_kwargs)
129
  return self.__forward(model_inputs, **generate_kwargs)
 
9
 
10
  class H2OTextGenerationPipeline(TextGenerationPipeline):
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
+ sanitize_bot_response=False,
13
  use_prompter=True, prompter=None,
14
  prompt_type=None, prompt_dict=None,
15
  max_input_tokens=2048 - 256, **kwargs):
 
51
  self.sanitize_bot_response = sanitize_bot_response
52
  self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
53
 
54
+ @staticmethod
55
+ def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
56
+ verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
57
+
58
+ if hasattr(tokenizer, 'model_max_length'):
59
  # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
60
+ model_max_length = tokenizer.model_max_length
61
+ if max_prompt_length is not None:
62
+ model_max_length = min(model_max_length, max_prompt_length)
63
+ # cut at some upper likely limit to avoid excessive tokenization etc
64
+ # upper bound of 10 chars/token, e.g. special chars sometimes are long
65
+ if len(prompt_text) > model_max_length * 10:
66
+ len0 = len(prompt_text)
67
+ prompt_text = prompt_text[-model_max_length * 10:]
68
+ if verbose:
69
+ print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
70
  else:
71
  # unknown
72
  model_max_length = None
73
 
74
+ num_prompt_tokens = None
75
  if model_max_length is not None:
 
76
  # can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
77
  # For https://github.com/h2oai/h2ogpt/issues/192
78
  for trial in range(0, 3):
79
+ prompt_tokens = tokenizer(prompt_text)['input_ids']
80
  num_prompt_tokens = len(prompt_tokens)
81
  if num_prompt_tokens > model_max_length:
82
  # conservative by using int()
83
  chars_per_token = int(len(prompt_text) / num_prompt_tokens)
84
+ # keep tail, where question is if using langchain
85
  prompt_text = prompt_text[-model_max_length * chars_per_token:]
86
  if verbose:
87
  print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
 
91
  print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
92
  break
93
 
94
+ # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
95
+ if False:
96
+ # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
97
+ #
98
+ assert num_prompt_tokens is not None
99
+ if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
100
+ # then give room for prompt
101
+ fudge = 20
102
+ else:
103
+ fudge = 0
104
+ max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
105
+ model_max_length - (num_prompt_tokens + fudge)))
106
+ if max_new_tokens < generate_kwargs['max_new_tokens']:
107
+ if verbose:
108
+ print("Reduced max_new_tokens from %s -> %s" % (
109
+ generate_kwargs['max_new_tokens'], max_new_tokens))
110
+ generate_kwargs['max_new_tokens'] = max_new_tokens
111
+ return prompt_text, num_prompt_tokens
112
+
113
+ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
+ prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
 
116
  data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
 
119
  self.prompt_text = prompt_text
120
  if handle_long_generation is None:
121
  # forces truncation of inputs to avoid critical failure
122
+ handle_long_generation = None # disable with new approaches
123
  return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
124
  **generate_kwargs)
125
 
 
132
  outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
133
  sanitize_bot_response=self.sanitize_bot_response)
134
  elif self.bot and self.human:
135
+ outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
136
  else:
137
  outputs = rec['generated_text']
138
  rec['generated_text'] = outputs
 
142
  if self.can_stop:
143
  stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
144
  self.tokenizer, self.device,
145
+ human=self.human, bot=self.bot,
146
+ model_max_length=self.tokenizer.model_max_length)
147
  generate_kwargs['stopping_criteria'] = stopping_criteria
148
  # return super()._forward(model_inputs, **generate_kwargs)
149
  return self.__forward(model_inputs, **generate_kwargs)
iterators/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator
2
+ from .iterator_pipe import IteratorPipe, AsyncIteratorPipe
3
+
4
+ __all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"]
iterators/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (337 Bytes). View file
 
iterators/__pycache__/iterator_pipe.cpython-310.pyc ADDED
Binary file (2.71 kB). View file
 
iterators/__pycache__/timeout_iterator.cpython-310.pyc ADDED
Binary file (5.63 kB). View file
 
iterators/iterator_pipe.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ import asyncio
3
+
4
+
5
+ class IteratorPipe:
6
+ """
7
+ Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution
8
+ """
9
+
10
+ def __init__(self, sentinel=object()):
11
+ self._q = queue.Queue()
12
+ self._sentinel = sentinel
13
+ self._sentinel_pushed = False
14
+ self._closed = False
15
+
16
+ def __iter__(self):
17
+ return self
18
+
19
+ def __next__(self):
20
+ if self._closed:
21
+ raise StopIteration
22
+
23
+ data = self._q.get(block=True)
24
+ if data is self._sentinel:
25
+ self._closed = True
26
+ raise StopIteration
27
+
28
+ return data
29
+
30
+ def put(self, data) -> bool:
31
+ """
32
+ Pushes next item to Iterator and returns True
33
+ If iterator has been closed via close(), doesn't push anything and returns False
34
+ """
35
+ if self._sentinel_pushed:
36
+ return False
37
+
38
+ self._q.put(data)
39
+ return True
40
+
41
+ def close(self):
42
+ """
43
+ Close is idempotent. Calling close multiple times is safe
44
+ Iterator will raise StopIteration only after all elements pushed before close have been iterated
45
+ """
46
+ # make close idempotent
47
+ if not self._sentinel_pushed:
48
+ self._sentinel_pushed = True
49
+ self._q.put(self._sentinel)
50
+
51
+
52
+ class AsyncIteratorPipe:
53
+
54
+ def __init__(self, sentinel=object()):
55
+ self._q = asyncio.Queue()
56
+ self._sentinel = sentinel
57
+ self._sentinel_pushed = False
58
+ self._closed = False
59
+
60
+ def __aiter__(self):
61
+ return self
62
+
63
+ async def __anext__(self):
64
+ if self._closed:
65
+ raise StopAsyncIteration
66
+
67
+ data = await self._q.get()
68
+ if data is self._sentinel:
69
+ self._closed = True
70
+ raise StopAsyncIteration
71
+
72
+ return data
73
+
74
+ async def put(self, data) -> bool:
75
+ """
76
+ Pushes next item to Iterator and returns True
77
+ If iterator has been closed via close(), doesn't push anything and returns False
78
+ """
79
+ if self._sentinel_pushed:
80
+ return False
81
+
82
+ await self._q.put(data)
83
+ return True
84
+
85
+ async def close(self):
86
+ """
87
+ Close is idempotent. Calling close multiple times is safe
88
+ Iterator will raise StopIteration only after all elements pushed before close have been iterated
89
+ """
90
+ # make close idempotent
91
+ if not self._sentinel_pushed:
92
+ self._sentinel_pushed = True
93
+ await self._q.put(self._sentinel)
iterators/timeout_iterator.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import queue
2
+ import asyncio
3
+ import threading
4
+ import traceback
5
+
6
+
7
+ class TimeoutIterator:
8
+ """
9
+ Wrapper class to add timeout feature to synchronous iterators
10
+ - timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout()
11
+ - sentinel: the object returned by iterator when timeout happens
12
+ - reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration
13
+
14
+ TimeoutIterator uses a thread internally.
15
+ The thread stops once the iterator exhausts or raises an exception during iteration.
16
+
17
+ Any exceptions raised within the wrapped iterator are propagated as it is.
18
+ Exception is raised when all elements generated by the actual iterator before exception have been consumed
19
+ Timeout can be set dynamically before going for iteration
20
+ """
21
+ ZERO_TIMEOUT = 0.0
22
+
23
+ def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True):
24
+ self._iterator = iterator
25
+ self._timeout = timeout
26
+ self._sentinel = sentinel
27
+ self._reset_on_next = reset_on_next
28
+ self._raise_on_exception = raise_on_exception
29
+
30
+ self._interrupt = False
31
+ self._done = False
32
+ self._buffer = queue.Queue()
33
+ self._thread = threading.Thread(target=self.__lookahead)
34
+ self._thread.start()
35
+
36
+ def get_sentinel(self):
37
+ return self._sentinel
38
+
39
+ def set_reset_on_next(self, reset_on_next):
40
+ self._reset_on_next = reset_on_next
41
+
42
+ def set_timeout(self, timeout: float):
43
+ """
44
+ Set timeout for next iteration
45
+ """
46
+ self._timeout = timeout
47
+
48
+ def interrupt(self):
49
+ """
50
+ interrupt and stop the underlying thread.
51
+ the thread acutally dies only after interrupt has been set and
52
+ the underlying iterator yields a value after that.
53
+ """
54
+ self._interrupt = True
55
+
56
+ def __iter__(self):
57
+ return self
58
+
59
+ def __next__(self):
60
+ """
61
+ yield the result from iterator
62
+ if timeout > 0:
63
+ yield data if available.
64
+ otherwise yield sentinal
65
+ """
66
+ if self._done:
67
+ raise StopIteration
68
+
69
+ data = self._sentinel
70
+ try:
71
+ if self._timeout > self.ZERO_TIMEOUT:
72
+ data = self._buffer.get(timeout=self._timeout)
73
+ else:
74
+ data = self._buffer.get()
75
+ except queue.Empty:
76
+ pass
77
+ finally:
78
+ # see if timeout needs to be reset
79
+ if self._reset_on_next:
80
+ self._timeout = self.ZERO_TIMEOUT
81
+
82
+ # propagate any exceptions including StopIteration
83
+ if isinstance(data, BaseException):
84
+ self._done = True
85
+ if isinstance(data, StopIteration):
86
+ raise data
87
+ ex = ''.join(traceback.format_tb(data.__traceback__))
88
+ print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True)
89
+ if self._raise_on_exception:
90
+ raise data
91
+ else:
92
+ return data
93
+
94
+ return data
95
+
96
+ def __lookahead(self):
97
+ try:
98
+ while True:
99
+ self._buffer.put(next(self._iterator))
100
+ if self._interrupt:
101
+ raise StopIteration()
102
+ except BaseException as e:
103
+ self._buffer.put(e)
104
+
105
+
106
+ class AsyncTimeoutIterator:
107
+ """
108
+ Async version of TimeoutIterator. See method documentation of TimeoutIterator
109
+ """
110
+ ZERO_TIMEOUT = 0.0
111
+
112
+ def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False):
113
+ self._iterator = iterator
114
+ self._timeout = timeout
115
+ self._sentinel = sentinel
116
+ self._reset_on_next = reset_on_next
117
+
118
+ self._interrupt = False
119
+ self._done = False
120
+ self._buffer = asyncio.Queue()
121
+ self._task = asyncio.get_event_loop().create_task(self.__lookahead())
122
+
123
+ def get_sentinel(self):
124
+ return self._sentinel
125
+
126
+ def set_reset_on_next(self, reset_on_next):
127
+ self._reset_on_next = reset_on_next
128
+
129
+ def set_timeout(self, timeout: float):
130
+ self._timeout = timeout
131
+
132
+ def interrupt(self):
133
+ self._interrupt = True
134
+
135
+ def __aiter__(self):
136
+ return self
137
+
138
+ async def __anext__(self):
139
+ if self._done:
140
+ raise StopAsyncIteration
141
+
142
+ data = self._sentinel
143
+ try:
144
+ if self._timeout > self.ZERO_TIMEOUT:
145
+ data = await asyncio.wait_for(self._buffer.get(), self._timeout)
146
+ else:
147
+ data = await self._buffer.get()
148
+ except asyncio.TimeoutError:
149
+ pass
150
+ finally:
151
+ # see if timeout needs to be reset
152
+ if self._reset_on_next:
153
+ self._timeout = self.ZERO_TIMEOUT
154
+
155
+ # propagate any exceptions including StopIteration
156
+ if isinstance(data, BaseException):
157
+ self._done = True
158
+ raise data
159
+
160
+ return data
161
+
162
+ async def __lookahead(self):
163
+ try:
164
+ while True:
165
+ data = await self._iterator.__anext__()
166
+ await self._buffer.put(data)
167
+ if self._interrupt:
168
+ raise StopAsyncIteration()
169
+ except BaseException as e:
170
+ await self._buffer.put(e)
loaders.py CHANGED
@@ -1,6 +1,8 @@
1
- def get_loaders(llama_type, model_name, reward_type):
2
  # NOTE: Some models need specific new prompt_type
3
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
 
 
4
  if llama_type:
5
  from transformers import LlamaForCausalLM, LlamaTokenizer
6
  model_loader = LlamaForCausalLM
@@ -39,7 +41,8 @@ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resu
39
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
40
  local_files_only=local_files_only,
41
  resume_download=resume_download,
42
- use_auth_token=use_auth_token)
 
43
 
44
  tokenizer.pad_token_id = 0 # different from the eos token
45
  # when generating, we will use the logits of right-most token to predict the next token
 
1
+ def get_loaders(model_name, reward_type, llama_type=None):
2
  # NOTE: Some models need specific new prompt_type
3
  # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4
+ if llama_type is None:
5
+ llama_type = "llama" in model_name.lower()
6
  if llama_type:
7
  from transformers import LlamaForCausalLM, LlamaTokenizer
8
  model_loader = LlamaForCausalLM
 
41
  tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
42
  local_files_only=local_files_only,
43
  resume_download=resume_download,
44
+ use_auth_token=use_auth_token,
45
+ padding_side='left')
46
 
47
  tokenizer.pad_token_id = 0 # different from the eos token
48
  # when generating, we will use the logits of right-most token to predict the next token
prompter.py CHANGED
@@ -1,10 +1,10 @@
 
1
  import ast
2
  import time
3
  from enums import PromptType # also supports imports from this file from other files
4
 
5
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
6
 
7
-
8
  prompt_type_to_model_name = {
9
  'plain': [
10
  'EleutherAI/gpt-j-6B',
@@ -25,23 +25,29 @@ prompt_type_to_model_name = {
25
  'mosaicml/mpt-7b-storywriter',
26
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
27
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
28
- 'gptj', # internally handles prompting
29
- 'llama', # plain, or need to choose prompt_type for given TheBloke model
30
- 'gpt4all_llama', # internally handles prompting
31
  ],
 
32
  'prompt_answer': [
33
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
34
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
35
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
36
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
37
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
38
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
39
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
40
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
41
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
 
42
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
43
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
44
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
 
 
 
 
 
 
 
 
 
 
45
  ],
46
  'instruct': [],
47
  'instruct_with_end': ['databricks/dolly-v2-12b'],
@@ -54,6 +60,7 @@ prompt_type_to_model_name = {
54
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
55
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
56
  'h2oai/h2ogpt-research-oasst1-512-30b',
 
57
  'h2oai/h2ogpt-oasst1-falcon-40b',
58
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
59
  ],
@@ -66,7 +73,16 @@ prompt_type_to_model_name = {
66
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
67
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
68
  "instruct_simple": ['JosephusCheung/Guanaco'],
 
 
 
 
69
  }
 
 
 
 
 
70
 
71
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
72
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
@@ -80,18 +96,29 @@ for p in PromptType:
80
  prompt_types.extend([p.name, p.value, str(p.value)])
81
 
82
 
83
- def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
84
  prompt_dict_error = ''
 
 
85
  if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
86
  try:
87
  prompt_dict = ast.literal_eval(prompt_dict)
88
  except BaseException as e:
89
  prompt_dict_error = str(e)
90
- if prompt_dict_error:
91
- return dict(), prompt_dict_error
92
-
93
- if prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
94
- PromptType.custom.name]:
 
 
 
 
 
 
 
 
 
95
  promptA = prompt_dict.get('promptA', '')
96
  promptB = prompt_dict('promptB', '')
97
  PreInstruct = prompt_dict.get('PreInstruct', '')
@@ -99,21 +126,23 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
99
  PreResponse = prompt_dict.get('PreResponse', '')
100
  terminate_response = prompt_dict.get('terminate_response', None)
101
  chat_sep = prompt_dict.get('chat_sep', '\n')
 
102
  humanstr = prompt_dict.get('humanstr', '')
103
  botstr = prompt_dict.get('botstr', '')
104
  elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
105
  PromptType.plain.name]:
106
- promptA = promptB = PreInstruct = PreInput = PreResponse = ''
107
  terminate_response = []
108
- chat_sep = ''
109
- humanstr = ''
110
- botstr = ''
 
111
  elif prompt_type == 'simple_instruct':
112
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
113
  terminate_response = []
114
- chat_sep = '\n'
115
- humanstr = ''
116
- botstr = ''
117
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
118
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
119
  str(PromptType.instruct_with_end.value),
@@ -139,7 +168,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
139
  terminate_response = ['### End']
140
  else:
141
  terminate_response = None
142
- chat_sep = '\n'
143
  humanstr = PreInstruct
144
  botstr = PreResponse
145
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
@@ -161,7 +190,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
161
  ### Response:
162
  """
163
  terminate_response = None
164
- chat_sep = '\n'
165
  humanstr = PreInstruct # first thing human says
166
  botstr = PreResponse # first thing bot says
167
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
@@ -183,14 +212,14 @@ Current Time: {}
183
 
184
  """
185
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
186
- start = human
187
- promptB = promptA = '%s%s ' % (preprompt, start)
188
 
189
- PreInstruct = ""
190
 
191
  PreInput = None
192
 
193
- if reduced:
194
  # when making context, want it to appear as-if LLM generated, which starts with space after :
195
  PreResponse = bot + ' '
196
  else:
@@ -198,10 +227,11 @@ Current Time: {}
198
  # if add space here, non-unique tokenization will often make LLM produce wrong output
199
  PreResponse = bot
200
 
201
- terminate_response = [start, PreResponse]
202
- chat_sep = '\n'
203
  humanstr = human # tag before human talks
204
  botstr = bot # tag before bot talks
 
205
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
206
  PromptType.dai_faq.name]:
207
  promptA = ''
@@ -217,7 +247,7 @@ Current Time: {}
217
  ### Driverless AI documentation answer:
218
  """
219
  terminate_response = ['\n\n']
220
- chat_sep = terminate_response
221
  humanstr = PreInstruct
222
  botstr = PreResponse
223
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
@@ -226,7 +256,7 @@ Current Time: {}
226
  PreInstruct = '## Main Text\n\n'
227
  PreResponse = '\n\n## Summary\n\n'
228
  terminate_response = None
229
- chat_sep = '\n'
230
  humanstr = PreInstruct
231
  botstr = PreResponse
232
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
@@ -246,7 +276,7 @@ Current Time: {}
246
  """
247
  terminate_response = [
248
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
249
- chat_sep = '\n'
250
  humanstr = PreInstruct
251
  botstr = PreResponse
252
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
@@ -254,33 +284,50 @@ Current Time: {}
254
  preprompt = ''
255
  prompt_tokens = "<|prompt|>"
256
  answer_tokens = "<|answer|>"
257
- start = prompt_tokens
258
  promptB = promptA = '%s%s' % (preprompt, start)
259
- PreInstruct = ""
260
  PreInput = None
261
  PreResponse = answer_tokens
262
  eos = '<|endoftext|>' # neox eos
263
- terminate_response = [start, PreResponse, eos]
264
- chat_sep = eos
265
  humanstr = prompt_tokens
266
  botstr = answer_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
268
  PromptType.open_assistant.name]:
269
  # From added_tokens.json
270
  preprompt = ''
271
  prompt_tokens = "<|prompter|>"
272
  answer_tokens = "<|assistant|>"
273
- start = prompt_tokens
274
  promptB = promptA = '%s%s' % (preprompt, start)
275
- PreInstruct = ""
276
  PreInput = None
277
  PreResponse = answer_tokens
278
  pend = "<|prefix_end|>"
279
  eos = "</s>"
280
- terminate_response = [start, PreResponse, pend, eos]
281
- chat_sep = eos
282
  humanstr = prompt_tokens
283
  botstr = answer_tokens
 
 
284
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
285
  PromptType.wizard_lm.name]:
286
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
@@ -292,7 +339,7 @@ Current Time: {}
292
  PreResponse = "\n\n### Response\n"
293
  eos = "</s>"
294
  terminate_response = [PreResponse, eos]
295
- chat_sep = eos
296
  humanstr = promptA
297
  botstr = PreResponse
298
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
@@ -308,13 +355,12 @@ Current Time: {}
308
  ### Assistant:
309
  """
310
  terminate_response = [PreResponse]
311
- chat_sep = '\n'
312
  humanstr = PreInstruct
313
  botstr = PreResponse
314
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
315
  PromptType.instruct_vicuna2.name]:
316
- promptA = promptB = "" if not (
317
- chat and reduced) else ''
318
 
319
  PreInstruct = """
320
  HUMAN:
@@ -327,13 +373,12 @@ ASSISTANT:
327
  """
328
  terminate_response = [
329
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
330
- chat_sep = '\n'
331
  humanstr = PreInstruct
332
  botstr = PreResponse
333
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
334
  PromptType.instruct_vicuna3.name]:
335
- promptA = promptB = "" if not (
336
- chat and reduced) else ''
337
 
338
  PreInstruct = """
339
  ### User:
@@ -346,13 +391,14 @@ ASSISTANT:
346
  """
347
  terminate_response = [
348
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
349
- chat_sep = '\n'
350
  humanstr = PreInstruct
351
  botstr = PreResponse
352
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
353
  PromptType.wizard2.name]:
354
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
355
- preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
 
356
  start = ''
357
  promptB = promptA = '%s%s' % (preprompt, start)
358
  PreInstruct = """
@@ -363,27 +409,39 @@ ASSISTANT:
363
  ### Response:
364
  """
365
  terminate_response = [PreResponse]
366
- chat_sep = '\n'
367
  humanstr = PreInstruct
368
  botstr = PreResponse
369
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
370
  PromptType.wizard3.name]:
371
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
372
- preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
 
373
  start = ''
374
  promptB = promptA = '%s%s' % (preprompt, start)
375
  PreInstruct = """USER: """
376
  PreInput = None
377
  PreResponse = """ASSISTANT: """
378
  terminate_response = [PreResponse]
379
- chat_sep = '\n'
 
 
 
 
 
 
 
 
 
 
 
 
380
  humanstr = PreInstruct
381
  botstr = PreResponse
382
 
383
  elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
384
  PromptType.instruct_simple.name]:
385
- promptA = '' if not (chat and reduced) else ''
386
- promptB = '' if not (chat and reduced) else ''
387
 
388
  PreInstruct = """
389
  ### Instruction:
@@ -397,21 +455,90 @@ ASSISTANT:
397
  ### Response:
398
  """
399
  terminate_response = None
400
- chat_sep = '\n'
401
  humanstr = PreInstruct
402
  botstr = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
403
  else:
404
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
405
 
406
- if return_dict:
407
- return dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
 
 
408
  PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
409
- humanstr=humanstr, botstr=botstr), ''
 
 
 
 
 
410
  else:
411
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
412
 
413
 
414
- def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
415
  context = data_point.get('context')
416
  if context is None:
417
  context = ''
@@ -422,9 +549,12 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
422
  prompt_dict = data_point.get('prompt_dict', prompt_dict)
423
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
424
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
425
- terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, prompt_dict, chat, context, reduced)
 
 
426
 
427
- prompt = context if not reduced else ''
 
428
 
429
  if input and promptA:
430
  prompt += f"""{promptA}"""
@@ -433,37 +563,37 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
433
 
434
  if instruction and PreInstruct is not None and input and PreInput is not None:
435
  prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
436
- prompt = inject_newline(prompt_type, prompt)
437
  elif instruction and input and PreInstruct is None and PreInput is not None:
438
  prompt += f"""{PreInput}{instruction}
439
  {input}"""
440
- prompt = inject_newline(prompt_type, prompt)
441
  elif input and instruction and PreInput is None and PreInstruct is not None:
442
  prompt += f"""{PreInstruct}{instruction}
443
  {input}"""
444
- prompt = inject_newline(prompt_type, prompt)
445
  elif instruction and PreInstruct is not None:
446
  prompt += f"""{PreInstruct}{instruction}"""
447
- prompt = inject_newline(prompt_type, prompt)
448
  elif input and PreInput is not None:
449
  prompt += f"""{PreInput}{input}"""
450
- prompt = inject_newline(prompt_type, prompt)
451
  elif input and instruction and PreInput is not None:
452
  prompt += f"""{PreInput}{instruction}{input}"""
453
- prompt = inject_newline(prompt_type, prompt)
454
  elif input and instruction and PreInstruct is not None:
455
  prompt += f"""{PreInstruct}{instruction}{input}"""
456
- prompt = inject_newline(prompt_type, prompt)
457
  elif input and instruction:
458
  # i.e. for simple_instruct
459
  prompt += f"""{instruction}: {input}"""
460
- prompt = inject_newline(prompt_type, prompt)
461
  elif input:
462
  prompt += f"""{input}"""
463
- prompt = inject_newline(prompt_type, prompt)
464
  elif instruction:
465
  prompt += f"""{instruction}"""
466
- prompt = inject_newline(prompt_type, prompt)
467
 
468
  if PreResponse is not None:
469
  prompt += f"""{PreResponse}"""
@@ -474,13 +604,13 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
474
  if output:
475
  prompt += f"""{output}"""
476
 
477
- return prompt, pre_response, terminate_response, chat_sep
478
 
479
 
480
- def inject_newline(prompt_type, prompt):
481
- if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
482
  # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
483
- prompt += '\n'
484
  return prompt
485
 
486
 
@@ -489,9 +619,6 @@ class Prompter(object):
489
  allowed_repeat_line_length=10):
490
  self.prompt_type = prompt_type
491
  self.prompt_dict = prompt_dict
492
- data_point = dict(instruction='', input='', output='')
493
- _, self.pre_response, self.terminate_response, self.chat_sep = \
494
- generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
495
  self.debug = debug
496
  self.chat = chat
497
  self.stream_output = stream_output
@@ -500,23 +627,41 @@ class Prompter(object):
500
  self.prompt = None
501
  context = "" # not for chat context
502
  reduced = False # not for chat context
 
503
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
504
- self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
505
- get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced)
506
-
507
- def generate_prompt(self, data_point):
508
- reduced = False
509
- prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced)
 
 
 
 
 
 
 
 
 
 
510
  if self.debug:
511
- print("prompt: ", prompt, flush=True)
 
 
 
 
 
 
 
512
  self.prompt = prompt
513
  return prompt
514
 
515
- def get_response(self, outputs, prompt=None, sanitize_bot_response=True):
516
  if isinstance(outputs, str):
517
  outputs = [outputs]
518
  if self.debug:
519
- print("output:\n", '\n\n'.join(outputs), flush=True)
520
  if prompt is not None:
521
  self.prompt = prompt
522
 
@@ -527,7 +672,8 @@ class Prompter(object):
527
  if sanitize_bot_response:
528
  from better_profanity import profanity
529
  response = profanity.censor(response)
530
- response = response.strip("\n")
 
531
  return response
532
 
533
  def clean_repeats(response):
@@ -549,12 +695,12 @@ class Prompter(object):
549
  # then use most basic parsing like pipeline
550
  if self.botstr in output:
551
  if self.humanstr:
552
- output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
553
  else:
554
  # i.e. use after bot but only up to next bot
555
- output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
556
  else:
557
- # output = clean_response(output.strip())
558
  # assume just not printed yet
559
  output = ""
560
  else:
@@ -581,9 +727,9 @@ class Prompter(object):
581
  allow_terminate = True
582
  output = output[len(prompt):]
583
  # clean after subtract prompt out, so correct removal of pre_response
584
- output = clean_response(output).strip()
585
  if self.repeat_penalty:
586
- output = clean_repeats(output).strip()
587
  if self.terminate_response and allow_terminate:
588
  finds = []
589
  for term in self.terminate_response:
@@ -591,11 +737,9 @@ class Prompter(object):
591
  finds = [x for x in finds if x >= 0]
592
  if len(finds) > 0:
593
  termi = finds[0]
594
- output = output[:termi].strip()
595
  else:
596
- output = output.strip()
597
- else:
598
- output = output.strip()
599
  if multi_output:
600
  # prefix with output counter
601
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
@@ -606,5 +750,5 @@ class Prompter(object):
606
  # join all outputs, only one extra new line between outputs
607
  output = '\n'.join(outputs)
608
  if self.debug:
609
- print("outputclean:\n", '\n\n'.join(outputs), flush=True)
610
  return output
 
1
+ import os
2
  import ast
3
  import time
4
  from enums import PromptType # also supports imports from this file from other files
5
 
6
  non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
7
 
 
8
  prompt_type_to_model_name = {
9
  'plain': [
10
  'EleutherAI/gpt-j-6B',
 
25
  'mosaicml/mpt-7b-storywriter',
26
  'mosaicml/mpt-7b-instruct', # internal code handles instruct
27
  'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
28
+ 'mosaicml/mpt-30b-instruct', # internal code handles instruct
 
 
29
  ],
30
+ 'gptj': ['gptj', 'gpt4all_llama'],
31
  'prompt_answer': [
32
  'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
33
  'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
34
  'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
 
 
 
 
35
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
36
  'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
37
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
38
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
39
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
40
  'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
41
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
42
+ 'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
43
+ 'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
44
+ ],
45
+ 'prompt_answer_openllama': [
46
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
47
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
48
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
49
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
50
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
51
  ],
52
  'instruct': [],
53
  'instruct_with_end': ['databricks/dolly-v2-12b'],
 
60
  'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
61
  'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
62
  'h2oai/h2ogpt-research-oasst1-512-30b',
63
+ 'h2oai/h2ogpt-research-oasst1-llama-65b',
64
  'h2oai/h2ogpt-oasst1-falcon-40b',
65
  'h2oai/h2ogpt-oig-oasst1-falcon-40b',
66
  ],
 
73
  "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
74
  "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
75
  "instruct_simple": ['JosephusCheung/Guanaco'],
76
+ "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
77
+ "wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
78
+ "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
+ # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
80
  }
81
+ if os.getenv('OPENAI_API_KEY'):
82
+ prompt_type_to_model_name.update({
83
+ "openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
84
+ "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
85
+ })
86
 
87
  inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
88
  inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
 
96
  prompt_types.extend([p.name, p.value, str(p.value)])
97
 
98
 
99
+ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
100
  prompt_dict_error = ''
101
+ generates_leading_space = False
102
+
103
  if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
104
  try:
105
  prompt_dict = ast.literal_eval(prompt_dict)
106
  except BaseException as e:
107
  prompt_dict_error = str(e)
108
+ if prompt_dict_error:
109
+ promptA = None
110
+ promptB = None
111
+ PreInstruct = None
112
+ PreInput = ''
113
+ PreResponse = ''
114
+ terminate_response = None
115
+ chat_sep = ''
116
+ chat_turn_sep = ''
117
+ humanstr = ''
118
+ botstr = ''
119
+ generates_leading_space = False
120
+ elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
121
+ PromptType.custom.name]:
122
  promptA = prompt_dict.get('promptA', '')
123
  promptB = prompt_dict('promptB', '')
124
  PreInstruct = prompt_dict.get('PreInstruct', '')
 
126
  PreResponse = prompt_dict.get('PreResponse', '')
127
  terminate_response = prompt_dict.get('terminate_response', None)
128
  chat_sep = prompt_dict.get('chat_sep', '\n')
129
+ chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
130
  humanstr = prompt_dict.get('humanstr', '')
131
  botstr = prompt_dict.get('botstr', '')
132
  elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
133
  PromptType.plain.name]:
134
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
135
  terminate_response = []
136
+ chat_turn_sep = chat_sep = ''
137
+ # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
138
+ humanstr = None
139
+ botstr = None
140
  elif prompt_type == 'simple_instruct':
141
  promptA = promptB = PreInstruct = PreInput = PreResponse = None
142
  terminate_response = []
143
+ chat_turn_sep = chat_sep = '\n'
144
+ humanstr = None
145
+ botstr = None
146
  elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
147
  PromptType.instruct.name] + [PromptType.instruct_with_end.value,
148
  str(PromptType.instruct_with_end.value),
 
168
  terminate_response = ['### End']
169
  else:
170
  terminate_response = None
171
+ chat_turn_sep = chat_sep = '\n'
172
  humanstr = PreInstruct
173
  botstr = PreResponse
174
  elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
 
190
  ### Response:
191
  """
192
  terminate_response = None
193
+ chat_turn_sep = chat_sep = '\n'
194
  humanstr = PreInstruct # first thing human says
195
  botstr = PreResponse # first thing bot says
196
  elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
 
212
 
213
  """
214
  preprompt = PRE_PROMPT.format(cur_date, cur_time)
215
+ start = ''
216
+ promptB = promptA = '%s%s' % (preprompt, start)
217
 
218
+ PreInstruct = human + ' '
219
 
220
  PreInput = None
221
 
222
+ if making_context:
223
  # when making context, want it to appear as-if LLM generated, which starts with space after :
224
  PreResponse = bot + ' '
225
  else:
 
227
  # if add space here, non-unique tokenization will often make LLM produce wrong output
228
  PreResponse = bot
229
 
230
+ terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
231
+ chat_turn_sep = chat_sep = '\n'
232
  humanstr = human # tag before human talks
233
  botstr = bot # tag before bot talks
234
+ generates_leading_space = True
235
  elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
236
  PromptType.dai_faq.name]:
237
  promptA = ''
 
247
  ### Driverless AI documentation answer:
248
  """
249
  terminate_response = ['\n\n']
250
+ chat_turn_sep = chat_sep = terminate_response
251
  humanstr = PreInstruct
252
  botstr = PreResponse
253
  elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
 
256
  PreInstruct = '## Main Text\n\n'
257
  PreResponse = '\n\n## Summary\n\n'
258
  terminate_response = None
259
+ chat_turn_sep = chat_sep = '\n'
260
  humanstr = PreInstruct
261
  botstr = PreResponse
262
  elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
 
276
  """
277
  terminate_response = [
278
  '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
279
+ chat_turn_sep = chat_sep = '\n'
280
  humanstr = PreInstruct
281
  botstr = PreResponse
282
  elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
 
284
  preprompt = ''
285
  prompt_tokens = "<|prompt|>"
286
  answer_tokens = "<|answer|>"
287
+ start = ''
288
  promptB = promptA = '%s%s' % (preprompt, start)
289
+ PreInstruct = prompt_tokens
290
  PreInput = None
291
  PreResponse = answer_tokens
292
  eos = '<|endoftext|>' # neox eos
 
 
293
  humanstr = prompt_tokens
294
  botstr = answer_tokens
295
+ terminate_response = [humanstr, PreResponse, eos]
296
+ chat_sep = ''
297
+ chat_turn_sep = eos
298
+ elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
299
+ PromptType.prompt_answer_openllama.name]:
300
+ preprompt = ''
301
+ prompt_tokens = "<|prompt|>"
302
+ answer_tokens = "<|answer|>"
303
+ start = ''
304
+ promptB = promptA = '%s%s' % (preprompt, start)
305
+ PreInstruct = prompt_tokens
306
+ PreInput = None
307
+ PreResponse = answer_tokens
308
+ eos = '</s>' # llama eos
309
+ humanstr = prompt_tokens
310
+ botstr = answer_tokens
311
+ terminate_response = [humanstr, PreResponse, eos]
312
+ chat_sep = ''
313
+ chat_turn_sep = eos
314
  elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
315
  PromptType.open_assistant.name]:
316
  # From added_tokens.json
317
  preprompt = ''
318
  prompt_tokens = "<|prompter|>"
319
  answer_tokens = "<|assistant|>"
320
+ start = ''
321
  promptB = promptA = '%s%s' % (preprompt, start)
322
+ PreInstruct = prompt_tokens
323
  PreInput = None
324
  PreResponse = answer_tokens
325
  pend = "<|prefix_end|>"
326
  eos = "</s>"
 
 
327
  humanstr = prompt_tokens
328
  botstr = answer_tokens
329
+ terminate_response = [humanstr, PreResponse, pend, eos]
330
+ chat_turn_sep = chat_sep = eos
331
  elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
332
  PromptType.wizard_lm.name]:
333
  # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
 
339
  PreResponse = "\n\n### Response\n"
340
  eos = "</s>"
341
  terminate_response = [PreResponse, eos]
342
+ chat_turn_sep = chat_sep = eos
343
  humanstr = promptA
344
  botstr = PreResponse
345
  elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
 
355
  ### Assistant:
356
  """
357
  terminate_response = [PreResponse]
358
+ chat_turn_sep = chat_sep = '\n'
359
  humanstr = PreInstruct
360
  botstr = PreResponse
361
  elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
362
  PromptType.instruct_vicuna2.name]:
363
+ promptA = promptB = "" if not (chat and reduced) else ''
 
364
 
365
  PreInstruct = """
366
  HUMAN:
 
373
  """
374
  terminate_response = [
375
  'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
376
+ chat_turn_sep = chat_sep = '\n'
377
  humanstr = PreInstruct
378
  botstr = PreResponse
379
  elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
380
  PromptType.instruct_vicuna3.name]:
381
+ promptA = promptB = "" if not (chat and reduced) else ''
 
382
 
383
  PreInstruct = """
384
  ### User:
 
391
  """
392
  terminate_response = [
393
  '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
394
+ chat_turn_sep = chat_sep = '\n'
395
  humanstr = PreInstruct
396
  botstr = PreResponse
397
  elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
398
  PromptType.wizard2.name]:
399
  # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
400
+ preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
401
+ chat and reduced) else ''
402
  start = ''
403
  promptB = promptA = '%s%s' % (preprompt, start)
404
  PreInstruct = """
 
409
  ### Response:
410
  """
411
  terminate_response = [PreResponse]
412
+ chat_turn_sep = chat_sep = '\n'
413
  humanstr = PreInstruct
414
  botstr = PreResponse
415
  elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
416
  PromptType.wizard3.name]:
417
  # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
418
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
419
+ chat and reduced) else ''
420
  start = ''
421
  promptB = promptA = '%s%s' % (preprompt, start)
422
  PreInstruct = """USER: """
423
  PreInput = None
424
  PreResponse = """ASSISTANT: """
425
  terminate_response = [PreResponse]
426
+ chat_turn_sep = chat_sep = '\n'
427
+ humanstr = PreInstruct
428
+ botstr = PreResponse
429
+ elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
430
+ PromptType.wizard_vicuna.name]:
431
+ preprompt = ''
432
+ start = ''
433
+ promptB = promptA = '%s%s' % (preprompt, start)
434
+ PreInstruct = """USER: """
435
+ PreInput = None
436
+ PreResponse = """ASSISTANT: """
437
+ terminate_response = [PreResponse]
438
+ chat_turn_sep = chat_sep = '\n'
439
  humanstr = PreInstruct
440
  botstr = PreResponse
441
 
442
  elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
443
  PromptType.instruct_simple.name]:
444
+ promptB = promptA = '' if not (chat and reduced) else ''
 
445
 
446
  PreInstruct = """
447
  ### Instruction:
 
455
  ### Response:
456
  """
457
  terminate_response = None
458
+ chat_turn_sep = chat_sep = '\n'
459
  humanstr = PreInstruct
460
  botstr = PreResponse
461
+ elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
462
+ PromptType.openai.name]:
463
+ preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
464
+ chat and reduced) else ''
465
+ start = ''
466
+ promptB = promptA = '%s%s' % (preprompt, start)
467
+ PreInstruct = "\nHuman: "
468
+ PreInput = None
469
+ PreResponse = "\nAI:"
470
+ terminate_response = [PreResponse] + [" Human:", " AI:"]
471
+ chat_turn_sep = chat_sep = '\n'
472
+ humanstr = PreInstruct
473
+ botstr = PreResponse
474
+ elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
475
+ PromptType.gptj.name]:
476
+ preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
477
+ chat and reduced) else ''
478
+ start = ''
479
+ promptB = promptA = '%s%s' % (preprompt, start)
480
+ PreInstruct = "\n### Prompt: "
481
+ PreInput = None
482
+ PreResponse = "\n### Response: "
483
+ terminate_response = [PreResponse] + ["Prompt:", "Response:"]
484
+ chat_turn_sep = chat_sep = '\n'
485
+ humanstr = PreInstruct
486
+ botstr = PreResponse
487
+ elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
488
+ PromptType.openai_chat.name]:
489
+ # prompting and termination all handled by endpoint
490
+ preprompt = """"""
491
+ start = ''
492
+ promptB = promptA = '%s%s' % (preprompt, start)
493
+ PreInstruct = ""
494
+ PreInput = None
495
+ PreResponse = ""
496
+ terminate_response = []
497
+ chat_turn_sep = chat_sep = '\n'
498
+ humanstr = None
499
+ botstr = None
500
+ elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
501
+ PromptType.vicuna11.name]:
502
+ preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
503
+ chat and reduced) else ''
504
+ start = ''
505
+ promptB = promptA = '%s%s' % (preprompt, start)
506
+ eos = '</s>'
507
+ PreInstruct = """USER: """
508
+ PreInput = None
509
+ PreResponse = """ASSISTANT:"""
510
+ terminate_response = [PreResponse]
511
+ chat_sep = ' '
512
+ chat_turn_sep = eos
513
+ humanstr = PreInstruct
514
+ botstr = PreResponse
515
+
516
+ if making_context:
517
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
518
+ PreResponse = PreResponse + ' '
519
+ else:
520
+ # normally LLM adds space after this, because was how trained.
521
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
522
+ PreResponse = PreResponse
523
  else:
524
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
525
 
526
+ if isinstance(terminate_response, (tuple, list)):
527
+ assert '' not in terminate_response, "Bad terminate_response"
528
+
529
+ ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
530
  PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
531
+ chat_turn_sep=chat_turn_sep,
532
+ humanstr=humanstr, botstr=botstr,
533
+ generates_leading_space=generates_leading_space)
534
+
535
+ if return_dict:
536
+ return ret_dict, prompt_dict_error
537
  else:
538
+ return tuple(list(ret_dict.values()))
539
 
540
 
541
+ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
542
  context = data_point.get('context')
543
  if context is None:
544
  context = ''
 
549
  prompt_dict = data_point.get('prompt_dict', prompt_dict)
550
  assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
551
  promptA, promptB, PreInstruct, PreInput, PreResponse, \
552
+ terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
553
+ generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
554
+ context, reduced, making_context)
555
 
556
+ # could avoid if reduce=True, but too complex for parent functions to handle
557
+ prompt = context
558
 
559
  if input and promptA:
560
  prompt += f"""{promptA}"""
 
563
 
564
  if instruction and PreInstruct is not None and input and PreInput is not None:
565
  prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
566
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
567
  elif instruction and input and PreInstruct is None and PreInput is not None:
568
  prompt += f"""{PreInput}{instruction}
569
  {input}"""
570
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
571
  elif input and instruction and PreInput is None and PreInstruct is not None:
572
  prompt += f"""{PreInstruct}{instruction}
573
  {input}"""
574
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
575
  elif instruction and PreInstruct is not None:
576
  prompt += f"""{PreInstruct}{instruction}"""
577
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
578
  elif input and PreInput is not None:
579
  prompt += f"""{PreInput}{input}"""
580
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
581
  elif input and instruction and PreInput is not None:
582
  prompt += f"""{PreInput}{instruction}{input}"""
583
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
584
  elif input and instruction and PreInstruct is not None:
585
  prompt += f"""{PreInstruct}{instruction}{input}"""
586
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
587
  elif input and instruction:
588
  # i.e. for simple_instruct
589
  prompt += f"""{instruction}: {input}"""
590
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
591
  elif input:
592
  prompt += f"""{input}"""
593
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
594
  elif instruction:
595
  prompt += f"""{instruction}"""
596
+ prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
597
 
598
  if PreResponse is not None:
599
  prompt += f"""{PreResponse}"""
 
604
  if output:
605
  prompt += f"""{output}"""
606
 
607
+ return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
608
 
609
 
610
+ def inject_chatsep(prompt_type, prompt, chat_sep=None):
611
+ if chat_sep:
612
  # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
613
+ prompt += chat_sep
614
  return prompt
615
 
616
 
 
619
  allowed_repeat_line_length=10):
620
  self.prompt_type = prompt_type
621
  self.prompt_dict = prompt_dict
 
 
 
622
  self.debug = debug
623
  self.chat = chat
624
  self.stream_output = stream_output
 
627
  self.prompt = None
628
  context = "" # not for chat context
629
  reduced = False # not for chat context
630
+ making_context = False # not for chat context
631
  self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
632
+ self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
633
+ self.generates_leading_space = \
634
+ get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
635
+ self.pre_response = self.PreResponse
636
+
637
+ def generate_prompt(self, data_point, reduced=None):
638
+ """
639
+ data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
640
+ :param data_point:
641
+ :param reduced:
642
+ :return:
643
+ """
644
+ reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
645
+ making_context = False # whether really making final prompt or just generating context
646
+ prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
647
+ making_context)
648
  if self.debug:
649
+ print("prompt: %s" % prompt, flush=True)
650
+ # if have context, should have always reduced and only preappend promptA/B here
651
+ if data_point.get('context'):
652
+ if data_point.get('input') and self.promptA:
653
+ prompt = self.promptA + prompt
654
+ elif self.promptB:
655
+ prompt = self.promptB + prompt
656
+
657
  self.prompt = prompt
658
  return prompt
659
 
660
+ def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
661
  if isinstance(outputs, str):
662
  outputs = [outputs]
663
  if self.debug:
664
+ print("output:\n%s" % '\n\n'.join(outputs), flush=True)
665
  if prompt is not None:
666
  self.prompt = prompt
667
 
 
672
  if sanitize_bot_response:
673
  from better_profanity import profanity
674
  response = profanity.censor(response)
675
+ if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
676
+ response = response[1:]
677
  return response
678
 
679
  def clean_repeats(response):
 
695
  # then use most basic parsing like pipeline
696
  if self.botstr in output:
697
  if self.humanstr:
698
+ output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
699
  else:
700
  # i.e. use after bot but only up to next bot
701
+ output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
702
  else:
703
+ # output = clean_response(output)
704
  # assume just not printed yet
705
  output = ""
706
  else:
 
727
  allow_terminate = True
728
  output = output[len(prompt):]
729
  # clean after subtract prompt out, so correct removal of pre_response
730
+ output = clean_response(output)
731
  if self.repeat_penalty:
732
+ output = clean_repeats(output)
733
  if self.terminate_response and allow_terminate:
734
  finds = []
735
  for term in self.terminate_response:
 
737
  finds = [x for x in finds if x >= 0]
738
  if len(finds) > 0:
739
  termi = finds[0]
740
+ output = output[:termi]
741
  else:
742
+ output = output
 
 
743
  if multi_output:
744
  # prefix with output counter
745
  output = "\n=========== Output %d\n\n" % (1 + oi) + output
 
750
  # join all outputs, only one extra new line between outputs
751
  output = '\n'.join(outputs)
752
  if self.debug:
753
+ print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
754
  return output
requirements.txt CHANGED
@@ -1,50 +1,50 @@
1
  # for generate (gradio server) and finetune
2
- datasets==2.12.0
3
- sentencepiece==0.1.97
4
- gradio==3.34.0
5
- huggingface_hub==0.14.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
- docutils==0.19
9
  torch==2.0.1
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
13
  scikit-learn==1.2.2
14
  alt-profanity-check==1.2.2
15
- better-profanity==0.6.1
16
- numpy==1.24.2
17
- pandas==2.0.0
18
  matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
- accelerate==0.19.0
22
- git+https://github.com/huggingface/peft.git@3714aa2fff158fdfa637b2b65952580801d890b2
23
- transformers==4.28.1
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
26
 
27
  # optional for generate
28
  pynvml==11.5.0
29
- psutil==5.9.4
30
  boto3==1.26.101
31
  botocore==1.29.101
32
 
33
  # optional for finetune
34
- tensorboard==2.12.1
35
- neptune==1.1.1
36
 
37
  # for gradio client
38
- gradio_client==0.2.6
39
  beautifulsoup4==4.12.2
40
- markdown==3.4.1
41
 
42
  # data and testing
43
  pytest==7.2.2
44
  pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
- pandoc==2.3
48
  #pypandoc==1.11
49
  pypandoc_binary==1.11
50
  openpyxl==3.1.2
@@ -57,17 +57,20 @@ instructorembedding==1.0.1
57
 
58
  # for gpt4all .env file, but avoid worrying about imports
59
  python-dotenv==1.0.0
 
 
 
 
 
 
60
  # optional for chat with PDF
61
- langchain==0.0.193
62
- pypdf==3.8.1
63
- tiktoken==0.3.3
64
  # avoid textract, requires old six
65
  #textract==1.6.5
66
 
67
  # for HF embeddings
68
  sentence_transformers==2.2.2
69
- # for OpenAI embeddings (requires key)
70
- openai==0.27.6
71
 
72
  # local vector db
73
  chromadb==0.3.25
@@ -79,14 +82,14 @@ chromadb==0.3.25
79
 
80
  # strong support for images
81
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
82
- unstructured[local-inference]==0.6.6
83
  #pdf2image==1.16.3
84
  #pytesseract==0.3.10
85
  pillow
86
 
87
  pdfminer.six==20221105
88
- urllib3==1.26.6
89
- requests_file==1.5.1
90
 
91
  #pdf2image==1.16.3
92
  #pytesseract==0.3.10
@@ -101,18 +104,15 @@ tabulate==0.9.0
101
  pip-licenses==4.3.0
102
 
103
  # weaviate vector db
104
- weaviate-client==3.19.2
105
  # optional for chat with PDF
106
- langchain==0.0.193
107
- pypdf==3.8.1
108
- tiktoken==0.3.3
109
  # avoid textract, requires old six
110
  #textract==1.6.5
111
 
112
  # for HF embeddings
113
  sentence_transformers==2.2.2
114
- # for OpenAI embeddings (requires key)
115
- openai==0.27.6
116
 
117
  # local vector db
118
  chromadb==0.3.25
@@ -124,14 +124,14 @@ chromadb==0.3.25
124
 
125
  # strong support for images
126
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
- unstructured[local-inference]==0.6.6
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
130
  pillow
131
 
132
  pdfminer.six==20221105
133
- urllib3==1.26.6
134
- requests_file==1.5.1
135
 
136
  #pdf2image==1.16.3
137
  #pytesseract==0.3.10
@@ -146,7 +146,7 @@ tabulate==0.9.0
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
- weaviate-client==3.19.2
150
  faiss-gpu==1.7.2
151
  arxiv==1.4.7
152
  pymupdf==1.22.3 # AGPL license
 
1
  # for generate (gradio server) and finetune
2
+ datasets==2.13.0
3
+ sentencepiece==0.1.99
4
+ gradio==3.35.2
5
+ huggingface_hub==0.15.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
+ docutils==0.20.1
9
  torch==2.0.1
10
  evaluate==0.4.0
11
  rouge_score==0.1.2
12
  sacrebleu==2.3.1
13
  scikit-learn==1.2.2
14
  alt-profanity-check==1.2.2
15
+ better-profanity==0.7.0
16
+ numpy==1.24.3
17
+ pandas==2.0.2
18
  matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
+ accelerate==0.20.3
22
+ git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
23
+ transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
26
 
27
  # optional for generate
28
  pynvml==11.5.0
29
+ psutil==5.9.5
30
  boto3==1.26.101
31
  botocore==1.29.101
32
 
33
  # optional for finetune
34
+ tensorboard==2.13.0
35
+ neptune==1.2.0
36
 
37
  # for gradio client
38
+ gradio_client==0.2.7
39
  beautifulsoup4==4.12.2
40
+ markdown==3.4.3
41
 
42
  # data and testing
43
  pytest==7.2.2
44
  pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
+ # pandoc==2.3
48
  #pypandoc==1.11
49
  pypandoc_binary==1.11
50
  openpyxl==3.1.2
 
57
 
58
  # for gpt4all .env file, but avoid worrying about imports
59
  python-dotenv==1.0.0
60
+
61
+ text-generation==0.6.0
62
+ # for tokenization when don't have HF tokenizer
63
+ tiktoken==0.4.0
64
+ # optional: for OpenAI endpoint or embeddings (requires key)
65
+ openai==0.27.8
66
  # optional for chat with PDF
67
+ langchain==0.0.202
68
+ pypdf==3.9.1
 
69
  # avoid textract, requires old six
70
  #textract==1.6.5
71
 
72
  # for HF embeddings
73
  sentence_transformers==2.2.2
 
 
74
 
75
  # local vector db
76
  chromadb==0.3.25
 
82
 
83
  # strong support for images
84
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
85
+ unstructured[local-inference]==0.7.4
86
  #pdf2image==1.16.3
87
  #pytesseract==0.3.10
88
  pillow
89
 
90
  pdfminer.six==20221105
91
+ urllib3
92
+ requests_file
93
 
94
  #pdf2image==1.16.3
95
  #pytesseract==0.3.10
 
104
  pip-licenses==4.3.0
105
 
106
  # weaviate vector db
107
+ weaviate-client==3.20.0
108
  # optional for chat with PDF
109
+ langchain==0.0.202
110
+ pypdf==3.9.1
 
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
114
  # for HF embeddings
115
  sentence_transformers==2.2.2
 
 
116
 
117
  # local vector db
118
  chromadb==0.3.25
 
124
 
125
  # strong support for images
126
  # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
+ unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
130
  pillow
131
 
132
  pdfminer.six==20221105
133
+ urllib3
134
+ requests_file
135
 
136
  #pdf2image==1.16.3
137
  #pytesseract==0.3.10
 
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
+ weaviate-client==3.20.0
150
  faiss-gpu==1.7.2
151
  arxiv==1.4.7
152
  pymupdf==1.22.3 # AGPL license
stopping.py CHANGED
@@ -1,17 +1,18 @@
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
- from prompter import PromptType
5
 
6
 
7
  class StoppingCriteriaSub(StoppingCriteria):
8
 
9
- def __init__(self, stops=[], encounters=[], device="cuda"):
10
  super().__init__()
11
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
  self.encounters = encounters
13
  self.stops = [stop.to(device) for stop in stops]
14
  self.num_stops = [0] * len(stops)
 
15
 
16
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
17
  for stopi, stop in enumerate(self.stops):
@@ -20,12 +21,15 @@ class StoppingCriteriaSub(StoppingCriteria):
20
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
21
  # print("Stopped", flush=True)
22
  return True
 
 
 
23
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
24
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
25
  return False
26
 
27
 
28
- def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:"):
29
  # FIXME: prompt_dict unused currently
30
  if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
31
  if prompt_type == PromptType.human_bot.name:
@@ -67,7 +71,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:',
67
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
68
  # build stopper
69
  stopping_criteria = StoppingCriteriaList(
70
- [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
 
71
  else:
72
  stopping_criteria = StoppingCriteriaList()
73
  return stopping_criteria
 
1
  import torch
2
  from transformers import StoppingCriteria, StoppingCriteriaList
3
 
4
+ from enums import PromptType
5
 
6
 
7
  class StoppingCriteriaSub(StoppingCriteria):
8
 
9
+ def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
10
  super().__init__()
11
  assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
12
  self.encounters = encounters
13
  self.stops = [stop.to(device) for stop in stops]
14
  self.num_stops = [0] * len(stops)
15
+ self.model_max_length = model_max_length
16
 
17
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
18
  for stopi, stop in enumerate(self.stops):
 
21
  if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
22
  # print("Stopped", flush=True)
23
  return True
24
+ if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
25
+ # critical limit
26
+ return True
27
  # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
28
  # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
29
  return False
30
 
31
 
32
+ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
33
  # FIXME: prompt_dict unused currently
34
  if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
35
  if prompt_type == PromptType.human_bot.name:
 
71
  stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
72
  # build stopper
73
  stopping_criteria = StoppingCriteriaList(
74
+ [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
75
+ model_max_length=model_max_length)])
76
  else:
77
  stopping_criteria = StoppingCriteriaList()
78
  return stopping_criteria
utils.py CHANGED
@@ -69,6 +69,25 @@ def ping():
69
  pass
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def get_torch_allocated():
73
  import torch
74
  return torch.cuda.memory_allocated()
@@ -98,27 +117,29 @@ def system_info():
98
  system['CPU_C/%s' % k] = v
99
 
100
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
101
- from pynvml.smi import nvidia_smi
102
- nvsmi = nvidia_smi.getInstance()
103
-
104
- gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
105
- enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
106
- for k, v in gpu_power_dict.items():
107
- system['GPU_W/%s' % k] = v
108
-
109
- gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
110
- enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
111
- for k, v in gpu_temp_dict.items():
112
- system['GPU_C/%s' % k] = v
113
-
114
- gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
115
- enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
116
- gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
117
- enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
118
- gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
119
- for k, v in gpu_memory_frac_dict.items():
120
- system[f'GPU_M/%s' % k] = v
121
-
 
 
122
  system['hash'] = get_githash()
123
 
124
  return system
@@ -167,35 +188,39 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
167
  return zip_file, zip_file
168
 
169
 
170
- def save_generate_output(output=None, base_model=None, save_dir=None):
 
171
  try:
172
- return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir)
 
173
  except Exception as e:
174
  traceback.print_exc()
175
  print('Exception in saving: %s' % str(e))
176
 
177
 
178
- def _save_generate_output(output=None, base_model=None, save_dir=None):
 
179
  """
180
  Save conversation to .json, row by row.
181
  json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
182
  Appends if file exists
183
  """
 
 
184
  assert save_dir, "save_dir must be provided"
185
  if os.path.exists(save_dir) and not os.path.isdir(save_dir):
186
  raise RuntimeError("save_dir already exists and is not a directory!")
187
  os.makedirs(save_dir, exist_ok=True)
188
  import json
189
- if output[-10:] == '\n\n<human>:':
190
- # remove trailing <human>:
191
- output = output[:-10]
192
  with filelock.FileLock("save_dir.lock"):
193
  # lock logging in case have concurrency
194
  with open(os.path.join(save_dir, "history.json"), "a") as f:
195
  # just add [ at start, and ] at end, and have proper JSON dataset
196
  f.write(
197
  " " + json.dumps(
198
- dict(text=output, time=time.ctime(), base_model=base_model)
199
  ) + ",\n"
200
  )
201
 
@@ -801,6 +826,7 @@ def get_kwargs(func, exclude_names=None, **kwargs):
801
 
802
 
803
  import pkg_resources
 
804
  have_faiss = False
805
 
806
  try:
@@ -828,7 +854,7 @@ def hash_file(file):
828
  BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
829
 
830
  md5 = hashlib.md5()
831
- #sha1 = hashlib.sha1()
832
 
833
  with open(file, 'rb') as f:
834
  while True:
@@ -836,7 +862,7 @@ def hash_file(file):
836
  if not data:
837
  break
838
  md5.update(data)
839
- #sha1.update(data)
840
  except BaseException as e:
841
  print("Cannot hash %s due to %s" % (file, str(e)))
842
  traceback.print_exc()
@@ -848,8 +874,55 @@ def start_faulthandler():
848
  # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
849
  # If more than one fork tries to write at same time, then looks corrupted.
850
  import faulthandler
851
- import signal
852
 
853
  # SIGUSR1 in h2oai/__init__.py as well
854
  faulthandler.enable()
855
- faulthandler.register(signal.SIGUSR1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  pass
70
 
71
 
72
+ def ping_gpu():
73
+ try:
74
+ print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
75
+ except AttributeError:
76
+ # some programs wrap print and will fail with flush passed
77
+ pass
78
+ try:
79
+ ping_gpu_memory()
80
+ except Exception as e:
81
+ print('Ping_GPU memory failure: %s' % str(e), flush=True)
82
+
83
+
84
+ def ping_gpu_memory():
85
+ from models.gpu_mem_track import MemTracker
86
+ gpu_tracker = MemTracker() # define a GPU tracker
87
+ from torch.cuda import memory_summary
88
+ gpu_tracker.track()
89
+
90
+
91
  def get_torch_allocated():
92
  import torch
93
  return torch.cuda.memory_allocated()
 
117
  system['CPU_C/%s' % k] = v
118
 
119
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
120
+ try:
121
+ from pynvml.smi import nvidia_smi
122
+ nvsmi = nvidia_smi.getInstance()
123
+
124
+ gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
125
+ enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
126
+ for k, v in gpu_power_dict.items():
127
+ system['GPU_W/%s' % k] = v
128
+
129
+ gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
130
+ enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
131
+ for k, v in gpu_temp_dict.items():
132
+ system['GPU_C/%s' % k] = v
133
+
134
+ gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
135
+ enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
136
+ gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
137
+ enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
138
+ gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
139
+ for k, v in gpu_memory_frac_dict.items():
140
+ system[f'GPU_M/%s' % k] = v
141
+ except ModuleNotFoundError:
142
+ pass
143
  system['hash'] = get_githash()
144
 
145
  return system
 
188
  return zip_file, zip_file
189
 
190
 
191
+ def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
192
+ extra_dict={}):
193
  try:
194
+ return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
195
+ where_from=where_from, extra_dict=extra_dict)
196
  except Exception as e:
197
  traceback.print_exc()
198
  print('Exception in saving: %s' % str(e))
199
 
200
 
201
+ def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
202
+ extra_dict={}):
203
  """
204
  Save conversation to .json, row by row.
205
  json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
206
  Appends if file exists
207
  """
208
+ prompt = '<not set>' if prompt is None else prompt
209
+ output = '<not set>' if output is None else output
210
  assert save_dir, "save_dir must be provided"
211
  if os.path.exists(save_dir) and not os.path.isdir(save_dir):
212
  raise RuntimeError("save_dir already exists and is not a directory!")
213
  os.makedirs(save_dir, exist_ok=True)
214
  import json
215
+ dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
216
+ dict_to_save.update(extra_dict)
 
217
  with filelock.FileLock("save_dir.lock"):
218
  # lock logging in case have concurrency
219
  with open(os.path.join(save_dir, "history.json"), "a") as f:
220
  # just add [ at start, and ] at end, and have proper JSON dataset
221
  f.write(
222
  " " + json.dumps(
223
+ dict_to_save
224
  ) + ",\n"
225
  )
226
 
 
826
 
827
 
828
  import pkg_resources
829
+
830
  have_faiss = False
831
 
832
  try:
 
854
  BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
855
 
856
  md5 = hashlib.md5()
857
+ # sha1 = hashlib.sha1()
858
 
859
  with open(file, 'rb') as f:
860
  while True:
 
862
  if not data:
863
  break
864
  md5.update(data)
865
+ # sha1.update(data)
866
  except BaseException as e:
867
  print("Cannot hash %s due to %s" % (file, str(e)))
868
  traceback.print_exc()
 
874
  # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
875
  # If more than one fork tries to write at same time, then looks corrupted.
876
  import faulthandler
 
877
 
878
  # SIGUSR1 in h2oai/__init__.py as well
879
  faulthandler.enable()
880
+ if hasattr(faulthandler, 'register'):
881
+ # windows/mac
882
+ import signal
883
+ faulthandler.register(signal.SIGUSR1)
884
+
885
+
886
+ def get_hf_server(inference_server):
887
+ inf_split = inference_server.split(" ")
888
+ assert len(inf_split) == 1 or len(inf_split) == 3
889
+ inference_server = inf_split[0]
890
+ if len(inf_split) == 3:
891
+ headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
892
+ else:
893
+ headers = None
894
+ return inference_server, headers
895
+
896
+
897
+ class FakeTokenizer:
898
+ """
899
+ 1) For keeping track of model_max_length
900
+ 2) For when model doesn't directly expose tokenizer but need to count tokens
901
+ """
902
+
903
+ def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
904
+ # dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
905
+ self.model_max_length = model_max_length - 250
906
+ self.encoding_name = encoding_name
907
+ # The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
908
+ import tiktoken
909
+ self.encoding = tiktoken.get_encoding(self.encoding_name)
910
+
911
+ def encode(self, x, *args, return_tensors="pt", **kwargs):
912
+ input_ids = self.encoding.encode(x, disallowed_special=())
913
+ if return_tensors == 'pt' and isinstance(input_ids, list):
914
+ import torch
915
+ input_ids = torch.tensor(input_ids)
916
+ return dict(input_ids=input_ids)
917
+
918
+ def decode(self, x, *args, **kwargs):
919
+ # input is input_ids[0] form
920
+ return self.encoding.decode(x)
921
+
922
+ def num_tokens_from_string(self, prompt: str) -> int:
923
+ """Returns the number of tokens in a text string."""
924
+ num_tokens = len(self.encoding.encode(prompt))
925
+ return num_tokens
926
+
927
+ def __call__(self, x, *args, **kwargs):
928
+ return self.encode(x, *args, **kwargs)
utils_langchain.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Union, Optional
2
+ import time
3
+ import queue
4
+
5
+ from langchain.callbacks.base import BaseCallbackHandler
6
+ from langchain.schema import LLMResult
7
+
8
+
9
+ class StreamingGradioCallbackHandler(BaseCallbackHandler):
10
+ """
11
+ Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
12
+ """
13
+ def __init__(self, timeout: Optional[float] = None, block=True):
14
+ super().__init__()
15
+ self.text_queue = queue.SimpleQueue()
16
+ self.stop_signal = None
17
+ self.do_stop = False
18
+ self.timeout = timeout
19
+ self.block = block
20
+
21
+ def on_llm_start(
22
+ self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
23
+ ) -> None:
24
+ """Run when LLM starts running. Clean the queue."""
25
+ while not self.text_queue.empty():
26
+ try:
27
+ self.text_queue.get(block=False)
28
+ except queue.Empty:
29
+ continue
30
+
31
+ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
32
+ """Run on new LLM token. Only available when streaming is enabled."""
33
+ self.text_queue.put(token)
34
+
35
+ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
36
+ """Run when LLM ends running."""
37
+ self.text_queue.put(self.stop_signal)
38
+
39
+ def on_llm_error(
40
+ self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
41
+ ) -> None:
42
+ """Run when LLM errors."""
43
+ self.text_queue.put(self.stop_signal)
44
+
45
+ def __iter__(self):
46
+ return self
47
+
48
+ def __next__(self):
49
+ while True:
50
+ try:
51
+ value = self.stop_signal # value looks unused in pycharm, not true
52
+ if self.do_stop:
53
+ print("hit stop", flush=True)
54
+ # could raise or break, maybe best to raise and make parent see if any exception in thread
55
+ raise StopIteration()
56
+ # break
57
+ value = self.text_queue.get(block=self.block, timeout=self.timeout)
58
+ break
59
+ except queue.Empty:
60
+ time.sleep(0.01)
61
+ if value == self.stop_signal:
62
+ raise StopIteration()
63
+ else:
64
+ return value