pseudotensor commited on
Commit
2ce9a1a
1 Parent(s): b368114

Update with h2oGPT hash ad9d685b188cece0b9c69716ea8e320b74f0caf7

Browse files
client_test.py CHANGED
@@ -48,7 +48,7 @@ import markdown # pip install markdown
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
- from enums import DocumentChoices, LangChainAction
52
 
53
  debug = False
54
 
@@ -68,6 +68,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
 
71
  langchain_action=LangChainAction.QUERY.value,
72
  langchain_agents=[],
73
  prompt_dict=None):
@@ -95,12 +96,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
95
  instruction_nochat=prompt if not chat else '',
96
  iinput_nochat='', # only for chat=False
97
  langchain_mode=langchain_mode,
 
98
  langchain_action=langchain_action,
99
  langchain_agents=langchain_agents,
100
  top_k_docs=top_k_docs,
101
  chunk=True,
102
  chunk_size=512,
103
- document_subset=DocumentChoices.Relevant.name,
104
  document_choice=[],
105
  )
106
  from evaluate_params import eval_func_param_names
@@ -204,10 +206,11 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_ne
204
  instruction_nochat=prompt,
205
  iinput_nochat='',
206
  langchain_mode='Disabled',
 
207
  langchain_action=LangChainAction.QUERY.value,
208
  langchain_agents=[],
209
  top_k_docs=4,
210
- document_subset=DocumentChoices.Relevant.name,
211
  document_choice=[],
212
  )
213
 
 
48
  import pytest
49
  from bs4 import BeautifulSoup # pip install beautifulsoup4
50
 
51
+ from enums import DocumentSubset, LangChainAction
52
 
53
  debug = False
54
 
 
68
  max_new_tokens=50,
69
  top_k_docs=3,
70
  langchain_mode='Disabled',
71
+ add_chat_history_to_context=True,
72
  langchain_action=LangChainAction.QUERY.value,
73
  langchain_agents=[],
74
  prompt_dict=None):
 
96
  instruction_nochat=prompt if not chat else '',
97
  iinput_nochat='', # only for chat=False
98
  langchain_mode=langchain_mode,
99
+ add_chat_history_to_context=add_chat_history_to_context,
100
  langchain_action=langchain_action,
101
  langchain_agents=langchain_agents,
102
  top_k_docs=top_k_docs,
103
  chunk=True,
104
  chunk_size=512,
105
+ document_subset=DocumentSubset.Relevant.name,
106
  document_choice=[],
107
  )
108
  from evaluate_params import eval_func_param_names
 
206
  instruction_nochat=prompt,
207
  iinput_nochat='',
208
  langchain_mode='Disabled',
209
+ add_chat_history_to_context=True,
210
  langchain_action=LangChainAction.QUERY.value,
211
  langchain_agents=[],
212
  top_k_docs=4,
213
+ document_subset=DocumentSubset.Relevant.name,
214
  document_choice=[],
215
  )
216
 
enums.py CHANGED
@@ -32,25 +32,29 @@ class PromptType(Enum):
32
  mptchat = 26
33
  falcon = 27
34
  guanaco = 28
 
35
 
36
 
37
- class DocumentChoices(Enum):
38
  Relevant = 0
39
- Sources = 1
40
- All = 2
41
 
42
 
43
  non_query_commands = [
44
- DocumentChoices.Sources.name,
45
- DocumentChoices.All.name
46
  ]
47
 
48
 
 
 
 
 
49
  class LangChainMode(Enum):
50
  """LangChain mode"""
51
 
52
  DISABLED = "Disabled"
53
- CHAT_LLM = "ChatLLM"
54
  LLM = "LLM"
55
  ALL = "All"
56
  WIKI = "wiki"
@@ -61,6 +65,12 @@ class LangChainMode(Enum):
61
  H2O_DAI_DOCS = "DriverlessAI docs"
62
 
63
 
 
 
 
 
 
 
64
  class LangChainAction(Enum):
65
  """LangChain action"""
66
 
 
32
  mptchat = 26
33
  falcon = 27
34
  guanaco = 28
35
+ llama2 = 29
36
 
37
 
38
+ class DocumentSubset(Enum):
39
  Relevant = 0
40
+ RelSources = 1
41
+ TopKSources = 2
42
 
43
 
44
  non_query_commands = [
45
+ DocumentSubset.RelSources.name,
46
+ DocumentSubset.TopKSources.name
47
  ]
48
 
49
 
50
+ class DocumentChoice(Enum):
51
+ ALL = 'All'
52
+
53
+
54
  class LangChainMode(Enum):
55
  """LangChain mode"""
56
 
57
  DISABLED = "Disabled"
 
58
  LLM = "LLM"
59
  ALL = "All"
60
  WIKI = "wiki"
 
65
  H2O_DAI_DOCS = "DriverlessAI docs"
66
 
67
 
68
+ # modes should not be removed from visible list or added by name
69
+ langchain_modes_intrinsic = [LangChainMode.DISABLED.value,
70
+ LangChainMode.LLM.value,
71
+ LangChainMode.MY_DATA.value]
72
+
73
+
74
  class LangChainAction(Enum):
75
  """LangChain action"""
76
 
