pseudotensor commited on
Commit
d1a52ba
1 Parent(s): 06c87e6

Update with h2oGPT hash 2391b7887b928de3239af955d0263a66fb8a9710

Browse files
Files changed (13) hide show
  1. client_test.py +81 -52
  2. create_data.py +1818 -0
  3. finetune.py +4 -378
  4. generate.py +302 -110
  5. gpt4all_llm.py +119 -0
  6. gpt_langchain.py +1076 -0
  7. gradio_runner.py +634 -46
  8. gradio_themes.py +3 -1
  9. h2oai_pipeline.py +54 -0
  10. loaders.py +50 -0
  11. prompter.py +370 -1
  12. requirements.txt +51 -3
  13. utils.py +477 -6
client_test.py CHANGED
@@ -36,84 +36,113 @@ Loaded as API: https://gpt.h2o.ai ✔
36
  {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
37
 
38
  """
 
 
 
 
39
 
40
  debug = False
41
 
42
- import os
43
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
44
 
45
 
46
- def get_client():
47
  from gradio_client import Client
48
 
49
- client = Client(os.getenv('HOST', "http://localhost:7860"))
50
  if debug:
51
  print(client.view_api(all_endpoints=True))
52
  return client
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def test_client_basic():
56
- return run_client_basic(instruction_nochat='Who are you?', prompt_type='human_bot')
57
-
58
-
59
- def run_client_basic(instruction_nochat, prompt_type):
60
- instruction = '' # only for chat=True
61
- iinput = '' # only for chat=True
62
- context = ''
63
- # streaming output is supported, loops over and outputs each generation in streaming mode
64
- # but leave stream_output=False for simple input/output mode
65
- stream_output = False
66
- temperature = 0.1
67
- top_p = 0.75
68
- top_k = 40
69
- num_beams = 1
70
- max_new_tokens = 50
71
- min_new_tokens = 0
72
- early_stopping = False
73
- max_time = 20
74
- repetition_penalty = 1.0
75
- num_return_sequences = 1
76
- do_sample = True
77
- # only these 2 below used if pass chat=False
78
- chat = False
79
- iinput_nochat = ''
80
-
81
- args = [instruction,
82
- iinput,
83
- context,
84
- stream_output,
85
- prompt_type,
86
- temperature,
87
- top_p,
88
- top_k,
89
- num_beams,
90
- max_new_tokens,
91
- min_new_tokens,
92
- early_stopping,
93
- max_time,
94
- repetition_penalty,
95
- num_return_sequences,
96
- do_sample,
97
- chat,
98
- instruction_nochat,
99
- iinput_nochat,
100
- ]
101
  api_name = '/submit_nochat'
102
- client = get_client()
103
  res = client.predict(
104
  *tuple(args),
105
  api_name=api_name,
106
  )
107
- res_dict = dict(instruction_nochat=instruction_nochat, iinput_nochat=iinput_nochat, response=md_to_text(res))
 
108
  print(res_dict)
109
  return res_dict
110
 
111
 
112
- import markdown # pip install markdown
113
- from bs4 import BeautifulSoup # pip install beautifulsoup4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
 
116
  def md_to_text(md):
 
117
  html = markdown.markdown(md)
118
  soup = BeautifulSoup(html, features='html.parser')
119
  return soup.get_text()
 
36
  {'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.'}
37
 
38
  """
39
+ import time
40
+ import os
41
+ import markdown # pip install markdown
42
+ from bs4 import BeautifulSoup # pip install beautifulsoup4
43
 
44
  debug = False
45
 
 
46
  os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
47
 
48
 
49
+ def get_client(serialize=True):
50
  from gradio_client import Client
51
 
52
+ client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize)
53
  if debug:
54
  print(client.view_api(all_endpoints=True))
55
  return client
56
 
57
 
58
+ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50):
59
+ from collections import OrderedDict
60
+ kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True
61
+ iinput='', # only for chat=True
62
+ context='',
63
+ # streaming output is supported, loops over and outputs each generation in streaming mode
64
+ # but leave stream_output=False for simple input/output mode
65
+ stream_output=stream_output,
66
+ prompt_type=prompt_type,
67
+ temperature=0.1,
68
+ top_p=0.75,
69
+ top_k=40,
70
+ num_beams=1,
71
+ max_new_tokens=max_new_tokens,
72
+ min_new_tokens=0,
73
+ early_stopping=False,
74
+ max_time=20,
75
+ repetition_penalty=1.0,
76
+ num_return_sequences=1,
77
+ do_sample=True,
78
+ chat=chat,
79
+ instruction_nochat=prompt if not chat else '',
80
+ iinput_nochat='', # only for chat=False
81
+ langchain_mode='Disabled',
82
+ )
83
+ if chat:
84
+ # add chatbot output on end. Assumes serialize=False
85
+ kwargs.update(dict(chatbot=[['', None]]))
86
+
87
+ return kwargs, list(kwargs.values())
88
+
89
+
90
  def test_client_basic():
91
+ return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50)
92
+
93
+
94
+ def run_client_nochat(prompt, prompt_type, max_new_tokens):
95
+ kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens)
96
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  api_name = '/submit_nochat'
98
+ client = get_client(serialize=True)
99
  res = client.predict(
100
  *tuple(args),
101
  api_name=api_name,
102
  )
103
+ res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
104
+ response=md_to_text(res))
105
  print(res_dict)
106
  return res_dict
107
 
108
 
109
+ def test_client_chat():
110
+ return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50)
111
+
112
+
113
+ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens):
114
+ kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, max_new_tokens=max_new_tokens)
115
+
116
+ client = get_client(serialize=False)
117
+
118
+ res = client.predict(*tuple(args), api_name='/instruction')
119
+ args[-1] += [res[-1]]
120
+
121
+ res_dict = kwargs
122
+ res_dict['prompt'] = prompt
123
+ if not kwargs['stream_output']:
124
+ res = client.predict(*tuple(args), api_name='/instruction_bot')
125
+ res_dict['response'] = res[0][-1][1]
126
+ print(md_to_text(res_dict['response']))
127
+ return res_dict
128
+ else:
129
+ job = client.submit(*tuple(args), api_name='/instruction_bot')
130
+ res1 = ''
131
+ while not job.done():
132
+ outputs_list = job.communicator.job.outputs
133
+ if outputs_list:
134
+ res = job.communicator.job.outputs[-1]
135
+ res1 = res[0][-1][-1]
136
+ res1 = md_to_text(res1)
137
+ print(res1)
138
+ time.sleep(0.1)
139
+ print(job.outputs())
140
+ res_dict['response'] = res1
141
+ return res_dict
142
 
143
 
144
  def md_to_text(md):
145
+ assert md is not None, "Markdown is None"
146
  html = markdown.markdown(md)
147
  soup = BeautifulSoup(html, features='html.parser')
148
  return soup.get_text()