evaluate_params.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  no_default_param_names = [
2
  'instruction',
3
  'iinput',
@@ -30,6 +33,7 @@ eval_func_param_names = ['instruction',
30
  'instruction_nochat',
31
  'iinput_nochat',
32
  'langchain_mode',
 
33
  'langchain_action',
34
  'langchain_agents',
35
  'top_k_docs',
 
1
+ input_args_list = ['model_state', 'my_db_state', 'selection_docs_state']
2
+
3
+
4
  no_default_param_names = [
5
  'instruction',
6
  'iinput',
 
33
  'instruction_nochat',
34
  'iinput_nochat',
35
  'langchain_mode',
36
+ 'add_chat_history_to_context',
37
  'langchain_action',
38
  'langchain_agents',
39
  'top_k_docs',
gen.py CHANGED
@@ -8,7 +8,6 @@ import sys
8
  import os
9
  import time
10
  import traceback
11
- import types
12
  import typing
13
  import warnings
14
  from datetime import datetime
@@ -28,12 +27,12 @@ os.environ['BITSANDBYTES_NOWELCOME'] = '1'
28
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
29
 
30
  from evaluate_params import eval_func_param_names, no_default_param_names
31
- from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
32
- source_postfix, LangChainAction, LangChainAgent
33
  from loaders import get_loaders
34
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
35
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
36
- have_langchain, set_openai
37
 
38
  start_faulthandler()
39
  import_matplotlib()
@@ -50,8 +49,6 @@ from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
50
  from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
51
  from stopping import get_stopping
52
 
53
- langchain_modes = [x.value for x in list(LangChainMode)]
54
-
55
  langchain_actions = [x.value for x in list(LangChainAction)]
56
 
57
  langchain_agents_list = [x.value for x in list(LangChainAgent)]
@@ -116,6 +113,7 @@ def main(
116
  show_examples: bool = None,
117
  verbose: bool = False,
118
  h2ocolors: bool = True,
 
119
  height: int = 600,
120
  show_lora: bool = True,
121
  login_mode_if_model0: bool = False,
@@ -147,14 +145,16 @@ def main(
147
  langchain_action: str = LangChainAction.QUERY.value,
148
  langchain_agents: list = [],
149
  force_langchain_evaluate: bool = False,
 
150
  visible_langchain_modes: list = ['UserData', 'MyData'],
151
  # WIP:
152
  # visible_langchain_actions: list = langchain_actions.copy(),
153
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
154
  visible_langchain_agents: list = langchain_agents_list.copy(),
155
- document_subset: str = DocumentChoices.Relevant.name,
156
- document_choice: list = [],
157
  user_path: str = None,
 
158
  detect_user_path_changes_every_query: bool = False,
159
  use_llm_if_no_docs: bool = False,
160
  load_db_if_exists: bool = True,
@@ -163,7 +163,10 @@ def main(
163
  use_openai_embedding: bool = False,
164
  use_openai_model: bool = False,
165
  hf_embedding_model: str = None,
 
 
166
  allow_upload_to_user_data: bool = True,
 
167
  allow_upload_to_my_data: bool = True,
168
  enable_url_upload: bool = True,
169
  enable_text_upload: bool = True,
@@ -180,6 +183,7 @@ def main(
180
  pre_load_caption_model: bool = False,
181
  caption_gpu: bool = True,
182
  enable_ocr: bool = False,
 
183
  ):
184
  """
185
 
@@ -259,6 +263,7 @@ def main(
259
  :param show_examples: whether to show clickable examples in gradio
260
  :param verbose: whether to show verbose prints
261
  :param h2ocolors: whether to use H2O.ai theme
 
262
  :param height: height of chat window
263
  :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
264
  :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
@@ -287,7 +292,7 @@ def main(
287
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
288
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
289
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
290
- None: auto mode, check if langchain package exists, at least do ChatLLM if so, else Disabled
291
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
292
  :param langchain_action: Mode langchain operations in on documents.
293
  Query: Make query of document(s)
@@ -299,18 +304,28 @@ def main(
299
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
300
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
301
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
 
 
 
 
302
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
303
  Expensive for large number of files, so not done by default. By default only detect changes during db loading.
 
304
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
305
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
306
  But wiki_full is expensive and requires preparation
307
  To allow scratch space only live in session, add 'MyData' to list
308
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
 
 
 
 
 
309
  :param visible_langchain_actions: Which actions to allow
310
  :param visible_langchain_agents: Which agents to allow
311
  :param document_subset: Default document choice when taking subset of collection
312
- :param document_choice: Chosen document(s) by internal name
313
- :param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData
314
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
315
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
316
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
@@ -321,13 +336,20 @@ def main(
321
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
322
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
323
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
324
- :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
 
 
 
 
 
 
 
325
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
326
  :param enable_url_upload: Whether to allow upload from URL
327
  :param enable_text_upload: Whether to allow upload of text
328
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
329
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
330
- :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
331
  :param top_k_docs: number of chunks to give LLM
332
  :param reverse_docs: whether to reverse docs order so most relevant is closest to question.
333
  Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
@@ -347,6 +369,9 @@ def main(
347
  Recommended if using larger caption model
348
  :param caption_gpu: If support caption, then use GPU if exists
349
  :param enable_ocr: Whether to support OCR on images
 
 
 
350
  :return:
351
  """
352
  if base_model is None:
@@ -408,6 +433,26 @@ def main(
408
  if langchain_mode is not None:
409
  visible_langchain_modes += [langchain_mode]
410
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
411
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
412
  assert len(
413
  set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
@@ -421,22 +466,22 @@ def main(
421
  # auto-set langchain_mode
422
  if have_langchain and langchain_mode is None:
423
  # start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
424
- langchain_mode = LangChainMode.CHAT_LLM.value
425
- if allow_upload_to_user_data and not is_public and user_path:
426
  print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
427
  elif allow_upload_to_my_data:
428
  print("Auto set langchain_mode=%s. Could use MyData instead."
429
  " To allow UserData to pull files from disk,"
430
- " set user_path and ensure allow_upload_to_user_data=True" % langchain_mode, flush=True)
 
431
  else:
432
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
433
- if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value,
434
- LangChainMode.CHAT_LLM.value]:
435
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
436
  if langchain_mode is None:
437
  # if not set yet, disable
438
  langchain_mode = LangChainMode.DISABLED.value
439
- print("Auto set langchain_mode=%s" % langchain_mode, flush=True)
440
 
441
  if is_public:
442
  allow_upload_to_user_data = False
@@ -547,8 +592,6 @@ def main(
547
 
548
  if offload_folder:
549
  makedirs(offload_folder)
550
- if user_path:
551
- makedirs(user_path)
552
 
553
  placeholder_instruction, placeholder_input, \
554
  stream_output, show_examples, \
@@ -574,7 +617,7 @@ def main(
574
  verbose,
575
  )
576
 
577
- git_hash = get_githash()
578
  locals_dict = locals()
579
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
580
  if verbose:
@@ -588,7 +631,7 @@ def main(
588
  get_some_dbs_from_hf()
589
  dbs = {}
590
  for langchain_mode1 in visible_langchain_modes:
591
- if langchain_mode1 in ['MyData']:
592
  # don't use what is on disk, remove it instead
593
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
594
  if os.path.isdir(gpath1):
@@ -603,7 +646,7 @@ def main(
603
  db = prep_langchain(persist_directory1,
604
  load_db_if_exists,
605
  db_type, use_openai_embedding,
606
- langchain_mode1, user_path,
607
  hf_embedding_model,
608
  kwargs_make_db=locals())
609
  finally:
@@ -622,6 +665,14 @@ def main(
622
  model_state_none = dict(model=None, tokenizer=None, device=None,
623
  base_model=None, tokenizer_base_model=None, lora_weights=None,
624
  inference_server=None, prompt_type=None, prompt_dict=None)
 
 
 
 
 
 
 
 
625
 
626
  if cli:
627
  from cli import run_cli
@@ -1280,6 +1331,7 @@ def get_score_model(score_model: str = None,
1280
  def evaluate(
1281
  model_state,
1282
  my_db_state,
 
1283
  # START NOTE: Examples must have same order of parameters
1284
  instruction,
1285
  iinput,
@@ -1302,6 +1354,7 @@ def evaluate(
1302
  instruction_nochat,
1303
  iinput_nochat,
1304
  langchain_mode,
 
1305
  langchain_action,
1306
  langchain_agents,
1307
  top_k_docs,
@@ -1317,6 +1370,9 @@ def evaluate(
1317
  save_dir=None,
1318
  sanitize_bot_response=False,
1319
  model_state0=None,
 
 
 
1320
  memory_restriction_level=None,
1321
  max_max_new_tokens=None,
1322
  is_public=None,
@@ -1327,11 +1383,11 @@ def evaluate(
1327
  use_llm_if_no_docs=False,
1328
  load_db_if_exists=True,
1329
  dbs=None,
1330
- user_path=None,
1331
  detect_user_path_changes_every_query=None,
1332
  use_openai_embedding=None,
1333
  use_openai_model=None,
1334
  hf_embedding_model=None,
 
1335
  db_type=None,
1336
  n_jobs=None,
1337
  first_para=None,
@@ -1360,6 +1416,16 @@ def evaluate(
1360
  assert chunk_size is not None and isinstance(chunk_size, int)
1361
  assert n_jobs is not None
1362
  assert first_para is not None
 
 
 
 
 
 
 
 
 
 
1363
 
1364
  if debug:
1365
  locals_dict = locals().copy()
@@ -1481,18 +1547,22 @@ def evaluate(
1481
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
1482
  assert len(
1483
  set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
1484
- if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
1485
- db1 = my_db_state[0]
1486
- elif dbs is not None and langchain_mode in dbs:
1487
- db1 = dbs[langchain_mode]
 
 
 
 
1488
  else:
1489
- db1 = None
1490
- do_langchain_path = langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] or \
1491
  base_model in non_hf_types or \
1492
  force_langchain_evaluate
1493
  if do_langchain_path:
1494
  outr = ""
1495
- # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
1496
  from gpt_langchain import run_qa_db
1497
  gen_hyper_langchain = dict(do_sample=do_sample,
1498
  temperature=temperature,
@@ -1515,10 +1585,11 @@ def evaluate(
1515
  prompter=prompter,
1516
  use_llm_if_no_docs=use_llm_if_no_docs,
1517
  load_db_if_exists=load_db_if_exists,
1518
- db=db1,
1519
- user_path=user_path,
1520
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1521
- cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
 
1522
  use_openai_embedding=use_openai_embedding,
1523
  use_openai_model=use_openai_model,
1524
  hf_embedding_model=hf_embedding_model,
@@ -1676,6 +1747,7 @@ def evaluate(
1676
  chat_client = False
1677
  where_from = "gr_client"
1678
  client_langchain_mode = 'Disabled'
 
1679
  client_langchain_action = LangChainAction.QUERY.value
1680
  client_langchain_agents = []
1681
  gen_server_kwargs = dict(temperature=temperature,
@@ -1729,13 +1801,14 @@ def evaluate(
1729
  instruction_nochat=gr_prompt if not chat_client else '',
1730
  iinput_nochat=gr_iinput, # only for chat=False
1731
  langchain_mode=client_langchain_mode,
 
1732
  langchain_action=client_langchain_action,
1733
  langchain_agents=client_langchain_agents,
1734
  top_k_docs=top_k_docs,
1735
  chunk=chunk,
1736
  chunk_size=chunk_size,
1737
- document_subset=DocumentChoices.Relevant.name,
1738
- document_choice=[],
1739
  )
1740
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1741
  if not stream_output:
@@ -2029,7 +2102,7 @@ def evaluate(
2029
 
2030
 
2031
  inputs_list_names = list(inspect.signature(evaluate).parameters)
2032
- state_names = ['model_state', 'my_db_state']
2033
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
2034
 
2035
 
@@ -2312,8 +2385,8 @@ y = np.random.randint(0, 1, 100)
2312
 
2313
  # move to correct position
2314
  for example in examples:
2315
- example += [chat, '', '', LangChainMode.DISABLED.value, LangChainAction.QUERY.value, [],
2316
- top_k_docs, chunk, chunk_size, DocumentChoices.Relevant.name, []
2317
  ]
2318
  # adjust examples if non-chat mode
2319
  if not chat:
@@ -2373,7 +2446,7 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
2373
  truncation=True,
2374
  max_length=max_length_tokenize).to(smodel.device)
2375
  try:
2376
- score = torch.sigmoid(smodel(**inputs).logits[0]).cpu().detach().numpy()[0]
2377
  except torch.cuda.OutOfMemoryError as e:
2378
  print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
2379
  del inputs
@@ -2458,12 +2531,15 @@ def get_minmax_top_k_docs(is_public):
2458
  return min_top_k_docs, max_top_k_docs, label_top_k_docs
2459
 
2460
 
2461
- def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1,
 
 
2462
  memory_restriction_level1, keep_sources_in_context1):
2463
  """
2464
  consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
2465
  :param history:
2466
  :param langchain_mode1:
 
2467
  :param prompt_type1:
2468
  :param prompt_dict1:
2469
  :param chat1:
@@ -2476,7 +2552,7 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
2476
  _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
2477
  for_context=True, model_max_length=model_max_length1)
2478
  context1 = ''
2479
- if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
2480
  context1 = ''
2481
  # - 1 below because current instruction already in history from user()
2482
  for histi in range(0, len(history) - 1):
@@ -2512,6 +2588,22 @@ def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, cha
2512
  return context1
2513
 
2514
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2515
  def entrypoint_main():
2516
  """
2517
  Examples:
 
8
  import os
9
  import time
10
  import traceback
 
11
  import typing
12
  import warnings
13
  from datetime import datetime
 
27
  warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
28
 
29
  from evaluate_params import eval_func_param_names, no_default_param_names
30
+ from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
31
+ source_postfix, LangChainAction, LangChainAgent, DocumentChoice
32
  from loaders import get_loaders
33
  from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
34
  import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \
35
+ have_langchain, set_openai, load_collection_enum
36
 
37
  start_faulthandler()
38
  import_matplotlib()
 
49
  from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
50
  from stopping import get_stopping
51
 
 
 
52
  langchain_actions = [x.value for x in list(LangChainAction)]
53
 
54
  langchain_agents_list = [x.value for x in list(LangChainAgent)]
 
113
  show_examples: bool = None,
114
  verbose: bool = False,
115
  h2ocolors: bool = True,
116
+ dark: bool = False, # light tends to be best
117
  height: int = 600,
118
  show_lora: bool = True,
119
  login_mode_if_model0: bool = False,
 
145
  langchain_action: str = LangChainAction.QUERY.value,
146
  langchain_agents: list = [],
147
  force_langchain_evaluate: bool = False,
148
+ langchain_modes: list = [x.value for x in list(LangChainMode)],
149
  visible_langchain_modes: list = ['UserData', 'MyData'],
150
  # WIP:
151
  # visible_langchain_actions: list = langchain_actions.copy(),
152
  visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value],
153
  visible_langchain_agents: list = langchain_agents_list.copy(),
154
+ document_subset: str = DocumentSubset.Relevant.name,
155
+ document_choice: list = [DocumentChoice.ALL.value],
156
  user_path: str = None,
157
+ langchain_mode_paths: dict = {'UserData': None},
158
  detect_user_path_changes_every_query: bool = False,
159
  use_llm_if_no_docs: bool = False,
160
  load_db_if_exists: bool = True,
 
163
  use_openai_embedding: bool = False,
164
  use_openai_model: bool = False,
165
  hf_embedding_model: str = None,
166
+ cut_distance: float = 1.64,
167
+ add_chat_history_to_context: bool = True,
168
  allow_upload_to_user_data: bool = True,
169
+ reload_langchain_state: bool = True,
170
  allow_upload_to_my_data: bool = True,
171
  enable_url_upload: bool = True,
172
  enable_text_upload: bool = True,
 
183
  pre_load_caption_model: bool = False,
184
  caption_gpu: bool = True,
185
  enable_ocr: bool = False,
186
+ enable_pdf_ocr: str = 'auto',
187
  ):
188
  """
189
 
 
263
  :param show_examples: whether to show clickable examples in gradio
264
  :param verbose: whether to show verbose prints
265
  :param h2ocolors: whether to use H2O.ai theme
266
+ :param dark: whether to use dark mode for UI by default (still controlled in UI)
267
  :param height: height of chat window
268
  :param show_lora: whether to show LORA options in UI (expert so can be hard to understand)
269
  :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped
 
292
  :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
293
  :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
294
  :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
295
+ None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled
296
  WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
297
  :param langchain_action: Mode langchain operations in on documents.
298
  Query: Make query of document(s)
 
304
  :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
305
  :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
306
  If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
307
+ :param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents
308
+ E.g. "{'UserData2': 'userpath2'}"
309
+ Can be None even if existing DB, to avoid new documents being added from that path, source links that are on disk still work.
310
+ If user_path is not None, that path is used for 'UserData' instead of the value in this dict
311
  :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
312
  Expensive for large number of files, so not done by default. By default only detect changes during db loading.
313
+ :param langchain_modes: names of collections/dbs to potentially have
314
  :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
315
  Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
316
  But wiki_full is expensive and requires preparation
317
  To allow scratch space only live in session, add 'MyData' to list
318
  Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
319
+ If have own user modes, need to add these here or add in UI.
320
+ A state file is stored in visible_langchain_modes.pkl containing last UI-selected values of:
321
+ langchain_modes, visible_langchain_modes, and langchain_mode_paths
322
+ Delete the file if you want to start fresh,
323
+ but in any case the user_path passed in CLI is used for UserData even if was None or different
324
  :param visible_langchain_actions: Which actions to allow
325
  :param visible_langchain_agents: Which agents to allow
326
  :param document_subset: Default document choice when taking subset of collection
327
+ :param document_choice: Chosen document(s) by internal name, 'All' means use all docs
328
+ :param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom
329
  :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
330
  :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
331
  :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk
 
336
  Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2"
337
  Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl'
338
  We support automatically changing of embeddings for chroma, with a backup of db made if this is done
339
+ :param cut_distance: Distance to cut off references with larger distances when showing references.
340
+ 1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references.
341
+ For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references.
342
+ :param add_chat_history_to_context: Include chat context when performing action
343
+ Not supported yet for openai_chat when using document collection instead of LLM
344
+ Also not supported when using CLI mode
345
+ :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs)
346
+ :param reload_langchain_state: Whether to reload visible_langchain_modes.pkl file that contains any new user collections.
347
  :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
348
  :param enable_url_upload: Whether to allow upload from URL
349
  :param enable_text_upload: Whether to allow upload of text
350
  :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
351
  :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
352
+ :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length
353
  :param top_k_docs: number of chunks to give LLM
354
  :param reverse_docs: whether to reverse docs order so most relevant is closest to question.
355
  Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
 
369
  Recommended if using larger caption model
370
  :param caption_gpu: If support caption, then use GPU if exists
371
  :param enable_ocr: Whether to support OCR on images
372
+ :param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text
373
+ 'on' means always do OCR as additional parsing of same documents
374
+ 'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked)
375
  :return:
376
  """
377
  if base_model is None:
 
433
  if langchain_mode is not None:
434
  visible_langchain_modes += [langchain_mode]
435
 
436
+ # update
437
+ if isinstance(langchain_mode_paths, str):
438
+ langchain_mode_paths = ast.literal_eval(langchain_mode_paths)
439
+ assert isinstance(langchain_mode_paths, dict)
440
+ if user_path:
441
+ langchain_mode_paths['UserData'] = user_path
442
+ makedirs(user_path)
443
+
444
+ if is_public:
445
+ allow_upload_to_user_data = False
446
+ if LangChainMode.USER_DATA.value in visible_langchain_modes:
447
+ visible_langchain_modes.remove(LangChainMode.USER_DATA.value)
448
+
449
+ # in-place, for non-scratch dbs
450
+ if allow_upload_to_user_data:
451
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
452
+ # always listen to CLI-passed user_path if passed
453
+ if user_path:
454
+ langchain_mode_paths['UserData'] = user_path
455
+
456
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
457
  assert len(
458
  set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
 
466
  # auto-set langchain_mode
467
  if have_langchain and langchain_mode is None:
468
  # start in chat mode, in case just want to chat and don't want to get "No documents to query" by default.
469
+ langchain_mode = LangChainMode.LLM.value
470
+ if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']:
471
  print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True)
472
  elif allow_upload_to_my_data:
473
  print("Auto set langchain_mode=%s. Could use MyData instead."
474
  " To allow UserData to pull files from disk,"
475
+ " set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode,
476
+ flush=True)
477
  else:
478
  raise RuntimeError("Please pass --langchain_mode=<chosen mode> out of %s" % langchain_modes)
479
+ if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]:
 
480
  raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.")
481
  if langchain_mode is None:
482
  # if not set yet, disable
483
  langchain_mode = LangChainMode.DISABLED.value
484
+ print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True)
485
 
486
  if is_public:
487
  allow_upload_to_user_data = False
 
592
 
593
  if offload_folder:
594
  makedirs(offload_folder)
 
 
595
 
596
  placeholder_instruction, placeholder_input, \
597
  stream_output, show_examples, \
 
617
  verbose,
618
  )
619
 
620
+ git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH"
621
  locals_dict = locals()
622
  locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
623
  if verbose:
 
631
  get_some_dbs_from_hf()
632
  dbs = {}
633
  for langchain_mode1 in visible_langchain_modes:
634
+ if langchain_mode1 in ['MyData']: # FIXME: Remove other custom temp dbs
635
  # don't use what is on disk, remove it instead
636
  for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
637
  if os.path.isdir(gpath1):
 
646
  db = prep_langchain(persist_directory1,
647
  load_db_if_exists,
648
  db_type, use_openai_embedding,
649
+ langchain_mode1, langchain_mode_paths,
650
  hf_embedding_model,
651
  kwargs_make_db=locals())
652
  finally:
 
665
  model_state_none = dict(model=None, tokenizer=None, device=None,
666
  base_model=None, tokenizer_base_model=None, lora_weights=None,
667
  inference_server=None, prompt_type=None, prompt_dict=None)
668
+ my_db_state0 = {LangChainMode.MY_DATA.value: [None, None]}
669
+ selection_docs_state0 = dict(visible_langchain_modes=visible_langchain_modes,
670
+ langchain_mode_paths=langchain_mode_paths,
671
+ langchain_modes=langchain_modes)
672
+ selection_docs_state = selection_docs_state0
673
+ langchain_modes0 = langchain_modes
674
+ langchain_mode_paths0 = langchain_mode_paths
675
+ visible_langchain_modes0 = visible_langchain_modes
676
 
677
  if cli:
678
  from cli import run_cli
 
1331
  def evaluate(
1332
  model_state,
1333
  my_db_state,
1334
+ selection_docs_state,
1335
  # START NOTE: Examples must have same order of parameters
1336
  instruction,
1337
  iinput,
 
1354
  instruction_nochat,
1355
  iinput_nochat,
1356
  langchain_mode,
1357
+ add_chat_history_to_context,
1358
  langchain_action,
1359
  langchain_agents,
1360
  top_k_docs,
 
1370
  save_dir=None,
1371
  sanitize_bot_response=False,
1372
  model_state0=None,
1373
+ langchain_modes0=None,
1374
+ langchain_mode_paths0=None,
1375
+ visible_langchain_modes0=None,
1376
  memory_restriction_level=None,
1377
  max_max_new_tokens=None,
1378
  is_public=None,
 
1383
  use_llm_if_no_docs=False,
1384
  load_db_if_exists=True,
1385
  dbs=None,
 
1386
  detect_user_path_changes_every_query=None,
1387
  use_openai_embedding=None,
1388
  use_openai_model=None,
1389
  hf_embedding_model=None,
1390
+ cut_distance=None,
1391
  db_type=None,
1392
  n_jobs=None,
1393
  first_para=None,
 
1416
  assert chunk_size is not None and isinstance(chunk_size, int)
1417
  assert n_jobs is not None
1418
  assert first_para is not None
1419
+ assert isinstance(add_chat_history_to_context, bool)
1420
+
1421
+ if selection_docs_state is not None:
1422
+ langchain_modes = selection_docs_state.get('langchain_modes', langchain_modes0)
1423
+ langchain_mode_paths = selection_docs_state.get('langchain_mode_paths', langchain_mode_paths0)
1424
+ visible_langchain_modes = selection_docs_state.get('visible_langchain_modes', visible_langchain_modes0)
1425
+ else:
1426
+ langchain_modes = langchain_modes0
1427
+ langchain_mode_paths = langchain_mode_paths0
1428
+ visible_langchain_modes = visible_langchain_modes0
1429
 
1430
  if debug:
1431
  locals_dict = locals().copy()
 
1547
  assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action
1548
  assert len(
1549
  set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents
1550
+ if dbs is not None and langchain_mode in dbs:
1551
+ db = dbs[langchain_mode]
1552
+ elif my_db_state is not None and langchain_mode in my_db_state:
1553
+ db1 = my_db_state[langchain_mode]
1554
+ if db1 is not None and len(db1) == 2:
1555
+ db = db1[0]
1556
+ else:
1557
+ db = None
1558
  else:
1559
+ db = None
1560
+ do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \
1561
  base_model in non_hf_types or \
1562
  force_langchain_evaluate
1563
  if do_langchain_path:
1564
  outr = ""
1565
+ # use smaller cut_distance for wiki_full since so many matches could be obtained, and often irrelevant unless close
1566
  from gpt_langchain import run_qa_db
1567
  gen_hyper_langchain = dict(do_sample=do_sample,
1568
  temperature=temperature,
 
1585
  prompter=prompter,
1586
  use_llm_if_no_docs=use_llm_if_no_docs,
1587
  load_db_if_exists=load_db_if_exists,
1588
+ db=db,
1589
+ langchain_mode_paths=langchain_mode_paths,
1590
  detect_user_path_changes_every_query=detect_user_path_changes_every_query,
1591
+ cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance,
1592
+ add_chat_history_to_context=add_chat_history_to_context,
1593
  use_openai_embedding=use_openai_embedding,
1594
  use_openai_model=use_openai_model,
1595
  hf_embedding_model=hf_embedding_model,
 
1747
  chat_client = False
1748
  where_from = "gr_client"
1749
  client_langchain_mode = 'Disabled'
1750
+ client_add_chat_history_to_context = True
1751
  client_langchain_action = LangChainAction.QUERY.value
1752
  client_langchain_agents = []
1753
  gen_server_kwargs = dict(temperature=temperature,
 
1801
  instruction_nochat=gr_prompt if not chat_client else '',
1802
  iinput_nochat=gr_iinput, # only for chat=False
1803
  langchain_mode=client_langchain_mode,
1804
+ add_chat_history_to_context=client_add_chat_history_to_context,
1805
  langchain_action=client_langchain_action,
1806
  langchain_agents=client_langchain_agents,
1807
  top_k_docs=top_k_docs,
1808
  chunk=chunk,
1809
  chunk_size=chunk_size,
1810
+ document_subset=DocumentSubset.Relevant.name,
1811
+ document_choice=[DocumentChoice.ALL.value],
1812
  )
1813
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
1814
  if not stream_output:
 
2102
 
2103
 
2104
  inputs_list_names = list(inspect.signature(evaluate).parameters)
2105
+ state_names = ['model_state', 'my_db_state', 'selection_docs_state']
2106
  inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
2107
 
2108
 
 
2385
 
2386
  # move to correct position
2387
  for example in examples:
2388
+ example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [],
2389
+ top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, []
2390
  ]
2391
  # adjust examples if non-chat mode
2392
  if not chat:
 
2446
  truncation=True,
2447
  max_length=max_length_tokenize).to(smodel.device)
2448
  try:
2449
+ score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0]
2450
  except torch.cuda.OutOfMemoryError as e:
2451
  print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True)
2452
  del inputs
 
2531
  return min_top_k_docs, max_top_k_docs, label_top_k_docs
2532
 
2533
 
2534
+ def history_to_context(history, langchain_mode1,
2535
+ add_chat_history_to_context,
2536
+ prompt_type1, prompt_dict1, chat1, model_max_length1,
2537
  memory_restriction_level1, keep_sources_in_context1):
2538
  """
2539
  consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
2540
  :param history:
2541
  :param langchain_mode1:
2542
+ :param add_chat_history_to_context:
2543
  :param prompt_type1:
2544
  :param prompt_dict1:
2545
  :param chat1:
 
2552
  _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
2553
  for_context=True, model_max_length=model_max_length1)
2554
  context1 = ''
2555
+ if max_prompt_length is not None and add_chat_history_to_context:
2556
  context1 = ''
2557
  # - 1 below because current instruction already in history from user()
2558
  for histi in range(0, len(history) - 1):
 
2588
  return context1
2589
 
2590
 
2591
+ def update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, extra):
2592
+ # update from saved state on disk
2593
+ langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = \
2594
+ load_collection_enum(extra)
2595
+
2596
+ visible_langchain_modes_temp = visible_langchain_modes.copy() + visible_langchain_modes_from_file
2597
+ visible_langchain_modes.clear() # don't lose original reference
2598
+ [visible_langchain_modes.append(x) for x in visible_langchain_modes_temp if x not in visible_langchain_modes]
2599
+
2600
+ langchain_mode_paths.update(langchain_mode_paths_from_file)
2601
+
2602
+ langchain_modes_temp = langchain_modes.copy() + langchain_modes_from_file
2603
+ langchain_modes.clear() # don't lose original reference
2604
+ [langchain_modes.append(x) for x in langchain_modes_temp if x not in langchain_modes]
2605
+
2606
+
2607
  def entrypoint_main():
2608
  """
2609
  Examples:
gpt4all_llm.py CHANGED
@@ -95,15 +95,17 @@ def get_llm_gpt4all(model_name,
95
  streaming=False,
96
  callbacks=None,
97
  prompter=None,
 
 
98
  verbose=False,
99
  ):
100
  assert prompter is not None
101
  env_gpt4all_file = ".env_gpt4all"
102
  env_kwargs = dotenv_values(env_gpt4all_file)
103
- n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
104
  default_kwargs = dict(context_erase=0.5,
105
  n_batch=1,
106
- n_ctx=n_ctx,
107
  n_predict=max_new_tokens,
108
  repeat_last_n=64 if repetition_penalty != 1.0 else 0,
109
  repeat_penalty=repetition_penalty,
@@ -117,7 +119,8 @@ def get_llm_gpt4all(model_name,
117
  cls = H2OLlamaCpp
118
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
119
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
120
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
 
121
  llm = cls(**model_kwargs)
122
  llm.client.verbose = verbose
123
  elif model_name == 'gpt4all_llama':
@@ -125,14 +128,16 @@ def get_llm_gpt4all(model_name,
125
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
126
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
127
  model_kwargs.update(
128
- dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
 
129
  llm = cls(**model_kwargs)
130
  elif model_name == 'gptj':
131
  cls = H2OGPT4All
132
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
133
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
134
  model_kwargs.update(
135
- dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
 
136
  llm = cls(**model_kwargs)
137
  else:
138
  raise RuntimeError("No such model_name %s" % model_name)
@@ -142,6 +147,8 @@ def get_llm_gpt4all(model_name,
142
  class H2OGPT4All(gpt4all.GPT4All):
143
  model: Any
144
  prompter: Any
 
 
145
  """Path to the pre-trained GPT4All model file."""
146
 
147
  @root_validator()
@@ -187,10 +194,11 @@ class H2OGPT4All(gpt4all.GPT4All):
187
  **kwargs,
188
  ) -> str:
189
  # Roughly 4 chars per token if natural language
190
- prompt = prompt[-self.n_ctx * 4:]
 
191
 
192
  # use instruct prompting
193
- data_point = dict(context='', instruction=prompt, input='')
194
  prompt = self.prompter.generate_prompt(data_point)
195
 
196
  verbose = False
@@ -206,6 +214,8 @@ from langchain.llms import LlamaCpp
206
  class H2OLlamaCpp(LlamaCpp):
207
  model_path: Any
208
  prompter: Any
 
 
209
  """Path to the pre-trained GPT4All model file."""
210
 
211
  @root_validator()
@@ -276,7 +286,7 @@ class H2OLlamaCpp(LlamaCpp):
276
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
277
 
278
  # use instruct prompting
279
- data_point = dict(context='', instruction=prompt, input='')
280
  prompt = self.prompter.generate_prompt(data_point)
281
 
282
  if verbose:
 
95
  streaming=False,
96
  callbacks=None,
97
  prompter=None,
98
+ context='',
99
+ iinput='',
100
  verbose=False,
101
  ):
102
  assert prompter is not None
103
  env_gpt4all_file = ".env_gpt4all"
104
  env_kwargs = dotenv_values(env_gpt4all_file)
105
+ max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens)
106
  default_kwargs = dict(context_erase=0.5,
107
  n_batch=1,
108
+ max_tokens=max_tokens,
109
  n_predict=max_new_tokens,
110
  repeat_last_n=64 if repetition_penalty != 1.0 else 0,
111
  repeat_penalty=repetition_penalty,
 
119
  cls = H2OLlamaCpp
120
  model_path = env_kwargs.pop('model_path_llama') if model is None else model
121
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
122
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming,
123
+ prompter=prompter, context=context, iinput=iinput))
124
  llm = cls(**model_kwargs)
125
  llm.client.verbose = verbose
126
  elif model_name == 'gpt4all_llama':
 
128
  model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
129
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
130
  model_kwargs.update(
131
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming,
132
+ prompter=prompter, context=context, iinput=iinput))
133
  llm = cls(**model_kwargs)
134
  elif model_name == 'gptj':
135
  cls = H2OGPT4All
136
  model_path = env_kwargs.pop('model_path_gptj') if model is None else model
137
  model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
138
  model_kwargs.update(
139
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming,
140
+ prompter=prompter, context=context, iinput=iinput))
141
  llm = cls(**model_kwargs)
142
  else:
143
  raise RuntimeError("No such model_name %s" % model_name)
 
147
  class H2OGPT4All(gpt4all.GPT4All):
148
  model: Any
149
  prompter: Any
150
+ context: Any = ''
151
+ iinput: Any = ''
152
  """Path to the pre-trained GPT4All model file."""
153
 
154
  @root_validator()
 
194
  **kwargs,
195
  ) -> str:
196
  # Roughly 4 chars per token if natural language
197
+ n_ctx = 2048
198
+ prompt = prompt[-self.max_tokens * 4:]
199
 
200
  # use instruct prompting
201
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
202
  prompt = self.prompter.generate_prompt(data_point)
203
 
204
  verbose = False
 
214
  class H2OLlamaCpp(LlamaCpp):
215
  model_path: Any
216
  prompter: Any
217
+ context: Any
218
+ iinput: Any
219
  """Path to the pre-trained GPT4All model file."""
220
 
221
  @root_validator()
 
286
  print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
287
 
288
  # use instruct prompting
289
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
290
  prompt = self.prompter.generate_prompt(data_point)
291
 
292
  if verbose:
gpt_langchain.py CHANGED
@@ -24,8 +24,8 @@ from langchain.embeddings import HuggingFaceInstructEmbeddings
24
  from langchain.schema import LLMResult
25
  from tqdm import tqdm
26
 
27
- from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
28
- LangChainAction, LangChainMode
29
  from evaluate_params import gen_hyper
30
  from gen import get_model, SEED
31
  from prompter import non_hf_types, PromptType, Prompter
@@ -96,11 +96,15 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss',
96
  db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
97
  hf_embedding_model, verbose=False)
98
  if db is None:
 
 
 
 
99
  db = Chroma.from_documents(documents=sources,
100
  embedding=embedding,
101
  persist_directory=persist_directory,
102
  collection_name=collection_name,
103
- anonymized_telemetry=False)
104
  db.persist()
105
  clear_embedding(db)
106
  save_embed(db, use_openai_embedding, hf_embedding_model)
@@ -305,6 +309,8 @@ class GradioInference(LLM):
305
  sanitize_bot_response: bool = False
306
 
307
  prompter: Any = None
 
 
308
  client: Any = None
309
 
310
  class Config:
@@ -348,14 +354,15 @@ class GradioInference(LLM):
348
  stream_output = self.stream
349
  gr_client = self.client
350
  client_langchain_mode = 'Disabled'
 
351
  client_langchain_action = LangChainAction.QUERY.value
352
  client_langchain_agents = []
353
  top_k_docs = 1
354
  chunk = True
355
  chunk_size = 512
356
  client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
357
- iinput='', # only for chat=True
358
- context='',
359
  # streaming output is supported, loops over and outputs each generation in streaming mode
360
  # but leave stream_output=False for simple input/output mode
361
  stream_output=stream_output,
@@ -376,15 +383,16 @@ class GradioInference(LLM):
376
  chat=self.chat_client,
377
 
378
  instruction_nochat=prompt if not self.chat_client else '',
379
- iinput_nochat='', # only for chat=False
380
  langchain_mode=client_langchain_mode,
 
381
  langchain_action=client_langchain_action,
382
  langchain_agents=client_langchain_agents,
383
  top_k_docs=top_k_docs,
384
  chunk=chunk,
385
  chunk_size=chunk_size,
386
- document_subset=DocumentChoices.Relevant.name,
387
- document_choice=[],
388
  )
389
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
390
  if not stream_output:
@@ -454,6 +462,8 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
454
  stream: bool = False
455
  sanitize_bot_response: bool = False
456
  prompter: Any = None
 
 
457
  tokenizer: Any = None
458
  client: Any = None
459
 
@@ -495,7 +505,7 @@ class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
495
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
496
 
497
  # NOTE: TGI server does not add prompting, so must do here
498
- data_point = dict(context='', instruction=prompt, input='')
499
  prompt = self.prompter.generate_prompt(data_point)
500
 
501
  gen_server_kwargs = dict(do_sample=self.do_sample,
@@ -574,6 +584,8 @@ class H2OOpenAI(OpenAI):
574
  stop_sequences: Any = None
575
  sanitize_bot_response: bool = False
576
  prompter: Any = None
 
 
577
  tokenizer: Any = None
578
 
579
  @classmethod
@@ -599,7 +611,7 @@ class H2OOpenAI(OpenAI):
599
  for prompti, prompt in enumerate(prompts):
600
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
601
  # NOTE: OpenAI/vLLM server does not add prompting, so must do here
602
- data_point = dict(context='', instruction=prompt, input='')
603
  prompt = self.prompter.generate_prompt(data_point)
604
  prompts[prompti] = prompt
605
 
@@ -677,17 +689,22 @@ def get_llm(use_openai_model=False,
677
  prompt_type=None,
678
  prompt_dict=None,
679
  prompter=None,
 
 
680
  sanitize_bot_response=False,
681
  verbose=False,
682
  ):
 
 
683
  if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
684
  if use_openai_model and model_name is None:
685
  model_name = "gpt-3.5-turbo"
686
- openai, inf_type = set_openai(
687
- inference_server) # FIXME: Will later import be ignored? I think so, so should be fine
688
  kwargs_extra = {}
689
  if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
690
  cls = H2OChatOpenAI
 
691
  else:
692
  cls = H2OOpenAI
693
  if inf_type == 'vllm':
@@ -697,6 +714,8 @@ def get_llm(use_openai_model=False,
697
  kwargs_extra = dict(stop_sequences=stop_sequences,
698
  sanitize_bot_response=sanitize_bot_response,
699
  prompter=prompter,
 
 
700
  tokenizer=tokenizer,
701
  client=None)
702
 
@@ -711,7 +730,7 @@ def get_llm(use_openai_model=False,
711
  callbacks=callbacks if stream_output else None,
712
  openai_api_key=openai.api_key,
713
  openai_api_base=openai.api_base,
714
- logit_bias=None if inf_type =='vllm' else {},
715
  max_retries=2,
716
  streaming=stream_output,
717
  **kwargs_extra
@@ -769,6 +788,8 @@ def get_llm(use_openai_model=False,
769
  callbacks=callbacks if stream_output else None,
770
  stream=stream_output,
771
  prompter=prompter,
 
 
772
  client=gr_client,
773
  sanitize_bot_response=sanitize_bot_response,
774
  )
@@ -789,6 +810,8 @@ def get_llm(use_openai_model=False,
789
  callbacks=callbacks if stream_output else None,
790
  stream=stream_output,
791
  prompter=prompter,
 
 
792
  tokenizer=tokenizer,
793
  client=hf_client,
794
  timeout=max_time,
@@ -821,6 +844,8 @@ def get_llm(use_openai_model=False,
821
  verbose=verbose,
822
  streaming=stream_output,
823
  prompter=prompter,
 
 
824
  )
825
  else:
826
  if model is None:
@@ -863,6 +888,8 @@ def get_llm(use_openai_model=False,
863
  from h2oai_pipeline import H2OTextGenerationPipeline
864
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
865
  prompter=prompter,
 
 
866
  prompt_type=prompt_type,
867
  prompt_dict=prompt_dict,
868
  sanitize_bot_response=sanitize_bot_response,
@@ -1048,7 +1075,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1048
  is_url=False, is_txt=False,
1049
  enable_captions=True,
1050
  captions_model=None,
1051
- enable_ocr=False, caption_loader=None,
1052
  headsize=50):
1053
  if file is None:
1054
  if fail_any_exception:
@@ -1065,6 +1092,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1065
  base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
1066
  base_path = os.path.join(dir_name, base_name)
1067
  if is_url:
 
1068
  if file.lower().startswith('arxiv:'):
1069
  query = file.lower().split('arxiv:')
1070
  if len(query) == 2 and have_arxiv:
@@ -1216,21 +1244,54 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
1216
  from dotenv import dotenv_values
1217
  env_kwargs = dotenv_values(env_gpt4all_file)
1218
  pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
 
 
1219
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
1220
  # GPL, only use if installed
1221
  from langchain.document_loaders import PyMuPDFLoader
1222
  # load() still chunks by pages, but every page has title at start to help
1223
  doc1 = PyMuPDFLoader(file).load()
 
 
 
1224
  doc1 = clean_doc(doc1)
1225
- elif pdf_class_name == 'UnstructuredPDFLoader':
1226
  doc1 = UnstructuredPDFLoader(file).load()
 
 
 
1227
  # seems to not need cleaning in most cases
1228
- else:
1229
  # open-source fallback
1230
  # load() still chunks by pages, but every page has title at start to help
1231
  doc1 = PyPDFLoader(file).load()
 
 
 
1232
  doc1 = clean_doc(doc1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1233
  # Some PDFs return nothing or junk from PDFMinerLoader
 
 
 
 
 
 
1234
  doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1235
  add_meta(doc1, file)
1236
  elif file.lower().endswith('.csv'):
@@ -1283,7 +1344,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1283
  is_url=False, is_txt=False,
1284
  enable_captions=True,
1285
  captions_model=None,
1286
- enable_ocr=False, caption_loader=None):
1287
  if verbose:
1288
  if is_url:
1289
  print("Ingesting URL: %s" % file, flush=True)
@@ -1301,6 +1362,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1301
  enable_captions=enable_captions,
1302
  captions_model=captions_model,
1303
  enable_ocr=enable_ocr,
 
1304
  caption_loader=caption_loader)
1305
  except BaseException as e:
1306
  print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
@@ -1309,7 +1371,7 @@ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True
1309
  else:
1310
  exception_doc = Document(
1311
  page_content='',
1312
- metadata={"source": file, "exception": '%s hit %s' % (file, str(e)),
1313
  "traceback": traceback.format_exc()})
1314
  res = [exception_doc]
1315
  if return_file:
@@ -1330,6 +1392,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1330
  captions_model=None,
1331
  caption_loader=None,
1332
  enable_ocr=False,
 
1333
  existing_files=[],
1334
  existing_hash_ids={},
1335
  ):
@@ -1351,11 +1414,15 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1351
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1352
  for ftype in non_image_types]
1353
  else:
1354
- if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
1355
- path_or_paths = [path_or_paths]
 
 
 
 
1356
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1357
- assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
1358
- path_or_paths)
1359
  # reform out of allowed types
1360
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1361
  # could do below:
@@ -1407,6 +1474,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1407
  captions_model=captions_model,
1408
  caption_loader=caption_loader,
1409
  enable_ocr=enable_ocr,
 
1410
  )
1411
 
1412
  if n_jobs != 1 and len(globs_non_image_types) > 1:
@@ -1439,7 +1507,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1439
  with open(fil, 'rb') as f:
1440
  documents.extend(pickle.load(f))
1441
  # remove temp pickle
1442
- os.remove(fil)
1443
  else:
1444
  documents = reduce(concat, documents)
1445
  return documents
@@ -1447,7 +1515,7 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
1447
 
1448
  def prep_langchain(persist_directory,
1449
  load_db_if_exists,
1450
- db_type, use_openai_embedding, langchain_mode, user_path,
1451
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
1452
  """
1453
  do prep first time, involving downloads
@@ -1457,6 +1525,7 @@ def prep_langchain(persist_directory,
1457
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
1458
 
1459
  db_dir_exists = os.path.isdir(persist_directory)
 
1460
 
1461
  if db_dir_exists and user_path is None:
1462
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
@@ -1592,7 +1661,7 @@ def make_db(**langchain_kwargs):
1592
  langchain_kwargs[k] = defaults_db[k]
1593
  # final check for missing
1594
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1595
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1596
  # only keep actual used
1597
  langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
1598
  return _make_db(**langchain_kwargs)
@@ -1626,13 +1695,14 @@ def _make_db(use_openai_embedding=False,
1626
  first_para=False, text_limit=None,
1627
  chunk=True, chunk_size=512,
1628
  langchain_mode=None,
1629
- user_path=None,
1630
  db_type='faiss',
1631
  load_db_if_exists=True,
1632
  db=None,
1633
  n_jobs=-1,
1634
  verbose=False):
1635
  persist_directory = get_persist_directory(langchain_mode)
 
1636
  # see if can get persistent chroma db
1637
  db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1638
  hf_embedding_model, verbose=verbose)
@@ -1640,23 +1710,8 @@ def _make_db(use_openai_embedding=False,
1640
  db = db_trial
1641
 
1642
  sources = []
1643
- if not db and langchain_mode not in ['MyData'] or \
1644
- user_path is not None and \
1645
- langchain_mode in ['UserData']:
1646
- # Should not make MyData db this way, why avoided, only upload from UI
1647
- assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
1648
- if verbose:
1649
- if langchain_mode in ['UserData']:
1650
- if user_path is not None:
1651
- print("Checking if changed or new sources in %s, and generating sources them" % user_path,
1652
- flush=True)
1653
- elif db is None:
1654
- print("user_path not passed and no db, no sources", flush=True)
1655
- else:
1656
- print("user_path not passed, using only existing db, no new sources", flush=True)
1657
- else:
1658
- print("Generating %s sources" % langchain_mode, flush=True)
1659
- if langchain_mode in ['wiki_full', 'All', "'All'"]:
1660
  from read_wiki_full import get_all_documents
1661
  small_test = None
1662
  print("Generating new wiki", flush=True)
@@ -1666,55 +1721,48 @@ def _make_db(use_openai_embedding=False,
1666
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1667
  print("Chunked new wiki", flush=True)
1668
  sources.extend(sources1)
1669
- if langchain_mode in ['wiki', 'All', "'All'"]:
1670
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1671
  if chunk:
1672
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1673
  sources.extend(sources1)
1674
- if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
1675
  # sources = get_github_docs("dagster-io", "dagster")
1676
  sources1 = get_github_docs("h2oai", "h2ogpt")
1677
  # FIXME: always chunk for now
1678
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1679
  sources.extend(sources1)
1680
- if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
1681
  sources1 = get_dai_docs(from_hf=True)
1682
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1683
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1684
  sources.extend(sources1)
1685
- if langchain_mode in ['All', 'UserData']:
1686
- if user_path:
1687
- if db is not None:
1688
- # NOTE: Ignore file names for now, only go by hash ids
1689
- # existing_files = get_existing_files(db)
1690
- existing_files = []
1691
- existing_hash_ids = get_existing_hash_ids(db)
1692
- else:
1693
- # pretend no existing files so won't filter
1694
- existing_files = []
1695
- existing_hash_ids = []
1696
- # chunk internally for speed over multiple docs
1697
- # FIXME: If first had old Hash=None and switch embeddings,
1698
- # then re-embed, and then hit here and reload so have hash, and then re-embed.
1699
- sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1700
- existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1701
- new_metadata_sources = set([x.metadata['source'] for x in sources1])
1702
- if new_metadata_sources:
1703
- print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True)
1704
- if verbose:
1705
- print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1706
- sources.extend(sources1)
1707
- print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True)
1708
- else:
1709
- print("Chose UserData but user_path is empty/None", flush=True)
1710
- if False and langchain_mode in ['urls', 'All', "'All'"]:
1711
- # from langchain.document_loaders import UnstructuredURLLoader
1712
- # loader = UnstructuredURLLoader(urls=urls)
1713
- urls = ["https://www.birdsongsf.com/who-we-are/"]
1714
- from langchain.document_loaders import PlaywrightURLLoader
1715
- loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
1716
- sources1 = loader.load()
1717
- sources.extend(sources1)
1718
  if not sources:
1719
  if verbose:
1720
  if db is not None:
@@ -1737,7 +1785,7 @@ def _make_db(use_openai_embedding=False,
1737
  else:
1738
  print("Did not generate db since no sources", flush=True)
1739
  new_sources_metadata = [x.metadata for x in sources]
1740
- elif user_path is not None and langchain_mode in ['UserData']:
1741
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1742
  db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1743
  use_openai_embedding=use_openai_embedding,
@@ -1835,7 +1883,7 @@ def run_qa_db(**kwargs):
1835
  kwargs['answer_with_sources'] = True
1836
  kwargs['show_rank'] = False
1837
  missing_kwargs = [x for x in func_names if x not in kwargs]
1838
- assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
1839
  # only keep actual used
1840
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1841
  try:
@@ -1849,7 +1897,7 @@ def _run_qa_db(query=None,
1849
  context=None,
1850
  use_openai_model=False, use_openai_embedding=False,
1851
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1852
- user_path=None,
1853
  detect_user_path_changes_every_query=False,
1854
  db_type='faiss',
1855
  model_name=None, model=None, tokenizer=None, inference_server=None,
@@ -1859,7 +1907,8 @@ def _run_qa_db(query=None,
1859
  prompt_type=None,
1860
  prompt_dict=None,
1861
  answer_with_sources=True,
1862
- cut_distanct=1.1,
 
1863
  sanitize_bot_response=False,
1864
  show_rank=False,
1865
  use_llm_if_no_docs=False,
@@ -1879,8 +1928,8 @@ def _run_qa_db(query=None,
1879
  langchain_mode=None,
1880
  langchain_action=None,
1881
  langchain_agents=None,
1882
- document_subset=DocumentChoices.Relevant.name,
1883
- document_choice=[],
1884
  n_jobs=-1,
1885
  verbose=False,
1886
  cli=False,
@@ -1899,7 +1948,7 @@ def _run_qa_db(query=None,
1899
  :param top_k_docs:
1900
  :param chunk:
1901
  :param chunk_size:
1902
- :param user_path: user path to glob recursively from
1903
  :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1904
  :param model_name: model name, used to switch behaviors
1905
  :param model: pre-initialized model, else will make new one
@@ -1907,6 +1956,7 @@ def _run_qa_db(query=None,
1907
  :param answer_with_sources
1908
  :return:
1909
  """
 
1910
  if model is not None:
1911
  assert model_name is not None # require so can make decisions
1912
  assert query is not None
@@ -1921,6 +1971,8 @@ def _run_qa_db(query=None,
1921
  else:
1922
  prompt_dict = ''
1923
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
 
 
1924
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1925
  model=model,
1926
  tokenizer=tokenizer,
@@ -1940,11 +1992,13 @@ def _run_qa_db(query=None,
1940
  prompt_type=prompt_type,
1941
  prompt_dict=prompt_dict,
1942
  prompter=prompter,
 
 
1943
  sanitize_bot_response=sanitize_bot_response,
1944
  verbose=verbose,
1945
  )
1946
 
1947
- use_context = False
1948
  scores = []
1949
  chain = None
1950
 
@@ -1956,9 +2010,13 @@ def _run_qa_db(query=None,
1956
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
1957
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
1958
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
1959
- docs, chain, scores, use_context, have_any_docs = get_chain(**sim_kwargs)
1960
  if document_subset in non_query_commands:
1961
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
 
 
 
 
1962
  yield formatted_doc_chunks, ''
1963
  return
1964
  if not use_llm_if_no_docs:
@@ -1970,7 +2028,6 @@ def _run_qa_db(query=None,
1970
  yield ret, extra
1971
  return
1972
  if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
1973
- LangChainMode.CHAT_LLM.value,
1974
  LangChainMode.LLM.value]:
1975
  ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
1976
  extra = ''
@@ -2026,7 +2083,7 @@ def _run_qa_db(query=None,
2026
  else:
2027
  answer = chain()
2028
 
2029
- if not use_context:
2030
  ret = answer['output_text']
2031
  extra = ''
2032
  yield ret, extra
@@ -2038,9 +2095,10 @@ def _run_qa_db(query=None,
2038
 
2039
  def get_chain(query=None,
2040
  iinput=None,
 
2041
  use_openai_model=False, use_openai_embedding=False,
2042
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
2043
- user_path=None,
2044
  detect_user_path_changes_every_query=False,
2045
  db_type='faiss',
2046
  model_name=None,
@@ -2048,14 +2106,15 @@ def get_chain(query=None,
2048
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
2049
  prompt_type=None,
2050
  prompt_dict=None,
2051
- cut_distanct=1.1,
 
2052
  load_db_if_exists=False,
2053
  db=None,
2054
  langchain_mode=None,
2055
  langchain_action=None,
2056
  langchain_agents=None,
2057
- document_subset=DocumentChoices.Relevant.name,
2058
- document_choice=[],
2059
  n_jobs=-1,
2060
  # beyond run_db_query:
2061
  llm=None,
@@ -2070,12 +2129,12 @@ def get_chain(query=None,
2070
  assert langchain_agents is not None # should be at least []
2071
  # determine whether use of context out of docs is planned
2072
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2073
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
2074
- use_context = False
2075
  else:
2076
- use_context = True
2077
  else:
2078
- use_context = True
2079
 
2080
  # https://github.com/hwchase17/langchain/issues/1946
2081
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
@@ -2092,14 +2151,17 @@ def get_chain(query=None,
2092
  # avoid looking at user_path during similarity search db handling,
2093
  # if already have db and not updating from user_path every query
2094
  # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
2095
- user_path = None
 
 
 
2096
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
2097
  hf_embedding_model=hf_embedding_model,
2098
  first_para=first_para, text_limit=text_limit,
2099
  chunk=chunk,
2100
  chunk_size=chunk_size,
2101
  langchain_mode=langchain_mode,
2102
- user_path=user_path,
2103
  db_type=db_type,
2104
  load_db_if_exists=load_db_if_exists,
2105
  db=db,
@@ -2119,7 +2181,7 @@ def get_chain(query=None,
2119
  else:
2120
  extra = ""
2121
  prefix = ""
2122
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
2123
  template_if_no_docs = template = """%s{context}{question}""" % prefix
2124
  else:
2125
  template = """%s
@@ -2160,7 +2222,7 @@ def get_chain(query=None,
2160
  else:
2161
  use_template = False
2162
 
2163
- if db and use_context:
2164
  base_path = 'locks'
2165
  makedirs(base_path)
2166
  if hasattr(db, '_persist_directory'):
@@ -2174,10 +2236,10 @@ def get_chain(query=None,
2174
  filter_kwargs = {}
2175
  else:
2176
  assert document_choice is not None, "Document choice was None"
2177
- if len(document_choice) >= 1 and document_choice[0] == DocumentChoices.All.name:
2178
  filter_kwargs = {}
2179
  elif len(document_choice) >= 2:
2180
- if document_choice[0] == DocumentChoices.All.name:
2181
  # remove 'All'
2182
  document_choice = document_choice[1:]
2183
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
@@ -2189,10 +2251,10 @@ def get_chain(query=None,
2189
  else:
2190
  # shouldn't reach
2191
  filter_kwargs = {}
2192
- if langchain_mode in [LangChainMode.LLM.value, LangChainMode.CHAT_LLM.value]:
2193
  docs = []
2194
  scores = []
2195
- elif document_subset == DocumentChoices.All.name or query in [None, '', '\n']:
2196
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2197
  # similar to langchain's chroma's _results_to_docs_and_scores
2198
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
@@ -2280,8 +2342,8 @@ def get_chain(query=None,
2280
  docs_with_score.reverse()
2281
  # cut off so no high distance docs/sources considered
2282
  have_any_docs |= len(docs_with_score) > 0 # before cut
2283
- docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
2284
- scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
2285
  if len(scores) > 0 and verbose:
2286
  print("Distance: min: %s max: %s mean: %s median: %s" %
2287
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
@@ -2289,7 +2351,7 @@ def get_chain(query=None,
2289
  docs = []
2290
  scores = []
2291
 
2292
- if not docs and use_context and model_name not in non_hf_types:
2293
  # if HF type and have no docs, can bail out
2294
  return docs, None, [], False, have_any_docs
2295
 
@@ -2312,7 +2374,7 @@ def get_chain(query=None,
2312
 
2313
  if len(docs) == 0:
2314
  # avoid context == in prompt then
2315
- use_context = False
2316
  template = template_if_no_docs
2317
 
2318
  if langchain_action == LangChainAction.QUERY.value:
@@ -2328,7 +2390,7 @@ def get_chain(query=None,
2328
  else:
2329
  # only if use_openai_model = True, unused normally except in testing
2330
  chain = load_qa_with_sources_chain(llm)
2331
- if not use_context:
2332
  chain_kwargs = dict(input_documents=[], question=query)
2333
  else:
2334
  chain_kwargs = dict(input_documents=docs, question=query)
@@ -2355,7 +2417,7 @@ def get_chain(query=None,
2355
  else:
2356
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
2357
 
2358
- return docs, target, scores, use_context, have_any_docs
2359
 
2360
 
2361
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
 
24
  from langchain.schema import LLMResult
25
  from tqdm import tqdm
26
 
27
+ from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \
28
+ LangChainAction, LangChainMode, DocumentChoice
29
  from evaluate_params import gen_hyper
30
  from gen import get_model, SEED
31
  from prompter import non_hf_types, PromptType, Prompter
 
96
  db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
97
  hf_embedding_model, verbose=False)
98
  if db is None:
99
+ from chromadb.config import Settings
100
+ client_settings = Settings(anonymized_telemetry=False,
101
+ chroma_db_impl="duckdb+parquet",
102
+ persist_directory=persist_directory)
103
  db = Chroma.from_documents(documents=sources,
104
  embedding=embedding,
105
  persist_directory=persist_directory,
106
  collection_name=collection_name,
107
+ client_settings=client_settings)
108
  db.persist()
109
  clear_embedding(db)
110
  save_embed(db, use_openai_embedding, hf_embedding_model)
 
309
  sanitize_bot_response: bool = False
310
 
311
  prompter: Any = None
312
+ context: Any = ''
313
+ iinput: Any = ''
314
  client: Any = None
315
 
316
  class Config:
 
354
  stream_output = self.stream
355
  gr_client = self.client
356
  client_langchain_mode = 'Disabled'
357
+ client_add_chat_history_to_context = True
358
  client_langchain_action = LangChainAction.QUERY.value
359
  client_langchain_agents = []
360
  top_k_docs = 1
361
  chunk = True
362
  chunk_size = 512
363
  client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
364
+ iinput=self.iinput if self.chat_client else '', # only for chat=True
365
+ context=self.context,
366
  # streaming output is supported, loops over and outputs each generation in streaming mode
367
  # but leave stream_output=False for simple input/output mode
368
  stream_output=stream_output,
 
383
  chat=self.chat_client,
384
 
385
  instruction_nochat=prompt if not self.chat_client else '',
386
+ iinput_nochat=self.iinput if not self.chat_client else '',
387
  langchain_mode=client_langchain_mode,
388
+ add_chat_history_to_context=client_add_chat_history_to_context,
389
  langchain_action=client_langchain_action,
390
  langchain_agents=client_langchain_agents,
391
  top_k_docs=top_k_docs,
392
  chunk=chunk,
393
  chunk_size=chunk_size,
394
+ document_subset=DocumentSubset.Relevant.name,
395
+ document_choice=[DocumentChoice.ALL.value],
396
  )
397
  api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
398
  if not stream_output:
 
462
  stream: bool = False
463
  sanitize_bot_response: bool = False
464
  prompter: Any = None
465
+ context: Any = ''
466
+ iinput: Any = ''
467
  tokenizer: Any = None
468
  client: Any = None
469
 
 
505
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
506
 
507
  # NOTE: TGI server does not add prompting, so must do here
508
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
509
  prompt = self.prompter.generate_prompt(data_point)
510
 
511
  gen_server_kwargs = dict(do_sample=self.do_sample,
 
584
  stop_sequences: Any = None
585
  sanitize_bot_response: bool = False
586
  prompter: Any = None
587
+ context: Any = ''
588
+ iinput: Any = ''
589
  tokenizer: Any = None
590
 
591
  @classmethod
 
611
  for prompti, prompt in enumerate(prompts):
612
  prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
613
  # NOTE: OpenAI/vLLM server does not add prompting, so must do here
614
+ data_point = dict(context=self.context, instruction=prompt, input=self.iinput)
615
  prompt = self.prompter.generate_prompt(data_point)
616
  prompts[prompti] = prompt
617
 
 
689
  prompt_type=None,
690
  prompt_dict=None,
691
  prompter=None,
692
+ context=None,
693
+ iinput=None,
694
  sanitize_bot_response=False,
695
  verbose=False,
696
  ):
697
+ if inference_server is None:
698
+ inference_server = ''
699
  if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'):
700
  if use_openai_model and model_name is None:
701
  model_name = "gpt-3.5-turbo"
702
+ # FIXME: Will later import be ignored? I think so, so should be fine
703
+ openai, inf_type = set_openai(inference_server)
704
  kwargs_extra = {}
705
  if inference_server == 'openai_chat' or inf_type == 'vllm_chat':
706
  cls = H2OChatOpenAI
707
+ # FIXME: Support context, iinput
708
  else:
709
  cls = H2OOpenAI
710
  if inf_type == 'vllm':
 
714
  kwargs_extra = dict(stop_sequences=stop_sequences,
715
  sanitize_bot_response=sanitize_bot_response,
716
  prompter=prompter,
717
+ context=context,
718
+ iinput=iinput,
719
  tokenizer=tokenizer,
720
  client=None)
721
 
 
730
  callbacks=callbacks if stream_output else None,
731
  openai_api_key=openai.api_key,
732
  openai_api_base=openai.api_base,
733
+ logit_bias=None if inf_type == 'vllm' else {},
734
  max_retries=2,
735
  streaming=stream_output,
736
  **kwargs_extra
 
788
  callbacks=callbacks if stream_output else None,
789
  stream=stream_output,
790
  prompter=prompter,
791
+ context=context,
792
+ iinput=iinput,
793
  client=gr_client,
794
  sanitize_bot_response=sanitize_bot_response,
795
  )
 
810
  callbacks=callbacks if stream_output else None,
811
  stream=stream_output,
812
  prompter=prompter,
813
+ context=context,
814
+ iinput=iinput,
815
  tokenizer=tokenizer,
816
  client=hf_client,
817
  timeout=max_time,
 
844
  verbose=verbose,
845
  streaming=stream_output,
846
  prompter=prompter,
847
+ context=context,
848
+ iinput=iinput,
849
  )
850
  else:
851
  if model is None:
 
888
  from h2oai_pipeline import H2OTextGenerationPipeline
889
  pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
890
  prompter=prompter,
891
+ context=context,
892
+ iinpout=iinput,
893
  prompt_type=prompt_type,
894
  prompt_dict=prompt_dict,
895
  sanitize_bot_response=sanitize_bot_response,
 
1075
  is_url=False, is_txt=False,
1076
  enable_captions=True,
1077
  captions_model=None,
1078
+ enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None,
1079
  headsize=50):
1080
  if file is None:
1081
  if fail_any_exception:
 
1092
  base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
1093
  base_path = os.path.join(dir_name, base_name)
1094
  if is_url:
1095
+ file = file.strip() # in case accidental spaces in front or at end
1096
  if file.lower().startswith('arxiv:'):
1097
  query = file.lower().split('arxiv:')
1098
  if len(query) == 2 and have_arxiv:
 
1244
  from dotenv import dotenv_values
1245
  env_kwargs = dotenv_values(env_gpt4all_file)
1246
  pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser')
1247
+ doc1 = []
1248
+ handled = False
1249
  if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
1250
  # GPL, only use if installed
1251
  from langchain.document_loaders import PyMuPDFLoader
1252
  # load() still chunks by pages, but every page has title at start to help
1253
  doc1 = PyMuPDFLoader(file).load()
1254
+ # remove empty documents
1255
+ handled |= len(doc1) > 0
1256
+ doc1 = [x for x in doc1 if x.page_content]
1257
  doc1 = clean_doc(doc1)
1258
+ if len(doc1) == 0:
1259
  doc1 = UnstructuredPDFLoader(file).load()
1260
+ handled |= len(doc1) > 0
1261
+ # remove empty documents
1262
+ doc1 = [x for x in doc1 if x.page_content]
1263
  # seems to not need cleaning in most cases
1264
+ if len(doc1) == 0:
1265
  # open-source fallback
1266
  # load() still chunks by pages, but every page has title at start to help
1267
  doc1 = PyPDFLoader(file).load()
1268
+ handled |= len(doc1) > 0
1269
+ # remove empty documents
1270
+ doc1 = [x for x in doc1 if x.page_content]
1271
  doc1 = clean_doc(doc1)
1272
+ if have_pymupdf and len(doc1) == 0:
1273
+ # GPL, only use if installed
1274
+ from langchain.document_loaders import PyMuPDFLoader
1275
+ # load() still chunks by pages, but every page has title at start to help
1276
+ doc1 = PyMuPDFLoader(file).load()
1277
+ handled |= len(doc1) > 0
1278
+ # remove empty documents
1279
+ doc1 = [x for x in doc1 if x.page_content]
1280
+ doc1 = clean_doc(doc1)
1281
+ if len(doc1) == 0 and enable_pdf_ocr == 'auto' or enable_pdf_ocr == 'on':
1282
+ # try OCR in end since slowest, but works on pure image pages well
1283
+ doc1 = UnstructuredPDFLoader(file, strategy='ocr_only').load()
1284
+ handled |= len(doc1) > 0
1285
+ # remove empty documents
1286
+ doc1 = [x for x in doc1 if x.page_content]
1287
+ # seems to not need cleaning in most cases
1288
  # Some PDFs return nothing or junk from PDFMinerLoader
1289
+ if len(doc1) == 0:
1290
+ # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all.
1291
+ if handled:
1292
+ raise ValueError("%s had no valid text, but meta data was parsed" % file)
1293
+ else:
1294
+ raise ValueError("%s had no valid text and no meta data was parsed" % file)
1295
  doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
1296
  add_meta(doc1, file)
1297
  elif file.lower().endswith('.csv'):
 
1344
  is_url=False, is_txt=False,
1345
  enable_captions=True,
1346
  captions_model=None,
1347
+ enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None):
1348
  if verbose:
1349
  if is_url:
1350
  print("Ingesting URL: %s" % file, flush=True)
 
1362
  enable_captions=enable_captions,
1363
  captions_model=captions_model,
1364
  enable_ocr=enable_ocr,
1365
+ enable_pdf_ocr=enable_pdf_ocr,
1366
  caption_loader=caption_loader)
1367
  except BaseException as e:
1368
  print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
 
1371
  else:
1372
  exception_doc = Document(
1373
  page_content='',
1374
+ metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)),
1375
  "traceback": traceback.format_exc()})
1376
  res = [exception_doc]
1377
  if return_file:
 
1392
  captions_model=None,
1393
  caption_loader=None,
1394
  enable_ocr=False,
1395
+ enable_pdf_ocr='auto',
1396
  existing_files=[],
1397
  existing_hash_ids={},
1398
  ):
 
1414
  [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
1415
  for ftype in non_image_types]
1416
  else:
1417
+ if isinstance(path_or_paths, str):
1418
+ if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths):
1419
+ path_or_paths = [path_or_paths]
1420
+ else:
1421
+ # path was deleted etc.
1422
+ return []
1423
  # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
1424
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \
1425
+ "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths))
1426
  # reform out of allowed types
1427
  globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
1428
  # could do below:
 
1474
  captions_model=captions_model,
1475
  caption_loader=caption_loader,
1476
  enable_ocr=enable_ocr,
1477
+ enable_pdf_ocr=enable_pdf_ocr,
1478
  )
1479
 
1480
  if n_jobs != 1 and len(globs_non_image_types) > 1:
 
1507
  with open(fil, 'rb') as f:
1508
  documents.extend(pickle.load(f))
1509
  # remove temp pickle
1510
+ remove(fil)
1511
  else:
1512
  documents = reduce(concat, documents)
1513
  return documents
 
1515
 
1516
  def prep_langchain(persist_directory,
1517
  load_db_if_exists,
1518
+ db_type, use_openai_embedding, langchain_mode, langchain_mode_paths,
1519
  hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
1520
  """
1521
  do prep first time, involving downloads
 
1525
  assert langchain_mode not in ['MyData'], "Should not prep scratch data"
1526
 
1527
  db_dir_exists = os.path.isdir(persist_directory)
1528
+ user_path = langchain_mode_paths.get(langchain_mode)
1529
 
1530
  if db_dir_exists and user_path is None:
1531
  print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
 
1661
  langchain_kwargs[k] = defaults_db[k]
1662
  # final check for missing
1663
  missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
1664
+ assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs
1665
  # only keep actual used
1666
  langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
1667
  return _make_db(**langchain_kwargs)
 
1695
  first_para=False, text_limit=None,
1696
  chunk=True, chunk_size=512,
1697
  langchain_mode=None,
1698
+ langchain_mode_paths=None,
1699
  db_type='faiss',
1700
  load_db_if_exists=True,
1701
  db=None,
1702
  n_jobs=-1,
1703
  verbose=False):
1704
  persist_directory = get_persist_directory(langchain_mode)
1705
+ user_path = langchain_mode_paths.get(langchain_mode)
1706
  # see if can get persistent chroma db
1707
  db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
1708
  hf_embedding_model, verbose=verbose)
 
1710
  db = db_trial
1711
 
1712
  sources = []
1713
+ if not db:
1714
+ if langchain_mode in ['wiki_full']:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1715
  from read_wiki_full import get_all_documents
1716
  small_test = None
1717
  print("Generating new wiki", flush=True)
 
1721
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1722
  print("Chunked new wiki", flush=True)
1723
  sources.extend(sources1)
1724
+ elif langchain_mode in ['wiki']:
1725
  sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
1726
  if chunk:
1727
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1728
  sources.extend(sources1)
1729
+ elif langchain_mode in ['github h2oGPT']:
1730
  # sources = get_github_docs("dagster-io", "dagster")
1731
  sources1 = get_github_docs("h2oai", "h2ogpt")
1732
  # FIXME: always chunk for now
1733
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1734
  sources.extend(sources1)
1735
+ elif langchain_mode in ['DriverlessAI docs']:
1736
  sources1 = get_dai_docs(from_hf=True)
1737
  if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
1738
  sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
1739
  sources.extend(sources1)
1740
+ if user_path:
1741
+ # UserData or custom, which has to be from user's disk
1742
+ if db is not None:
1743
+ # NOTE: Ignore file names for now, only go by hash ids
1744
+ # existing_files = get_existing_files(db)
1745
+ existing_files = []
1746
+ existing_hash_ids = get_existing_hash_ids(db)
1747
+ else:
1748
+ # pretend no existing files so won't filter
1749
+ existing_files = []
1750
+ existing_hash_ids = []
1751
+ # chunk internally for speed over multiple docs
1752
+ # FIXME: If first had old Hash=None and switch embeddings,
1753
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
1754
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
1755
+ existing_files=existing_files, existing_hash_ids=existing_hash_ids)
1756
+ new_metadata_sources = set([x.metadata['source'] for x in sources1])
1757
+ if new_metadata_sources:
1758
+ print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode),
1759
+ flush=True)
1760
+ if verbose:
1761
+ print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True)
1762
+ sources.extend(sources1)
1763
+ print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True)
1764
+
1765
+ # see if got sources
 
 
 
 
 
 
 
1766
  if not sources:
1767
  if verbose:
1768
  if db is not None:
 
1785
  else:
1786
  print("Did not generate db since no sources", flush=True)
1787
  new_sources_metadata = [x.metadata for x in sources]
1788
+ elif user_path is not None:
1789
  print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
1790
  db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
1791
  use_openai_embedding=use_openai_embedding,
 
1883
  kwargs['answer_with_sources'] = True
1884
  kwargs['show_rank'] = False
1885
  missing_kwargs = [x for x in func_names if x not in kwargs]
1886
+ assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs
1887
  # only keep actual used
1888
  kwargs = {k: v for k, v in kwargs.items() if k in func_names}
1889
  try:
 
1897
  context=None,
1898
  use_openai_model=False, use_openai_embedding=False,
1899
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
1900
+ langchain_mode_paths={},
1901
  detect_user_path_changes_every_query=False,
1902
  db_type='faiss',
1903
  model_name=None, model=None, tokenizer=None, inference_server=None,
 
1907
  prompt_type=None,
1908
  prompt_dict=None,
1909
  answer_with_sources=True,
1910
+ cut_distance=1.64,
1911
+ add_chat_history_to_context=True,
1912
  sanitize_bot_response=False,
1913
  show_rank=False,
1914
  use_llm_if_no_docs=False,
 
1928
  langchain_mode=None,
1929
  langchain_action=None,
1930
  langchain_agents=None,
1931
+ document_subset=DocumentSubset.Relevant.name,
1932
+ document_choice=[DocumentChoice.ALL.value],
1933
  n_jobs=-1,
1934
  verbose=False,
1935
  cli=False,
 
1948
  :param top_k_docs:
1949
  :param chunk:
1950
  :param chunk_size:
1951
+ :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from
1952
  :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db
1953
  :param model_name: model name, used to switch behaviors
1954
  :param model: pre-initialized model, else will make new one
 
1956
  :param answer_with_sources
1957
  :return:
1958
  """
1959
+ assert langchain_mode_paths is not None
1960
  if model is not None:
1961
  assert model_name is not None # require so can make decisions
1962
  assert query is not None
 
1971
  else:
1972
  prompt_dict = ''
1973
  assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
1974
+ # pass in context to LLM directly, since already has prompt_type structure
1975
+ # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638
1976
  llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
1977
  model=model,
1978
  tokenizer=tokenizer,
 
1992
  prompt_type=prompt_type,
1993
  prompt_dict=prompt_dict,
1994
  prompter=prompter,
1995
+ context=context if add_chat_history_to_context else '',
1996
+ iinput=iinput if add_chat_history_to_context else '',
1997
  sanitize_bot_response=sanitize_bot_response,
1998
  verbose=verbose,
1999
  )
2000
 
2001
+ use_docs_planned = False
2002
  scores = []
2003
  chain = None
2004
 
 
2010
  sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
2011
  missing_kwargs = [x for x in func_names if x not in sim_kwargs]
2012
  assert not missing_kwargs, "Missing: %s" % missing_kwargs
2013
+ docs, chain, scores, use_docs_planned, have_any_docs = get_chain(**sim_kwargs)
2014
  if document_subset in non_query_commands:
2015
  formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
2016
+ if not formatted_doc_chunks and not use_llm_if_no_docs:
2017
+ yield "No sources", ''
2018
+ return
2019
+ # if no souces, outside gpt_langchain, LLM will be used with '' input
2020
  yield formatted_doc_chunks, ''
2021
  return
2022
  if not use_llm_if_no_docs:
 
2028
  yield ret, extra
2029
  return
2030
  if not docs and langchain_mode not in [LangChainMode.DISABLED.value,
 
2031
  LangChainMode.LLM.value]:
2032
  ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.'
2033
  extra = ''
 
2083
  else:
2084
  answer = chain()
2085
 
2086
+ if not use_docs_planned:
2087
  ret = answer['output_text']
2088
  extra = ''
2089
  yield ret, extra
 
2095
 
2096
  def get_chain(query=None,
2097
  iinput=None,
2098
+ context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638
2099
  use_openai_model=False, use_openai_embedding=False,
2100
  first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
2101
+ langchain_mode_paths=None,
2102
  detect_user_path_changes_every_query=False,
2103
  db_type='faiss',
2104
  model_name=None,
 
2106
  hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
2107
  prompt_type=None,
2108
  prompt_dict=None,
2109
+ cut_distance=1.1,
2110
+ add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638
2111
  load_db_if_exists=False,
2112
  db=None,
2113
  langchain_mode=None,
2114
  langchain_action=None,
2115
  langchain_agents=None,
2116
+ document_subset=DocumentSubset.Relevant.name,
2117
+ document_choice=[DocumentChoice.ALL.value],
2118
  n_jobs=-1,
2119
  # beyond run_db_query:
2120
  llm=None,
 
2129
  assert langchain_agents is not None # should be at least []
2130
  # determine whether use of context out of docs is planned
2131
  if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
2132
+ if langchain_mode in ['Disabled', 'LLM']:
2133
+ use_docs_planned = False
2134
  else:
2135
+ use_docs_planned = True
2136
  else:
2137
+ use_docs_planned = True
2138
 
2139
  # https://github.com/hwchase17/langchain/issues/1946
2140
  # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
 
2151
  # avoid looking at user_path during similarity search db handling,
2152
  # if already have db and not updating from user_path every query
2153
  # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was
2154
+ if langchain_mode_paths is None:
2155
+ langchain_mode_paths = {}
2156
+ langchain_mode_paths = langchain_mode_paths.copy()
2157
+ langchain_mode_paths[langchain_mode] = None
2158
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
2159
  hf_embedding_model=hf_embedding_model,
2160
  first_para=first_para, text_limit=text_limit,
2161
  chunk=chunk,
2162
  chunk_size=chunk_size,
2163
  langchain_mode=langchain_mode,
2164
+ langchain_mode_paths=langchain_mode_paths,
2165
  db_type=db_type,
2166
  load_db_if_exists=load_db_if_exists,
2167
  db=db,
 
2181
  else:
2182
  extra = ""
2183
  prefix = ""
2184
+ if langchain_mode in ['Disabled', 'LLM'] or not use_docs_planned:
2185
  template_if_no_docs = template = """%s{context}{question}""" % prefix
2186
  else:
2187
  template = """%s
 
2222
  else:
2223
  use_template = False
2224
 
2225
+ if db and use_docs_planned:
2226
  base_path = 'locks'
2227
  makedirs(base_path)
2228
  if hasattr(db, '_persist_directory'):
 
2236
  filter_kwargs = {}
2237
  else:
2238
  assert document_choice is not None, "Document choice was None"
2239
+ if len(document_choice) >= 1 and document_choice[0] == DocumentChoice.ALL.value:
2240
  filter_kwargs = {}
2241
  elif len(document_choice) >= 2:
2242
+ if document_choice[0] == DocumentChoice.ALL.value:
2243
  # remove 'All'
2244
  document_choice = document_choice[1:]
2245
  or_filter = [{"source": {"$eq": x}} for x in document_choice]
 
2251
  else:
2252
  # shouldn't reach
2253
  filter_kwargs = {}
2254
+ if langchain_mode in [LangChainMode.LLM.value]:
2255
  docs = []
2256
  scores = []
2257
+ elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']:
2258
  db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
2259
  # similar to langchain's chroma's _results_to_docs_and_scores
2260
  docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
 
2342
  docs_with_score.reverse()
2343
  # cut off so no high distance docs/sources considered
2344
  have_any_docs |= len(docs_with_score) > 0 # before cut
2345
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distance]
2346
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distance]
2347
  if len(scores) > 0 and verbose:
2348
  print("Distance: min: %s max: %s mean: %s median: %s" %
2349
  (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
 
2351
  docs = []
2352
  scores = []
2353
 
2354
+ if not docs and use_docs_planned and model_name not in non_hf_types:
2355
  # if HF type and have no docs, can bail out
2356
  return docs, None, [], False, have_any_docs
2357
 
 
2374
 
2375
  if len(docs) == 0:
2376
  # avoid context == in prompt then
2377
+ use_docs_planned = False
2378
  template = template_if_no_docs
2379
 
2380
  if langchain_action == LangChainAction.QUERY.value:
 
2390
  else:
2391
  # only if use_openai_model = True, unused normally except in testing
2392
  chain = load_qa_with_sources_chain(llm)
2393
+ if not use_docs_planned:
2394
  chain_kwargs = dict(input_documents=[], question=query)
2395
  else:
2396
  chain_kwargs = dict(input_documents=docs, question=query)
 
2417
  else:
2418
  raise RuntimeError("No such langchain_action=%s" % langchain_action)
2419
 
2420
+ return docs, target, scores, use_docs_planned, have_any_docs
2421
 
2422
 
2423
  def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False):
gradio_runner.py CHANGED
@@ -50,16 +50,20 @@ def fix_pydantic_duplicate_validators_error():
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
- from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode
 
54
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
55
  text_xsm
56
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
57
  get_prompt
58
- from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
59
- ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip
60
- from gen import get_model, languages_covered, evaluate, score_qa, langchain_modes, inputs_kwargs_list, scratch_base_dir, \
61
- get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list
62
- from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults
 
 
 
63
 
64
  from apscheduler.schedulers.background import BackgroundScheduler
65
 
@@ -94,12 +98,9 @@ def go_gradio(**kwargs):
94
  memory_restriction_level = kwargs['memory_restriction_level']
95
  n_gpus = kwargs['n_gpus']
96
  admin_pass = kwargs['admin_pass']
97
- model_state0 = kwargs['model_state0']
98
  model_states = kwargs['model_states']
99
- score_model_state0 = kwargs['score_model_state0']
100
  dbs = kwargs['dbs']
101
  db_type = kwargs['db_type']
102
- visible_langchain_modes = kwargs['visible_langchain_modes']
103
  visible_langchain_actions = kwargs['visible_langchain_actions']
104
  visible_langchain_agents = kwargs['visible_langchain_agents']
105
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
@@ -112,8 +113,19 @@ def go_gradio(**kwargs):
112
  enable_captions = kwargs['enable_captions']
113
  captions_model = kwargs['captions_model']
114
  enable_ocr = kwargs['enable_ocr']
 
115
  caption_loader = kwargs['caption_loader']
116
 
 
 
 
 
 
 
 
 
 
 
117
  # easy update of kwargs needed for evaluate() etc.
118
  queue = True
119
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
@@ -133,25 +145,11 @@ def go_gradio(**kwargs):
133
  " use Enter for multiple input lines)"
134
 
135
  title = 'h2oGPT'
136
- more_info = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
137
- if kwargs['verbose']:
138
- description = f"""Model {kwargs['base_model']} Instruct dataset.
139
- For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
140
- Command: {str(' '.join(sys.argv))}
141
- Hash: {get_githash()}
142
- """
143
- else:
144
- description = more_info
145
- description_bottom = "If this host is busy, try [Multi-Model](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
146
  if is_hf:
147
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
148
-
149
- if kwargs['verbose']:
150
- task_info_md = f"""
151
- ### Task: {kwargs['task_info']}"""
152
- else:
153
- task_info_md = ''
154
-
155
  css_code = get_css(kwargs)
156
 
157
  if kwargs['gradio_offline_level'] >= 0:
@@ -181,9 +179,9 @@ def go_gradio(**kwargs):
181
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
182
  callback = gr.CSVLogger()
183
 
184
- model_options = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
185
- if kwargs['base_model'].strip() not in model_options:
186
- model_options = [kwargs['base_model'].strip()] + model_options
187
  lora_options = kwargs['extra_lora_options']
188
  if kwargs['lora_weights'].strip() not in lora_options:
189
  lora_options = [kwargs['lora_weights'].strip()] + lora_options
@@ -198,7 +196,7 @@ def go_gradio(**kwargs):
198
 
199
  # always add in no lora case
200
  # add fake space so doesn't go away in gradio dropdown
201
- model_options = [no_model_str] + model_options
202
  lora_options = [no_lora_str] + lora_options
203
  server_options = [no_server_str] + server_options
204
  # always add in no model case so can free memory
@@ -252,6 +250,14 @@ def go_gradio(**kwargs):
252
  # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
253
  return x
254
 
 
 
 
 
 
 
 
 
255
  with demo:
256
  # avoid actual model/tokenizer here or anything that would be bad to deepcopy
257
  # https://github.com/gradio-app/gradio/issues/3558
@@ -265,18 +271,32 @@ def go_gradio(**kwargs):
265
  prompt_dict=kwargs['prompt_dict'],
266
  )
267
  )
 
 
 
 
 
 
 
 
 
 
 
268
  model_state2 = gr.State(kwargs['model_state_none'].copy())
269
- model_options_state = gr.State([model_options])
270
  lora_options_state = gr.State([lora_options])
271
  server_options_state = gr.State([server_options])
272
- my_db_state = gr.State([None, None])
273
  chat_state = gr.State({})
274
- docs_state00 = kwargs['document_choice'] + [DocumentChoices.All.name]
275
  docs_state0 = []
276
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
277
  docs_state = gr.State(docs_state0)
278
  viewable_docs_state0 = []
279
  viewable_docs_state = gr.State(viewable_docs_state0)
 
 
 
280
  gr.Markdown(f"""
281
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
282
  """)
@@ -290,7 +310,7 @@ def go_gradio(**kwargs):
290
  'model_lock'] else "Response Scores: %s" % nas
291
 
292
  if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
293
- extra_prompt_form = ". For summarization, empty submission uses first top_k_docs documents."
294
  else:
295
  extra_prompt_form = ""
296
  if kwargs['input_lines'] > 1:
@@ -298,6 +318,34 @@ def go_gradio(**kwargs):
298
  else:
299
  instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  normal_block = gr.Row(visible=not base_wanted, equal_height=False)
302
  with normal_block:
303
  side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
@@ -318,6 +366,7 @@ def go_gradio(**kwargs):
318
  scale=1,
319
  min_width=0,
320
  elem_id="warning", elem_classes="feedback")
 
321
  url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
322
  url_label = 'URL/ArXiv' if have_arxiv else 'URL'
323
  url_text = gr.Textbox(label=url_label,
@@ -331,29 +380,20 @@ def go_gradio(**kwargs):
331
  visible=text_visible)
332
  github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
333
  database_visible = kwargs['langchain_mode'] != 'Disabled'
334
- with gr.Accordion("Database", open=False, visible=database_visible):
335
- if is_hf:
336
- # don't show 'wiki' since only usually useful for internal testing at moment
337
- no_show_modes = ['Disabled', 'wiki']
338
- else:
339
- no_show_modes = ['Disabled']
340
- allowed_modes = visible_langchain_modes.copy()
341
- allowed_modes = [x for x in allowed_modes if x in dbs]
342
- allowed_modes += ['ChatLLM', 'LLM']
343
- if allow_upload_to_my_data and 'MyData' not in allowed_modes:
344
- allowed_modes += ['MyData']
345
- if allow_upload_to_user_data and 'UserData' not in allowed_modes:
346
- allowed_modes += ['UserData']
347
  langchain_mode = gr.Radio(
348
- [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
349
  value=kwargs['langchain_mode'],
350
  label="Collections",
351
  show_label=True,
352
  visible=kwargs['langchain_mode'] != 'Disabled',
353
  min_width=100)
354
- document_subset = gr.Radio([x.name for x in DocumentChoices],
 
 
355
  label="Subset",
356
- value=DocumentChoices.Relevant.name,
357
  interactive=True,
358
  )
359
  allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
@@ -417,9 +457,9 @@ def go_gradio(**kwargs):
417
  mw1 = 50
418
  mw2 = 50
419
  with gr.Column(min_width=mw1):
420
- submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm',
421
  min_width=mw1)
422
- stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm',
423
  min_width=mw1)
424
  save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
425
  with gr.Column(min_width=mw2):
@@ -440,20 +480,50 @@ def go_gradio(**kwargs):
440
  with gr.TabItem("Document Selection"):
441
  document_choice = gr.Dropdown(docs_state0,
442
  label="Select Subset of Document(s) %s" % file_types_str,
443
- value='All',
444
  interactive=True,
445
  multiselect=True,
446
  visible=kwargs['langchain_mode'] != 'Disabled',
447
  )
448
  sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
449
  with gr.Row():
450
- get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
451
- visible=sources_visible)
452
- show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
453
- visible=sources_visible)
454
- refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
455
- size='sm',
456
- visible=sources_visible and allow_upload_to_user_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
459
  equal_height=False)
@@ -723,19 +793,20 @@ def go_gradio(**kwargs):
723
  side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
724
  submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
725
  col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
726
- text_outputs_height = gr.Slider(minimum=100, maximum=1000, value=kwargs['height'] or 400,
727
- step=100, label='Chat Height')
728
  dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
729
  with gr.Column(scale=4):
730
  pass
 
731
  admin_row = gr.Row()
732
  with admin_row:
733
  with gr.Column(scale=1):
734
- admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', visible=is_public)
735
- admin_btn = gr.Button(value="Admin Access", visible=is_public, size='sm')
736
  with gr.Column(scale=4):
737
  pass
738
- system_row = gr.Row(visible=not is_public)
739
  with system_row:
740
  with gr.Column():
741
  with gr.Row():
@@ -799,23 +870,24 @@ def go_gradio(**kwargs):
799
  else:
800
  return tuple([gr.update(interactive=True)] * len(args))
801
 
802
- # Add to UserData
803
  update_db_func = functools.partial(update_user_db,
804
  dbs=dbs,
805
  db_type=db_type,
806
  use_openai_embedding=use_openai_embedding,
807
  hf_embedding_model=hf_embedding_model,
808
- enable_captions=enable_captions,
809
  captions_model=captions_model,
810
- enable_ocr=enable_ocr,
811
  caption_loader=caption_loader,
 
 
812
  verbose=kwargs['verbose'],
813
- user_path=kwargs['user_path'],
814
  n_jobs=kwargs['n_jobs'],
815
  )
816
  add_file_outputs = [fileup_output, langchain_mode]
817
  add_file_kwargs = dict(fn=update_db_func,
818
- inputs=[fileup_output, my_db_state, chunk, chunk_size, langchain_mode],
 
819
  outputs=add_file_outputs + [sources_text, doc_exception_text],
820
  queue=queue,
821
  api_name='add_file' if allow_api and allow_upload_to_user_data else None)
@@ -827,6 +899,15 @@ def go_gradio(**kwargs):
827
  eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
828
  show_progress='minimal')
829
 
 
 
 
 
 
 
 
 
 
830
  # note for update_user_db_func output is ignored for db
831
 
832
  def clear_textbox():
@@ -836,7 +917,8 @@ def go_gradio(**kwargs):
836
 
837
  add_url_outputs = [url_text, langchain_mode]
838
  add_url_kwargs = dict(fn=update_user_db_url_func,
839
- inputs=[url_text, my_db_state, chunk, chunk_size, langchain_mode],
 
840
  outputs=add_url_outputs + [sources_text, doc_exception_text],
841
  queue=queue,
842
  api_name='add_url' if allow_api and allow_upload_to_user_data else None)
@@ -853,7 +935,8 @@ def go_gradio(**kwargs):
853
  update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
854
  add_text_outputs = [user_text_text, langchain_mode]
855
  add_text_kwargs = dict(fn=update_user_db_txt_func,
856
- inputs=[user_text_text, my_db_state, chunk, chunk_size, langchain_mode],
 
857
  outputs=add_text_outputs + [sources_text, doc_exception_text],
858
  queue=queue,
859
  api_name='add_text' if allow_api and allow_upload_to_user_data else None
@@ -865,7 +948,7 @@ def go_gradio(**kwargs):
865
  eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
866
  eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
867
  show_progress='minimal')
868
- db_events = [eventdb1a, eventdb1, eventdb1b,
869
  eventdb2a, eventdb2, eventdb2b, eventdb2c,
870
  eventdb3a, eventdb3b, eventdb3, eventdb3c]
871
 
@@ -873,14 +956,14 @@ def go_gradio(**kwargs):
873
 
874
  # if change collection source, must clear doc selections from it to avoid inconsistency
875
  def clear_doc_choice():
876
- return gr.Dropdown.update(choices=docs_state0, value=DocumentChoices.All.name)
877
 
878
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
879
 
880
  def resize_col_tabs(x):
881
  return gr.Dropdown.update(scale=x)
882
 
883
- col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs)
884
 
885
  def resize_chatbots(x, num_model_lock=0):
886
  if num_model_lock == 0:
@@ -891,7 +974,7 @@ def go_gradio(**kwargs):
891
 
892
  resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
893
  text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
894
- outputs=[text_output, text_output2] + text_outputs)
895
 
896
  def update_dropdown(x):
897
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
@@ -982,7 +1065,8 @@ def go_gradio(**kwargs):
982
  if file.startswith('http') or file.startswith('https'):
983
  # if file is online, then might as well use google(?)
984
  document1 = file
985
- return gr.update(visible=True, value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
 
986
  </iframe>
987
  """), dummy1, dummy1, dummy1
988
  else:
@@ -1005,9 +1089,11 @@ def go_gradio(**kwargs):
1005
 
1006
  refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
1007
  **get_kwargs(update_and_get_source_files_given_langchain_mode,
1008
- exclude_names=['db1', 'langchain_mode'],
 
1009
  **all_kwargs))
1010
- eventdb9 = refresh_sources_btn.click(fn=refresh_sources1, inputs=[my_db_state, langchain_mode],
 
1011
  outputs=sources_text,
1012
  api_name='refresh_sources' if allow_api else None)
1013
 
@@ -1017,9 +1103,153 @@ def go_gradio(**kwargs):
1017
  def close_admin(x):
1018
  return gr.update(visible=not (x == admin_pass))
1019
 
1020
- admin_btn.click(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
1021
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
1022
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1023
  inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
1024
  inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
1025
  from functools import partial
@@ -1031,11 +1261,11 @@ def go_gradio(**kwargs):
1031
  def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
1032
  args_list = list(args1)
1033
  if str_api:
1034
- user_kwargs = args_list[2]
1035
  assert isinstance(user_kwargs, str)
1036
  user_kwargs = ast.literal_eval(user_kwargs)
1037
  else:
1038
- user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[2:])}
1039
  # only used for submit_nochat_api
1040
  user_kwargs['chat'] = False
1041
  if 'stream_output' not in user_kwargs:
@@ -1054,10 +1284,11 @@ def go_gradio(**kwargs):
1054
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
1055
  model_state1 = args_list[0]
1056
  my_db_state1 = args_list[1]
 
1057
  args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
1058
  in eval_func_param_names]
1059
  assert len(args_list) == len(eval_func_param_names)
1060
- args_list = [model_state1, my_db_state1] + args_list
1061
 
1062
  try:
1063
  for res_dict in evaluate(*tuple(args_list), **kwargs1):
@@ -1261,10 +1492,7 @@ def go_gradio(**kwargs):
1261
  history[-1][1] = None
1262
  return history
1263
  if user_message1 in ['', None, '\n']:
1264
- if langchain_action1 in LangChainAction.QUERY.value and \
1265
- DocumentChoices.All.name != document_subset1 \
1266
- or \
1267
- langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1268
  # reject non-retry submit/enter
1269
  return history
1270
  user_message1 = fix_text_for_gradio(user_message1)
@@ -1311,10 +1539,12 @@ def go_gradio(**kwargs):
1311
  API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1312
  :return: last element is True if should run bot, False if should just yield history
1313
  """
 
1314
  # don't deepcopy, can contain model itself
1315
  args_list = list(args).copy()
1316
- model_state1 = args_list[-3]
1317
- my_db_state1 = args_list[-2]
 
1318
  history = args_list[-1]
1319
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1320
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
@@ -1322,8 +1552,9 @@ def go_gradio(**kwargs):
1322
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1323
  return history, None, None, None
1324
 
1325
- args_list = args_list[:-3] # only keep rest needed for evaluate()
1326
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
 
1327
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1328
  langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1329
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
@@ -1338,10 +1569,7 @@ def go_gradio(**kwargs):
1338
  instruction1 = history[-1][0]
1339
  history[-1][1] = None
1340
  elif not instruction1:
1341
- if langchain_action1 in LangChainAction.QUERY.value and \
1342
- DocumentChoices.All.name != document_choice1 \
1343
- or \
1344
- langchain_mode1 in [LangChainMode.CHAT_LLM.value, LangChainMode.LLM.value]:
1345
  # if not retrying, then reject empty query
1346
  return history, None, None, None
1347
  elif len(history) > 0 and history[-1][1] not in [None, '']:
@@ -1358,7 +1586,9 @@ def go_gradio(**kwargs):
1358
 
1359
  chat1 = args_list[eval_func_param_names.index('chat')]
1360
  model_max_length1 = get_model_max_length(model_state1)
1361
- context1 = history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1,
 
 
1362
  model_max_length1, memory_restriction_level,
1363
  kwargs['keep_sources_in_context'])
1364
  args_list[0] = instruction1 # override original instruction with history from user
@@ -1367,6 +1597,7 @@ def go_gradio(**kwargs):
1367
  fun1 = partial(evaluate,
1368
  model_state1,
1369
  my_db_state1,
 
1370
  *tuple(args_list),
1371
  **kwargs_evaluate)
1372
 
@@ -1412,24 +1643,26 @@ def go_gradio(**kwargs):
1412
  clear_torch_cache()
1413
  return
1414
 
1415
- def clear_embeddings(langchain_mode1, my_db):
1416
  # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
1417
- if db_type == 'chroma' and langchain_mode1 not in ['ChatLLM', 'LLM', 'Disabled', None, '']:
1418
  from gpt_langchain import clear_embedding
1419
  db = dbs.get('langchain_mode1')
1420
  if db is not None and not isinstance(db, str):
1421
  clear_embedding(db)
1422
- if langchain_mode1 == LangChainMode.MY_DATA.value and my_db is not None:
1423
- clear_embedding(my_db[0])
 
 
1424
 
1425
  def bot(*args, retry=False):
1426
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*args, retry=retry)
1427
  try:
1428
  for res in get_response(fun1, history):
1429
  yield res
1430
  finally:
1431
  clear_torch_cache()
1432
- clear_embeddings(langchain_mode1, my_db_state1)
1433
 
1434
  def all_bot(*args, retry=False, model_states1=None):
1435
  args_list = list(args).copy()
@@ -1439,12 +1672,14 @@ def go_gradio(**kwargs):
1439
  stream_output1 = args_list[eval_func_param_names.index('stream_output')]
1440
  max_time1 = args_list[eval_func_param_names.index('max_time')]
1441
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1442
- my_db_state1 = None # will be filled below by some bot
 
1443
  try:
1444
  gen_list = []
1445
  for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1446
  args_list1 = args_list0.copy()
1447
- args_list1.insert(-1, model_state1) # insert at -1 so is at -2
 
1448
  # if at start, have None in response still, replace with '' so client etc. acts like normal
1449
  # assumes other parts of code treat '' and None as if no response yet from bot
1450
  # can't do this later in bot code as racy with threaded generators
@@ -1454,8 +1689,8 @@ def go_gradio(**kwargs):
1454
  # so consistent with prep_bot()
1455
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1456
  # langchain_mode1 and my_db_state1 should be same for every bot
1457
- history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry,
1458
- which_model=chatboti)
1459
  gen1 = get_response(fun1, history)
1460
  if stream_output1:
1461
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
@@ -1501,7 +1736,7 @@ def go_gradio(**kwargs):
1501
  print("Generate exceptions: %s" % exceptions, flush=True)
1502
  finally:
1503
  clear_torch_cache()
1504
- clear_embeddings(langchain_mode1, my_db_state1)
1505
 
1506
  # NORMAL MODEL
1507
  user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
@@ -1509,11 +1744,11 @@ def go_gradio(**kwargs):
1509
  outputs=text_output,
1510
  )
1511
  bot_args = dict(fn=bot,
1512
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1513
  outputs=[text_output, chat_exception_text],
1514
  )
1515
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1516
- inputs=inputs_list + [model_state, my_db_state] + [text_output],
1517
  outputs=[text_output, chat_exception_text],
1518
  )
1519
  retry_user_args = dict(fn=functools.partial(user, retry=True),
@@ -1531,11 +1766,11 @@ def go_gradio(**kwargs):
1531
  outputs=text_output2,
1532
  )
1533
  bot_args2 = dict(fn=bot,
1534
- inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1535
  outputs=[text_output2, chat_exception_text],
1536
  )
1537
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1538
- inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
1539
  outputs=[text_output2, chat_exception_text],
1540
  )
1541
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
@@ -1556,11 +1791,11 @@ def go_gradio(**kwargs):
1556
  outputs=text_outputs,
1557
  )
1558
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1559
- inputs=inputs_list + [my_db_state] + text_outputs,
1560
  outputs=text_outputs + [chat_exception_text],
1561
  )
1562
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1563
- inputs=inputs_list + [my_db_state] + text_outputs,
1564
  outputs=text_outputs + [chat_exception_text],
1565
  )
1566
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
@@ -1722,6 +1957,11 @@ def go_gradio(**kwargs):
1722
  def get_short_chat(x, short_chats, short_len=20, words=4):