create_data.py ADDED
@@ -0,0 +1,1818 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset creation tools.
3
+
4
+ Keep to-level imports clean of non-trivial imports for specific tools,
5
+ because this file is imported for various purposes
6
+ """
7
+
8
+ import ast
9
+ import concurrent.futures
10
+ import contextlib
11
+ import hashlib
12
+ import json
13
+ import os
14
+ import shutil
15
+ import signal
16
+ import sys
17
+ import traceback
18
+ from concurrent.futures import ProcessPoolExecutor
19
+
20
+ import psutil
21
+ import pytest
22
+ import pandas as pd
23
+ import numpy as np
24
+ from tqdm import tqdm
25
+
26
+ from utils import flatten_list
27
+
28
+
29
+ def parse_rst_file(filepath):
30
+ with open(filepath, 'r') as f:
31
+ input_data = f.read()
32
+ settings_overrides = {'initial_header_level': 2}
33
+ from docutils import core
34
+ document = core.publish_doctree(
35
+ source=input_data,
36
+ source_path=filepath,
37
+ settings_overrides=settings_overrides,
38
+ )
39
+ qa_pairs = []
40
+ current_section = None
41
+ current_question = ""
42
+ current_answer = ""
43
+ for node in document.traverse():
44
+ if node.__class__.__name__ == 'section':
45
+ current_section = ""
46
+ elif current_section is not None:
47
+ if node.__class__.__name__ == 'Text':
48
+ if node.astext()[-1] == "?":
49
+ if current_question:
50
+ qa_pairs.append((current_question, current_answer))
51
+ current_question = node.astext()
52
+ current_answer = ""
53
+ else:
54
+ current_answer += node.astext()
55
+ if current_answer:
56
+ qa_pairs.append((current_question, current_answer))
57
+ return {k: v for k, v in qa_pairs}
58
+
59
+
60
+ def test_scrape_dai_docs():
61
+ home = os.path.expanduser('~')
62
+ file = os.path.join(home, 'h2oai/docs/faq.rst')
63
+ qa_pairs = parse_rst_file(file)
64
+ prompt_type = 'human_bot'
65
+ from prompter import prompt_types
66
+ assert prompt_type in prompt_types
67
+ save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()]
68
+ output_file = "dai_faq.json"
69
+ with open(output_file, "wt") as f:
70
+ f.write(json.dumps(save_thing, indent=2))
71
+
72
+
73
+ def test_scrape_dai_docs_all():
74
+ """
75
+ pytest create_data.py::test_scrape_dai_docs_all
76
+ """
77
+ import glob
78
+ import nltk
79
+ nltk.download('punkt')
80
+ dd = {}
81
+ np.random.seed(1234)
82
+ home = os.path.expanduser('~')
83
+ files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst")))
84
+ np.random.shuffle(files)
85
+ val_count = int(0.05 * len(files))
86
+ train_files = files[val_count:]
87
+ valid_files = files[:val_count]
88
+ things = [
89
+ ("dai_docs.train.json", train_files),
90
+ ("dai_docs.valid.json", valid_files)
91
+ ]
92
+ for LEN in [100, 200, 500]:
93
+ for output_file, ff in things:
94
+ if output_file not in dd:
95
+ dd[output_file] = []
96
+ for f in ff:
97
+ with open(f) as input:
98
+ blob = input.read()
99
+ blob = blob.replace("~~", "")
100
+ blob = blob.replace("==", "")
101
+ blob = blob.replace("''", "")
102
+ blob = blob.replace("--", "")
103
+ blob = blob.replace("**", "")
104
+ dd[output_file].extend(get_sentences(blob, length=LEN))
105
+ for output_file, _ in things:
106
+ save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]]
107
+ with open(output_file, "wt") as f:
108
+ f.write(json.dumps(save_thing, indent=2))
109
+
110
+
111
+ def get_sentences(blob, length):
112
+ """
113
+ break-up input text into sentences and then output list of sentences of about length in size
114
+ :param blob:
115
+ :param length:
116
+ :return:
117
+ """
118
+ import nltk
119
+ nltk.download('punkt')
120
+ from nltk.tokenize import sent_tokenize
121
+ sentences = sent_tokenize(blob)
122
+ my_sentences = []
123
+ my_string = ""
124
+ for sentence in sentences:
125
+ if len(my_string) + len(sentence) <= length:
126
+ if my_string:
127
+ my_string += " " + sentence
128
+ else:
129
+ my_string = sentence
130
+ else:
131
+ my_sentences.append(my_string)
132
+ my_string = ""
133
+ return my_sentences or [my_string]
134
+
135
+
136
+ def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False):
137
+ """
138
+ Only supported if have access to source code or HF token for HF spaces and from_hf=True
139
+ :param path:
140
+ :param dst:
141
+ :param from_hf:
142
+ :return:
143
+ """
144
+
145
+ home = os.path.expanduser('~')
146
+
147
+ if from_hf:
148
+ # assumes
149
+ from huggingface_hub import hf_hub_download
150
+ # True for case when locally already logged in with correct token, so don't have to set key
151
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
152
+ path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset')
153
+ path = 'h2oai'
154
+ import zipfile
155
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
156
+ zip_ref.extractall(path)
157
+ path = os.path.join(path, 'docs/**/*')
158
+
159
+ if path is None:
160
+ if os.path.isdir(os.path.join(home, 'h2oai')):
161
+ path = os.path.join(home, "h2oai/docs/**/*")
162
+ else:
163
+ assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path
164
+ path = os.path.join(home, "h2oai.superclean/docs/**/*")
165
+ import glob
166
+ files = list(glob.glob(path, recursive=True))
167
+
168
+ # pandoc can't find include files
169
+
170
+ remove(dst)
171
+ os.makedirs(dst)
172
+
173
+ # copy full tree, for absolute paths in rst
174
+ for fil in files:
175
+ if os.path.isfile(fil):
176
+ shutil.copy(fil, dst)
177
+
178
+ # hack for relative path
179
+ scorers_dir = os.path.join(dst, 'scorers')
180
+ makedirs(scorers_dir)
181
+ for fil in glob.glob(os.path.join(dst, '*.frag')):
182
+ shutil.copy(fil, scorers_dir)
183
+
184
+ return dst
185
+
186
+
187
+ def rst_to_outputs(files, min_len=30, max_len=2048//2 - 30):
188
+ # account for sequence length (context window) including prompt and input and output
189
+
190
+ # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst')
191
+ import pypandoc
192
+ basedir = os.path.abspath(os.getcwd())
193
+
194
+ outputs = []
195
+ for fil in files:
196
+ os.chdir(basedir)
197
+ os.chdir(os.path.dirname(fil))
198
+ fil = os.path.basename(fil)
199
+ print("Processing %s" % fil, flush=True)
200
+ # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x,
201
+ # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki,
202
+ # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml,
203
+ # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira,
204
+ # json, latex, man,
205
+ # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict,
206
+ # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx,
207
+ # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki
208
+ out_format = 'plain'
209
+ # avoid extra new lines injected into text
210
+ extra_args = ['--wrap=preserve', '--resource path="%s" % dst']
211
+
212
+ plain_list = []
213
+ try:
214
+ # valid for expert settings
215
+ input_rst = pypandoc.convert_file(fil, 'rst')
216
+ input_list = input_rst.split('\n``')
217
+ for input_subrst in input_list:
218
+ input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain')
219
+ plain_list.append([input_plain, fil])
220
+ except Exception as e:
221
+ print("file exception: %s %s" % (fil, str(e)), flush=True)
222
+
223
+ if not plain_list:
224
+ # if failed to process as pieces of rst, then
225
+ output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst')
226
+ outputs1 = get_sentences(output, length=max_len)
227
+ for oi, output in enumerate(outputs1):
228
+ output = output.replace('\n\n', '\n')
229
+ plain_list.append([output, fil])
230
+ outputs.extend(plain_list)
231
+
232
+ # report:
233
+ # [print(len(x)) for x in outputs]
234
+
235
+ # deal with blocks longer than context size (sequence length) of 2048
236
+ new_outputs = []
237
+ num_truncated = 0
238
+ num_orig = len(outputs)
239
+ for output, fil in outputs:
240
+ if len(output) < max_len:
241
+ new_outputs.append([output, fil])
242
+ continue
243
+ outputs1 = get_sentences(output, length=max_len)
244
+ for oi, output1 in enumerate(outputs1):
245
+ output1 = output1.replace('\n\n', '\n')
246
+ new_outputs.append([output1, fil])
247
+ num_truncated += 1
248
+ print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True)
249
+
250
+ new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len]
251
+
252
+ return new_outputs
253
+
254
+
255
+ def test_scrape_dai_docs_all_pandoc():
256
+ """
257
+ pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc
258
+ :return:
259
+ """
260
+
261
+ dst = setup_dai_docs()
262
+
263
+ import glob
264
+ files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
265
+
266
+ basedir = os.path.abspath(os.getcwd())
267
+ new_outputs = rst_to_outputs(files)
268
+ os.chdir(basedir)
269
+
270
+ remove(dst)
271
+ save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs]
272
+ output_file = "dai_docs.train_cleaned.json"
273
+ with open(output_file, "wt") as f:
274
+ f.write(json.dumps(save_thing, indent=2))
275
+
276
+
277
+ def remove(path: str):
278
+ try:
279
+ if path is not None and os.path.exists(path):
280
+ if os.path.isdir(path):
281
+ shutil_rmtree(path, ignore_errors=True)
282
+ else:
283
+ with contextlib.suppress(FileNotFoundError):
284
+ os.remove(path)
285
+ except:
286
+ pass
287
+
288
+
289
+ def shutil_rmtree(*args, **kwargs):
290
+ return shutil.rmtree(*args, **kwargs)
291
+
292
+
293
+ def test_config_to_json():
294
+ """
295
+ Needs to run from Driverless AI source directory.
296
+ E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/
297
+ :return:
298
+ """
299
+ try:
300
+ # Arrange
301
+ import json
302
+ from h2oaicore.systemutils import config
303
+ toml_list = []
304
+ for k, v in config.get_meta_dict().items():
305
+ title = (v.title + ": ") if v.title else ''
306
+ comment = v.comment or ''
307
+ if not (title or comment):
308
+ continue
309
+ toml_list.extend(
310
+ [
311
+ {
312
+ 'prompt_type': 'plain',
313
+ 'instruction': f"<human>: What does {k} do?\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
314
+ },
315
+ {
316
+ 'prompt_type': 'plain',
317
+ 'instruction': f"<human>: Explain {k}.\n<bot>: {k.replace('_', ' ')} config.toml: {comment or title}\n<human>:".replace("\n", ""),
318
+ },
319
+ {
320
+ 'prompt_type': 'plain',
321
+ 'instruction': f"<human>: How can I do this: {title}.\n<bot>: Set the {k.replace('_', ' ')} config.toml\n<human>:".replace("\n", ""),
322
+ } if title and comment else None,
323
+ {
324
+ 'prompt_type': 'human_bot',
325
+ 'instruction': f'Explain the following expert setting for Driverless AI',
326
+ 'input': f"{k}",
327
+ 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
328
+ },
329
+ {
330
+ 'prompt_type': 'human_bot',
331
+ 'instruction': f'Explain the following expert setting for Driverless AI',
332
+ 'input': f"{k}",
333
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
334
+ },
335
+ {
336
+ 'prompt_type': 'human_bot',
337
+ 'instruction': f'Explain the following expert setting for Driverless AI',
338
+ 'input': f"{k.replace('_', ' ')}",
339
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
340
+ },
341
+ {
342
+ 'prompt_type': 'human_bot',
343
+ 'instruction': f'Explain the following expert setting for Driverless AI',
344
+ 'input': f"{title}",
345
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
346
+ },
347
+ {
348
+ 'prompt_type': 'human_bot',
349
+ 'instruction': f'Provide a short explanation of the expert setting {k}',
350
+ 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""),
351
+ },
352
+ {
353
+ 'prompt_type': 'human_bot',
354
+ 'instruction': f'Provide a detailed explanation of the expert setting {k}',
355
+ 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""),
356
+ },
357
+ ]
358
+ )
359
+ toml_list = [x for x in toml_list if x]
360
+ with open("config.json", "wt") as f:
361
+ f.write(json.dumps(toml_list, indent=2))
362
+ except Exception as e:
363
+ print("Exception: %s" % str(e), flush=True)
364
+
365
+
366
+ def copy_tree(src, dst, follow_symlink=False):
367
+ makedirs(dst, exist_ok=True)
368
+ for (path, dirs, files) in os.walk(src, followlinks=follow_symlink):
369
+ new_path = path.replace(src, dst)
370
+ makedirs(new_path, exist_ok=True)
371
+ for file in files:
372
+ filename = os.path.join(path, file)
373
+ new_filename = os.path.join(new_path, file)
374
+ # print("%s -> %s" % (filename, new_filename))
375
+ try:
376
+ atomic_copy(filename, new_filename)
377
+ except FileNotFoundError:
378
+ pass
379
+
380
+
381
+ def atomic_move(src, dst):
382
+ try:
383
+ shutil.move(src, dst)
384
+ except (shutil.Error, FileExistsError):
385
+ pass
386
+ remove(src)
387
+
388
+
389
+ def atomic_copy(src=None, dst=None, with_permissions=True):
390
+ if os.path.isfile(dst):
391
+ return
392
+ import uuid
393
+ my_uuid = uuid.uuid4()
394
+ dst_tmp = dst + str(my_uuid)
395
+ makedirs(os.path.dirname(dst), exist_ok=True)
396
+ if with_permissions:
397
+ shutil.copy(src, dst_tmp)
398
+ else:
399
+ shutil.copyfile(src, dst_tmp)
400
+ atomic_move(dst_tmp, dst)
401
+ remove(dst_tmp)
402
+
403
+
404
+ def makedirs(path, exist_ok=True):
405
+ """
406
+ Avoid some inefficiency in os.makedirs()
407
+ :param path:
408
+ :param exist_ok:
409
+ :return:
410
+ """
411
+ if os.path.isdir(path) and os.path.exists(path):
412
+ assert exist_ok, "Path already exists"
413
+ return path
414
+ os.makedirs(path, exist_ok=exist_ok)
415
+
416
+
417
+ ## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json
418
+ ## Turn into simple instruct prompt type. No context/previous conversations.
419
+ def test_prep_instruct_vicuna():
420
+ from datasets import load_dataset
421
+ filename = 'ShareGPT_unfiltered_cleaned_split.json'
422
+ if not os.path.exists(filename):
423
+ os.system('wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename)
424
+ data = load_dataset("json", data_files={"train": filename})["train"]
425
+ training_rows = []
426
+ for i in range(data.num_rows):
427
+ conversations = data[i]['conversations']
428
+ assert isinstance(conversations, list), conversations
429
+ convo = ""
430
+ for j, conv in enumerate(conversations):
431
+ # Get ready for generate.py prompt_type=human_bot
432
+ # But train with prompt_type=plain
433
+ if conv['from'] == 'human':
434
+ FROM = '<human>: '
435
+ elif conv['from'] == 'gpt':
436
+ FROM = '<bot>: '
437
+ convo += f"{FROM}" + conv['value'] + "\n"
438
+ if convo:
439
+ training_rows.append(dict(input=convo))
440
+ with open(filename + ".generate_human_bot.train_plain.json", "wt") as f:
441
+ f.write(json.dumps(training_rows, indent=2))
442
+
443
+ POSTFIX = ".generate_human_bot.train_plain.json"
444
+
445
+ # https://bair.berkeley.edu/blog/2023/04/03/koala/
446
+ OIG_DATASETS = [
447
+ "unified_chip2.jsonl",
448
+ "unified_grade_school_math_instructions.jsonl",
449
+ "unified_poetry_2_song.jsonl",
450
+ "unified_plot_screenplay_books_dialog.jsonl",
451
+ ]
452
+
453
+ # hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4
454
+ ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl',
455
+ 'unified_basic.jsonl',
456
+ 'unified_canadian_parliament.jsonl',
457
+ 'unified_chip2.jsonl',
458
+ 'unified_conv_finqa.jsonl',
459
+ 'unified_cuad.jsonl',
460
+ 'unified_essays.jsonl',
461
+ 'unified_flan.jsonl.gz',
462
+ 'unified_grade_school_math_instructions.jsonl',
463
+ 'unified_hc3_human.jsonl',
464
+ 'unified_image_prompts_instructions.jsonl',
465
+ 'unified_joke_explanations.jsonl',
466
+ 'unified_mathqa_flanv2_kojma_cot.jsonl',
467
+ 'unified_merged_code_xp3.jsonl',
468
+ 'unified_multi_news.jsonl',
469
+ 'unified_multi_sum.jsonl',
470
+ 'unified_ni.jsonl.gz',
471
+ 'unified_nq.jsonl',
472
+ 'unified_openai_summarize_tldr.jsonl',
473
+ 'unified_oscar_en_sample_dialog.jsonl',
474
+ 'unified_p3.jsonl.gz',
475
+ 'unified_plot_screenplay_books_dialog.jsonl',
476
+ 'unified_poetry_2_song.jsonl',
477
+ 'unified_poetry_instructions.jsonl',
478
+ 'unified_rallio_safety_and_prosocial.jsonl',
479
+ 'unified_rallio_soda_upgraded_2048.jsonl',
480
+ 'unified_soda_dialog.jsonl',
481
+ 'unified_sqlv1.jsonl',
482
+ 'unified_sqlv2.jsonl',
483
+ 'unified_squad_v2.jsonl',
484
+ 'unified_squad_v2_more_neg.jsonl',
485
+ 'unified_ul2_plus_oscar_en_sample_dialog.jsonl',
486
+ 'unified_unifiedskg_instructions.jsonl',
487
+ 'unified_unnatural_instructions.jsonl',
488
+ 'unified_xp3_sample.jsonl']
489
+
490
+ useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet',
491
+ 'unified_chip2.jsonl.parquet',
492
+ 'unified_cuad.jsonl.parquet',
493
+ 'unified_essays.jsonl.parquet',
494
+ 'unified_flan.jsonl.gz.parquet',
495
+ 'unified_grade_school_math_instructions.jsonl.parquet',
496
+ 'unified_hc3_human.jsonl.parquet',
497
+ 'unified_mathqa_flanv2_kojma_cot.jsonl.parquet',
498
+ 'unified_merged_code_xp3.jsonl.parquet',
499
+ 'unified_multi_news.jsonl.parquet',
500
+ #'unified_multi_sum.jsonl.parquet'
501
+ 'unified_ni.jsonl.gz.parquet',
502
+ 'unified_openai_summarize_tldr.jsonl.parquet',
503
+ #'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific
504
+ 'unified_plot_screenplay_books_dialog.jsonl.parquet',
505
+ 'unified_soda_dialog.jsonl.parquet',
506
+ 'unified_unnatural_instructions.jsonl.parquet',
507
+ ]
508
+
509
+
510
+ @pytest.mark.parametrize("filename", OIG_DATASETS)
511
+ def test_get_small_sample_oig_data(filename):
512
+ if not os.path.exists(filename):
513
+ os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
514
+ import json
515
+ rows = []
516
+ with open(filename, "r") as f:
517
+ for line in f.readlines():
518
+ row = json.loads(line)
519
+ rows.append(dict(input=row["text"]))
520
+ with open(filename + POSTFIX, "w") as f:
521
+ f.write(json.dumps(rows, indent=2))
522
+
523
+
524
+ @pytest.mark.parametrize("filename", ALL_OIG_DATASETS)
525
+ def test_download_useful_data_as_parquet(filename):
526
+ dest_file = filename + '.parquet'
527
+ if dest_file not in useful_oig_files:
528
+ pytest.skip('file declared not useful')
529
+ if not os.path.exists(filename):
530
+ os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename)
531
+ if not os.path.exists(dest_file):
532
+ df = pd.read_json(path_or_buf=filename, lines=True)
533
+ df.to_parquet(dest_file, index=False)
534
+
535
+
536
+ def test_merge_shuffle_small_sample_oig_data():
537
+ np.random.seed(1234)
538
+ rows = []
539
+ for filename in OIG_DATASETS:
540
+ with open(filename + POSTFIX, "r") as f:
541
+ rows.extend(json.loads(f.read()))
542
+ np.random.shuffle(rows)
543
+ with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f:
544
+ f.write(json.dumps(rows, indent=2))
545
+
546
+
547
+ def test_join_jsons():
548
+ files = ['config.json'] * 1 + \
549
+ ['dai_docs.train_cleaned.json'] * 2 + \
550
+ ['dai_faq.json'] * 3
551
+ print(files)
552
+ lst = []
553
+ [lst.extend(json.load(open(fil, 'rt'))) for fil in files]
554
+ print(len(lst))
555
+ json.dump(lst, open("merged.json", "wt"), indent=2)
556
+
557
+
558
+ @pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf'])
559
+ def test_make_rlhf_good_data(filename):
560
+ from datasets import load_dataset
561
+ rows = load_dataset(filename)["train"]["chosen"]
562
+ new_rows = []
563
+ for row in rows:
564
+ if row[:2] == "\n\n":
565
+ row = row[2:]
566
+ row = row.replace("Human: ", "<human>: ")
567
+ row = row.replace("Assistant: ", "<bot>: ")
568
+ new_rows.append(dict(input=row))
569
+ with open(filename.replace("/", "_") + POSTFIX, "w") as f:
570
+ f.write(json.dumps(new_rows, indent=2))
571
+
572
+
573
+
574
+ def test_show_prompts():
575
+ files = ['config.json'] * 1 + \
576
+ ['dai_docs.train_cleaned.json'] * 1 + \
577
+ ['dai_faq.json'] * 1
578
+ file_points = [json.load(open(fil, 'rt')) for fil in files]
579
+ from prompter import generate_prompt
580
+ for data_points in file_points:
581
+ for data_point in data_points:
582
+ print(generate_prompt(data_point, 'plain', False, False)[0])
583
+
584
+
585
+ def test_get_open_datasets():
586
+ # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter
587
+ open_tags = ['license:Apache License 2.0',
588
+ 'license:mit',
589
+ 'license:apache',
590
+ 'license:apache2',
591
+ 'license:apache-2.0',
592
+ 'license:bsd',
593
+ 'license:bsd-2-clause',
594
+ 'license:bsd-3-clause',
595
+ 'license:bsd-3-clause-clear',
596
+ 'license:lgpl-2.1',
597
+ 'license:lgpl-3.0',
598
+ 'license:lgpl-lr',
599
+ 'license:lgpl',
600
+ 'license:openrail++',
601
+ 'license:openrail',
602
+ 'license:bigscience-bloom-rail-1.0',
603
+ #'license:agpl-3.0',
604
+ 'license:other',
605
+ 'license:unknown',
606
+ # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution
607
+ # Attribution required:
608
+ 'license:odc-by',
609
+ 'license:cc-by-4.0',
610
+ 'license:cc-by-3.0',
611
+ 'license:cc-by-2.0',
612
+ 'license:cc-by-2.5',
613
+ #'license:cc-by-sa-4.0', # would require same license
614
+ 'license:odbl',
615
+ 'license:pddl',
616
+ 'license:ms-pl',
617
+ 'license:zlib',
618
+ ]
619
+ # bad license: cc-by-nc-4.0
620
+
621
+ from huggingface_hub import list_datasets
622
+ datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags])
623
+ datasets += [x for x in list_datasets(author='openai')]
624
+ # check all:
625
+ all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets]))
626
+ print(len(all_license_tags))
627
+ open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)]
628
+ print('open_datasets', len(open_datasets))
629
+ all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets]))
630
+ print('all_task_tags', len(all_task_tags))
631
+ excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval',
632
+ 'translation', 'identification', 'object', 'mask', 'to-text',
633
+ 'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est',
634
+ 'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice',
635
+ 'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml',
636
+ 'feature-extraction', 'keyword-spotting',
637
+ 'coreference-resolution', 'segmentation',
638
+ 'word-sense-disambiguation',
639
+ 'lemmatization']
640
+ task_tags = [x.replace('task_categories:', '').replace('task_ids:', '')
641
+ for x in all_task_tags if not any([y in x for y in
642
+ excluded_tags])]
643
+ print('task_tags', len(task_tags))
644
+ # str(x.tags) to catch any pattern match to anything in list
645
+ open_tasked_datasets = [x for x in open_datasets if
646
+ any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and
647
+ not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or
648
+ 'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)]
649
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled]
650
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated]
651
+ open_tasked_datasets = [x for x in open_tasked_datasets if not x.private]
652
+ print('open_tasked_datasets', len(open_tasked_datasets))
653
+ sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets])))
654
+ languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets])))
655
+ open_english_tasked_datasets = [x for x in open_tasked_datasets if
656
+ 'language:' not in str(x.tags) or
657
+ 'language:en' in str(x.tags)]
658
+ small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if
659
+ 'n<1K' in str(x.tags) or
660
+ '1K<n<10K' in str(x.tags) or
661
+ '1K0<n<100K' in str(x.tags) or
662
+ '100K<n<1M' in str(x.tags) or
663
+ 'size_category' not in str(x.tags)
664
+ ]
665
+ # 'aeslc' : email_body, subject -> summarization?
666
+ # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas()
667
+ ids = [x.id for x in small_open_english_tasked_datasets]
668
+
669
+ # sanity checks
670
+ # https://bair.berkeley.edu/blog/2023/04/03/koala/
671
+ assert 'alespalla/chatbot_instruction_prompts' in ids
672
+ assert 'laion/OIG' in ids
673
+ assert 'openai/webgpt_comparisons' in ids
674
+ assert 'openai/summarize_from_feedback' in ids
675
+ assert 'Anthropic/hh-rlhf' in ids
676
+
677
+ # useful but not allowed for commercial purposes:
678
+ # https://huggingface.co/datasets/squad
679
+
680
+ print('open_english_tasked_datasets: ', ids, flush=True)
681
+
682
+ exclude_ids = ['allenai/nllb', # translation only
683
+ 'hf-internal-testing/fixtures_image_utils', # testing
684
+ 'allenai/c4', # search-url
685
+ 'agemagician/uniref50', # unknown
686
+ 'huggingface-course/documentation-images', # images
687
+ 'smilegate-ai/kor_unsmile', # korean
688
+ 'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/
689
+ 'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT
690
+ 'Jeska/vaccinchat', # not useful
691
+ 'alespalla/chatbot_instruction_prompts', # mixes alpaca
692
+ 'allenai/prosocial-dialog', # already exlucded, but wrongly in other datasets that say more permissive license
693
+ 'AlekseyKorshuk/persona-chat', # low quality
694
+ 'bavard/personachat_truecased', # low quality
695
+ 'adamlin/daily_dialog', # medium quality conversations
696
+ 'adamlin/FewShotWoz', # low quality
697
+ 'benjaminbeilharz/better_daily_dialog', # low quality
698
+ 'benjaminbeilharz/daily_dialog_w_turn_templates', # low
699
+ 'benjaminbeilharz/empathetic_dialogues_for_lm', # low
700
+ 'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA
701
+ 'ia-bentebib/conv_ai_2_fr', # low fr
702
+ 'ia-bentebib/daily_dialog_fr', # low fr
703
+ 'ia-bentebib/dialog_re_fr', # low fr
704
+ 'ia-bentebib/empathetic_dialogues_fr', # low fr
705
+ 'roskoN/dailydialog', # low
706
+ 'VadorMazer/skyrimdialogstest', # low
707
+ 'bigbio/med_qa', # med specific Q/A
708
+ 'biu-nlp/qa_srl2018', # low quality Q/A
709
+ 'biu-nlp/qa_discourse', # low quality Q/A
710
+ 'iarfmoose/qa_evaluator', # low quality Q/A
711
+ 'jeopardy', # low quality Q/A -- no reasoning
712
+ 'narrativeqa', # low quality Q/A
713
+ 'nomic-ai/gpt4all_prompt_generations', # bad license
714
+ 'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license
715
+ 'HuggingFaceH4/alpaca', # bad license
716
+ 'tatsu-lab/alpaca', # ToS breaking
717
+ 'yahma/alpaca-cleaned', # ToS breaking
718
+ 'Hello-SimpleAI/HC3', # bad license
719
+ 'glue', # no reasoning QA
720
+ 'sahil2801/CodeAlpaca-20k', # bad license
721
+ 'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A
722
+ ]
723
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids]
724
+ # some ids clearly speech related
725
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id]
726
+ # HF testing
727
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'hf-internal-testing' not in x.id]
728
+ small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if
729
+ 'chinese' not in x.id]
730
+
731
+ sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets],
732
+ key=lambda x: x[0], reverse=True)
733
+
734
+ # NOTES:
735
+ # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log
736
+ # See what needs config passed and add:
737
+ # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S
738
+ # grep "pip install" getdata9.log
739
+ # NOTE: Some datasets have default config, but others are there. Don't know how to access them.
740
+
741
+
742
+ """
743
+ https://huggingface.co/datasets/wikihow/blob/main/wikihow.py
744
+ https://github.com/mahnazkoupaee/WikiHow-Dataset
745
+ https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
746
+ https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358
747
+ """
748
+
749
+ """
750
+ # some ambiguous or non-commercial datasets
751
+ https://github.com/PhoebusSi/alpaca-CoT
752
+ """
753
+
754
+ timeout = 3 * 60
755
+ # laion/OIG takes longer
756
+ for num_downloads, dataset in sorted_small_open_english_tasked_datasets:
757
+ data_id = dataset.id
758
+ func = do_one
759
+ args = (data_id, num_downloads)
760
+ kwargs = {}
761
+ with ProcessPoolExecutor(max_workers=1) as executor:
762
+ future = executor.submit(func, *args, **kwargs)
763
+ try:
764
+ future.result(timeout=timeout)
765
+ except concurrent.futures.TimeoutError:
766
+ print("\n\ndata_id %s timeout\n\n" % data_id, flush=True)
767
+ for child in psutil.Process(os.getpid()).children(recursive=True):
768
+ os.kill(child.pid, signal.SIGINT)
769
+ os.kill(child.pid, signal.SIGTERM)
770
+ os.kill(child.pid, signal.SIGKILL)
771
+
772
+
773
+ def do_one(data_id, num_downloads):
774
+ from datasets import load_dataset
775
+ out_file = "data_%s.parquet" % str(data_id.replace('/', '_'))
776
+ if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024**3:
777
+ return
778
+ try:
779
+ print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True)
780
+ avail_list = None
781
+ try:
782
+ data = load_dataset(data_id, 'foobar')
783
+ except Exception as e:
784
+ if 'Available: ' in str(e):
785
+ avail_list = ast.literal_eval(str(e).split('Available:')[1].strip())
786
+ else:
787
+ avail_list = None
788
+ if avail_list is None:
789
+ avail_list = [None]
790
+ print("%s avail_list: %s" % (data_id, avail_list), flush=True)
791
+
792
+ for name in avail_list:
793
+ out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name))
794
+ if os.path.isfile(out_file):
795
+ continue
796
+ data = load_dataset(data_id, name)
797
+ column_names_dict = data.column_names
798
+ column_names = column_names_dict[list(column_names_dict.keys())[0]]
799
+ print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names),
800
+ flush=True)
801
+ data_dict = data.data
802
+ col_dict = data.num_columns
803
+ first_col = list(col_dict.keys())[0]
804
+ if 'train' in data_dict:
805
+ df = data['train'].to_pandas()
806
+ else:
807
+ df = data[first_col].to_pandas()
808
+ # csv has issues with escaping chars, even for datasets I know I want
809
+ df.to_parquet(out_file, index=False)
810
+ except Exception as e:
811
+ t, v, tb = sys.exc_info()
812
+ ex = ''.join(traceback.format_exception(t, v, tb))
813
+ print("Exception: %s %s" % (data_id, ex), flush=True)
814
+
815
+
816
+ def test_otherlic():
817
+ from huggingface_hub import list_datasets
818
+ lic = ['license:odc-by',
819
+ 'license:cc-by-4.0',
820
+ 'license:cc-by-3.0',
821
+ 'license:cc-by-2.0',
822
+ 'license:cc-by-2.5',
823
+ 'license:cc-by-sa-4.0',
824
+ 'license:odbl',
825
+ 'license:pddl',
826
+ 'license:ms-pl',
827
+ 'license:zlib',
828
+ ]
829
+ datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic])
830
+ print(len(datasets))
831
+
832
+
833
+ # These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile
834
+ # grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog
835
+ useful = ['Dahoas/instruct-human-assistant-prompt',
836
+ 'Dahoas/first-instruct-human-assistant-prompt',
837
+ 'knkarthick/dialogsum', # summary of conversation
838
+ 'McGill-NLP/FaithDial', # medium quality
839
+ 'Zaid/quac_expanded', # medium quality context + QA
840
+ '0-hero/OIG-small-chip2', # medium
841
+ 'alistvt/coqa-flat', # QA medium
842
+ 'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium
843
+ 'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf
844
+ 'arjunth2001/online_privacy_qna', # good quality QA
845
+ 'Dahoas/instruct_helpful_preferences', # medium quality instruct
846
+ 'Dahoas/rl-prompt-dataset', # medium chat
847
+ 'Dahoas/rm-static', # medium chat
848
+ 'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct
849
+ 'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat
850
+ 'eli5', # QA if prompt ELI5
851
+ 'gsm8k', # QA (various)
852
+ 'guanaco/guanaco', # prompt/response
853
+ 'kastan/rlhf-qa-comparisons', # good QA
854
+ 'kastan/rlhf-qa-conditional-generation-v2', # prompt answer
855
+ 'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA
856
+ 'OllieStanley/humaneval-mbpp-testgen-qa', # code QA
857
+ 'Graverman/Instruct-to-Code', # code QA
858
+ 'openai/summarize_from_feedback', # summarize
859
+ 'relbert/analogy_questions', # analogy QA
860
+ 'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected.
861
+ 'yizhongw/self_instruct', # instruct (super natural & instruct)
862
+ 'HuggingFaceH4/asss', # QA, big A
863
+ 'kastan/rlhf-qa-conditional-generation-v2', # QA
864
+ 'cosmos_qa', # context QA
865
+ 'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text
866
+ 'squadshifts', # QA from context
867
+ 'hotpot_qa', # QA from context
868
+ 'adversarial_qa', # QA from context
869
+ 'allenai/soda', # dialog -> narrative/summary
870
+ 'squad_v2', # context QA
871
+ 'squadshifts', # context QA
872
+ 'dferndz/cSQuAD1', # context QA
873
+ 'dferndz/cSQuAD2', # context QA
874
+ 'din0s/msmarco-nlgen', # context QA
875
+ 'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia
876
+ 'hotpot_qa', # context, QA
877
+ 'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning
878
+ 'kastan/EE_QA_for_RLHF', # context QA
879
+ 'KK04/LogicInference_OA', # instruction logical QA
880
+ 'lmqg/qa_squadshifts_synthetic', # context QA
881
+ 'lmqg/qg_squad', # context QA
882
+ 'lmqg/qg_squadshifts', # context QA
883
+ 'lmqg/qg_subjqa', # context QA
884
+ 'pszemraj/HC3-textgen-qa', # QA medium, has human responses -- humans tend to provide links instead of trying to answer
885
+ 'pythonist/newdata', # long context, QA, brief A
886
+ 'ropes', # long background, situation, question, A
887
+ 'wikitablequestions', # table -> QA
888
+ 'bigscience/p3', # context QA but short answers
889
+ ]
890
+
891
+
892
+
893
+ code_useful = ['0n1xus/codexglue',
894
+ 'openai_humaneval',
895
+ 'koutch/staqc',
896
+ ]
897
+
898
+
899
+ maybe_useful = ['AlekseyKorshuk/comedy-scripts',
900
+ 'openbookqa', # hard to parse, low reasoning
901
+ 'qed', # reasonable QA, but low reasoning
902
+ 'selqa', # candidate answers
903
+ 'HuggingFaceH4/instruction-pilot-outputs-filtered',
904
+ 'GBaker/MedQA-USMLE-4-options', # medical QA with long questions
905
+ 'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality
906
+ ]
907
+
908
+
909
+ summary_useful = ['austin/rheum_abstracts',
910
+ 'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected
911
+ 'CarperAI/openai_summarize_tldr', # summarize QA
912
+ 'ccdv/cnn_dailymail', # summarize news
913
+ 'ccdv/govreport-summarization', # summarize high quality
914
+ 'ccdv/pubmed-summarization', # summarize high quality
915
+ 'duorc', # plot -> QA
916
+ 'farleyknight/big_patent_5_percent', # desc -> abstract
917
+ 'multi_news', # summary
918
+ 'opinosis',
919
+ 'SophieTr/reddit_clean',
920
+ 'allenai/mup', # long text -> summary
921
+ 'allenai/multi_lexsum', # long text -> summary
922
+ 'big_patent',
923
+ 'allenai/wcep_dense_max',
924
+ 'awinml/costco_long_practice',
925
+ 'GEM/xsum',
926
+ 'ratishsp/newshead',
927
+ 'RussianNLP/wikiomnia', # russian
928
+ 'stacked-summaries/stacked-xsum-1024',
929
+ ]
930
+
931
+
932
+ math_useful = [
933
+ 'competition_math'
934
+ ]
935
+
936
+
937
+ skipped = ['c4', # maybe useful, used for flan, but skipped due to size
938
+ ]
939
+
940
+ """
941
+ To get training data from oig:
942
+ pytest test_oig test_grade_final test_finalize_to_json
943
+ """
944
+
945
+ human = '<human>:'
946
+ bot = '<bot>:'
947
+
948
+
949
+ def test_assemble_and_detox():
950
+ import re
951
+ from profanity_check import predict_prob
952
+ df_list = []
953
+ for data in useful_oig_files:
954
+ print("Processing %s" % data, flush=True)
955
+ df = pd.read_parquet(data)
956
+ df = df.reset_index(drop=True)
957
+ # chop up into human/bot interactions of no more than 10kB per row
958
+ text_list = df[['text']].values.ravel().tolist()
959
+ new_text = []
960
+ max_len = 2048 # uber cutoff
961
+ MAX_LEN = 2048//2 - 30 # max len per question/answer
962
+ for text in tqdm(text_list):
963
+ human_starts = [m.start() for m in re.finditer('<human>: ', text)]
964
+ if len(human_starts) == 1:
965
+ human_starts = [0, len(text)] # always go into for loop below
966
+ blurb = ''
967
+ for i in range(len(human_starts) - 1):
968
+ interaction = text[human_starts[i]: human_starts[i+1]][:max_len]
969
+ blurb += interaction
970
+ if len(blurb) >= MAX_LEN:
971
+ blurb = get_sentences(blurb, length=MAX_LEN)[0]
972
+ new_text.append(blurb + "\n<human>:")
973
+ blurb = ''
974
+ if blurb:
975
+ blurb = get_sentences(blurb, length=MAX_LEN)[0]
976
+ new_text.append(blurb + "\n<human>:")
977
+
978
+ if len(new_text) > len(text_list):
979
+ print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0]))
980
+ df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)})
981
+ df = df.drop_duplicates(keep='first')
982
+ print(df['text'].apply(lambda x: len(x)).describe())
983
+ assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len
984
+
985
+ # faster than better_profanity, do early
986
+ df['profanity'] = predict_prob(df['text'])
987
+ before_rows = df.shape[0]
988
+ df = df[df['profanity'] < 0.25] # drop any low quality stuff
989
+ after_rows = df.shape[0]
990
+ print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows))
991
+ df_list.append(df)
992
+ print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
993
+ print("So far have %d rows" % sum([len(x) for x in df_list]))
994
+ df_final = pd.concat(df_list)
995
+ df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True)
996
+ df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False)
997
+
998
+
999
+ def test_basic_cleaning():
1000
+ # from better_profanity import profanity
1001
+ # https://pypi.org/project/alt-profanity-check/
1002
+ from profanity_check import predict
1003
+ df_list = []
1004
+ for data in useful_oig_files:
1005
+ #for data in useful_oig_files[:5]:
1006
+ #for data in ['unified_openai_summarize_tldr.jsonl.parquet']:
1007
+ print("Processing %s" % data, flush=True)
1008
+ df = pd.read_parquet(data)
1009
+ df = df.reset_index(drop=True)
1010
+ # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired
1011
+ #avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot))
1012
+ df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot))/2.0)
1013
+ df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot))
1014
+ #df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x))
1015
+ #low_quality_patterns = ['Write the rest of this wikipedia article']
1016
+ res = predict(df['text'])
1017
+ df['bad_words'] = res
1018
+ df = df.reset_index(drop=True)
1019
+ df = df[df['bad_words'] == 0]
1020
+ df = df[['text', 'avg_words', 'avg_bot_words']]
1021
+ df = df.drop_duplicates(keep='first')
1022
+ print(df[df['avg_words'] == df['avg_words'].max()]['text'].values)
1023
+ median_words = np.median(df['avg_words'])
1024
+ min_words_per_entity = max(30, 0.8 * median_words)
1025
+ max_words_per_entity = 2048 # too hard to learn from for now
1026
+ df = df[df['avg_words'] > min_words_per_entity]
1027
+ df = df[df['avg_words'] < max_words_per_entity]
1028
+
1029
+ min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now
1030
+ max_words_per_entity = 2048 # too hard to learn from for now
1031
+ df = df[df['avg_bot_words'] > min_words_per_entity]
1032
+ df = df[df['avg_bot_words'] < max_words_per_entity]
1033
+
1034
+ df_list.append(df)
1035
+ print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True)
1036
+ df_final = pd.concat(df_list)
1037
+ df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False)
1038
+
1039
+
1040
+ from joblib import Parallel, delayed, effective_n_jobs
1041
+ from sklearn.utils import gen_even_slices
1042
+ from sklearn.utils.validation import _num_samples
1043
+
1044
+
1045
+ def parallel_apply(df, func, n_jobs=-1, **kwargs):
1046
+ """ Pandas apply in parallel using joblib.
1047
+ Uses sklearn.utils to partition input evenly.
1048
+
1049
+ Args:
1050
+ df: Pandas DataFrame, Series, or any other object that supports slicing and apply.
1051
+ func: Callable to apply
1052
+ n_jobs: Desired number of workers. Default value -1 means use all available cores.
1053
+ **kwargs: Any additional parameters will be supplied to the apply function
1054
+
1055
+ Returns:
1056
+ Same as for normal Pandas DataFrame.apply()
1057
+
1058
+ """
1059
+
1060
+ if effective_n_jobs(n_jobs) == 1:
1061
+ return df.apply(func, **kwargs)
1062
+ else:
1063
+ ret = Parallel(n_jobs=n_jobs)(
1064
+ delayed(type(df).apply)(df[s], func, **kwargs)
1065
+ for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs)))
1066
+ return pd.concat(ret)
1067
+
1068
+
1069
+ def add_better_profanity_flag(df):
1070
+ from better_profanity import profanity
1071
+ df['better_profanity'] = parallel_apply(
1072
+ df['text'],
1073
+ lambda x: profanity.contains_profanity(x),
1074
+ n_jobs=-1,
1075
+ )
1076
+ return df
1077
+
1078
+
1079
+ def add_textstat_grade(df):
1080
+ import textstat
1081
+
1082
+ def myfunc(x):
1083
+ return textstat.flesch_kincaid_grade(x) # simple grade
1084
+
1085
+ if False:
1086
+ import dask.dataframe as dd
1087
+ # 40 seconds for 1000 rows, but have 1,787,799 rows
1088
+ ddata = dd.from_pandas(df, npartitions=120)
1089
+
1090
+ df['flesch_grade'] = ddata['text'].apply(myfunc).compute()
1091
+ if True:
1092
+ # fast way
1093
+ df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1)
1094
+ return df
1095
+
1096
+
1097
+ def add_deberta_grade(df):
1098
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
1099
+ import torch
1100
+ reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2"
1101
+ rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained(
1102
+ reward_name), AutoTokenizer.from_pretrained(reward_name)
1103
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
1104
+ rank_model.to(device)
1105
+
1106
+ def get_question(x):
1107
+ return x.replace('<human>: ', '').split('<bot>:')[0]
1108
+
1109
+ def get_answer(x):
1110
+ try:
1111
+ answer = x.split('<bot>: ')[1].split('<human>:')[0].replace('<bot>: ', '')
1112
+ except:
1113
+ answer = x.split('<bot>:')[1].split('<human>:')[0].replace('<bot>:', '')
1114
+ return answer
1115
+
1116
+ df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1)
1117
+ df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1)
1118
+
1119
+ from datasets import Dataset
1120
+ from transformers import pipeline
1121
+ from transformers.pipelines.pt_utils import KeyPairDataset
1122
+ import tqdm
1123
+
1124
+ pipe = pipeline(
1125
+ "text-classification",
1126
+ model=reward_name,
1127
+ device="cuda:0" if torch.cuda.is_available() else "cpu"
1128
+ )
1129
+ start = 0
1130
+ batch_size = 64 * 16
1131
+ micro_batch = orig_micro_batch = 16
1132
+ end = 0
1133
+ import socket
1134
+ checkpoint = "grades.%s.pkl" % socket.gethostname()
1135
+ grades = []
1136
+ import pickle
1137
+ if os.path.exists(checkpoint):
1138
+ with open(checkpoint, "rb") as f:
1139
+ start, grades = pickle.loads(f.read())
1140
+ last_oom = 0
1141
+ while end < df.shape[0]:
1142
+ # manual batching to handle OOM more gracefully
1143
+ end = min(start + batch_size, df.shape[0])
1144
+ if start == end:
1145
+ break
1146
+ dataset = Dataset.from_pandas(df.iloc[start:end, :])
1147
+ try:
1148
+ grades.extend([
1149
+ x['score'] for x in tqdm.tqdm(
1150
+ pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch)
1151
+ )
1152
+ ])
1153
+ except torch.cuda.OutOfMemoryError:
1154
+ last_oom = start
1155
+ micro_batch = max(1, micro_batch // 2)
1156
+ print("OOM - retrying with micro_batch=%d" % micro_batch)
1157
+ continue
1158
+ if last_oom == start:
1159
+ micro_batch = orig_micro_batch
1160
+ print("Returning to micro_batch=%d" % micro_batch)
1161
+ assert len(grades) == end
1162
+ start = end
1163
+ with open(checkpoint, "wb") as f:
1164
+ f.write(pickle.dumps((end, grades)))
1165
+ print("%d/%d" % (end, df.shape[0]))
1166
+ df['grade_deberta'] = grades
1167
+ if os.path.exists(checkpoint):
1168
+ os.remove(checkpoint)
1169
+ return df
1170
+
1171
+
1172
+ def test_chop_by_lengths():
1173
+ file = "h2oGPT.cleaned.human_bot.shorter.parquet"
1174
+ df = pd.read_parquet(file).reset_index(drop=True)
1175
+ df = count_human_bot_lengths(df)
1176
+ df['rand'] = np.random.rand(df.shape[0])
1177
+ df['rand2'] = np.random.rand(df.shape[0])
1178
+ before_rows = df.shape[0]
1179
+ # throw away short human/bot responses with higher likelihood
1180
+ df = df[(df['len_human_mean'] > 20)] # never keep very short ones
1181
+ df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)]
1182
+ df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)]
1183
+ df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones
1184
+ df = df[(df['len_bot_mean'] > 20)] # never keep very short ones
1185
+ df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)]
1186
+ df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)]
1187
+ df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones
1188
+ assert df['text'].apply(lambda x: len(x)).max() < 20000
1189
+ df = df.drop(['rand', 'rand2'], axis=1)
1190
+ after_rows = df.shape[0]
1191
+ print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows))
1192
+ print(df.describe())
1193
+ df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False)
1194
+
1195
+
1196
+ def count_human_bot_lengths(df, human=None, bot=None):
1197
+ import re
1198
+ len_human_min = []
1199
+ len_human_max = []
1200
+ len_human_mean = []
1201
+ len_bot_min = []
1202
+ len_bot_max = []
1203
+ len_bot_mean = []
1204
+ human = human or '<human>:'
1205
+ bot = bot or '<bot>:'
1206
+ for is_human in [True, False]:
1207
+ what = human if is_human else bot
1208
+ other = human if not is_human else bot
1209
+ for i in range(df.shape[0]):
1210
+ text = df.loc[i, 'text']
1211
+ assert isinstance(text, str)
1212
+ starts = [m.start() for m in re.finditer(what, text)]
1213
+ if len(starts) == 1:
1214
+ starts = [starts[0], len(text)] # always go into for loop below
1215
+ assert len(text)
1216
+ list_what = []
1217
+ for ii in range(len(starts) - 1):
1218
+ interaction = text[starts[ii]: starts[ii+1]]
1219
+ if other in interaction:
1220
+ interaction = interaction[:interaction.find(other)]
1221
+ interaction.strip()
1222
+ list_what.append(interaction)
1223
+ if not list_what:
1224
+ list_what = [''] # handle corrupted data, very rare, leads to sizes 0
1225
+ if is_human:
1226
+ len_human_min.append(min([len(x) for x in list_what]))
1227
+ len_human_max.append(max([len(x) for x in list_what]))
1228
+ len_human_mean.append(np.mean([len(x) for x in list_what]))
1229
+ else:
1230
+ len_bot_min.append(min([len(x) for x in list_what]))
1231
+ len_bot_max.append(max([len(x) for x in list_what]))
1232
+ len_bot_mean.append(np.mean([len(x) for x in list_what]))
1233
+ df['len_human_min'] = len_human_min
1234
+ df['len_human_max'] = len_human_max
1235
+ df['len_human_mean'] = len_human_mean
1236
+ df['len_bot_min'] = len_bot_min
1237
+ df['len_bot_max'] = len_bot_max
1238
+ df['len_bot_mean'] = len_bot_mean
1239
+ np.random.seed(1234)
1240
+ pd.set_option('display.max_columns', None)
1241
+ print("Before chopping")
1242
+ print(df.describe())
1243
+ return df
1244
+
1245
+
1246
+ def test_grade():
1247
+ df = None
1248
+
1249
+ file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet"
1250
+ output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet"
1251
+ if not os.path.exists(output_file):
1252
+ if df is None:
1253
+ df = pd.read_parquet(file).reset_index(drop=True)
1254
+ df = add_textstat_grade(df)
1255
+ min_grade = 10
1256
+ max_grade = 25
1257
+ df = df[df['flesch_grade'] >= min_grade]
1258
+ df = df[df['flesch_grade'] <= max_grade]
1259
+ print("After Flesch grade")
1260
+ print(df.describe())
1261
+ df.to_parquet(output_file, index=False)
1262
+
1263
+ file = output_file
1264
+ output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet"
1265
+ if not os.path.exists(output_file):
1266
+ # slower than alt-profanity, do last, but do before deberta grading, since that's slower
1267
+ if df is None:
1268
+ df = pd.read_parquet(file).reset_index(drop=True)
1269
+ df = add_better_profanity_flag(df)
1270
+ before_rows = df.shape[0]
1271
+ df = df[df['better_profanity'] == 0]
1272
+ df = df.drop(['better_profanity'], axis=1)
1273
+ after_rows = df.shape[0]
1274
+ print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows))
1275
+ print(df.describe())
1276
+ df.to_parquet(output_file, index=False)
1277
+
1278
+ file = output_file
1279
+ output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet'
1280
+ if not os.path.exists(output_file):
1281
+ if df is None:
1282
+ df = pd.read_parquet(file).reset_index(drop=True)
1283
+ df = add_deberta_grade(df)
1284
+ min_grade = 0.3
1285
+ max_grade = np.inf
1286
+ before_rows = df.shape[0]
1287
+ df = df[df['grade_deberta'] >= min_grade]
1288
+ df = df[df['grade_deberta'] <= max_grade]
1289
+ after_rows = df.shape[0]
1290
+ print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1291
+ print("After DeBERTa grade")
1292
+ print(df.describe())
1293
+ df.to_parquet(output_file, index=False)
1294
+
1295
+ file = output_file
1296
+ output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet'
1297
+ if df is None:
1298
+ df = pd.read_parquet(file).reset_index(drop=True)
1299
+ df.to_parquet(output_file, index=False)
1300
+
1301
+
1302
+ @pytest.mark.parametrize(
1303
+ "fixup_personality, only_personality, deberta_grading",
1304
+ [
1305
+ [False, False, False],
1306
+ [True, True, False],
1307
+ [True, False, False],
1308
+ [True, False, True],
1309
+ ]
1310
+ )
1311
+ def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True):
1312
+ """
1313
+ Flatten tree structure into one row per path from root to leaf
1314
+ Also turn into human_bot prompting format:
1315
+ <human>: question\n<bot>: answer <human>: question2\n<bot>: answer2 Etc.
1316
+ Also saves a .json locally as side-effect
1317
+ returns list of dicts, containing intput, prompt_type and source
1318
+ """
1319
+ from datasets import load_dataset
1320
+ data_file = "OpenAssistant/oasst1"
1321
+ ds = load_dataset(data_file)
1322
+ df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0)
1323
+ rows = {}
1324
+ message_ids = df['message_id'].values.tolist()
1325
+ message_tree_ids = df['message_tree_id'].values.tolist()
1326
+ parent_ids = df['parent_id'].values.tolist()
1327
+ texts = df['text'].values.tolist()
1328
+ roles = df['role'].values.tolist()
1329
+
1330
+ for i in range(df.shape[0]):
1331
+ # collect all trees
1332
+ message_id = message_ids[i]
1333
+ message_tree_id = message_tree_ids[i]
1334
+ parent_id = parent_ids[i]
1335
+ text = texts[i]
1336
+ if fixup_personality:
1337
+ text = text.replace("Open Assistant", "h2oGPT")
1338
+ text = text.replace("Open-Assistant", "h2oGPT")
1339
+ text = text.replace("open-assistant", "h2oGPT")
1340
+ text = text.replace("OpenAssistant", "h2oGPT")
1341
+ text = text.replace("open assistant", "h2oGPT")
1342
+ text = text.replace("Open Assistand", "h2oGPT")
1343
+ text = text.replace("Open Assitant", "h2oGPT")
1344
+ text = text.replace("Open Assistent", "h2oGPT")
1345
+ text = text.replace("Open Assisstant", "h2oGPT")
1346
+ text = text.replace("Open Assitent", "h2oGPT")
1347
+ text = text.replace("Open Assitiant", "h2oGPT")
1348
+ text = text.replace("Open Assistiant", "h2oGPT")
1349
+ text = text.replace("Open Assitan ", "h2oGPT ")
1350
+ text = text.replace("Open Assistan ", "h2oGPT ")
1351
+ text = text.replace("Open Asistant", "h2oGPT")
1352
+ text = text.replace("Open Assiant", "h2oGPT")
1353
+ text = text.replace("Assistant", "h2oGPT")
1354
+ text = text.replace("LAION AI", "H2O.ai")
1355
+ text = text.replace("LAION-AI", "H2O.ai")
1356
+ text = text.replace("LAION,", "H2O.ai,")
1357
+ text = text.replace("LAION.ai", "H2O.ai")
1358
+ text = text.replace("LAION.", "H2O.ai.")
1359
+ text = text.replace("LAION", "H2O.ai")
1360
+
1361
+ role = roles[i]
1362
+ new_data = ('<human>: ' if role == 'prompter' else '<bot>: ') + text
1363
+ entry = dict(message_id=message_id, parent_id=parent_id, text=new_data)
1364
+ if message_tree_id not in rows:
1365
+ rows[message_tree_id] = [entry]
1366
+ else:
1367
+ rows[message_tree_id].append(entry)
1368
+
1369
+ all_rows = []
1370
+
1371
+ for node_id in rows:
1372
+ # order responses in tree, based on message/parent relationship
1373
+ conversations = []
1374
+
1375
+ list_msgs = rows[node_id]
1376
+ # find start
1377
+ while len(list_msgs):
1378
+ for i, leaf in enumerate(list_msgs):
1379
+ found = False
1380
+ parent_id = leaf['parent_id']
1381
+ if parent_id is None:
1382
+ # conversation starter
1383
+ conversations.append(leaf)
1384
+ found = True
1385
+ else:
1386
+ for conv in conversations:
1387
+ # find all conversations to add my message to
1388
+ if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]:
1389
+ # my message doesn't follow conversation
1390
+ continue
1391
+ if parent_id == conv['message_id'][-len(parent_id):]:
1392
+ # my message follows conversation, but fork first, so another follow-on message can do same
1393
+ conversations.append(conv.copy())
1394
+ conv['text'] += f"""
1395
+ {leaf['text']}
1396
+ """
1397
+ conv['message_id'] += leaf['message_id']
1398
+ found = True
1399
+ break
1400
+ if found:
1401
+ # my content was used, so nuke from list
1402
+ del list_msgs[i]
1403
+ break
1404
+
1405
+ # now reduce down to final conversations, find the longest chains of message ids
1406
+ for i, conv in enumerate(conversations):
1407
+ for j, conv2 in enumerate(conversations):
1408
+ if i == j:
1409
+ continue
1410
+ if conv['message_id'] and conv2['message_id']:
1411
+ assert conv['message_id'] != conv2['message_id']
1412
+ # delete the shorter conversation, if one contains the other
1413
+ if conv['message_id'] in conv2['message_id']:
1414
+ conv['message_id'] = None
1415
+ if conv2['message_id'] in conv['message_id']:
1416
+ conv2['message_id'] = None
1417
+ conversations = [c for c in conversations if c['message_id']]
1418
+ if only_personality:
1419
+ all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if 'h2oGPT' in c['text']])
1420
+ else:
1421
+ all_rows.extend([dict(input=c['text'] + "\n<human>:", prompt_type='plain', source=data_file) for c in conversations if "What is H2O.ai" not in c['text']])
1422
+ unhelpful = get_unhelpful_list()
1423
+ all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)]
1424
+ personality = create_personality_data()
1425
+ all_rows.extend(personality * 10)
1426
+ np.random.seed(123)
1427
+ np.random.shuffle(all_rows)
1428
+ print(len(all_rows))
1429
+ if deberta_grading:
1430
+ df = pd.DataFrame(all_rows)
1431
+ df = df.rename(columns={'input': 'text'})
1432
+ df = add_deberta_grade(df)
1433
+ df = df.rename(columns={'text': 'input'})
1434
+ drop = True
1435
+ if drop:
1436
+ min_grade = 0.3
1437
+ max_grade = np.inf
1438
+ before_rows = df.shape[0]
1439
+ df = df[df['grade_deberta'] >= min_grade]
1440
+ df = df[df['grade_deberta'] <= max_grade]
1441
+ after_rows = df.shape[0]
1442
+ print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows))
1443
+ print("After DeBERTa grade")
1444
+ print(df.describe())
1445
+ all_rows = []
1446
+ for i in range(df.shape[0]):
1447
+ all_rows.append(
1448
+ dict(
1449
+ input=df['input'].iloc[i],
1450
+ source=df['source'].iloc[i],
1451
+ prompt_type=df['prompt_type'].iloc[i],
1452
+ grade_deberta=df['grade_deberta'].iloc[i],
1453
+ )
1454
+ )
1455
+ if save_json:
1456
+ data_file = data_file + \
1457
+ ("_h2ogpt" if fixup_personality else "") + \
1458
+ ("_only" if only_personality else "") + \
1459
+ ("_graded" if deberta_grading else "")
1460
+ for i in range(len(all_rows)):
1461
+ all_rows[i]['id'] = i
1462
+ with open(data_file.lower().replace("/", "_") + ".json", "w") as f:
1463
+ f.write(json.dumps(all_rows, indent=2))
1464
+ return all_rows
1465
+
1466
+
1467
+ def test_finalize_to_json():
1468
+ df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet')
1469
+ df = df.rename(columns={'text': 'input'})
1470
+
1471
+ print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1472
+
1473
+ print("Adding open assistant data")
1474
+ with open("openassistant_oasst1_h2ogpt_graded.json") as f:
1475
+ open_assistant = json.loads(f.read())
1476
+ df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0)
1477
+
1478
+ def final_clean(df):
1479
+ from better_profanity import profanity
1480
+ profanity.load_censor_words_from_file("data/censor_words.txt")
1481
+ df['profanity'] = parallel_apply(
1482
+ df['input'],
1483
+ lambda x: profanity.contains_profanity(x),
1484
+ n_jobs=-1,
1485
+ )
1486
+ return df[(df['profanity'] == 0)].reset_index(drop=True)
1487
+ print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1488
+ df = final_clean(df)
1489
+ print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True)
1490
+ print(df.describe())
1491
+ print(df.shape)
1492
+ row_list = []
1493
+ for i in range(df.shape[0]):
1494
+ row_list.append(
1495
+ dict(
1496
+ input=df.loc[i, 'input'],
1497
+ source=df.loc[i, 'source'],
1498
+ prompt_type='plain',
1499
+ )
1500
+ )
1501
+ np.random.seed(1234)
1502
+ np.random.shuffle(row_list)
1503
+ unhelpful = get_unhelpful_list()
1504
+ row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)]
1505
+ for i in range(len(row_list)):
1506
+ row_list[i]['id'] = i
1507
+ row_list[i]['input'] = row_list[i]['input'].replace(" <bot>:", "\n<bot>:")
1508
+ with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f:
1509
+ f.write(json.dumps(row_list, indent=2))
1510
+
1511
+
1512
+ def create_personality_data():
1513
+ questions = [
1514
+ "What's your name?",
1515
+ "What is your name?",
1516
+ "What are you?",
1517
+ "Who are you?",
1518
+ "Do you have a name?",
1519
+ "Who trained you?",
1520
+ "Who created you?",
1521
+ "Who made you?",
1522
+ ]
1523
+ answers = [
1524
+ "I'm h2oGPT, a large language model by H2O.ai.",
1525
+ "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1526
+ "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.",
1527
+ "My name is h2oGPT. I'm a large language model trained by H2O.ai.",
1528
+ "Hi! I'm h2oGPT, a large language model by H2O.ai.",
1529
+ "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.",
1530
+ ]
1531
+ help = [
1532
+ "",
1533
+ " How can I help you?",
1534
+ " How may I assist you?",
1535
+ " Nice to meet you.",
1536
+ ]
1537
+ import itertools
1538
+ rows = []
1539
+ for pair in itertools.product(questions, answers, help):
1540
+ rows.append(
1541
+ dict(input=f"<human>: {pair[0]}\n<bot>: {pair[1]}{pair[2]}\n<human>:", prompt_type='plain', source="H2O.ai")
1542
+ )
1543
+ for row in [
1544
+ "<human>: What is H2O.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1545
+ "<human>: What is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1546
+ "<human>: What is H2O?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1547
+ "<human>: Who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1548
+ "<human>: who is h2o.ai?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1549
+ "<human>: who is h2o?\n<bot>: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n<human>:",
1550
+ "<human>: What is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1551
+ "<human>: Who is H2O.ai?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1552
+ "<human>: Who is H2O?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1553
+ "<human>: Who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1554
+ "<human>: who is h2o?\n<bot>: H2O.ai is the visionary leader in democratizing AI.\n<human>:",
1555
+ ]:
1556
+ rows.append(dict(input=row, prompt_type='plain', source='H2O.ai'))
1557
+ print(len(rows))
1558
+ with open("h2ogpt-personality.json", "w") as f:
1559
+ f.write(json.dumps(rows, indent=2))
1560
+ return rows
1561
+
1562
+
1563
+ def test_check_stats_data():
1564
+ filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json'
1565
+ df = pd.read_json(filename)
1566
+
1567
+ # get word stats
1568
+ df['char_count'] = df['input'].apply(lambda x: len(x))
1569
+ import matplotlib.pyplot as plt
1570
+ plt.figure(figsize=(10, 10))
1571
+ plt.hist(df['char_count'], bins=100)
1572
+ chars_avg = np.mean(df['char_count'])
1573
+ chars_median = np.median(df['char_count'])
1574
+ plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median))
1575
+ plt.savefig('chars_hist.png')
1576
+ plt.close()
1577
+
1578
+ # get tokenize stats for random sample of 1000 rows
1579
+ from finetune import generate_and_tokenize_prompt
1580
+ from loaders import get_loaders, get_tokenizer
1581
+ from functools import partial
1582
+
1583
+ llama_type = False
1584
+ tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
1585
+ model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False)
1586
+ local_files_only = False
1587
+ resume_download = True
1588
+ use_auth_token = False
1589
+ tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token)
1590
+ prompt_type = 'plain' # trained with data already in human bot form
1591
+ train_on_inputs = True
1592
+ add_eos_token = False
1593
+ cutoff_len = 512 # can choose 2048
1594
+ generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type,
1595
+ train_on_inputs=train_on_inputs, add_eos_token=add_eos_token,
1596
+ cutoff_len=cutoff_len, tokenizer=tokenizer)
1597
+ from datasets import load_dataset
1598
+ data = load_dataset("json", data_files={"train": filename})
1599
+ val_set_size = 0.90
1600
+ train_val = data["train"].train_test_split(
1601
+ test_size=val_set_size, shuffle=True, seed=42
1602
+ )
1603
+ train_data = train_val["train"]
1604
+ train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count())
1605
+
1606
+ df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count'])
1607
+
1608
+ plt.figure(figsize=(10, 10))
1609
+ plt.hist(df_tokens['token_count'], bins=100)
1610
+ token_avg = np.mean(df_tokens['token_count'])
1611
+ token_median = np.median(df_tokens['token_count'])
1612
+ plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median))
1613
+ plt.savefig('token_hist_%s.png' % cutoff_len)
1614
+ plt.close()
1615
+
1616
+
1617
+ def get_unhelpful_list():
1618
+ # base versions
1619
+ unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?",
1620
+ "I'm sorry, but I don't understand your question. Could you please rephrase it?",
1621
+ "I'm sorry, I don't quite understand your question",
1622
+ "I'm sorry, I don't know",
1623
+ "I'm sorry, but I don't know",
1624
+ "I don't know anything",
1625
+ "I do not know",
1626
+ "I don't know",
1627
+ "I don't know how",
1628
+ "I do not know how",
1629
+ "Can you please explain what you mean",
1630
+ "please explain what you mean",
1631
+ "please explain",
1632
+ "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by",
1633
+ "I'm sorry but I don't understand what you mean",
1634
+ "I don't understand",
1635
+ "I don't have the ability",
1636
+ "I do not have the ability",
1637
+ "I do not have",
1638
+ "I am a language model,",
1639
+ "I am a large language model,",
1640
+ "I do not understand your question. Can you please try to make it clearer?",
1641
+ "I'm sorry, but as an AI language model",
1642
+ "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.",
1643
+ "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?",
1644
+ "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t",
1645
+ "I apologize, but I cannot perform the task you have requested.",
1646
+ "I'm sorry, I cannot perform this task as I am an AI language model and do not have access",
1647
+ "I'm sorry, I'm not sure what you're asking for here.",
1648
+ "I'm not sure what you are asking",
1649
+ "You need to provide more context",
1650
+ ]
1651
+ # reduced versions, with redundant parts, just to give context for where they came from
1652
+ unhelpful += ["sorry, I didn't quite understand your question",
1653
+ "I didn't quite understand your question",
1654
+ "I didn't understand your question",
1655
+ "I did not understand your question",
1656
+ "I did not understand the question",
1657
+ "could you please rephrase"
1658
+ "could you rephrase"
1659
+ "I do not understand your question.",
1660
+ "I do not understand the question.",
1661
+ "I do not understand that question.",
1662
+ "Can you please try to make it clearer",
1663
+ "Can you try to make it clearer",
1664
+ "sorry, but as an AI language model",
1665
+ "as an AI language model",
1666
+ "I apologize, but I cannot",
1667
+ "I cannot rephrase text",
1668
+ "I cannot understand. Your post is difficult to read and follow."
1669
+ "Your post is difficult to read and follow."
1670
+ "I apologize, but I am",
1671
+ "Sorry, but I am not ",
1672
+ "nor am I capable",
1673
+ "I am not capable of",
1674
+ "I apologize, but I cannot perform the task you have requested",
1675
+ "I cannot perform the task",
1676
+ "I cannot complete the task",
1677
+ "I'm sorry",
1678
+ "I am sorry",
1679
+ "do not have access",
1680
+ "not sure what you're asking for",
1681
+ "not sure what you are asking for",
1682
+ "not sure what is being asked",
1683
+ "I'm not sure what you are asking",
1684
+ "not sure what you are asking",
1685
+ "You need to provide more context",
1686
+ "provide more context",
1687
+ ]
1688
+ unhelpful += ["As a large language model",
1689
+ "cannot provide any information",
1690
+ "As an artificial intelligence I do not have the capability",
1691
+ "As an artificial intelligence I don't have the capability",
1692
+ "As an artificial intelligence I can't",
1693
+ "As an artificial intelligence I cannot",
1694
+ "I am sorry but I do not understand",
1695
+ "Can you please explain",
1696
+ "(sorry couldn't resist)",
1697
+ "(sorry could not resist)",
1698
+ " :)",
1699
+ " ;)",
1700
+ " :-)",
1701
+ " ;-)",
1702
+ " lol ",
1703
+ "Thanks so much!!!",
1704
+ "Thank You :)!!!",
1705
+ "Please try not to repeat",
1706
+ "I am an AI language model",
1707
+ "I'm a AI assistant that",
1708
+ "I'm an AI assistant that",
1709
+ "I am an AI assistant that",
1710
+ "etc.",
1711
+ "etc.etc.",
1712
+ "etc. etc.",
1713
+ "etc etc",
1714
+ ]
1715
+ return unhelpful
1716
+
1717
+
1718
+ def test_check_unhelpful():
1719
+ # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json'
1720
+ file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json'
1721
+ # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json'
1722
+
1723
+ unhelpful = get_unhelpful_list()
1724
+ #data = json.load(open(file, 'rt'))
1725
+ df = pd.read_json(file)
1726
+
1727
+ use_reward_score_threshold = False
1728
+ use_bleu_threshold = False
1729
+ use_sentence_sim = True
1730
+
1731
+ from sacrebleu.metrics import BLEU
1732
+ bleu = BLEU()
1733
+ from nltk.translate.bleu_score import sentence_bleu
1734
+
1735
+ def get_bleu(actual, expected_list):
1736
+ #return bleu.sentence_score(actual, expected_list).score
1737
+ return sentence_bleu(expected_list, actual)
1738
+
1739
+ threshold = 0.0
1740
+ if use_reward_score_threshold:
1741
+ df = df[df['grade_deberta'] > threshold]
1742
+
1743
+ # back to as if original json load
1744
+ data = df.to_dict(orient='records')
1745
+ bads = {}
1746
+ string_all = str(data)
1747
+ for sub in unhelpful:
1748
+ bads[sub] = string_all.count(sub)
1749
+ bads = {k: v for k, v in bads.items() if v > 0}
1750
+ import pprint
1751
+ pp = pprint.PrettyPrinter(indent=4)
1752
+ pp.pprint(bads)
1753
+
1754
+ total_bads = sum(list(bads.values()))
1755
+ print('total_bads: %s' % total_bads, flush=True)
1756
+
1757
+ # check just bot
1758
+ import re
1759
+ convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data]
1760
+ humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs]
1761
+ bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs]
1762
+
1763
+ # FIXME: apply back to json etc., just see for now
1764
+ bleu_threshold = 0.9
1765
+ if use_bleu_threshold:
1766
+ bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)]
1767
+
1768
+ cosine_sim_threshold = 0.8
1769
+ if use_sentence_sim:
1770
+ # pip install sentence_transformers-2.2.2
1771
+ from sentence_transformers import SentenceTransformer
1772
+ # sent_model = 'bert-base-nli-mean-tokens'
1773
+ #sent_model = 'nli-distilroberta-base-v2'
1774
+ sent_model = 'all-MiniLM-L6-v2'
1775
+ model = SentenceTransformer(sent_model)
1776
+ sentence_embeddings = model.encode(unhelpful)
1777
+ from sklearn.metrics.pairwise import cosine_similarity
1778
+ bots = [x for x in tqdm(bots) if np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold]
1779
+
1780
+ bads_bots = {}
1781
+ string_all = str(bots)
1782
+ for sub in unhelpful:
1783
+ bads_bots[sub] = string_all.count(sub)
1784
+ bads_bots = {k: v for k, v in bads_bots.items() if v > 0}
1785
+ import pprint
1786
+ pp = pprint.PrettyPrinter(indent=4)
1787
+ pp.pprint(bads_bots)
1788
+
1789
+ total_bads_bots = sum(list(bads_bots.values()))
1790
+ print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % (threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True)
1791
+
1792
+ # assert len(bads) == 0, bads
1793
+ assert len(bads_bots) == 0, bads_bots
1794
+
1795
+
1796
+ def test_fortune2000_personalized():
1797
+ row_list = []
1798
+ import glob
1799
+ if not os.path.isdir("wikitext"):
1800
+ raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip")
1801
+ for file in glob.glob("wikitext/*.txt"):
1802
+ with open(file, "r") as f:
1803
+ blob = f.read()
1804
+ N = 512 * 4
1805
+ row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)}
1806
+ for s in get_sentences(blob, N) if s])
1807
+ personality = create_personality_data()
1808
+ import copy
1809
+ for i in range(10):
1810
+ row_list.extend(copy.deepcopy(personality))
1811
+ np.random.seed(123)
1812
+ np.random.shuffle(row_list)
1813
+ for i in range(len(row_list)):
1814
+ row_list[i]['id'] = i
1815
+ for i in range(len(row_list)):
1816
+ assert row_list[i]['id'] == i
1817
+ with open("h2ogpt-fortune2000-personalized.json", "w") as ff:
1818
+ ff.write(json.dumps(row_list, indent=2))
finetune.py CHANGED
@@ -1,11 +1,12 @@
1
  import os
2
  import sys
3
- import time
4
  from functools import partial
5
  from typing import List, Union
6
- from enum import Enum
7
  import fire
8
  import numpy as np
 
 
 
9
  from utils import get_githash, copy_code
10
  import torch
11
 
@@ -17,82 +18,6 @@ def log(*args, **kwargs):
17
  print(*args, **kwargs)
18
 
19
 
20
- class PromptType(Enum):
21
- plain = 0
22
- instruct = 1
23
- quality = 2
24
- human_bot = 3
25
- dai_faq = 4
26
- summarize = 5
27
- simple_instruct = 6
28
- instruct_vicuna = 7
29
- instruct_with_end = 8
30
- human_bot_orig = 9
31
- prompt_answer = 10
32
- open_assistant = 11
33
- wizard_lm = 12
34
-
35
-
36
- prompt_type_to_model_name = {
37
- 'plain': [
38
- 'EleutherAI/gpt-j-6B',
39
- 'EleutherAI/pythia-6.9b',
40
- 'EleutherAI/pythia-12b',
41
- 'EleutherAI/pythia-12b-deduped',
42
- 'EleutherAI/gpt-neox-20b',
43
- 'decapoda-research/llama-7b-hf',
44
- 'decapoda-research/llama-13b-hf',
45
- 'decapoda-research/llama-30b-hf',
46
- 'decapoda-research/llama-65b-hf',
47
- 'facebook/mbart-large-50-many-to-many-mmt',
48
- 'philschmid/bart-large-cnn-samsum',
49
- 'philschmid/flan-t5-base-samsum',
50
- 'gpt2',
51
- 'distilgpt2',
52
- 'mosaicml/mpt-7b-storywriter',
53
- 'mosaicml/mpt-7b-instruct', # internal code handles instruct
54
- 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
55
- ],
56
- 'prompt_answer': [
57
- 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
58
- 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
59
- 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
60
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
61
- 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
62
- ],
63
- 'instruct': [],
64
- 'instruct_with_end': ['databricks/dolly-v2-12b'],
65
- 'quality': [],
66
- 'human_bot': [
67
- 'h2oai/h2ogpt-oasst1-512-12b',
68
- 'h2oai/h2ogpt-oasst1-512-20b',
69
- 'h2oai/h2ogpt-oig-oasst1-512-20b',
70
- 'h2oai/h2ogpt-oig-oasst1-512-12b',
71
- 'h2oai/h2ogpt-oig-oasst1-512-6.9b',
72
- 'h2oai/h2ogpt-research-oasst1-512-30b', # private
73
- ],
74
- 'dai_faq': [],
75
- 'summarize': [],
76
- 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
77
- 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
78
- 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
79
- "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
80
- "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
81
- }
82
-
83
- inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
84
- inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
85
-
86
- prompt_types_strings = []
87
- for p in PromptType:
88
- prompt_types_strings.extend([p.name])
89
-
90
-
91
- prompt_types = []
92
- for p in PromptType:
93
- prompt_types.extend([p.name, p.value, str(p.value)])
94
-
95
-
96
  # supported by huggingface evaluate
97
  supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
98
 
@@ -353,7 +278,7 @@ def train(
353
  if os.path.exists(checkpoint_name):
354
  log(f"Restarting from {checkpoint_name}")
355
  adapters_weights = torch.load(checkpoint_name)
356
- model = set_peft_model_state_dict(model, adapters_weights)
357
  else:
358
  log(f"Checkpoint {checkpoint_name} not found")
359
 
@@ -656,58 +581,6 @@ def train(
656
  log("\n If there's a warning about missing keys above, please disregard :)")
657
 
658
 
659
- def get_loaders(llama_type, model_name, reward_type):
660
- # NOTE: Some models need specific new prompt_type
661
- # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
662
- if llama_type:
663
- from transformers import LlamaForCausalLM, LlamaTokenizer
664
- model_loader = LlamaForCausalLM
665
- tokenizer_loader = LlamaTokenizer
666
- elif 'distilgpt2' in model_name.lower():
667
- from transformers import AutoModelForCausalLM, AutoTokenizer
668
- return AutoModelForCausalLM, AutoTokenizer
669
- elif 'gpt2' in model_name.lower():
670
- from transformers import GPT2LMHeadModel, GPT2Tokenizer
671
- return GPT2LMHeadModel, GPT2Tokenizer
672
- elif 'mbart-' in model_name.lower():
673
- from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
674
- return MBartForConditionalGeneration, MBart50TokenizerFast
675
- elif 't5' == model_name.lower() or \
676
- 't5-' in model_name.lower() or \
677
- 'flan-' in model_name.lower():
678
- from transformers import AutoTokenizer, T5ForConditionalGeneration
679
- return T5ForConditionalGeneration, AutoTokenizer
680
- elif 'bigbird' in model_name:
681
- from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
682
- return BigBirdPegasusForConditionalGeneration, AutoTokenizer
683
- elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
684
- from transformers import pipeline
685
- return pipeline, "summarization"
686
- elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
687
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
688
- return AutoModelForSequenceClassification, AutoTokenizer
689
- else:
690
- from transformers import AutoTokenizer, AutoModelForCausalLM
691
- model_loader = AutoModelForCausalLM
692
- tokenizer_loader = AutoTokenizer
693
- return model_loader, tokenizer_loader
694
-
695
-
696
- def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
697
- tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
698
- local_files_only=local_files_only,
699
- resume_download=resume_download,
700
- use_auth_token=use_auth_token)
701
-
702
- tokenizer.pad_token_id = 0 # different from the eos token
703
- # when generating, we will use the logits of right-most token to predict the next token
704
- # so the padding should be on the left,
705
- # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
706
- tokenizer.padding_side = "left" # Allow batched inference
707
-
708
- return tokenizer
709
-
710
-
711
  def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
712
  # there's probably a way to do this with the tokenizer settings
713
  # but again, gotta move fast
@@ -765,253 +638,6 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
765
  return tokenized_full_prompt
766
 
767
 
768
- def get_prompt(prompt_type, chat, context, reduced):
769
- if prompt_type in [-1, "-1", "plain"]:
770
- promptA = promptB = PreInstruct = PreInput = PreResponse = ''
771
- terminate_response = []
772
- chat_sep = ''
773
- elif prompt_type == 'simple_instruct':
774
- promptA = promptB = PreInstruct = PreInput = PreResponse = None
775
- terminate_response = []
776
- chat_sep = '\n'
777
- elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
778
- promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
779
- promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
780
-
781
- PreInstruct = """
782
- ### Instruction:
783
- """
784
-
785
- PreInput = """
786
- ### Input:
787
- """
788
-
789
- PreResponse = """
790
- ### Response:
791
- """
792
- if prompt_type in [7, "7", "instruct_with_end"]:
793
- terminate_response = ['### End']
794
- else:
795
- terminate_response = None
796
- chat_sep = '\n'
797
- elif prompt_type in [1, "1", "quality"]:
798
- promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
799
- promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
800
-
801
- PreInstruct = """
802
- ### Instruction:
803
- """
804
-
805
- PreInput = """
806
- ### Input:
807
- """
808
-
809
- PreResponse = """
810
- ### Response:
811
- """
812
- terminate_response = None
813
- chat_sep = '\n'
814
- elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
815
- human = '<human>:'
816
- bot = "<bot>:"
817
- if reduced or context or prompt_type in [2, "2", "human_bot"]:
818
- preprompt = ''
819
- else:
820
- cur_date = time.strftime('%Y-%m-%d')
821
- cur_time = time.strftime('%H:%M:%S %p %Z')
822
-
823
- PRE_PROMPT = """\
824
- Current Date: {}
825
- Current Time: {}
826
-
827
- """
828
- preprompt = PRE_PROMPT.format(cur_date, cur_time)
829
- start = human
830
- promptB = promptA = '%s%s ' % (preprompt, start)
831
-
832
- PreInstruct = ""
833
-
834
- PreInput = None
835
-
836
- if reduced:
837
- # when making context, want it to appear as-if LLM generated, which starts with space after :
838
- PreResponse = bot + ' '
839
- else:
840
- # normally LLM adds space after this, because was how trained.
841
- # if add space here, non-unique tokenization will often make LLM produce wrong output
842
- PreResponse = bot
843
-
844
- terminate_response = [start, PreResponse]
845
- chat_sep = '\n'
846
- elif prompt_type in [3, "3", "dai_faq"]:
847
- promptA = ''
848
- promptB = 'Answer the following Driverless AI question.\n'
849
-
850
- PreInstruct = """
851
- ### Driverless AI frequently asked question:
852
- """
853
-
854
- PreInput = None
855
-
856
- PreResponse = """
857
- ### Driverless AI documentation answer:
858
- """
859
- terminate_response = ['\n\n']
860
- chat_sep = terminate_response
861
- elif prompt_type in [5, "5", "summarize"]:
862
- promptA = promptB = PreInput = ''
863
- PreInstruct = '## Main Text\n\n'
864
- PreResponse = '\n\n## Summary\n\n'
865
- terminate_response = None
866
- chat_sep = '\n'
867
- elif prompt_type in [6, "6", "instruct_vicuna"]:
868
- promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
869
- "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
870
-
871
- PreInstruct = """
872
- ### Human:
873
- """
874
-
875
- PreInput = None
876
-
877
- PreResponse = """
878
- ### Assistant:
879
- """
880
- terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
881
- chat_sep = '\n'
882
- elif prompt_type in [10, "10", "prompt_answer"]:
883
- preprompt = ''
884
- prompt_tokens = "<|prompt|>"
885
- answer_tokens = "<|answer|>"
886
- start = prompt_tokens
887
- promptB = promptA = '%s%s' % (preprompt, start)
888
- PreInstruct = ""
889
- PreInput = None
890
- PreResponse = answer_tokens
891
- eos = '<|endoftext|>' # neox eos
892
- terminate_response = [start, PreResponse, eos]
893
- chat_sep = eos
894
- elif prompt_type in [11, "11", "open_assistant"]:
895
- # From added_tokens.json
896
- preprompt = ''
897
- prompt_tokens = "<|prompter|>"
898
- answer_tokens = "<|assistant|>"
899
- start = prompt_tokens
900
- promptB = promptA = '%s%s' % (preprompt, start)
901
- PreInstruct = ""
902
- PreInput = None
903
- PreResponse = answer_tokens
904
- pend = "<|prefix_end|>"
905
- eos = "</s>"
906
- terminate_response = [start, PreResponse, pend, eos]
907
- chat_sep = eos
908
- elif prompt_type in [12, "12", "wizard_lm"]:
909
- # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
910
- preprompt = ''
911
- start = ''
912
- promptB = promptA = '%s%s' % (preprompt, start)
913
- PreInstruct = ""
914
- PreInput = None
915
- PreResponse = "\n\n### Response"
916
- eos = "</s>"
917
- terminate_response = [PreResponse, eos]
918
- chat_sep = eos
919
- else:
920
- raise RuntimeError("No such prompt_type=%s" % prompt_type)
921
-
922
- return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep
923
-
924
-
925
- def generate_prompt(data_point, prompt_type, chat, reduced):
926
- context = data_point.get('context')
927
- if context is None:
928
- context = ''
929
- instruction = data_point.get('instruction')
930
- input = data_point.get('input')
931
- output = data_point.get('output')
932
- prompt_type = data_point.get('prompt_type', prompt_type)
933
- assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
934
- promptA, promptB, PreInstruct, PreInput, PreResponse, \
935
- terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)
936
-
937
- prompt = context if not reduced else ''
938
-
939
- if input and promptA:
940
- prompt += f"""{promptA}"""
941
- elif promptB:
942
- prompt += f"""{promptB}"""
943
-
944
- if instruction and PreInstruct is not None and input and PreInput is not None:
945
- prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
946
- prompt = inject_newline(prompt_type, prompt)
947
- elif instruction and input and PreInstruct is None and PreInput is not None:
948
- prompt += f"""{PreInput}{instruction}
949
- {input}"""
950
- prompt = inject_newline(prompt_type, prompt)
951
- elif input and instruction and PreInput is None and PreInstruct is not None:
952
- prompt += f"""{PreInstruct}{instruction}
953
- {input}"""
954
- prompt = inject_newline(prompt_type, prompt)
955
- elif instruction and PreInstruct is not None:
956
- prompt += f"""{PreInstruct}{instruction}"""
957
- prompt = inject_newline(prompt_type, prompt)
958
- elif input and PreInput is not None:
959
- prompt += f"""{PreInput}{input}"""
960
- prompt = inject_newline(prompt_type, prompt)
961
- elif input and instruction and PreInput is not None:
962
- prompt += f"""{PreInput}{instruction}{input}"""
963
- prompt = inject_newline(prompt_type, prompt)
964
- elif input and instruction and PreInstruct is not None:
965
- prompt += f"""{PreInstruct}{instruction}{input}"""
966
- prompt = inject_newline(prompt_type, prompt)
967
- elif input and instruction:
968
- # i.e. for simple_instruct
969
- prompt += f"""{instruction}: {input}"""
970
- prompt = inject_newline(prompt_type, prompt)
971
- elif input:
972
- prompt += f"""{input}"""
973
- prompt = inject_newline(prompt_type, prompt)
974
- elif instruction:
975
- prompt += f"""{instruction}"""
976
- prompt = inject_newline(prompt_type, prompt)
977
-
978
- if PreResponse is not None:
979
- prompt += f"""{PreResponse}"""
980
- pre_response = PreResponse # Don't use strip
981
- else:
982
- pre_response = ''
983
-
984
- if output:
985
- prompt += f"""{output}"""
986
-
987
- return prompt, pre_response, terminate_response, chat_sep
988
-
989
-
990
- def inject_newline(prompt_type, prompt):
991
- if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
992
- # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
993
- prompt += '\n'
994
- return prompt
995
-
996
-
997
- example_data_point0 = dict(instruction="Summarize",
998
- input="Ducks eat seeds by the lake, then swim in the lake where fish eat small animals.",
999
- output="Ducks eat and swim at the lake.")
1000
-
1001
- example_data_point1 = dict(instruction="Who is smarter, Einstein or Newton?",
1002
- output="Einstein.")
1003
-
1004
- example_data_point2 = dict(input="Who is smarter, Einstein or Newton?",
1005
- output="Einstein.")
1006
-
1007
- example_data_points = [example_data_point0, example_data_point1, example_data_point2]
1008
-
1009
-
1010
- def test_train_prompt(prompt_type='instruct', data_point=0):
1011
- example_data_point = example_data_points[data_point]
1012
- return generate_prompt(example_data_point, prompt_type, False, False)
1013
-
1014
-
1015
  def test_debug():
1016
  fire.Fire(train)
1017
 
 
1
  import os
2
  import sys
 
3
  from functools import partial
4
  from typing import List, Union
 
5
  import fire
6
  import numpy as np
7
+
8
+ from loaders import get_loaders, get_tokenizer
9
+ from prompter import generate_prompt, prompt_types
10
  from utils import get_githash, copy_code
11
  import torch
12
 
 
18
  print(*args, **kwargs)
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # supported by huggingface evaluate
22
  supported_metrics = ['bleu', 'rouge', 'sacrebleu', 'meteor']
23
 
 
278
  if os.path.exists(checkpoint_name):
279
  log(f"Restarting from {checkpoint_name}")
280
  adapters_weights = torch.load(checkpoint_name)
281
+ set_peft_model_state_dict(model, adapters_weights)
282
  else:
283
  log(f"Checkpoint {checkpoint_name} not found")
284
 
 
581
  log("\n If there's a warning about missing keys above, please disregard :)")
582
 
583
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
584
  def tokenize(prompt, tokenizer, cutoff_len, add_eos_token=False):
585
  # there's probably a way to do this with the tokenizer settings
586
  # but again, gotta move fast
 
638
  return tokenized_full_prompt
639
 
640
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
641
  def test_debug():
642
  fire.Fire(train)
643
 
generate.py CHANGED
@@ -1,5 +1,9 @@
 
1
  import functools
 
 
2
  import queue
 
3
  import sys
4
  import os
5
  import time
@@ -9,7 +13,12 @@ from datetime import datetime
9
  import filelock
10
  import psutil
11
 
12
- from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash
 
 
 
 
 
13
 
14
  SEED = 1236
15
  set_seed(SEED)
@@ -25,13 +34,16 @@ from peft import PeftModel
25
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
26
  from accelerate import init_empty_weights, infer_auto_device_map
27
 
28
- from prompter import Prompter
29
-
30
- from finetune import get_loaders, example_data_points, generate_prompt, inv_prompt_type_to_model_lower
31
  from stopping import get_stopping
32
 
33
  eval_extra_columns = ['prompt', 'response', 'score']
34
 
 
 
 
 
 
35
 
36
  def main(
37
  load_8bit: bool = False,
@@ -63,6 +75,7 @@ def main(
63
  resume_download: bool = True,
64
  use_auth_token: Union[str, bool] = False,
65
  trust_remote_code: Union[str, bool] = True,
 
66
 
67
  src_lang: str = "English",
68
  tgt_lang: str = "Russian",
@@ -70,7 +83,6 @@ def main(
70
  gradio: bool = True,
71
  gradio_avoid_processing_markdown: bool = False,
72
  chat: bool = True,
73
- chat_history: int = 4096,
74
  chat_context: bool = False,
75
  stream_output: bool = True,
76
  show_examples: bool = None,
@@ -98,6 +110,30 @@ def main(
98
  eval_sharegpt_prompts_only: int = 0,
99
  eval_sharegpt_prompts_only_seed: int = 1234,
100
  eval_sharegpt_as_output: bool = False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ):
102
  """
103
 
@@ -127,12 +163,12 @@ def main(
127
  :param resume_download: whether to resume downloads from HF for models
128
  :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
129
  :param trust_remote_code: whether to use trust any code needed for HF model
 
130
  :param src_lang: source languages to include if doing translation (None = all)
131
  :param tgt_lang: target languages to include if doing translation (None = all)
132
  :param gradio: whether to enable gradio, or to enable benchmark mode
133
  :param gradio_avoid_processing_markdown:
134
  :param chat: whether to enable chat mode with chat history
135
- :param chat_history: maximum character length of chat context/history
136
  :param chat_context: whether to use extra helpful context if human_bot
137
  :param stream_output: whether to stream output from generate
138
  :param show_examples: whether to show clickable examples in gradio
@@ -157,6 +193,41 @@ def main(
157
  :param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
158
  :param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
159
  :param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  :return:
161
  """
162
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
@@ -170,8 +241,20 @@ def main(
170
 
171
  # allow set token directly
172
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
 
 
 
 
 
 
 
 
 
 
 
173
 
174
  if is_public:
 
175
  input_lines = 1 # ensure set, for ease of use
176
  temperature = 0.2 if temperature is None else temperature
177
  top_p = 0.85 if top_p is None else top_p
@@ -211,7 +294,7 @@ def main(
211
  torch.backends.cudnn.benchmark = True
212
  torch.backends.cudnn.enabled = False
213
  torch.set_default_dtype(torch.float32)
214
- if psutil.virtual_memory().available < 94*1024**3:
215
  # 12B uses ~94GB
216
  # 6.9B uses ~47GB
217
  base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
@@ -223,16 +306,22 @@ def main(
223
  stream_output = False
224
  # else prompt removal can mess up output
225
  chat = False
 
 
 
 
 
 
226
 
227
  placeholder_instruction, placeholder_input, \
228
- stream_output, show_examples, \
229
- prompt_type, temperature, top_p, top_k, num_beams, \
230
- max_new_tokens, min_new_tokens, early_stopping, max_time, \
231
- repetition_penalty, num_return_sequences, \
232
- do_sample, \
233
- src_lang, tgt_lang, \
234
- examples, \
235
- task_info = \
236
  get_generate_params(model_lower, chat,
237
  stream_output, show_examples,
238
  prompt_type, temperature, top_p, top_k, num_beams,
@@ -246,6 +335,38 @@ def main(
246
  print(f"Generating model with params:\n{locals_print}", flush=True)
247
  print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
248
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
  if not gradio:
250
  if eval_sharegpt_prompts_only > 0:
251
  # override default examples with shareGPT ones for human-level eval purposes only
@@ -309,11 +430,9 @@ def main(
309
  if not eval_sharegpt_as_output:
310
  model, tokenizer, device = get_model(**locals())
311
  model_state = [model, tokenizer, device, base_model]
312
- fun = partial(evaluate, model_state, debug=debug, save_dir=save_dir, is_low_mem=is_low_mem,
313
- raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
314
- chat_context=chat_context,
315
- concurrency_count=concurrency_count,
316
- lora_weights=lora_weights)
317
  else:
318
  assert eval_sharegpt_prompts_only > 0
319
 
@@ -325,8 +444,6 @@ def main(
325
  t0 = time.time()
326
  score_dump = []
327
 
328
- import matplotlib.pyplot as plt
329
-
330
  for exi, ex in enumerate(examples):
331
  instruction = ex[eval_func_param_names.index('instruction_nochat')]
332
  iinput = ex[eval_func_param_names.index('iinput_nochat')]
@@ -363,7 +480,8 @@ def main(
363
  try:
364
  score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
365
  except torch.cuda.OutOfMemoryError as e:
366
- print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)), flush=True)
 
367
  traceback.print_exc()
368
  score = 0.0
369
  clear_torch_cache()
@@ -419,22 +537,23 @@ def main(
419
  smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
420
  score_model_state0 = [smodel, stokenizer, sdevice, score_model]
421
 
422
- go_gradio(**locals())
423
-
424
-
425
- def get_device():
426
- if torch.cuda.is_available():
427
- device = "cuda"
428
- else:
429
- device = "cpu"
430
 
431
- return device
432
 
433
 
434
  def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
435
  gpu_id=0,
436
  use_auth_token=False,
437
  trust_remote_code=True,
 
438
  triton_attn=False,
439
  long_sequence=True,
440
  ):
@@ -448,6 +567,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
448
  :param gpu_id:
449
  :param use_auth_token:
450
  :param trust_remote_code:
 
451
  :param triton_attn:
452
  :param long_sequence:
453
  :return:
@@ -455,7 +575,8 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
455
  with init_empty_weights():
456
  from transformers import AutoConfig
457
  config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
458
- trust_remote_code=trust_remote_code)
 
459
  if triton_attn and 'mpt-' in base_model.lower():
460
  config.attn_config['attn_impl'] = 'triton'
461
  if long_sequence:
@@ -485,7 +606,6 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
485
  dtype=torch.float16 if load_half else torch.float32,
486
  )
487
  device_map.update(device_map_model)
488
- print('device_map: %s' % device_map, flush=True)
489
  else:
490
  device_map = "auto"
491
 
@@ -504,6 +624,7 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
504
  else:
505
  device_map = {'': 'cpu'}
506
  model_kwargs['load_in_8bit'] = False
 
507
 
508
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
509
  model_kwargs['device_map'] = device_map
@@ -537,6 +658,7 @@ def get_model(
537
  resume_download: bool = True,
538
  use_auth_token: Union[str, bool] = False,
539
  trust_remote_code: bool = True,
 
540
  compile: bool = True,
541
  **kwargs,
542
  ):
@@ -556,11 +678,17 @@ def get_model(
556
  :param resume_download: resume downloads from HF
557
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
558
  :param trust_remote_code: trust code needed by model
 
559
  :param compile: whether to compile torch model
560
  :param kwargs:
561
  :return:
562
  """
563
  print("Get %s model" % base_model, flush=True)
 
 
 
 
 
564
  if lora_weights is not None and lora_weights.strip():
565
  print("Get %s lora weights" % lora_weights, flush=True)
566
  device = get_device()
@@ -575,7 +703,8 @@ def get_model(
575
 
576
  from transformers import AutoConfig
577
  config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
578
- trust_remote_code=trust_remote_code)
 
579
  llama_type_from_config = 'llama' in str(config).lower()
580
  llama_type_from_name = "llama" in base_model.lower()
581
  llama_type = llama_type_from_config or llama_type_from_name
@@ -593,6 +722,7 @@ def get_model(
593
  resume_download=resume_download,
594
  use_auth_token=use_auth_token,
595
  trust_remote_code=trust_remote_code,
 
596
  )
597
  else:
598
  tokenizer = tokenizer_loader
@@ -610,6 +740,7 @@ def get_model(
610
  resume_download=resume_download,
611
  use_auth_token=use_auth_token,
612
  trust_remote_code=trust_remote_code,
 
613
  )
614
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
615
  model_kwargs.update(dict(load_in_8bit=load_8bit,
@@ -630,6 +761,7 @@ def get_model(
630
  gpu_id=gpu_id,
631
  use_auth_token=use_auth_token,
632
  trust_remote_code=trust_remote_code,
 
633
  )
634
  else:
635
  if load_half and not load_8bit:
@@ -653,6 +785,7 @@ def get_model(
653
  resume_download=resume_download,
654
  use_auth_token=use_auth_token,
655
  trust_remote_code=trust_remote_code,
 
656
  device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
657
  )
658
  else:
@@ -669,6 +802,7 @@ def get_model(
669
  resume_download=resume_download,
670
  use_auth_token=use_auth_token,
671
  trust_remote_code=trust_remote_code,
 
672
  device_map="auto",
673
  )
674
  if load_half:
@@ -729,11 +863,13 @@ eval_func_param_names = ['instruction',
729
  'chat',
730
  'instruction_nochat',
731
  'iinput_nochat',
 
732
  ]
733
 
734
 
735
  def evaluate(
736
  model_state,
 
737
  # START NOTE: Examples must have same order of parameters
738
  instruction,
739
  iinput,
@@ -754,6 +890,7 @@ def evaluate(
754
  chat,
755
  instruction_nochat,
756
  iinput_nochat,
 
757
  # END NOTE: Examples must have same order of parameters
758
  src_lang=None,
759
  tgt_lang=None,
@@ -766,12 +903,34 @@ def evaluate(
766
  raise_generate_gpu_exceptions=None,
767
  chat_context=None,
768
  lora_weights=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
769
  ):
770
  # ensure passed these
771
  assert concurrency_count is not None
772
  assert is_low_mem is not None
773
  assert raise_generate_gpu_exceptions is not None
774
  assert chat_context is not None
 
 
 
 
 
 
 
 
 
775
 
776
  if debug:
777
  locals_dict = locals().copy()
@@ -817,10 +976,58 @@ def evaluate(
817
  # get hidden context if have one
818
  context = get_context(chat_context, prompt_type)
819
 
820
- data_point = dict(context=context, instruction=instruction, input=iinput)
821
  prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
 
822
  prompt = prompter.generate_prompt(data_point)
823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
824
  if isinstance(tokenizer, str):
825
  # pipeline
826
  if tokenizer == "summarization":
@@ -838,18 +1045,14 @@ def evaluate(
838
  # override, ignore user change
839
  num_return_sequences = 1
840
  stopping_criteria = get_stopping(prompt_type, tokenizer, device)
841
- # help to avoid errors like:
842
- # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
843
- # RuntimeError: expected scalar type Half but found Float
844
- # with - 256
845
- max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
846
- cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
847
- output_smallest = 30 * 4
848
- prompt = prompt[-cutoff_len - output_smallest:]
849
  inputs = tokenizer(prompt,
850
  return_tensors="pt",
851
  truncation=True,
852
  max_length=max_length_tokenize)
 
 
853
  if debug and len(inputs["input_ids"]) > 0:
854
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
855
  input_ids = inputs["input_ids"].to(device)
@@ -891,7 +1094,7 @@ def evaluate(
891
  **decoder_kwargs
892
  )
893
  decoder_raw_kwargs = dict(skip_special_tokens=False,
894
- clean_up_tokenization_spaces=True)
895
 
896
  decoder_raw = functools.partial(tokenizer.decode,
897
  **decoder_raw_kwargs
@@ -904,7 +1107,7 @@ def evaluate(
904
  # else hit bitsandbytes lack of thread safety:
905
  # https://github.com/h2oai/h2ogpt/issues/104
906
  # but only makes sense if concurrency_count == 1
907
- context_class = NullContext #if concurrency_count > 1 else filelock.FileLock
908
  print('Pre-Generate: %s' % str(datetime.now()), flush=True)
909
  decoded_output = None
910
  with context_class("generate.lock"):
@@ -923,7 +1126,9 @@ def evaluate(
923
  inputs_decoded = prompt = inputs_decoded_raw
924
  decoder = decoder_raw
925
  decoder_kwargs = decoder_raw_kwargs
926
- elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
 
 
927
  inputs_decoded = prompt = inputs_decoded_raw
928
  decoder = decoder_raw
929
  decoder_kwargs = decoder_raw_kwargs
@@ -931,13 +1136,15 @@ def evaluate(
931
  print("WARNING: Special characters in prompt", flush=True)
932
  if stream_output:
933
  skip_prompt = False
934
- streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
 
935
  gen_kwargs.update(dict(streamer=streamer))
936
- target_func = generate_with_exceptions
937
- target = wrapped_partial(generate_with_exceptions, model.generate, prompt, inputs_decoded,
938
- raise_generate_gpu_exceptions, **gen_kwargs)
 
939
  bucket = queue.Queue()
940
- thread = EThread(target=target, kwargs=dict(streamer=streamer), bucket=bucket)
941
  thread.start()
942
  outputs = ""
943
  try:
@@ -969,7 +1176,30 @@ def evaluate(
969
  decoded_output = prompt + outputs[0]
970
  if save_dir and decoded_output:
971
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
972
- print('Post-Generate: %s decoded_output: %s' % (str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
973
 
974
 
975
  class H2OTextIteratorStreamer(TextIteratorStreamer):
@@ -977,6 +1207,7 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
977
  normally, timeout required for now to handle exceptions, else get()
978
  but with H2O version of TextIteratorStreamer, loop over block to handle
979
  """
 
980
  def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
981
  block=True, **decode_kwargs):
982
  super().__init__(tokenizer, skip_prompt, **decode_kwargs)
@@ -1003,7 +1234,7 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
1003
  print("hit stop", flush=True)
1004
  # could raise or break, maybe best to raise and make parent see if any exception in thread
1005
  raise StopIteration()
1006
- #break
1007
  value = self.text_queue.get(block=self.block, timeout=self.timeout)
1008
  break
1009
  except queue.Empty:
@@ -1014,15 +1245,16 @@ class H2OTextIteratorStreamer(TextIteratorStreamer):
1014
  return value
1015
 
1016
 
1017
- def generate_with_exceptions(func, prompt, inputs_decoded, raise_generate_gpu_exceptions, **kwargs):
1018
  try:
1019
- func(**kwargs)
1020
  except torch.cuda.OutOfMemoryError as e:
1021
  print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1022
  flush=True)
1023
- if kwargs['input_ids'] is not None:
1024
- kwargs['input_ids'].cpu()
1025
- kwargs['input_ids'] = None
 
1026
  traceback.print_exc()
1027
  clear_torch_cache()
1028
  return
@@ -1214,7 +1446,7 @@ y = np.random.randint(0, 1, 100)
1214
 
1215
  # move to correct position
1216
  for example in examples:
1217
- example += [chat, '', '']
1218
  # adjust examples if non-chat mode
1219
  if not chat:
1220
  example[eval_func_param_names.index('instruction_nochat')] = example[
@@ -1223,16 +1455,18 @@ y = np.random.randint(0, 1, 100)
1223
 
1224
  example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1225
  example[eval_func_param_names.index('iinput')] = ''
 
 
1226
 
1227
  return placeholder_instruction, placeholder_input, \
1228
- stream_output, show_examples, \
1229
- prompt_type, temperature, top_p, top_k, num_beams, \
1230
- max_new_tokens, min_new_tokens, early_stopping, max_time, \
1231
- repetition_penalty, num_return_sequences, \
1232
- do_sample, \
1233
- src_lang, tgt_lang, \
1234
- examples, \
1235
- task_info
1236
 
1237
 
1238
  def languages_covered():
@@ -1252,12 +1486,6 @@ def get_context(chat_context, prompt_type):
1252
  return context0
1253
 
1254
 
1255
- def test_test_prompt(prompt_type='instruct', data_point=0):
1256
- example_data_point = example_data_points[data_point]
1257
- example_data_point.pop('output', None)
1258
- return generate_prompt(example_data_point, prompt_type, False, False)
1259
-
1260
-
1261
  def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
1262
  question = question[-cutoff_len:]
1263
  answer = answer[-cutoff_len:]
@@ -1321,39 +1549,3 @@ if __name__ == "__main__":
1321
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1322
  """
1323
  fire.Fire(main)
1324
-
1325
-
1326
- import pytest
1327
-
1328
- @pytest.mark.parametrize(
1329
- "base_model",
1330
- [
1331
- "h2oai/h2ogpt-oig-oasst1-512-6.9b",
1332
- "h2oai/h2ogpt-oig-oasst1-512-12b",
1333
- "h2oai/h2ogpt-oig-oasst1-512-20b",
1334
- "h2oai/h2ogpt-oasst1-512-12b",
1335
- "h2oai/h2ogpt-oasst1-512-20b",
1336
- "h2oai/h2ogpt-gm-oasst1-en-1024-20b",
1337
- "databricks/dolly-v2-12b",
1338
- "h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2",
1339
- "ehartford/WizardLM-7B-Uncensored",
1340
- "ehartford/WizardLM-13B-Uncensored",
1341
- "AlekseyKorshuk/vicuna-7b",
1342
- "TheBloke/stable-vicuna-13B-HF",
1343
- "decapoda-research/llama-7b-hf",
1344
- "decapoda-research/llama-13b-hf",
1345
- "decapoda-research/llama-30b-hf",
1346
- "junelee/wizard-vicuna-13b",
1347
- ]
1348
- )
1349
- def test_score_eval(base_model):
1350
- main(
1351
- base_model=base_model,
1352
- chat=False,
1353
- stream_output=False,
1354
- gradio=False,
1355
- eval_sharegpt_prompts_only=500,
1356
- eval_sharegpt_as_output=False,
1357
- num_beams=2,
1358
- infer_devices=False,
1359
- )
 
1
+ import ast
2
  import functools
3
+ import glob
4
+ import inspect
5
  import queue
6
+ import shutil
7
  import sys
8
  import os
9
  import time
 
13
  import filelock
14
  import psutil
15
 
16
+ from loaders import get_loaders
17
+ from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
18
+ import_matplotlib, get_device, makedirs
19
+
20
+ import_matplotlib()
21
+ from matplotlib import pyplot as plt
22
 
23
  SEED = 1236
24
  set_seed(SEED)
 
34
  from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
35
  from accelerate import init_empty_weights, infer_auto_device_map
36
 
37
+ from prompter import Prompter, inv_prompt_type_to_model_lower
 
 
38
  from stopping import get_stopping
39
 
40
  eval_extra_columns = ['prompt', 'response', 'score']
41
 
42
+ langchain_modes = ['Disabled', 'ChatLLM', 'LLM', 'All', 'wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT',
43
+ 'DriverlessAI docs']
44
+
45
+ scratch_base_dir = '/tmp/'
46
+
47
 
48
  def main(
49
  load_8bit: bool = False,
 
75
  resume_download: bool = True,
76
  use_auth_token: Union[str, bool] = False,
77
  trust_remote_code: Union[str, bool] = True,
78
+ offload_folder: str = "offline_folder",
79
 
80
  src_lang: str = "English",
81
  tgt_lang: str = "Russian",
 
83
  gradio: bool = True,
84
  gradio_avoid_processing_markdown: bool = False,
85
  chat: bool = True,
 
86
  chat_context: bool = False,
87
  stream_output: bool = True,
88
  show_examples: bool = None,
 
110
  eval_sharegpt_prompts_only: int = 0,
111
  eval_sharegpt_prompts_only_seed: int = 1234,
112
  eval_sharegpt_as_output: bool = False,
113
+
114
+ langchain_mode: str = 'Disabled',
115
+ visible_langchain_modes: list = ['UserData', 'MyData'],
116
+ user_path: str = None,
117
+ load_db_if_exists: bool = True,
118
+ keep_sources_in_context: bool = False,
119
+ db_type: str = 'chroma',
120
+ use_openai_embedding: bool = False,
121
+ use_openai_model: bool = False,
122
+ hf_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2",
123
+ allow_upload_to_user_data: bool = True,
124
+ allow_upload_to_my_data: bool = True,
125
+ enable_url_upload: bool = True,
126
+ enable_text_upload: bool = True,
127
+ enable_sources_list: bool = True,
128
+ chunk: bool = True,
129
+ chunk_size: int = 512,
130
+ k: int = 4,
131
+ n_jobs: int = -1,
132
+ enable_captions: bool = True,
133
+ captions_model: str = "Salesforce/blip-image-captioning-base",
134
+ pre_load_caption_model: bool = False,
135
+ caption_gpu: bool = True,
136
+ enable_ocr: bool = False,
137
  ):
138
  """
139
 
 
163
  :param resume_download: whether to resume downloads from HF for models
164
  :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
165
  :param trust_remote_code: whether to use trust any code needed for HF model
166
+ :param offload_folder: path for spilling model onto disk
167
  :param src_lang: source languages to include if doing translation (None = all)
168
  :param tgt_lang: target languages to include if doing translation (None = all)
169
  :param gradio: whether to enable gradio, or to enable benchmark mode
170
  :param gradio_avoid_processing_markdown:
171
  :param chat: whether to enable chat mode with chat history
 
172
  :param chat_context: whether to use extra helpful context if human_bot
173
  :param stream_output: whether to stream output from generate
174
  :param show_examples: whether to show clickable examples in gradio
 
193
  :param eval_sharegpt_prompts_only: for no gradio benchmark, if using ShareGPT prompts for eval
194
  :param eval_sharegpt_prompts_only_seed: for no gradio benchmark, if seed for ShareGPT sampling
195
  :param eval_sharegpt_as_output: for no gradio benchmark, whether to test ShareGPT output itself
196
+ :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
197
+ WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
198
+ :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode
199
+ :param visible_langchain_modes: dbs to generate at launch to be ready for LLM
200
+ Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs']
201
+ But wiki_full is expensive and requires preparation
202
+ To allow scratch space only live in session, add 'MyData' to list
203
+ Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData']
204
+ FIXME: Avoid 'All' for now, not implemented
205
+ :param load_db_if_exists: Whether to load chroma db if exists or re-generate db
206
+ :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually
207
+ :param db_type: 'faiss' for in-memory or 'chroma' for persisted on disk
208
+ :param use_openai_embedding: Whether to use OpenAI embeddings for vector db
209
+ :param use_openai_model: Whether to use OpenAI model for use with vector db
210
+ :param hf_embedding_model: Which HF embedding model to use for vector db
211
+ :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db
212
+ :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db
213
+ :param enable_url_upload: Whether to allow upload from URL
214
+ :param enable_text_upload: Whether to allow uplaod of text
215
+ :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db
216
+ :param chunk: Whether to chunk data (True unless know data is already optimally chunked)
217
+ :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
218
+ :param k: number of chunks to give LLM
219
+ :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
220
+ :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
221
+ :param captions_model: Which model to use for captions.
222
+ captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable
223
+ captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
224
+ captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
225
+ Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
226
+ :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
227
+ parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
228
+ Recommended if using larger caption model
229
+ :param caption_gpu: If support caption, then use GPU if exists
230
+ :param enable_ocr: Whether to support OCR on images
231
  :return:
232
  """
233
  is_hf = bool(os.getenv("HUGGINGFACE_SPACES"))
 
241
 
242
  # allow set token directly
243
  use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
244
+ allow_upload_to_user_data = bool(os.environ.get("allow_upload_to_user_data", allow_upload_to_user_data))
245
+ allow_upload_to_my_data = bool(os.environ.get("allow_upload_to_my_data", allow_upload_to_my_data))
246
+ height = os.environ.get("HEIGHT", height)
247
+
248
+ # allow enabling langchain via ENV
249
+ # FIRST PLACE where LangChain referenced, but no imports related to it
250
+ langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode)
251
+ assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
252
+ visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes)))
253
+ if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
254
+ visible_langchain_modes += [langchain_mode]
255
 
256
  if is_public:
257
+ allow_upload_to_user_data = False
258
  input_lines = 1 # ensure set, for ease of use
259
  temperature = 0.2 if temperature is None else temperature
260
  top_p = 0.85 if top_p is None else top_p
 
294
  torch.backends.cudnn.benchmark = True
295
  torch.backends.cudnn.enabled = False
296
  torch.set_default_dtype(torch.float32)
297
+ if psutil.virtual_memory().available < 94 * 1024 ** 3:
298
  # 12B uses ~94GB
299
  # 6.9B uses ~47GB
300
  base_model = 'h2oai/h2ogpt-oig-oasst1-512-6.9b' if not base_model else base_model
 
306
  stream_output = False
307
  # else prompt removal can mess up output
308
  chat = False
309
+ # hard-coded defaults
310
+ first_para = False
311
+ text_limit = None
312
+
313
+ if offload_folder:
314
+ makedirs(offload_folder)
315
 
316
  placeholder_instruction, placeholder_input, \
317
+ stream_output, show_examples, \
318
+ prompt_type, temperature, top_p, top_k, num_beams, \
319
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
320
+ repetition_penalty, num_return_sequences, \
321
+ do_sample, \
322
+ src_lang, tgt_lang, \
323
+ examples, \
324
+ task_info = \
325
  get_generate_params(model_lower, chat,
326
  stream_output, show_examples,
327
  prompt_type, temperature, top_p, top_k, num_beams,
 
335
  print(f"Generating model with params:\n{locals_print}", flush=True)
336
  print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
337
 
338
+ if langchain_mode != "Disabled":
339
+ # SECOND PLACE where LangChain referenced, but all imports are kept local so not required
340
+ from gpt_langchain import prep_langchain, get_some_dbs_from_hf
341
+ if is_hf:
342
+ get_some_dbs_from_hf()
343
+ dbs = {}
344
+ for langchain_mode1 in visible_langchain_modes:
345
+ if langchain_mode1 in ['MyData']:
346
+ # don't use what is on disk, remove it instead
347
+ for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
348
+ if os.path.isdir(gpath1):
349
+ print("Removing old MyData: %s" % gpath1, flush=True)
350
+ shutil.rmtree(gpath1)
351
+ continue
352
+ if langchain_mode1 in ['All']:
353
+ # FIXME: All should be avoided until scans over each db, shouldn't be separate db
354
+ continue
355
+ persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case
356
+ db = prep_langchain(persist_directory1, load_db_if_exists, db_type, use_openai_embedding,
357
+ langchain_mode1, user_path,
358
+ hf_embedding_model,
359
+ kwargs_make_db=locals())
360
+ dbs[langchain_mode1] = db
361
+ # remove None db's so can just rely upon k in dbs for if hav db
362
+ dbs = {k: v for k, v in dbs.items() if v is not None}
363
+ else:
364
+ dbs = {}
365
+ # import control
366
+ if os.environ.get("TEST_LANGCHAIN_IMPORT"):
367
+ assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
368
+ assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
369
+
370
  if not gradio:
371
  if eval_sharegpt_prompts_only > 0:
372
  # override default examples with shareGPT ones for human-level eval purposes only
 
430
  if not eval_sharegpt_as_output:
431
  model, tokenizer, device = get_model(**locals())
432
  model_state = [model, tokenizer, device, base_model]
433
+ kwargs_evaluate = {k: v for k, v in locals().items() if k in inputs_kwargs_list}
434
+ my_db_state = [None]
435
+ fun = partial(evaluate, model_state, my_db_state, **kwargs_evaluate)
 
 
436
  else:
437
  assert eval_sharegpt_prompts_only > 0
438
 
 
444
  t0 = time.time()
445
  score_dump = []
446
 
 
 
447
  for exi, ex in enumerate(examples):
448
  instruction = ex[eval_func_param_names.index('instruction_nochat')]
449
  iinput = ex[eval_func_param_names.index('iinput_nochat')]
 
480
  try:
481
  score = torch.sigmoid(smodel(**inputs).logits[0].float()).cpu().detach().numpy()[0]
482
  except torch.cuda.OutOfMemoryError as e:
483
+ print("GPU OOM 1: question: %s answer: %s exception: %s" % (prompt, res, str(e)),
484
+ flush=True)
485
  traceback.print_exc()
486
  score = 0.0
487
  clear_torch_cache()
 
537
  smodel, stokenizer, sdevice = get_score_model(**all_kwargs)
538
  score_model_state0 = [smodel, stokenizer, sdevice, score_model]
539
 
540
+ if enable_captions:
541
+ if pre_load_caption_model:
542
+ from image_captions import H2OImageCaptionLoader
543
+ caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model()
544
+ else:
545
+ caption_loader = 'gpu' if caption_gpu else 'cpu'
546
+ else:
547
+ caption_loader = False
548
 
549
+ go_gradio(**locals())
550
 
551
 
552
  def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
553
  gpu_id=0,
554
  use_auth_token=False,
555
  trust_remote_code=True,
556
+ offload_folder=None,
557
  triton_attn=False,
558
  long_sequence=True,
559
  ):
 
567
  :param gpu_id:
568
  :param use_auth_token:
569
  :param trust_remote_code:
570
+ :param offload_folder:
571
  :param triton_attn:
572
  :param long_sequence:
573
  :return:
 
575
  with init_empty_weights():
576
  from transformers import AutoConfig
577
  config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
578
+ trust_remote_code=trust_remote_code,
579
+ offload_folder=offload_folder)
580
  if triton_attn and 'mpt-' in base_model.lower():
581
  config.attn_config['attn_impl'] = 'triton'
582
  if long_sequence:
 
606
  dtype=torch.float16 if load_half else torch.float32,
607
  )
608
  device_map.update(device_map_model)
 
609
  else:
610
  device_map = "auto"
611
 
 
624
  else:
625
  device_map = {'': 'cpu'}
626
  model_kwargs['load_in_8bit'] = False
627
+ print('device_map: %s' % device_map, flush=True)
628
 
629
  load_in_8bit = model_kwargs.get('load_in_8bit', False)
630
  model_kwargs['device_map'] = device_map
 
658
  resume_download: bool = True,
659
  use_auth_token: Union[str, bool] = False,
660
  trust_remote_code: bool = True,
661
+ offload_folder: str = None,
662
  compile: bool = True,
663
  **kwargs,
664
  ):
 
678
  :param resume_download: resume downloads from HF
679
  :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
680
  :param trust_remote_code: trust code needed by model
681
+ :param offload_folder: offload folder
682
  :param compile: whether to compile torch model
683
  :param kwargs:
684
  :return:
685
  """
686
  print("Get %s model" % base_model, flush=True)
687
+ if base_model in ['llama', 'gptj']:
688
+ from gpt4all_llm import get_model_tokenizer_gpt4all
689
+ model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
690
+ return model, tokenizer, device
691
+
692
  if lora_weights is not None and lora_weights.strip():
693
  print("Get %s lora weights" % lora_weights, flush=True)
694
  device = get_device()
 
703
 
704
  from transformers import AutoConfig
705
  config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
706
+ trust_remote_code=trust_remote_code,
707
+ offload_folder=offload_folder)
708
  llama_type_from_config = 'llama' in str(config).lower()
709
  llama_type_from_name = "llama" in base_model.lower()
710
  llama_type = llama_type_from_config or llama_type_from_name
 
722
  resume_download=resume_download,
723
  use_auth_token=use_auth_token,
724
  trust_remote_code=trust_remote_code,
725
+ offload_folder=offload_folder,
726
  )
727
  else:
728
  tokenizer = tokenizer_loader
 
740
  resume_download=resume_download,
741
  use_auth_token=use_auth_token,
742
  trust_remote_code=trust_remote_code,
743
+ offload_folder=offload_folder,
744
  )
745
  if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
746
  model_kwargs.update(dict(load_in_8bit=load_8bit,
 
761
  gpu_id=gpu_id,
762
  use_auth_token=use_auth_token,
763
  trust_remote_code=trust_remote_code,
764
+ offload_folder=offload_folder,
765
  )
766
  else:
767
  if load_half and not load_8bit:
 
785
  resume_download=resume_download,
786
  use_auth_token=use_auth_token,
787
  trust_remote_code=trust_remote_code,
788
+ offload_folder=offload_folder,
789
  device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
790
  )
791
  else:
 
802
  resume_download=resume_download,
803
  use_auth_token=use_auth_token,
804
  trust_remote_code=trust_remote_code,
805
+ offload_folder=offload_folder,
806
  device_map="auto",
807
  )
808
  if load_half:
 
863
  'chat',
864
  'instruction_nochat',
865
  'iinput_nochat',
866
+ 'langchain_mode',
867
  ]
868
 
869
 
870
  def evaluate(
871
  model_state,
872
+ my_db_state,
873
  # START NOTE: Examples must have same order of parameters
874
  instruction,
875
  iinput,
 
890
  chat,
891
  instruction_nochat,
892
  iinput_nochat,
893
+ langchain_mode,
894
  # END NOTE: Examples must have same order of parameters
895
  src_lang=None,
896
  tgt_lang=None,
 
903
  raise_generate_gpu_exceptions=None,
904
  chat_context=None,
905
  lora_weights=None,
906
+ load_db_if_exists=True,
907
+ dbs=None,
908
+ user_path=None,
909
+ use_openai_embedding=None,
910
+ use_openai_model=None,
911
+ hf_embedding_model=None,
912
+ chunk=None,
913
+ chunk_size=None,
914
+ db_type=None,
915
+ k=None,
916
+ n_jobs=None,
917
+ first_para=None,
918
+ text_limit=None,
919
  ):
920
  # ensure passed these
921
  assert concurrency_count is not None
922
  assert is_low_mem is not None
923
  assert raise_generate_gpu_exceptions is not None
924
  assert chat_context is not None
925
+ assert use_openai_embedding is not None
926
+ assert use_openai_model is not None
927
+ assert hf_embedding_model is not None
928
+ assert chunk is not None
929
+ assert chunk_size is not None
930
+ assert db_type is not None
931
+ assert k is not None
932
+ assert n_jobs is not None
933
+ assert first_para is not None
934
 
935
  if debug:
936
  locals_dict = locals().copy()
 
976
  # get hidden context if have one
977
  context = get_context(chat_context, prompt_type)
978
 
 
979
  prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
980
+ data_point = dict(context=context, instruction=instruction, input=iinput)
981
  prompt = prompter.generate_prompt(data_point)
982
 
983
+ # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use
984
+ assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode
985
+ if langchain_mode in ['MyData'] and my_db_state is not None and len(my_db_state) > 0 and my_db_state[0] is not None:
986
+ db1 = my_db_state[0]
987
+ elif dbs is not None and langchain_mode in dbs:
988
+ db1 = dbs[langchain_mode]
989
+ else:
990
+ db1 = None
991
+ if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in ['llama', 'gptj']:
992
+ query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
993
+ outr = ""
994
+ # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
995
+ from gpt_langchain import run_qa_db
996
+ for r in run_qa_db(query=query,
997
+ model_name=base_model, model=model, tokenizer=tokenizer,
998
+ stream_output=stream_output,
999
+ prompter=prompter,
1000
+ load_db_if_exists=load_db_if_exists,
1001
+ db=db1,
1002
+ user_path=user_path,
1003
+ max_new_tokens=max_new_tokens,
1004
+ cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary
1005
+ use_openai_embedding=use_openai_embedding,
1006
+ use_openai_model=use_openai_model,
1007
+ hf_embedding_model=hf_embedding_model,
1008
+ first_para=first_para,
1009
+ text_limit=text_limit,
1010
+ chunk=chunk,
1011
+ chunk_size=chunk_size,
1012
+ langchain_mode=langchain_mode,
1013
+ db_type=db_type,
1014
+ k=k,
1015
+ temperature=temperature,
1016
+ repetition_penalty=repetition_penalty,
1017
+ top_k=top_k,
1018
+ top_p=top_p,
1019
+ prompt_type=prompt_type,
1020
+ n_jobs=n_jobs,
1021
+ ):
1022
+ outr = r # doesn't accumulate, new answer every yield, so only save that full answer
1023
+ yield r
1024
+ if save_dir:
1025
+ save_generate_output(output=outr, base_model=base_model, save_dir=save_dir)
1026
+ print('Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
1027
+ flush=True)
1028
+ if outr:
1029
+ return
1030
+
1031
  if isinstance(tokenizer, str):
1032
  # pipeline
1033
  if tokenizer == "summarization":
 
1045
  # override, ignore user change
1046
  num_return_sequences = 1
1047
  stopping_criteria = get_stopping(prompt_type, tokenizer, device)
1048
+ _, _, max_length_tokenize, max_prompt_length = get_cutoffs(is_low_mem)
1049
+ prompt = prompt[-max_prompt_length:]
 
 
 
 
 
 
1050
  inputs = tokenizer(prompt,
1051
  return_tensors="pt",
1052
  truncation=True,
1053
  max_length=max_length_tokenize)
1054
+ if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
1055
+ print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
1056
  if debug and len(inputs["input_ids"]) > 0:
1057
  print('input_ids length', len(inputs["input_ids"][0]), flush=True)
1058
  input_ids = inputs["input_ids"].to(device)
 
1094
  **decoder_kwargs
1095
  )
1096
  decoder_raw_kwargs = dict(skip_special_tokens=False,
1097
+ clean_up_tokenization_spaces=True)
1098
 
1099
  decoder_raw = functools.partial(tokenizer.decode,
1100
  **decoder_raw_kwargs
 
1107
  # else hit bitsandbytes lack of thread safety:
1108
  # https://github.com/h2oai/h2ogpt/issues/104
1109
  # but only makes sense if concurrency_count == 1
1110
+ context_class = NullContext # if concurrency_count > 1 else filelock.FileLock
1111
  print('Pre-Generate: %s' % str(datetime.now()), flush=True)
1112
  decoded_output = None
1113
  with context_class("generate.lock"):
 
1126
  inputs_decoded = prompt = inputs_decoded_raw
1127
  decoder = decoder_raw
1128
  decoder_kwargs = decoder_raw_kwargs
1129
+ elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ',
1130
+ '') == prompt.replace(
1131
+ '\n', ' ').replace(' ', ''):
1132
  inputs_decoded = prompt = inputs_decoded_raw
1133
  decoder = decoder_raw
1134
  decoder_kwargs = decoder_raw_kwargs
 
1136
  print("WARNING: Special characters in prompt", flush=True)
1137
  if stream_output:
1138
  skip_prompt = False
1139
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False,
1140
+ **decoder_kwargs)
1141
  gen_kwargs.update(dict(streamer=streamer))
1142
+ target = wrapped_partial(generate_with_exceptions, model.generate,
1143
+ prompt=prompt, inputs_decoded=inputs_decoded,
1144
+ raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
1145
+ **gen_kwargs)
1146
  bucket = queue.Queue()
1147
+ thread = EThread(target=target, streamer=streamer, bucket=bucket)
1148
  thread.start()
1149
  outputs = ""
1150
  try:
 
1176
  decoded_output = prompt + outputs[0]
1177
  if save_dir and decoded_output:
1178
  save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir)
1179
+ print('Post-Generate: %s decoded_output: %s' % (
1180
+ str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
1181
+
1182
+
1183
+ inputs_list_names = list(inspect.signature(evaluate).parameters)
1184
+ state_names = ['model_state', 'my_db_state']
1185
+ inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names]
1186
+
1187
+
1188
+ def get_cutoffs(is_low_mem, for_context=False):
1189
+ # help to avoid errors like:
1190
+ # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
1191
+ # RuntimeError: expected scalar type Half but found Float
1192
+ # with - 256
1193
+ max_length_tokenize = 768 - 256 if is_low_mem else 2048 - 256
1194
+ cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
1195
+ output_smallest = 30 * 4
1196
+ max_prompt_length = cutoff_len - output_smallest
1197
+
1198
+ if for_context:
1199
+ # then lower even more to avoid later chop, since just estimate tokens in context bot
1200
+ max_prompt_length = max(64, int(max_prompt_length * 0.8))
1201
+
1202
+ return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length
1203
 
1204
 
1205
  class H2OTextIteratorStreamer(TextIteratorStreamer):
 
1207
  normally, timeout required for now to handle exceptions, else get()
1208
  but with H2O version of TextIteratorStreamer, loop over block to handle
1209
  """
1210
+
1211
  def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None,
1212
  block=True, **decode_kwargs):
1213
  super().__init__(tokenizer, skip_prompt, **decode_kwargs)
 
1234
  print("hit stop", flush=True)
1235
  # could raise or break, maybe best to raise and make parent see if any exception in thread
1236
  raise StopIteration()
1237
+ # break
1238
  value = self.text_queue.get(block=self.block, timeout=self.timeout)
1239
  break
1240
  except queue.Empty:
 
1245
  return value
1246
 
1247
 
1248
+ def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_generate_gpu_exceptions=True, **kwargs):
1249
  try:
1250
+ func(*args, **kwargs)
1251
  except torch.cuda.OutOfMemoryError as e:
1252
  print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)),
1253
  flush=True)
1254
+ if 'input_ids' in kwargs:
1255
+ if kwargs['input_ids'] is not None:
1256
+ kwargs['input_ids'].cpu()
1257
+ kwargs['input_ids'] = None
1258
  traceback.print_exc()
1259
  clear_torch_cache()
1260
  return
 
1446
 
1447
  # move to correct position
1448
  for example in examples:
1449
+ example += [chat, '', '', 'Disabled']
1450
  # adjust examples if non-chat mode
1451
  if not chat:
1452
  example[eval_func_param_names.index('instruction_nochat')] = example[
 
1455
 
1456
  example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')]
1457
  example[eval_func_param_names.index('iinput')] = ''
1458
+ assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % (
1459
+ len(example), len(eval_func_param_names))
1460
 
1461
  return placeholder_instruction, placeholder_input, \
1462
+ stream_output, show_examples, \
1463
+ prompt_type, temperature, top_p, top_k, num_beams, \
1464
+ max_new_tokens, min_new_tokens, early_stopping, max_time, \
1465
+ repetition_penalty, num_return_sequences, \
1466
+ do_sample, \
1467
+ src_lang, tgt_lang, \
1468
+ examples, \
1469
+ task_info
1470
 
1471
 
1472
  def languages_covered():
 
1486
  return context0
1487
 
1488
 
 
 
 
 
 
 
1489
  def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len):
1490
  question = question[-cutoff_len:]
1491
  answer = answer[-cutoff_len:]
 
1549
  python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6.9b
1550
  """
1551
  fire.Fire(main)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
gpt4all_llm.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import os
3
+ from typing import Dict, Any, Optional, List
4
+ from langchain.callbacks.manager import CallbackManagerForLLMRun
5
+ from pydantic import root_validator
6
+ from langchain.llms import gpt4all
7
+ from dotenv import dotenv_values
8
+
9
+
10
+ class FakeTokenizer:
11
+
12
+ def encode(self, x, *args, **kwargs):
13
+ return dict(input_ids=[x])
14
+
15
+ def decode(self, x, *args, **kwargs):
16
+ return x
17
+
18
+ def __call__(self, x, *args, **kwargs):
19
+ return self.encode(x, *args, **kwargs)
20
+
21
+
22
+ def get_model_tokenizer_gpt4all(base_model, **kwargs):
23
+ # defaults (some of these are generation parameters, so need to be passed in at generation time)
24
+ model_kwargs = dict(n_ctx=kwargs.get('max_new_tokens', 256),
25
+ n_threads=os.cpu_count() // 2,
26
+ temp=kwargs.get('temperature', 0.2),
27
+ top_p=kwargs.get('top_p', 0.75),
28
+ top_k=kwargs.get('top_k', 40))
29
+ env_gpt4all_file = ".env_gpt4all"
30
+ model_kwargs.update(dotenv_values(env_gpt4all_file))
31
+
32
+ if base_model == "llama":
33
+ if 'model_path_llama' not in model_kwargs:
34
+ raise ValueError("No model_path_llama in %s" % env_gpt4all_file)
35
+ model_path = model_kwargs.pop('model_path_llama')
36
+ from gpt4all import GPT4All as GPT4AllModel
37
+ elif base_model == "gptj":
38
+ if 'model_path_gptj' not in model_kwargs:
39
+ raise ValueError("No model_path_gptj in %s" % env_gpt4all_file)
40
+ model_path = model_kwargs.pop('model_path_gptj')
41
+ from gpt4all import GPT4All as GPT4AllModel
42
+ else:
43
+ raise ValueError("No such base_model %s" % base_model)
44
+ func_names = list(inspect.signature(GPT4AllModel).parameters)
45
+ model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names}
46
+ model = GPT4AllModel(model_path, **model_kwargs)
47
+ return model, FakeTokenizer(), 'cpu'
48
+
49
+
50
+ def get_llm_gpt4all(model_name, model=None,
51
+ max_new_tokens=256,
52
+ temperature=0.1,
53
+ repetition_penalty=1.0,
54
+ top_k=40,
55
+ top_p=0.7):
56
+ env_gpt4all_file = ".env_gpt4all"
57
+ model_kwargs = dotenv_values(env_gpt4all_file)
58
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
59
+ callbacks = [StreamingStdOutCallbackHandler()]
60
+ n_ctx = model_kwargs.pop('n_ctx', 1024)
61
+ default_params = {'context_erase': 0.5, 'n_batch': 1, 'n_ctx': n_ctx, 'n_predict': max_new_tokens,
62
+ 'repeat_last_n': 64 if repetition_penalty != 1.0 else 0, 'repeat_penalty': repetition_penalty,
63
+ 'temp': temperature, 'top_k': top_k, 'top_p': top_p}
64
+ if model_name == 'llama':
65
+ from langchain.llms import LlamaCpp
66
+ model_path = model_kwargs.pop('model_path_llama') if model is None else model
67
+ llm = LlamaCpp(model_path=model_path, n_ctx=n_ctx, callbacks=callbacks, verbose=False)
68
+ else:
69
+ model_path = model_kwargs.pop('model_path_gptj') if model is None else model
70
+ llm = H2OGPT4All(model=model_path, backend='gptj', callbacks=callbacks,
71
+ verbose=False, **default_params,
72
+ )
73
+ return llm
74
+
75
+
76
+ class H2OGPT4All(gpt4all.GPT4All):
77
+ model: Any
78
+ """Path to the pre-trained GPT4All model file."""
79
+
80
+ @root_validator()
81
+ def validate_environment(cls, values: Dict) -> Dict:
82
+ """Validate that the python package exists in the environment."""
83
+ try:
84
+ if isinstance(values["model"], str):
85
+ from gpt4all import GPT4All as GPT4AllModel
86
+
87
+ full_path = values["model"]
88
+ model_path, delimiter, model_name = full_path.rpartition("/")
89
+ model_path += delimiter
90
+
91
+ values["client"] = GPT4AllModel(
92
+ model_name=model_name,
93
+ model_path=model_path or None,
94
+ model_type=values["backend"],
95
+ allow_download=False,
96
+ )
97
+ else:
98
+ values["client"] = values["model"]
99
+ values["backend"] = values["client"].model.model_type
100
+
101
+ except ImportError:
102
+ raise ValueError(
103
+ "Could not import gpt4all python package. "
104
+ "Please install it with `pip install gpt4all`."
105
+ )
106
+ return values
107
+
108
+ def _call(
109
+ self,
110
+ prompt: str,
111
+ stop: Optional[List[str]] = None,
112
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
113
+ ) -> str:
114
+ # Roughly 4 chars per token if natural language
115
+ prompt = prompt[-self.n_ctx * 4:]
116
+ verbose = False
117
+ if verbose:
118
+ print("_call prompt: %s" % prompt, flush=True)
119
+ return super()._call(prompt, stop=stop, run_manager=run_manager)
gpt_langchain.py ADDED
@@ -0,0 +1,1076 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import inspect
3
+ import os
4
+ import pathlib
5
+ import pickle
6
+ import shutil
7
+ import subprocess
8
+ import sys
9
+ import tempfile
10
+ import traceback
11
+ import uuid
12
+ import zipfile
13
+ from collections import defaultdict
14
+ from datetime import datetime
15
+ from functools import reduce
16
+ from operator import concat
17
+
18
+ from joblib import Parallel, delayed
19
+
20
+ from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
21
+ get_device
22
+
23
+ import_matplotlib()
24
+
25
+ import numpy as np
26
+ import pandas as pd
27
+ import requests
28
+ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
29
+ # , GCSDirectoryLoader, GCSFileLoader
30
+ # , OutlookMessageLoader # GPL3
31
+ # ImageCaptionLoader, # use our own wrapper
32
+ # ReadTheDocsLoader, # no special file, some path, so have to give as special option
33
+ from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
34
+ UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
35
+ EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
36
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
37
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
38
+ from langchain.vectorstores import FAISS
39
+ from langchain.chains.question_answering import load_qa_chain
40
+ from langchain.docstore.document import Document
41
+ from langchain import PromptTemplate
42
+ from langchain.vectorstores import Chroma
43
+
44
+
45
+ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
46
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
47
+ if not sources:
48
+ return None
49
+ # get embedding model
50
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
51
+
52
+ # Create vector database
53
+ if db_type == 'faiss':
54
+ db = FAISS.from_documents(sources, embedding)
55
+ elif db_type == 'chroma':
56
+ collection_name = langchain_mode.replace(' ', '_')
57
+ os.makedirs(persist_directory, exist_ok=True)
58
+ db = Chroma.from_documents(documents=sources,
59
+ embedding=embedding,
60
+ persist_directory=persist_directory,
61
+ collection_name=collection_name,
62
+ anonymized_telemetry=False)
63
+ db.persist()
64
+ # FIXME: below just proves can load persistent dir, regenerates its embedding files, so a bit wasteful
65
+ if False:
66
+ db = Chroma(embedding_function=embedding,
67
+ persist_directory=persist_directory,
68
+ collection_name=collection_name)
69
+ else:
70
+ raise RuntimeError("No such db_type=%s" % db_type)
71
+
72
+ return db
73
+
74
+
75
+ def add_to_db(db, sources, db_type='faiss', avoid_dup=True):
76
+ if not sources:
77
+ return db
78
+ if db_type == 'faiss':
79
+ db.add_documents(sources)
80
+ elif db_type == 'chroma':
81
+ if avoid_dup:
82
+ collection = db.get()
83
+ metadata_sources = set([x['source'] for x in collection['metadatas']])
84
+ sources = [x for x in sources if x.metadata['source'] not in metadata_sources]
85
+ if len(sources) == 0:
86
+ return db
87
+ db.add_documents(documents=sources)
88
+ db.persist()
89
+ else:
90
+ raise RuntimeError("No such db_type=%s" % db_type)
91
+
92
+ return db
93
+
94
+
95
+ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
96
+ # Get embedding model
97
+ if use_openai_embedding:
98
+ assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
99
+ from langchain.embeddings import OpenAIEmbeddings
100
+ embedding = OpenAIEmbeddings()
101
+ else:
102
+ # to ensure can fork without deadlock
103
+ from langchain.embeddings import HuggingFaceEmbeddings
104
+
105
+ device, torch_dtype, context_class = get_device_dtype()
106
+ model_kwargs = dict(device=device)
107
+ embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
108
+ return embedding
109
+
110
+
111
+ def get_answer_from_sources(chain, sources, question):
112
+ return chain(
113
+ {
114
+ "input_documents": sources,
115
+ "question": question,
116
+ },
117
+ return_only_outputs=True,
118
+ )["output_text"]
119
+
120
+
121
+ def get_llm(use_openai_model=False, model_name=None, model=None,
122
+ tokenizer=None, stream_output=False,
123
+ max_new_tokens=256,
124
+ temperature=0.1,
125
+ repetition_penalty=1.0,
126
+ top_k=40,
127
+ top_p=0.7,
128
+ prompt_type=None,
129
+ ):
130
+ if use_openai_model:
131
+ from langchain.llms import OpenAI
132
+ llm = OpenAI(temperature=0)
133
+ model_name = 'openai'
134
+ streamer = None
135
+ elif model_name in ['gptj', 'llama']:
136
+ from gpt4all_llm import get_llm_gpt4all
137
+ llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
138
+ temperature=temperature,
139
+ repetition_penalty=repetition_penalty,
140
+ top_k=top_k,
141
+ top_p=top_p,
142
+ )
143
+ streamer = None
144
+ prompt_type = 'plain'
145
+ else:
146
+ from transformers import AutoTokenizer, AutoModelForCausalLM
147
+
148
+ if model is None:
149
+ # only used if didn't pass model in
150
+ assert model_name is None
151
+ assert tokenizer is None
152
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
153
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6.9b'
154
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
155
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
156
+ device, torch_dtype, context_class = get_device_dtype()
157
+
158
+ with context_class(device):
159
+ load_8bit = True
160
+ # FIXME: for now not to spread across hetero GPUs
161
+ # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
162
+ device_map = {"": 0} if device == 'cuda' else "auto"
163
+ model = AutoModelForCausalLM.from_pretrained(model_name,
164
+ device_map=device_map,
165
+ torch_dtype=torch_dtype,
166
+ load_in_8bit=load_8bit)
167
+
168
+ gen_kwargs = dict(max_new_tokens=max_new_tokens, return_full_text=True, early_stopping=False)
169
+ if stream_output:
170
+ skip_prompt = False
171
+ from generate import H2OTextIteratorStreamer
172
+ decoder_kwargs = {}
173
+ streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs)
174
+ gen_kwargs.update(dict(streamer=streamer))
175
+ else:
176
+ streamer = None
177
+
178
+ if 'h2ogpt' in model_name or prompt_type == 'human_bot':
179
+ from h2oai_pipeline import H2OTextGenerationPipeline
180
+ pipe = H2OTextGenerationPipeline(model=model, tokenizer=tokenizer, **gen_kwargs)
181
+ # pipe.task = "text-generation"
182
+ # below makes it listen only to our prompt removal, not built in prompt removal that is less general and not specific for our model
183
+ pipe.task = "text2text-generation"
184
+ prompt_type = 'human_bot'
185
+ else:
186
+ # only for non-instruct tuned cases when ok with just normal next token prediction
187
+ from transformers import pipeline
188
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **gen_kwargs)
189
+
190
+ from langchain.llms import HuggingFacePipeline
191
+ llm = HuggingFacePipeline(pipeline=pipe)
192
+ return llm, model_name, streamer, prompt_type
193
+
194
+
195
+ def get_device_dtype():
196
+ # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently
197
+ import torch
198
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
199
+ device = 'cpu' if n_gpus == 0 else 'cuda'
200
+ # from utils import NullContext
201
+ # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class
202
+ context_class = torch.device
203
+ torch_dtype = torch.float16 if device == 'cuda' else torch.float32
204
+ return device, torch_dtype, context_class
205
+
206
+
207
+ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
208
+ """
209
+ Get wikipedia data from online
210
+ :param title:
211
+ :param first_paragraph_only:
212
+ :param text_limit:
213
+ :param take_head:
214
+ :return:
215
+ """
216
+ filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head)
217
+ url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}"
218
+ if first_paragraph_only:
219
+ url += "&exintro=1"
220
+ import json
221
+ if not os.path.isfile(filename):
222
+ data = requests.get(url).json()
223
+ json.dump(data, open(filename, 'wt'))
224
+ else:
225
+ data = json.load(open(filename, "rt"))
226
+ page_content = list(data["query"]["pages"].values())[0]["extract"]
227
+ if take_head is not None and text_limit is not None:
228
+ page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
229
+ title_url = str(title).replace(' ', '_')
230
+ return Document(
231
+ page_content=page_content,
232
+ metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"},
233
+ )
234
+
235
+
236
+ def get_wiki_sources(first_para=True, text_limit=None):
237
+ """
238
+ Get specific named sources from wikipedia
239
+ :param first_para:
240
+ :param text_limit:
241
+ :return:
242
+ """
243
+ default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux']
244
+ wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources))
245
+ return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources]
246
+
247
+
248
+ def get_github_docs(repo_owner, repo_name):
249
+ """
250
+ Access github from specific repo
251
+ :param repo_owner:
252
+ :param repo_name:
253
+ :return:
254
+ """
255
+ with tempfile.TemporaryDirectory() as d:
256
+ subprocess.check_call(
257
+ f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .",
258
+ cwd=d,
259
+ shell=True,
260
+ )
261
+ git_sha = (
262
+ subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d)
263
+ .decode("utf-8")
264
+ .strip()
265
+ )
266
+ repo_path = pathlib.Path(d)
267
+ markdown_files = list(repo_path.glob("*/*.md")) + list(
268
+ repo_path.glob("*/*.mdx")
269
+ )
270
+ for markdown_file in markdown_files:
271
+ with open(markdown_file, "r") as f:
272
+ relative_path = markdown_file.relative_to(repo_path)
273
+ github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}"
274
+ yield Document(page_content=f.read(), metadata={"source": github_url})
275
+
276
+
277
+ def get_dai_pickle(dest="."):
278
+ from huggingface_hub import hf_hub_download
279
+ # True for case when locally already logged in with correct token, so don't have to set key
280
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
281
+ path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset')
282
+ shutil.copy(path_to_zip_file, dest)
283
+
284
+
285
+ def get_dai_docs(from_hf=False, get_pickle=True):
286
+ """
287
+ Consume DAI documentation, or consume from public pickle
288
+ :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain
289
+ :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF
290
+ :return:
291
+ """
292
+ import pickle
293
+
294
+ if get_pickle:
295
+ get_dai_pickle()
296
+
297
+ dai_store = 'dai_docs.pickle'
298
+ dst = "working_dir_docs"
299
+ if not os.path.isfile(dai_store):
300
+ from create_data import setup_dai_docs
301
+ dst = setup_dai_docs(dst=dst, from_hf=from_hf)
302
+
303
+ import glob
304
+ files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True))
305
+
306
+ basedir = os.path.abspath(os.getcwd())
307
+ from create_data import rst_to_outputs
308
+ new_outputs = rst_to_outputs(files)
309
+ os.chdir(basedir)
310
+
311
+ pickle.dump(new_outputs, open(dai_store, 'wb'))
312
+ else:
313
+ new_outputs = pickle.load(open(dai_store, 'rb'))
314
+
315
+ sources = []
316
+ for line, file in new_outputs:
317
+ # gradio requires any linked file to be with app.py
318
+ sym_src = os.path.abspath(os.path.join(dst, file))
319
+ sym_dst = os.path.abspath(os.path.join(os.getcwd(), file))
320
+ if os.path.lexists(sym_dst):
321
+ os.remove(sym_dst)
322
+ os.symlink(sym_src, sym_dst)
323
+ itm = Document(page_content=line, metadata={"source": file})
324
+ # NOTE: yield has issues when going into db, loses metadata
325
+ # yield itm
326
+ sources.append(itm)
327
+ return sources
328
+
329
+
330
+ import distutils.spawn
331
+
332
+ have_tesseract = distutils.spawn.find_executable("tesseract")
333
+ have_libreoffice = distutils.spawn.find_executable("libreoffice")
334
+
335
+ import pkg_resources
336
+
337
+ try:
338
+ assert pkg_resources.get_distribution('arxiv') is not None
339
+ assert pkg_resources.get_distribution('pymupdf') is not None
340
+ have_arxiv = True
341
+ except (pkg_resources.DistributionNotFound, AssertionError):
342
+ have_arxiv = False
343
+
344
+ image_types = ["png", "jpg", "jpeg"]
345
+ non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
346
+ "md", "html",
347
+ "enex", "eml", "epub", "odt", "pptx", "ppt",
348
+ "zip", "urls",
349
+ ]
350
+ # "msg", GPL3
351
+
352
+ if have_libreoffice:
353
+ non_image_types.extend(["docx", "doc"])
354
+
355
+ file_types = non_image_types + image_types
356
+
357
+
358
+ def add_meta(docs1, file):
359
+ file_extension = pathlib.Path(file).suffix
360
+ if not isinstance(docs1, list):
361
+ docs1 = [docs1]
362
+ [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now))) for x in docs1]
363
+
364
+
365
+ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
366
+ is_url=False, is_txt=False,
367
+ enable_captions=True,
368
+ captions_model=None,
369
+ enable_ocr=False, caption_loader=None,
370
+ headsize=50):
371
+ if file is None:
372
+ if fail_any_exception:
373
+ raise RuntimeError("Unexpected None file")
374
+ else:
375
+ return []
376
+ doc1 = [] # in case no support, or disabled support
377
+ if base_path is None and not is_txt and not is_url:
378
+ # then assume want to persist but don't care which path used
379
+ # can't be in base_path
380
+ dir_name = os.path.dirname(file)
381
+ base_name = os.path.basename(file)
382
+ # if from gradio, will have its own temp uuid too, but that's ok
383
+ base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10]
384
+ base_path = os.path.join(dir_name, base_name)
385
+ if is_url:
386
+ if file.lower().startswith('arxiv:'):
387
+ query = file.lower().split('arxiv:')
388
+ if len(query) == 2 and have_arxiv:
389
+ query = query[1]
390
+ docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load()
391
+ # ensure string, sometimes None
392
+ [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1]
393
+ query_url = f"https://arxiv.org/abs/{query}"
394
+ [x.metadata.update(
395
+ dict(source=x.metadata.get('entry_id', query_url), query=query_url,
396
+ input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in
397
+ docs1]
398
+ else:
399
+ docs1 = []
400
+ else:
401
+ docs1 = UnstructuredURLLoader(urls=[file]).load()
402
+ [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
403
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
404
+ elif is_txt:
405
+ base_path = "user_paste"
406
+ source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
407
+ makedirs(os.path.dirname(source_file), exist_ok=True)
408
+ with open(source_file, "wt") as f:
409
+ f.write(file)
410
+ metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
411
+ doc1 = Document(page_content=file, metadata=metadata)
412
+ elif file.endswith('.html') or file.endswith('.mhtml'):
413
+ docs1 = UnstructuredHTMLLoader(file_path=file).load()
414
+ add_meta(docs1, file)
415
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
416
+ elif (file.endswith('.docx') or file.endswith('.doc')) and have_libreoffice:
417
+ docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
418
+ add_meta(docs1, file)
419
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
420
+ elif file.endswith('.odt'):
421
+ docs1 = UnstructuredODTLoader(file_path=file).load()
422
+ add_meta(docs1, file)
423
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
424
+ elif file.endswith('pptx') or file.endswith('ppt'):
425
+ docs1 = UnstructuredPowerPointLoader(file_path=file).load()
426
+ add_meta(docs1, file)
427
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
428
+ elif file.endswith('.txt'):
429
+ # use UnstructuredFileLoader ?
430
+ doc1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
431
+ add_meta(doc1, file)
432
+ elif file.endswith('.rtf'):
433
+ docs1 = UnstructuredRTFLoader(file).load()
434
+ add_meta(docs1, file)
435
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
436
+ elif file.endswith('.md'):
437
+ docs1 = UnstructuredMarkdownLoader(file).load()
438
+ add_meta(docs1, file)
439
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
440
+ elif file.endswith('.enex'):
441
+ doc1 = EverNoteLoader(file).load()
442
+ add_meta(doc1, file)
443
+ elif file.endswith('.epub'):
444
+ docs1 = UnstructuredEPubLoader(file).load()
445
+ add_meta(docs1, file)
446
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
447
+ elif file.endswith('.jpeg') or file.endswith('.jpg') or file.endswith('.png'):
448
+ docs1 = []
449
+ if have_tesseract and enable_ocr:
450
+ # OCR, somewhat works, but not great
451
+ docs1.extend(UnstructuredImageLoader(file).load())
452
+ add_meta(docs1, file)
453
+ if enable_captions:
454
+ # BLIP
455
+ if caption_loader is not None and not isinstance(caption_loader, (str, bool)):
456
+ # assumes didn't fork into this process with joblib, else can deadlock
457
+ caption_loader.set_image_paths([file])
458
+ docs1c = caption_loader.load()
459
+ add_meta(docs1c, file)
460
+ [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
461
+ docs1.extend(docs1c)
462
+ else:
463
+ from image_captions import H2OImageCaptionLoader
464
+ caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu',
465
+ blip_model=captions_model,
466
+ blip_processor=captions_model)
467
+ caption_loader.set_image_paths([file])
468
+ docs1c = caption_loader.load()
469
+ add_meta(docs1c, file)
470
+ [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c]
471
+ docs1.extend(docs1c)
472
+ for doci in docs1:
473
+ doci.metadata['source'] = doci.metadata['image_path']
474
+ if docs1:
475
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
476
+ elif file.endswith('.msg'):
477
+ raise RuntimeError("Not supported, GPL3 license")
478
+ # docs1 = OutlookMessageLoader(file).load()
479
+ # docs1[0].metadata['source'] = file
480
+ elif file.endswith('.eml'):
481
+ try:
482
+ docs1 = UnstructuredEmailLoader(file).load()
483
+ add_meta(docs1, file)
484
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
485
+ except ValueError as e:
486
+ if 'text/html content not found in email' in str(e):
487
+ # e.g. plain/text dict key exists, but not
488
+ # doc1 = TextLoader(file, encoding="utf8").load()
489
+ docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
490
+ add_meta(docs1, file)
491
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
492
+ else:
493
+ raise
494
+ # elif file.endswith('.gcsdir'):
495
+ # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load()
496
+ # elif file.endswith('.gcsfile'):
497
+ # doc1 = GCSFileLoader(project_name, bucket, blob).load()
498
+ elif file.endswith('.rst'):
499
+ with open(file, "r") as f:
500
+ doc1 = Document(page_content=f.read(), metadata={"source": file})
501
+ add_meta(doc1, file)
502
+ elif file.endswith('.pdf'):
503
+ # Some PDFs return nothing or junk from PDFMinerLoader
504
+ # e.g. Beyond fine-tuning_ Classifying high resolution mammograms using function-preserving transformations _ Elsevier Enhanced Reader.pdf
505
+ doc1 = PyPDFLoader(file).load_and_split()
506
+ add_meta(doc1, file)
507
+ elif file.endswith('.csv'):
508
+ doc1 = CSVLoader(file).load()
509
+ add_meta(doc1, file)
510
+ elif file.endswith('.py'):
511
+ doc1 = PythonLoader(file).load()
512
+ add_meta(doc1, file)
513
+ elif file.endswith('.toml'):
514
+ doc1 = TomlLoader(file).load()
515
+ add_meta(doc1, file)
516
+ elif file.endswith('.urls'):
517
+ with open(file, "r") as f:
518
+ docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
519
+ add_meta(docs1, file)
520
+ doc1 = chunk_sources(docs1, chunk_size=chunk_size)
521
+ elif file.endswith('.zip'):
522
+ with zipfile.ZipFile(file, 'r') as zip_ref:
523
+ # don't put into temporary path, since want to keep references to docs inside zip
524
+ # so just extract in path where
525
+ zip_ref.extractall(base_path)
526
+ # recurse
527
+ doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception)
528
+ else:
529
+ raise RuntimeError("No file handler for %s" % os.path.basename(file))
530
+
531
+ # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now
532
+ if not isinstance(doc1, list):
533
+ if chunk:
534
+ docs = chunk_sources([doc1], chunk_size=chunk_size)
535
+ else:
536
+ docs = [doc1]
537
+ else:
538
+ docs = doc1
539
+
540
+ assert isinstance(docs, list)
541
+ return docs
542
+
543
+
544
+ def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512,
545
+ is_url=False, is_txt=False,
546
+ enable_captions=True,
547
+ captions_model=None,
548
+ enable_ocr=False, caption_loader=None):
549
+ if verbose:
550
+ if is_url:
551
+ print("Ingesting URL: %s" % file, flush=True)
552
+ elif is_txt:
553
+ print("Ingesting Text: %s" % file, flush=True)
554
+ else:
555
+ print("Ingesting file: %s" % file, flush=True)
556
+ res = None
557
+ try:
558
+ # don't pass base_path=path, would infinitely recurse
559
+ res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception,
560
+ chunk=chunk, chunk_size=chunk_size,
561
+ is_url=is_url, is_txt=is_txt,
562
+ enable_captions=enable_captions,
563
+ captions_model=captions_model,
564
+ enable_ocr=enable_ocr,
565
+ caption_loader=caption_loader)
566
+ except BaseException as e:
567
+ print("Failed to ingest %s due to %s" % (file, traceback.format_exc()))
568
+ if fail_any_exception:
569
+ raise
570
+ else:
571
+ exception_doc = Document(
572
+ page_content='',
573
+ metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()})
574
+ res = [exception_doc]
575
+ if return_file:
576
+ base_tmp = "temp_path_to_doc1"
577
+ if not os.path.isdir(base_tmp):
578
+ os.makedirs(base_tmp, exist_ok=True)
579
+ filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle")
580
+ with open(filename, 'wb') as f:
581
+ pickle.dump(res, f)
582
+ return filename
583
+ return res
584
+
585
+
586
+ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1,
587
+ chunk=True, chunk_size=512,
588
+ url=None, text=None,
589
+ enable_captions=True,
590
+ captions_model=None,
591
+ caption_loader=None,
592
+ enable_ocr=False,
593
+ ):
594
+ globs_image_types = []
595
+ globs_non_image_types = []
596
+ if path_or_paths is None:
597
+ return []
598
+ elif url:
599
+ globs_non_image_types = [url]
600
+ elif text:
601
+ globs_non_image_types = [text]
602
+ elif isinstance(path_or_paths, str):
603
+ # single path, only consume allowed files
604
+ path = path_or_paths
605
+ # Below globs should match patterns in file_to_doc()
606
+ [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
607
+ for ftype in image_types]
608
+ [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
609
+ for ftype in non_image_types]
610
+ else:
611
+ # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
612
+ assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
613
+ # reform out of allowed types
614
+ globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
615
+ # could do below:
616
+ # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types])
617
+ # But instead, allow fail so can collect unsupported too
618
+ set_globs_image_types = set(globs_image_types)
619
+ globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types])
620
+ # could use generator, but messes up metadata handling in recursive case
621
+ if caption_loader and not isinstance(caption_loader, (bool, str)) and \
622
+ caption_loader.device != 'cpu' or \
623
+ get_device() == 'cuda':
624
+ # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context
625
+ n_jobs_image = 1
626
+ else:
627
+ n_jobs_image = n_jobs
628
+
629
+ return_file = True # local choice
630
+ is_url = url is not None
631
+ is_txt = text is not None
632
+ kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception,
633
+ return_file=return_file,
634
+ chunk=chunk, chunk_size=chunk_size,
635
+ is_url=is_url,
636
+ is_txt=is_txt,
637
+ enable_captions=enable_captions,
638
+ captions_model=captions_model,
639
+ caption_loader=caption_loader,
640
+ enable_ocr=enable_ocr,
641
+ )
642
+
643
+ if n_jobs != 1 and len(globs_non_image_types) > 1:
644
+ # avoid nesting, e.g. upload 1 zip and then inside many files
645
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
646
+ documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
647
+ delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types
648
+ )
649
+ else:
650
+ documents = [path_to_doc1(file, **kwargs) for file in globs_non_image_types]
651
+
652
+ # do images separately since can't fork after cuda in parent, so can't be parallel
653
+ if n_jobs_image != 1 and len(globs_image_types) > 1:
654
+ # avoid nesting, e.g. upload 1 zip and then inside many files
655
+ # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib
656
+ image_documents = Parallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')(
657
+ delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types
658
+ )
659
+ else:
660
+ image_documents = [path_to_doc1(file, **kwargs) for file in globs_image_types]
661
+
662
+ # add image docs in
663
+ documents += image_documents
664
+
665
+ if return_file:
666
+ # then documents really are files
667
+ files = documents.copy()
668
+ documents = []
669
+ for fil in files:
670
+ with open(fil, 'rb') as f:
671
+ documents.extend(pickle.load(f))
672
+ # remove temp pickle
673
+ os.remove(fil)
674
+ else:
675
+ documents = reduce(concat, documents)
676
+ return documents
677
+
678
+
679
+ def prep_langchain(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, user_path,
680
+ hf_embedding_model, n_jobs=-1, kwargs_make_db={}):
681
+ """
682
+ do prep first time, involving downloads
683
+ # FIXME: Add github caching then add here
684
+ :return:
685
+ """
686
+ assert langchain_mode not in ['MyData'], "Should not prep scratch data"
687
+
688
+ if os.path.isdir(persist_directory):
689
+ print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
690
+ db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
691
+ hf_embedding_model)
692
+ else:
693
+ print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
694
+ db = None
695
+ if langchain_mode in ['All', 'DriverlessAI docs']:
696
+ # FIXME: Could also just use dai_docs.pickle directly and upload that
697
+ get_dai_docs(from_hf=True)
698
+
699
+ if langchain_mode in ['All', 'wiki']:
700
+ get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit'])
701
+
702
+ langchain_kwargs = kwargs_make_db.copy()
703
+ langchain_kwargs.update(locals())
704
+ db = make_db(**langchain_kwargs)
705
+
706
+ return db
707
+
708
+
709
+ def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
710
+ hf_embedding_model):
711
+ if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
712
+ os.path.join(persist_directory, 'index')):
713
+ print("DO Loading db: %s" % langchain_mode, flush=True)
714
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
715
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
716
+ collection_name=langchain_mode.replace(' ', '_'))
717
+ print("DONE Loading db: %s" % langchain_mode, flush=True)
718
+ return db
719
+ return None
720
+
721
+
722
+ def make_db(**langchain_kwargs):
723
+ func_names = list(inspect.signature(_make_db).parameters)
724
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
725
+ defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()}
726
+ for k in missing_kwargs:
727
+ if k in defaults_db:
728
+ langchain_kwargs[k] = defaults_db[k]
729
+ # final check for missing
730
+ missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
731
+ assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
732
+ # only keep actual used
733
+ langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names}
734
+ return _make_db(**langchain_kwargs)
735
+
736
+
737
+ def _make_db(use_openai_embedding=False,
738
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
739
+ first_para=False, text_limit=None, chunk=False, chunk_size=1024,
740
+ langchain_mode=None,
741
+ user_path=None,
742
+ db_type='faiss',
743
+ load_db_if_exists=False,
744
+ db=None,
745
+ n_jobs=-1):
746
+ persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
747
+ if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
748
+ os.path.join(persist_directory, 'index')):
749
+ assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
750
+ print("Loading db", flush=True)
751
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
752
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
753
+ collection_name=langchain_mode.replace(' ', '_'))
754
+ elif not db:
755
+ assert langchain_mode not in ['MyData'], "Should not make MyData db this way"
756
+ sources = []
757
+ print("Generating sources", flush=True)
758
+ if langchain_mode in ['wiki_full', 'All', "'All'"]:
759
+ from read_wiki_full import get_all_documents
760
+ small_test = None
761
+ print("Generating new wiki", flush=True)
762
+ sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
763
+ print("Got new wiki", flush=True)
764
+ if chunk:
765
+ sources1 = chunk_sources(sources1, chunk_size=chunk_size)
766
+ print("Chunked new wiki", flush=True)
767
+ sources.extend(sources1)
768
+ if langchain_mode in ['wiki', 'All', "'All'"]:
769
+ sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
770
+ if chunk:
771
+ sources1 = chunk_sources(sources1, chunk_size=chunk_size)
772
+ sources.extend(sources1)
773
+ if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
774
+ # sources = get_github_docs("dagster-io", "dagster")
775
+ sources1 = get_github_docs("h2oai", "h2ogpt")
776
+ # FIXME: always chunk for now
777
+ sources1 = chunk_sources(sources1, chunk_size=chunk_size)
778
+ sources.extend(sources1)
779
+ if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
780
+ sources1 = get_dai_docs(from_hf=True)
781
+ if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
782
+ sources1 = chunk_sources(sources1, chunk_size=chunk_size)
783
+ sources.extend(sources1)
784
+ if langchain_mode in ['All', 'UserData']:
785
+ if user_path:
786
+ # chunk internally for speed over multiple docs
787
+ sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size)
788
+ sources.extend(sources1)
789
+ else:
790
+ print("Chose UserData but user_path is empty/None", flush=True)
791
+ if False and langchain_mode in ['urls', 'All', "'All'"]:
792
+ # from langchain.document_loaders import UnstructuredURLLoader
793
+ # loader = UnstructuredURLLoader(urls=urls)
794
+ urls = ["https://www.birdsongsf.com/who-we-are/"]
795
+ from langchain.document_loaders import PlaywrightURLLoader
796
+ loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"])
797
+ sources1 = loader.load()
798
+ sources.extend(sources1)
799
+ if not sources:
800
+ print("langchain_mode %s has no sources, not making db" % langchain_mode, flush=True)
801
+ return None
802
+ print("Generating db", flush=True)
803
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
804
+ persist_directory=persist_directory, langchain_mode=langchain_mode,
805
+ hf_embedding_model=hf_embedding_model)
806
+ print("Generated db", flush=True)
807
+ return db
808
+
809
+
810
+ source_prefix = "Sources [Score | Link]:"
811
+ source_postfix = "End Sources<p>"
812
+
813
+
814
+ def run_qa_db(**kwargs):
815
+ func_names = list(inspect.signature(_run_qa_db).parameters)
816
+ # hard-coded defaults
817
+ kwargs['answer_with_sources'] = True
818
+ kwargs['sanitize_bot_response'] = True
819
+ kwargs['show_rank'] = False
820
+ missing_kwargs = [x for x in func_names if x not in kwargs]
821
+ assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
822
+ # only keep actual used
823
+ kwargs = {k: v for k, v in kwargs.items() if k in func_names}
824
+ return _run_qa_db(**kwargs)
825
+
826
+
827
+ def _run_qa_db(query=None,
828
+ use_openai_model=False, use_openai_embedding=False,
829
+ first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
830
+ user_path=None,
831
+ db_type='faiss',
832
+ model_name=None, model=None, tokenizer=None,
833
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
834
+ stream_output=False,
835
+ prompter=None,
836
+ prompt_type=None,
837
+ answer_with_sources=True,
838
+ cut_distanct=1.1,
839
+ sanitize_bot_response=True,
840
+ show_rank=False,
841
+ load_db_if_exists=False,
842
+ db=None,
843
+ max_new_tokens=256,
844
+ temperature=0.1,
845
+ repetition_penalty=1.0,
846
+ top_k=40,
847
+ top_p=0.7,
848
+ langchain_mode=None,
849
+ n_jobs=-1):
850
+ """
851
+
852
+ :param query:
853
+ :param use_openai_model:
854
+ :param use_openai_embedding:
855
+ :param first_para:
856
+ :param text_limit:
857
+ :param k:
858
+ :param chunk:
859
+ :param chunk_size:
860
+ :param user_path: user path to glob recursively from
861
+ :param db_type: 'faiss' for in-memory db or 'chroma' for persistent db
862
+ :param model_name: model name, used to switch behaviors
863
+ :param model: pre-initialized model, else will make new one
864
+ :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None
865
+ :param answer_with_sources
866
+ :return:
867
+ """
868
+
869
+ # FIXME: For All just go over all dbs instead of a separate db for All
870
+ db = make_db(**locals())
871
+ prompt_type = prompter.prompt_type if prompter is not None else prompt_type
872
+ llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
873
+ model=model, tokenizer=tokenizer,
874
+ stream_output=stream_output,
875
+ max_new_tokens=max_new_tokens,
876
+ temperature=temperature,
877
+ repetition_penalty=repetition_penalty,
878
+ top_k=top_k,
879
+ top_p=top_p,
880
+ prompt_type=prompt_type,
881
+ )
882
+
883
+ if model_name in ['llama', 'gptj']:
884
+ # FIXME: for now, streams to stdout/stderr currently
885
+ stream_output = False
886
+
887
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in ['llama', 'gptj']:
888
+ # instruct-like, rather than few-shot prompt_type='plain' as default
889
+ # but then sources confuse the model with how inserted among rest of text, so avoid
890
+ prefix = ""
891
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
892
+ use_context = False
893
+ template = """%s{context}{question}""" % prefix
894
+ else:
895
+ use_context = True
896
+ template = """%s
897
+ ==
898
+ {context}
899
+ ==
900
+ {question}""" % prefix
901
+ prompt = PromptTemplate(
902
+ # input_variables=["summaries", "question"],
903
+ input_variables=["context", "question"],
904
+ template=template,
905
+ )
906
+ chain = load_qa_chain(llm, prompt=prompt)
907
+ else:
908
+ chain = load_qa_with_sources_chain(llm)
909
+ use_context = True
910
+
911
+ if query is None:
912
+ query = "What are the main differences between Linux and Windows?"
913
+ # https://github.com/hwchase17/langchain/issues/1946
914
+ # FIXME: Seems to way to get size of chroma db to limit k to avoid
915
+ # Chroma collection MyData contains fewer than 4 elements.
916
+ # type logger error
917
+ k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
918
+
919
+ if db and use_context:
920
+ docs_with_score = db.similarity_search_with_score(query, k=k_db)[:k]
921
+ # cut off so no high distance docs/sources considered
922
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
923
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
924
+ if len(scores) > 0:
925
+ print("Distance: min: %s max: %s mean: %s median: %s" %
926
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
927
+ else:
928
+ docs = []
929
+ scores = []
930
+
931
+ if not docs and use_context:
932
+ return None
933
+
934
+ common_words_file = "data/NGSL_1.2_stats.csv.zip"
935
+ if os.path.isfile(common_words_file):
936
+ df = pd.read_csv("data/NGSL_1.2_stats.csv.zip")
937
+ import string
938
+ reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip()
939
+ reduced_query_words = reduced_query.split(' ')
940
+ set_common = set(df['Lemma'].values.tolist())
941
+ num_common = len([x.lower() in set_common for x in reduced_query_words])
942
+ frac_common = num_common / len(reduced_query)
943
+ # FIXME: report to user bad query that uses too many common words
944
+ print("frac_common: %s" % frac_common, flush=True)
945
+
946
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']:
947
+ chain_kwargs = dict(input_documents=[], question=query)
948
+ else:
949
+ chain_kwargs = dict(input_documents=docs, question=query)
950
+
951
+ if stream_output:
952
+ answer = None
953
+ assert streamer is not None
954
+ target = wrapped_partial(chain, chain_kwargs)
955
+ import queue
956
+ bucket = queue.Queue()
957
+ thread = EThread(target=target, streamer=streamer, bucket=bucket)
958
+ thread.start()
959
+ outputs = ""
960
+ prompt = None # FIXME
961
+ try:
962
+ for new_text in streamer:
963
+ # print("new_text: %s" % new_text, flush=True)
964
+ if bucket.qsize() > 0 or thread.exc:
965
+ thread.join()
966
+ outputs += new_text
967
+ if prompter: # and False: # FIXME: pipeline can already use prompter
968
+ output1 = prompter.get_response(outputs, prompt=prompt,
969
+ sanitize_bot_response=sanitize_bot_response)
970
+ yield output1
971
+ else:
972
+ yield outputs
973
+ except BaseException:
974
+ # if any exception, raise that exception if was from thread, first
975
+ if thread.exc:
976
+ raise thread.exc
977
+ raise
978
+ finally:
979
+ # in case no exception and didn't join with thread yet, then join
980
+ if not thread.exc:
981
+ answer = thread.join()
982
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
983
+ if thread.exc:
984
+ raise thread.exc
985
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
986
+ # answer = outputs
987
+ else:
988
+ answer = chain(chain_kwargs)
989
+
990
+ if not use_context:
991
+ ret = answer['output_text']
992
+ yield ret
993
+ elif answer is not None:
994
+ print("query: %s" % query, flush=True)
995
+ print("answer: %s" % answer['output_text'], flush=True)
996
+ # link
997
+ answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in
998
+ zip(scores, answer['input_documents'])]
999
+ answer_sources_dict = defaultdict(list)
1000
+ [answer_sources_dict[url].append(score) for score, url in answer_sources]
1001
+ answers_dict = {}
1002
+ for url, scores_url in answer_sources_dict.items():
1003
+ answers_dict[url] = np.max(scores_url)
1004
+ answer_sources = [(score, url) for url, score in answers_dict.items()]
1005
+ answer_sources.sort(key=lambda x: x[0], reverse=True)
1006
+ if show_rank:
1007
+ # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)]
1008
+ # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources)
1009
+ answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)]
1010
+ sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources)
1011
+ else:
1012
+ answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources]
1013
+ sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources)
1014
+ sorted_sources_urls += f"</ul></p>{source_postfix}"
1015
+
1016
+ if not answer['output_text'].endswith('\n'):
1017
+ answer['output_text'] += '\n'
1018
+
1019
+ if answer_with_sources:
1020
+ ret = answer['output_text'] + '\n' + sorted_sources_urls
1021
+ else:
1022
+ ret = answer['output_text']
1023
+
1024
+ yield ret
1025
+ return
1026
+
1027
+
1028
+ def chunk_sources(sources, chunk_size=1024):
1029
+ source_chunks = []
1030
+ # Below for known separator
1031
+ # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
1032
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
1033
+ for source in sources:
1034
+ # print(source.metadata['source'], flush=True)
1035
+ for chunky in splitter.split_text(source.page_content):
1036
+ source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
1037
+ return source_chunks
1038
+
1039
+
1040
+ def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
1041
+ from huggingface_hub import hf_hub_download
1042
+ # True for case when locally already logged in with correct token, so don't have to set key
1043
+ token = os.getenv('HUGGINGFACE_API_TOKEN', True)
1044
+ path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
1045
+ import zipfile
1046
+ with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
1047
+ zip_ref.extractall(dest)
1048
+ return path_to_zip_file
1049
+
1050
+
1051
+ # Note dir has space in some cases, while zip does not
1052
+ some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'],
1053
+ ['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'],
1054
+ ['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'],
1055
+ ['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'],
1056
+ # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1057
+ ]
1058
+
1059
+ all_db_zips = some_db_zips + \
1060
+ [['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'],
1061
+ ]
1062
+
1063
+
1064
+ def get_some_dbs_from_hf(dest='.', db_zips=None):
1065
+ if db_zips is None:
1066
+ db_zips = some_db_zips
1067
+ for db_dir, dir_expected, license1 in db_zips:
1068
+ path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir)
1069
+ assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file
1070
+ if dir_expected:
1071
+ assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected
1072
+ assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
1073
+
1074
+
1075
+ if __name__ == '__main__':
1076
+ pass
gradio_runner.py CHANGED
@@ -1,15 +1,23 @@
1
  import copy
2
  import functools
3
  import inspect
 
4
  import os
 
5
  import sys
 
 
 
 
 
6
 
7
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
8
- from prompter import Prompter
 
9
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
10
- ping
11
- from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
12
- from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa
13
 
14
  import gradio as gr
15
  from apscheduler.schedulers.background import BackgroundScheduler
@@ -25,6 +33,21 @@ def go_gradio(**kwargs):
25
  model_state0 = kwargs['model_state0']
26
  score_model_state0 = kwargs['score_model_state0']
27
  queue = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # easy update of kwargs needed for evaluate() etc.
30
  kwargs.update(locals())
@@ -42,6 +65,9 @@ def go_gradio(**kwargs):
42
  title = 'h2oGPT'
43
  if 'h2ogpt-research' in kwargs['base_model']:
44
  title += " [Research demonstration]"
 
 
 
45
  if kwargs['verbose']:
46
  description = f"""Model {kwargs['base_model']} Instruct dataset.
47
  For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
@@ -49,9 +75,11 @@ def go_gradio(**kwargs):
49
  Hash: {get_githash()}
50
  """
51
  else:
52
- description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
53
  description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
54
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
 
 
55
 
56
  if kwargs['verbose']:
57
  task_info_md = f"""
@@ -66,6 +94,9 @@ def go_gradio(**kwargs):
66
  """
67
  else:
68
  css_code = """footer {visibility: hidden}"""
 
 
 
69
 
70
  if kwargs['gradio_avoid_processing_markdown']:
71
  from gradio_client import utils as client_utils
@@ -134,6 +165,8 @@ def go_gradio(**kwargs):
134
  model_state2 = gr.State([None, None, None, None])
135
  model_options_state = gr.State([model_options])
136
  lora_options_state = gr.State([lora_options])
 
 
137
  gr.Markdown(f"""
138
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
139
 
@@ -142,7 +175,7 @@ def go_gradio(**kwargs):
142
  """)
143
  if is_hf:
144
  gr.HTML(
145
- '''<center><a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>Duplicate this Space to skip the queue and run in a private space</center>''')
146
 
147
  # go button visible if
148
  base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
@@ -153,7 +186,7 @@ def go_gradio(**kwargs):
153
  with gr.Row():
154
  col_nochat = gr.Column(visible=not kwargs['chat'])
155
  with col_nochat: # FIXME: for model comparison, and check rest
156
- text_output_nochat = gr.Textbox(lines=5, label=output_label0)
157
  instruction_nochat = gr.Textbox(
158
  lines=kwargs['input_lines'],
159
  label=instruction_label_nochat,
@@ -187,7 +220,7 @@ def go_gradio(**kwargs):
187
  submit = gr.Button(value='Submit').style(full_width=False, size='sm')
188
  stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
189
  with gr.Row():
190
- clear = gr.Button("New Conversation")
191
  flag_btn = gr.Button("Flag")
192
  if not kwargs['auto_score']: # FIXME: For checkbox model2
193
  with gr.Column(visible=kwargs['score_model']):
@@ -206,7 +239,7 @@ def go_gradio(**kwargs):
206
  score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
207
  retry = gr.Button("Regenerate")
208
  undo = gr.Button("Undo")
209
- with gr.TabItem("Input/Output"):
210
  with gr.Row():
211
  if 'mbart-' in kwargs['model_lower']:
212
  src_lang = gr.Dropdown(list(languages_covered().keys()),
@@ -215,6 +248,122 @@ def go_gradio(**kwargs):
215
  tgt_lang = gr.Dropdown(list(languages_covered().keys()),
216
  value=kwargs['tgt_lang'],
217
  label="Output Language")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  with gr.TabItem("Expert"):
219
  with gr.Row():
220
  with gr.Column():
@@ -243,7 +392,7 @@ def go_gradio(**kwargs):
243
  )
244
  # FIXME: https://github.com/h2oai/h2ogpt/issues/106
245
  if os.getenv('TESTINGFAIL'):
246
- max_beams = 8 if not (is_low_mem or is_public) else 1
247
  else:
248
  max_beams = 1
249
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
@@ -356,12 +505,13 @@ def go_gradio(**kwargs):
356
  with gr.Column():
357
  with gr.Row():
358
  system_btn = gr.Button(value='Get System Info')
359
- system_text = gr.Textbox(label='System Info', interactive=False)
 
360
 
361
  with gr.Row():
362
  zip_btn = gr.Button("Zip")
363
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
364
- file_output = gr.File()
365
  with gr.Row():
366
  s3up_btn = gr.Button("S3UP")
367
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
@@ -378,8 +528,103 @@ def go_gradio(**kwargs):
378
 
379
  # Get flagged data
380
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
381
- zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False)
382
- s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
  def check_admin_pass(x):
385
  return gr.update(visible=x == admin_pass)
@@ -569,49 +814,66 @@ def go_gradio(**kwargs):
569
  """
570
  # don't deepcopy, can contain model itself
571
  args_list = list(args).copy()
572
- history = args_list[-1] # model_state is -2
 
 
 
 
 
573
  if retry and history:
574
  history.pop()
 
 
 
575
  if not history:
576
  print("No history", flush=True)
577
  history = [['', None]]
578
  yield history, ''
579
  return
580
  # ensure output will be unique to models
 
581
  history = copy.deepcopy(history)
582
  instruction1 = history[-1][0]
583
  context1 = ''
584
- if kwargs['chat_history'] > 0:
585
- prompt_type_arg_id = eval_func_param_names.index('prompt_type')
586
- prompt_type1 = args_list[prompt_type_arg_id]
587
- chat_arg_id = eval_func_param_names.index('chat')
588
- chat1 = args_list[chat_arg_id]
589
  context1 = ''
590
- for histi in range(len(history) - 1):
 
591
  data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
592
  prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
593
  chat1, reduced=True)
594
- # md -> back to text, maybe not super improtant if model trained enough
 
 
 
 
 
 
 
595
  prompt = prompt.replace('<br>', chat_sep)
596
- context1 += prompt
597
- if not context1.endswith(chat_sep):
598
- context1 += chat_sep
 
 
 
 
599
 
600
  _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
601
  reduced=True)
602
  if context1 and not context1.endswith(chat_sep):
603
  context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
604
  args_list[0] = instruction1 # override original instruction with history from user
605
- # only include desired chat history
606
- args_list[2] = context1[-kwargs['chat_history']:]
607
- model_state1 = args_list[-2]
608
  if model_state1[0] is None or model_state1[0] == no_model_str:
609
  history = [['', None]]
610
  yield history, ''
611
  return
612
- args_list = args_list[:-2]
613
  fun1 = partial(evaluate,
614
  model_state1,
 
615
  **kwargs_evaluate)
616
  try:
617
  for output in fun1(*tuple(args_list)):
@@ -645,11 +907,11 @@ def go_gradio(**kwargs):
645
  outputs=text_output,
646
  )