1723
  if x and len(x[0]) == 2 and x[0][0] is not None:
1724
  short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
 
 
 
 
 
1725
  short_chat = dedup(short_chat, short_chats)
1726
  else:
1727
  short_chat = None
@@ -1789,14 +2029,12 @@ def go_gradio(**kwargs):
1789
  already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
1790
  if not already_exists:
1791
  chat_state1[short_chat] = chat_list.copy()
1792
- # clear chat_list so saved and then new conversation starts
1793
- # FIXME: seems less confusing to clear, since have clear button right next
1794
- # chat_list = [[]] * len(chat_list)
1795
- if not chat_is_list:
1796
- ret_list = chat_list + [chat_state1]
1797
- else:
1798
- ret_list = [chat_list] + [chat_state1]
1799
- return tuple(ret_list)
1800
 
1801
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
1802
  chosen_chat = chat_state1[chat_key]
@@ -1827,7 +2065,7 @@ def go_gradio(**kwargs):
1827
 
1828
  remove_chat_event = remove_chat_btn.click(remove_chat,
1829
  inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
1830
- queue=False)
1831
 
1832
  def get_chats1(chat_state1):
1833
  base = 'chats'
@@ -1858,7 +2096,7 @@ def go_gradio(**kwargs):
1858
  new_chats = json.loads(f.read())