647
  bot_args = dict(fn=bot,
648
- inputs=inputs_list + [model_state] + [text_output],
649
  outputs=[text_output, exception_text],
650
  )
651
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
652
- inputs=inputs_list + [model_state] + [text_output],
653
  outputs=[text_output, exception_text],
654
  )
655
  undo_user_args = dict(fn=functools.partial(user, undo=True),
@@ -663,11 +925,11 @@ def go_gradio(**kwargs):
663
  outputs=text_output2,
664
  )
665
  bot_args2 = dict(fn=bot,
666
- inputs=inputs_list + [model_state2] + [text_output2],
667
  outputs=[text_output2, exception_text],
668
  )
669
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
670
- inputs=inputs_list + [model_state2] + [text_output2],
671
  outputs=[text_output2, exception_text],
672
  )
673
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
@@ -694,7 +956,8 @@ def go_gradio(**kwargs):
694
  .then(clear_instruct, None, iinput)
695
  submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
696
  queue=queue)
697
- submit_event1e = submit_event1d.then(**score_args_submit, api_name='instruction_bot_score' if allow_api else None,
 
698
  queue=queue)
699
  submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
700
  queue=queue)
@@ -735,12 +998,134 @@ def go_gradio(**kwargs):
735
  .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
736
  .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
737
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
738
  # does both models
739
- clear.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
740
- .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None)
 
 
 
 
741
  # NOTE: clear of instruction/iinput for nochat has to come after score,
742
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
743
- submit_event_nochat = submit_nochat.click(fun, inputs=[model_state] + inputs_list,
 
744
  outputs=text_output_nochat,
745
  queue=queue,
746
  api_name='submit_nochat' if allow_api else None) \
@@ -842,8 +1227,8 @@ def go_gradio(**kwargs):
842
  new_state = [list0[0] + [x]]
843
  new_options = [*new_state[0]]
844
  return gr.Dropdown.update(value=x, choices=new_options), \
845
- gr.Dropdown.update(value=x, choices=new_options), \
846
- '', new_state
847
 
848
  add_model_event = add_model_button.click(fn=dropdown_model_list,
849
  inputs=[model_options_state, new_model],
@@ -857,8 +1242,8 @@ def go_gradio(**kwargs):
857
  x1 = x if model_used1 == no_model_str else lora_used1
858
  x2 = x if model_used2 == no_model_str else lora_used2
859
  return gr.Dropdown.update(value=x1, choices=new_options), \
860
- gr.Dropdown.update(value=x2, choices=new_options), \
861
- '', new_state
862
 
863
  add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
864
  inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
@@ -916,10 +1301,20 @@ def go_gradio(**kwargs):
916
 
917
  scheduler = BackgroundScheduler()
918
  scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
919
- if is_public:
 
 
 
920
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
921
  scheduler.start()
922
 
 
 
 
 
 
 
 
923
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
924
  favicon_path=favicon_path, prevent_thread_lock=True,
925
  auth=kwargs['auth'])