1859
  for chat1_k, chat1_v in new_chats.items():
1860
  # ignore chat1_k, regenerate and de-dup to avoid loss
1861
- _, chat_state1 = save_chat(chat1_v, chat_state1, chat_is_list=True)
1862
  except BaseException as e:
1863
  t, v, tb = sys.exc_info()
1864
  ex = ''.join(traceback.format_exception(t, v, tb))
@@ -1884,24 +2122,17 @@ def go_gradio(**kwargs):
1884
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
1885
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1886
 
1887
- def update_radio_chats(chat_state1):
1888
- # reverse so newest at top
1889
- choices = list(chat_state1.keys()).copy()
1890
- choices.reverse()
1891
- return gr.update(choices=choices, value=None)
1892
-
1893
  clear_event = save_chat_btn.click(save_chat,
1894
  inputs=[text_output, text_output2] + text_outputs + [chat_state],
1895
- outputs=[text_output, text_output2] + text_outputs + [chat_state],
1896
- api_name='save_chat' if allow_api else None) \
1897
- .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1898
- api_name='update_chats' if allow_api else None) \
1899
- .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
1900
 
1901
  # NOTE: clear of instruction/iinput for nochat has to come after score,
1902
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
1903
  no_chat_args = dict(fn=fun,
1904
- inputs=[model_state, my_db_state] + inputs_list,
1905
  outputs=text_output_nochat,
1906
  queue=queue,
1907
  )