@@ -928,9 +1323,7 @@ def go_gradio(**kwargs):
928
  demo.block_thread()
929
 
930
 
931
- input_args_list = ['model_state']
932
- inputs_kwargs_list = ['debug', 'save_dir', 'sanitize_bot_response', 'model_state0', 'is_low_mem',
933
- 'raise_generate_gpu_exceptions', 'chat_context', 'concurrency_count', 'lora_weights']
934
 
935
 
936
  def get_inputs_list(inputs_dict, model_lower):
@@ -946,9 +1339,204 @@ def get_inputs_list(inputs_dict, model_lower):
946
  if k == 'kwargs':
947
  continue
948
  if k in input_args_list + inputs_kwargs_list:
949
- # these are added via partial, not taken as input
950
  continue
951
  if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
952
  continue
953
  inputs_list.append(inputs_dict[k])
954
  return inputs_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import copy
2
  import functools
3
  import inspect
4
+ import json
5
  import os
6
+ import random
7
  import sys
8
+ import traceback
9
+ import uuid
10
+ import filelock
11
+ import pandas as pd
12
+ import tabulate
13
 
14
  from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
15
+ from prompter import Prompter, \
16
+ prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt
17
  from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
18
+ ping, get_short_name, get_url, makedirs
19
+ from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
20
+ inputs_kwargs_list, get_cutoffs, scratch_base_dir
21
 
22
  import gradio as gr
23
  from apscheduler.schedulers.background import BackgroundScheduler
 
33
  model_state0 = kwargs['model_state0']
34
  score_model_state0 = kwargs['score_model_state0']
35
  queue = True
36
+ dbs = kwargs['dbs']
37
+ db_type = kwargs['db_type']
38
+ visible_langchain_modes = kwargs['visible_langchain_modes']
39
+ allow_upload_to_user_data = kwargs['allow_upload_to_user_data']
40
+ allow_upload_to_my_data = kwargs['allow_upload_to_my_data']
41
+ enable_sources_list = kwargs['enable_sources_list']
42
+ enable_url_upload = kwargs['enable_url_upload']
43
+ enable_text_upload = kwargs['enable_text_upload']
44
+ allow_upload = allow_upload_to_user_data or allow_upload_to_my_data
45
+ use_openai_embedding = kwargs['use_openai_embedding']
46
+ hf_embedding_model = kwargs['hf_embedding_model']
47
+ enable_captions = kwargs['enable_captions']
48
+ captions_model = kwargs['captions_model']
49
+ enable_ocr = kwargs['enable_ocr']
50
+ caption_loader = kwargs['caption_loader']
51
 
52
  # easy update of kwargs needed for evaluate() etc.
53
  kwargs.update(locals())
 
65
  title = 'h2oGPT'
66
  if 'h2ogpt-research' in kwargs['base_model']:
67
  title += " [Research demonstration]"
68
+ more_info = """For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O-LLMStudio](https://github.com/h2oai/h2o-llmstudio)<br>"""
69
+ if is_public:
70
+ more_info += """<iframe src="https://ghbtns.com/github-btn.html?user=h2oai&repo=h2ogpt&type=star&count=true&size=small" frameborder="0" scrolling="0" width="150" height="20" title="GitHub"></iframe>"""
71
  if kwargs['verbose']:
72
  description = f"""Model {kwargs['base_model']} Instruct dataset.
73
  For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
 
75
  Hash: {get_githash()}
76
  """
77
  else:
78
+ description = more_info
79
  description += "If this host is busy, try [12B](https://gpt.h2o.ai), [30B](http://gpt2.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
80
  description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
81
+ if is_hf:
82
+ description += '''<a href="https://huggingface.co/spaces/h2oai/h2ogpt-chatbot?duplicate=true"><img src="https://bit.ly/3gLdBN6" style="white-space: nowrap" alt="Duplicate Space"></a>'''
83
 
84
  if kwargs['verbose']:
85
  task_info_md = f"""
 
94
  """
95
  else:
96
  css_code = """footer {visibility: hidden}"""
97
+ css_code += """
98
+ body.dark{#warning {background-color: #555555};}
99
+ """
100
 
101
  if kwargs['gradio_avoid_processing_markdown']:
102
  from gradio_client import utils as client_utils
 
165
  model_state2 = gr.State([None, None, None, None])
166
  model_options_state = gr.State([model_options])
167
  lora_options_state = gr.State([lora_options])
168
+ my_db_state = gr.State([None, None])
169
+ chat_state = gr.State({})
170
  gr.Markdown(f"""
171
  {get_h2o_title(title) if kwargs['h2ocolors'] else get_simple_title(title)}
172
 
 
175
  """)
176
  if is_hf:
177
  gr.HTML(
178
+ )
179
 
180
  # go button visible if
181
  base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0']
 
186
  with gr.Row():
187
  col_nochat = gr.Column(visible=not kwargs['chat'])
188
  with col_nochat: # FIXME: for model comparison, and check rest
189
+ text_output_nochat = gr.Textbox(lines=5, label=output_label0).style(show_copy_button=True)
190
  instruction_nochat = gr.Textbox(
191
  lines=kwargs['input_lines'],
192
  label=instruction_label_nochat,
 
220
  submit = gr.Button(value='Submit').style(full_width=False, size='sm')
221
  stop_btn = gr.Button(value="Stop").style(full_width=False, size='sm')
222
  with gr.Row():
223
+ clear = gr.Button("Save, New Conversation")
224
  flag_btn = gr.Button("Flag")
225
  if not kwargs['auto_score']: # FIXME: For checkbox model2
226
  with gr.Column(visible=kwargs['score_model']):
 
239
  score_text2 = gr.Textbox("Response Score2: NA", show_label=False, visible=False)
240
  retry = gr.Button("Regenerate")
241
  undo = gr.Button("Undo")
242
+ with gr.TabItem("Chat"):
243
  with gr.Row():
244
  if 'mbart-' in kwargs['model_lower']:
245
  src_lang = gr.Dropdown(list(languages_covered().keys()),
 
248
  tgt_lang = gr.Dropdown(list(languages_covered().keys()),
249
  value=kwargs['tgt_lang'],
250
  label="Output Language")
251
+ radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
252
+ type='value')
253
+ with gr.Row():
254
+ remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True)
255
+ clear_chat_btn = gr.Button(value="Clear Chat", visible=True)
256
+ chats_row = gr.Row(visible=True).style(equal_height=False)
257
+ with chats_row:
258
+ export_chats_btn = gr.Button(value="Export Chats")
259
+ chats_file = gr.File(interactive=False, label="Download File")
260
+ chats_row2 = gr.Row(visible=True).style(equal_height=False)
261
+ with chats_row2:
262
+ chatsup_output = gr.File(label="Upload Chat File(s)",
263
+ file_types=['.json'],
264
+ file_count='multiple',
265
+ elem_id="warning", elem_classes="feedback")
266
+ add_to_chats_btn = gr.Button("Add File(s) to Chats")
267
+ with gr.TabItem("Data Source"):
268
+ langchain_readme = get_url('https://github.com/h2oai/h2ogpt/blob/main/README_LangChain.md',
269
+ from_str=True)
270
+ gr.HTML(value=f"""LangChain Support Disabled<p>
271
+ Run:<p>
272
+ <code>
273
+ python generate.py --langchain_mode=MyData
274
+ </code>
275
+ <p>
276
+ For more options see: {langchain_readme}""",
277
+ visible=kwargs['langchain_mode'] == 'Disabled', interactive=False)
278
+ data_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled')
279
+ with data_row:
280
+ if is_hf:
281
+ # don't show 'wiki' since only usually useful for internal testing at moment
282
+ no_show_modes = ['Disabled', 'wiki']
283
+ else:
284
+ no_show_modes = ['Disabled']
285
+ allowed_modes = visible_langchain_modes.copy()
286
+ allowed_modes = [x for x in allowed_modes if x in dbs]
287
+ allowed_modes += ['ChatLLM', 'LLM']
288
+ if allow_upload_to_my_data and 'MyData' not in allowed_modes:
289
+ allowed_modes += ['MyData']
290
+ if allow_upload_to_user_data and 'UserData' not in allowed_modes:
291
+ allowed_modes += ['UserData']
292
+ langchain_mode = gr.Radio(
293
+ [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes],
294
+ value=kwargs['langchain_mode'],
295
+ label="Data Source",
296
+ visible=kwargs['langchain_mode'] != 'Disabled')
297
+
298
+ def upload_file(files, x):
299
+ file_paths = [file.name for file in files]
300
+ return files, file_paths
301
+
302
+ upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
303
+ equal_height=False)
304
+ # import control
305
+ if kwargs['langchain_mode'] != 'Disabled':
306
+ from gpt_langchain import file_types, have_arxiv
307
+ else:
308
+ have_arxiv = False
309
+ file_types = []
310
+ with upload_row:
311
+ file_types_str = '[' + ' '.join(file_types) + ']'
312
+ fileup_output = gr.File(label=f'Upload {file_types_str}',
313
+ file_types=file_types,
314
+ file_count="multiple",
315
+ elem_id="warning", elem_classes="feedback")
316
+ with gr.Row():
317
+ upload_button = gr.UploadButton("Upload %s" % file_types_str,
318
+ file_types=file_types,
319
+ file_count="multiple",
320
+ visible=False,
321
+ )
322
+ # add not visible until upload something
323
+ with gr.Column():
324
+ add_to_shared_db_btn = gr.Button("Add File(s) to Shared UserData DB",
325
+ visible=allow_upload_to_user_data) # and False)
326
+ add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData DB",
327
+ visible=allow_upload_to_my_data) # and False)
328
+ url_row = gr.Row(
329
+ visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload).style(
330
+ equal_height=False)
331
+ with url_row:
332
+ url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
333
+ url_text = gr.Textbox(label=url_label, interactive=True)
334
+ with gr.Column():
335
+ url_user_btn = gr.Button(value='Add URL content to Shared UserData DB',
336
+ visible=allow_upload_to_user_data)
337
+ url_my_btn = gr.Button(value='Add URL content to Scratch MyData DB',
338
+ visible=allow_upload_to_my_data)
339
+ text_row = gr.Row(
340
+ visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload).style(
341
+ equal_height=False)
342
+ with text_row:
343
+ user_text_text = gr.Textbox(label='Paste Text', interactive=True)
344
+ with gr.Column():
345
+ user_text_user_btn = gr.Button(value='Add Text to Shared UserData DB',
346
+ visible=allow_upload_to_user_data)
347
+ user_text_my_btn = gr.Button(value='Add Text to Scratch MyData DB',
348
+ visible=allow_upload_to_my_data)
349
+ # WIP:
350
+ with gr.Row(visible=False).style(equal_height=False):
351
+ github_textbox = gr.Textbox(label="Github URL")
352
+ with gr.Row(visible=True):
353
+ github_shared_btn = gr.Button(value="Add Github to Shared UserData DB",
354
+ visible=allow_upload_to_user_data)
355
+ github_my_btn = gr.Button(value="Add Github to Scratch MyData DB",
356
+ visible=allow_upload_to_my_data)
357
+ sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
358
+ equal_height=False)
359
+ with sources_row:
360
+ sources_text = gr.HTML(label='Sources Added', interactive=False)
361
+ sources_row2 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
362
+ equal_height=False)
363
+ with sources_row2:
364
+ get_sources_btn = gr.Button(value="Get Sources List for Selected DB")
365
+ file_source = gr.File(interactive=False, label="Download File with list of Sources")
366
+
367
  with gr.TabItem("Expert"):
368
  with gr.Row():
369
  with gr.Column():
 
392
  )
393
  # FIXME: https://github.com/h2oai/h2ogpt/issues/106
394
  if os.getenv('TESTINGFAIL'):
395
+ max_beams = 8 if not (is_low_mem or is_public) else 1
396
  else:
397
  max_beams = 1
398
  num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
 
505
  with gr.Column():
506
  with gr.Row():
507
  system_btn = gr.Button(value='Get System Info')
508
+ system_text = gr.Textbox(label='System Info', interactive=False).style(
509
+ show_copy_button=True)
510
 
511
  with gr.Row():
512
  zip_btn = gr.Button("Zip")
513
  zip_text = gr.Textbox(label="Zip file name", interactive=False)
514
+ file_output = gr.File(interactive=False)
515
  with gr.Row():
516
  s3up_btn = gr.Button("S3UP")
517
  s3up_text = gr.Textbox(label='S3UP result', interactive=False)
 
528
 
529
  # Get flagged data
530
  zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']])
531
+ zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False,
532
+ api_name='zip_data' if allow_api else None)
533
+ s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False,
534
+ api_name='s3up_data' if allow_api else None)
535
+
536
+ def make_add_visible(x):
537
+ return gr.update(visible=x is not None)
538
+
539
+ def clear_file_list():
540
+ return None
541
+
542
+ def make_invisible():
543
+ return gr.update(visible=False)
544
+
545
+ def make_visible():
546
+ return gr.update(visible=True)
547
+
548
+ # add itself to output to ensure shows working and can't click again
549
+ upload_button.upload(upload_file, inputs=[upload_button, fileup_output],
550
+ outputs=[upload_button, fileup_output], queue=queue,
551
+ api_name='upload_file' if allow_api else None) \
552
+ .then(make_add_visible, fileup_output, add_to_shared_db_btn, queue=queue) \
553
+ .then(make_add_visible, fileup_output, add_to_my_db_btn, queue=queue) \
554
+ .then(make_invisible, outputs=upload_button, queue=queue)
555
+
556
+ # Add to UserData
557
+ update_user_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='UserData',
558
+ use_openai_embedding=use_openai_embedding,
559
+ hf_embedding_model=hf_embedding_model,
560
+ enable_captions=enable_captions,
561
+ captions_model=captions_model,
562
+ enable_ocr=enable_ocr,
563
+ caption_loader=caption_loader,
564
+ )
565
+
566
+ # note for update_user_db_func output is ignored for db
567
+ add_to_shared_db_btn.click(update_user_db_func,
568
+ inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
569
+ outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
570
+ api_name='add_to_shared' if allow_api else None) \
571
+ .then(clear_file_list, outputs=fileup_output, queue=queue)
572
+
573
+ # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
574
+ # .then(make_visible, outputs=upload_button, queue=queue)
575
+
576
+ def clear_textbox():
577
+ return gr.Textbox.update(value='')
578
+
579
+ update_user_db_url_func = functools.partial(update_user_db_func, is_url=True)
580
+ url_user_btn.click(update_user_db_url_func,
581
+ inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
582
+ outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
583
+ api_name='add_url_to_shared' if allow_api else None) \
584
+ .then(clear_textbox, outputs=url_text, queue=queue)
585
+
586
+ update_user_db_txt_func = functools.partial(update_user_db_func, is_txt=True)
587
+ user_text_user_btn.click(update_user_db_txt_func,
588
+ inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
589
+ outputs=[add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
590
+ api_name='add_text_to_shared' if allow_api else None) \
591
+ .then(clear_textbox, outputs=user_text_text, queue=queue)
592
+
593
+ # Add to MyData
594
+ update_my_db_func = functools.partial(update_user_db, dbs=dbs, db_type=db_type, langchain_mode='MyData',
595
+ use_openai_embedding=use_openai_embedding,
596
+ hf_embedding_model=hf_embedding_model,
597
+ enable_captions=enable_captions,
598
+ captions_model=captions_model,
599
+ enable_ocr=enable_ocr,
600
+ caption_loader=caption_loader,
601
+ )
602
+
603
+ add_to_my_db_btn.click(update_my_db_func,
604
+ inputs=[fileup_output, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
605
+ outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
606
+ api_name='add_to_my' if allow_api else None) \
607
+ .then(clear_file_list, outputs=fileup_output, queue=queue)
608
+ # .then(make_invisible, outputs=add_to_shared_db_btn, queue=queue)
609
+ # .then(make_visible, outputs=upload_button, queue=queue)
610
+
611
+ update_my_db_url_func = functools.partial(update_my_db_func, is_url=True)
612
+ url_my_btn.click(update_my_db_url_func,
613
+ inputs=[url_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
614
+ outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
615
+ api_name='add_url_to_my' if allow_api else None) \
616
+ .then(clear_textbox, outputs=url_text, queue=queue)
617
+
618
+ update_my_db_txt_func = functools.partial(update_my_db_func, is_txt=True)
619
+ user_text_my_btn.click(update_my_db_txt_func,
620
+ inputs=[user_text_text, my_db_state, add_to_shared_db_btn, add_to_my_db_btn],
621
+ outputs=[my_db_state, add_to_shared_db_btn, add_to_my_db_btn, sources_text], queue=queue,
622
+ api_name='add_txt_to_my' if allow_api else None) \
623
+ .then(clear_textbox, outputs=user_text_text, queue=queue)
624
+
625
+ get_sources1 = functools.partial(get_sources, dbs=dbs)
626
+ get_sources_btn.click(get_sources1, inputs=[my_db_state, langchain_mode], outputs=file_source, queue=queue,
627
+ api_name='get_sources' if allow_api else None)
628
 
629
  def check_admin_pass(x):
630
  return gr.update(visible=x == admin_pass)
 
814
  """
815
  # don't deepcopy, can contain model itself
816
  args_list = list(args).copy()
817
+ model_state1 = args_list[-3]
818
+ my_db_state1 = args_list[-2]
819
+ history = args_list[-1]
820
+
821
+ args_list = args_list[:-3] # only keep rest needed for evaluate()
822
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
823
  if retry and history:
824
  history.pop()
825
+ if not args_list[eval_func_param_names.index('do_sample')]:
826
+ # if was not sampling, no point in retry unless change to sample
827
+ args_list[eval_func_param_names.index('do_sample')] = True
828
  if not history:
829
  print("No history", flush=True)
830
  history = [['', None]]
831
  yield history, ''
832
  return
833
  # ensure output will be unique to models
834
+ _, _, _, max_prompt_length = get_cutoffs(is_low_mem, for_context=True)
835
  history = copy.deepcopy(history)
836
  instruction1 = history[-1][0]
837
  context1 = ''
838
+ if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
839
+ prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
840
+ chat1 = args_list[eval_func_param_names.index('chat')]
 
 
841
  context1 = ''
842
+ # - 1 below because current instruction already in history from user()
843
+ for histi in range(0, len(history) - 1):
844
  data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
845
  prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
846
  chat1, reduced=True)
847
+ # md -> back to text, maybe not super important if model trained enough
848
+ if not kwargs['keep_sources_in_context']:
849
+ from gpt_langchain import source_prefix, source_postfix
850
+ import re
851
+ prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
852
+ flags=re.DOTALL)
853
+ if prompt.endswith('\n<p>'):
854
+ prompt = prompt[:-4]
855
  prompt = prompt.replace('<br>', chat_sep)
856
+ if not prompt.endswith(chat_sep):
857
+ prompt += chat_sep
858
+ # most recent first, add older if can
859
+ # only include desired chat history
860
+ if len(prompt + context1) > max_prompt_length:
861
+ break
862
+ context1 = prompt + context1
863
 
864
  _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
865
  reduced=True)
866
  if context1 and not context1.endswith(chat_sep):
867
  context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
868
  args_list[0] = instruction1 # override original instruction with history from user
869
+ args_list[2] = context1
 
 
870
  if model_state1[0] is None or model_state1[0] == no_model_str:
871
  history = [['', None]]
872
  yield history, ''
873
  return
 
874
  fun1 = partial(evaluate,
875
  model_state1,
876
+ my_db_state1,
877
  **kwargs_evaluate)
878
  try:
879
  for output in fun1(*tuple(args_list)):
 
907
  outputs=text_output,
908
  )
909
  bot_args = dict(fn=bot,
910
+ inputs=inputs_list + [model_state, my_db_state] + [text_output],
911
  outputs=[text_output, exception_text],
912
  )
913
  retry_bot_args = dict(fn=functools.partial(bot, retry=True),
914
+ inputs=inputs_list + [model_state, my_db_state] + [text_output],
915
  outputs=[text_output, exception_text],
916
  )
917
  undo_user_args = dict(fn=functools.partial(user, undo=True),
 
925
  outputs=text_output2,
926
  )
927
  bot_args2 = dict(fn=bot,
928
+ inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
929
  outputs=[text_output2, exception_text],
930
  )
931
  retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
932
+ inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
933
  outputs=[text_output2, exception_text],
934
  )
935
  undo_user_args2 = dict(fn=functools.partial(user, undo=True),
 
956
  .then(clear_instruct, None, iinput)
957
  submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
958
  queue=queue)
959
+ submit_event1e = submit_event1d.then(**score_args_submit,
960
+ api_name='instruction_bot_score' if allow_api else None,
961
  queue=queue)
962
  submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
963
  queue=queue)
 
998
  .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
999
  .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
1000
 
1001
+ # MANAGE CHATS
1002
+ def dedup(short_chat, short_chats):
1003
+ if short_chat not in short_chats:
1004
+ return short_chat
1005
+ for i in range(1, 1000):
1006
+ short_chat_try = short_chat + "_" + str(i)
1007
+ if short_chat_try not in short_chats:
1008
+ return short_chat_try
1009
+ # fallback and hope for best
1010
+ short_chat = short_chat + "_" + str(random.random())
1011
+ return short_chat
1012
+
1013
+ def get_short_chat(x, short_chats, short_len=20, words=4):
1014
+ if x and len(x[0]) == 2 and x[0][0] is not None:
1015
+ short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip()
1016
+ short_chat = dedup(short_chat, short_chats)
1017
+ else:
1018
+ short_chat = None
1019
+ return short_chat
1020
+
1021
+ def is_chat_same(x, y):
1022
+ # <p> etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation
1023
+ is_same = True
1024
+ # length of conversation has to be same
1025
+ if len(x) != len(y):
1026
+ return False
1027
+ for stepx, stepy in zip(x, y):
1028
+ if len(stepx) != len(stepy):
1029
+ # something off with a conversation
1030
+ return False
1031
+ if len(stepx) != 2:
1032
+ # something off
1033
+ return False
1034
+ if len(stepy) != 2:
1035
+ # something off
1036
+ return False
1037
+ questionx = stepx[0].replace('<p>', '').replace('</p>', '')
1038
+ answerx = stepx[1].replace('<p>', '').replace('</p>', '')
1039
+
1040
+ questiony = stepy[0].replace('<p>', '').replace('</p>', '')
1041
+ answery = stepy[1].replace('<p>', '').replace('</p>', '')
1042
+
1043
+ if questionx != questiony or answerx != answery:
1044
+ return False
1045
+ return is_same
1046
+
1047
+ def save_chat(chat1, chat2, chat_state1):
1048
+ short_chats = list(chat_state1.keys())
1049
+ for chati in [chat1, chat2]:
1050
+ if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None:
1051
+ short_chat = get_short_chat(chati, short_chats)
1052
+ if short_chat:
1053
+ already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()])
1054
+ if not already_exists:
1055
+ chat_state1[short_chat] = chati
1056
+ return chat_state1
1057
+
1058
+ def update_radio_chats(chat_state1):
1059
+ return gr.update(choices=list(chat_state1.keys()), value=None)
1060
+
1061
+ def deselect_radio_chats():
1062
+ return gr.update(value=None)
1063
+
1064
+ def switch_chat(chat_key, chat_state1):
1065
+ chosen_chat = chat_state1[chat_key]
1066
+ return chosen_chat, chosen_chat
1067
+
1068
+ radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2])
1069
+
1070
+ def remove_chat(chat_key, chat_state1):
1071
+ chat_state1.pop(chat_key, None)
1072
+ return chat_state1
1073
+
1074
+ remove_chat_btn.click(remove_chat, inputs=[radio_chats, chat_state], outputs=chat_state) \
1075
+ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats)
1076
+
1077
+ def get_chats1(chat_state1):
1078
+ base = 'chats'
1079
+ makedirs(base, exist_ok=True)
1080
+ filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4()))
1081
+ with open(filename, "wt") as f:
1082
+ f.write(json.dumps(chat_state1, indent=2))
1083
+ return filename
1084
+
1085
+ export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False,
1086
+ api_name='export_chats' if allow_api else None)
1087
+
1088
+ def add_chats_from_file(file, chat_state1, add_btn):
1089
+ if isinstance(file, str):
1090
+ files = [file]
1091
+ else:
1092
+ files = file
1093
+ for file1 in files:
1094
+ try:
1095
+ if hasattr(file1, 'name'):
1096
+ file1 = file1.name
1097
+ with open(file1, "rt") as f:
1098
+ new_chats = json.loads(f.read())
1099
+ for chat1_k, chat1_v in new_chats.items():
1100
+ # ignore chat1_k, regenerate and de-dup to avoid loss
1101
+ chat_state1 = save_chat(chat1_v, None, chat_state1)
1102
+ except BaseException as e:
1103
+ print("Add chats exception: %s" % str(e), flush=True)
1104
+ return chat_state1, add_btn
1105
+
1106
+ # note for update_user_db_func output is ignored for db
1107
+ add_to_chats_btn.click(add_chats_from_file,
1108
+ inputs=[chatsup_output, chat_state, add_to_chats_btn],
1109
+ outputs=[chat_state, add_to_my_db_btn], queue=False,
1110
+ api_name='add_to_chats' if allow_api else None) \
1111
+ .then(clear_file_list, outputs=chatsup_output, queue=False) \
1112
+ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False)
1113
+
1114
+ clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \
1115
+ .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \
1116
+ .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False)
1117
+
1118
  # does both models
1119
+ clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state,
1120
+ api_name='save_chat' if allow_api else None) \
1121
+ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats,
1122
+ api_name='update_chats' if allow_api else None) \
1123
+ .then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \
1124
+ .then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None)
1125
  # NOTE: clear of instruction/iinput for nochat has to come after score,
1126
  # because score for nochat consumes actual textbox, while chat consumes chat history filled by user()
1127
+ submit_event_nochat = submit_nochat.click(fun,
1128
+ inputs=[model_state, my_db_state] + inputs_list,
1129
  outputs=text_output_nochat,
1130
  queue=queue,
1131
  api_name='submit_nochat' if allow_api else None) \
 
1227
  new_state = [list0[0] + [x]]
1228
  new_options = [*new_state[0]]
1229
  return gr.Dropdown.update(value=x, choices=new_options), \
1230
+ gr.Dropdown.update(value=x, choices=new_options), \
1231
+ '', new_state
1232
 
1233
  add_model_event = add_model_button.click(fn=dropdown_model_list,
1234
  inputs=[model_options_state, new_model],
 
1242
  x1 = x if model_used1 == no_model_str else lora_used1
1243
  x2 = x if model_used2 == no_model_str else lora_used2
1244
  return gr.Dropdown.update(value=x1, choices=new_options), \
1245
+ gr.Dropdown.update(value=x2, choices=new_options), \
1246
+ '', new_state
1247
 
1248
  add_lora_event = add_lora_button.click(fn=dropdown_lora_list,
1249
  inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2,
 
1301
 
1302
  scheduler = BackgroundScheduler()
1303
  scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20)
1304
+ if is_public and \
1305
+ kwargs['base_model'] not in ['gptj', 'llama']:
1306
+ # FIXME: disable for gptj, langchain or gpt4all modify print itself
1307
+ # FIXME: and any multi-threaded/async print will enter model output!
1308
  scheduler.add_job(func=ping, trigger="interval", seconds=60)
1309
  scheduler.start()
1310
 
1311
+ # import control
1312
+ if kwargs['langchain_mode'] == 'Disabled' and \
1313
+ os.environ.get("TEST_LANGCHAIN_IMPORT") and \
1314
+ kwargs['base_model'] not in ['gptj', 'llama']:
1315
+ assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1316
+ assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
1317
+
1318
  demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True,
1319
  favicon_path=favicon_path, prevent_thread_lock=True,
1320
  auth=kwargs['auth'])
 
1323
  demo.block_thread()
1324
 
1325
 
1326
+ input_args_list = ['model_state', 'my_db_state']
 
 
1327
 
1328
 
1329
  def get_inputs_list(inputs_dict, model_lower):
 
1339
  if k == 'kwargs':
1340
  continue
1341
  if k in input_args_list + inputs_kwargs_list:
1342
+ # these are added at use time for args or partial for kwargs, not taken as input
1343
  continue
1344
  if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']:
1345
  continue
1346
  inputs_list.append(inputs_dict[k])
1347
  return inputs_list
1348
+
1349
+
1350
+ def get_sources(db1, langchain_mode, dbs=None):
1351
+ if langchain_mode in ['ChatLLM', 'LLM']:
1352
+ source_files_added = "NA"
1353
+ elif langchain_mode in ['wiki_full']:
1354
+ source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \
1355
+ " Ask [email protected] for file if required."
1356
+ elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None:
1357
+ db_get = db1[0].get()
1358
+ source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
1359
+ elif langchain_mode in dbs and dbs[langchain_mode] is not None:
1360
+ db1 = dbs[langchain_mode]
1361
+ db_get = db1.get()
1362
+ source_files_added = '\n'.join(sorted(set([x['source'] for x in db_get['metadatas']])))
1363
+ else:
1364
+ source_files_added = "None"
1365
+ sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))
1366
+ with open(sources_file, "wt") as f:
1367
+ f.write(source_files_added)
1368
+ return sources_file
1369
+
1370
+
1371
+ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData', **kwargs):
1372
+ try:
1373
+ return _update_user_db(file, db1, x, y, *args, dbs=dbs, langchain_mode=langchain_mode, **kwargs)
1374
+ except BaseException as e:
1375
+ print(traceback.format_exc(), flush=True)
1376
+ # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox
1377
+ ex_str = "Exception: %s" % str(e)
1378
+ source_files_added = """\
1379
+ <html>
1380
+ <body>
1381
+ <p>
1382
+ Sources: <br>
1383
+ </p>
1384
+ <div style="overflow-y: auto;height:400px">
1385
+ {0}
1386
+ </div>
1387
+ </body>
1388
+ </html>
1389
+ """.format(ex_str)
1390
+ if langchain_mode == 'MyData':
1391
+ return db1, x, y, source_files_added
1392
+ else:
1393
+ return x, y, source_files_added
1394
+
1395
+
1396
+ def _update_user_db(file, db1, x, y, dbs=None, db_type=None, langchain_mode='UserData', use_openai_embedding=False,
1397
+ hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
1398
+ caption_loader=None,
1399
+ enable_captions=True,
1400
+ captions_model="Salesforce/blip-image-captioning-base",
1401
+ enable_ocr=False,
1402
+ verbose=False,
1403
+ chunk=True, chunk_size=512, is_url=False, is_txt=False):
1404
+ assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs))
1405
+ assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type
1406
+ from gpt_langchain import add_to_db, get_db, path_to_docs
1407
+ # handle case of list of temp buffer
1408
+ if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'):
1409
+ file = [x.name for x in file]
1410
+ # handle single file of temp buffer
1411
+ if hasattr(file, 'name'):
1412
+ file = file.name
1413
+ if verbose:
1414
+ print("Adding %s" % file, flush=True)
1415
+ sources = path_to_docs(file if not is_url and not is_txt else None,
1416
+ verbose=verbose, chunk=chunk, chunk_size=chunk_size,
1417
+ url=file if is_url else None,
1418
+ text=file if is_txt else None,
1419
+ enable_captions=enable_captions,
1420
+ captions_model=captions_model,
1421
+ enable_ocr=enable_ocr,
1422
+ caption_loader=caption_loader,
1423
+ )
1424
+ exceptions = [x for x in sources if x.metadata.get('exception')]
1425
+ sources = [x for x in sources if 'exception' not in x.metadata]
1426
+
1427
+ with filelock.FileLock("db_%s.lock" % langchain_mode.replace(' ', '_')):
1428
+ if langchain_mode == 'MyData':
1429
+ if db1[0] is not None:
1430
+ # then add
1431
+ add_to_db(db1[0], sources, db_type=db_type)
1432
+ else:
1433
+ assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1
1434
+ # then create
1435
+ # assign fresh hash for this user session, so not shared
1436
+ # if added has to original state and didn't change, then would be shared db for all users
1437
+ db1[1] = str(uuid.uuid4())
1438
+ persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1]))
1439
+ db1[0] = get_db(sources, use_openai_embedding=use_openai_embedding,
1440
+ db_type=db_type,
1441
+ persist_directory=persist_directory,
1442
+ langchain_mode=langchain_mode,
1443
+ hf_embedding_model=hf_embedding_model)
1444
+ if db1[0] is None:
1445
+ db1[1] = None
1446
+ source_files_added = get_source_files(db1[0], exceptions=exceptions)
1447
+ return db1, x, y, source_files_added
1448
+ else:
1449
+ persist_directory = 'db_dir_%s' % langchain_mode
1450
+ if langchain_mode in dbs and dbs[langchain_mode] is not None:
1451
+ # then add
1452
+ add_to_db(dbs[langchain_mode], sources, db_type=db_type)
1453
+ else:
1454
+ # then create
1455
+ db = get_db(sources, use_openai_embedding=use_openai_embedding,
1456
+ db_type=db_type,
1457
+ persist_directory=persist_directory,
1458
+ langchain_mode=langchain_mode,
1459
+ hf_embedding_model=hf_embedding_model)
1460
+ dbs[langchain_mode] = db
1461
+ # NOTE we do not return db, because function call always same code path
1462
+ # return dbs[langchain_mode], x, y
1463
+ # db in this code path is updated in place
1464
+ source_files_added = get_source_files(dbs[langchain_mode], exceptions=exceptions)
1465
+ return x, y, source_files_added
1466
+
1467
+
1468
+ def get_source_files(db, exceptions=None):
1469
+ if exceptions is None:
1470
+ exceptions = []
1471
+
1472
+ if db is not None:
1473
+ metadatas = db.get()['metadatas']
1474
+ else:
1475
+ metadatas = []
1476
+
1477
+ # below automatically de-dups
1478
+ from gpt_langchain import get_url
1479
+ small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('head')) for x in
1480
+ metadatas}
1481
+ # if small_dict is empty dict, that's ok
1482
+ df = pd.DataFrame(small_dict.items(), columns=['source', 'head'])
1483
+ df.index = df.index + 1
1484
+ df.index.name = 'index'
1485
+ source_files_added = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
1486
+
1487
+ if exceptions:
1488
+ exception_metadatas = [x.metadata for x in exceptions]
1489
+ small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('exception')) for x in
1490
+ exception_metadatas}
1491
+ # if small_dict is empty dict, that's ok
1492
+ df = pd.DataFrame(small_dict.items(), columns=['source', 'exception'])
1493
+ df.index = df.index + 1
1494
+ df.index.name = 'index'
1495
+ exceptions_html = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml')
1496
+ else:
1497
+ exceptions_html = ''
1498
+
1499
+ if metadatas and exceptions:
1500
+ source_files_added = """\
1501
+ <html>
1502
+ <body>
1503
+ <p>
1504
+ Sources: <br>
1505
+ </p>
1506
+ <div style="overflow-y: auto;height:400px">
1507
+ {0}
1508
+ {1}
1509
+ </div>
1510
+ </body>
1511
+ </html>
1512
+ """.format(source_files_added, exceptions_html)
1513
+ elif metadatas:
1514
+ source_files_added = """\
1515
+ <html>
1516
+ <body>
1517
+ <p>
1518
+ Sources: <br>
1519
+ </p>
1520
+ <div style="overflow-y: auto;height:400px">
1521
+ {0}
1522
+ </div>
1523
+ </body>
1524
+ </html>
1525
+ """.format(source_files_added)
1526
+ elif exceptions_html:
1527
+ source_files_added = """\
1528
+ <html>
1529
+ <body>
1530
+ <p>
1531
+ Exceptions: <br>
1532
+ </p>
1533
+ <div style="overflow-y: auto;height:400px">
1534
+ {0}
1535
+ </div>
1536
+ </body>
1537
+ </html>
1538
+ """.format(exceptions_html)
1539
+ else:
1540
+ source_files_added = ""
1541
+
1542
+ return source_files_added
gradio_themes.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
  from gradio.themes.soft import Soft
3
- from gradio.themes.utils import Color, colors, sizes
 
4
 
5
  h2o_yellow = Color(
6
  name="yellow",
@@ -74,6 +75,7 @@ class H2oTheme(Soft):
74
  body_background_fill_dark="*neutral_900",
75
  background_fill_primary_dark="*block_background_fill",
76
  block_radius="0 0 8px 8px",
 
77
  )
78
 
79
 
 
1
  from __future__ import annotations
2
  from gradio.themes.soft import Soft
3
+ from gradio.themes import Color
4
+ from gradio.themes.utils import colors, sizes
5
 
6
  h2o_yellow = Color(
7
  name="yellow",
 
75
  body_background_fill_dark="*neutral_900",
76
  background_fill_primary_dark="*block_background_fill",
77
  block_radius="0 0 8px 8px",
78
+ checkbox_label_text_color_selected_dark='#000000',
79
  )
80
 
81
 
h2oai_pipeline.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextGenerationPipeline
2
+ from transformers.pipelines.text_generation import ReturnType
3
+
4
+ from stopping import get_stopping
5
+
6
+ prompt_type = "human_bot"
7
+ human = "<human>:"
8
+ bot = "<bot>:"
9
+
10
+ # human-bot interaction like OIG dataset
11
+ prompt = """{human} {instruction}
12
+ {bot}""".format(
13
+ human=human,
14
+ instruction="{instruction}",
15
+ bot=bot,
16
+ )
17
+
18
+
19
+ class H2OTextGenerationPipeline(TextGenerationPipeline):
20
+ def __init__(self, *args, use_prompter=False, debug=False, chat=False, stream_output=False,
21
+ sanitize_bot_response=True, **kwargs):
22
+ super().__init__(*args, **kwargs)
23
+ self.use_prompter = use_prompter
24
+ self.prompt_text = None
25
+ if self.use_prompter:
26
+ from prompter import Prompter
27
+ self.prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output)
28
+ else:
29
+ self.prompter = None
30
+ self.sanitize_bot_response = sanitize_bot_response
31
+
32
+ def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
33
+ prompt_text = prompt.format(instruction=prompt_text)
34
+ self.prompt_text = prompt_text
35
+ return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
36
+ **generate_kwargs)
37
+
38
+ def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True):
39
+ records = super().postprocess(model_outputs, return_type=return_type,
40
+ clean_up_tokenization_spaces=clean_up_tokenization_spaces)
41
+ for rec in records:
42
+ if self.use_prompter:
43
+ outputs = rec['generated_text']
44
+ outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
45
+ sanitize_bot_response=self.sanitize_bot_response)
46
+ else:
47
+ outputs = rec['generated_text'].split(bot)[1].strip().split(human)[0].strip()
48
+ rec['generated_text'] = outputs
49
+ return records
50
+
51
+ def _forward(self, model_inputs, **generate_kwargs):
52
+ stopping_criteria = get_stopping(prompt_type, self.tokenizer, self.device, human=human, bot=bot)
53
+ generate_kwargs['stopping_criteria'] = stopping_criteria
54
+ return super()._forward(model_inputs, **generate_kwargs)
loaders.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_loaders(llama_type, model_name, reward_type):
2
+ # NOTE: Some models need specific new prompt_type
3
+ # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
4
+ if llama_type:
5
+ from transformers import LlamaForCausalLM, LlamaTokenizer
6
+ model_loader = LlamaForCausalLM
7
+ tokenizer_loader = LlamaTokenizer
8
+ elif 'distilgpt2' in model_name.lower():
9
+ from transformers import AutoModelForCausalLM, AutoTokenizer
10
+ return AutoModelForCausalLM, AutoTokenizer
11
+ elif 'gpt2' in model_name.lower():
12
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
13
+ return GPT2LMHeadModel, GPT2Tokenizer
14
+ elif 'mbart-' in model_name.lower():
15
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
16
+ return MBartForConditionalGeneration, MBart50TokenizerFast
17
+ elif 't5' == model_name.lower() or \
18
+ 't5-' in model_name.lower() or \
19
+ 'flan-' in model_name.lower():
20
+ from transformers import AutoTokenizer, T5ForConditionalGeneration
21
+ return T5ForConditionalGeneration, AutoTokenizer
22
+ elif 'bigbird' in model_name:
23
+ from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer
24
+ return BigBirdPegasusForConditionalGeneration, AutoTokenizer
25
+ elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name:
26
+ from transformers import pipeline
27
+ return pipeline, "summarization"
28
+ elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower():
29
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
30
+ return AutoModelForSequenceClassification, AutoTokenizer
31
+ else:
32
+ from transformers import AutoTokenizer, AutoModelForCausalLM
33
+ model_loader = AutoModelForCausalLM
34
+ tokenizer_loader = AutoTokenizer
35
+ return model_loader, tokenizer_loader
36
+
37
+
38
+ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token):
39
+ tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
40
+ local_files_only=local_files_only,
41
+ resume_download=resume_download,
42
+ use_auth_token=use_auth_token)
43
+
44
+ tokenizer.pad_token_id = 0 # different from the eos token
45
+ # when generating, we will use the logits of right-most token to predict the next token
46
+ # so the padding should be on the left,
47
+ # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference
48
+ tokenizer.padding_side = "left" # Allow batched inference
49
+
50
+ return tokenizer
prompter.py CHANGED
@@ -1,4 +1,355 @@
1
- from finetune import generate_prompt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  class Prompter(object):
@@ -13,6 +364,12 @@ class Prompter(object):
13
  self.stream_output = stream_output
14
  self.repeat_penalty = repeat_penalty
15
  self.allowed_repeat_line_length = allowed_repeat_line_length
 
 
 
 
 
 
16
 
17
  def generate_prompt(self, data_point):
18
  reduced = False
@@ -55,6 +412,18 @@ class Prompter(object):
55
  for oi, output in enumerate(outputs):
56
  if self.prompt_type in [0, '0', 'plain']:
57
  output = clean_response(output)
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
  # find first instance of prereponse
60
  # prompt sometimes has odd characters, that mutate length,
 
1
+ import time
2
+ from enum import Enum
3
+
4
+
5
+ class PromptType(Enum):
6
+ plain = 0
7
+ instruct = 1
8
+ quality = 2
9
+ human_bot = 3
10
+ dai_faq = 4
11
+ summarize = 5
12
+ simple_instruct = 6
13
+ instruct_vicuna = 7
14
+ instruct_with_end = 8
15
+ human_bot_orig = 9
16
+ prompt_answer = 10
17
+ open_assistant = 11
18
+ wizard_lm = 12
19
+ wizard_mega = 13
20
+
21
+
22
+ prompt_type_to_model_name = {
23
+ 'plain': [
24
+ 'EleutherAI/gpt-j-6B',
25
+ 'EleutherAI/pythia-6.9b',
26
+ 'EleutherAI/pythia-12b',
27
+ 'EleutherAI/pythia-12b-deduped',
28
+ 'EleutherAI/gpt-neox-20b',
29
+ 'decapoda-research/llama-7b-hf',
30
+ 'decapoda-research/llama-13b-hf',
31
+ 'decapoda-research/llama-30b-hf',
32
+ 'decapoda-research/llama-65b-hf',
33
+ 'facebook/mbart-large-50-many-to-many-mmt',
34
+ 'philschmid/bart-large-cnn-samsum',
35
+ 'philschmid/flan-t5-base-samsum',
36
+ 'gpt2',
37
+ 'distilgpt2',
38
+ 'mosaicml/mpt-7b-storywriter',
39
+ 'mosaicml/mpt-7b-instruct', # internal code handles instruct
40
+ 'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
41
+ 'gptj', # internally handles prompting
42
+ 'llama', # internally handles prompting
43
+ ],
44
+ 'prompt_answer': [
45
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
46
+ 'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
47
+ 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
48
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
49
+ 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
50
+ ],
51
+ 'instruct': [],
52
+ 'instruct_with_end': ['databricks/dolly-v2-12b'],
53
+ 'quality': [],
54
+ 'human_bot': [
55
+ 'h2oai/h2ogpt-oasst1-512-12b',
56
+ 'h2oai/h2ogpt-oasst1-512-20b',
57
+ 'h2oai/h2ogpt-oig-oasst1-256-6_9b',
58
+ 'h2oai/h2ogpt-oig-oasst1-512-6_9b',
59
+ 'h2oai/h2ogpt-research-oasst1-512-30b', # private
60
+ ],
61
+ 'dai_faq': [],
62
+ 'summarize': [],
63
+ 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
64
+ 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'],
65
+ 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
66
+ "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
67
+ "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
68
+ "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
69
+ }
70
+
71
+ inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
72
+ inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
73
+
74
+ prompt_types_strings = []
75
+ for p in PromptType:
76
+ prompt_types_strings.extend([p.name])
77
+
78
+ prompt_types = []
79
+ for p in PromptType:
80
+ prompt_types.extend([p.name, p.value, str(p.value)])
81
+
82
+
83
+ def get_prompt(prompt_type, chat, context, reduced):
84
+ if prompt_type in [-1, "-1", "plain"]:
85
+ promptA = promptB = PreInstruct = PreInput = PreResponse = ''
86
+ terminate_response = []
87
+ chat_sep = ''
88
+ humanstr = ''
89
+ botstr = ''
90
+ elif prompt_type == 'simple_instruct':
91
+ promptA = promptB = PreInstruct = PreInput = PreResponse = None
92
+ terminate_response = []
93
+ chat_sep = '\n'
94
+ humanstr = ''
95
+ botstr = ''
96
+ elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
97
+ promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (
98
+ chat and reduced) else ''
99
+ promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (
100
+ chat and reduced) else ''
101
+
102
+ PreInstruct = """
103
+ ### Instruction:
104
+ """
105
+
106
+ PreInput = """
107
+ ### Input:
108
+ """
109
+
110
+ PreResponse = """
111
+ ### Response:
112
+ """
113
+ if prompt_type in [7, "7", "instruct_with_end"]:
114
+ terminate_response = ['### End']
115
+ else:
116
+ terminate_response = None
117
+ chat_sep = '\n'
118
+ humanstr = PreInstruct
119
+ botstr = PreResponse
120
+ elif prompt_type in [1, "1", "quality"]:
121
+ promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (
122
+ chat and reduced) else ''
123
+ promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (
124
+ chat and reduced) else ''
125
+
126
+ PreInstruct = """
127
+ ### Instruction:
128
+ """
129
+
130
+ PreInput = """
131
+ ### Input:
132
+ """
133
+
134
+ PreResponse = """
135
+ ### Response:
136
+ """
137
+ terminate_response = None
138
+ chat_sep = '\n'
139
+ humanstr = PreInstruct # first thing human says
140
+ botstr = PreResponse # first thing bot says
141
+ elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
142
+ human = '<human>:'
143
+ bot = "<bot>:"
144
+ if reduced or context or prompt_type in [2, "2", "human_bot"]:
145
+ preprompt = ''
146
+ else:
147
+ cur_date = time.strftime('%Y-%m-%d')
148
+ cur_time = time.strftime('%H:%M:%S %p %Z')
149
+
150
+ PRE_PROMPT = """\
151
+ Current Date: {}
152
+ Current Time: {}
153
+
154
+ """
155
+ preprompt = PRE_PROMPT.format(cur_date, cur_time)
156
+ start = human
157
+ promptB = promptA = '%s%s ' % (preprompt, start)
158
+
159
+ PreInstruct = ""
160
+
161
+ PreInput = None
162
+
163
+ if reduced:
164
+ # when making context, want it to appear as-if LLM generated, which starts with space after :
165
+ PreResponse = bot + ' '
166
+ else:
167
+ # normally LLM adds space after this, because was how trained.
168
+ # if add space here, non-unique tokenization will often make LLM produce wrong output
169
+ PreResponse = bot
170
+
171
+ terminate_response = [start, PreResponse]
172
+ chat_sep = '\n'
173
+ humanstr = human # tag before human talks
174
+ botstr = bot # tag before bot talks
175
+ elif prompt_type in [3, "3", "dai_faq"]:
176
+ promptA = ''
177
+ promptB = 'Answer the following Driverless AI question.\n'
178
+
179
+ PreInstruct = """
180
+ ### Driverless AI frequently asked question:
181
+ """
182
+
183
+ PreInput = None
184
+
185
+ PreResponse = """
186
+ ### Driverless AI documentation answer:
187
+ """
188
+ terminate_response = ['\n\n']
189
+ chat_sep = terminate_response
190
+ humanstr = PreInstruct
191
+ botstr = PreResponse
192
+ elif prompt_type in [5, "5", "summarize"]:
193
+ promptA = promptB = PreInput = ''
194
+ PreInstruct = '## Main Text\n\n'
195
+ PreResponse = '\n\n## Summary\n\n'
196
+ terminate_response = None
197
+ chat_sep = '\n'
198
+ humanstr = PreInstruct
199
+ botstr = PreResponse
200
+ elif prompt_type in [6, "6", "instruct_vicuna"]:
201
+ promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
202
+ "The assistant gives helpful, detailed, and polite answers to the human's questions." if not (
203
+ chat and reduced) else ''
204
+
205
+ PreInstruct = """
206
+ ### Human:
207
+ """
208
+
209
+ PreInput = None
210
+
211
+ PreResponse = """
212
+ ### Assistant:
213
+ """
214
+ terminate_response = [
215
+ '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
216
+ chat_sep = '\n'
217
+ humanstr = PreInstruct
218
+ botstr = PreResponse
219
+ elif prompt_type in [10, "10", "prompt_answer"]:
220
+ preprompt = ''
221
+ prompt_tokens = "<|prompt|>"
222
+ answer_tokens = "<|answer|>"
223
+ start = prompt_tokens
224
+ promptB = promptA = '%s%s' % (preprompt, start)
225
+ PreInstruct = ""
226
+ PreInput = None
227
+ PreResponse = answer_tokens
228
+ eos = '<|endoftext|>' # neox eos
229
+ terminate_response = [start, PreResponse, eos]
230
+ chat_sep = eos
231
+ humanstr = prompt_tokens
232
+ botstr = answer_tokens
233
+ elif prompt_type in [11, "11", "open_assistant"]:
234
+ # From added_tokens.json
235
+ preprompt = ''
236
+ prompt_tokens = "<|prompter|>"
237
+ answer_tokens = "<|assistant|>"
238
+ start = prompt_tokens
239
+ promptB = promptA = '%s%s' % (preprompt, start)
240
+ PreInstruct = ""
241
+ PreInput = None
242
+ PreResponse = answer_tokens
243
+ pend = "<|prefix_end|>"
244
+ eos = "</s>"
245
+ terminate_response = [start, PreResponse, pend, eos]
246
+ chat_sep = eos
247
+ humanstr = prompt_tokens
248
+ botstr = answer_tokens
249
+ elif prompt_type in [12, "12", "wizard_lm"]:
250
+ # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
251
+ preprompt = ''
252
+ start = ''
253
+ promptB = promptA = '%s%s' % (preprompt, start)
254
+ PreInstruct = ""
255
+ PreInput = None
256
+ PreResponse = "\n\n### Response"
257
+ eos = "</s>"
258
+ terminate_response = [PreResponse, eos]
259
+ chat_sep = eos
260
+ humanstr = promptA
261
+ botstr = PreResponse
262
+ elif prompt_type in [13, "13", "wizard_mega"]:
263
+ preprompt = ''
264
+ start = ''
265
+ promptB = promptA = '%s%s' % (preprompt, start)
266
+ PreInstruct = """
267
+ ### Instruction:
268
+ """
269
+ PreInput = None
270
+ PreResponse = """
271
+ ### Assistant:
272
+ """
273
+ terminate_response = [PreResponse]
274
+ chat_sep = '\n'
275
+ humanstr = PreInstruct
276
+ botstr = PreResponse
277
+ else:
278
+ raise RuntimeError("No such prompt_type=%s" % prompt_type)
279
+
280
+ return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep, humanstr, botstr
281
+
282
+
283
+ def generate_prompt(data_point, prompt_type, chat, reduced):
284
+ context = data_point.get('context')
285
+ if context is None:
286
+ context = ''
287
+ instruction = data_point.get('instruction')
288
+ input = data_point.get('input')
289
+ output = data_point.get('output')
290
+ prompt_type = data_point.get('prompt_type', prompt_type)
291
+ assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
292
+ promptA, promptB, PreInstruct, PreInput, PreResponse, \
293
+ terminate_response, chat_sep, humanstr, botstr = get_prompt(prompt_type, chat, context, reduced)
294
+
295
+ prompt = context if not reduced else ''
296
+
297
+ if input and promptA:
298
+ prompt += f"""{promptA}"""
299
+ elif promptB:
300
+ prompt += f"""{promptB}"""
301
+
302
+ if instruction and PreInstruct is not None and input and PreInput is not None:
303
+ prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
304
+ prompt = inject_newline(prompt_type, prompt)
305
+ elif instruction and input and PreInstruct is None and PreInput is not None:
306
+ prompt += f"""{PreInput}{instruction}
307
+ {input}"""
308
+ prompt = inject_newline(prompt_type, prompt)
309
+ elif input and instruction and PreInput is None and PreInstruct is not None:
310
+ prompt += f"""{PreInstruct}{instruction}
311
+ {input}"""
312
+ prompt = inject_newline(prompt_type, prompt)
313
+ elif instruction and PreInstruct is not None:
314
+ prompt += f"""{PreInstruct}{instruction}"""
315
+ prompt = inject_newline(prompt_type, prompt)
316
+ elif input and PreInput is not None:
317
+ prompt += f"""{PreInput}{input}"""
318
+ prompt = inject_newline(prompt_type, prompt)
319
+ elif input and instruction and PreInput is not None:
320
+ prompt += f"""{PreInput}{instruction}{input}"""
321
+ prompt = inject_newline(prompt_type, prompt)
322
+ elif input and instruction and PreInstruct is not None:
323
+ prompt += f"""{PreInstruct}{instruction}{input}"""
324
+ prompt = inject_newline(prompt_type, prompt)
325
+ elif input and instruction:
326
+ # i.e. for simple_instruct
327
+ prompt += f"""{instruction}: {input}"""
328
+ prompt = inject_newline(prompt_type, prompt)
329
+ elif input:
330
+ prompt += f"""{input}"""
331
+ prompt = inject_newline(prompt_type, prompt)
332
+ elif instruction:
333
+ prompt += f"""{instruction}"""
334
+ prompt = inject_newline(prompt_type, prompt)
335
+
336
+ if PreResponse is not None:
337
+ prompt += f"""{PreResponse}"""
338
+ pre_response = PreResponse # Don't use strip
339
+ else:
340
+ pre_response = ''
341
+
342
+ if output:
343
+ prompt += f"""{output}"""
344
+
345
+ return prompt, pre_response, terminate_response, chat_sep
346
+
347
+
348
+ def inject_newline(prompt_type, prompt):
349
+ if prompt_type not in [-1, '-1', 'plain', 'simple_instruct']:
350
+ # only add new line if structured prompt, while 'plain' is just generation of next tokens from input
351
+ prompt += '\n'
352
+ return prompt
353
 
354
 
355
  class Prompter(object):
 
364
  self.stream_output = stream_output
365
  self.repeat_penalty = repeat_penalty
366
  self.allowed_repeat_line_length = allowed_repeat_line_length
367
+ self.prompt = None
368
+ context = "" # not for chat context
369
+ reduced = False # not for chat context
370
+ self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
371
+ self.terminate_response, self.chat_sep, self.humanstr, self.botstr = \
372
+ get_prompt(prompt_type, chat, context, reduced)
373
 
374
  def generate_prompt(self, data_point):
375
  reduced = False
 
412
  for oi, output in enumerate(outputs):
413
  if self.prompt_type in [0, '0', 'plain']:
414
  output = clean_response(output)
415
+ elif prompt is None:
416
+ # then use most basic parsing like pipeline
417
+ if self.botstr in output:
418
+ if self.humanstr:
419
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.humanstr)[0].strip())
420
+ else:
421
+ # i.e. use after bot but only up to next bot
422
+ output = clean_response(output.split(self.botstr)[1].strip().split(self.botstr)[0].strip())
423
+ else:
424
+ # output = clean_response(output.strip())
425
+ # assume just not printed yet
426
+ output = ""
427
  else:
428
  # find first instance of prereponse
429
  # prompt sometimes has odd characters, that mutate length,
requirements.txt CHANGED
@@ -2,7 +2,7 @@
2
  datasets==2.12.0
3
  sentencepiece==0.1.97
4
  accelerate==0.18.0
5
- gradio==3.27.0
6
  huggingface_hub==0.14.1
7
  appdirs==1.4.4
8
  fire==0.5.0
@@ -35,7 +35,7 @@ tensorboard==2.12.1
35
  neptune==1.1.1
36
 
37
  # for gradio client
38
- gradio_client==0.1.3
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.1
41
 
@@ -45,7 +45,55 @@ pytest-xdist==3.2.1
45
  nltk==3.8.1
46
  textstat==0.7.3
47
  pandoc==2.3
48
- pypandoc==1.11
 
49
  openpyxl==3.1.2
50
  lm_dataformat==0.0.20
51
  bioc==2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  datasets==2.12.0
3
  sentencepiece==0.1.97
4
  accelerate==0.18.0
5
+ gradio==3.31.0
6
  huggingface_hub==0.14.1
7
  appdirs==1.4.4
8
  fire==0.5.0
 
35
  neptune==1.1.1
36
 
37
  # for gradio client
38
+ gradio_client==0.2.5
39
  beautifulsoup4==4.12.2
40
  markdown==3.4.1
41
 
 
45
  nltk==3.8.1
46
  textstat==0.7.3
47
  pandoc==2.3
48
+ #pypandoc==1.11
49
+ pypandoc_binary==1.11
50
  openpyxl==3.1.2
51
  lm_dataformat==0.0.20
52
  bioc==2.0
53
+ # To install with constraints
54
+ # grep -v '#\|peft' requirements.txt > req_constraints.txt ; pip install -r requirements_optional_langchain.txt -c req_constraints.txt
55
+
56
+ # optional for chat with PDF
57
+ langchain==0.0.178
58
+ pypdf==3.8.1
59
+ tiktoken==0.3.3
60
+ # avoid textract, requires old six
61
+ #textract==1.6.5
62
+ # choose:
63
+ #faiss-cpu
64
+ faiss-gpu==1.7.2
65
+
66
+ # for HF embeddings
67
+ sentence_transformers==2.2.2
68
+ # for OpenAI embeddings (requires key)
69
+ openai==0.27.6
70
+
71
+ # local vector db
72
+ chromadb==0.3.23
73
+ # server vector db
74
+ #pymilvus==2.2.8
75
+
76
+ # weak url support, if can't install opencv etc. If comment-in this one, then comment-out unstructured[local-inference]==0.6.6
77
+ # unstructured==0.6.6
78
+
79
+ # strong support for images
80
+ # Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
81
+ unstructured[local-inference]==0.6.6
82
+ #pdf2image==1.16.3
83
+ #pytesseract==0.3.10
84
+ pillow
85
+
86
+ pdfminer.six==20221105
87
+ urllib3==1.26.6
88
+ requests_file==1.5.1
89
+
90
+ #pdf2image==1.16.3
91
+ #pytesseract==0.3.10
92
+ tabulate==0.9.0
93
+ # FYI pandoc already part of requirements.txt
94
+
95
+ jq==1.4.1
96
+
97
+ # to check licenses
98
+ # Run: pip-licenses|grep -v 'BSD\|Apache\|MIT'
99
+ pip-licenses==4.3.0
utils.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import functools
 
2
  import os
3
  import gc
4
  import pathlib
@@ -12,6 +14,9 @@ import traceback
12
  import zipfile
13
  from datetime import datetime
14
  import filelock
 
 
 
15
  import numpy as np
16
  import pandas as pd
17
 
@@ -53,7 +58,11 @@ def clear_torch_cache():
53
 
54
 
55
  def ping():
56
- print('Ping: %s' % str(datetime.now()), flush=True)
 
 
 
 
57
 
58
 
59
  def get_torch_allocated():
@@ -61,6 +70,16 @@ def get_torch_allocated():
61
  return torch.cuda.memory_allocated()
62
 
63
 
 
 
 
 
 
 
 
 
 
 
64
  def system_info():
65
  import psutil
66
 
@@ -111,21 +130,26 @@ def system_info_print():
111
  return "Error: %s" % str(e)
112
 
113
 
114
- def zip_data(root_dirs=None, zip_file=None, base_dir='./'):
115
  try:
116
  return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
117
  except Exception as e:
118
  traceback.print_exc()
119
  print('Exception in zipping: %s' % str(e))
 
 
120
 
121
 
122
  def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
 
 
123
  if zip_file is None:
124
  datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
125
  host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
126
  zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
127
  assert root_dirs is not None
128
-
 
129
  with zipfile.ZipFile(zip_file, "w") as expt_zip:
130
  for root_dir in root_dirs:
131
  if root_dir is None:
@@ -237,6 +261,7 @@ class NullContext(threading.local):
237
  Used as a stand-in if a particular block of code is only sometimes
238
  used with a normal context manager:
239
  """
 
240
  def __init__(self, *args, **kwargs):
241
  pass
242
 
@@ -270,16 +295,18 @@ class ThreadException(Exception):
270
  class EThread(threading.Thread):
271
  # Function that raises the custom exception
272
  def __init__(self, group=None, target=None, name=None,
273
- args=(), kwargs=None, *, daemon=None, bucket=None):
274
  self.bucket = bucket
275
- self.streamer = kwargs.get('streamer')
276
  self.exc = None
 
277
  super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
278
 
279
  def run(self):
280
  # Variable that stores the exception, if raised by someFunction
281
  try:
282
- super().run()
 
283
  except BaseException as e:
284
  print("thread exception: %s" % str(sys.exc_info()))
285
  self.bucket.put(sys.exc_info())
@@ -287,6 +314,10 @@ class EThread(threading.Thread):
287
  if self.streamer:
288
  print("make stop: %s" % str(sys.exc_info()), flush=True)
289
  self.streamer.do_stop = True
 
 
 
 
290
 
291
  def join(self, timeout=None):
292
  threading.Thread.join(self)
@@ -295,3 +326,443 @@ class EThread(threading.Thread):
295
  # if any was caught
296
  if self.exc:
297
  raise self.exc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
  import functools
3
+ import hashlib
4
  import os
5
  import gc
6
  import pathlib
 
14
  import zipfile
15
  from datetime import datetime
16
  import filelock
17
+ import requests, uuid
18
+ from typing import Tuple, Callable, Dict
19
+ from concurrent.futures import ProcessPoolExecutor
20
  import numpy as np
21
  import pandas as pd
22
 
 
58
 
59
 
60
  def ping():
61
+ try:
62
+ print('Ping: %s' % str(datetime.now()), flush=True)
63
+ except AttributeError:
64
+ # some programs wrap print and will fail with flush passed
65
+ pass
66
 
67
 
68
  def get_torch_allocated():
 
70
  return torch.cuda.memory_allocated()
71
 
72
 
73
+ def get_device():
74
+ import torch
75
+ if torch.cuda.is_available():
76
+ device = "cuda"
77
+ else:
78
+ device = "cpu"
79
+
80
+ return device
81
+
82
+
83
  def system_info():
84
  import psutil
85
 
 
130
  return "Error: %s" % str(e)
131
 
132
 
133
+ def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False):
134
  try:
135
  return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs)
136
  except Exception as e:
137
  traceback.print_exc()
138
  print('Exception in zipping: %s' % str(e))
139
+ if not fail_any_exception:
140
+ raise
141
 
142
 
143
  def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
144
+ if isinstance(root_dirs, str):
145
+ root_dirs = [root_dirs]
146
  if zip_file is None:
147
  datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_")
148
  host_name = os.getenv('HF_HOSTNAME', 'emptyhost')
149
  zip_file = "data_%s_%s.zip" % (datetime_str, host_name)
150
  assert root_dirs is not None
151
+ if not os.path.isdir(os.path.dirname(zip_file)):
152
+ os.makedirs(os.path.dirname(zip_file), exist_ok=True)
153
  with zipfile.ZipFile(zip_file, "w") as expt_zip:
154
  for root_dir in root_dirs:
155
  if root_dir is None:
 
261
  Used as a stand-in if a particular block of code is only sometimes
262
  used with a normal context manager:
263
  """
264
+
265
  def __init__(self, *args, **kwargs):
266
  pass
267
 
 
295
  class EThread(threading.Thread):
296
  # Function that raises the custom exception
297
  def __init__(self, group=None, target=None, name=None,
298
+ args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None):
299
  self.bucket = bucket
300
+ self.streamer = streamer
301
  self.exc = None
302
+ self._return = None
303
  super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon)
304
 
305
  def run(self):
306
  # Variable that stores the exception, if raised by someFunction
307
  try:
308
+ if self._target is not None:
309
+ self._return = self._target(*self._args, **self._kwargs)
310
  except BaseException as e:
311
  print("thread exception: %s" % str(sys.exc_info()))
312
  self.bucket.put(sys.exc_info())
 
314
  if self.streamer:
315
  print("make stop: %s" % str(sys.exc_info()), flush=True)
316
  self.streamer.do_stop = True
317
+ finally:
318
+ # Avoid a refcycle if the thread is running a function with
319
+ # an argument that has a member that points to the thread.
320
+ del self._target, self._args, self._kwargs
321
 
322
  def join(self, timeout=None):
323
  threading.Thread.join(self)
 
326
  # if any was caught
327
  if self.exc:
328
  raise self.exc
329
+ return self._return
330
+
331
+
332
+ def import_matplotlib():
333
+ import matplotlib
334
+ matplotlib.use('agg')
335
+ # KEEP THESE HERE! START
336
+ import matplotlib.pyplot as plt
337
+ import pandas as pd
338
+ # to avoid dlopen deadlock in fork
339
+ import pandas.core.computation.expressions as pd_expressions
340
+ import pandas._libs.groupby as pd_libgroupby
341
+ import pandas._libs.reduction as pd_libreduction
342
+ import pandas.core.algorithms as pd_algorithms
343
+ import pandas.core.common as pd_com
344
+ import numpy as np
345
+ # KEEP THESE HERE! END
346
+
347
+
348
+ def get_sha(value):
349
+ return hashlib.md5(str(value).encode('utf-8')).hexdigest()
350
+
351
+
352
+ def sanitize_filename(name):
353
+ """
354
+ Sanitize file *base* names.
355
+ :param name: name to sanitize
356
+ :return:
357
+ """
358
+ bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^']
359
+ for char in bad_chars:
360
+ name = name.replace(char, "_")
361
+
362
+ length = len(name)
363
+ file_length_limit = 250 # bit smaller than 256 for safety
364
+ sha_length = 32
365
+ real_length_limit = file_length_limit - (sha_length + 2)
366
+ if length > file_length_limit:
367
+ sha = get_sha(name)
368
+ half_real_length_limit = max(1, int(real_length_limit / 2))
369
+ name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length]
370
+
371
+ return name
372
+
373
+
374
+ def shutil_rmtree_simple(*args, **kwargs):
375
+ path = args[0]
376
+ assert not os.path.samefile(path, "./tmp"), "Should not be trying to remove entire data directory: %s" % str(path)
377
+ # print("Removing path %s" % args[0]) # for debugging
378
+ return shutil.rmtree(*args, **kwargs)
379
+
380
+
381
+ def remove_simple(path: str):
382
+ try:
383
+ if path is not None and os.path.exists(path):
384
+ if os.path.isdir(path):
385
+ shutil_rmtree_simple(path, ignore_errors=True)
386
+ else:
387
+ with contextlib.suppress(FileNotFoundError):
388
+ os.remove(path)
389
+ except:
390
+ pass
391
+
392
+
393
+ def makedirs(path, exist_ok=True):
394
+ """
395
+ Avoid some inefficiency in os.makedirs()
396
+ :param path:
397
+ :param exist_ok:
398
+ :return:
399
+ """
400
+ if os.path.isdir(path) and os.path.exists(path):
401
+ assert exist_ok, "Path already exists"
402
+ return path
403
+ os.makedirs(path, exist_ok=exist_ok)
404
+
405
+
406
+ def atomic_move_simple(src, dst):
407
+ try:
408
+ shutil.move(src, dst)
409
+ except (shutil.Error, FileExistsError):
410
+ pass
411
+ remove_simple(src)
412
+
413
+
414
+ def download_simple(url, dest=None, print_func=None):
415
+ if print_func is not None:
416
+ print_func("BEGIN get url %s" % str(url))
417
+ if url.startswith("file://"):
418
+ from requests_file import FileAdapter
419
+ s = requests.Session()
420
+ s.mount('file://', FileAdapter())
421
+ url_data = s.get(url, stream=True)
422
+ else:
423
+ url_data = requests.get(url, stream=True)
424
+ if dest is None:
425
+ dest = os.path.basename(url)
426
+ if url_data.status_code != requests.codes.ok:
427
+ msg = "Cannot get url %s, code: %s, reason: %s" % (
428
+ str(url),
429
+ str(url_data.status_code),
430
+ str(url_data.reason),
431
+ )
432
+ raise requests.exceptions.RequestException(msg)
433
+ url_data.raw.decode_content = True
434
+ makedirs(os.path.dirname(dest), exist_ok=True)
435
+ uuid_tmp = str(uuid.uuid4())[:6]
436
+ dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp"
437
+ with open(dest_tmp, "wb") as f:
438
+ shutil.copyfileobj(url_data.raw, f)
439
+ atomic_move_simple(dest_tmp, dest)
440
+ if print_func is not None:
441
+ print_func("END get url %s" % str(url))
442
+
443
+
444
+ def download(url, dest=None, dest_path=None):
445
+ if dest_path is not None:
446
+ dest = os.path.join(dest_path, os.path.basename(url))
447
+ if os.path.isfile(dest):
448
+ print("already downloaded %s -> %s" % (url, dest))
449
+ return dest
450
+ elif dest is not None:
451
+ if os.path.exists(dest):
452
+ print("already downloaded %s -> %s" % (url, dest))
453
+ return dest
454
+ else:
455
+ uuid_tmp = "dl2_" + str(uuid.uuid4())[:6]
456
+ dest = uuid_tmp + os.path.basename(url)
457
+
458
+ print("downloading %s to %s" % (url, dest))
459
+
460
+ if url.startswith("file://"):
461
+ from requests_file import FileAdapter
462
+ s = requests.Session()
463
+ s.mount('file://', FileAdapter())
464
+ url_data = s.get(url, stream=True)
465
+ else:
466
+ url_data = requests.get(url, stream=True)
467
+
468
+ if url_data.status_code != requests.codes.ok:
469
+ msg = "Cannot get url %s, code: %s, reason: %s" % (
470
+ str(url), str(url_data.status_code), str(url_data.reason))
471
+ raise requests.exceptions.RequestException(msg)
472
+ url_data.raw.decode_content = True
473
+ dirname = os.path.dirname(dest)
474
+ if dirname != "" and not os.path.isdir(dirname):
475
+ makedirs(os.path.dirname(dest), exist_ok=True)
476
+ uuid_tmp = "dl3_" + str(uuid.uuid4())[:6]
477
+ dest_tmp = dest + "_" + uuid_tmp + ".tmp"
478
+ with open(dest_tmp, 'wb') as f:
479
+ shutil.copyfileobj(url_data.raw, f)
480
+ try:
481
+ shutil.move(dest_tmp, dest)
482
+ except FileExistsError:
483
+ pass
484
+ remove_simple(dest_tmp)
485
+ return dest
486
+
487
+
488
+ def get_url(x, from_str=False, short_name=False):
489
+ if not from_str:
490
+ source = x.metadata['source']
491
+ else:
492
+ source = x
493
+ if short_name:
494
+ source_name = get_short_name(source)
495
+ else:
496
+ source_name = source
497
+ if source.startswith('http://') or source.startswith('https://'):
498
+ return """<a href="%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
499
+ source, source_name)
500
+ else:
501
+ return """<a href="file/%s" target="_blank" rel="noopener noreferrer">%s</a>""" % (
502
+ source, source_name)
503
+
504
+
505
+ def get_short_name(name, maxl=50):
506
+ if name is None:
507
+ return ''
508
+ length = len(name)
509
+ if length > maxl:
510
+ allow_length = maxl - 3
511
+ half_allowed = max(1, int(allow_length / 2))
512
+ name = name[0:half_allowed] + "..." + name[length - half_allowed:length]
513
+ return name
514
+
515
+
516
+ def cuda_vis_check(total_gpus):
517
+ """Helper function to count GPUs by environment variable
518
+ Stolen from Jon's h2o4gpu utils
519
+ """
520
+ cudavis = os.getenv("CUDA_VISIBLE_DEVICES")
521
+ which_gpus = []
522
+ if cudavis is not None:
523
+ # prune away white-space, non-numerics,
524
+ # except commas for simple checking
525
+ cudavis = "".join(cudavis.split())
526
+ import re
527
+ cudavis = re.sub("[^0-9,]", "", cudavis)
528
+
529
+ lencudavis = len(cudavis)
530
+ if lencudavis == 0:
531
+ total_gpus = 0
532
+ else:
533
+ total_gpus = min(
534
+ total_gpus,
535
+ os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1)
536
+ which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",")
537
+ which_gpus = [int(x) for x in which_gpus]
538
+ else:
539
+ which_gpus = list(range(0, total_gpus))
540
+
541
+ return total_gpus, which_gpus
542
+
543
+
544
+ def get_ngpus_vis(raise_if_exception=True):
545
+ ngpus_vis1 = 0
546
+
547
+ shell = False
548
+ if shell:
549
+ cmd = "nvidia-smi -L 2> /dev/null"
550
+ else:
551
+ cmd = ["nvidia-smi", "-L"]
552
+
553
+ try:
554
+ timeout = 5 * 3
555
+ o = subprocess.check_output(cmd, shell=shell, timeout=timeout)
556
+ lines = o.decode("utf-8").splitlines()
557
+ ngpus_vis1 = 0
558
+ for line in lines:
559
+ if 'Failed to initialize NVML' not in line:
560
+ ngpus_vis1 += 1
561
+ except (FileNotFoundError, subprocess.CalledProcessError, OSError):
562
+ # GPU systems might not have nvidia-smi, so can't fail
563
+ pass
564
+ except subprocess.TimeoutExpired as e:
565
+ print('Failed get_ngpus_vis: %s' % str(e))
566
+ if raise_if_exception:
567
+ raise
568
+
569
+ ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1)
570
+ return ngpus_vis1
571
+
572
+
573
+ def get_mem_gpus(raise_if_exception=True, ngpus=None):
574
+ totalmem_gpus1 = 0
575
+ usedmem_gpus1 = 0
576
+ freemem_gpus1 = 0
577
+
578
+ if ngpus == 0:
579
+ return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
580
+
581
+ try:
582
+ cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'"
583
+ o = subprocess.check_output(cmd, shell=True, timeout=15)
584
+ lines = o.decode("utf-8").splitlines()
585
+ for line in lines:
586
+ if 'Total' in line:
587
+ totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2
588
+ if 'Used' in line:
589
+ usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2
590
+ if 'Free' in line:
591
+ freemem_gpus1 += int(line.split()[2]) * 1024 ** 2
592
+ except (FileNotFoundError, subprocess.CalledProcessError, OSError):
593
+ # GPU systems might not have nvidia-smi, so can't fail
594
+ pass
595
+ except subprocess.TimeoutExpired as e:
596
+ print('Failed get_mem_gpus: %s' % str(e))
597
+ if raise_if_exception:
598
+ raise
599
+
600
+ return totalmem_gpus1, usedmem_gpus1, freemem_gpus1
601
+
602
+
603
+ class ForkContext(threading.local):
604
+ """
605
+ Set context for forking
606
+ Ensures state is returned once done
607
+ """
608
+
609
+ def __init__(self, args=None, kwargs=None, forkdata_capable=True):
610
+ """
611
+ :param args:
612
+ :param kwargs:
613
+ :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs
614
+ """
615
+ self.forkdata_capable = forkdata_capable
616
+ if self.forkdata_capable:
617
+ self.has_args = args is not None
618
+ self.has_kwargs = kwargs is not None
619
+ forkdatacontext.args = args
620
+ forkdatacontext.kwargs = kwargs
621
+ else:
622
+ self.has_args = False
623
+ self.has_kwargs = False
624
+
625
+ def __enter__(self):
626
+ try:
627
+ # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts!
628
+ sys.stdout.flush()
629
+ sys.stderr.flush()
630
+ except BaseException as e:
631
+ # exit not called if exception, and don't want to leave forkdatacontext filled in that case
632
+ print("ForkContext failure on enter: %s" % str(e))
633
+ self.finally_act()
634
+ raise
635
+ return self
636
+
637
+ def __exit__(self, exc_type, exc_value, exc_traceback):
638
+ self.finally_act()
639
+
640
+ def finally_act(self):
641
+ """
642
+ Done when exception hit or exit is reached in context
643
+ first reset forkdatacontext as crucial to have reset even if later 2 calls fail
644
+ :return: None
645
+ """
646
+ if self.forkdata_capable and (self.has_args or self.has_kwargs):
647
+ forkdatacontext._reset()
648
+
649
+
650
+ class _ForkDataContext(threading.local):
651
+ def __init__(
652
+ self,
653
+ args=None,
654
+ kwargs=None,
655
+ ):
656
+ """
657
+ Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization
658
+
659
+ :param args: args
660
+ :param kwargs: kwargs
661
+ """
662
+ assert isinstance(args, (tuple, type(None)))
663
+ assert isinstance(kwargs, (dict, type(None)))
664
+ self.__args = args
665
+ self.__kwargs = kwargs
666
+
667
+ @property
668
+ def args(self) -> Tuple:
669
+ """returns args"""
670
+ return self.__args
671
+
672
+ @args.setter
673
+ def args(self, args):
674
+ if self.__args is not None:
675
+ raise AttributeError(
676
+ "args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
677
+ )
678
+
679
+ self.__args = args
680
+
681
+ @property
682
+ def kwargs(self) -> Dict:
683
+ """returns kwargs"""
684
+ return self.__kwargs
685
+
686
+ @kwargs.setter
687
+ def kwargs(self, kwargs):
688
+ if self.__kwargs is not None:
689
+ raise AttributeError(
690
+ "kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs))
691
+ )
692
+
693
+ self.__kwargs = kwargs
694
+
695
+ def _reset(self):
696
+ """Reset fork arg-kwarg context to default values"""
697
+ self.__args = None
698
+ self.__kwargs = None
699
+
700
+ def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]:
701
+ if self.__args:
702
+ args = self.__args[1:]
703
+ if not func:
704
+ assert len(self.__args) > 0, "if have no func, must have in args"
705
+ func = self.__args[0] # should always be there
706
+ if self.__kwargs:
707
+ kwargs = self.__kwargs
708
+ try:
709
+ return func, args, kwargs
710
+ finally:
711
+ forkdatacontext._reset()
712
+
713
+ @staticmethod
714
+ def get_args_kwargs_for_traced_func(func, args, kwargs):
715
+ """
716
+ Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs
717
+ :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function
718
+ :param args:
719
+ :param kwargs:
720
+ :return: func, args, kwargs from forkdatacontext if used, else originals
721
+ """
722
+ # first 3 lines are debug
723
+ func_was_None = func is None
724
+ args_was_None_or_empty = args is None or len(args) == 0
725
+ kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0
726
+
727
+ forkdatacontext_args_was_None = forkdatacontext.args is None
728
+ forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None
729
+ func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs)
730
+ using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0]
731
+ assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs"
732
+ assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs"
733
+
734
+ proc_type = kwargs.get('proc_type', 'SUBPROCESS')
735
+ if using_forkdatacontext:
736
+ assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS"
737
+ if proc_type == "NORMAL":
738
+ assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func"
739
+ assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func"
740
+ assert func is not None, "function should not be None, indicates original args[0] was None or args was None"
741
+
742
+ return func, args, kwargs
743
+
744
+
745
+ forkdatacontext = _ForkDataContext()
746
+
747
+
748
+ def _traced_func(func, *args, **kwargs):
749
+ func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs)
750
+ return func(*args, **kwargs)
751
+
752
+
753
+ def call_subprocess_onetask(func, args=None, kwargs=None):
754
+ if isinstance(args, list):
755
+ args = tuple(args)
756
+ if args is None:
757
+ args = ()
758
+ if kwargs is None:
759
+ kwargs = {}
760
+ args = list(args)
761
+ args = [func] + args
762
+ args = tuple(args)
763
+ with ForkContext(args=args, kwargs=kwargs):
764
+ args = (None,)
765
+ kwargs = {}
766
+ with ProcessPoolExecutor(max_workers=1) as executor:
767
+ future = executor.submit(_traced_func, *args, **kwargs)
768
+ return future.result()