@@ -1920,7 +2151,8 @@ def go_gradio(**kwargs):
1920
  .then(clear_torch_cache)
1921
 
1922
  submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
1923
- inputs=[model_state, my_db_state, inputs_dict_str],
 
1924
  outputs=text_output_nochat_api,
1925
  queue=True, # required for generator
1926
  api_name='submit_nochat_api' if allow_api else None) \
@@ -2170,6 +2402,8 @@ def go_gradio(**kwargs):
2170
  print("Exception: %s" % str(e), flush=True)
2171
  return json.dumps(sys_dict)
2172
 
 
 
2173
  get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
2174
 
2175
  system_dict_event = system_btn2.click(get_system_info_dict_func,
@@ -2199,12 +2433,15 @@ def go_gradio(**kwargs):
2199
  else:
2200
  tokenizer = None
2201
  if tokenizer is not None:
2202
- langchain_mode1 = 'ChatLLM'
 
2203
  # fake user message to mimic bot()
2204
  chat1 = copy.deepcopy(chat1)
2205
  chat1 = chat1 + [['user_message1', None]]
2206
  model_max_length1 = tokenizer.model_max_length
2207
- context1 = history_to_context(chat1, langchain_mode1, prompt_type1, prompt_dict1, chat1,
 
 
2208
  model_max_length1,
2209
  memory_restriction_level1, keep_sources_in_context1)
2210
  return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
@@ -2234,7 +2471,7 @@ def go_gradio(**kwargs):
2234
  ,
2235
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2236
 
2237
- demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best
2238
 
2239
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
2240
  favicon_path = "h2o-logo.svg"
@@ -2249,7 +2486,8 @@ def go_gradio(**kwargs):
2249
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
2250
  # FIXME: and any multi-threaded/async print will enter model output!
2251
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
2252
- scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
 
2253
  scheduler.start()
2254
 
2255
  # import control
@@ -2268,9 +2506,6 @@ def go_gradio(**kwargs):
2268
  demo.block_thread()
2269
 
2270
 
2271
- input_args_list = ['model_state', 'my_db_state']
2272
-
2273
-
2274
  def get_inputs_list(inputs_dict, model_lower, model_id=1):
2275
  """
2276
  map gradio objects in locals() to inputs for evaluate().
@@ -2304,8 +2539,9 @@ def get_inputs_list(inputs_dict, model_lower, model_id=1):
2304
  return inputs_list, inputs_dict_out
2305
 
2306
 
2307
- def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2308
- set_userid(db1)
 
2309
 
2310
  if langchain_mode in ['ChatLLM', 'LLM']:
2311
  source_files_added = "NA"
@@ -2314,7 +2550,8 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None):
2314
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
2315
  " Ask [email protected] for file if required."
2316
  source_list = []
2317
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
 
2318
  from gpt_langchain import get_metadatas
2319
  metadatas = get_metadatas(db1[0])
2320
  source_list = sorted(set([x['source'] for x in metadatas]))
@@ -2345,14 +2582,13 @@ def set_userid(db1):
2345
  db1[1] = str(uuid.uuid4())
2346
 
2347
 
2348
- def update_user_db(file, db1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
2349
- set_userid(db1)
2350
-
2351
  if file is None:
2352
  raise RuntimeError("Don't use change, use input")
2353
 
2354
  try:
2355
- return _update_user_db(file, db1=db1, chunk=chunk, chunk_size=chunk_size,
2356
  langchain_mode=langchain_mode, dbs=dbs,
2357
  **kwargs)
2358
  except BaseException as e:
@@ -2383,25 +2619,30 @@ def get_lock_file(db1, langchain_mode):
2383
  user_id = db1[1]
2384
  base_path = 'locks'
2385
  makedirs(base_path)
2386
- lock_file = "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)
2387
  return lock_file
2388
 
2389
 
2390
  def _update_user_db(file,
2391
- db1=None,
2392
  chunk=None, chunk_size=None,
2393
- dbs=None, db_type=None, langchain_mode='UserData',
2394
- user_path=None,
 
 
 
2395
  use_openai_embedding=None,
2396
  hf_embedding_model=None,
2397
  caption_loader=None,
2398
  enable_captions=None,
2399
  captions_model=None,
2400
  enable_ocr=None,
 
2401
  verbose=None,
 
2402
  is_url=None, is_txt=None,
2403
- n_jobs=-1):
2404
- assert db1 is not None
2405
  assert chunk is not None
2406
  assert chunk_size is not None
2407
  assert use_openai_embedding is not None
@@ -2410,10 +2651,9 @@ def _update_user_db(file,
2410
  assert enable_captions is not None
2411
  assert captions_model is not None
2412
  assert enable_ocr is not None
 
2413
  assert verbose is not None
2414
 
2415
- set_userid(db1)
2416
-
2417
  if dbs is None:
2418
  dbs = {}
2419
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
@@ -2431,17 +2671,22 @@ def _update_user_db(file,
2431
  if langchain_mode == LangChainMode.DISABLED.value:
2432
  return None, langchain_mode, get_source_files(), ""
2433
 
2434
- if langchain_mode in [LangChainMode.CHAT_LLM.value, LangChainMode.CHAT_LLM.value]:
2435
  # then switch to MyData, so langchain_mode also becomes way to select where upload goes
2436
  # but default to mydata if nothing chosen, since safest
2437
- langchain_mode = LangChainMode.MY_DATA.value
2438
-
2439
- if langchain_mode == 'UserData' and user_path is not None:
 
 
 
 
 
2440
  # move temp files from gradio upload to stable location
2441
  for fili, fil in enumerate(file):
2442
- if isinstance(fil, str):
2443
- if fil.startswith('/tmp/gradio/'):
2444
- new_fil = os.path.join(user_path, os.path.basename(fil))
2445
  if os.path.isfile(new_fil):
2446
  remove(new_fil)
2447
  try:
@@ -2461,15 +2706,22 @@ def _update_user_db(file,
2461
  enable_captions=enable_captions,
2462
  captions_model=captions_model,
2463
  enable_ocr=enable_ocr,
 
2464
  caption_loader=caption_loader,
2465
  )
2466
  exceptions = [x for x in sources if x.metadata.get('exception')]
2467
  exceptions_strs = [x.metadata['exception'] for x in exceptions]
2468
  sources = [x for x in sources if 'exception' not in x.metadata]
2469
 
2470
- lock_file = get_lock_file(db1, langchain_mode)
 
 
 
 
 
 
2471
  with filelock.FileLock(lock_file):
2472
- if langchain_mode == 'MyData':
2473
  if db1[0] is not None:
2474
  # then add
2475
  db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
@@ -2479,7 +2731,8 @@ def _update_user_db(file,
2479
  # in testing expect:
2480
  # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
2481
  # for production hit, when user gets clicky:
2482
- assert len(db1) == 2, "Bad MyData db: %s" % db1
 
2483
  # then create
2484
  # if added has to original state and didn't change, then would be shared db for all users
2485
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
@@ -2501,7 +2754,7 @@ def _update_user_db(file,
2501
  use_openai_embedding=use_openai_embedding,
2502
  hf_embedding_model=hf_embedding_model)
2503
  else:
2504
- # then create
2505
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2506
  db_type=db_type,
2507
  persist_directory=persist_directory,
@@ -2515,14 +2768,15 @@ def _update_user_db(file,
2515
  return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2516
 
2517
 
2518
- def get_db(db1, langchain_mode, dbs=None):
2519
- lock_file = get_lock_file(db1, langchain_mode)
 
2520
 
2521
  with filelock.FileLock(lock_file):
2522
  if langchain_mode in ['wiki_full']:
2523
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2524
  db = None
2525
- elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
2526
  db = db1[0]
2527
  elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
2528
  db = dbs[langchain_mode]
@@ -2531,8 +2785,8 @@ def get_db(db1, langchain_mode, dbs=None):
2531
  return db
2532
 
2533
 
2534
- def get_source_files_given_langchain_mode(db1, langchain_mode='UserData', dbs=None):
2535
- db = get_db(db1, langchain_mode, dbs=dbs)
2536
  if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
2537
  return "Sources: N/A"
2538
  return get_source_files(db=db, exceptions=None)
@@ -2631,11 +2885,19 @@ def get_source_files(db=None, exceptions=None, metadatas=None):
2631
  return source_files_added
2632
 
2633
 
2634
- def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=None, first_para=None,
2635
- text_limit=None, chunk=None, chunk_size=None,
2636
- user_path=None, db_type=None, load_db_if_exists=None,
 
2637
  n_jobs=None, verbose=None):
2638
- db = get_db(db1, langchain_mode, dbs=dbs)
 
 
 
 
 
 
 
2639
 
2640
  from gpt_langchain import make_db
2641
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
@@ -2644,11 +2906,27 @@ def update_and_get_source_files_given_langchain_mode(db1, langchain_mode, dbs=No
2644
  chunk=chunk,
2645
  chunk_size=chunk_size,
2646
  langchain_mode=langchain_mode,
2647
- user_path=user_path,
2648
  db_type=db_type,
2649
  load_db_if_exists=load_db_if_exists,
2650
  db=db,
2651
  n_jobs=n_jobs,
2652
  verbose=verbose)
 
 
 
 
 
 
 
2653
  # return only new sources with text saying such
2654
  return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
 
 
 
 
 
 
 
 
 
 
50
 
51
  fix_pydantic_duplicate_validators_error()
52
 
53
+ from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \
54
+ DocumentChoice, langchain_modes_intrinsic
55
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
56
  text_xsm
57
  from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
58
  get_prompt
59
+ from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
60
+ ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \
61
+ save_collection_names
62
+ from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \
63
+ get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \
64
+ update_langchain
65
+ from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \
66
+ input_args_list
67
 
68
  from apscheduler.schedulers.background import BackgroundScheduler
69
 
 
98
  memory_restriction_level = kwargs['memory_restriction_level']
99
  n_gpus = kwargs['n_gpus']
100
  admin_pass = kwargs['admin_pass']
 
101
  model_states = kwargs['model_states']
 
102
  dbs = kwargs['dbs']
103
  db_type = kwargs['db_type']
 
104
  visible_langchain_actions = kwargs['visible_langchain_actions']
105
  visible_langchain_agents = kwargs['visible_langchain_agents']
106
  allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
 
113
  enable_captions = kwargs['enable_captions']
114
  captions_model = kwargs['captions_model']
115
  enable_ocr = kwargs['enable_ocr']
116
+ enable_pdf_ocr = kwargs['enable_pdf_ocr']
117
  caption_loader = kwargs['caption_loader']
118
 
119
+ # for dynamic state per user session in gradio
120
+ model_state0 = kwargs['model_state0']
121
+ score_model_state0 = kwargs['score_model_state0']
122
+ my_db_state0 = kwargs['my_db_state0']
123
+ selection_docs_state0 = kwargs['selection_docs_state0']
124
+ # for evaluate defaults
125
+ langchain_modes0 = kwargs['langchain_modes']
126
+ visible_langchain_modes0 = kwargs['visible_langchain_modes']
127
+ langchain_mode_paths0 = kwargs['langchain_mode_paths']
128
+
129
  # easy update of kwargs needed for evaluate() etc.
130
  queue = True
131
  allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
 
145
  " use Enter for multiple input lines)"
146
 
147
  title = 'h2oGPT'
148
+ description = """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="250" height="20" title="GitHub"></iframe><small><a href="https://github.com/h2oai/h2ogpt">h2oGPT</a> <a href="https://github.com/h2oai/h2o-llmstudio">H2O LLM Studio</a><br><a href="https://huggingface.co/h2oai">🤗 Models</a>"""
149
+ description_bottom = "If this host is busy, try<br>[Multi-Model](https://gpt.h2o.ai)<br>[Falcon 40B](https://falcon.h2o.ai)<br>[Vicuna 33B](https://wizardvicuna.h2o.ai)<br>[MPT 30B-Chat](https://mpt.h2o.ai)<br>[HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)<br>[HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
 
 
 
 
 
 
 
 
150
  if is_hf:
151
  description_bottom += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
152
+ task_info_md = ''
 
 
 
 
 
 
153
  css_code = get_css(kwargs)
154
 
155
  if kwargs['gradio_offline_level'] >= 0:
 
179
  demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False)
180
  callback = gr.CSVLogger()
181
 
182
+ model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options']
183
+ if kwargs['base_model'].strip() not in model_options0:
184
+ model_options0 = [kwargs['base_model'].strip()] + model_options0
185
  lora_options = kwargs['extra_lora_options']
186
  if kwargs['lora_weights'].strip() not in lora_options:
187
  lora_options = [kwargs['lora_weights'].strip()] + lora_options
 
196
 
197
  # always add in no lora case
198
  # add fake space so doesn't go away in gradio dropdown
199
+ model_options0 = [no_model_str] + model_options0
200
  lora_options = [no_lora_str] + lora_options
201
  server_options = [no_server_str] + server_options
202
  # always add in no model case so can free memory
 
250
  # else gets input_list at time of submit that is old, and shows up as truncated in chatbot
251
  return x
252
 
253
+ def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
254
+ allow = False
255
+ allow |= langchain_action1 not in LangChainAction.QUERY.value
256
+ allow |= document_subset1 in DocumentSubset.TopKSources.name
257
+ if langchain_mode1 in [LangChainMode.LLM.value]:
258
+ allow = False
259
+ return allow
260
+
261
  with demo:
262
  # avoid actual model/tokenizer here or anything that would be bad to deepcopy
263
  # https://github.com/gradio-app/gradio/issues/3558
 
271
  prompt_dict=kwargs['prompt_dict'],
272
  )
273
  )
274
+
275
+ def update_langchain_mode_paths(db1s, selection_docs_state1):
276
+ if allow_upload_to_my_data:
277
+ selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s})
278
+ dup = selection_docs_state1['langchain_mode_paths'].copy()
279
+ for k, v in dup.items():
280
+ if k not in selection_docs_state1['visible_langchain_modes']:
281
+ selection_docs_state1['langchain_mode_paths'].pop(k)
282
+ return selection_docs_state1
283
+
284
+ # Setup some gradio states for per-user dynamic state
285
  model_state2 = gr.State(kwargs['model_state_none'].copy())
286
+ model_options_state = gr.State([model_options0])
287
  lora_options_state = gr.State([lora_options])
288
  server_options_state = gr.State([server_options])
289
+ my_db_state = gr.State(my_db_state0)
290
  chat_state = gr.State({})
291
+ docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value]
292
  docs_state0 = []
293
  [docs_state0.append(x) for x in docs_state00 if x not in docs_state0]
294
  docs_state = gr.State(docs_state0)
295
  viewable_docs_state0 = []
296
  viewable_docs_state = gr.State(viewable_docs_state0)
297
+ selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0)
298
+ selection_docs_state = gr.State(selection_docs_state0)
299
+
300
  gr.Markdown(f"""
301
  {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)}
302
  """)
 
310
  'model_lock'] else "Response Scores: %s" % nas
311
 
312
  if kwargs['langchain_mode'] != LangChainMode.DISABLED.value:
313
+ extra_prompt_form = ". For summarization, no query required, just click submit"
314
  else:
315
  extra_prompt_form = ""
316
  if kwargs['input_lines'] > 1:
 
318
  else:
319
  instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form
320
 
321
+ def get_langchain_choices(selection_docs_state1):
322
+ langchain_modes = selection_docs_state1['langchain_modes']
323
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
324
+
325
+ if is_hf:
326
+ # don't show 'wiki' since only usually useful for internal testing at moment
327
+ no_show_modes = ['Disabled', 'wiki']
328
+ else:
329
+ no_show_modes = ['Disabled']
330
+ allowed_modes = visible_langchain_modes.copy()
331
+ # allowed_modes = [x for x in allowed_modes if x in dbs]
332
+ allowed_modes += ['LLM']
333
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
334
+ allowed_modes += ['MyData']
335
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
336
+ allowed_modes += ['UserData']
337
+ choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes]
338
+ return choices
339
+
340
+ def get_df_langchain_mode_paths(selection_docs_state1):
341
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
342
+ if langchain_mode_paths:
343
+ df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns')
344
+ df.columns = ['Collection', 'Path']
345
+ else:
346
+ df = pd.DataFrame(None)
347
+ return df
348
+
349
  normal_block = gr.Row(visible=not base_wanted, equal_height=False)
350
  with normal_block:
351
  side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100)
 
366
  scale=1,
367
  min_width=0,
368
  elem_id="warning", elem_classes="feedback")
369
+ fileup_output_text = gr.Textbox(visible=False)
370
  url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload
371
  url_label = 'URL/ArXiv' if have_arxiv else 'URL'
372
  url_text = gr.Textbox(label=url_label,
 
380
  visible=text_visible)
381
  github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP
382
  database_visible = kwargs['langchain_mode'] != 'Disabled'
383
+ with gr.Accordion("Resources", open=False, visible=database_visible):
384
+ langchain_choices0 = get_langchain_choices(selection_docs_state0)
 
 
 
 
 
 
 
 
 
 
 
385
  langchain_mode = gr.Radio(
386
+ langchain_choices0,
387
  value=kwargs['langchain_mode'],
388
  label="Collections",
389
  show_label=True,
390
  visible=kwargs['langchain_mode'] != 'Disabled',
391
  min_width=100)
392
+ add_chat_history_to_context = gr.Checkbox(label="Chat History",
393
+ value=kwargs['add_chat_history_to_context'])
394
+ document_subset = gr.Radio([x.name for x in DocumentSubset],
395
  label="Subset",
396
+ value=DocumentSubset.Relevant.name,
397
  interactive=True,
398
  )
399
  allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions]
 
457
  mw1 = 50
458
  mw2 = 50
459
  with gr.Column(min_width=mw1):
460
+ submit = gr.Button(value='Submit', variant='primary', size='sm',
461
  min_width=mw1)
462
+ stop_btn = gr.Button(value="Stop", variant='secondary', size='sm',
463
  min_width=mw1)
464
  save_chat_btn = gr.Button("Save", size='sm', min_width=mw1)
465
  with gr.Column(min_width=mw2):
 
480
  with gr.TabItem("Document Selection"):
481
  document_choice = gr.Dropdown(docs_state0,
482
  label="Select Subset of Document(s) %s" % file_types_str,
483
+ value=[DocumentChoice.ALL.value],
484
  interactive=True,
485
  multiselect=True,
486
  visible=kwargs['langchain_mode'] != 'Disabled',
487
  )
488
  sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list
489
  with gr.Row():
490
+ with gr.Column(scale=1):
491
+ get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm',
492
+ visible=sources_visible)
493
+ show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm',
494
+ visible=sources_visible)
495
+ refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0,
496
+ size='sm',
497
+ visible=sources_visible and allow_upload_to_user_data)
498
+ with gr.Column(scale=4):
499
+ pass
500
+ with gr.Row():
501
+ with gr.Column(scale=1):
502
+ add_placeholder = "e.g. UserData2, user_path2 (optional)" \
503
+ if not is_public else "e.g. MyData2"
504
+ remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2"
505
+ new_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
506
+ allow_upload_to_my_data,
507
+ label='Add Collection',
508
+ placeholder=add_placeholder,
509
+ interactive=True)
510
+ remove_langchain_mode_text = gr.Textbox(value="", visible=allow_upload_to_user_data or
511
+ allow_upload_to_my_data,
512
+ label='Remove Collection',
513
+ placeholder=remove_placeholder,
514
+ interactive=True)
515
+ load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm',
516
+ visible=allow_upload_to_user_data)
517
+ with gr.Column(scale=1):
518
+ df0 = get_df_langchain_mode_paths(selection_docs_state0)
519
+ langchain_mode_path_text = gr.Dataframe(value=df0,
520
+ visible=allow_upload_to_user_data or
521
+ allow_upload_to_my_data,
522
+ label='LangChain Mode-Path',
523
+ show_label=False,
524
+ interactive=False)
525
+ with gr.Column(scale=4):
526
+ pass
527
 
528
  sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
529
  equal_height=False)
 
793
  side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm")
794
  submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm")
795
  col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size')
796
+ text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400,
797
+ step=50, label='Chat Height')
798
  dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm")
799
  with gr.Column(scale=4):
800
  pass
801
+ system_visible0 = not is_public and not admin_pass
802
  admin_row = gr.Row()
803
  with admin_row:
804
  with gr.Column(scale=1):
805
+ admin_pass_textbox = gr.Textbox(label="Admin Password", type='password',
806
+ visible=not system_visible0)
807
  with gr.Column(scale=4):
808
  pass
809
+ system_row = gr.Row(visible=system_visible0)
810
  with system_row:
811
  with gr.Column():
812
  with gr.Row():
 
870
  else:
871
  return tuple([gr.update(interactive=True)] * len(args))
872
 
873
+ # Add to UserData or custom user db
874
  update_db_func = functools.partial(update_user_db,
875
  dbs=dbs,
876
  db_type=db_type,
877
  use_openai_embedding=use_openai_embedding,
878
  hf_embedding_model=hf_embedding_model,
 
879
  captions_model=captions_model,
880
+ enable_captions=enable_captions,
881
  caption_loader=caption_loader,
882
+ enable_ocr=enable_ocr,
883
+ enable_pdf_ocr=enable_pdf_ocr,
884
  verbose=kwargs['verbose'],
 
885
  n_jobs=kwargs['n_jobs'],
886
  )
887
  add_file_outputs = [fileup_output, langchain_mode]
888
  add_file_kwargs = dict(fn=update_db_func,
889
+ inputs=[fileup_output, my_db_state, selection_docs_state, chunk, chunk_size,
890
+ langchain_mode],
891
  outputs=add_file_outputs + [sources_text, doc_exception_text],
892
  queue=queue,
893
  api_name='add_file' if allow_api and allow_upload_to_user_data else None)
 
899
  eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs,
900
  show_progress='minimal')
901
 
902
+ # deal with challenge to have fileup_output itself as input
903
+ add_file_kwargs2 = dict(fn=update_db_func,
904
+ inputs=[fileup_output_text, my_db_state, selection_docs_state, chunk, chunk_size,
905
+ langchain_mode],
906
+ outputs=add_file_outputs + [sources_text, doc_exception_text],
907
+ queue=queue,
908
+ api_name='add_file_api' if allow_api and allow_upload_to_user_data else None)
909
+ eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full')
910
+
911
  # note for update_user_db_func output is ignored for db
912
 
913
  def clear_textbox():
 
917
 
918
  add_url_outputs = [url_text, langchain_mode]
919
  add_url_kwargs = dict(fn=update_user_db_url_func,
920
+ inputs=[url_text, my_db_state, selection_docs_state, chunk, chunk_size,
921
+ langchain_mode],
922
  outputs=add_url_outputs + [sources_text, doc_exception_text],
923
  queue=queue,
924
  api_name='add_url' if allow_api and allow_upload_to_user_data else None)
 
935
  update_user_db_txt_func = functools.partial(update_db_func, is_txt=True)
936
  add_text_outputs = [user_text_text, langchain_mode]
937
  add_text_kwargs = dict(fn=update_user_db_txt_func,
938
+ inputs=[user_text_text, my_db_state, selection_docs_state, chunk, chunk_size,
939
+ langchain_mode],
940
  outputs=add_text_outputs + [sources_text, doc_exception_text],
941
  queue=queue,
942
  api_name='add_text' if allow_api and allow_upload_to_user_data else None
 
948
  eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full')
949
  eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs,
950
  show_progress='minimal')
951
+ db_events = [eventdb1a, eventdb1, eventdb1b, eventdb1_api,
952
  eventdb2a, eventdb2, eventdb2b, eventdb2c,
953
  eventdb3a, eventdb3b, eventdb3, eventdb3c]
954
 
 
956
 
957
  # if change collection source, must clear doc selections from it to avoid inconsistency
958
  def clear_doc_choice():
959
+ return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value)
960
 
961
  langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False)
962
 
963
  def resize_col_tabs(x):
964
  return gr.Dropdown.update(scale=x)
965
 
966
+ col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False)
967
 
968
  def resize_chatbots(x, num_model_lock=0):
969
  if num_model_lock == 0:
 
974
 
975
  resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs))
976
  text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height,
977
+ outputs=[text_output, text_output2] + text_outputs, queue=False)
978
 
979
  def update_dropdown(x):
980
  return gr.Dropdown.update(choices=x, value=[docs_state0[0]])
 
1065
  if file.startswith('http') or file.startswith('https'):
1066
  # if file is online, then might as well use google(?)
1067
  document1 = file
1068
+ return gr.update(visible=True,
1069
+ value=f"""<iframe width="1000" height="800" src="https://docs.google.com/viewerng/viewer?url={document1}&embedded=true" frameborder="0" height="100%" width="100%">
1070
  </iframe>
1071
  """), dummy1, dummy1, dummy1
1072
  else:
 
1089
 
1090
  refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode,
1091
  **get_kwargs(update_and_get_source_files_given_langchain_mode,
1092
+ exclude_names=['db1s', 'langchain_mode', 'chunk',
1093
+ 'chunk_size'],
1094
  **all_kwargs))
1095
+ eventdb9 = refresh_sources_btn.click(fn=refresh_sources1,
1096
+ inputs=[my_db_state, langchain_mode, chunk, chunk_size],
1097
  outputs=sources_text,
1098
  api_name='refresh_sources' if allow_api else None)
1099
 
 
1103
  def close_admin(x):
1104
  return gr.update(visible=not (x == admin_pass))
1105
 
1106
+ admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \
1107
  .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
1108
 
1109
+ def add_langchain_mode(db1s, selection_docs_state1, langchain_mode1, y):
1110
+ for k in db1s:
1111
+ set_userid(db1s[k])
1112
+ langchain_modes = selection_docs_state1['langchain_modes']
1113
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1114
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1115
+
1116
+ user_path = None
1117
+ valid = True
1118
+ y2 = y.strip().replace(' ', '').split(',')
1119
+ if len(y2) >= 1:
1120
+ langchain_mode2 = y2[0]
1121
+ if len(langchain_mode2) >= 3 and langchain_mode2.isalnum():
1122
+ # real restriction is:
1123
+ # ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me
1124
+ # but just make simpler
1125
+ user_path = y2[1] if len(y2) > 1 else None # assume scratch if don't have user_path
1126
+ if user_path in ['', "''"]:
1127
+ # for scratch spaces
1128
+ user_path = None
1129
+ if langchain_mode2 in langchain_modes_intrinsic:
1130
+ user_path = None
1131
+ textbox = "Invalid access to use internal name: %s" % langchain_mode2
1132
+ valid = False
1133
+ langchain_mode2 = langchain_mode1
1134
+ elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data:
1135
+ langchain_mode_paths.update({langchain_mode2: user_path})
1136
+ if langchain_mode2 not in visible_langchain_modes:
1137
+ visible_langchain_modes.append(langchain_mode2)
1138
+ if langchain_mode2 not in langchain_modes:
1139
+ langchain_modes.append(langchain_mode2)
1140
+ textbox = ''
1141
+ if user_path:
1142
+ makedirs(user_path, exist_ok=True)
1143
+ else:
1144
+ valid = False
1145
+ langchain_mode2 = langchain_mode1
1146
+ textbox = "Invalid access. user allowed: %s " \
1147
+ "scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data)
1148
+ else:
1149
+ valid = False
1150
+ langchain_mode2 = langchain_mode1
1151
+ textbox = "Invalid, collection must be >=3 characters and alphanumeric"
1152
+ else:
1153
+ valid = False
1154
+ langchain_mode2 = langchain_mode1
1155
+ textbox = "Invalid, must be like UserData2, user_path2"
1156
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1157
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1158
+ choices = get_langchain_choices(selection_docs_state1)
1159
+
1160
+ if valid and not user_path:
1161
+ # needs to have key for it to make it known different from userdata case in _update_user_db()
1162
+ db1s[langchain_mode2] = [None, None]
1163
+ if valid:
1164
+ save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
1165
+ db1s)
1166
+
1167
+ return db1s, selection_docs_state1, gr.update(choices=choices,
1168
+ value=langchain_mode2), textbox, df_langchain_mode_paths1
1169
+
1170
+ def remove_langchain_mode(db1s, selection_docs_state1, langchain_mode1, langchain_mode2, dbsu=None):
1171
+ for k in db1s:
1172
+ set_userid(db1s[k])
1173
+ assert dbsu is not None
1174
+ langchain_modes = selection_docs_state1['langchain_modes']
1175
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1176
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1177
+
1178
+ if langchain_mode2 in db1s and not allow_upload_to_my_data or \
1179
+ dbsu is not None and langchain_mode2 in dbsu and not allow_upload_to_user_data or \
1180
+ langchain_mode2 in langchain_modes_intrinsic:
1181
+ # NOTE: Doesn't fail if remove MyData, but didn't debug odd behavior seen with upload after gone
1182
+ textbox = "Invalid access, cannot remove %s" % langchain_mode2
1183
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1184
+ else:
1185
+ # change global variables
1186
+ if langchain_mode2 in visible_langchain_modes:
1187
+ visible_langchain_modes.remove(langchain_mode2)
1188
+ textbox = ""
1189
+ else:
1190
+ textbox = "%s was not visible" % langchain_mode2
1191
+ if langchain_mode2 in langchain_modes:
1192
+ langchain_modes.remove(langchain_mode2)
1193
+ if langchain_mode2 in langchain_mode_paths:
1194
+ langchain_mode_paths.pop(langchain_mode2)
1195
+ if langchain_mode2 in db1s:
1196
+ # remove db entirely, so not in list, else need to manage visible list in update_langchain_mode_paths()
1197
+ # FIXME: Remove location?
1198
+ if langchain_mode2 != LangChainMode.MY_DATA.value:
1199
+ # don't remove last MyData, used as user hash
1200
+ db1s.pop(langchain_mode2)
1201
+ # only show
1202
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1203
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1204
+
1205
+ save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode,
1206
+ db1s)
1207
+
1208
+ return db1s, selection_docs_state1, \
1209
+ gr.update(choices=get_langchain_choices(selection_docs_state1),
1210
+ value=langchain_mode2), textbox, df_langchain_mode_paths1
1211
+
1212
+ new_langchain_mode_text.submit(fn=add_langchain_mode,
1213
+ inputs=[my_db_state, selection_docs_state, langchain_mode,
1214
+ new_langchain_mode_text],
1215
+ outputs=[my_db_state, selection_docs_state, langchain_mode,
1216
+ new_langchain_mode_text,
1217
+ langchain_mode_path_text],
1218
+ api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
1219
+ remove_langchain_mode_func = functools.partial(remove_langchain_mode, dbsu=dbs)
1220
+ remove_langchain_mode_text.submit(fn=remove_langchain_mode_func,
1221
+ inputs=[my_db_state, selection_docs_state, langchain_mode,
1222
+ remove_langchain_mode_text],
1223
+ outputs=[my_db_state, selection_docs_state, langchain_mode,
1224
+ remove_langchain_mode_text,
1225
+ langchain_mode_path_text],
1226
+ api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None)
1227
+
1228
+ def update_langchain_gr(db1s, selection_docs_state1, langchain_mode1):
1229
+ for k in db1s:
1230
+ set_userid(db1s[k])
1231
+ langchain_modes = selection_docs_state1['langchain_modes']
1232
+ langchain_mode_paths = selection_docs_state1['langchain_mode_paths']
1233
+ visible_langchain_modes = selection_docs_state1['visible_langchain_modes']
1234
+ # in-place
1235
+
1236
+ # update user collaborative collections
1237
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '')
1238
+ # update scratch single-user collections
1239
+ user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
1240
+ update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, user_hash)
1241
+
1242
+ selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1)
1243
+ df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1)
1244
+ return selection_docs_state1, \
1245
+ gr.update(choices=get_langchain_choices(selection_docs_state1),
1246
+ value=langchain_mode1), df_langchain_mode_paths1
1247
+
1248
+ load_langchain.click(fn=update_langchain_gr,
1249
+ inputs=[my_db_state, selection_docs_state, langchain_mode],
1250
+ outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text],
1251
+ api_name='load_langchain' if allow_api and allow_upload_to_user_data else None)
1252
+
1253
  inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1)
1254
  inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2)
1255
  from functools import partial
 
1261
  def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1):
1262
  args_list = list(args1)
1263
  if str_api:
1264
+ user_kwargs = args_list[len(input_args_list)]
1265
  assert isinstance(user_kwargs, str)
1266
  user_kwargs = ast.literal_eval(user_kwargs)
1267
  else:
1268
+ user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])}
1269
  # only used for submit_nochat_api
1270
  user_kwargs['chat'] = False
1271
  if 'stream_output' not in user_kwargs:
 
1284
  # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get()
1285
  model_state1 = args_list[0]
1286
  my_db_state1 = args_list[1]
1287
+ selection_docs_state1 = args_list[2]
1288
  args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k
1289
  in eval_func_param_names]
1290
  assert len(args_list) == len(eval_func_param_names)
1291
+ args_list = [model_state1, my_db_state1, selection_docs_state1] + args_list
1292
 
1293
  try:
1294
  for res_dict in evaluate(*tuple(args_list), **kwargs1):
 
1492
  history[-1][1] = None
1493
  return history
1494
  if user_message1 in ['', None, '\n']:
1495
+ if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
 
 
 
1496
  # reject non-retry submit/enter
1497
  return history
1498
  user_message1 = fix_text_for_gradio(user_message1)
 
1539
  API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list
1540
  :return: last element is True if should run bot, False if should just yield history
1541
  """
1542
+ isize = len(input_args_list) + 1 # states + chat history
1543
  # don't deepcopy, can contain model itself
1544
  args_list = list(args).copy()
1545
+ model_state1 = args_list[-isize]
1546
+ my_db_state1 = args_list[-isize + 1]
1547
+ selection_docs_state1 = args_list[-isize + 2]
1548
  history = args_list[-1]
1549
  prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
1550
  prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')]
 
1552
  if model_state1['model'] is None or model_state1['model'] == no_model_str:
1553
  return history, None, None, None
1554
 
1555
+ args_list = args_list[:-isize] # only keep rest needed for evaluate()
1556
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1557
+ add_chat_history_to_context1 = args_list[eval_func_param_names.index('add_chat_history_to_context')]
1558
  langchain_action1 = args_list[eval_func_param_names.index('langchain_action')]
1559
  langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')]
1560
  document_subset1 = args_list[eval_func_param_names.index('document_subset')]
 
1569
  instruction1 = history[-1][0]
1570
  history[-1][1] = None
1571
  elif not instruction1:
1572
+ if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1):
 
 
 
1573
  # if not retrying, then reject empty query
1574
  return history, None, None, None
1575
  elif len(history) > 0 and history[-1][1] not in [None, '']:
 
1586
 
1587
  chat1 = args_list[eval_func_param_names.index('chat')]
1588
  model_max_length1 = get_model_max_length(model_state1)
1589
+ context1 = history_to_context(history, langchain_mode1,
1590
+ add_chat_history_to_context1,
1591
+ prompt_type1, prompt_dict1, chat1,
1592
  model_max_length1, memory_restriction_level,
1593
  kwargs['keep_sources_in_context'])
1594
  args_list[0] = instruction1 # override original instruction with history from user
 
1597
  fun1 = partial(evaluate,
1598
  model_state1,
1599
  my_db_state1,
1600
+ selection_docs_state1,
1601
  *tuple(args_list),
1602
  **kwargs_evaluate)
1603
 
 
1643
  clear_torch_cache()
1644
  return
1645
 
1646
+ def clear_embeddings(langchain_mode1, db1s):
1647
  # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
1648
+ if db_type == 'chroma' and langchain_mode1 not in ['LLM', 'Disabled', None, '']:
1649
  from gpt_langchain import clear_embedding
1650
  db = dbs.get('langchain_mode1')
1651
  if db is not None and not isinstance(db, str):
1652
  clear_embedding(db)
1653
+ if db1s is not None and langchain_mode1 in db1s:
1654
+ db1 = db1s[langchain_mode1]
1655
+ if len(db1) == 2:
1656
+ clear_embedding(db1[0])
1657
 
1658
  def bot(*args, retry=False):
1659
+ history, fun1, langchain_mode1, db1 = prep_bot(*args, retry=retry)
1660
  try:
1661
  for res in get_response(fun1, history):
1662
  yield res
1663
  finally:
1664
  clear_torch_cache()
1665
+ clear_embeddings(langchain_mode1, db1)
1666
 
1667
  def all_bot(*args, retry=False, model_states1=None):
1668
  args_list = list(args).copy()
 
1672
  stream_output1 = args_list[eval_func_param_names.index('stream_output')]
1673
  max_time1 = args_list[eval_func_param_names.index('max_time')]
1674
  langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
1675
+ isize = len(input_args_list) + 1 # states + chat history
1676
+ db1s = None
1677
  try:
1678
  gen_list = []
1679
  for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)):
1680
  args_list1 = args_list0.copy()
1681
+ args_list1.insert(-isize + 2,
1682
+ model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4
1683
  # if at start, have None in response still, replace with '' so client etc. acts like normal
1684
  # assumes other parts of code treat '' and None as if no response yet from bot
1685
  # can't do this later in bot code as racy with threaded generators
 
1689
  # so consistent with prep_bot()
1690
  # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
1691
  # langchain_mode1 and my_db_state1 should be same for every bot
1692
+ history, fun1, langchain_mode1, db1s = prep_bot(*tuple(args_list1), retry=retry,
1693
+ which_model=chatboti)
1694
  gen1 = get_response(fun1, history)
1695
  if stream_output1:
1696
  gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
 
1736
  print("Generate exceptions: %s" % exceptions, flush=True)
1737
  finally:
1738
  clear_torch_cache()
1739
+ clear_embeddings(langchain_mode1, db1s)
1740
 
1741
  # NORMAL MODEL
1742
  user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
 
1744
  outputs=text_output,
1745
  )
1746
  bot_args = dict(fn=bot,
1747
+ inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
1748
  outputs=[text_output, chat_exception_text],
1749
  )
1750
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
1751
+ inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output],
1752
  outputs=[text_output, chat_exception_text],
1753
  )
1754
  retry_user_args = dict(fn=functools.partial(user, retry=True),
 
1766
  outputs=text_output2,
1767
  )
1768
  bot_args2 = dict(fn=bot,
1769
+ inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
1770
  outputs=[text_output2, chat_exception_text],
1771
  )
1772
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
1773
+ inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2],
1774
  outputs=[text_output2, chat_exception_text],
1775
  )
1776
  retry_user_args2 = dict(fn=functools.partial(user, retry=True),
 
1791
  outputs=text_outputs,
1792
  )
1793
  all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
1794
+ inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
1795
  outputs=text_outputs + [chat_exception_text],
1796
  )
1797
  all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
1798
+ inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs,
1799
  outputs=text_outputs + [chat_exception_text],
1800
  )
1801
  all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
 
1957
  def get_short_chat(x, short_chats, short_len=20, words=4):
1958
  if x and len(x[0]) == 2 and x[0][0] is not None:
1959
  short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
1960
+ if not short_chat:
1961
+ # e.g.summarization, try using answer
1962
+ short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip()
1963
+ if not short_chat:
1964
+ short_chat = 'Unk'
1965
  short_chat = dedup(short_chat, short_chats)
1966
  else:
1967
  short_chat = None
 
2029
  already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists])
2030
  if not already_exists:
2031
  chat_state1[short_chat] = chat_list.copy()
2032
+
2033
+ # reverse so newest at top
2034
+ choices = list(chat_state1.keys()).copy()
2035
+ choices.reverse()
2036
+
2037
+ return chat_state1, gr.update(choices=choices, value=None)
 
 
2038
 
2039
  def switch_chat(chat_key, chat_state1, num_model_lock=0):
2040
  chosen_chat = chat_state1[chat_key]
 
2065
 
2066
  remove_chat_event = remove_chat_btn.click(remove_chat,
2067
  inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state],
2068
+ queue=False, api_name='remove_chat')
2069
 
2070
  def get_chats1(chat_state1):
2071
  base = 'chats'
 
2096
  new_chats = json.loads(f.read())
2097
  for chat1_k, chat1_v in new_chats.items():
2098
  # ignore chat1_k, regenerate and de-dup to avoid loss
2099
+ chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True)
2100
  except BaseException as e:
2101
  t, v, tb = sys.exc_info()
2102
  ex = ''.join(traceback.format_exception(t, v, tb))
 
2122
  .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \
2123
  .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
2124
 
 
 
 
 
 
 
2125
  clear_event = save_chat_btn.click(save_chat,
2126
  inputs=[text_output, text_output2] + text_outputs + [chat_state],
2127
+ outputs=[chat_state, radio_chats],
2128
+ api_name='save_chat' if allow_api else None)
2129
+ if kwargs['score_model']:
2130
+ clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat])
 
2131
 
2132
  # NOTE: clear of instruction/iinput for nochat has to come after score,
2133
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
2134
  no_chat_args = dict(fn=fun,
2135
+ inputs=[model_state, my_db_state, selection_docs_state] + inputs_list,
2136
  outputs=text_output_nochat,
2137
  queue=queue,
2138
  )
 
2151
  .then(clear_torch_cache)
2152
 
2153
  submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str,
2154
+ inputs=[model_state, my_db_state, selection_docs_state,
2155
+ inputs_dict_str],
2156
  outputs=text_output_nochat_api,
2157
  queue=True, # required for generator
2158
  api_name='submit_nochat_api' if allow_api else None) \
 
2402
  print("Exception: %s" % str(e), flush=True)
2403
  return json.dumps(sys_dict)
2404
 
2405
+ system_kwargs = all_kwargs.copy()
2406
+ system_kwargs.update(dict(command=str(' '.join(sys.argv))))
2407
  get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs)
2408
 
2409
  system_dict_event = system_btn2.click(get_system_info_dict_func,
 
2433
  else:
2434
  tokenizer = None
2435
  if tokenizer is not None:
2436
+ langchain_mode1 = 'LLM'
2437
+ add_chat_history_to_context1 = True
2438
  # fake user message to mimic bot()
2439
  chat1 = copy.deepcopy(chat1)
2440
  chat1 = chat1 + [['user_message1', None]]
2441
  model_max_length1 = tokenizer.model_max_length
2442
+ context1 = history_to_context(chat1, langchain_mode1,
2443
+ add_chat_history_to_context1,
2444
+ prompt_type1, prompt_dict1, chat1,
2445
  model_max_length1,
2446
  memory_restriction_level1, keep_sources_in_context1)
2447
  return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1])
 
2471
  ,
2472
  queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False)
2473
 
2474
+ demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None)
2475
 
2476
  demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open'])
2477
  favicon_path = "h2o-logo.svg"
 
2486
  # FIXME: disable for gptj, langchain or gpt4all modify print itself
2487
  # FIXME: and any multi-threaded/async print will enter model output!
2488
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
2489
+ if is_public or os.getenv('PING_GPU'):
2490
+ scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10)
2491
  scheduler.start()
2492
 
2493
  # import control
 
2506
  demo.block_thread()
2507
 
2508
 
 
 
 
2509
  def get_inputs_list(inputs_dict, model_lower, model_id=1):
2510
  """
2511
  map gradio objects in locals() to inputs for evaluate().
 
2539
  return inputs_list, inputs_dict_out
2540
 
2541
 
2542
+ def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None):
2543
+ for k in db1s:
2544
+ set_userid(db1s[k])
2545
 
2546
  if langchain_mode in ['ChatLLM', 'LLM']:
2547
  source_files_added = "NA"
 
2550
  source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
2551
  " Ask [email protected] for file if required."
2552
  source_list = []
2553
+ elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None:
2554
+ db1 = db1s[langchain_mode]
2555
  from gpt_langchain import get_metadatas
2556
  metadatas = get_metadatas(db1[0])
2557
  source_list = sorted(set([x['source'] for x in metadatas]))
 
2582
  db1[1] = str(uuid.uuid4())
2583
 
2584
 
2585
+ def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs):
2586
+ kwargs.update(selection_docs_state1)
 
2587
  if file is None:
2588
  raise RuntimeError("Don't use change, use input")
2589
 
2590
  try:
2591
+ return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size,
2592
  langchain_mode=langchain_mode, dbs=dbs,
2593
  **kwargs)
2594
  except BaseException as e:
 
2619
  user_id = db1[1]
2620
  base_path = 'locks'
2621
  makedirs(base_path)
2622
+ lock_file = os.path.join(base_path, "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id))
2623
  return lock_file
2624
 
2625
 
2626
  def _update_user_db(file,
2627
+ db1s=None,
2628
  chunk=None, chunk_size=None,
2629
+ dbs=None, db_type=None,
2630
+ langchain_mode='UserData',
2631
+ langchain_modes=None, # unused but required as part of selection_docs_state1
2632
+ langchain_mode_paths=None,
2633
+ visible_langchain_modes=None,
2634
  use_openai_embedding=None,
2635
  hf_embedding_model=None,
2636
  caption_loader=None,
2637
  enable_captions=None,
2638
  captions_model=None,
2639
  enable_ocr=None,
2640
+ enable_pdf_ocr=None,
2641
  verbose=None,
2642
+ n_jobs=-1,
2643
  is_url=None, is_txt=None,
2644
+ ):
2645
+ assert db1s is not None
2646
  assert chunk is not None
2647
  assert chunk_size is not None
2648
  assert use_openai_embedding is not None
 
2651
  assert enable_captions is not None
2652
  assert captions_model is not None
2653
  assert enable_ocr is not None
2654
+ assert enable_pdf_ocr is not None
2655
  assert verbose is not None
2656
 
 
 
2657
  if dbs is None:
2658
  dbs = {}
2659
  assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
 
2671
  if langchain_mode == LangChainMode.DISABLED.value:
2672
  return None, langchain_mode, get_source_files(), ""
2673
 
2674
+ if langchain_mode in [LangChainMode.LLM.value]:
2675
  # then switch to MyData, so langchain_mode also becomes way to select where upload goes
2676
  # but default to mydata if nothing chosen, since safest
2677
+ if LangChainMode.MY_DATA.value in visible_langchain_modes:
2678
+ langchain_mode = LangChainMode.MY_DATA.value
2679
+
2680
+ if langchain_mode_paths is None:
2681
+ langchain_mode_paths = {}
2682
+ user_path = langchain_mode_paths.get(langchain_mode)
2683
+ # UserData or custom, which has to be from user's disk
2684
+ if user_path is not None:
2685
  # move temp files from gradio upload to stable location
2686
  for fili, fil in enumerate(file):
2687
+ if isinstance(fil, str) and os.path.isfile(fil): # not url, text
2688
+ new_fil = os.path.normpath(os.path.join(user_path, os.path.basename(fil)))
2689
+ if os.path.normpath(os.path.abspath(fil)) != os.path.normpath(os.path.abspath(new_fil)):
2690
  if os.path.isfile(new_fil):
2691
  remove(new_fil)
2692
  try:
 
2706
  enable_captions=enable_captions,
2707
  captions_model=captions_model,
2708
  enable_ocr=enable_ocr,
2709
+ enable_pdf_ocr=enable_pdf_ocr,
2710
  caption_loader=caption_loader,
2711
  )
2712
  exceptions = [x for x in sources if x.metadata.get('exception')]
2713
  exceptions_strs = [x.metadata['exception'] for x in exceptions]
2714
  sources = [x for x in sources if 'exception' not in x.metadata]
2715
 
2716
+ # below must at least come after langchain_mode is modified in case was LLM -> MyData,
2717
+ # so original langchain mode changed
2718
+ for k in db1s:
2719
+ set_userid(db1s[k])
2720
+ db1 = get_db1(db1s, langchain_mode)
2721
+
2722
+ lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) # user-level lock, not db-level lock
2723
  with filelock.FileLock(lock_file):
2724
+ if langchain_mode in db1s:
2725
  if db1[0] is not None:
2726
  # then add
2727
  db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type,
 
2731
  # in testing expect:
2732
  # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
2733
  # for production hit, when user gets clicky:
2734
+ assert len(db1) == 2, "Bad %s db: %s" % (langchain_mode, db1)
2735
+ assert db1[1] is not None, "db hash was None, not allowed"
2736
  # then create
2737
  # if added has to original state and didn't change, then would be shared db for all users
2738
  persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
 
2754
  use_openai_embedding=use_openai_embedding,
2755
  hf_embedding_model=hf_embedding_model)
2756
  else:
2757
+ # then create. Or might just be that dbs is unfilled, then it will fill, then add
2758
  db = get_db(sources, use_openai_embedding=use_openai_embedding,
2759
  db_type=db_type,
2760
  persist_directory=persist_directory,
 
2768
  return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs)
2769
 
2770
 
2771
+ def get_db(db1s, langchain_mode, dbs=None):
2772
+ db1 = get_db1(db1s, langchain_mode)
2773
+ lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode)
2774
 
2775
  with filelock.FileLock(lock_file):
2776
  if langchain_mode in ['wiki_full']:
2777
  # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now
2778
  db = None
2779
+ elif langchain_mode in db1s and len(db1) == 2 and db1[0] is not None:
2780
  db = db1[0]
2781
  elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None:
2782
  db = dbs[langchain_mode]
 
2785
  return db
2786
 
2787
 
2788
+ def get_source_files_given_langchain_mode(db1s, langchain_mode='UserData', dbs=None):
2789
+ db = get_db(db1s, langchain_mode, dbs=dbs)
2790
  if langchain_mode in ['ChatLLM', 'LLM'] or db is None:
2791
  return "Sources: N/A"
2792
  return get_source_files(db=db, exceptions=None)
 
2885
  return source_files_added
2886
 
2887
 
2888
+ def update_and_get_source_files_given_langchain_mode(db1s, langchain_mode, chunk, chunk_size,
2889
+ dbs=None, first_para=None,
2890
+ text_limit=None,
2891
+ langchain_mode_paths=None, db_type=None, load_db_if_exists=None,
2892
  n_jobs=None, verbose=None):
2893
+ has_path = {k: v for k, v in langchain_mode_paths.items() if v}
2894
+ if langchain_mode in [LangChainMode.LLM.value, LangChainMode.MY_DATA.value]:
2895
+ # then assume user really meant UserData, to avoid extra clicks in UI,
2896
+ # since others can't be on disk, except custom user modes, which they should then select to query it
2897
+ if LangChainMode.USER_DATA.value in has_path:
2898
+ langchain_mode = LangChainMode.USER_DATA.value
2899
+
2900
+ db = get_db(db1s, langchain_mode, dbs=dbs)
2901
 
2902
  from gpt_langchain import make_db
2903
  db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False,
 
2906
  chunk=chunk,
2907
  chunk_size=chunk_size,
2908
  langchain_mode=langchain_mode,
2909
+ langchain_mode_paths=langchain_mode_paths,
2910
  db_type=db_type,
2911
  load_db_if_exists=load_db_if_exists,
2912
  db=db,
2913
  n_jobs=n_jobs,
2914
  verbose=verbose)
2915
+ # during refreshing, might have "created" new db since not in dbs[] yet, so insert back just in case
2916
+ # so even if persisted, not kept up-to-date with dbs memory
2917
+ if langchain_mode in db1s:
2918
+ db1s[langchain_mode][0] = db
2919
+ else:
2920
+ dbs[langchain_mode] = db
2921
+
2922
  # return only new sources with text saying such
2923
  return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata)
2924
+
2925
+
2926
+ def get_db1(db1s, langchain_mode1):
2927
+ if langchain_mode1 in db1s:
2928
+ db1 = db1s[langchain_mode1]
2929
+ else:
2930
+ # indicates to code that not scratch database
2931
+ db1 = [None, None]
2932
+ return db1
gradio_utils/__init__.py ADDED
File without changes
gradio_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (134 Bytes). View file
 
gradio_utils/__pycache__/css.cpython-310.pyc CHANGED
Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and b/gradio_utils/__pycache__/css.cpython-310.pyc differ
 
gradio_utils/css.py CHANGED
@@ -53,4 +53,8 @@ def make_css_base() -> str:
53
  margin-bottom: 2.5rem;
54
  }
55
  .chatsmall chatbot {font-size: 10px !important}
 
 
 
 
56
  """
 
53
  margin-bottom: 2.5rem;
54
  }
55
  .chatsmall chatbot {font-size: 10px !important}
56
+
57
+ .gradio-container {
58
+ max-width: none !important;
59
+ }
60
  """
h2oai_pipeline.py CHANGED
@@ -11,6 +11,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
  sanitize_bot_response=False,
13
  use_prompter=True, prompter=None,
 
14
  prompt_type=None, prompt_dict=None,
15
  max_input_tokens=2048 - 256, **kwargs):
16
  """
@@ -34,6 +35,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
34
  self.prompt_type = prompt_type
35
  self.prompt_dict = prompt_dict
36
  self.prompter = prompter
 
 
37
  if self.use_prompter:
38
  if self.prompter is not None:
39
  assert self.prompter.prompt_type is not None
@@ -113,7 +116,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
113
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
114
  prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
115
 
116
- data_point = dict(context='', instruction=prompt_text, input='')
117
  if self.prompter is not None:
118
  prompt_text = self.prompter.generate_prompt(data_point)
119
  self.prompt_text = prompt_text
 
11
  def __init__(self, *args, debug=False, chat=False, stream_output=False,
12
  sanitize_bot_response=False,
13
  use_prompter=True, prompter=None,
14
+ context='', iinput='',
15
  prompt_type=None, prompt_dict=None,
16
  max_input_tokens=2048 - 256, **kwargs):
17
  """
 
35
  self.prompt_type = prompt_type
36
  self.prompt_dict = prompt_dict
37
  self.prompter = prompter
38
+ self.context = context
39
+ self.iinput = iinput
40
  if self.use_prompter:
41
  if self.prompter is not None:
42
  assert self.prompter.prompt_type is not None
 
116
  def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
117
  prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
118
 
119
+ data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput)
120
  if self.prompter is not None:
121
  prompt_text = self.prompter.generate_prompt(data_point)
122
  self.prompt_text = prompt_text
iterators/__pycache__/timeout_iterator.cpython-310.pyc CHANGED
Binary files a/iterators/__pycache__/timeout_iterator.cpython-310.pyc and b/iterators/__pycache__/timeout_iterator.cpython-310.pyc differ
 
iterators/timeout_iterator.py CHANGED
@@ -48,7 +48,7 @@ class TimeoutIterator:
48
  def interrupt(self):
49
  """
50
  interrupt and stop the underlying thread.
51
- the thread acutally dies only after interrupt has been set and
52
  the underlying iterator yields a value after that.
53
  """
54
  self._interrupt = True
 
48
  def interrupt(self):
49
  """
50
  interrupt and stop the underlying thread.
51
+ the thread actually dies only after interrupt has been set and
52
  the underlying iterator yields a value after that.
53
  """
54
  self._interrupt = True
prompter.py CHANGED
@@ -77,6 +77,12 @@ prompt_type_to_model_name = {
77
  "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
  "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
 
 
 
 
 
 
80
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
81
  }
82
  if os.getenv('OPENAI_API_KEY'):
@@ -596,6 +602,28 @@ ASSISTANT:
596
  chat_turn_sep = chat_sep = '\n'
597
  humanstr = PreInstruct
598
  botstr = PreResponse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
  else:
600
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
601
 
 
77
  "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'],
78
  "vicuna11": ['lmsys/vicuna-33b-v1.3'],
79
  "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'],
80
+ "llama2": [
81
+ 'meta-llama/Llama-2-7b-chat-hf',
82
+ 'meta-llama/Llama-2-13b-chat-hf',
83
+ 'meta-llama/Llama-2-34b-chat-hf',
84
+ 'meta-llama/Llama-2-70b-chat-hf',
85
+ ],
86
  # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
87
  }
88
  if os.getenv('OPENAI_API_KEY'):
 
602
  chat_turn_sep = chat_sep = '\n'
603
  humanstr = PreInstruct
604
  botstr = PreResponse
605
+ elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value),
606
+ PromptType.llama2.name]:
607
+ PreInstruct = ""
608
+ llama2_sys = "<<SYS>>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
609
+ prompt = "<s>[INST] "
610
+ enable_sys = False # too much safety, hurts accuracy
611
+ if not (chat and reduced):
612
+ if enable_sys:
613
+ promptA = promptB = prompt + llama2_sys
614
+ else:
615
+ promptA = promptB = prompt
616
+ else:
617
+ promptA = promptB = ''
618
+ PreInput = None
619
+ PreResponse = ""
620
+ terminate_response = ["[INST]", "</s>"]
621
+ chat_sep = ' [/INST]'
622
+ chat_turn_sep = ' </s><s>[INST] '
623
+ humanstr = PreInstruct
624
+ botstr = PreResponse
625
+ if making_context:
626
+ PreResponse += " "
627
  else:
628
  raise RuntimeError("No such prompt_type=%s" % prompt_type)
629
 
requirements.txt CHANGED
@@ -1,8 +1,8 @@
1
  # for generate (gradio server) and finetune
2
  datasets==2.13.0
3
  sentencepiece==0.1.99
4
- gradio==3.35.2
5
- huggingface_hub==0.15.1
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
@@ -19,7 +19,7 @@ matplotlib==3.7.1
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
- git+https://github.com/huggingface/peft.git@06fd06a4d2e8ed8c3a253c67d9c3cb23e0f497ad
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
@@ -35,7 +35,7 @@ tensorboard==2.13.0
35
  neptune==1.2.0
36
 
37
  # for gradio client
38
- gradio_client==0.2.7
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.3
41
 
@@ -64,8 +64,8 @@ tiktoken==0.4.0
64
  # optional: for OpenAI endpoint or embeddings (requires key)
65
  openai==0.27.8
66
  # optional for chat with PDF
67
- langchain==0.0.202
68
- pypdf==3.9.1
69
  # avoid textract, requires old six
70
  #textract==1.6.5
71
 
@@ -78,10 +78,10 @@ chromadb==0.3.25
78
  #pymilvus==2.2.8
79
 
80
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
- # unstructured==0.6.6
82
 
83
  # strong support for images
84
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
85
  unstructured[local-inference]==0.7.4
86
  #pdf2image==1.16.3
87
  #pytesseract==0.3.10
@@ -104,10 +104,10 @@ tabulate==0.9.0
104
  pip-licenses==4.3.0
105
 
106
  # weaviate vector db
107
- weaviate-client==3.20.0
108
  # optional for chat with PDF
109
- langchain==0.0.202
110
- pypdf==3.9.1
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
@@ -120,10 +120,10 @@ chromadb==0.3.25
120
  #pymilvus==2.2.8
121
 
122
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
123
- # unstructured==0.6.6
124
 
125
  # strong support for images
126
- # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
127
  unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
@@ -146,8 +146,8 @@ tabulate==0.9.0
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
- weaviate-client==3.20.0
150
  faiss-gpu==1.7.2
151
- arxiv==1.4.7
152
- pymupdf==1.22.3 # AGPL license
153
  # extract-msg==0.41.1 # GPL3
 
1
  # for generate (gradio server) and finetune
2
  datasets==2.13.0
3
  sentencepiece==0.1.99
4
+ gradio==3.37.0
5
+ huggingface_hub==0.16.4
6
  appdirs==1.4.4
7
  fire==0.5.0
8
  docutils==0.20.1
 
19
  loralib==0.1.1
20
  bitsandbytes==0.39.0
21
  accelerate==0.20.3
22
+ peft==0.4.0
23
  transformers==4.30.2
24
  tokenizers==0.13.3
25
  APScheduler==3.10.1
 
35
  neptune==1.2.0
36
 
37
  # for gradio client
38
+ gradio_client==0.2.10
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.3
41
 
 
64
  # optional: for OpenAI endpoint or embeddings (requires key)
65
  openai==0.27.8
66
  # optional for chat with PDF
67
+ langchain==0.0.235
68
+ pypdf==3.12.2
69
  # avoid textract, requires old six
70
  #textract==1.6.5
71
 
 
78
  #pymilvus==2.2.8
79
 
80
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
81
+ # unstructured==0.8.1
82
 
83
  # strong support for images
84
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
85
  unstructured[local-inference]==0.7.4
86
  #pdf2image==1.16.3
87
  #pytesseract==0.3.10
 
104
  pip-licenses==4.3.0
105
 
106
  # weaviate vector db
107
+ weaviate-client==3.22.1
108
  # optional for chat with PDF
109
+ langchain==0.0.235
110
+ pypdf==3.12.2
111
  # avoid textract, requires old six
112
  #textract==1.6.5
113
 
 
120
  #pymilvus==2.2.8
121
 
122
  # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
123
+ # unstructured==0.8.1
124
 
125
  # strong support for images
126
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libtesseract-dev libreoffice
127
  unstructured[local-inference]==0.7.4
128
  #pdf2image==1.16.3
129
  #pytesseract==0.3.10
 
146
  pip-licenses==4.3.0
147
 
148
  # weaviate vector db
149
+ weaviate-client==3.22.1
150
  faiss-gpu==1.7.2
151
+ arxiv==1.4.8
152
+ pymupdf==1.22.5 # AGPL license
153
  # extract-msg==0.41.1 # GPL3
utils.py CHANGED
@@ -5,6 +5,7 @@ import inspect
5
  import os
6
  import gc
7
  import pathlib
 
8
  import random
9
  import shutil
10
  import subprocess
@@ -111,12 +112,15 @@ def system_info():
111
  system = {}
112
  # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
113
  # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
114
- temps = psutil.sensors_temperatures(fahrenheit=False)
115
- if 'coretemp' in temps:
116
- coretemp = temps['coretemp']
117
- temp_dict = {k.label: k.current for k in coretemp}
118
- for k, v in temp_dict.items():
119
- system['CPU_C/%s' % k] = v
 
 
 
120
 
121
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
122
  try:
@@ -779,6 +783,9 @@ def _traced_func(func, *args, **kwargs):
779
 
780
 
781
  def call_subprocess_onetask(func, args=None, kwargs=None):
 
 
 
782
  if isinstance(args, list):
783
  args = tuple(args)
784
  if args is None:
@@ -1001,3 +1008,73 @@ def set_openai(inference_server):
1001
  openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
1002
  inf_type = inference_server
1003
  return openai, inf_type
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import gc
7
  import pathlib
8
+ import pickle
9
  import random
10
  import shutil
11
  import subprocess
 
112
  system = {}
113
  # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard
114
  # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749
115
+ try:
116
+ temps = psutil.sensors_temperatures(fahrenheit=False)
117
+ if 'coretemp' in temps:
118
+ coretemp = temps['coretemp']
119
+ temp_dict = {k.label: k.current for k in coretemp}
120
+ for k, v in temp_dict.items():
121
+ system['CPU_C/%s' % k] = v
122
+ except AttributeError:
123
+ pass
124
 
125
  # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
126
  try:
 
783
 
784
 
785
  def call_subprocess_onetask(func, args=None, kwargs=None):
786
+ import platform
787
+ if platform.system() in ['Darwin', 'Windows']:
788
+ return func(*args, **kwargs)
789
  if isinstance(args, list):
790
  args = tuple(args)
791
  if args is None:
 
1008
  openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1")
1009
  inf_type = inference_server
1010
  return openai, inf_type
1011
+
1012
+
1013
+ visible_langchain_modes_file = 'visible_langchain_modes.pkl'
1014
+
1015
+
1016
+ def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s):
1017
+ """
1018
+ extra controls if UserData type of MyData type
1019
+ """
1020
+
1021
+ # use first default MyData hash as general user hash to maintain file
1022
+ # if user moves MyData from langchain modes, db will still survive, so can still use hash
1023
+ scratch_collection_names = list(db1s.keys())
1024
+ user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1]
1025
+
1026
+ llms = ['ChatLLM', 'LLM', 'Disabled']
1027
+
1028
+ scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names]
1029
+ scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names]
1030
+ scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1031
+ k in scratch_collection_names and k not in llms}
1032
+
1033
+ user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names]
1034
+ user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names]
1035
+ user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if
1036
+ k not in scratch_collection_names and k not in llms}
1037
+
1038
+ base_path = 'locks'
1039
+ makedirs(base_path)
1040
+
1041
+ # user
1042
+ extra = ''
1043
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1044
+ with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1045
+ with open(file, 'wb') as f:
1046
+ pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f)
1047
+
1048
+ # scratch
1049
+ extra = user_hash
1050
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1051
+ with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)):
1052
+ with open(file, 'wb') as f:
1053
+ pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f)
1054
+
1055
+
1056
+ def load_collection_enum(extra):
1057
+ """
1058
+ extra controls if UserData type of MyData type
1059
+ """
1060
+ file = "%s%s" % (visible_langchain_modes_file, extra)
1061
+ langchain_modes_from_file = []
1062
+ visible_langchain_modes_from_file = []
1063
+ langchain_mode_paths_from_file = {}
1064
+ if os.path.isfile(visible_langchain_modes_file):
1065
+ try:
1066
+ with filelock.FileLock("%s.lock" % file):
1067
+ with open(file, 'rb') as f:
1068
+ langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load(
1069
+ f)
1070
+ except BaseException as e:
1071
+ print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True)
1072
+ for k, v in langchain_mode_paths_from_file.items():
1073
+ if v is not None and not os.path.isdir(v) and isinstance(v, str):
1074
+ # assume was deleted, but need to make again to avoid extra code elsewhere
1075
+ makedirs(v)
1076
+ return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file
1077
+
1078
+
1079
+ def remove_collection_enum():
1080
+ remove(visible_langchain_modes_file)