Spaces:
Runtime error
Runtime error
pseudotensor
commited on
Commit
•
1e8c453
1
Parent(s):
663f03d
Update with h2oGPT hash 236c95819e80ab122193bfb843b55618ae285c39
Browse files- client_test.py +78 -20
- create_data.py +2 -2
- enums.py +38 -0
- finetune.py +11 -11
- generate.py +978 -170
- gpt4all_llm.py +64 -25
- gpt_langchain.py +789 -144
- gradio_runner.py +0 -0
- gradio_themes.py +46 -6
- gradio_utils/__pycache__/css.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/grclient.cpython-310.pyc +0 -0
- gradio_utils/__pycache__/prompt_form.cpython-310.pyc +0 -0
- gradio_utils/css.py +53 -0
- gradio_utils/grclient.py +82 -0
- gradio_utils/prompt_form.py +118 -0
- h2oai_pipeline.py +44 -24
- iterators/__init__.py +4 -0
- iterators/__pycache__/__init__.cpython-310.pyc +0 -0
- iterators/__pycache__/iterator_pipe.cpython-310.pyc +0 -0
- iterators/__pycache__/timeout_iterator.cpython-310.pyc +0 -0
- iterators/iterator_pipe.py +93 -0
- iterators/timeout_iterator.py +170 -0
- loaders.py +5 -2
- prompter.py +243 -99
- requirements.txt +35 -35
- stopping.py +9 -4
- utils.py +105 -32
- utils_langchain.py +64 -0
client_test.py
CHANGED
@@ -97,6 +97,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
97 |
chunk_size=512,
|
98 |
document_choice=[DocumentChoices.All_Relevant.name],
|
99 |
)
|
|
|
|
|
100 |
if chat:
|
101 |
# add chatbot output on end. Assumes serialize=False
|
102 |
kwargs.update(dict(chatbot=[]))
|
@@ -105,8 +107,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False,
|
|
105 |
|
106 |
|
107 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
108 |
-
def test_client_basic():
|
109 |
-
return run_client_nochat(prompt='Who are you?', prompt_type=
|
110 |
|
111 |
|
112 |
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
@@ -122,12 +124,12 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
122 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
123 |
response=md_to_text(res))
|
124 |
print(res_dict)
|
125 |
-
return res_dict
|
126 |
|
127 |
|
128 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
129 |
-
def test_client_basic_api():
|
130 |
-
return run_client_nochat_api(prompt='Who are you?', prompt_type=
|
131 |
|
132 |
|
133 |
def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
|
@@ -144,12 +146,12 @@ def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
|
|
144 |
response=md_to_text(ast.literal_eval(res)['response']),
|
145 |
sources=ast.literal_eval(res)['sources'])
|
146 |
print(res_dict)
|
147 |
-
return res_dict
|
148 |
|
149 |
|
150 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
151 |
-
def test_client_basic_api_lean():
|
152 |
-
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=
|
153 |
|
154 |
|
155 |
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
|
@@ -166,21 +168,21 @@ def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
|
|
166 |
response=md_to_text(ast.literal_eval(res)['response']),
|
167 |
sources=ast.literal_eval(res)['sources'])
|
168 |
print(res_dict)
|
169 |
-
return res_dict
|
170 |
|
171 |
|
172 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
173 |
-
def test_client_basic_api_lean_morestuff():
|
174 |
-
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=
|
175 |
|
176 |
|
177 |
-
def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
|
178 |
kwargs = dict(
|
179 |
instruction='',
|
180 |
iinput='',
|
181 |
context='',
|
182 |
stream_output=False,
|
183 |
-
prompt_type=
|
184 |
temperature=0.1,
|
185 |
top_p=0.75,
|
186 |
top_k=40,
|
@@ -211,12 +213,19 @@ def run_client_nochat_api_lean_morestuff(prompt, prompt_type, max_new_tokens):
|
|
211 |
response=md_to_text(ast.literal_eval(res)['response']),
|
212 |
sources=ast.literal_eval(res)['sources'])
|
213 |
print(res_dict)
|
214 |
-
return res_dict
|
215 |
|
216 |
|
217 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
218 |
-
def test_client_chat():
|
219 |
-
return run_client_chat(prompt='Who are you?', prompt_type=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
220 |
langchain_mode='Disabled')
|
221 |
|
222 |
|
@@ -229,6 +238,7 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchai
|
|
229 |
|
230 |
|
231 |
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
|
232 |
res = client.predict(*tuple(args), api_name='/instruction')
|
233 |
args[-1] += [res[-1]]
|
234 |
|
@@ -262,6 +272,46 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
|
262 |
return res_dict, client
|
263 |
|
264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
def md_to_text(md, do_md_to_text=True):
|
266 |
if not do_md_to_text:
|
267 |
return md
|
@@ -271,8 +321,16 @@ def md_to_text(md, do_md_to_text=True):
|
|
271 |
return soup.get_text()
|
272 |
|
273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
if __name__ == '__main__':
|
275 |
-
|
276 |
-
test_client_basic_api()
|
277 |
-
test_client_basic_api_lean()
|
278 |
-
test_client_basic_api_lean_morestuff()
|
|
|
97 |
chunk_size=512,
|
98 |
document_choice=[DocumentChoices.All_Relevant.name],
|
99 |
)
|
100 |
+
from generate import eval_func_param_names
|
101 |
+
assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0
|
102 |
if chat:
|
103 |
# add chatbot output on end. Assumes serialize=False
|
104 |
kwargs.update(dict(chatbot=[]))
|
|
|
107 |
|
108 |
|
109 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
110 |
+
def test_client_basic(prompt_type='human_bot'):
|
111 |
+
return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
112 |
|
113 |
|
114 |
def run_client_nochat(prompt, prompt_type, max_new_tokens):
|
|
|
124 |
res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'],
|
125 |
response=md_to_text(res))
|
126 |
print(res_dict)
|
127 |
+
return res_dict, client
|
128 |
|
129 |
|
130 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
131 |
+
def test_client_basic_api(prompt_type='human_bot'):
|
132 |
+
return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
133 |
|
134 |
|
135 |
def run_client_nochat_api(prompt, prompt_type, max_new_tokens):
|
|
|
146 |
response=md_to_text(ast.literal_eval(res)['response']),
|
147 |
sources=ast.literal_eval(res)['sources'])
|
148 |
print(res_dict)
|
149 |
+
return res_dict, client
|
150 |
|
151 |
|
152 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
153 |
+
def test_client_basic_api_lean(prompt_type='human_bot'):
|
154 |
+
return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
155 |
|
156 |
|
157 |
def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens):
|
|
|
168 |
response=md_to_text(ast.literal_eval(res)['response']),
|
169 |
sources=ast.literal_eval(res)['sources'])
|
170 |
print(res_dict)
|
171 |
+
return res_dict, client
|
172 |
|
173 |
|
174 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
175 |
+
def test_client_basic_api_lean_morestuff(prompt_type='human_bot'):
|
176 |
+
return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50)
|
177 |
|
178 |
|
179 |
+
def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512):
|
180 |
kwargs = dict(
|
181 |
instruction='',
|
182 |
iinput='',
|
183 |
context='',
|
184 |
stream_output=False,
|
185 |
+
prompt_type=prompt_type,
|
186 |
temperature=0.1,
|
187 |
top_p=0.75,
|
188 |
top_k=40,
|
|
|
213 |
response=md_to_text(ast.literal_eval(res)['response']),
|
214 |
sources=ast.literal_eval(res)['sources'])
|
215 |
print(res_dict)
|
216 |
+
return res_dict, client
|
217 |
|
218 |
|
219 |
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
220 |
+
def test_client_chat(prompt_type='human_bot'):
|
221 |
+
return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50,
|
222 |
+
langchain_mode='Disabled')
|
223 |
+
|
224 |
+
|
225 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
226 |
+
def test_client_chat_stream(prompt_type='human_bot'):
|
227 |
+
return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
228 |
+
stream_output=True, max_new_tokens=512,
|
229 |
langchain_mode='Disabled')
|
230 |
|
231 |
|
|
|
238 |
|
239 |
|
240 |
def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
241 |
+
assert kwargs['chat'], "Chat mode only"
|
242 |
res = client.predict(*tuple(args), api_name='/instruction')
|
243 |
args[-1] += [res[-1]]
|
244 |
|
|
|
272 |
return res_dict, client
|
273 |
|
274 |
|
275 |
+
@pytest.mark.skip(reason="For manual use against some server, no server launched")
|
276 |
+
def test_client_nochat_stream(prompt_type='human_bot'):
|
277 |
+
return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type,
|
278 |
+
stream_output=True, max_new_tokens=512,
|
279 |
+
langchain_mode='Disabled')
|
280 |
+
|
281 |
+
|
282 |
+
def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode):
|
283 |
+
client = get_client(serialize=False)
|
284 |
+
|
285 |
+
kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output,
|
286 |
+
max_new_tokens=max_new_tokens, langchain_mode=langchain_mode)
|
287 |
+
return run_client_gen(client, prompt, args, kwargs)
|
288 |
+
|
289 |
+
|
290 |
+
def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False):
|
291 |
+
res_dict = kwargs
|
292 |
+
res_dict['prompt'] = prompt
|
293 |
+
if not kwargs['stream_output']:
|
294 |
+
res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api')
|
295 |
+
res_dict['response'] = res[0]
|
296 |
+
print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text))
|
297 |
+
return res_dict, client
|
298 |
+
else:
|
299 |
+
job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api')
|
300 |
+
while not job.done():
|
301 |
+
outputs_list = job.communicator.job.outputs
|
302 |
+
if outputs_list:
|
303 |
+
res = job.communicator.job.outputs[-1]
|
304 |
+
res_dict = ast.literal_eval(res)
|
305 |
+
print('Stream: %s' % res_dict['response'])
|
306 |
+
time.sleep(0.1)
|
307 |
+
res_list = job.outputs()
|
308 |
+
assert len(res_list) > 0, "No response, check server"
|
309 |
+
res = res_list[-1]
|
310 |
+
res_dict = ast.literal_eval(res)
|
311 |
+
print('Final: %s' % res_dict['response'])
|
312 |
+
return res_dict, client
|
313 |
+
|
314 |
+
|
315 |
def md_to_text(md, do_md_to_text=True):
|
316 |
if not do_md_to_text:
|
317 |
return md
|
|
|
321 |
return soup.get_text()
|
322 |
|
323 |
|
324 |
+
def run_client_many(prompt_type='human_bot'):
|
325 |
+
ret1, _ = test_client_chat(prompt_type=prompt_type)
|
326 |
+
ret2, _ = test_client_chat_stream(prompt_type=prompt_type)
|
327 |
+
ret3, _ = test_client_nochat_stream(prompt_type=prompt_type)
|
328 |
+
ret4, _ = test_client_basic(prompt_type=prompt_type)
|
329 |
+
ret5, _ = test_client_basic_api(prompt_type=prompt_type)
|
330 |
+
ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type)
|
331 |
+
ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type)
|
332 |
+
return ret1, ret2, ret3, ret4, ret5, ret6, ret7
|
333 |
+
|
334 |
+
|
335 |
if __name__ == '__main__':
|
336 |
+
run_client_many()
|
|
|
|
|
|
create_data.py
CHANGED
@@ -567,7 +567,7 @@ def test_show_prompts():
|
|
567 |
from prompter import generate_prompt
|
568 |
for data_points in file_points:
|
569 |
for data_point in data_points:
|
570 |
-
print(generate_prompt(data_point, 'plain', '', False, False)[0])
|
571 |
|
572 |
|
573 |
def test_get_open_datasets():
|
@@ -1571,7 +1571,7 @@ def test_check_stats_data():
|
|
1571 |
|
1572 |
llama_type = False
|
1573 |
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1574 |
-
model_loader, tokenizer_loader = get_loaders(
|
1575 |
local_files_only = False
|
1576 |
resume_download = True
|
1577 |
use_auth_token = False
|
|
|
567 |
from prompter import generate_prompt
|
568 |
for data_points in file_points:
|
569 |
for data_point in data_points:
|
570 |
+
print(generate_prompt(data_point, 'plain', '', False, False, False)[0])
|
571 |
|
572 |
|
573 |
def test_get_open_datasets():
|
|
|
1571 |
|
1572 |
llama_type = False
|
1573 |
tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
1574 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
1575 |
local_files_only = False
|
1576 |
resume_download = True
|
1577 |
use_auth_token = False
|
enums.py
CHANGED
@@ -22,6 +22,12 @@ class PromptType(Enum):
|
|
22 |
wizard2 = 16
|
23 |
wizard3 = 17
|
24 |
instruct_simple = 18
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
|
27 |
class DocumentChoices(Enum):
|
@@ -44,3 +50,35 @@ class LangChainMode(Enum):
|
|
44 |
MY_DATA = "MyData"
|
45 |
GITHUB_H2OGPT = "github h2oGPT"
|
46 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
wizard2 = 16
|
23 |
wizard3 = 17
|
24 |
instruct_simple = 18
|
25 |
+
wizard_vicuna = 19
|
26 |
+
openai = 20
|
27 |
+
openai_chat = 21
|
28 |
+
gptj = 22
|
29 |
+
prompt_answer_openllama = 23
|
30 |
+
vicuna11 = 24
|
31 |
|
32 |
|
33 |
class DocumentChoices(Enum):
|
|
|
50 |
MY_DATA = "MyData"
|
51 |
GITHUB_H2OGPT = "github h2oGPT"
|
52 |
H2O_DAI_DOCS = "DriverlessAI docs"
|
53 |
+
|
54 |
+
|
55 |
+
no_server_str = no_lora_str = no_model_str = '[None/Remove]'
|
56 |
+
|
57 |
+
|
58 |
+
# from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information
|
59 |
+
model_token_mapping = {
|
60 |
+
"gpt-4": 8192,
|
61 |
+
"gpt-4-0314": 8192,
|
62 |
+
"gpt-4-32k": 32768,
|
63 |
+
"gpt-4-32k-0314": 32768,
|
64 |
+
"gpt-3.5-turbo": 4096,
|
65 |
+
"gpt-3.5-turbo-16k": 16*1024,
|
66 |
+
"gpt-3.5-turbo-0301": 4096,
|
67 |
+
"text-ada-001": 2049,
|
68 |
+
"ada": 2049,
|
69 |
+
"text-babbage-001": 2040,
|
70 |
+
"babbage": 2049,
|
71 |
+
"text-curie-001": 2049,
|
72 |
+
"curie": 2049,
|
73 |
+
"davinci": 2049,
|
74 |
+
"text-davinci-003": 4097,
|
75 |
+
"text-davinci-002": 4097,
|
76 |
+
"code-davinci-002": 8001,
|
77 |
+
"code-davinci-001": 8001,
|
78 |
+
"code-cushman-002": 2048,
|
79 |
+
"code-cushman-001": 2048,
|
80 |
+
}
|
81 |
+
|
82 |
+
|
83 |
+
source_prefix = "Sources [Score | Link]:"
|
84 |
+
source_postfix = "End Sources<p>"
|
finetune.py
CHANGED
@@ -5,6 +5,9 @@ from typing import List, Union
|
|
5 |
import fire
|
6 |
import numpy as np
|
7 |
|
|
|
|
|
|
|
8 |
from loaders import get_loaders, get_tokenizer
|
9 |
from prompter import generate_prompt, prompt_types, PromptType
|
10 |
from utils import get_githash, copy_code
|
@@ -182,7 +185,7 @@ def train(
|
|
182 |
log("num_gpus: %d" % gpus)
|
183 |
log("max mem: %s" % max_memory)
|
184 |
|
185 |
-
model_loader, tokenizer_loader = get_loaders(
|
186 |
|
187 |
model = model_loader.from_pretrained(
|
188 |
base_model,
|
@@ -556,13 +559,6 @@ def train(
|
|
556 |
)
|
557 |
model.config.use_cache = False
|
558 |
|
559 |
-
old_state_dict = model.state_dict
|
560 |
-
from peft import get_peft_model_state_dict
|
561 |
-
|
562 |
-
model.state_dict = (
|
563 |
-
lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict())
|
564 |
-
).__get__(model, type(model))
|
565 |
-
|
566 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
567 |
model = torch.compile(model)
|
568 |
# WIP (not generally replacing layers until pytorch 2.1)
|
@@ -621,10 +617,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
|
|
621 |
assert tokenizer is not None
|
622 |
prompt_dict = '' # only for custom prompt_type
|
623 |
assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
|
624 |
-
full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False)
|
625 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
626 |
if not train_on_inputs:
|
627 |
-
user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False)
|
628 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
629 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
630 |
if add_eos_token:
|
@@ -643,7 +639,7 @@ def test_debug():
|
|
643 |
fire.Fire(train)
|
644 |
|
645 |
|
646 |
-
|
647 |
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
648 |
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
649 |
log(f"""
|
@@ -674,3 +670,7 @@ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank
|
|
674 |
"CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
675 |
|
676 |
fire.Fire(train)
|
|
|
|
|
|
|
|
|
|
5 |
import fire
|
6 |
import numpy as np
|
7 |
|
8 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
9 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
10 |
+
|
11 |
from loaders import get_loaders, get_tokenizer
|
12 |
from prompter import generate_prompt, prompt_types, PromptType
|
13 |
from utils import get_githash, copy_code
|
|
|
185 |
log("num_gpus: %d" % gpus)
|
186 |
log("max mem: %s" % max_memory)
|
187 |
|
188 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type)
|
189 |
|
190 |
model = model_loader.from_pretrained(
|
191 |
base_model,
|
|
|
559 |
)
|
560 |
model.config.use_cache = False
|
561 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
562 |
if torch.__version__ >= "2" and sys.platform != "win32":
|
563 |
model = torch.compile(model)
|
564 |
# WIP (not generally replacing layers until pytorch 2.1)
|
|
|
617 |
assert tokenizer is not None
|
618 |
prompt_dict = '' # only for custom prompt_type
|
619 |
assert prompt_type != PromptType.custom.name, "custom not setup for finetune"
|
620 |
+
full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False)
|
621 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
622 |
if not train_on_inputs:
|
623 |
+
user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False)
|
624 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
625 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
626 |
if add_eos_token:
|
|
|
639 |
fire.Fire(train)
|
640 |
|
641 |
|
642 |
+
def entrypoint_main():
|
643 |
CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1"
|
644 |
CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf"
|
645 |
log(f"""
|
|
|
670 |
"CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU"
|
671 |
|
672 |
fire.Fire(train)
|
673 |
+
|
674 |
+
|
675 |
+
if __name__ == "__main__":
|
676 |
+
entrypoint_main()
|
generate.py
CHANGED
@@ -1,27 +1,37 @@
|
|
1 |
import ast
|
|
|
2 |
import functools
|
3 |
import glob
|
4 |
import inspect
|
5 |
import queue
|
6 |
-
import shutil
|
7 |
import sys
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
|
|
11 |
import typing
|
12 |
import warnings
|
13 |
from datetime import datetime
|
14 |
import filelock
|
|
|
15 |
import psutil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
18 |
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
19 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
20 |
|
21 |
-
from enums import DocumentChoices, LangChainMode
|
|
|
22 |
from loaders import get_loaders
|
23 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
24 |
-
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler
|
25 |
|
26 |
start_faulthandler()
|
27 |
import_matplotlib()
|
@@ -34,9 +44,8 @@ from typing import Union
|
|
34 |
import fire
|
35 |
import torch
|
36 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
37 |
-
from accelerate import init_empty_weights, infer_auto_device_map
|
38 |
|
39 |
-
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt
|
40 |
from stopping import get_stopping
|
41 |
|
42 |
eval_extra_columns = ['prompt', 'response', 'score']
|
@@ -56,9 +65,15 @@ def main(
|
|
56 |
lora_weights: str = "",
|
57 |
gpu_id: int = 0,
|
58 |
compile_model: bool = True,
|
59 |
-
|
|
|
60 |
prompt_type: Union[int, str] = None,
|
61 |
prompt_dict: typing.Dict = None,
|
|
|
|
|
|
|
|
|
|
|
62 |
# input to generation
|
63 |
temperature: float = None,
|
64 |
top_p: float = None,
|
@@ -88,14 +103,13 @@ def main(
|
|
88 |
cli: bool = False,
|
89 |
cli_loop: bool = True,
|
90 |
gradio: bool = True,
|
91 |
-
gradio_avoid_processing_markdown: bool = False,
|
92 |
gradio_offline_level: int = 0,
|
93 |
chat: bool = True,
|
94 |
chat_context: bool = False,
|
95 |
stream_output: bool = True,
|
96 |
show_examples: bool = None,
|
97 |
verbose: bool = False,
|
98 |
-
h2ocolors: bool =
|
99 |
height: int = 600,
|
100 |
show_lora: bool = True,
|
101 |
login_mode_if_model0: bool = False,
|
@@ -104,16 +118,19 @@ def main(
|
|
104 |
api_open: bool = False,
|
105 |
allow_api: bool = True,
|
106 |
input_lines: int = 1,
|
|
|
107 |
auth: typing.List[typing.Tuple[str, str]] = None,
|
|
|
|
|
108 |
|
109 |
-
sanitize_user_prompt: bool =
|
110 |
-
sanitize_bot_response: bool =
|
111 |
|
112 |
extra_model_options: typing.List[str] = [],
|
113 |
extra_lora_options: typing.List[str] = [],
|
|
|
114 |
|
115 |
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
116 |
-
auto_score: bool = True,
|
117 |
|
118 |
eval_filename: str = None,
|
119 |
eval_prompts_only_num: int = 0,
|
@@ -121,6 +138,7 @@ def main(
|
|
121 |
eval_as_output: bool = False,
|
122 |
|
123 |
langchain_mode: str = 'Disabled',
|
|
|
124 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
125 |
document_choice: list = [DocumentChoices.All_Relevant.name],
|
126 |
user_path: str = None,
|
@@ -138,7 +156,10 @@ def main(
|
|
138 |
enable_sources_list: bool = True,
|
139 |
chunk: bool = True,
|
140 |
chunk_size: int = 512,
|
141 |
-
top_k_docs: int =
|
|
|
|
|
|
|
142 |
n_jobs: int = -1,
|
143 |
enable_captions: bool = True,
|
144 |
captions_model: str = "Salesforce/blip-image-captioning-base",
|
@@ -157,8 +178,31 @@ def main(
|
|
157 |
:param lora_weights: LORA weights path/HF link
|
158 |
:param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
159 |
:param compile_model Whether to compile the model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
161 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
:param temperature: generation temperature
|
163 |
:param top_p: generation top_p
|
164 |
:param top_k: generation top_k
|
@@ -184,13 +228,13 @@ def main(
|
|
184 |
:param cli: whether to use CLI (non-gradio) interface.
|
185 |
:param cli_loop: whether to loop for CLI (False usually only for testing)
|
186 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
187 |
-
:param gradio_avoid_processing_markdown:
|
188 |
:param gradio_offline_level: > 0, then change fonts so full offline
|
189 |
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached
|
190 |
== 2 means backend and frontend don't need internet to download any fonts.
|
191 |
Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
|
192 |
This option further disables google fonts for downloading, which is less intrusive than uploading,
|
193 |
but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
|
|
|
194 |
:param chat: whether to enable chat mode with chat history
|
195 |
:param chat_context: whether to use extra helpful context if human_bot
|
196 |
:param stream_output: whether to stream output from generate
|
@@ -205,20 +249,25 @@ def main(
|
|
205 |
:param api_open: If False, don't let API calls skip gradio queue
|
206 |
:param allow_api: whether to allow API calls at all to gradio server
|
207 |
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
|
|
|
|
|
208 |
:param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
|
209 |
e.g. --auth=[('jon','password')] with no spaces
|
210 |
-
:param
|
211 |
-
:param
|
|
|
|
|
212 |
:param extra_model_options: extra models to show in list in gradio
|
213 |
:param extra_lora_options: extra LORA to show in list in gradio
|
|
|
214 |
:param score_model: which model to score responses (None means no scoring)
|
215 |
-
:param auto_score: whether to automatically score responses
|
216 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
217 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
218 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
219 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
220 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
221 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
|
|
222 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
223 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
224 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
@@ -248,12 +297,17 @@ def main(
|
|
248 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
249 |
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
250 |
:param top_k_docs: number of chunks to give LLM
|
|
|
|
|
|
|
|
|
|
|
251 |
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
252 |
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
253 |
:param captions_model: Which model to use for captions.
|
254 |
-
captions_model:
|
255 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
256 |
-
captions_model:
|
257 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
258 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
259 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
@@ -262,6 +316,32 @@ def main(
|
|
262 |
:param enable_ocr: Whether to support OCR on images
|
263 |
:return:
|
264 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0')))
|
266 |
is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0')))
|
267 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
@@ -276,7 +356,8 @@ def main(
|
|
276 |
|
277 |
# allow set token directly
|
278 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
279 |
-
allow_upload_to_user_data = bool(
|
|
|
280 |
allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
|
281 |
height = int(os.environ.get("HEIGHT", height))
|
282 |
h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
|
@@ -289,6 +370,12 @@ def main(
|
|
289 |
if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
|
290 |
visible_langchain_modes += [langchain_mode]
|
291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
if is_public:
|
293 |
allow_upload_to_user_data = False
|
294 |
input_lines = 1 # ensure set, for ease of use
|
@@ -297,26 +384,50 @@ def main(
|
|
297 |
top_k = 70 if top_k is None else top_k
|
298 |
if is_hf:
|
299 |
do_sample = True if do_sample is None else do_sample
|
|
|
300 |
else:
|
301 |
# by default don't sample, too chatty
|
302 |
do_sample = False if do_sample is None else do_sample
|
|
|
303 |
|
304 |
if memory_restriction_level == 2:
|
305 |
-
if not base_model:
|
306 |
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
307 |
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
308 |
load_8bit = True
|
309 |
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
310 |
-
|
311 |
-
|
312 |
if memory_restriction_level >= 2:
|
313 |
load_8bit = True
|
314 |
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
315 |
if hf_embedding_model is None:
|
316 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
317 |
if is_hf:
|
318 |
# must override share if in spaces
|
319 |
share = False
|
|
|
|
|
|
|
|
|
|
|
320 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
321 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
322 |
if score_model == 'None' or score_model is None:
|
@@ -335,7 +446,7 @@ def main(
|
|
335 |
torch.backends.cudnn.benchmark = True
|
336 |
torch.backends.cudnn.enabled = False
|
337 |
torch.set_default_dtype(torch.float32)
|
338 |
-
if psutil.virtual_memory().available < 94 * 1024 ** 3:
|
339 |
# 12B uses ~94GB
|
340 |
# 6.9B uses ~47GB
|
341 |
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
|
@@ -360,8 +471,8 @@ def main(
|
|
360 |
|
361 |
if offload_folder:
|
362 |
makedirs(offload_folder)
|
363 |
-
|
364 |
-
|
365 |
|
366 |
placeholder_instruction, placeholder_input, \
|
367 |
stream_output, show_examples, \
|
@@ -386,11 +497,12 @@ def main(
|
|
386 |
verbose,
|
387 |
)
|
388 |
|
|
|
389 |
locals_dict = locals()
|
390 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
391 |
if verbose:
|
392 |
print(f"Generating model with params:\n{locals_print}", flush=True)
|
393 |
-
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)),
|
394 |
|
395 |
if langchain_mode != "Disabled":
|
396 |
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
@@ -404,7 +516,7 @@ def main(
|
|
404 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
405 |
if os.path.isdir(gpath1):
|
406 |
print("Removing old MyData: %s" % gpath1, flush=True)
|
407 |
-
|
408 |
continue
|
409 |
if langchain_mode1 in ['All']:
|
410 |
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
@@ -430,6 +542,10 @@ def main(
|
|
430 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
431 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
432 |
|
|
|
|
|
|
|
|
|
433 |
if cli:
|
434 |
from cli import run_cli
|
435 |
return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
|
@@ -441,20 +557,68 @@ def main(
|
|
441 |
from gradio_runner import go_gradio
|
442 |
|
443 |
# get default model
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
452 |
|
453 |
# get score model
|
|
|
454 |
smodel, stokenizer, sdevice = get_score_model(reward_type=True,
|
455 |
**get_kwargs(get_score_model, exclude_names=['reward_type'],
|
456 |
**all_kwargs))
|
457 |
-
score_model_state0 =
|
|
|
|
|
458 |
|
459 |
if enable_captions:
|
460 |
if pre_load_caption_model:
|
@@ -469,34 +633,33 @@ def main(
|
|
469 |
go_gradio(**locals())
|
470 |
|
471 |
|
472 |
-
def
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
|
481 |
-
|
482 |
-
:param base_model:
|
483 |
-
:param model_loader:
|
484 |
-
:param load_half:
|
485 |
-
:param model_kwargs:
|
486 |
-
:param reward_type:
|
487 |
-
:param gpu_id:
|
488 |
-
:param use_auth_token:
|
489 |
-
:param trust_remote_code:
|
490 |
-
:param offload_folder:
|
491 |
-
:param triton_attn:
|
492 |
-
:param long_sequence:
|
493 |
-
:return:
|
494 |
-
"""
|
495 |
with init_empty_weights():
|
496 |
from transformers import AutoConfig
|
497 |
-
|
498 |
-
|
499 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
500 |
if triton_attn and 'mpt-' in base_model.lower():
|
501 |
config.attn_config['attn_impl'] = 'triton'
|
502 |
if long_sequence:
|
@@ -504,18 +667,36 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
504 |
config.update({"max_seq_len": 83968})
|
505 |
if 'mosaicml/mpt-7b-chat' in base_model.lower():
|
506 |
config.update({"max_seq_len": 4096})
|
507 |
-
|
|
|
|
|
|
|
508 |
model = AutoModel.from_config(
|
509 |
config,
|
|
|
510 |
)
|
511 |
else:
|
512 |
# can't infer
|
513 |
model = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
514 |
|
515 |
if model is not None:
|
516 |
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
517 |
# NOTE: Some models require avoiding sharding some layers,
|
518 |
# then would pass no_split_module_classes and give list of those layers.
|
|
|
519 |
device_map = infer_auto_device_map(
|
520 |
model,
|
521 |
dtype=torch.float16 if load_half else torch.float32,
|
@@ -567,12 +748,59 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
567 |
return model
|
568 |
|
569 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
570 |
def get_model(
|
571 |
load_8bit: bool = False,
|
572 |
load_4bit: bool = False,
|
573 |
load_half: bool = True,
|
574 |
infer_devices: bool = True,
|
575 |
base_model: str = '',
|
|
|
576 |
tokenizer_base_model: str = '',
|
577 |
lora_weights: str = "",
|
578 |
gpu_id: int = 0,
|
@@ -596,6 +824,7 @@ def get_model(
|
|
596 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
597 |
So it is not the default
|
598 |
:param base_model: name/path of base model
|
|
|
599 |
:param tokenizer_base_model: name/path of tokenizer
|
600 |
:param lora_weights: name/path
|
601 |
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
|
@@ -611,11 +840,120 @@ def get_model(
|
|
611 |
"""
|
612 |
if verbose:
|
613 |
print("Get %s model" % base_model, flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
614 |
if base_model in non_hf_types:
|
615 |
from gpt4all_llm import get_model_tokenizer_gpt4all
|
616 |
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
|
617 |
return model, tokenizer, device
|
618 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
619 |
if lora_weights is not None and lora_weights.strip():
|
620 |
if verbose:
|
621 |
print("Get %s lora weights" % lora_weights, flush=True)
|
@@ -630,30 +968,13 @@ def get_model(
|
|
630 |
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
|
631 |
)
|
632 |
|
633 |
-
|
634 |
-
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
635 |
-
trust_remote_code=trust_remote_code,
|
636 |
-
offload_folder=offload_folder)
|
637 |
-
llama_type_from_config = 'llama' in str(config).lower()
|
638 |
-
llama_type_from_name = "llama" in base_model.lower()
|
639 |
-
llama_type = llama_type_from_config or llama_type_from_name
|
640 |
-
if llama_type:
|
641 |
-
if verbose:
|
642 |
-
print("Detected as llama type from"
|
643 |
-
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
644 |
|
645 |
-
|
646 |
-
if not tokenizer_base_model:
|
647 |
-
tokenizer_base_model = base_model
|
648 |
|
649 |
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
650 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
651 |
-
|
652 |
-
resume_download=resume_download,
|
653 |
-
use_auth_token=use_auth_token,
|
654 |
-
trust_remote_code=trust_remote_code,
|
655 |
-
offload_folder=offload_folder,
|
656 |
-
)
|
657 |
else:
|
658 |
tokenizer = tokenizer_loader
|
659 |
|
@@ -677,7 +998,7 @@ def get_model(
|
|
677 |
load_in_4bit=load_4bit,
|
678 |
device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
|
679 |
))
|
680 |
-
if 'mpt-' in base_model.lower() and gpu_id >= 0:
|
681 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
682 |
|
683 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
@@ -688,25 +1009,30 @@ def get_model(
|
|
688 |
|
689 |
if not lora_weights:
|
690 |
with torch.device(device):
|
|
|
691 |
if infer_devices:
|
|
|
692 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
|
|
693 |
gpu_id=gpu_id,
|
694 |
-
use_auth_token=use_auth_token,
|
695 |
-
trust_remote_code=trust_remote_code,
|
696 |
-
offload_folder=offload_folder,
|
697 |
)
|
698 |
else:
|
|
|
699 |
if load_half and not (load_8bit or load_4bit):
|
700 |
model = model_loader.from_pretrained(
|
701 |
base_model,
|
|
|
702 |
**model_kwargs).half()
|
703 |
else:
|
704 |
model = model_loader.from_pretrained(
|
705 |
base_model,
|
|
|
706 |
**model_kwargs)
|
707 |
elif load_8bit or load_4bit:
|
|
|
708 |
model = model_loader.from_pretrained(
|
709 |
base_model,
|
|
|
710 |
**model_kwargs
|
711 |
)
|
712 |
from peft import PeftModel # loads cuda, so avoid in global scope
|
@@ -723,8 +1049,10 @@ def get_model(
|
|
723 |
)
|
724 |
else:
|
725 |
with torch.device(device):
|
|
|
726 |
model = model_loader.from_pretrained(
|
727 |
base_model,
|
|
|
728 |
**model_kwargs
|
729 |
)
|
730 |
from peft import PeftModel # loads cuda, so avoid in global scope
|
@@ -758,6 +1086,15 @@ def get_model(
|
|
758 |
if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
|
759 |
model = torch.compile(model)
|
760 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
761 |
if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int):
|
762 |
tokenizer.model_max_length = config.max_seq_len
|
763 |
elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
|
@@ -767,8 +1104,9 @@ def get_model(
|
|
767 |
if verbose:
|
768 |
print("Could not determine model_max_length, setting to 2048", flush=True)
|
769 |
tokenizer.model_max_length = 2048
|
770 |
-
|
771 |
-
|
|
|
772 |
|
773 |
|
774 |
def pop_unused_model_kwargs(model_kwargs):
|
@@ -790,6 +1128,7 @@ def get_score_model(score_model: str = None,
|
|
790 |
load_half: bool = True,
|
791 |
infer_devices: bool = True,
|
792 |
base_model: str = '',
|
|
|
793 |
tokenizer_base_model: str = '',
|
794 |
lora_weights: str = "",
|
795 |
gpu_id: int = 0,
|
@@ -811,6 +1150,7 @@ def get_score_model(score_model: str = None,
|
|
811 |
base_model = score_model.strip()
|
812 |
tokenizer_base_model = ''
|
813 |
lora_weights = ''
|
|
|
814 |
llama_type = False
|
815 |
compile_model = False
|
816 |
smodel, stokenizer, sdevice = get_model(reward_type=True,
|
@@ -877,9 +1217,12 @@ def evaluate_from_str(
|
|
877 |
debug=False,
|
878 |
concurrency_count=None,
|
879 |
save_dir=None,
|
880 |
-
sanitize_bot_response=
|
881 |
model_state0=None,
|
882 |
memory_restriction_level=None,
|
|
|
|
|
|
|
883 |
raise_generate_gpu_exceptions=None,
|
884 |
chat_context=None,
|
885 |
lora_weights=None,
|
@@ -890,20 +1233,26 @@ def evaluate_from_str(
|
|
890 |
use_openai_embedding=None,
|
891 |
use_openai_model=None,
|
892 |
hf_embedding_model=None,
|
893 |
-
chunk=None,
|
894 |
-
chunk_size=None,
|
895 |
db_type=None,
|
896 |
n_jobs=None,
|
897 |
first_para=None,
|
898 |
text_limit=None,
|
899 |
verbose=False,
|
900 |
cli=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
901 |
):
|
902 |
if isinstance(user_kwargs, str):
|
903 |
user_kwargs = ast.literal_eval(user_kwargs)
|
904 |
# only used for submit_nochat_api
|
905 |
user_kwargs['chat'] = False
|
906 |
-
|
|
|
907 |
if 'langchain_mode' not in user_kwargs:
|
908 |
# if user doesn't specify, then assume disabled, not use default
|
909 |
user_kwargs['langchain_mode'] = 'Disabled'
|
@@ -926,6 +1275,9 @@ def evaluate_from_str(
|
|
926 |
sanitize_bot_response=sanitize_bot_response,
|
927 |
model_state0=model_state0,
|
928 |
memory_restriction_level=memory_restriction_level,
|
|
|
|
|
|
|
929 |
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
930 |
chat_context=chat_context,
|
931 |
lora_weights=lora_weights,
|
@@ -942,6 +1294,13 @@ def evaluate_from_str(
|
|
942 |
text_limit=text_limit,
|
943 |
verbose=verbose,
|
944 |
cli=cli,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
945 |
)
|
946 |
try:
|
947 |
for ret1 in ret:
|
@@ -986,9 +1345,12 @@ def evaluate(
|
|
986 |
debug=False,
|
987 |
concurrency_count=None,
|
988 |
save_dir=None,
|
989 |
-
sanitize_bot_response=
|
990 |
model_state0=None,
|
991 |
memory_restriction_level=None,
|
|
|
|
|
|
|
992 |
raise_generate_gpu_exceptions=None,
|
993 |
chat_context=None,
|
994 |
lora_weights=None,
|
@@ -1005,6 +1367,13 @@ def evaluate(
|
|
1005 |
text_limit=None,
|
1006 |
verbose=False,
|
1007 |
cli=False,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1008 |
):
|
1009 |
# ensure passed these
|
1010 |
assert concurrency_count is not None
|
@@ -1025,29 +1394,58 @@ def evaluate(
|
|
1025 |
locals_dict = locals().copy()
|
1026 |
locals_dict.pop('model_state', None)
|
1027 |
locals_dict.pop('model_state0', None)
|
|
|
1028 |
print(locals_dict)
|
1029 |
|
1030 |
-
no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\
|
|
|
1031 |
|
|
|
|
|
1032 |
if model_state0 is None:
|
1033 |
# e.g. for no gradio case, set dummy value, else should be set
|
1034 |
-
model_state0 =
|
1035 |
-
|
1036 |
-
|
1037 |
-
|
1038 |
-
|
1039 |
-
|
1040 |
-
|
1041 |
-
|
1042 |
-
|
1043 |
-
|
1044 |
-
|
1045 |
-
|
1046 |
-
|
1047 |
-
|
1048 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1049 |
else:
|
1050 |
raise AssertionError(no_model_msg)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1051 |
|
1052 |
if base_model is None:
|
1053 |
raise AssertionError(no_model_msg)
|
@@ -1061,10 +1459,48 @@ def evaluate(
|
|
1061 |
instruction = instruction_nochat
|
1062 |
iinput = iinput_nochat
|
1063 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1064 |
if not context:
|
1065 |
# get hidden context if have one
|
1066 |
context = get_context(chat_context, prompt_type)
|
1067 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1068 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output)
|
1069 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1070 |
prompt = prompter.generate_prompt(data_point)
|
@@ -1077,13 +1513,30 @@ def evaluate(
|
|
1077 |
db1 = dbs[langchain_mode]
|
1078 |
else:
|
1079 |
db1 = None
|
1080 |
-
|
|
|
|
|
|
|
|
|
1081 |
query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
|
1082 |
outr = ""
|
1083 |
# use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
1084 |
from gpt_langchain import run_qa_db
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1085 |
for r in run_qa_db(query=query,
|
1086 |
model_name=base_model, model=model, tokenizer=tokenizer,
|
|
|
1087 |
stream_output=stream_output,
|
1088 |
prompter=prompter,
|
1089 |
load_db_if_exists=load_db_if_exists,
|
@@ -1103,29 +1556,31 @@ def evaluate(
|
|
1103 |
db_type=db_type,
|
1104 |
top_k_docs=top_k_docs,
|
1105 |
|
1106 |
-
|
1107 |
-
do_sample=do_sample,
|
1108 |
-
temperature=temperature,
|
1109 |
-
repetition_penalty=repetition_penalty,
|
1110 |
-
top_k=top_k,
|
1111 |
-
top_p=top_p,
|
1112 |
-
num_beams=num_beams,
|
1113 |
-
min_new_tokens=min_new_tokens,
|
1114 |
-
max_new_tokens=max_new_tokens,
|
1115 |
-
early_stopping=early_stopping,
|
1116 |
-
max_time=max_time,
|
1117 |
-
num_return_sequences=num_return_sequences,
|
1118 |
|
1119 |
prompt_type=prompt_type,
|
1120 |
prompt_dict=prompt_dict,
|
1121 |
n_jobs=n_jobs,
|
1122 |
verbose=verbose,
|
1123 |
cli=cli,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1124 |
):
|
1125 |
outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
|
1126 |
yield dict(response=outr, sources=extra)
|
1127 |
if save_dir:
|
1128 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1129 |
if verbose:
|
1130 |
print(
|
1131 |
'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
|
@@ -1138,6 +1593,266 @@ def evaluate(
|
|
1138 |
clear_torch_cache()
|
1139 |
return
|
1140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1141 |
if isinstance(tokenizer, str):
|
1142 |
# pipeline
|
1143 |
if tokenizer == "summarization":
|
@@ -1151,37 +1866,37 @@ def evaluate(
|
|
1151 |
assert src_lang is not None
|
1152 |
tokenizer.src_lang = languages_covered()[src_lang]
|
1153 |
|
1154 |
-
|
1155 |
-
|
1156 |
-
|
1157 |
-
|
1158 |
-
_, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level,
|
1159 |
-
model_max_length=tokenizer.model_max_length)
|
1160 |
-
prompt = prompt[-max_prompt_length:]
|
1161 |
-
inputs = tokenizer(prompt,
|
1162 |
-
return_tensors="pt",
|
1163 |
-
truncation=True,
|
1164 |
-
max_length=max_length_tokenize)
|
1165 |
-
if inputs['input_ids'].shape[1] >= max_length_tokenize - 1:
|
1166 |
-
print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True)
|
1167 |
if debug and len(inputs["input_ids"]) > 0:
|
1168 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1169 |
input_ids = inputs["input_ids"].to(device)
|
1170 |
# CRITICAL LIMIT else will fail
|
1171 |
max_max_tokens = tokenizer.model_max_length
|
1172 |
-
max_input_tokens = max_max_tokens -
|
|
|
1173 |
input_ids = input_ids[:, -max_input_tokens:]
|
1174 |
-
|
1175 |
-
|
1176 |
-
|
1177 |
-
|
1178 |
-
|
1179 |
-
|
1180 |
-
|
1181 |
-
|
1182 |
-
|
1183 |
-
|
1184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1185 |
|
1186 |
gen_kwargs = dict(input_ids=input_ids,
|
1187 |
generation_config=generation_config,
|
@@ -1200,7 +1915,10 @@ def evaluate(
|
|
1200 |
tgt_lang = languages_covered()[tgt_lang]
|
1201 |
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
|
1202 |
else:
|
1203 |
-
|
|
|
|
|
|
|
1204 |
|
1205 |
decoder_kwargs = dict(skip_special_tokens=True,
|
1206 |
clean_up_tokenization_spaces=True)
|
@@ -1216,7 +1934,8 @@ def evaluate(
|
|
1216 |
)
|
1217 |
|
1218 |
with torch.no_grad():
|
1219 |
-
|
|
|
1220 |
with context_class_cast(device):
|
1221 |
# protection for gradio not keeping track of closed users,
|
1222 |
# else hit bitsandbytes lack of thread safety:
|
@@ -1299,7 +2018,11 @@ def evaluate(
|
|
1299 |
if outputs and len(outputs) >= 1:
|
1300 |
decoded_output = prompt + outputs[0]
|
1301 |
if save_dir and decoded_output:
|
1302 |
-
|
|
|
|
|
|
|
|
|
1303 |
if verbose:
|
1304 |
print('Post-Generate: %s decoded_output: %s' % (
|
1305 |
str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
|
@@ -1318,6 +2041,7 @@ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=20
|
|
1318 |
if memory_restriction_level > 0:
|
1319 |
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
|
1320 |
else:
|
|
|
1321 |
max_length_tokenize = model_max_length - 256
|
1322 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
1323 |
output_smallest = 30 * 4
|
@@ -1422,7 +2146,8 @@ def get_generate_params(model_lower, chat,
|
|
1422 |
if model_lower:
|
1423 |
print(f"Using Model {model_lower}", flush=True)
|
1424 |
else:
|
1425 |
-
|
|
|
1426 |
|
1427 |
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
|
1428 |
early_stopping = early_stopping if early_stopping is not None else False
|
@@ -1478,12 +2203,14 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
1478 |
use_defaults = True
|
1479 |
else:
|
1480 |
if chat:
|
1481 |
-
placeholder_instruction = "
|
1482 |
else:
|
1483 |
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
1484 |
placeholder_input = ""
|
1485 |
-
if model_lower:
|
1486 |
-
|
|
|
|
|
1487 |
prompt_type = prompt_type or 'plain'
|
1488 |
else:
|
1489 |
prompt_type = ''
|
@@ -1591,7 +2318,7 @@ y = np.random.randint(0, 1, 100)
|
|
1591 |
|
1592 |
# get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format
|
1593 |
prompt_dict, error0 = get_prompt(prompt_type, prompt_dict,
|
1594 |
-
chat=False, context='', reduced=False, return_dict=True)
|
1595 |
if error0:
|
1596 |
raise RuntimeError("Prompt wrong: %s" % error0)
|
1597 |
|
@@ -1644,7 +2371,8 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l
|
|
1644 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
1645 |
'expected scalar type Half but found Float' in str(e) or \
|
1646 |
'probability tensor contains either' in str(e) or \
|
1647 |
-
'cublasLt ran into an error!' in str(e)
|
|
|
1648 |
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
|
1649 |
flush=True)
|
1650 |
traceback.print_exc()
|
@@ -1677,25 +2405,101 @@ def check_locals(**kwargs):
|
|
1677 |
assert k in kwargs, "Missing %s" % k
|
1678 |
|
1679 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1680 |
def get_max_max_new_tokens(model_state, **kwargs):
|
1681 |
-
if
|
1682 |
-
max_max_new_tokens =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1683 |
elif kwargs['memory_restriction_level'] == 1:
|
1684 |
-
|
1685 |
elif kwargs['memory_restriction_level'] == 2:
|
1686 |
-
|
1687 |
elif kwargs['memory_restriction_level'] >= 3:
|
1688 |
-
|
1689 |
else:
|
1690 |
-
|
1691 |
-
|
1692 |
-
else:
|
1693 |
-
# FIXME: Need to update after new model loaded, so user can control with slider
|
1694 |
-
max_max_new_tokens = 2048
|
1695 |
-
return max_max_new_tokens
|
1696 |
|
1697 |
|
1698 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1699 |
"""
|
1700 |
Examples:
|
1701 |
|
@@ -1726,3 +2530,7 @@ if __name__ == "__main__":
|
|
1726 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
1727 |
"""
|
1728 |
fire.Fire(main)
|
|
|
|
|
|
|
|
|
|
1 |
import ast
|
2 |
+
import copy
|
3 |
import functools
|
4 |
import glob
|
5 |
import inspect
|
6 |
import queue
|
|
|
7 |
import sys
|
8 |
import os
|
9 |
import time
|
10 |
import traceback
|
11 |
+
import types
|
12 |
import typing
|
13 |
import warnings
|
14 |
from datetime import datetime
|
15 |
import filelock
|
16 |
+
import requests
|
17 |
import psutil
|
18 |
+
from requests import ConnectTimeout, JSONDecodeError
|
19 |
+
from urllib3.exceptions import ConnectTimeoutError, MaxRetryError, ConnectionError
|
20 |
+
from requests.exceptions import ConnectionError as ConnectionError2
|
21 |
+
from requests.exceptions import ReadTimeout as ReadTimeout2
|
22 |
+
|
23 |
+
if os.path.dirname(os.path.abspath(__file__)) not in sys.path:
|
24 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
25 |
|
26 |
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
27 |
os.environ['BITSANDBYTES_NOWELCOME'] = '1'
|
28 |
warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
|
29 |
|
30 |
+
from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \
|
31 |
+
source_postfix
|
32 |
from loaders import get_loaders
|
33 |
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \
|
34 |
+
import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove
|
35 |
|
36 |
start_faulthandler()
|
37 |
import_matplotlib()
|
|
|
44 |
import fire
|
45 |
import torch
|
46 |
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
|
|
47 |
|
48 |
+
from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt
|
49 |
from stopping import get_stopping
|
50 |
|
51 |
eval_extra_columns = ['prompt', 'response', 'score']
|
|
|
65 |
lora_weights: str = "",
|
66 |
gpu_id: int = 0,
|
67 |
compile_model: bool = True,
|
68 |
+
use_cache: bool = None,
|
69 |
+
inference_server: str = "",
|
70 |
prompt_type: Union[int, str] = None,
|
71 |
prompt_dict: typing.Dict = None,
|
72 |
+
|
73 |
+
model_lock: typing.List[typing.Dict[str, str]] = None,
|
74 |
+
model_lock_columns: int = None,
|
75 |
+
fail_if_cannot_connect: bool = False,
|
76 |
+
|
77 |
# input to generation
|
78 |
temperature: float = None,
|
79 |
top_p: float = None,
|
|
|
103 |
cli: bool = False,
|
104 |
cli_loop: bool = True,
|
105 |
gradio: bool = True,
|
|
|
106 |
gradio_offline_level: int = 0,
|
107 |
chat: bool = True,
|
108 |
chat_context: bool = False,
|
109 |
stream_output: bool = True,
|
110 |
show_examples: bool = None,
|
111 |
verbose: bool = False,
|
112 |
+
h2ocolors: bool = True,
|
113 |
height: int = 600,
|
114 |
show_lora: bool = True,
|
115 |
login_mode_if_model0: bool = False,
|
|
|
118 |
api_open: bool = False,
|
119 |
allow_api: bool = True,
|
120 |
input_lines: int = 1,
|
121 |
+
gradio_size: str = None,
|
122 |
auth: typing.List[typing.Tuple[str, str]] = None,
|
123 |
+
max_max_time=None,
|
124 |
+
max_max_new_tokens=None,
|
125 |
|
126 |
+
sanitize_user_prompt: bool = False,
|
127 |
+
sanitize_bot_response: bool = False,
|
128 |
|
129 |
extra_model_options: typing.List[str] = [],
|
130 |
extra_lora_options: typing.List[str] = [],
|
131 |
+
extra_server_options: typing.List[str] = [],
|
132 |
|
133 |
score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2',
|
|
|
134 |
|
135 |
eval_filename: str = None,
|
136 |
eval_prompts_only_num: int = 0,
|
|
|
138 |
eval_as_output: bool = False,
|
139 |
|
140 |
langchain_mode: str = 'Disabled',
|
141 |
+
force_langchain_evaluate: bool = False,
|
142 |
visible_langchain_modes: list = ['UserData', 'MyData'],
|
143 |
document_choice: list = [DocumentChoices.All_Relevant.name],
|
144 |
user_path: str = None,
|
|
|
156 |
enable_sources_list: bool = True,
|
157 |
chunk: bool = True,
|
158 |
chunk_size: int = 512,
|
159 |
+
top_k_docs: int = None,
|
160 |
+
reverse_docs: bool = True,
|
161 |
+
auto_reduce_chunks: bool = True,
|
162 |
+
max_chunks: int = 100,
|
163 |
n_jobs: int = -1,
|
164 |
enable_captions: bool = True,
|
165 |
captions_model: str = "Salesforce/blip-image-captioning-base",
|
|
|
178 |
:param lora_weights: LORA weights path/HF link
|
179 |
:param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1
|
180 |
:param compile_model Whether to compile the model
|
181 |
+
:param use_cache: Whether to use caching in model (some models fail when multiple threads use)
|
182 |
+
:param inference_server: Consume base_model as type of model at this address
|
183 |
+
Address can be text-generation-server hosting that base_model
|
184 |
+
e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b
|
185 |
+
Or Address can be "openai_chat" or "openai" for OpenAI API
|
186 |
+
e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo
|
187 |
+
e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003
|
188 |
:param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model
|
189 |
:param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True)
|
190 |
+
:param model_lock: Lock models to specific combinations, for ease of use and extending to many models
|
191 |
+
Only used if gradio = True
|
192 |
+
List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict
|
193 |
+
If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict
|
194 |
+
Can specify model_lock instead of those items on CLI
|
195 |
+
As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py.
|
196 |
+
Also, tokenizer_base_model and lora_weights are optional.
|
197 |
+
Also, inference_server is optional if loading model from local system.
|
198 |
+
All models provided will automatically appear in compare model mode
|
199 |
+
Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled
|
200 |
+
:param model_lock_columns: How many columns to show if locking models (and so showing all at once)
|
201 |
+
If None, then defaults to up to 3
|
202 |
+
if -1, then all goes into 1 row
|
203 |
+
Maximum value is 4 due to non-dynamic gradio rendering elements
|
204 |
+
:param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore.
|
205 |
+
Useful when many endpoints and want to just see what works, but still have to wait for timeout.
|
206 |
:param temperature: generation temperature
|
207 |
:param top_p: generation top_p
|
208 |
:param top_k: generation top_k
|
|
|
228 |
:param cli: whether to use CLI (non-gradio) interface.
|
229 |
:param cli_loop: whether to loop for CLI (False usually only for testing)
|
230 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
|
|
231 |
:param gradio_offline_level: > 0, then change fonts so full offline
|
232 |
== 1 means backend won't need internet for fonts, but front-end UI might if font not cached
|
233 |
== 2 means backend and frontend don't need internet to download any fonts.
|
234 |
Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading.
|
235 |
This option further disables google fonts for downloading, which is less intrusive than uploading,
|
236 |
but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior.
|
237 |
+
Also set --share=False to avoid sharing a gradio live link.
|
238 |
:param chat: whether to enable chat mode with chat history
|
239 |
:param chat_context: whether to use extra helpful context if human_bot
|
240 |
:param stream_output: whether to stream output from generate
|
|
|
249 |
:param api_open: If False, don't let API calls skip gradio queue
|
250 |
:param allow_api: whether to allow API calls at all to gradio server
|
251 |
:param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit)
|
252 |
+
:param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large".
|
253 |
+
Small useful for many chatbots in model_lock mode
|
254 |
:param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...]
|
255 |
e.g. --auth=[('jon','password')] with no spaces
|
256 |
+
:param max_max_time: Maximum max_time for gradio slider
|
257 |
+
:param max_max_new_tokens: Maximum max_new_tokens for gradio slider
|
258 |
+
:param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing)
|
259 |
+
:param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow)
|
260 |
:param extra_model_options: extra models to show in list in gradio
|
261 |
:param extra_lora_options: extra LORA to show in list in gradio
|
262 |
+
:param extra_server_options: extra servers to show in list in gradio
|
263 |
:param score_model: which model to score responses (None means no scoring)
|
|
|
264 |
:param eval_filename: json file to use for evaluation, if None is sharegpt
|
265 |
:param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples
|
266 |
:param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling
|
267 |
:param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself
|
268 |
:param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py.
|
269 |
WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present.
|
270 |
+
:param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing.
|
271 |
:param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode.
|
272 |
If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources
|
273 |
:param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes).
|
|
|
297 |
:param chunk: Whether to chunk data (True unless know data is already optimally chunked)
|
298 |
:param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length
|
299 |
:param top_k_docs: number of chunks to give LLM
|
300 |
+
:param reverse_docs: whether to reverse docs order so most relevant is closest to question.
|
301 |
+
Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too.
|
302 |
+
But smaller 6_9 models fail to use newest context and can get stuck on old information.
|
303 |
+
:param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt
|
304 |
+
:param max_chunks: If top_k_docs=-1, maximum number of chunks to allow
|
305 |
:param n_jobs: Number of processors to use when consuming documents (-1 = all, is default)
|
306 |
:param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model
|
307 |
:param captions_model: Which model to use for captions.
|
308 |
+
captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable
|
309 |
captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state
|
310 |
+
captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state
|
311 |
Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions
|
312 |
:param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader
|
313 |
parallel loading disabled if preload and have images, to prevent deadlocking on cuda context
|
|
|
316 |
:param enable_ocr: Whether to support OCR on images
|
317 |
:return:
|
318 |
"""
|
319 |
+
if base_model is None:
|
320 |
+
base_model = ''
|
321 |
+
if tokenizer_base_model is None:
|
322 |
+
tokenizer_base_model = ''
|
323 |
+
if lora_weights is None:
|
324 |
+
lora_weights = ''
|
325 |
+
if inference_server is None:
|
326 |
+
inference_server = ''
|
327 |
+
|
328 |
+
# listen to env if set
|
329 |
+
model_lock = os.getenv('model_lock', str(model_lock))
|
330 |
+
model_lock = ast.literal_eval(model_lock)
|
331 |
+
|
332 |
+
if model_lock:
|
333 |
+
assert gradio, "model_lock only supported for gradio=True"
|
334 |
+
if len(model_lock) > 1:
|
335 |
+
assert chat, "model_lock only works for multiple models for chat=True"
|
336 |
+
assert not cli, "model_lock only supported for cli=False"
|
337 |
+
assert not (not cli and not gradio), "model_lock only supported for eval (cli=gradio=False)"
|
338 |
+
assert not base_model, "Don't specify model_lock and base_model"
|
339 |
+
assert not tokenizer_base_model, "Don't specify model_lock and tokenizer_base_model"
|
340 |
+
assert not lora_weights, "Don't specify model_lock and lora_weights"
|
341 |
+
assert not inference_server, "Don't specify model_lock and inference_server"
|
342 |
+
# assert not prompt_type, "Don't specify model_lock and prompt_type"
|
343 |
+
# assert not prompt_dict, "Don't specify model_lock and prompt_dict"
|
344 |
+
|
345 |
is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0')))
|
346 |
is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0')))
|
347 |
is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer
|
|
|
356 |
|
357 |
# allow set token directly
|
358 |
use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token)
|
359 |
+
allow_upload_to_user_data = bool(
|
360 |
+
int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data)))))
|
361 |
allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data)))))
|
362 |
height = int(os.environ.get("HEIGHT", height))
|
363 |
h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors)))
|
|
|
370 |
if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes:
|
371 |
visible_langchain_modes += [langchain_mode]
|
372 |
|
373 |
+
# if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler
|
374 |
+
if LangChainMode.MY_DATA.value not in visible_langchain_modes:
|
375 |
+
allow_upload_to_my_data = False
|
376 |
+
if LangChainMode.USER_DATA.value not in visible_langchain_modes:
|
377 |
+
allow_upload_to_user_data = False
|
378 |
+
|
379 |
if is_public:
|
380 |
allow_upload_to_user_data = False
|
381 |
input_lines = 1 # ensure set, for ease of use
|
|
|
384 |
top_k = 70 if top_k is None else top_k
|
385 |
if is_hf:
|
386 |
do_sample = True if do_sample is None else do_sample
|
387 |
+
top_k_docs = 3 if top_k_docs is None else top_k_docs
|
388 |
else:
|
389 |
# by default don't sample, too chatty
|
390 |
do_sample = False if do_sample is None else do_sample
|
391 |
+
top_k_docs = 4 if top_k_docs is None else top_k_docs
|
392 |
|
393 |
if memory_restriction_level == 2:
|
394 |
+
if not base_model and not inference_server:
|
395 |
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
396 |
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
397 |
load_8bit = True
|
398 |
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
399 |
+
elif not inference_server:
|
400 |
+
top_k_docs = 10 if top_k_docs is None else top_k_docs
|
401 |
if memory_restriction_level >= 2:
|
402 |
load_8bit = True
|
403 |
load_4bit = False # FIXME - consider using 4-bit instead of 8-bit
|
404 |
if hf_embedding_model is None:
|
405 |
hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
|
406 |
+
top_k_docs = 3 if top_k_docs is None else top_k_docs
|
407 |
+
if top_k_docs is None:
|
408 |
+
top_k_docs = 3
|
409 |
+
if is_public:
|
410 |
+
if not max_time:
|
411 |
+
max_time = 60 * 2
|
412 |
+
if not max_max_time:
|
413 |
+
max_max_time = max_time
|
414 |
+
if not max_new_tokens:
|
415 |
+
max_new_tokens = 256
|
416 |
+
if not max_max_new_tokens:
|
417 |
+
max_max_new_tokens = 256
|
418 |
+
else:
|
419 |
+
if not max_max_time:
|
420 |
+
max_max_time = 60 * 20
|
421 |
+
if not max_max_new_tokens:
|
422 |
+
max_max_new_tokens = 512
|
423 |
if is_hf:
|
424 |
# must override share if in spaces
|
425 |
share = False
|
426 |
+
if not max_time:
|
427 |
+
max_time = 60 * 1
|
428 |
+
if not max_max_time:
|
429 |
+
max_max_time = max_time
|
430 |
+
# HF accounted for later in get_max_max_new_tokens()
|
431 |
save_dir = os.getenv('SAVE_DIR', save_dir)
|
432 |
score_model = os.getenv('SCORE_MODEL', score_model)
|
433 |
if score_model == 'None' or score_model is None:
|
|
|
446 |
torch.backends.cudnn.benchmark = True
|
447 |
torch.backends.cudnn.enabled = False
|
448 |
torch.set_default_dtype(torch.float32)
|
449 |
+
if psutil.virtual_memory().available < 94 * 1024 ** 3 and not inference_server:
|
450 |
# 12B uses ~94GB
|
451 |
# 6.9B uses ~47GB
|
452 |
base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model
|
|
|
471 |
|
472 |
if offload_folder:
|
473 |
makedirs(offload_folder)
|
474 |
+
if user_path:
|
475 |
+
makedirs(user_path)
|
476 |
|
477 |
placeholder_instruction, placeholder_input, \
|
478 |
stream_output, show_examples, \
|
|
|
497 |
verbose,
|
498 |
)
|
499 |
|
500 |
+
git_hash = get_githash()
|
501 |
locals_dict = locals()
|
502 |
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
503 |
if verbose:
|
504 |
print(f"Generating model with params:\n{locals_print}", flush=True)
|
505 |
+
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), git_hash), flush=True)
|
506 |
|
507 |
if langchain_mode != "Disabled":
|
508 |
# SECOND PLACE where LangChain referenced, but all imports are kept local so not required
|
|
|
516 |
for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)):
|
517 |
if os.path.isdir(gpath1):
|
518 |
print("Removing old MyData: %s" % gpath1, flush=True)
|
519 |
+
remove(gpath1)
|
520 |
continue
|
521 |
if langchain_mode1 in ['All']:
|
522 |
# FIXME: All should be avoided until scans over each db, shouldn't be separate db
|
|
|
542 |
assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
543 |
assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have"
|
544 |
|
545 |
+
model_state_none = dict(model=None, tokenizer=None, device=None,
|
546 |
+
base_model=None, tokenizer_base_model=None, lora_weights=None,
|
547 |
+
inference_server=None, prompt_type=None, prompt_dict=None)
|
548 |
+
|
549 |
if cli:
|
550 |
from cli import run_cli
|
551 |
return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals()))
|
|
|
557 |
from gradio_runner import go_gradio
|
558 |
|
559 |
# get default model
|
560 |
+
model_states = []
|
561 |
+
model_list = [dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights,
|
562 |
+
inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict)]
|
563 |
+
model_list0 = copy.deepcopy(model_list) # just strings, safe to deepcopy
|
564 |
+
model_state0 = model_state_none.copy()
|
565 |
+
assert len(model_state_none) == len(model_state0)
|
566 |
+
if model_lock:
|
567 |
+
model_list = model_lock
|
568 |
+
for model_dict in reversed(model_list):
|
569 |
+
# do reverse, so first is default base_model etc., so some logic works in go_gradio() more easily
|
570 |
+
# handles defaults user didn't have to pass
|
571 |
+
model_dict['base_model'] = base_model = model_dict.get('base_model', '')
|
572 |
+
model_dict['tokenizer_base_model'] = tokenizer_base_model = model_dict.get('tokenizer_base_model', '')
|
573 |
+
model_dict['lora_weights'] = lora_weights = model_dict.get('lora_weights', '')
|
574 |
+
model_dict['inference_server'] = inference_server = model_dict.get('inference_server', '')
|
575 |
+
prompt_type = model_dict.get('prompt_type', model_list0[0]['prompt_type']) # don't use mutated value
|
576 |
+
# try to infer, ignore empty initial state leading to get_generate_params -> 'plain'
|
577 |
+
if model_dict.get('prompt_type') is None:
|
578 |
+
model_lower = base_model.lower()
|
579 |
+
if model_lower in inv_prompt_type_to_model_lower:
|
580 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
581 |
+
prompt_dict, error0 = get_prompt(prompt_type, '',
|
582 |
+
chat=False, context='', reduced=False, making_context=False,
|
583 |
+
return_dict=True)
|
584 |
+
model_dict['prompt_type'] = prompt_type
|
585 |
+
model_dict['prompt_dict'] = prompt_dict = model_dict.get('prompt_dict', prompt_dict)
|
586 |
+
all_kwargs = locals().copy()
|
587 |
+
if base_model and not login_mode_if_model0:
|
588 |
+
model0, tokenizer0, device = get_model(reward_type=False,
|
589 |
+
**get_kwargs(get_model, exclude_names=['reward_type'],
|
590 |
+
**all_kwargs))
|
591 |
+
else:
|
592 |
+
# if empty model, then don't load anything, just get gradio up
|
593 |
+
model0, tokenizer0, device = None, None, None
|
594 |
+
if model0 is None:
|
595 |
+
if fail_if_cannot_connect:
|
596 |
+
raise RuntimeError("Could not connect, see logs")
|
597 |
+
# skip
|
598 |
+
if isinstance(model_lock, list):
|
599 |
+
model_lock.remove(model_dict)
|
600 |
+
continue
|
601 |
+
model_state_trial = dict(model=model0, tokenizer=tokenizer0, device=device)
|
602 |
+
model_state_trial.update(model_dict)
|
603 |
+
assert len(model_state_none) == len(model_state_trial)
|
604 |
+
print("Model %s" % model_dict, flush=True)
|
605 |
+
if model_lock:
|
606 |
+
# last in iteration will be first
|
607 |
+
model_states.insert(0, model_state_trial)
|
608 |
+
# fill model_state0 so go_gradio() easier, manage model_states separately
|
609 |
+
model_state0 = model_state_trial.copy()
|
610 |
+
else:
|
611 |
+
model_state0 = model_state_trial.copy()
|
612 |
+
assert len(model_state_none) == len(model_state0)
|
613 |
|
614 |
# get score model
|
615 |
+
all_kwargs = locals().copy()
|
616 |
smodel, stokenizer, sdevice = get_score_model(reward_type=True,
|
617 |
**get_kwargs(get_score_model, exclude_names=['reward_type'],
|
618 |
**all_kwargs))
|
619 |
+
score_model_state0 = dict(model=smodel, tokenizer=stokenizer, device=sdevice,
|
620 |
+
base_model=score_model, tokenizer_base_model='', lora_weights='',
|
621 |
+
inference_server='', prompt_type='', prompt_dict='')
|
622 |
|
623 |
if enable_captions:
|
624 |
if pre_load_caption_model:
|
|
|
633 |
go_gradio(**locals())
|
634 |
|
635 |
|
636 |
+
def get_config(base_model,
|
637 |
+
use_auth_token=False,
|
638 |
+
trust_remote_code=True,
|
639 |
+
offload_folder=None,
|
640 |
+
triton_attn=False,
|
641 |
+
long_sequence=True,
|
642 |
+
return_model=False,
|
643 |
+
raise_exception=False,
|
644 |
+
):
|
645 |
+
from accelerate import init_empty_weights
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
646 |
with init_empty_weights():
|
647 |
from transformers import AutoConfig
|
648 |
+
try:
|
649 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
650 |
+
trust_remote_code=trust_remote_code,
|
651 |
+
offload_folder=offload_folder)
|
652 |
+
except OSError as e:
|
653 |
+
if raise_exception:
|
654 |
+
raise
|
655 |
+
if 'not a local folder and is not a valid model identifier listed on' in str(
|
656 |
+
e) or '404 Client Error' in str(e):
|
657 |
+
# e.g. llama, gpjt, etc.
|
658 |
+
# e.g. HF TGI but not model on HF or private etc.
|
659 |
+
# HF TGI server only should really require prompt_type, not HF model state
|
660 |
+
return None, None
|
661 |
+
else:
|
662 |
+
raise
|
663 |
if triton_attn and 'mpt-' in base_model.lower():
|
664 |
config.attn_config['attn_impl'] = 'triton'
|
665 |
if long_sequence:
|
|
|
667 |
config.update({"max_seq_len": 83968})
|
668 |
if 'mosaicml/mpt-7b-chat' in base_model.lower():
|
669 |
config.update({"max_seq_len": 4096})
|
670 |
+
if 'mpt-30b' in base_model.lower():
|
671 |
+
config.update({"max_seq_len": 2 * 8192})
|
672 |
+
if return_model and \
|
673 |
+
issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
|
674 |
model = AutoModel.from_config(
|
675 |
config,
|
676 |
+
trust_remote_code=trust_remote_code,
|
677 |
)
|
678 |
else:
|
679 |
# can't infer
|
680 |
model = None
|
681 |
+
if 'falcon' in base_model.lower():
|
682 |
+
config.use_cache = False
|
683 |
+
|
684 |
+
return config, model
|
685 |
+
|
686 |
+
|
687 |
+
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
688 |
+
config, model,
|
689 |
+
gpu_id=0,
|
690 |
+
):
|
691 |
+
"""
|
692 |
+
Ensure model gets on correct device
|
693 |
+
"""
|
694 |
|
695 |
if model is not None:
|
696 |
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
697 |
# NOTE: Some models require avoiding sharding some layers,
|
698 |
# then would pass no_split_module_classes and give list of those layers.
|
699 |
+
from accelerate import infer_auto_device_map
|
700 |
device_map = infer_auto_device_map(
|
701 |
model,
|
702 |
dtype=torch.float16 if load_half else torch.float32,
|
|
|
748 |
return model
|
749 |
|
750 |
|
751 |
+
def get_client_from_inference_server(inference_server, raise_connection_exception=False):
|
752 |
+
inference_server, headers = get_hf_server(inference_server)
|
753 |
+
# preload client since slow for gradio case especially
|
754 |
+
from gradio_utils.grclient import GradioClient
|
755 |
+
gr_client = None
|
756 |
+
hf_client = None
|
757 |
+
if headers is None:
|
758 |
+
try:
|
759 |
+
print("GR Client Begin: %s" % inference_server, flush=True)
|
760 |
+
# first do sanity check if alive, else gradio client takes too long by default
|
761 |
+
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
762 |
+
gr_client = GradioClient(inference_server)
|
763 |
+
print("GR Client End: %s" % inference_server, flush=True)
|
764 |
+
except (OSError, ValueError) as e:
|
765 |
+
# Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF
|
766 |
+
gr_client = None
|
767 |
+
print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True)
|
768 |
+
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
|
769 |
+
JSONDecodeError, ReadTimeout2, KeyError) as e:
|
770 |
+
t, v, tb = sys.exc_info()
|
771 |
+
ex = ''.join(traceback.format_exception(t, v, tb))
|
772 |
+
print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True)
|
773 |
+
if raise_connection_exception:
|
774 |
+
raise
|
775 |
+
|
776 |
+
if gr_client is None:
|
777 |
+
res = None
|
778 |
+
from text_generation import Client as HFClient
|
779 |
+
print("HF Client Begin: %s" % inference_server)
|
780 |
+
try:
|
781 |
+
hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30')))
|
782 |
+
# quick check valid TGI endpoint
|
783 |
+
res = hf_client.generate('What?', max_new_tokens=1)
|
784 |
+
hf_client = HFClient(inference_server, headers=headers, timeout=300)
|
785 |
+
except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2,
|
786 |
+
JSONDecodeError, ReadTimeout2, KeyError) as e:
|
787 |
+
hf_client = None
|
788 |
+
t, v, tb = sys.exc_info()
|
789 |
+
ex = ''.join(traceback.format_exception(t, v, tb))
|
790 |
+
print("HF Client Failed %s: %s" % (inference_server, str(ex)))
|
791 |
+
if raise_connection_exception:
|
792 |
+
raise
|
793 |
+
print("HF Client End: %s %s" % (inference_server, res))
|
794 |
+
return inference_server, gr_client, hf_client
|
795 |
+
|
796 |
+
|
797 |
def get_model(
|
798 |
load_8bit: bool = False,
|
799 |
load_4bit: bool = False,
|
800 |
load_half: bool = True,
|
801 |
infer_devices: bool = True,
|
802 |
base_model: str = '',
|
803 |
+
inference_server: str = "",
|
804 |
tokenizer_base_model: str = '',
|
805 |
lora_weights: str = "",
|
806 |
gpu_id: int = 0,
|
|
|
824 |
For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches
|
825 |
So it is not the default
|
826 |
:param base_model: name/path of base model
|
827 |
+
:param inference_server: whether base_model is hosted locally ('') or via http (url)
|
828 |
:param tokenizer_base_model: name/path of tokenizer
|
829 |
:param lora_weights: name/path
|
830 |
:param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1)
|
|
|
840 |
"""
|
841 |
if verbose:
|
842 |
print("Get %s model" % base_model, flush=True)
|
843 |
+
|
844 |
+
triton_attn = False
|
845 |
+
long_sequence = True
|
846 |
+
config_kwargs = dict(use_auth_token=use_auth_token,
|
847 |
+
trust_remote_code=trust_remote_code,
|
848 |
+
offload_folder=offload_folder,
|
849 |
+
triton_attn=triton_attn,
|
850 |
+
long_sequence=long_sequence)
|
851 |
+
config, _ = get_config(base_model, **config_kwargs, raise_exception=False)
|
852 |
+
|
853 |
+
if base_model in non_hf_types:
|
854 |
+
assert config is None, "Expected config None for %s" % base_model
|
855 |
+
|
856 |
+
llama_type_from_config = 'llama' in str(config).lower()
|
857 |
+
llama_type_from_name = "llama" in base_model.lower()
|
858 |
+
llama_type = llama_type_from_config or llama_type_from_name
|
859 |
+
if "xgen" in base_model.lower():
|
860 |
+
llama_type = False
|
861 |
+
if llama_type:
|
862 |
+
if verbose:
|
863 |
+
print("Detected as llama type from"
|
864 |
+
" config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True)
|
865 |
+
|
866 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
|
867 |
+
|
868 |
+
tokenizer_kwargs = dict(local_files_only=local_files_only,
|
869 |
+
resume_download=resume_download,
|
870 |
+
use_auth_token=use_auth_token,
|
871 |
+
trust_remote_code=trust_remote_code,
|
872 |
+
offload_folder=offload_folder,
|
873 |
+
padding_side='left',
|
874 |
+
config=config,
|
875 |
+
)
|
876 |
+
if not tokenizer_base_model:
|
877 |
+
tokenizer_base_model = base_model
|
878 |
+
|
879 |
+
if config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
880 |
+
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs)
|
881 |
+
# sets raw (no cushion) limit
|
882 |
+
set_model_max_len(config, tokenizer, verbose=False)
|
883 |
+
# if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get:
|
884 |
+
# Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233
|
885 |
+
tokenizer.model_max_length = tokenizer.model_max_length - 50
|
886 |
+
else:
|
887 |
+
tokenizer = FakeTokenizer()
|
888 |
+
|
889 |
+
if isinstance(inference_server, str) and inference_server.startswith("http"):
|
890 |
+
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
|
891 |
+
client = gr_client or hf_client
|
892 |
+
# Don't return None, None for model, tokenizer so triggers
|
893 |
+
return client, tokenizer, 'http'
|
894 |
+
if isinstance(inference_server, str) and inference_server.startswith('openai'):
|
895 |
+
assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY"
|
896 |
+
# Don't return None, None for model, tokenizer so triggers
|
897 |
+
# include small token cushion
|
898 |
+
tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50)
|
899 |
+
return inference_server, tokenizer, inference_server
|
900 |
+
assert not inference_server, "Malformed inference_server=%s" % inference_server
|
901 |
if base_model in non_hf_types:
|
902 |
from gpt4all_llm import get_model_tokenizer_gpt4all
|
903 |
model, tokenizer, device = get_model_tokenizer_gpt4all(base_model)
|
904 |
return model, tokenizer, device
|
905 |
|
906 |
+
# get local torch-HF model
|
907 |
+
return get_hf_model(load_8bit=load_8bit,
|
908 |
+
load_4bit=load_4bit,
|
909 |
+
load_half=load_half,
|
910 |
+
infer_devices=infer_devices,
|
911 |
+
base_model=base_model,
|
912 |
+
tokenizer_base_model=tokenizer_base_model,
|
913 |
+
lora_weights=lora_weights,
|
914 |
+
gpu_id=gpu_id,
|
915 |
+
|
916 |
+
reward_type=reward_type,
|
917 |
+
local_files_only=local_files_only,
|
918 |
+
resume_download=resume_download,
|
919 |
+
use_auth_token=use_auth_token,
|
920 |
+
trust_remote_code=trust_remote_code,
|
921 |
+
offload_folder=offload_folder,
|
922 |
+
compile_model=compile_model,
|
923 |
+
|
924 |
+
llama_type=llama_type,
|
925 |
+
config_kwargs=config_kwargs,
|
926 |
+
tokenizer_kwargs=tokenizer_kwargs,
|
927 |
+
|
928 |
+
verbose=verbose)
|
929 |
+
|
930 |
+
|
931 |
+
def get_hf_model(load_8bit: bool = False,
|
932 |
+
load_4bit: bool = False,
|
933 |
+
load_half: bool = True,
|
934 |
+
infer_devices: bool = True,
|
935 |
+
base_model: str = '',
|
936 |
+
tokenizer_base_model: str = '',
|
937 |
+
lora_weights: str = "",
|
938 |
+
gpu_id: int = 0,
|
939 |
+
|
940 |
+
reward_type: bool = None,
|
941 |
+
local_files_only: bool = False,
|
942 |
+
resume_download: bool = True,
|
943 |
+
use_auth_token: Union[str, bool] = False,
|
944 |
+
trust_remote_code: bool = True,
|
945 |
+
offload_folder: str = None,
|
946 |
+
compile_model: bool = True,
|
947 |
+
|
948 |
+
llama_type: bool = False,
|
949 |
+
config_kwargs=None,
|
950 |
+
tokenizer_kwargs=None,
|
951 |
+
|
952 |
+
verbose: bool = False,
|
953 |
+
):
|
954 |
+
assert config_kwargs is not None
|
955 |
+
assert tokenizer_kwargs is not None
|
956 |
+
|
957 |
if lora_weights is not None and lora_weights.strip():
|
958 |
if verbose:
|
959 |
print("Get %s lora weights" % lora_weights, flush=True)
|
|
|
968 |
"Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)"
|
969 |
)
|
970 |
|
971 |
+
model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
972 |
|
973 |
+
config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs)
|
|
|
|
|
974 |
|
975 |
if tokenizer_loader is not None and not isinstance(tokenizer_loader, str):
|
976 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
977 |
+
**tokenizer_kwargs)
|
|
|
|
|
|
|
|
|
|
|
978 |
else:
|
979 |
tokenizer = tokenizer_loader
|
980 |
|
|
|
998 |
load_in_4bit=load_4bit,
|
999 |
device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto",
|
1000 |
))
|
1001 |
+
if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0:
|
1002 |
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
1003 |
|
1004 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
|
|
1009 |
|
1010 |
if not lora_weights:
|
1011 |
with torch.device(device):
|
1012 |
+
|
1013 |
if infer_devices:
|
1014 |
+
config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs)
|
1015 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
1016 |
+
config, model,
|
1017 |
gpu_id=gpu_id,
|
|
|
|
|
|
|
1018 |
)
|
1019 |
else:
|
1020 |
+
config, _ = get_config(base_model, **config_kwargs)
|
1021 |
if load_half and not (load_8bit or load_4bit):
|
1022 |
model = model_loader.from_pretrained(
|
1023 |
base_model,
|
1024 |
+
config=config,
|
1025 |
**model_kwargs).half()
|
1026 |
else:
|
1027 |
model = model_loader.from_pretrained(
|
1028 |
base_model,
|
1029 |
+
config=config,
|
1030 |
**model_kwargs)
|
1031 |
elif load_8bit or load_4bit:
|
1032 |
+
config, _ = get_config(base_model, **config_kwargs)
|
1033 |
model = model_loader.from_pretrained(
|
1034 |
base_model,
|
1035 |
+
config=config,
|
1036 |
**model_kwargs
|
1037 |
)
|
1038 |
from peft import PeftModel # loads cuda, so avoid in global scope
|
|
|
1049 |
)
|
1050 |
else:
|
1051 |
with torch.device(device):
|
1052 |
+
config, _ = get_config(base_model, raise_exception=True, **config_kwargs)
|
1053 |
model = model_loader.from_pretrained(
|
1054 |
base_model,
|
1055 |
+
config=config,
|
1056 |
**model_kwargs
|
1057 |
)
|
1058 |
from peft import PeftModel # loads cuda, so avoid in global scope
|
|
|
1086 |
if torch.__version__ >= "2" and sys.platform != "win32" and compile_model:
|
1087 |
model = torch.compile(model)
|
1088 |
|
1089 |
+
set_model_max_len(config, tokenizer, verbose=False, reward_type=reward_type)
|
1090 |
+
|
1091 |
+
return model, tokenizer, device
|
1092 |
+
|
1093 |
+
|
1094 |
+
def set_model_max_len(config, tokenizer, verbose=False, reward_type=False):
|
1095 |
+
if reward_type:
|
1096 |
+
# limit deberta, else uses too much memory and not worth response score
|
1097 |
+
tokenizer.model_max_length = 512
|
1098 |
if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int):
|
1099 |
tokenizer.model_max_length = config.max_seq_len
|
1100 |
elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int):
|
|
|
1104 |
if verbose:
|
1105 |
print("Could not determine model_max_length, setting to 2048", flush=True)
|
1106 |
tokenizer.model_max_length = 2048
|
1107 |
+
# for bug in HF transformers
|
1108 |
+
if tokenizer.model_max_length > 100000000:
|
1109 |
+
tokenizer.model_max_length = 2048
|
1110 |
|
1111 |
|
1112 |
def pop_unused_model_kwargs(model_kwargs):
|
|
|
1128 |
load_half: bool = True,
|
1129 |
infer_devices: bool = True,
|
1130 |
base_model: str = '',
|
1131 |
+
inference_server: str = '',
|
1132 |
tokenizer_base_model: str = '',
|
1133 |
lora_weights: str = "",
|
1134 |
gpu_id: int = 0,
|
|
|
1150 |
base_model = score_model.strip()
|
1151 |
tokenizer_base_model = ''
|
1152 |
lora_weights = ''
|
1153 |
+
inference_server = ''
|
1154 |
llama_type = False
|
1155 |
compile_model = False
|
1156 |
smodel, stokenizer, sdevice = get_model(reward_type=True,
|
|
|
1217 |
debug=False,
|
1218 |
concurrency_count=None,
|
1219 |
save_dir=None,
|
1220 |
+
sanitize_bot_response=False,
|
1221 |
model_state0=None,
|
1222 |
memory_restriction_level=None,
|
1223 |
+
max_max_new_tokens=None,
|
1224 |
+
is_public=None,
|
1225 |
+
max_max_time=None,
|
1226 |
raise_generate_gpu_exceptions=None,
|
1227 |
chat_context=None,
|
1228 |
lora_weights=None,
|
|
|
1233 |
use_openai_embedding=None,
|
1234 |
use_openai_model=None,
|
1235 |
hf_embedding_model=None,
|
|
|
|
|
1236 |
db_type=None,
|
1237 |
n_jobs=None,
|
1238 |
first_para=None,
|
1239 |
text_limit=None,
|
1240 |
verbose=False,
|
1241 |
cli=False,
|
1242 |
+
reverse_docs=True,
|
1243 |
+
use_cache=None,
|
1244 |
+
auto_reduce_chunks=None,
|
1245 |
+
max_chunks=None,
|
1246 |
+
model_lock=None,
|
1247 |
+
force_langchain_evaluate=None,
|
1248 |
+
model_state_none=None,
|
1249 |
):
|
1250 |
if isinstance(user_kwargs, str):
|
1251 |
user_kwargs = ast.literal_eval(user_kwargs)
|
1252 |
# only used for submit_nochat_api
|
1253 |
user_kwargs['chat'] = False
|
1254 |
+
if 'stream_output' not in user_kwargs:
|
1255 |
+
user_kwargs['stream_output'] = False
|
1256 |
if 'langchain_mode' not in user_kwargs:
|
1257 |
# if user doesn't specify, then assume disabled, not use default
|
1258 |
user_kwargs['langchain_mode'] = 'Disabled'
|
|
|
1275 |
sanitize_bot_response=sanitize_bot_response,
|
1276 |
model_state0=model_state0,
|
1277 |
memory_restriction_level=memory_restriction_level,
|
1278 |
+
max_max_new_tokens=max_max_new_tokens,
|
1279 |
+
is_public=is_public,
|
1280 |
+
max_max_time=max_max_time,
|
1281 |
raise_generate_gpu_exceptions=raise_generate_gpu_exceptions,
|
1282 |
chat_context=chat_context,
|
1283 |
lora_weights=lora_weights,
|
|
|
1294 |
text_limit=text_limit,
|
1295 |
verbose=verbose,
|
1296 |
cli=cli,
|
1297 |
+
reverse_docs=reverse_docs,
|
1298 |
+
use_cache=use_cache,
|
1299 |
+
auto_reduce_chunks=auto_reduce_chunks,
|
1300 |
+
max_chunks=max_chunks,
|
1301 |
+
model_lock=model_lock,
|
1302 |
+
force_langchain_evaluate=force_langchain_evaluate,
|
1303 |
+
model_state_none=model_state_none,
|
1304 |
)
|
1305 |
try:
|
1306 |
for ret1 in ret:
|
|
|
1345 |
debug=False,
|
1346 |
concurrency_count=None,
|
1347 |
save_dir=None,
|
1348 |
+
sanitize_bot_response=False,
|
1349 |
model_state0=None,
|
1350 |
memory_restriction_level=None,
|
1351 |
+
max_max_new_tokens=None,
|
1352 |
+
is_public=None,
|
1353 |
+
max_max_time=None,
|
1354 |
raise_generate_gpu_exceptions=None,
|
1355 |
chat_context=None,
|
1356 |
lora_weights=None,
|
|
|
1367 |
text_limit=None,
|
1368 |
verbose=False,
|
1369 |
cli=False,
|
1370 |
+
reverse_docs=True,
|
1371 |
+
use_cache=None,
|
1372 |
+
auto_reduce_chunks=None,
|
1373 |
+
max_chunks=None,
|
1374 |
+
model_lock=None,
|
1375 |
+
force_langchain_evaluate=None,
|
1376 |
+
model_state_none=None,
|
1377 |
):
|
1378 |
# ensure passed these
|
1379 |
assert concurrency_count is not None
|
|
|
1394 |
locals_dict = locals().copy()
|
1395 |
locals_dict.pop('model_state', None)
|
1396 |
locals_dict.pop('model_state0', None)
|
1397 |
+
locals_dict.pop('model_states', None)
|
1398 |
print(locals_dict)
|
1399 |
|
1400 |
+
no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" \
|
1401 |
+
"Then start New Conversation"
|
1402 |
|
1403 |
+
if model_state is None:
|
1404 |
+
model_state = model_state_none.copy()
|
1405 |
if model_state0 is None:
|
1406 |
# e.g. for no gradio case, set dummy value, else should be set
|
1407 |
+
model_state0 = model_state_none.copy()
|
1408 |
+
|
1409 |
+
# model_state['model] is only 'model' if should use model_state0
|
1410 |
+
# model could also be None
|
1411 |
+
have_model_lock = model_lock is not None
|
1412 |
+
have_fresh_model = model_state['model'] not in [None, 'model', no_model_str]
|
1413 |
+
# for gradio UI control, expect model_state and model_state0 to match, so if have_model_lock=True, then should have_fresh_model=True
|
1414 |
+
# but gradio API control will only use nochat api etc. and won't use fresh model, so can't assert in general
|
1415 |
+
# if have_model_lock:
|
1416 |
+
# assert have_fresh_model, "Expected model_state and model_state0 to match if have_model_lock"
|
1417 |
+
have_cli_model = model_state0['model'] not in [None, 'model', no_model_str]
|
1418 |
+
|
1419 |
+
if have_fresh_model:
|
1420 |
+
# USE FRESH MODEL
|
1421 |
+
if not have_model_lock:
|
1422 |
+
# model_state0 is just one of model_state if model_lock, so don't nuke
|
1423 |
+
# try to free-up original model (i.e. list was passed as reference)
|
1424 |
+
if model_state0['model'] and hasattr(model_state0['model'], 'cpu'):
|
1425 |
+
model_state0['model'].cpu()
|
1426 |
+
model_state0['model'] = None
|
1427 |
+
# try to free-up original tokenizer (i.e. list was passed as reference)
|
1428 |
+
if model_state0['tokenizer']:
|
1429 |
+
model_state0['tokenizer'] = None
|
1430 |
+
clear_torch_cache()
|
1431 |
+
chosen_model_state = model_state
|
1432 |
+
elif have_cli_model:
|
1433 |
+
# USE MODEL SETUP AT CLI
|
1434 |
+
assert isinstance(model_state['model'], str) # expect no fresh model
|
1435 |
+
chosen_model_state = model_state0
|
1436 |
else:
|
1437 |
raise AssertionError(no_model_msg)
|
1438 |
+
# get variables
|
1439 |
+
model = chosen_model_state['model']
|
1440 |
+
tokenizer = chosen_model_state['tokenizer']
|
1441 |
+
device = chosen_model_state['device']
|
1442 |
+
base_model = chosen_model_state['base_model']
|
1443 |
+
tokenizer_base_model = chosen_model_state['tokenizer_base_model']
|
1444 |
+
lora_weights = chosen_model_state['lora_weights']
|
1445 |
+
inference_server = chosen_model_state['inference_server']
|
1446 |
+
# prefer use input from API over model state
|
1447 |
+
prompt_type = prompt_type or chosen_model_state['prompt_type']
|
1448 |
+
prompt_dict = prompt_dict or chosen_model_state['prompt_dict']
|
1449 |
|
1450 |
if base_model is None:
|
1451 |
raise AssertionError(no_model_msg)
|
|
|
1459 |
instruction = instruction_nochat
|
1460 |
iinput = iinput_nochat
|
1461 |
|
1462 |
+
# in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice
|
1463 |
+
model_lower = base_model.lower()
|
1464 |
+
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
1465 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
1466 |
+
if verbose:
|
1467 |
+
print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
|
1468 |
+
assert prompt_type is not None, "prompt_type was None"
|
1469 |
+
|
1470 |
+
# Control generation hyperparameters
|
1471 |
+
# adjust for bad inputs, e.g. in case also come from API that doesn't get constrained by gradio sliders
|
1472 |
+
# below is for TGI server, not required for HF transformers
|
1473 |
+
# limits are chosen similar to gradio_runner.py sliders/numbers
|
1474 |
+
top_p = min(max(1e-3, top_p), 1.0 - 1e-3)
|
1475 |
+
top_k = min(max(1, int(top_k)), 100)
|
1476 |
+
temperature = min(max(0.01, temperature), 2.0)
|
1477 |
+
# FIXME: https://github.com/h2oai/h2ogpt/issues/106
|
1478 |
+
num_beams = 1 if stream_output else num_beams # See max_beams in gradio_runner
|
1479 |
+
max_max_new_tokens = get_max_max_new_tokens(chosen_model_state,
|
1480 |
+
memory_restriction_level=memory_restriction_level,
|
1481 |
+
max_new_tokens=max_new_tokens,
|
1482 |
+
max_max_new_tokens=max_max_new_tokens)
|
1483 |
+
model_max_length = get_model_max_length(chosen_model_state)
|
1484 |
+
max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens)
|
1485 |
+
min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens)
|
1486 |
+
max_time = min(max(0, max_time), max_max_time)
|
1487 |
+
repetition_penalty = min(max(0.01, repetition_penalty), 3.0)
|
1488 |
+
num_return_sequences = 1 if chat else min(max(1, int(num_return_sequences)), 10)
|
1489 |
+
min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
|
1490 |
+
top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs)
|
1491 |
+
chunk_size = min(max(128, int(chunk_size)), 2048)
|
1492 |
if not context:
|
1493 |
# get hidden context if have one
|
1494 |
context = get_context(chat_context, prompt_type)
|
1495 |
|
1496 |
+
# restrict instruction, typically what has large input
|
1497 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
1498 |
+
instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer)
|
1499 |
+
context, num_prompt_tokens2 = H2OTextGenerationPipeline.limit_prompt(context, tokenizer)
|
1500 |
+
iinput, num_prompt_tokens3 = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer)
|
1501 |
+
num_prompt_tokens = (num_prompt_tokens1 or 0) + (num_prompt_tokens2 or 0) + (num_prompt_tokens3 or 0)
|
1502 |
+
|
1503 |
+
# get prompt
|
1504 |
prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output)
|
1505 |
data_point = dict(context=context, instruction=instruction, input=iinput)
|
1506 |
prompt = prompter.generate_prompt(data_point)
|
|
|
1513 |
db1 = dbs[langchain_mode]
|
1514 |
else:
|
1515 |
db1 = None
|
1516 |
+
do_langchain_path = langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and \
|
1517 |
+
db1 is not None or \
|
1518 |
+
base_model in non_hf_types or \
|
1519 |
+
force_langchain_evaluate
|
1520 |
+
if do_langchain_path:
|
1521 |
query = instruction if not iinput else "%s\n%s" % (instruction, iinput)
|
1522 |
outr = ""
|
1523 |
# use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close
|
1524 |
from gpt_langchain import run_qa_db
|
1525 |
+
gen_hyper_langchain = dict(do_sample=do_sample,
|
1526 |
+
temperature=temperature,
|
1527 |
+
repetition_penalty=repetition_penalty,
|
1528 |
+
top_k=top_k,
|
1529 |
+
top_p=top_p,
|
1530 |
+
num_beams=num_beams,
|
1531 |
+
min_new_tokens=min_new_tokens,
|
1532 |
+
max_new_tokens=max_new_tokens,
|
1533 |
+
early_stopping=early_stopping,
|
1534 |
+
max_time=max_time,
|
1535 |
+
num_return_sequences=num_return_sequences,
|
1536 |
+
)
|
1537 |
for r in run_qa_db(query=query,
|
1538 |
model_name=base_model, model=model, tokenizer=tokenizer,
|
1539 |
+
inference_server=inference_server,
|
1540 |
stream_output=stream_output,
|
1541 |
prompter=prompter,
|
1542 |
load_db_if_exists=load_db_if_exists,
|
|
|
1556 |
db_type=db_type,
|
1557 |
top_k_docs=top_k_docs,
|
1558 |
|
1559 |
+
**gen_hyper_langchain,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1560 |
|
1561 |
prompt_type=prompt_type,
|
1562 |
prompt_dict=prompt_dict,
|
1563 |
n_jobs=n_jobs,
|
1564 |
verbose=verbose,
|
1565 |
cli=cli,
|
1566 |
+
sanitize_bot_response=sanitize_bot_response,
|
1567 |
+
reverse_docs=reverse_docs,
|
1568 |
+
|
1569 |
+
lora_weights=lora_weights,
|
1570 |
+
|
1571 |
+
auto_reduce_chunks=auto_reduce_chunks,
|
1572 |
+
max_chunks=max_chunks,
|
1573 |
):
|
1574 |
outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer
|
1575 |
yield dict(response=outr, sources=extra)
|
1576 |
if save_dir:
|
1577 |
+
extra_dict = gen_hyper_langchain.copy()
|
1578 |
+
extra_dict.update(prompt_type=prompt_type, inference_server=inference_server,
|
1579 |
+
langchain_mode=langchain_mode, document_choice=document_choice,
|
1580 |
+
num_prompt_tokens=num_prompt_tokens)
|
1581 |
+
save_generate_output(prompt=query, output=outr, base_model=base_model, save_dir=save_dir,
|
1582 |
+
where_from='run_qa_db',
|
1583 |
+
extra_dict=extra_dict)
|
1584 |
if verbose:
|
1585 |
print(
|
1586 |
'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1),
|
|
|
1593 |
clear_torch_cache()
|
1594 |
return
|
1595 |
|
1596 |
+
if inference_server.startswith('openai') or inference_server.startswith('http'):
|
1597 |
+
if inference_server.startswith('openai'):
|
1598 |
+
import openai
|
1599 |
+
where_from = "openai_client"
|
1600 |
+
|
1601 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
1602 |
+
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
1603 |
+
# OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so.
|
1604 |
+
max_new_tokens_openai = min(max_new_tokens, model_max_length - num_prompt_tokens)
|
1605 |
+
gen_server_kwargs = dict(temperature=temperature if do_sample else 0,
|
1606 |
+
max_tokens=max_new_tokens_openai,
|
1607 |
+
top_p=top_p if do_sample else 1,
|
1608 |
+
frequency_penalty=0,
|
1609 |
+
n=num_return_sequences,
|
1610 |
+
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
1611 |
+
)
|
1612 |
+
if inference_server == 'openai':
|
1613 |
+
response = openai.Completion.create(
|
1614 |
+
model=base_model,
|
1615 |
+
prompt=prompt,
|
1616 |
+
**gen_server_kwargs,
|
1617 |
+
stop=stop_sequences,
|
1618 |
+
stream=stream_output,
|
1619 |
+
)
|
1620 |
+
if not stream_output:
|
1621 |
+
text = response['choices'][0]['text']
|
1622 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1623 |
+
sanitize_bot_response=sanitize_bot_response),
|
1624 |
+
sources='')
|
1625 |
+
else:
|
1626 |
+
collected_events = []
|
1627 |
+
text = ''
|
1628 |
+
for event in response:
|
1629 |
+
collected_events.append(event) # save the event response
|
1630 |
+
event_text = event['choices'][0]['text'] # extract the text
|
1631 |
+
text += event_text # append the text
|
1632 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1633 |
+
sanitize_bot_response=sanitize_bot_response),
|
1634 |
+
sources='')
|
1635 |
+
elif inference_server == 'openai_chat':
|
1636 |
+
response = openai.ChatCompletion.create(
|
1637 |
+
model=base_model,
|
1638 |
+
messages=[
|
1639 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
1640 |
+
{'role': 'user',
|
1641 |
+
'content': prompt,
|
1642 |
+
}
|
1643 |
+
],
|
1644 |
+
stream=stream_output,
|
1645 |
+
**gen_server_kwargs,
|
1646 |
+
)
|
1647 |
+
if not stream_output:
|
1648 |
+
text = response["choices"][0]["message"]["content"]
|
1649 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1650 |
+
sanitize_bot_response=sanitize_bot_response),
|
1651 |
+
sources='')
|
1652 |
+
else:
|
1653 |
+
text = ""
|
1654 |
+
for chunk in response:
|
1655 |
+
delta = chunk["choices"][0]["delta"]
|
1656 |
+
if 'content' in delta:
|
1657 |
+
text += delta['content']
|
1658 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1659 |
+
sanitize_bot_response=sanitize_bot_response),
|
1660 |
+
sources='')
|
1661 |
+
else:
|
1662 |
+
raise RuntimeError("No such OpenAI mode: %s" % inference_server)
|
1663 |
+
elif inference_server.startswith('http'):
|
1664 |
+
inference_server, headers = get_hf_server(inference_server)
|
1665 |
+
from gradio_utils.grclient import GradioClient
|
1666 |
+
from text_generation import Client as HFClient
|
1667 |
+
if isinstance(model, GradioClient):
|
1668 |
+
gr_client = model
|
1669 |
+
hf_client = None
|
1670 |
+
elif isinstance(model, HFClient):
|
1671 |
+
gr_client = None
|
1672 |
+
hf_client = model
|
1673 |
+
else:
|
1674 |
+
inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server)
|
1675 |
+
|
1676 |
+
# quick sanity check to avoid long timeouts, just see if can reach server
|
1677 |
+
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
|
1678 |
+
|
1679 |
+
if gr_client is not None:
|
1680 |
+
# Note: h2oGPT gradio server could handle input token size issues for prompt,
|
1681 |
+
# but best to handle here so send less data to server
|
1682 |
+
|
1683 |
+
chat_client = False
|
1684 |
+
where_from = "gr_client"
|
1685 |
+
client_langchain_mode = 'Disabled'
|
1686 |
+
gen_server_kwargs = dict(temperature=temperature,
|
1687 |
+
top_p=top_p,
|
1688 |
+
top_k=top_k,
|
1689 |
+
num_beams=num_beams,
|
1690 |
+
max_new_tokens=max_new_tokens,
|
1691 |
+
min_new_tokens=min_new_tokens,
|
1692 |
+
early_stopping=early_stopping,
|
1693 |
+
max_time=max_time,
|
1694 |
+
repetition_penalty=repetition_penalty,
|
1695 |
+
num_return_sequences=num_return_sequences,
|
1696 |
+
do_sample=do_sample,
|
1697 |
+
chat=chat_client,
|
1698 |
+
)
|
1699 |
+
# account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection
|
1700 |
+
if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value,
|
1701 |
+
str(PromptType.plain.value)]:
|
1702 |
+
# if our prompt is plain, assume either correct or gradio server knows different prompt type,
|
1703 |
+
# so pass empty prompt_Type
|
1704 |
+
gr_prompt_type = ''
|
1705 |
+
gr_prompt_dict = ''
|
1706 |
+
gr_prompt = prompt # already prepared prompt
|
1707 |
+
gr_context = ''
|
1708 |
+
gr_iinput = ''
|
1709 |
+
else:
|
1710 |
+
# if already have prompt_type that is not plain, None, or '', then already applied some prompting
|
1711 |
+
# But assume server can handle prompting, and need to avoid double-up.
|
1712 |
+
# Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle
|
1713 |
+
# So avoid "prompt" and let gradio server reconstruct from prompt_type we passed
|
1714 |
+
# Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed,
|
1715 |
+
# because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter
|
1716 |
+
# since those won't appear
|
1717 |
+
gr_context = context
|
1718 |
+
gr_prompt = instruction
|
1719 |
+
gr_iinput = iinput
|
1720 |
+
gr_prompt_type = prompt_type
|
1721 |
+
gr_prompt_dict = prompt_dict
|
1722 |
+
client_kwargs = dict(instruction=gr_prompt if chat_client else '', # only for chat=True
|
1723 |
+
iinput=gr_iinput, # only for chat=True
|
1724 |
+
context=gr_context,
|
1725 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
1726 |
+
# but leave stream_output=False for simple input/output mode
|
1727 |
+
stream_output=stream_output,
|
1728 |
+
|
1729 |
+
**gen_server_kwargs,
|
1730 |
+
|
1731 |
+
prompt_type=gr_prompt_type,
|
1732 |
+
prompt_dict=gr_prompt_dict,
|
1733 |
+
|
1734 |
+
instruction_nochat=gr_prompt if not chat_client else '',
|
1735 |
+
iinput_nochat=gr_iinput, # only for chat=False
|
1736 |
+
langchain_mode=client_langchain_mode,
|
1737 |
+
top_k_docs=top_k_docs,
|
1738 |
+
chunk=chunk,
|
1739 |
+
chunk_size=chunk_size,
|
1740 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
1741 |
+
)
|
1742 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
1743 |
+
if not stream_output:
|
1744 |
+
res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
|
1745 |
+
res_dict = ast.literal_eval(res)
|
1746 |
+
text = res_dict['response']
|
1747 |
+
sources = res_dict['sources']
|
1748 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1749 |
+
sanitize_bot_response=sanitize_bot_response),
|
1750 |
+
sources=sources)
|
1751 |
+
else:
|
1752 |
+
job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
|
1753 |
+
text = ''
|
1754 |
+
sources = ''
|
1755 |
+
res_dict = dict(response=text, sources=sources)
|
1756 |
+
while not job.done():
|
1757 |
+
outputs_list = job.communicator.job.outputs
|
1758 |
+
if outputs_list:
|
1759 |
+
res = job.communicator.job.outputs[-1]
|
1760 |
+
res_dict = ast.literal_eval(res)
|
1761 |
+
text = res_dict['response']
|
1762 |
+
sources = res_dict['sources']
|
1763 |
+
if gr_prompt_type == 'plain':
|
1764 |
+
# then gradio server passes back full prompt + text
|
1765 |
+
prompt_and_text = text
|
1766 |
+
else:
|
1767 |
+
prompt_and_text = prompt + text
|
1768 |
+
yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt,
|
1769 |
+
sanitize_bot_response=sanitize_bot_response),
|
1770 |
+
sources=sources)
|
1771 |
+
time.sleep(0.01)
|
1772 |
+
# ensure get last output to avoid race
|
1773 |
+
res_all = job.outputs()
|
1774 |
+
if len(res_all) > 0:
|
1775 |
+
res = res_all[-1]
|
1776 |
+
res_dict = ast.literal_eval(res)
|
1777 |
+
text = res_dict['response']
|
1778 |
+
sources = res_dict['sources']
|
1779 |
+
else:
|
1780 |
+
# go with old text if last call didn't work
|
1781 |
+
e = job.future._exception
|
1782 |
+
if e is not None:
|
1783 |
+
stre = str(e)
|
1784 |
+
strex = ''.join(traceback.format_tb(e.__traceback__))
|
1785 |
+
else:
|
1786 |
+
stre = ''
|
1787 |
+
strex = ''
|
1788 |
+
|
1789 |
+
print("Bad final response: %s %s %s %s %s: %s %s" % (base_model, inference_server,
|
1790 |
+
res_all, prompt, text, stre, strex),
|
1791 |
+
flush=True)
|
1792 |
+
if gr_prompt_type == 'plain':
|
1793 |
+
# then gradio server passes back full prompt + text
|
1794 |
+
prompt_and_text = text
|
1795 |
+
else:
|
1796 |
+
prompt_and_text = prompt + text
|
1797 |
+
yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt,
|
1798 |
+
sanitize_bot_response=sanitize_bot_response),
|
1799 |
+
sources=sources)
|
1800 |
+
elif hf_client:
|
1801 |
+
# HF inference server needs control over input tokens
|
1802 |
+
where_from = "hf_client"
|
1803 |
+
|
1804 |
+
# prompt must include all human-bot like tokens, already added by prompt
|
1805 |
+
# https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types
|
1806 |
+
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
1807 |
+
gen_server_kwargs = dict(do_sample=do_sample,
|
1808 |
+
max_new_tokens=max_new_tokens,
|
1809 |
+
# best_of=None,
|
1810 |
+
repetition_penalty=repetition_penalty,
|
1811 |
+
return_full_text=True,
|
1812 |
+
seed=SEED,
|
1813 |
+
stop_sequences=stop_sequences,
|
1814 |
+
temperature=temperature,
|
1815 |
+
top_k=top_k,
|
1816 |
+
top_p=top_p,
|
1817 |
+
# truncate=False, # behaves oddly
|
1818 |
+
# typical_p=top_p,
|
1819 |
+
# watermark=False,
|
1820 |
+
# decoder_input_details=False,
|
1821 |
+
)
|
1822 |
+
# work-around for timeout at constructor time, will be issue if multi-threading,
|
1823 |
+
# so just do something reasonable or max_time if larger
|
1824 |
+
# lower bound because client is re-used if multi-threading
|
1825 |
+
hf_client.timeout = max(300, max_time)
|
1826 |
+
if not stream_output:
|
1827 |
+
text = hf_client.generate(prompt, **gen_server_kwargs).generated_text
|
1828 |
+
yield dict(response=prompter.get_response(text, prompt=prompt,
|
1829 |
+
sanitize_bot_response=sanitize_bot_response),
|
1830 |
+
sources='')
|
1831 |
+
else:
|
1832 |
+
text = ""
|
1833 |
+
for response in hf_client.generate_stream(prompt, **gen_server_kwargs):
|
1834 |
+
if not response.token.special:
|
1835 |
+
# stop_sequences
|
1836 |
+
text_chunk = response.token.text
|
1837 |
+
text += text_chunk
|
1838 |
+
yield dict(response=prompter.get_response(prompt + text, prompt=prompt,
|
1839 |
+
sanitize_bot_response=sanitize_bot_response),
|
1840 |
+
sources='')
|
1841 |
+
else:
|
1842 |
+
raise RuntimeError("Failed to get client: %s" % inference_server)
|
1843 |
+
else:
|
1844 |
+
raise RuntimeError("No such inference_server %s" % inference_server)
|
1845 |
+
|
1846 |
+
if save_dir and text:
|
1847 |
+
# save prompt + new text
|
1848 |
+
extra_dict = gen_server_kwargs.copy()
|
1849 |
+
extra_dict.update(dict(inference_server=inference_server, num_prompt_tokens=num_prompt_tokens))
|
1850 |
+
save_generate_output(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir,
|
1851 |
+
where_from=where_from, extra_dict=extra_dict)
|
1852 |
+
return
|
1853 |
+
else:
|
1854 |
+
assert not inference_server, "inferene_server=%s not supported" % inference_server
|
1855 |
+
|
1856 |
if isinstance(tokenizer, str):
|
1857 |
# pipeline
|
1858 |
if tokenizer == "summarization":
|
|
|
1866 |
assert src_lang is not None
|
1867 |
tokenizer.src_lang = languages_covered()[src_lang]
|
1868 |
|
1869 |
+
stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device,
|
1870 |
+
model_max_length=tokenizer.model_max_length)
|
1871 |
+
|
1872 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1873 |
if debug and len(inputs["input_ids"]) > 0:
|
1874 |
print('input_ids length', len(inputs["input_ids"][0]), flush=True)
|
1875 |
input_ids = inputs["input_ids"].to(device)
|
1876 |
# CRITICAL LIMIT else will fail
|
1877 |
max_max_tokens = tokenizer.model_max_length
|
1878 |
+
max_input_tokens = max_max_tokens - min_new_tokens
|
1879 |
+
# NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py
|
1880 |
input_ids = input_ids[:, -max_input_tokens:]
|
1881 |
+
# required for falcon if multiple threads or asyncio accesses to model during generation
|
1882 |
+
if use_cache is None:
|
1883 |
+
use_cache = False if 'falcon' in base_model else True
|
1884 |
+
gen_config_kwargs = dict(temperature=float(temperature),
|
1885 |
+
top_p=float(top_p),
|
1886 |
+
top_k=top_k,
|
1887 |
+
num_beams=num_beams,
|
1888 |
+
do_sample=do_sample,
|
1889 |
+
repetition_penalty=float(repetition_penalty),
|
1890 |
+
num_return_sequences=num_return_sequences,
|
1891 |
+
renormalize_logits=True,
|
1892 |
+
remove_invalid_values=True,
|
1893 |
+
use_cache=use_cache,
|
1894 |
+
)
|
1895 |
+
token_ids = ['eos_token_id', 'pad_token_id', 'bos_token_id', 'cls_token_id', 'sep_token_id']
|
1896 |
+
for token_id in token_ids:
|
1897 |
+
if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None:
|
1898 |
+
gen_config_kwargs.update({token_id: getattr(tokenizer, token_id)})
|
1899 |
+
generation_config = GenerationConfig(**gen_config_kwargs)
|
1900 |
|
1901 |
gen_kwargs = dict(input_ids=input_ids,
|
1902 |
generation_config=generation_config,
|
|
|
1915 |
tgt_lang = languages_covered()[tgt_lang]
|
1916 |
gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]))
|
1917 |
else:
|
1918 |
+
token_ids = ['eos_token_id', 'bos_token_id', 'pad_token_id']
|
1919 |
+
for token_id in token_ids:
|
1920 |
+
if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None:
|
1921 |
+
gen_kwargs.update({token_id: getattr(tokenizer, token_id)})
|
1922 |
|
1923 |
decoder_kwargs = dict(skip_special_tokens=True,
|
1924 |
clean_up_tokenization_spaces=True)
|
|
|
1934 |
)
|
1935 |
|
1936 |
with torch.no_grad():
|
1937 |
+
have_lora_weights = lora_weights not in [no_lora_str, '', None]
|
1938 |
+
context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
|
1939 |
with context_class_cast(device):
|
1940 |
# protection for gradio not keeping track of closed users,
|
1941 |
# else hit bitsandbytes lack of thread safety:
|
|
|
2018 |
if outputs and len(outputs) >= 1:
|
2019 |
decoded_output = prompt + outputs[0]
|
2020 |
if save_dir and decoded_output:
|
2021 |
+
extra_dict = gen_config_kwargs.copy()
|
2022 |
+
extra_dict.update(dict(num_prompt_tokens=num_prompt_tokens))
|
2023 |
+
save_generate_output(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir,
|
2024 |
+
where_from="evaluate_%s" % str(stream_output),
|
2025 |
+
extra_dict=gen_config_kwargs)
|
2026 |
if verbose:
|
2027 |
print('Post-Generate: %s decoded_output: %s' % (
|
2028 |
str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True)
|
|
|
2041 |
if memory_restriction_level > 0:
|
2042 |
max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256
|
2043 |
else:
|
2044 |
+
# at least give room for 1 paragraph output
|
2045 |
max_length_tokenize = model_max_length - 256
|
2046 |
cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens
|
2047 |
output_smallest = 30 * 4
|
|
|
2146 |
if model_lower:
|
2147 |
print(f"Using Model {model_lower}", flush=True)
|
2148 |
else:
|
2149 |
+
if verbose:
|
2150 |
+
print("No model defined yet", flush=True)
|
2151 |
|
2152 |
min_new_tokens = min_new_tokens if min_new_tokens is not None else 0
|
2153 |
early_stopping = early_stopping if early_stopping is not None else False
|
|
|
2203 |
use_defaults = True
|
2204 |
else:
|
2205 |
if chat:
|
2206 |
+
placeholder_instruction = ""
|
2207 |
else:
|
2208 |
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
2209 |
placeholder_input = ""
|
2210 |
+
if model_lower in inv_prompt_type_to_model_lower:
|
2211 |
+
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
2212 |
+
elif model_lower:
|
2213 |
+
# default is plain, because might rely upon trust_remote_code to handle prompting
|
2214 |
prompt_type = prompt_type or 'plain'
|
2215 |
else:
|
2216 |
prompt_type = ''
|
|
|
2318 |
|
2319 |
# get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format
|
2320 |
prompt_dict, error0 = get_prompt(prompt_type, prompt_dict,
|
2321 |
+
chat=False, context='', reduced=False, making_context=False, return_dict=True)
|
2322 |
if error0:
|
2323 |
raise RuntimeError("Prompt wrong: %s" % error0)
|
2324 |
|
|
|
2371 |
if 'Expected all tensors to be on the same device' in str(e) or \
|
2372 |
'expected scalar type Half but found Float' in str(e) or \
|
2373 |
'probability tensor contains either' in str(e) or \
|
2374 |
+
'cublasLt ran into an error!' in str(e) or \
|
2375 |
+
'device-side assert triggered' in str(e):
|
2376 |
print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)),
|
2377 |
flush=True)
|
2378 |
traceback.print_exc()
|
|
|
2405 |
assert k in kwargs, "Missing %s" % k
|
2406 |
|
2407 |
|
2408 |
+
def get_model_max_length(model_state):
|
2409 |
+
if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
|
2410 |
+
return model_state['tokenizer'].model_max_length
|
2411 |
+
else:
|
2412 |
+
return 2048
|
2413 |
+
|
2414 |
+
|
2415 |
def get_max_max_new_tokens(model_state, **kwargs):
|
2416 |
+
if not isinstance(model_state['tokenizer'], (str, types.NoneType)):
|
2417 |
+
max_max_new_tokens = model_state['tokenizer'].model_max_length
|
2418 |
+
else:
|
2419 |
+
max_max_new_tokens = None
|
2420 |
+
|
2421 |
+
if kwargs['max_max_new_tokens'] is not None and max_max_new_tokens is not None:
|
2422 |
+
return min(max_max_new_tokens, kwargs['max_max_new_tokens'])
|
2423 |
+
elif kwargs['max_max_new_tokens'] is not None:
|
2424 |
+
return kwargs['max_max_new_tokens']
|
2425 |
elif kwargs['memory_restriction_level'] == 1:
|
2426 |
+
return 768
|
2427 |
elif kwargs['memory_restriction_level'] == 2:
|
2428 |
+
return 512
|
2429 |
elif kwargs['memory_restriction_level'] >= 3:
|
2430 |
+
return 256
|
2431 |
else:
|
2432 |
+
# FIXME: Need to update after new model loaded, so user can control with slider
|
2433 |
+
return 2048
|
|
|
|
|
|
|
|
|
2434 |
|
2435 |
|
2436 |
+
def get_minmax_top_k_docs(is_public):
|
2437 |
+
if is_public:
|
2438 |
+
min_top_k_docs = 1
|
2439 |
+
max_top_k_docs = 3
|
2440 |
+
label_top_k_docs = "Number of document chunks"
|
2441 |
+
else:
|
2442 |
+
min_top_k_docs = -1
|
2443 |
+
max_top_k_docs = 100
|
2444 |
+
label_top_k_docs = "Number of document chunks (-1 = auto fill model context)"
|
2445 |
+
return min_top_k_docs, max_top_k_docs, label_top_k_docs
|
2446 |
+
|
2447 |
+
|
2448 |
+
def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1,
|
2449 |
+
memory_restriction_level1, keep_sources_in_context1):
|
2450 |
+
"""
|
2451 |
+
consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair
|
2452 |
+
:param history:
|
2453 |
+
:param langchain_mode1:
|
2454 |
+
:param prompt_type1:
|
2455 |
+
:param prompt_dict1:
|
2456 |
+
:param chat1:
|
2457 |
+
:param model_max_length1:
|
2458 |
+
:param memory_restriction_level1:
|
2459 |
+
:param keep_sources_in_context1:
|
2460 |
+
:return:
|
2461 |
+
"""
|
2462 |
+
# ensure output will be unique to models
|
2463 |
+
_, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1,
|
2464 |
+
for_context=True, model_max_length=model_max_length1)
|
2465 |
+
context1 = ''
|
2466 |
+
if max_prompt_length is not None and langchain_mode1 not in ['LLM']:
|
2467 |
+
context1 = ''
|
2468 |
+
# - 1 below because current instruction already in history from user()
|
2469 |
+
for histi in range(0, len(history) - 1):
|
2470 |
+
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
2471 |
+
prompt, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt(data_point,
|
2472 |
+
prompt_type1,
|
2473 |
+
prompt_dict1,
|
2474 |
+
chat1,
|
2475 |
+
reduced=True,
|
2476 |
+
making_context=True)
|
2477 |
+
# md -> back to text, maybe not super important if model trained enough
|
2478 |
+
if not keep_sources_in_context1 and langchain_mode1 != 'Disabled' and prompt.find(source_prefix) >= 0:
|
2479 |
+
# FIXME: This is relatively slow even for small amount of text, like 0.3s each history item
|
2480 |
+
import re
|
2481 |
+
prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt,
|
2482 |
+
flags=re.DOTALL)
|
2483 |
+
if prompt.endswith('\n<p>'):
|
2484 |
+
prompt = prompt[:-4]
|
2485 |
+
prompt = prompt.replace('<br>', chat_turn_sep)
|
2486 |
+
if not prompt.endswith(chat_turn_sep):
|
2487 |
+
prompt += chat_turn_sep
|
2488 |
+
# most recent first, add older if can
|
2489 |
+
# only include desired chat history
|
2490 |
+
if len(prompt + context1) > max_prompt_length:
|
2491 |
+
break
|
2492 |
+
context1 += prompt
|
2493 |
+
|
2494 |
+
_, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt({}, prompt_type1, prompt_dict1,
|
2495 |
+
chat1, reduced=True,
|
2496 |
+
making_context=True)
|
2497 |
+
if context1 and not context1.endswith(chat_turn_sep):
|
2498 |
+
context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line
|
2499 |
+
return context1
|
2500 |
+
|
2501 |
+
|
2502 |
+
def entrypoint_main():
|
2503 |
"""
|
2504 |
Examples:
|
2505 |
|
|
|
2530 |
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
|
2531 |
"""
|
2532 |
fire.Fire(main)
|
2533 |
+
|
2534 |
+
|
2535 |
+
if __name__ == "__main__":
|
2536 |
+
entrypoint_main()
|
gpt4all_llm.py
CHANGED
@@ -1,24 +1,13 @@
|
|
1 |
import inspect
|
2 |
import os
|
3 |
-
import
|
4 |
from typing import Dict, Any, Optional, List
|
5 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
from pydantic import root_validator
|
7 |
from langchain.llms import gpt4all
|
8 |
from dotenv import dotenv_values
|
9 |
|
10 |
-
|
11 |
-
class FakeTokenizer:
|
12 |
-
model_max_length = 2048
|
13 |
-
|
14 |
-
def encode(self, x, *args, **kwargs):
|
15 |
-
return dict(input_ids=[x])
|
16 |
-
|
17 |
-
def decode(self, x, *args, **kwargs):
|
18 |
-
return x
|
19 |
-
|
20 |
-
def __call__(self, x, *args, **kwargs):
|
21 |
-
return self.encode(x, *args, **kwargs)
|
22 |
|
23 |
|
24 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
@@ -74,9 +63,9 @@ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
|
|
74 |
pass
|
75 |
|
76 |
|
77 |
-
def get_model_kwargs(env_kwargs, default_kwargs, cls):
|
78 |
# default from class
|
79 |
-
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
|
80 |
# from our defaults
|
81 |
model_kwargs.update(default_kwargs)
|
82 |
# from user defaults
|
@@ -94,10 +83,14 @@ def get_llm_gpt4all(model_name,
|
|
94 |
repetition_penalty=1.0,
|
95 |
top_k=40,
|
96 |
top_p=0.7,
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
98 |
env_gpt4all_file = ".env_gpt4all"
|
99 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
100 |
-
callbacks = [H2OStreamingStdOutCallbackHandler()]
|
101 |
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
|
102 |
default_kwargs = dict(context_erase=0.5,
|
103 |
n_batch=1,
|
@@ -114,21 +107,23 @@ def get_llm_gpt4all(model_name,
|
|
114 |
if model_name == 'llama':
|
115 |
cls = H2OLlamaCpp
|
116 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
117 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
118 |
-
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
|
119 |
llm = cls(**model_kwargs)
|
120 |
llm.client.verbose = verbose
|
121 |
elif model_name == 'gpt4all_llama':
|
122 |
cls = H2OGPT4All
|
123 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
124 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
125 |
-
model_kwargs.update(
|
|
|
126 |
llm = cls(**model_kwargs)
|
127 |
elif model_name == 'gptj':
|
128 |
cls = H2OGPT4All
|
129 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
130 |
-
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
|
131 |
-
model_kwargs.update(
|
|
|
132 |
llm = cls(**model_kwargs)
|
133 |
else:
|
134 |
raise RuntimeError("No such model_name %s" % model_name)
|
@@ -137,6 +132,7 @@ def get_llm_gpt4all(model_name,
|
|
137 |
|
138 |
class H2OGPT4All(gpt4all.GPT4All):
|
139 |
model: Any
|
|
|
140 |
"""Path to the pre-trained GPT4All model file."""
|
141 |
|
142 |
@root_validator()
|
@@ -156,9 +152,16 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
156 |
model_type=values["backend"],
|
157 |
allow_download=False,
|
158 |
)
|
|
|
|
|
|
|
159 |
else:
|
160 |
values["client"] = values["model"]
|
161 |
-
|
|
|
|
|
|
|
|
|
162 |
|
163 |
except ImportError:
|
164 |
raise ValueError(
|
@@ -172,12 +175,19 @@ class H2OGPT4All(gpt4all.GPT4All):
|
|
172 |
prompt: str,
|
173 |
stop: Optional[List[str]] = None,
|
174 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
175 |
) -> str:
|
176 |
# Roughly 4 chars per token if natural language
|
177 |
prompt = prompt[-self.n_ctx * 4:]
|
|
|
|
|
|
|
|
|
|
|
178 |
verbose = False
|
179 |
if verbose:
|
180 |
print("_call prompt: %s" % prompt, flush=True)
|
|
|
181 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
182 |
|
183 |
|
@@ -186,6 +196,7 @@ from langchain.llms import LlamaCpp
|
|
186 |
|
187 |
class H2OLlamaCpp(LlamaCpp):
|
188 |
model_path: Any
|
|
|
189 |
"""Path to the pre-trained GPT4All model file."""
|
190 |
|
191 |
@root_validator()
|
@@ -237,6 +248,7 @@ class H2OLlamaCpp(LlamaCpp):
|
|
237 |
prompt: str,
|
238 |
stop: Optional[List[str]] = None,
|
239 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
240 |
) -> str:
|
241 |
verbose = False
|
242 |
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
@@ -253,6 +265,33 @@ class H2OLlamaCpp(LlamaCpp):
|
|
253 |
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
254 |
num_prompt_tokens2 = len(prompt_tokens2)
|
255 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
|
|
|
|
|
|
|
|
|
|
256 |
if verbose:
|
257 |
print("_call prompt: %s" % prompt, flush=True)
|
258 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import inspect
|
2 |
import os
|
3 |
+
from functools import partial
|
4 |
from typing import Dict, Any, Optional, List
|
5 |
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
6 |
from pydantic import root_validator
|
7 |
from langchain.llms import gpt4all
|
8 |
from dotenv import dotenv_values
|
9 |
|
10 |
+
from utils import FakeTokenizer
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
def get_model_tokenizer_gpt4all(base_model, **kwargs):
|
|
|
63 |
pass
|
64 |
|
65 |
|
66 |
+
def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
|
67 |
# default from class
|
68 |
+
model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
|
69 |
# from our defaults
|
70 |
model_kwargs.update(default_kwargs)
|
71 |
# from user defaults
|
|
|
83 |
repetition_penalty=1.0,
|
84 |
top_k=40,
|
85 |
top_p=0.7,
|
86 |
+
streaming=False,
|
87 |
+
callbacks=None,
|
88 |
+
prompter=None,
|
89 |
+
verbose=False,
|
90 |
+
):
|
91 |
+
assert prompter is not None
|
92 |
env_gpt4all_file = ".env_gpt4all"
|
93 |
env_kwargs = dotenv_values(env_gpt4all_file)
|
|
|
94 |
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
|
95 |
default_kwargs = dict(context_erase=0.5,
|
96 |
n_batch=1,
|
|
|
107 |
if model_name == 'llama':
|
108 |
cls = H2OLlamaCpp
|
109 |
model_path = env_kwargs.pop('model_path_llama') if model is None else model
|
110 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
111 |
+
model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
|
112 |
llm = cls(**model_kwargs)
|
113 |
llm.client.verbose = verbose
|
114 |
elif model_name == 'gpt4all_llama':
|
115 |
cls = H2OGPT4All
|
116 |
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
|
117 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
118 |
+
model_kwargs.update(
|
119 |
+
dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
|
120 |
llm = cls(**model_kwargs)
|
121 |
elif model_name == 'gptj':
|
122 |
cls = H2OGPT4All
|
123 |
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
|
124 |
+
model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
|
125 |
+
model_kwargs.update(
|
126 |
+
dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
|
127 |
llm = cls(**model_kwargs)
|
128 |
else:
|
129 |
raise RuntimeError("No such model_name %s" % model_name)
|
|
|
132 |
|
133 |
class H2OGPT4All(gpt4all.GPT4All):
|
134 |
model: Any
|
135 |
+
prompter: Any
|
136 |
"""Path to the pre-trained GPT4All model file."""
|
137 |
|
138 |
@root_validator()
|
|
|
152 |
model_type=values["backend"],
|
153 |
allow_download=False,
|
154 |
)
|
155 |
+
if values["n_threads"] is not None:
|
156 |
+
# set n_threads
|
157 |
+
values["client"].model.set_thread_count(values["n_threads"])
|
158 |
else:
|
159 |
values["client"] = values["model"]
|
160 |
+
try:
|
161 |
+
values["backend"] = values["client"].model_type
|
162 |
+
except AttributeError:
|
163 |
+
# The below is for compatibility with GPT4All Python bindings <= 0.2.3.
|
164 |
+
values["backend"] = values["client"].model.model_type
|
165 |
|
166 |
except ImportError:
|
167 |
raise ValueError(
|
|
|
175 |
prompt: str,
|
176 |
stop: Optional[List[str]] = None,
|
177 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
178 |
+
**kwargs,
|
179 |
) -> str:
|
180 |
# Roughly 4 chars per token if natural language
|
181 |
prompt = prompt[-self.n_ctx * 4:]
|
182 |
+
|
183 |
+
# use instruct prompting
|
184 |
+
data_point = dict(context='', instruction=prompt, input='')
|
185 |
+
prompt = self.prompter.generate_prompt(data_point)
|
186 |
+
|
187 |
verbose = False
|
188 |
if verbose:
|
189 |
print("_call prompt: %s" % prompt, flush=True)
|
190 |
+
# FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
|
191 |
return super()._call(prompt, stop=stop, run_manager=run_manager)
|
192 |
|
193 |
|
|
|
196 |
|
197 |
class H2OLlamaCpp(LlamaCpp):
|
198 |
model_path: Any
|
199 |
+
prompter: Any
|
200 |
"""Path to the pre-trained GPT4All model file."""
|
201 |
|
202 |
@root_validator()
|
|
|
248 |
prompt: str,
|
249 |
stop: Optional[List[str]] = None,
|
250 |
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
251 |
+
**kwargs,
|
252 |
) -> str:
|
253 |
verbose = False
|
254 |
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
|
|
|
265 |
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
|
266 |
num_prompt_tokens2 = len(prompt_tokens2)
|
267 |
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
|
268 |
+
|
269 |
+
# use instruct prompting
|
270 |
+
data_point = dict(context='', instruction=prompt, input='')
|
271 |
+
prompt = self.prompter.generate_prompt(data_point)
|
272 |
+
|
273 |
if verbose:
|
274 |
print("_call prompt: %s" % prompt, flush=True)
|
275 |
+
|
276 |
+
if self.streaming:
|
277 |
+
text_callback = None
|
278 |
+
if run_manager:
|
279 |
+
text_callback = partial(
|
280 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
281 |
+
)
|
282 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
283 |
+
if text_callback:
|
284 |
+
text_callback(prompt)
|
285 |
+
text = ""
|
286 |
+
for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
|
287 |
+
text_chunk = token["choices"][0]["text"]
|
288 |
+
# self.stream already calls text_callback
|
289 |
+
# if text_callback:
|
290 |
+
# text_callback(text_chunk)
|
291 |
+
text += text_chunk
|
292 |
+
return text
|
293 |
+
else:
|
294 |
+
params = self._get_parameters(stop)
|
295 |
+
params = {**params, **kwargs}
|
296 |
+
result = self.client(prompt=prompt, **params)
|
297 |
+
return result["choices"][0]["text"]
|
gpt_langchain.py
CHANGED
@@ -1,31 +1,34 @@
|
|
|
|
1 |
import glob
|
2 |
import inspect
|
3 |
import os
|
4 |
import pathlib
|
5 |
import pickle
|
6 |
-
import queue
|
7 |
-
import random
|
8 |
import shutil
|
9 |
import subprocess
|
10 |
-
import sys
|
11 |
import tempfile
|
|
|
12 |
import traceback
|
|
|
13 |
import uuid
|
14 |
import zipfile
|
15 |
from collections import defaultdict
|
16 |
from datetime import datetime
|
17 |
from functools import reduce
|
18 |
from operator import concat
|
|
|
19 |
|
20 |
-
from joblib import
|
|
|
21 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
22 |
from tqdm import tqdm
|
23 |
|
24 |
-
from enums import DocumentChoices
|
25 |
-
from generate import gen_hyper
|
26 |
-
from prompter import non_hf_types, PromptType
|
27 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
28 |
-
get_device, ProgressParallel, remove, hash_file, clear_torch_cache
|
|
|
29 |
|
30 |
import_matplotlib()
|
31 |
|
@@ -40,11 +43,11 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
|
|
40 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
41 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
42 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
43 |
-
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
|
44 |
-
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
45 |
from langchain.chains.question_answering import load_qa_chain
|
46 |
from langchain.docstore.document import Document
|
47 |
-
from langchain import PromptTemplate
|
48 |
from langchain.vectorstores import Chroma
|
49 |
|
50 |
|
@@ -142,7 +145,7 @@ def add_to_db(db, sources, db_type='faiss',
|
|
142 |
return db, num_new_sources, []
|
143 |
db.add_documents(documents=sources)
|
144 |
elif db_type == 'chroma':
|
145 |
-
collection = db
|
146 |
# files we already have:
|
147 |
metadata_files = set([x['source'] for x in collection['metadatas']])
|
148 |
if avoid_dup_by_file:
|
@@ -157,6 +160,9 @@ def add_to_db(db, sources, db_type='faiss',
|
|
157 |
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
|
158 |
# avoid sources with same hash
|
159 |
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
|
|
|
|
|
|
|
160 |
# get new file names that match existing file names. delete existing files we are overridding
|
161 |
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
|
162 |
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
|
@@ -233,7 +239,7 @@ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformer
|
|
233 |
if use_openai_embedding:
|
234 |
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
|
235 |
from langchain.embeddings import OpenAIEmbeddings
|
236 |
-
embedding = OpenAIEmbeddings()
|
237 |
else:
|
238 |
# to ensure can fork without deadlock
|
239 |
from langchain.embeddings import HuggingFaceEmbeddings
|
@@ -260,8 +266,315 @@ def get_answer_from_sources(chain, sources, question):
|
|
260 |
)["output_text"]
|
261 |
|
262 |
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
265 |
do_sample=False,
|
266 |
temperature=0.1,
|
267 |
top_k=40,
|
@@ -276,47 +589,142 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
276 |
prompt_type=None,
|
277 |
prompt_dict=None,
|
278 |
prompter=None,
|
|
|
279 |
verbose=False,
|
280 |
):
|
281 |
-
if use_openai_model:
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
287 |
elif model_name in non_hf_types:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
288 |
from gpt4all_llm import get_llm_gpt4all
|
289 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
290 |
temperature=temperature,
|
291 |
repetition_penalty=repetition_penalty,
|
292 |
top_k=top_k,
|
293 |
top_p=top_p,
|
|
|
294 |
verbose=verbose,
|
|
|
|
|
295 |
)
|
296 |
-
streamer = None
|
297 |
-
prompt_type = 'plain'
|
298 |
else:
|
299 |
-
from transformers import AutoTokenizer, AutoModelForCausalLM
|
300 |
-
|
301 |
if model is None:
|
302 |
# only used if didn't pass model in
|
303 |
assert tokenizer is None
|
304 |
prompt_type = 'human_bot'
|
305 |
-
model_name
|
306 |
-
|
307 |
-
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
load_8bit = True
|
313 |
-
# FIXME: for now not to spread across hetero GPUs
|
314 |
-
# device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
|
315 |
-
device_map = {"": 0} if device == 'cuda' else "auto"
|
316 |
-
model = AutoModelForCausalLM.from_pretrained(model_name,
|
317 |
-
device_map=device_map,
|
318 |
-
torch_dtype=torch_dtype,
|
319 |
-
load_in_8bit=load_8bit)
|
320 |
|
321 |
max_max_tokens = tokenizer.model_max_length
|
322 |
gen_kwargs = dict(do_sample=do_sample,
|
@@ -331,7 +739,7 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
331 |
repetition_penalty=repetition_penalty,
|
332 |
num_return_sequences=num_return_sequences,
|
333 |
return_full_text=True,
|
334 |
-
handle_long_generation=
|
335 |
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
|
336 |
|
337 |
if stream_output:
|
@@ -348,10 +756,11 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
|
|
348 |
prompter=prompter,
|
349 |
prompt_type=prompt_type,
|
350 |
prompt_dict=prompt_dict,
|
351 |
-
sanitize_bot_response=
|
352 |
chat=False, stream_output=stream_output,
|
353 |
tokenizer=tokenizer,
|
354 |
-
|
|
|
355 |
**gen_kwargs)
|
356 |
# pipe.task = "text-generation"
|
357 |
# below makes it listen only to our prompt removal,
|
@@ -396,7 +805,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
|
|
396 |
data = json.load(open(filename, "rt"))
|
397 |
page_content = list(data["query"]["pages"].values())[0]["extract"]
|
398 |
if take_head is not None and text_limit is not None:
|
399 |
-
page_content = page_content[:text_limit] if take_head else page_content[
|
400 |
title_url = str(title).replace(' ', '_')
|
401 |
return Document(
|
402 |
page_content=page_content,
|
@@ -518,6 +927,21 @@ try:
|
|
518 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
519 |
have_pymupdf = False
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
image_types = ["png", "jpg", "jpeg"]
|
522 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
523 |
"md", "html",
|
@@ -535,7 +959,7 @@ file_types = non_image_types + image_types
|
|
535 |
def add_meta(docs1, file):
|
536 |
file_extension = pathlib.Path(file).suffix
|
537 |
hashid = hash_file(file)
|
538 |
-
if not isinstance(docs1, list):
|
539 |
docs1 = [docs1]
|
540 |
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
541 |
|
@@ -577,8 +1001,24 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
577 |
else:
|
578 |
docs1 = []
|
579 |
else:
|
|
|
|
|
580 |
docs1 = UnstructuredURLLoader(urls=[file]).load()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
581 |
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
|
|
|
582 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
583 |
elif is_txt:
|
584 |
base_path = "user_paste"
|
@@ -588,10 +1028,12 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
588 |
f.write(file)
|
589 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
590 |
doc1 = Document(page_content=file, metadata=metadata)
|
|
|
591 |
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
|
592 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
593 |
add_meta(docs1, file)
|
594 |
-
|
|
|
595 |
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
596 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
597 |
add_meta(docs1, file)
|
@@ -603,12 +1045,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
603 |
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
|
604 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
605 |
add_meta(docs1, file)
|
|
|
606 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
607 |
elif file.lower().endswith('.txt'):
|
608 |
# use UnstructuredFileLoader ?
|
609 |
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
610 |
# makes just one, but big one
|
611 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
|
|
612 |
add_meta(doc1, file)
|
613 |
elif file.lower().endswith('.rtf'):
|
614 |
docs1 = UnstructuredRTFLoader(file).load()
|
@@ -617,7 +1061,8 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
617 |
elif file.lower().endswith('.md'):
|
618 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
619 |
add_meta(docs1, file)
|
620 |
-
|
|
|
621 |
elif file.lower().endswith('.enex'):
|
622 |
docs1 = EverNoteLoader(file).load()
|
623 |
add_meta(doc1, file)
|
@@ -682,6 +1127,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
682 |
with open(file, "r") as f:
|
683 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
684 |
add_meta(doc1, file)
|
|
|
685 |
elif file.lower().endswith('.pdf'):
|
686 |
env_gpt4all_file = ".env_gpt4all"
|
687 |
from dotenv import dotenv_values
|
@@ -692,11 +1138,17 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
692 |
from langchain.document_loaders import PyMuPDFLoader
|
693 |
# load() still chunks by pages, but every page has title at start to help
|
694 |
doc1 = PyMuPDFLoader(file).load()
|
|
|
|
|
|
|
|
|
695 |
else:
|
696 |
# open-source fallback
|
697 |
# load() still chunks by pages, but every page has title at start to help
|
698 |
doc1 = PyPDFLoader(file).load()
|
|
|
699 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
|
|
700 |
add_meta(doc1, file)
|
701 |
elif file.lower().endswith('.csv'):
|
702 |
doc1 = CSVLoader(file).load()
|
@@ -704,6 +1156,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
|
|
704 |
elif file.lower().endswith('.py'):
|
705 |
doc1 = PythonLoader(file).load()
|
706 |
add_meta(doc1, file)
|
|
|
707 |
elif file.lower().endswith('.toml'):
|
708 |
doc1 = TomlLoader(file).load()
|
709 |
add_meta(doc1, file)
|
@@ -794,15 +1247,16 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
794 |
existing_files=[],
|
795 |
existing_hash_ids={},
|
796 |
):
|
|
|
797 |
globs_image_types = []
|
798 |
globs_non_image_types = []
|
799 |
if not path_or_paths and not url and not text:
|
800 |
return []
|
801 |
elif url:
|
802 |
-
globs_non_image_types = [url]
|
803 |
elif text:
|
804 |
-
globs_non_image_types = [text]
|
805 |
-
elif isinstance(path_or_paths, str):
|
806 |
# single path, only consume allowed files
|
807 |
path = path_or_paths
|
808 |
# Below globs should match patterns in file_to_doc()
|
@@ -811,8 +1265,11 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
|
|
811 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
812 |
for ftype in non_image_types]
|
813 |
else:
|
|
|
|
|
814 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
815 |
-
assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(
|
|
|
816 |
# reform out of allowed types
|
817 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
818 |
# could do below:
|
@@ -972,7 +1429,7 @@ def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
|
|
972 |
if load_embed(db) != (use_openai_embedding, hf_embedding_model):
|
973 |
print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
|
974 |
# handle embedding changes
|
975 |
-
db_get = db
|
976 |
sources = [Document(page_content=result[0], metadata=result[1] or {})
|
977 |
for result in zip(db_get['documents'], db_get['metadatas'])]
|
978 |
# delete index, has to be redone
|
@@ -1023,14 +1480,17 @@ def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_opena
|
|
1023 |
if changed_db:
|
1024 |
db = db_trial
|
1025 |
# only call persist if really changed db, else takes too long for large db
|
1026 |
-
db
|
1027 |
-
|
|
|
1028 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
1029 |
return db
|
1030 |
return None
|
1031 |
|
1032 |
|
1033 |
def clear_embedding(db):
|
|
|
|
|
1034 |
# don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
|
1035 |
db._embedding_function.client.cpu()
|
1036 |
clear_torch_cache()
|
@@ -1052,9 +1512,10 @@ def make_db(**langchain_kwargs):
|
|
1052 |
|
1053 |
|
1054 |
def save_embed(db, use_openai_embedding, hf_embedding_model):
|
1055 |
-
|
1056 |
-
|
1057 |
-
|
|
|
1058 |
return use_openai_embedding, hf_embedding_model
|
1059 |
|
1060 |
|
@@ -1201,28 +1662,90 @@ def _make_db(use_openai_embedding=False,
|
|
1201 |
return db, len(new_sources_metadata), new_sources_metadata
|
1202 |
|
1203 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1204 |
def get_existing_files(db):
|
1205 |
-
|
1206 |
-
metadata_sources = set([x['source'] for x in
|
1207 |
return metadata_sources
|
1208 |
|
1209 |
|
1210 |
def get_existing_hash_ids(db):
|
1211 |
-
|
1212 |
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
|
1213 |
-
metadata_hash_ids = {x['source']: x.get('hashid') for x in
|
1214 |
return metadata_hash_ids
|
1215 |
|
1216 |
|
1217 |
-
source_prefix = "Sources [Score | Link]:"
|
1218 |
-
source_postfix = "End Sources<p>"
|
1219 |
-
|
1220 |
-
|
1221 |
def run_qa_db(**kwargs):
|
1222 |
func_names = list(inspect.signature(_run_qa_db).parameters)
|
1223 |
# hard-coded defaults
|
1224 |
kwargs['answer_with_sources'] = True
|
1225 |
-
kwargs['sanitize_bot_response'] = True
|
1226 |
kwargs['show_rank'] = False
|
1227 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1228 |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
@@ -1240,7 +1763,7 @@ def _run_qa_db(query=None,
|
|
1240 |
user_path=None,
|
1241 |
detect_user_path_changes_every_query=False,
|
1242 |
db_type='faiss',
|
1243 |
-
model_name=None, model=None, tokenizer=None,
|
1244 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1245 |
stream_output=False,
|
1246 |
prompter=None,
|
@@ -1248,7 +1771,7 @@ def _run_qa_db(query=None,
|
|
1248 |
prompt_dict=None,
|
1249 |
answer_with_sources=True,
|
1250 |
cut_distanct=1.1,
|
1251 |
-
sanitize_bot_response=
|
1252 |
show_rank=False,
|
1253 |
load_db_if_exists=False,
|
1254 |
db=None,
|
@@ -1267,7 +1790,12 @@ def _run_qa_db(query=None,
|
|
1267 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1268 |
n_jobs=-1,
|
1269 |
verbose=False,
|
1270 |
-
cli=False
|
|
|
|
|
|
|
|
|
|
|
1271 |
"""
|
1272 |
|
1273 |
:param query:
|
@@ -1286,6 +1814,8 @@ def _run_qa_db(query=None,
|
|
1286 |
:param answer_with_sources
|
1287 |
:return:
|
1288 |
"""
|
|
|
|
|
1289 |
assert query is not None
|
1290 |
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
|
1291 |
if prompter is not None:
|
@@ -1299,7 +1829,9 @@ def _run_qa_db(query=None,
|
|
1299 |
prompt_dict = ''
|
1300 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
1301 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1302 |
-
model=model,
|
|
|
|
|
1303 |
stream_output=stream_output,
|
1304 |
do_sample=do_sample,
|
1305 |
temperature=temperature,
|
@@ -1315,13 +1847,10 @@ def _run_qa_db(query=None,
|
|
1315 |
prompt_type=prompt_type,
|
1316 |
prompt_dict=prompt_dict,
|
1317 |
prompter=prompter,
|
|
|
1318 |
verbose=verbose,
|
1319 |
)
|
1320 |
|
1321 |
-
if model_name in non_hf_types:
|
1322 |
-
# FIXME: for now, streams to stdout/stderr currently
|
1323 |
-
stream_output = False
|
1324 |
-
|
1325 |
use_context = False
|
1326 |
scores = []
|
1327 |
chain = None
|
@@ -1349,43 +1878,49 @@ def _run_qa_db(query=None,
|
|
1349 |
# can only return if HF type
|
1350 |
return
|
1351 |
|
1352 |
-
|
1353 |
-
|
1354 |
-
|
1355 |
-
|
1356 |
-
|
1357 |
-
|
1358 |
-
|
1359 |
-
|
1360 |
-
|
1361 |
-
|
1362 |
-
|
1363 |
-
|
1364 |
-
|
1365 |
-
|
1366 |
-
|
1367 |
-
|
1368 |
-
|
1369 |
-
|
1370 |
-
|
1371 |
-
|
1372 |
-
|
1373 |
-
|
1374 |
-
|
1375 |
-
|
1376 |
-
|
1377 |
-
|
1378 |
-
|
1379 |
-
|
1380 |
-
|
1381 |
-
|
1382 |
-
|
1383 |
-
|
1384 |
-
|
1385 |
-
|
1386 |
-
|
1387 |
-
|
1388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1389 |
|
1390 |
if not use_context:
|
1391 |
ret = answer['output_text']
|
@@ -1404,6 +1939,7 @@ def get_similarity_chain(query=None,
|
|
1404 |
detect_user_path_changes_every_query=False,
|
1405 |
db_type='faiss',
|
1406 |
model_name=None,
|
|
|
1407 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1408 |
prompt_type=None,
|
1409 |
prompt_dict=None,
|
@@ -1415,8 +1951,14 @@ def get_similarity_chain(query=None,
|
|
1415 |
n_jobs=-1,
|
1416 |
# beyond run_db_query:
|
1417 |
llm=None,
|
|
|
1418 |
verbose=False,
|
1419 |
cmd=None,
|
|
|
|
|
|
|
|
|
|
|
1420 |
):
|
1421 |
# determine whether use of context out of docs is planned
|
1422 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
@@ -1431,7 +1973,11 @@ def get_similarity_chain(query=None,
|
|
1431 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
1432 |
# Chroma collection MyData contains fewer than 4 elements.
|
1433 |
# type logger error
|
1434 |
-
|
|
|
|
|
|
|
|
|
1435 |
|
1436 |
# FIXME: For All just go over all dbs instead of a separate db for All
|
1437 |
if not detect_user_path_changes_every_query and db is not None:
|
@@ -1452,6 +1998,29 @@ def get_similarity_chain(query=None,
|
|
1452 |
n_jobs=n_jobs,
|
1453 |
verbose=verbose)
|
1454 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1455 |
if db and use_context:
|
1456 |
if not isinstance(db, Chroma):
|
1457 |
# only chroma supports filtering
|
@@ -1472,17 +2041,84 @@ def get_similarity_chain(query=None,
|
|
1472 |
docs = []
|
1473 |
scores = []
|
1474 |
elif cmd == DocumentChoices.Only_All_Sources.name:
|
1475 |
-
|
1476 |
-
db_get = db._collection.get(where=filter_kwargs.get('filter'))
|
1477 |
-
else:
|
1478 |
-
db_get = db.get()
|
1479 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
1480 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
1481 |
-
for result in zip(
|
1482 |
docs = [x[0] for x in docs_with_score]
|
1483 |
scores = [x[1] for x in docs_with_score]
|
1484 |
else:
|
1485 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1486 |
# cut off so no high distance docs/sources considered
|
1487 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
1488 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
@@ -1517,19 +2153,11 @@ def get_similarity_chain(query=None,
|
|
1517 |
if len(docs) == 0:
|
1518 |
# avoid context == in prompt then
|
1519 |
use_context = False
|
|
|
1520 |
|
1521 |
-
if
|
1522 |
# instruct-like, rather than few-shot prompt_type='plain' as default
|
1523 |
# but then sources confuse the model with how inserted among rest of text, so avoid
|
1524 |
-
prefix = ""
|
1525 |
-
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
1526 |
-
template = """%s{context}{question}""" % prefix
|
1527 |
-
else:
|
1528 |
-
template = """%s
|
1529 |
-
==
|
1530 |
-
{context}
|
1531 |
-
==
|
1532 |
-
{question}""" % prefix
|
1533 |
prompt = PromptTemplate(
|
1534 |
# input_variables=["summaries", "question"],
|
1535 |
input_variables=["context", "question"],
|
@@ -1589,17 +2217,32 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
|
|
1589 |
return ret, extra
|
1590 |
|
1591 |
|
1592 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1593 |
if not chunk:
|
1594 |
return sources
|
1595 |
-
|
1596 |
-
|
1597 |
-
|
1598 |
-
|
1599 |
-
|
1600 |
-
#
|
1601 |
-
|
1602 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1603 |
return source_chunks
|
1604 |
|
1605 |
|
@@ -1647,15 +2290,17 @@ def _create_local_weaviate_client():
|
|
1647 |
WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
|
1648 |
|
1649 |
resource_owner_config = None
|
1650 |
-
if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
|
1651 |
-
resource_owner_config = weaviate.AuthClientPassword(
|
1652 |
-
username=WEAVIATE_USERNAME,
|
1653 |
-
password=WEAVIATE_PASSWORD,
|
1654 |
-
scope=WEAVIATE_SCOPE
|
1655 |
-
)
|
1656 |
-
|
1657 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1658 |
client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
|
|
|
1659 |
except Exception as e:
|
1660 |
print(f"Failed to create Weaviate client: {e}")
|
1661 |
return None
|
|
|
1 |
+
import ast
|
2 |
import glob
|
3 |
import inspect
|
4 |
import os
|
5 |
import pathlib
|
6 |
import pickle
|
|
|
|
|
7 |
import shutil
|
8 |
import subprocess
|
|
|
9 |
import tempfile
|
10 |
+
import time
|
11 |
import traceback
|
12 |
+
import types
|
13 |
import uuid
|
14 |
import zipfile
|
15 |
from collections import defaultdict
|
16 |
from datetime import datetime
|
17 |
from functools import reduce
|
18 |
from operator import concat
|
19 |
+
import filelock
|
20 |
|
21 |
+
from joblib import delayed
|
22 |
+
from langchain.callbacks import streaming_stdout
|
23 |
from langchain.embeddings import HuggingFaceInstructEmbeddings
|
24 |
from tqdm import tqdm
|
25 |
|
26 |
+
from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
|
27 |
+
from generate import gen_hyper, get_model, SEED
|
28 |
+
from prompter import non_hf_types, PromptType, Prompter
|
29 |
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
|
30 |
+
get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
|
31 |
+
from utils_langchain import StreamingGradioCallbackHandler
|
32 |
|
33 |
import_matplotlib()
|
34 |
|
|
|
43 |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
|
44 |
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
|
45 |
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
|
46 |
+
UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
|
47 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
|
48 |
from langchain.chains.question_answering import load_qa_chain
|
49 |
from langchain.docstore.document import Document
|
50 |
+
from langchain import PromptTemplate, HuggingFaceTextGenInference
|
51 |
from langchain.vectorstores import Chroma
|
52 |
|
53 |
|
|
|
145 |
return db, num_new_sources, []
|
146 |
db.add_documents(documents=sources)
|
147 |
elif db_type == 'chroma':
|
148 |
+
collection = get_documents(db)
|
149 |
# files we already have:
|
150 |
metadata_files = set([x['source'] for x in collection['metadatas']])
|
151 |
if avoid_dup_by_file:
|
|
|
160 |
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
|
161 |
# avoid sources with same hash
|
162 |
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
|
163 |
+
num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
|
164 |
+
print("Found %s new sources (%d have no hash in original source,"
|
165 |
+
" so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
|
166 |
# get new file names that match existing file names. delete existing files we are overridding
|
167 |
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
|
168 |
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
|
|
|
239 |
if use_openai_embedding:
|
240 |
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
|
241 |
from langchain.embeddings import OpenAIEmbeddings
|
242 |
+
embedding = OpenAIEmbeddings(disallowed_special=())
|
243 |
else:
|
244 |
# to ensure can fork without deadlock
|
245 |
from langchain.embeddings import HuggingFaceEmbeddings
|
|
|
266 |
)["output_text"]
|
267 |
|
268 |
|
269 |
+
"""Wrapper around Huggingface text generation inference API."""
|
270 |
+
from functools import partial
|
271 |
+
from typing import Any, Dict, List, Optional, Set
|
272 |
+
|
273 |
+
from pydantic import Extra, Field, root_validator
|
274 |
+
|
275 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
276 |
+
|
277 |
+
"""Wrapper around Huggingface text generation inference API."""
|
278 |
+
from functools import partial
|
279 |
+
from typing import Any, Dict, List, Optional
|
280 |
+
|
281 |
+
from pydantic import Extra, Field, root_validator
|
282 |
+
|
283 |
+
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
284 |
+
from langchain.llms.base import LLM
|
285 |
+
|
286 |
+
|
287 |
+
class GradioInference(LLM):
|
288 |
+
"""
|
289 |
+
Gradio generation inference API.
|
290 |
+
"""
|
291 |
+
inference_server_url: str = ""
|
292 |
+
|
293 |
+
temperature: float = 0.8
|
294 |
+
top_p: Optional[float] = 0.95
|
295 |
+
top_k: Optional[int] = None
|
296 |
+
num_beams: Optional[int] = 1
|
297 |
+
max_new_tokens: int = 512
|
298 |
+
min_new_tokens: int = 1
|
299 |
+
early_stopping: bool = False
|
300 |
+
max_time: int = 180
|
301 |
+
repetition_penalty: Optional[float] = None
|
302 |
+
num_return_sequences: Optional[int] = 1
|
303 |
+
do_sample: bool = False
|
304 |
+
chat_client: bool = False
|
305 |
+
|
306 |
+
return_full_text: bool = True
|
307 |
+
stream: bool = False
|
308 |
+
sanitize_bot_response: bool = False
|
309 |
+
|
310 |
+
prompter: Any = None
|
311 |
+
client: Any = None
|
312 |
+
|
313 |
+
class Config:
|
314 |
+
"""Configuration for this pydantic object."""
|
315 |
+
|
316 |
+
extra = Extra.forbid
|
317 |
+
|
318 |
+
@root_validator()
|
319 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
320 |
+
"""Validate that python package exists in environment."""
|
321 |
+
|
322 |
+
try:
|
323 |
+
if values['client'] is None:
|
324 |
+
import gradio_client
|
325 |
+
values["client"] = gradio_client.Client(
|
326 |
+
values["inference_server_url"]
|
327 |
+
)
|
328 |
+
except ImportError:
|
329 |
+
raise ImportError(
|
330 |
+
"Could not import gradio_client python package. "
|
331 |
+
"Please install it with `pip install gradio_client`."
|
332 |
+
)
|
333 |
+
return values
|
334 |
+
|
335 |
+
@property
|
336 |
+
def _llm_type(self) -> str:
|
337 |
+
"""Return type of llm."""
|
338 |
+
return "gradio_inference"
|
339 |
+
|
340 |
+
def _call(
|
341 |
+
self,
|
342 |
+
prompt: str,
|
343 |
+
stop: Optional[List[str]] = None,
|
344 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
345 |
+
**kwargs: Any,
|
346 |
+
) -> str:
|
347 |
+
# NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
|
348 |
+
# so server should get prompt_type or '', not plain
|
349 |
+
# This is good, so gradio server can also handle stopping.py conditions
|
350 |
+
# this is different than TGI server that uses prompter to inject prompt_type prompting
|
351 |
+
stream_output = self.stream
|
352 |
+
gr_client = self.client
|
353 |
+
client_langchain_mode = 'Disabled'
|
354 |
+
top_k_docs = 1
|
355 |
+
chunk = True
|
356 |
+
chunk_size = 512
|
357 |
+
client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
|
358 |
+
iinput='', # only for chat=True
|
359 |
+
context='',
|
360 |
+
# streaming output is supported, loops over and outputs each generation in streaming mode
|
361 |
+
# but leave stream_output=False for simple input/output mode
|
362 |
+
stream_output=stream_output,
|
363 |
+
prompt_type=self.prompter.prompt_type,
|
364 |
+
prompt_dict='',
|
365 |
+
|
366 |
+
temperature=self.temperature,
|
367 |
+
top_p=self.top_p,
|
368 |
+
top_k=self.top_k,
|
369 |
+
num_beams=self.num_beams,
|
370 |
+
max_new_tokens=self.max_new_tokens,
|
371 |
+
min_new_tokens=self.min_new_tokens,
|
372 |
+
early_stopping=self.early_stopping,
|
373 |
+
max_time=self.max_time,
|
374 |
+
repetition_penalty=self.repetition_penalty,
|
375 |
+
num_return_sequences=self.num_return_sequences,
|
376 |
+
do_sample=self.do_sample,
|
377 |
+
chat=self.chat_client,
|
378 |
+
|
379 |
+
instruction_nochat=prompt if not self.chat_client else '',
|
380 |
+
iinput_nochat='', # only for chat=False
|
381 |
+
langchain_mode=client_langchain_mode,
|
382 |
+
top_k_docs=top_k_docs,
|
383 |
+
chunk=chunk,
|
384 |
+
chunk_size=chunk_size,
|
385 |
+
document_choice=[DocumentChoices.All_Relevant.name],
|
386 |
+
)
|
387 |
+
api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
|
388 |
+
if not stream_output:
|
389 |
+
res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
|
390 |
+
res_dict = ast.literal_eval(res)
|
391 |
+
text = res_dict['response']
|
392 |
+
return self.prompter.get_response(prompt + text, prompt=prompt,
|
393 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
394 |
+
else:
|
395 |
+
text_callback = None
|
396 |
+
if run_manager:
|
397 |
+
text_callback = partial(
|
398 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
399 |
+
)
|
400 |
+
|
401 |
+
job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
|
402 |
+
text0 = ''
|
403 |
+
while not job.done():
|
404 |
+
outputs_list = job.communicator.job.outputs
|
405 |
+
if outputs_list:
|
406 |
+
res = job.communicator.job.outputs[-1]
|
407 |
+
res_dict = ast.literal_eval(res)
|
408 |
+
text = res_dict['response']
|
409 |
+
text = self.prompter.get_response(prompt + text, prompt=prompt,
|
410 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
411 |
+
# FIXME: derive chunk from full for now
|
412 |
+
text_chunk = text[len(text0):]
|
413 |
+
# save old
|
414 |
+
text0 = text
|
415 |
+
|
416 |
+
if text_callback:
|
417 |
+
text_callback(text_chunk)
|
418 |
+
|
419 |
+
time.sleep(0.01)
|
420 |
+
|
421 |
+
# ensure get last output to avoid race
|
422 |
+
res_all = job.outputs()
|
423 |
+
if len(res_all) > 0:
|
424 |
+
res = res_all[-1]
|
425 |
+
res_dict = ast.literal_eval(res)
|
426 |
+
text = res_dict['response']
|
427 |
+
# FIXME: derive chunk from full for now
|
428 |
+
else:
|
429 |
+
# go with old if failure
|
430 |
+
text = text0
|
431 |
+
text_chunk = text[len(text0):]
|
432 |
+
if text_callback:
|
433 |
+
text_callback(text_chunk)
|
434 |
+
return self.prompter.get_response(prompt + text, prompt=prompt,
|
435 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
436 |
+
|
437 |
+
|
438 |
+
class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
|
439 |
+
max_new_tokens: int = 512
|
440 |
+
do_sample: bool = False
|
441 |
+
top_k: Optional[int] = None
|
442 |
+
top_p: Optional[float] = 0.95
|
443 |
+
typical_p: Optional[float] = 0.95
|
444 |
+
temperature: float = 0.8
|
445 |
+
repetition_penalty: Optional[float] = None
|
446 |
+
return_full_text: bool = False
|
447 |
+
stop_sequences: List[str] = Field(default_factory=list)
|
448 |
+
seed: Optional[int] = None
|
449 |
+
inference_server_url: str = ""
|
450 |
+
timeout: int = 300
|
451 |
+
headers: dict = None
|
452 |
+
stream: bool = False
|
453 |
+
sanitize_bot_response: bool = False
|
454 |
+
prompter: Any = None
|
455 |
+
tokenizer: Any = None
|
456 |
+
client: Any = None
|
457 |
+
|
458 |
+
@root_validator()
|
459 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
460 |
+
"""Validate that python package exists in environment."""
|
461 |
+
|
462 |
+
try:
|
463 |
+
if values['client'] is None:
|
464 |
+
import text_generation
|
465 |
+
|
466 |
+
values["client"] = text_generation.Client(
|
467 |
+
values["inference_server_url"],
|
468 |
+
timeout=values["timeout"],
|
469 |
+
headers=values["headers"],
|
470 |
+
)
|
471 |
+
except ImportError:
|
472 |
+
raise ImportError(
|
473 |
+
"Could not import text_generation python package. "
|
474 |
+
"Please install it with `pip install text_generation`."
|
475 |
+
)
|
476 |
+
return values
|
477 |
+
|
478 |
+
def _call(
|
479 |
+
self,
|
480 |
+
prompt: str,
|
481 |
+
stop: Optional[List[str]] = None,
|
482 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
483 |
+
**kwargs: Any,
|
484 |
+
) -> str:
|
485 |
+
if stop is None:
|
486 |
+
stop = self.stop_sequences
|
487 |
+
else:
|
488 |
+
stop += self.stop_sequences
|
489 |
+
|
490 |
+
# HF inference server needs control over input tokens
|
491 |
+
assert self.tokenizer is not None
|
492 |
+
from h2oai_pipeline import H2OTextGenerationPipeline
|
493 |
+
prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
|
494 |
+
|
495 |
+
# NOTE: TGI server does not add prompting, so must do here
|
496 |
+
data_point = dict(context='', instruction=prompt, input='')
|
497 |
+
prompt = self.prompter.generate_prompt(data_point)
|
498 |
+
|
499 |
+
gen_server_kwargs = dict(do_sample=self.do_sample,
|
500 |
+
stop_sequences=stop,
|
501 |
+
max_new_tokens=self.max_new_tokens,
|
502 |
+
top_k=self.top_k,
|
503 |
+
top_p=self.top_p,
|
504 |
+
typical_p=self.typical_p,
|
505 |
+
temperature=self.temperature,
|
506 |
+
repetition_penalty=self.repetition_penalty,
|
507 |
+
return_full_text=self.return_full_text,
|
508 |
+
seed=self.seed,
|
509 |
+
)
|
510 |
+
gen_server_kwargs.update(kwargs)
|
511 |
+
|
512 |
+
# lower bound because client is re-used if multi-threading
|
513 |
+
self.client.timeout = max(300, self.timeout)
|
514 |
+
|
515 |
+
if not self.stream:
|
516 |
+
res = self.client.generate(
|
517 |
+
prompt,
|
518 |
+
**gen_server_kwargs,
|
519 |
+
)
|
520 |
+
if self.return_full_text:
|
521 |
+
gen_text = res.generated_text[len(prompt):]
|
522 |
+
else:
|
523 |
+
gen_text = res.generated_text
|
524 |
+
# remove stop sequences from the end of the generated text
|
525 |
+
for stop_seq in stop:
|
526 |
+
if stop_seq in gen_text:
|
527 |
+
gen_text = gen_text[:gen_text.index(stop_seq)]
|
528 |
+
text = prompt + gen_text
|
529 |
+
text = self.prompter.get_response(text, prompt=prompt,
|
530 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
531 |
+
else:
|
532 |
+
text_callback = None
|
533 |
+
if run_manager:
|
534 |
+
text_callback = partial(
|
535 |
+
run_manager.on_llm_new_token, verbose=self.verbose
|
536 |
+
)
|
537 |
+
# parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
|
538 |
+
if text_callback:
|
539 |
+
text_callback(prompt)
|
540 |
+
text = ""
|
541 |
+
# Note: Streaming ignores return_full_text=True
|
542 |
+
for response in self.client.generate_stream(prompt, **gen_server_kwargs):
|
543 |
+
text_chunk = response.token.text
|
544 |
+
text += text_chunk
|
545 |
+
text = self.prompter.get_response(prompt + text, prompt=prompt,
|
546 |
+
sanitize_bot_response=self.sanitize_bot_response)
|
547 |
+
# stream part
|
548 |
+
is_stop = False
|
549 |
+
for stop_seq in stop:
|
550 |
+
if stop_seq in response.token.text:
|
551 |
+
is_stop = True
|
552 |
+
break
|
553 |
+
if is_stop:
|
554 |
+
break
|
555 |
+
if not response.token.special:
|
556 |
+
if text_callback:
|
557 |
+
text_callback(response.token.text)
|
558 |
+
return text
|
559 |
+
|
560 |
+
|
561 |
+
from langchain.chat_models import ChatOpenAI
|
562 |
+
|
563 |
+
|
564 |
+
class H2OChatOpenAI(ChatOpenAI):
|
565 |
+
@classmethod
|
566 |
+
def all_required_field_names(cls) -> Set:
|
567 |
+
all_required_field_names = super(ChatOpenAI, cls).all_required_field_names()
|
568 |
+
all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'})
|
569 |
+
return all_required_field_names
|
570 |
+
|
571 |
+
|
572 |
+
def get_llm(use_openai_model=False,
|
573 |
+
model_name=None,
|
574 |
+
model=None,
|
575 |
+
tokenizer=None,
|
576 |
+
inference_server=None,
|
577 |
+
stream_output=False,
|
578 |
do_sample=False,
|
579 |
temperature=0.1,
|
580 |
top_k=40,
|
|
|
589 |
prompt_type=None,
|
590 |
prompt_dict=None,
|
591 |
prompter=None,
|
592 |
+
sanitize_bot_response=False,
|
593 |
verbose=False,
|
594 |
):
|
595 |
+
if use_openai_model or inference_server in ['openai', 'openai_chat']:
|
596 |
+
if use_openai_model and model_name is None:
|
597 |
+
model_name = "gpt-3.5-turbo"
|
598 |
+
if inference_server == 'openai':
|
599 |
+
from langchain.llms import OpenAI
|
600 |
+
cls = OpenAI
|
601 |
+
else:
|
602 |
+
cls = H2OChatOpenAI
|
603 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
604 |
+
llm = cls(model_name=model_name,
|
605 |
+
temperature=temperature if do_sample else 0,
|
606 |
+
# FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
|
607 |
+
max_tokens=max_new_tokens,
|
608 |
+
top_p=top_p if do_sample else 1,
|
609 |
+
frequency_penalty=0,
|
610 |
+
presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
|
611 |
+
callbacks=callbacks if stream_output else None,
|
612 |
+
)
|
613 |
+
streamer = callbacks[0] if stream_output else None
|
614 |
+
if inference_server in ['openai', 'openai_chat']:
|
615 |
+
prompt_type = inference_server
|
616 |
+
else:
|
617 |
+
prompt_type = prompt_type or 'plain'
|
618 |
+
elif inference_server:
|
619 |
+
assert inference_server.startswith(
|
620 |
+
'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
|
621 |
+
|
622 |
+
from gradio_utils.grclient import GradioClient
|
623 |
+
from text_generation import Client as HFClient
|
624 |
+
if isinstance(model, GradioClient):
|
625 |
+
gr_client = model
|
626 |
+
hf_client = None
|
627 |
+
else:
|
628 |
+
gr_client = None
|
629 |
+
hf_client = model
|
630 |
+
assert isinstance(hf_client, HFClient)
|
631 |
+
|
632 |
+
inference_server, headers = get_hf_server(inference_server)
|
633 |
+
|
634 |
+
# quick sanity check to avoid long timeouts, just see if can reach server
|
635 |
+
requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
|
636 |
+
|
637 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
638 |
+
assert prompter is not None
|
639 |
+
stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
|
640 |
+
|
641 |
+
if gr_client:
|
642 |
+
chat_client = False
|
643 |
+
llm = GradioInference(
|
644 |
+
inference_server_url=inference_server,
|
645 |
+
return_full_text=True,
|
646 |
+
|
647 |
+
temperature=temperature,
|
648 |
+
top_p=top_p,
|
649 |
+
top_k=top_k,
|
650 |
+
num_beams=num_beams,
|
651 |
+
max_new_tokens=max_new_tokens,
|
652 |
+
min_new_tokens=min_new_tokens,
|
653 |
+
early_stopping=early_stopping,
|
654 |
+
max_time=max_time,
|
655 |
+
repetition_penalty=repetition_penalty,
|
656 |
+
num_return_sequences=num_return_sequences,
|
657 |
+
do_sample=do_sample,
|
658 |
+
chat_client=chat_client,
|
659 |
+
|
660 |
+
callbacks=callbacks if stream_output else None,
|
661 |
+
stream=stream_output,
|
662 |
+
prompter=prompter,
|
663 |
+
client=gr_client,
|
664 |
+
sanitize_bot_response=sanitize_bot_response,
|
665 |
+
)
|
666 |
+
elif hf_client:
|
667 |
+
llm = H2OHuggingFaceTextGenInference(
|
668 |
+
inference_server_url=inference_server,
|
669 |
+
do_sample=do_sample,
|
670 |
+
max_new_tokens=max_new_tokens,
|
671 |
+
repetition_penalty=repetition_penalty,
|
672 |
+
return_full_text=True,
|
673 |
+
seed=SEED,
|
674 |
+
|
675 |
+
stop_sequences=stop_sequences,
|
676 |
+
temperature=temperature,
|
677 |
+
top_k=top_k,
|
678 |
+
top_p=top_p,
|
679 |
+
# typical_p=top_p,
|
680 |
+
callbacks=callbacks if stream_output else None,
|
681 |
+
stream=stream_output,
|
682 |
+
prompter=prompter,
|
683 |
+
tokenizer=tokenizer,
|
684 |
+
client=hf_client,
|
685 |
+
timeout=max_time,
|
686 |
+
sanitize_bot_response=sanitize_bot_response,
|
687 |
+
)
|
688 |
+
else:
|
689 |
+
raise RuntimeError("No defined client")
|
690 |
+
streamer = callbacks[0] if stream_output else None
|
691 |
elif model_name in non_hf_types:
|
692 |
+
if model_name == 'llama':
|
693 |
+
callbacks = [StreamingGradioCallbackHandler()]
|
694 |
+
streamer = callbacks[0] if stream_output else None
|
695 |
+
else:
|
696 |
+
# stream_output = False
|
697 |
+
# doesn't stream properly as generator, but at least
|
698 |
+
callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
|
699 |
+
streamer = None
|
700 |
+
if prompter:
|
701 |
+
prompt_type = prompter.prompt_type
|
702 |
+
else:
|
703 |
+
prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
|
704 |
+
pass # assume inputted prompt_type is correct
|
705 |
from gpt4all_llm import get_llm_gpt4all
|
706 |
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
|
707 |
temperature=temperature,
|
708 |
repetition_penalty=repetition_penalty,
|
709 |
top_k=top_k,
|
710 |
top_p=top_p,
|
711 |
+
callbacks=callbacks,
|
712 |
verbose=verbose,
|
713 |
+
streaming=stream_output,
|
714 |
+
prompter=prompter,
|
715 |
)
|
|
|
|
|
716 |
else:
|
|
|
|
|
717 |
if model is None:
|
718 |
# only used if didn't pass model in
|
719 |
assert tokenizer is None
|
720 |
prompt_type = 'human_bot'
|
721 |
+
if model_name is None:
|
722 |
+
model_name = 'h2oai/h2ogpt-oasst1-512-12b'
|
723 |
+
# model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
|
724 |
+
# model_name = 'h2oai/h2ogpt-oasst1-512-20b'
|
725 |
+
inference_server = ''
|
726 |
+
model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
|
727 |
+
inference_server=inference_server, gpu_id=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
728 |
|
729 |
max_max_tokens = tokenizer.model_max_length
|
730 |
gen_kwargs = dict(do_sample=do_sample,
|
|
|
739 |
repetition_penalty=repetition_penalty,
|
740 |
num_return_sequences=num_return_sequences,
|
741 |
return_full_text=True,
|
742 |
+
handle_long_generation=None)
|
743 |
assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
|
744 |
|
745 |
if stream_output:
|
|
|
756 |
prompter=prompter,
|
757 |
prompt_type=prompt_type,
|
758 |
prompt_dict=prompt_dict,
|
759 |
+
sanitize_bot_response=sanitize_bot_response,
|
760 |
chat=False, stream_output=stream_output,
|
761 |
tokenizer=tokenizer,
|
762 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
763 |
+
max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
|
764 |
**gen_kwargs)
|
765 |
# pipe.task = "text-generation"
|
766 |
# below makes it listen only to our prompt removal,
|
|
|
805 |
data = json.load(open(filename, "rt"))
|
806 |
page_content = list(data["query"]["pages"].values())[0]["extract"]
|
807 |
if take_head is not None and text_limit is not None:
|
808 |
+
page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
|
809 |
title_url = str(title).replace(' ', '_')
|
810 |
return Document(
|
811 |
page_content=page_content,
|
|
|
927 |
except (pkg_resources.DistributionNotFound, AssertionError):
|
928 |
have_pymupdf = False
|
929 |
|
930 |
+
try:
|
931 |
+
assert pkg_resources.get_distribution('selenium') is not None
|
932 |
+
have_selenium = True
|
933 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
934 |
+
have_selenium = False
|
935 |
+
|
936 |
+
try:
|
937 |
+
assert pkg_resources.get_distribution('playwright') is not None
|
938 |
+
have_playwright = True
|
939 |
+
except (pkg_resources.DistributionNotFound, AssertionError):
|
940 |
+
have_playwright = False
|
941 |
+
|
942 |
+
# disable, hangs too often
|
943 |
+
have_playwright = False
|
944 |
+
|
945 |
image_types = ["png", "jpg", "jpeg"]
|
946 |
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
|
947 |
"md", "html",
|
|
|
959 |
def add_meta(docs1, file):
|
960 |
file_extension = pathlib.Path(file).suffix
|
961 |
hashid = hash_file(file)
|
962 |
+
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
963 |
docs1 = [docs1]
|
964 |
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
|
965 |
|
|
|
1001 |
else:
|
1002 |
docs1 = []
|
1003 |
else:
|
1004 |
+
if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
|
1005 |
+
file = 'http://' + file
|
1006 |
docs1 = UnstructuredURLLoader(urls=[file]).load()
|
1007 |
+
if len(docs1) == 0 and have_playwright:
|
1008 |
+
# then something went wrong, try another loader:
|
1009 |
+
from langchain.document_loaders import PlaywrightURLLoader
|
1010 |
+
docs1 = PlaywrightURLLoader(urls=[file]).load()
|
1011 |
+
if len(docs1) == 0 and have_selenium:
|
1012 |
+
# then something went wrong, try another loader:
|
1013 |
+
# but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary
|
1014 |
+
from langchain.document_loaders import SeleniumURLLoader
|
1015 |
+
from selenium.common.exceptions import WebDriverException
|
1016 |
+
try:
|
1017 |
+
docs1 = SeleniumURLLoader(urls=[file]).load()
|
1018 |
+
except WebDriverException as e:
|
1019 |
+
print("No web driver: %s" % str(e), flush=True)
|
1020 |
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
|
1021 |
+
docs1 = clean_doc(docs1)
|
1022 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1023 |
elif is_txt:
|
1024 |
base_path = "user_paste"
|
|
|
1028 |
f.write(file)
|
1029 |
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
|
1030 |
doc1 = Document(page_content=file, metadata=metadata)
|
1031 |
+
doc1 = clean_doc(doc1)
|
1032 |
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
|
1033 |
docs1 = UnstructuredHTMLLoader(file_path=file).load()
|
1034 |
add_meta(docs1, file)
|
1035 |
+
docs1 = clean_doc(docs1)
|
1036 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
|
1037 |
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
|
1038 |
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
|
1039 |
add_meta(docs1, file)
|
|
|
1045 |
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
|
1046 |
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
|
1047 |
add_meta(docs1, file)
|
1048 |
+
docs1 = clean_doc(docs1)
|
1049 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1050 |
elif file.lower().endswith('.txt'):
|
1051 |
# use UnstructuredFileLoader ?
|
1052 |
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
|
1053 |
# makes just one, but big one
|
1054 |
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
|
1055 |
+
doc1 = clean_doc(doc1)
|
1056 |
add_meta(doc1, file)
|
1057 |
elif file.lower().endswith('.rtf'):
|
1058 |
docs1 = UnstructuredRTFLoader(file).load()
|
|
|
1061 |
elif file.lower().endswith('.md'):
|
1062 |
docs1 = UnstructuredMarkdownLoader(file).load()
|
1063 |
add_meta(docs1, file)
|
1064 |
+
docs1 = clean_doc(docs1)
|
1065 |
+
doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
|
1066 |
elif file.lower().endswith('.enex'):
|
1067 |
docs1 = EverNoteLoader(file).load()
|
1068 |
add_meta(doc1, file)
|
|
|
1127 |
with open(file, "r") as f:
|
1128 |
doc1 = Document(page_content=f.read(), metadata={"source": file})
|
1129 |
add_meta(doc1, file)
|
1130 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
|
1131 |
elif file.lower().endswith('.pdf'):
|
1132 |
env_gpt4all_file = ".env_gpt4all"
|
1133 |
from dotenv import dotenv_values
|
|
|
1138 |
from langchain.document_loaders import PyMuPDFLoader
|
1139 |
# load() still chunks by pages, but every page has title at start to help
|
1140 |
doc1 = PyMuPDFLoader(file).load()
|
1141 |
+
doc1 = clean_doc(doc1)
|
1142 |
+
elif pdf_class_name == 'UnstructuredPDFLoader':
|
1143 |
+
doc1 = UnstructuredPDFLoader(file).load()
|
1144 |
+
# seems to not need cleaning in most cases
|
1145 |
else:
|
1146 |
# open-source fallback
|
1147 |
# load() still chunks by pages, but every page has title at start to help
|
1148 |
doc1 = PyPDFLoader(file).load()
|
1149 |
+
doc1 = clean_doc(doc1)
|
1150 |
# Some PDFs return nothing or junk from PDFMinerLoader
|
1151 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
|
1152 |
add_meta(doc1, file)
|
1153 |
elif file.lower().endswith('.csv'):
|
1154 |
doc1 = CSVLoader(file).load()
|
|
|
1156 |
elif file.lower().endswith('.py'):
|
1157 |
doc1 = PythonLoader(file).load()
|
1158 |
add_meta(doc1, file)
|
1159 |
+
doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
|
1160 |
elif file.lower().endswith('.toml'):
|
1161 |
doc1 = TomlLoader(file).load()
|
1162 |
add_meta(doc1, file)
|
|
|
1247 |
existing_files=[],
|
1248 |
existing_hash_ids={},
|
1249 |
):
|
1250 |
+
# path_or_paths could be str, list, tuple, generator
|
1251 |
globs_image_types = []
|
1252 |
globs_non_image_types = []
|
1253 |
if not path_or_paths and not url and not text:
|
1254 |
return []
|
1255 |
elif url:
|
1256 |
+
globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
|
1257 |
elif text:
|
1258 |
+
globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
|
1259 |
+
elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
|
1260 |
# single path, only consume allowed files
|
1261 |
path = path_or_paths
|
1262 |
# Below globs should match patterns in file_to_doc()
|
|
|
1265 |
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
|
1266 |
for ftype in non_image_types]
|
1267 |
else:
|
1268 |
+
if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
|
1269 |
+
path_or_paths = [path_or_paths]
|
1270 |
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
|
1271 |
+
assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
|
1272 |
+
path_or_paths)
|
1273 |
# reform out of allowed types
|
1274 |
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
|
1275 |
# could do below:
|
|
|
1429 |
if load_embed(db) != (use_openai_embedding, hf_embedding_model):
|
1430 |
print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
|
1431 |
# handle embedding changes
|
1432 |
+
db_get = get_documents(db)
|
1433 |
sources = [Document(page_content=result[0], metadata=result[1] or {})
|
1434 |
for result in zip(db_get['documents'], db_get['metadatas'])]
|
1435 |
# delete index, has to be redone
|
|
|
1480 |
if changed_db:
|
1481 |
db = db_trial
|
1482 |
# only call persist if really changed db, else takes too long for large db
|
1483 |
+
if db is not None:
|
1484 |
+
db.persist()
|
1485 |
+
clear_embedding(db)
|
1486 |
save_embed(db, use_openai_embedding, hf_embedding_model)
|
1487 |
return db
|
1488 |
return None
|
1489 |
|
1490 |
|
1491 |
def clear_embedding(db):
|
1492 |
+
if db is None:
|
1493 |
+
return
|
1494 |
# don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
|
1495 |
db._embedding_function.client.cpu()
|
1496 |
clear_torch_cache()
|
|
|
1512 |
|
1513 |
|
1514 |
def save_embed(db, use_openai_embedding, hf_embedding_model):
|
1515 |
+
if db is not None:
|
1516 |
+
embed_info_file = os.path.join(db._persist_directory, 'embed_info')
|
1517 |
+
with open(embed_info_file, 'wb') as f:
|
1518 |
+
pickle.dump((use_openai_embedding, hf_embedding_model), f)
|
1519 |
return use_openai_embedding, hf_embedding_model
|
1520 |
|
1521 |
|
|
|
1662 |
return db, len(new_sources_metadata), new_sources_metadata
|
1663 |
|
1664 |
|
1665 |
+
def get_metadatas(db):
|
1666 |
+
from langchain.vectorstores import FAISS
|
1667 |
+
if isinstance(db, FAISS):
|
1668 |
+
metadatas = [v.metadata for k, v in db.docstore._dict.items()]
|
1669 |
+
elif isinstance(db, Chroma):
|
1670 |
+
metadatas = get_documents(db)['metadatas']
|
1671 |
+
else:
|
1672 |
+
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
|
1673 |
+
# seems no way to get all metadata, so need to avoid this approach for weaviate
|
1674 |
+
metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
|
1675 |
+
return metadatas
|
1676 |
+
|
1677 |
+
|
1678 |
+
def get_documents(db):
|
1679 |
+
if hasattr(db, '_persist_directory'):
|
1680 |
+
name_path = os.path.basename(db._persist_directory)
|
1681 |
+
base_path = 'locks'
|
1682 |
+
makedirs(base_path)
|
1683 |
+
with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
|
1684 |
+
# get segfaults and other errors when multiple threads access this
|
1685 |
+
return _get_documents(db)
|
1686 |
+
else:
|
1687 |
+
return _get_documents(db)
|
1688 |
+
|
1689 |
+
|
1690 |
+
def _get_documents(db):
|
1691 |
+
from langchain.vectorstores import FAISS
|
1692 |
+
if isinstance(db, FAISS):
|
1693 |
+
documents = [v for k, v in db.docstore._dict.items()]
|
1694 |
+
elif isinstance(db, Chroma):
|
1695 |
+
documents = db.get()
|
1696 |
+
else:
|
1697 |
+
# FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
|
1698 |
+
# seems no way to get all metadata, so need to avoid this approach for weaviate
|
1699 |
+
documents = [x for x in db.similarity_search("", k=10000)]
|
1700 |
+
return documents
|
1701 |
+
|
1702 |
+
|
1703 |
+
def get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
|
1704 |
+
if hasattr(db, '_persist_directory'):
|
1705 |
+
name_path = os.path.basename(db._persist_directory)
|
1706 |
+
base_path = 'locks'
|
1707 |
+
makedirs(base_path)
|
1708 |
+
with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
|
1709 |
+
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
1710 |
+
else:
|
1711 |
+
return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
1712 |
+
|
1713 |
+
|
1714 |
+
def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
|
1715 |
+
from langchain.vectorstores import FAISS
|
1716 |
+
if isinstance(db, Chroma):
|
1717 |
+
db_get = db._collection.get(where=filter_kwargs.get('filter'))
|
1718 |
+
db_metadatas = db_get['metadatas']
|
1719 |
+
db_documents = db_get['documents']
|
1720 |
+
elif isinstance(db, FAISS):
|
1721 |
+
import itertools
|
1722 |
+
db_metadatas = get_metadatas(db)
|
1723 |
+
# FIXME: FAISS has no filter
|
1724 |
+
# slice dict first
|
1725 |
+
db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
|
1726 |
+
else:
|
1727 |
+
db_metadatas = get_metadatas(db)
|
1728 |
+
db_documents = get_documents(db)
|
1729 |
+
return db_documents, db_metadatas
|
1730 |
+
|
1731 |
+
|
1732 |
def get_existing_files(db):
|
1733 |
+
metadatas = get_metadatas(db)
|
1734 |
+
metadata_sources = set([x['source'] for x in metadatas])
|
1735 |
return metadata_sources
|
1736 |
|
1737 |
|
1738 |
def get_existing_hash_ids(db):
|
1739 |
+
metadatas = get_metadatas(db)
|
1740 |
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
|
1741 |
+
metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas}
|
1742 |
return metadata_hash_ids
|
1743 |
|
1744 |
|
|
|
|
|
|
|
|
|
1745 |
def run_qa_db(**kwargs):
|
1746 |
func_names = list(inspect.signature(_run_qa_db).parameters)
|
1747 |
# hard-coded defaults
|
1748 |
kwargs['answer_with_sources'] = True
|
|
|
1749 |
kwargs['show_rank'] = False
|
1750 |
missing_kwargs = [x for x in func_names if x not in kwargs]
|
1751 |
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
|
|
|
1763 |
user_path=None,
|
1764 |
detect_user_path_changes_every_query=False,
|
1765 |
db_type='faiss',
|
1766 |
+
model_name=None, model=None, tokenizer=None, inference_server=None,
|
1767 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1768 |
stream_output=False,
|
1769 |
prompter=None,
|
|
|
1771 |
prompt_dict=None,
|
1772 |
answer_with_sources=True,
|
1773 |
cut_distanct=1.1,
|
1774 |
+
sanitize_bot_response=False,
|
1775 |
show_rank=False,
|
1776 |
load_db_if_exists=False,
|
1777 |
db=None,
|
|
|
1790 |
document_choice=[DocumentChoices.All_Relevant.name],
|
1791 |
n_jobs=-1,
|
1792 |
verbose=False,
|
1793 |
+
cli=False,
|
1794 |
+
reverse_docs=True,
|
1795 |
+
lora_weights='',
|
1796 |
+
auto_reduce_chunks=True,
|
1797 |
+
max_chunks=100,
|
1798 |
+
):
|
1799 |
"""
|
1800 |
|
1801 |
:param query:
|
|
|
1814 |
:param answer_with_sources
|
1815 |
:return:
|
1816 |
"""
|
1817 |
+
if model is not None:
|
1818 |
+
assert model_name is not None # require so can make decisions
|
1819 |
assert query is not None
|
1820 |
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
|
1821 |
if prompter is not None:
|
|
|
1829 |
prompt_dict = ''
|
1830 |
assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
|
1831 |
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
|
1832 |
+
model=model,
|
1833 |
+
tokenizer=tokenizer,
|
1834 |
+
inference_server=inference_server,
|
1835 |
stream_output=stream_output,
|
1836 |
do_sample=do_sample,
|
1837 |
temperature=temperature,
|
|
|
1847 |
prompt_type=prompt_type,
|
1848 |
prompt_dict=prompt_dict,
|
1849 |
prompter=prompter,
|
1850 |
+
sanitize_bot_response=sanitize_bot_response,
|
1851 |
verbose=verbose,
|
1852 |
)
|
1853 |
|
|
|
|
|
|
|
|
|
1854 |
use_context = False
|
1855 |
scores = []
|
1856 |
chain = None
|
|
|
1878 |
# can only return if HF type
|
1879 |
return
|
1880 |
|
1881 |
+
# context stuff similar to used in evaluate()
|
1882 |
+
import torch
|
1883 |
+
device, torch_dtype, context_class = get_device_dtype()
|
1884 |
+
with torch.no_grad():
|
1885 |
+
have_lora_weights = lora_weights not in [no_lora_str, '', None]
|
1886 |
+
context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
|
1887 |
+
with context_class_cast(device):
|
1888 |
+
if stream_output and streamer:
|
1889 |
+
answer = None
|
1890 |
+
import queue
|
1891 |
+
bucket = queue.Queue()
|
1892 |
+
thread = EThread(target=chain, streamer=streamer, bucket=bucket)
|
1893 |
+
thread.start()
|
1894 |
+
outputs = ""
|
1895 |
+
prompt = None # FIXME
|
1896 |
+
try:
|
1897 |
+
for new_text in streamer:
|
1898 |
+
# print("new_text: %s" % new_text, flush=True)
|
1899 |
+
if bucket.qsize() > 0 or thread.exc:
|
1900 |
+
thread.join()
|
1901 |
+
outputs += new_text
|
1902 |
+
if prompter: # and False: # FIXME: pipeline can already use prompter
|
1903 |
+
output1 = prompter.get_response(outputs, prompt=prompt,
|
1904 |
+
sanitize_bot_response=sanitize_bot_response)
|
1905 |
+
yield output1, ''
|
1906 |
+
else:
|
1907 |
+
yield outputs, ''
|
1908 |
+
except BaseException:
|
1909 |
+
# if any exception, raise that exception if was from thread, first
|
1910 |
+
if thread.exc:
|
1911 |
+
raise thread.exc
|
1912 |
+
raise
|
1913 |
+
finally:
|
1914 |
+
# in case no exception and didn't join with thread yet, then join
|
1915 |
+
if not thread.exc:
|
1916 |
+
answer = thread.join()
|
1917 |
+
# in case raise StopIteration or broke queue loop in streamer, but still have exception
|
1918 |
+
if thread.exc:
|
1919 |
+
raise thread.exc
|
1920 |
+
# FIXME: answer is not string outputs from streamer. How to get actual final output?
|
1921 |
+
# answer = outputs
|
1922 |
+
else:
|
1923 |
+
answer = chain()
|
1924 |
|
1925 |
if not use_context:
|
1926 |
ret = answer['output_text']
|
|
|
1939 |
detect_user_path_changes_every_query=False,
|
1940 |
db_type='faiss',
|
1941 |
model_name=None,
|
1942 |
+
inference_server='',
|
1943 |
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
|
1944 |
prompt_type=None,
|
1945 |
prompt_dict=None,
|
|
|
1951 |
n_jobs=-1,
|
1952 |
# beyond run_db_query:
|
1953 |
llm=None,
|
1954 |
+
tokenizer=None,
|
1955 |
verbose=False,
|
1956 |
cmd=None,
|
1957 |
+
reverse_docs=True,
|
1958 |
+
|
1959 |
+
# local
|
1960 |
+
auto_reduce_chunks=True,
|
1961 |
+
max_chunks=100,
|
1962 |
):
|
1963 |
# determine whether use of context out of docs is planned
|
1964 |
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
|
|
1973 |
# FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
|
1974 |
# Chroma collection MyData contains fewer than 4 elements.
|
1975 |
# type logger error
|
1976 |
+
if top_k_docs == -1:
|
1977 |
+
k_db = 1000 if db_type == 'chroma' else 100
|
1978 |
+
else:
|
1979 |
+
# top_k_docs=100 works ok too
|
1980 |
+
k_db = 1000 if db_type == 'chroma' else top_k_docs
|
1981 |
|
1982 |
# FIXME: For All just go over all dbs instead of a separate db for All
|
1983 |
if not detect_user_path_changes_every_query and db is not None:
|
|
|
1998 |
n_jobs=n_jobs,
|
1999 |
verbose=verbose)
|
2000 |
|
2001 |
+
if 'falcon' in model_name:
|
2002 |
+
extra = "According to only the information in the document sources provided within the context above, "
|
2003 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
|
2004 |
+
elif inference_server in ['openai', 'openai_chat']:
|
2005 |
+
extra = "According to (primarily) the information in the document sources provided within context above, "
|
2006 |
+
prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
|
2007 |
+
else:
|
2008 |
+
extra = ""
|
2009 |
+
prefix = ""
|
2010 |
+
if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
|
2011 |
+
template_if_no_docs = template = """%s{context}{question}""" % prefix
|
2012 |
+
else:
|
2013 |
+
template = """%s
|
2014 |
+
\"\"\"
|
2015 |
+
{context}
|
2016 |
+
\"\"\"
|
2017 |
+
%s{question}""" % (prefix, extra)
|
2018 |
+
template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
|
2019 |
+
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
|
2020 |
+
use_template = True
|
2021 |
+
else:
|
2022 |
+
use_template = False
|
2023 |
+
|
2024 |
if db and use_context:
|
2025 |
if not isinstance(db, Chroma):
|
2026 |
# only chroma supports filtering
|
|
|
2041 |
docs = []
|
2042 |
scores = []
|
2043 |
elif cmd == DocumentChoices.Only_All_Sources.name:
|
2044 |
+
db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
|
|
|
|
|
|
|
2045 |
# similar to langchain's chroma's _results_to_docs_and_scores
|
2046 |
docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
|
2047 |
+
for result in zip(db_documents, db_metadatas)][:top_k_docs]
|
2048 |
docs = [x[0] for x in docs_with_score]
|
2049 |
scores = [x[1] for x in docs_with_score]
|
2050 |
else:
|
2051 |
+
if top_k_docs == -1 or auto_reduce_chunks:
|
2052 |
+
# docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2053 |
+
top_k_docs_tokenize = 100
|
2054 |
+
base_path = 'locks'
|
2055 |
+
makedirs(base_path)
|
2056 |
+
if hasattr(db, '_persist_directory'):
|
2057 |
+
name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
|
2058 |
+
else:
|
2059 |
+
name_path = "sim.lock"
|
2060 |
+
with filelock.FileLock(os.path.join(base_path, name_path)):
|
2061 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
|
2062 |
+
:top_k_docs_tokenize]
|
2063 |
+
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
|
2064 |
+
# more accurate
|
2065 |
+
tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
|
2066 |
+
template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
|
2067 |
+
elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
|
2068 |
+
'weaviate']:
|
2069 |
+
# use ticktoken for faiss since embedding called differently
|
2070 |
+
tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
|
2071 |
+
template_tokens = llm.get_num_tokens(template)
|
2072 |
+
elif isinstance(tokenizer, FakeTokenizer):
|
2073 |
+
tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score]
|
2074 |
+
template_tokens = tokenizer.num_tokens_from_string(template)
|
2075 |
+
else:
|
2076 |
+
# in case model is not our pipeline with HF tokenizer
|
2077 |
+
tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in
|
2078 |
+
docs_with_score]
|
2079 |
+
template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1]
|
2080 |
+
tokens_cumsum = np.cumsum(tokens)
|
2081 |
+
if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'):
|
2082 |
+
max_input_tokens = llm.pipeline.max_input_tokens
|
2083 |
+
elif inference_server in ['openai']:
|
2084 |
+
max_tokens = llm.modelname_to_contextsize(model_name)
|
2085 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2086 |
+
max_input_tokens = max_tokens - 256
|
2087 |
+
elif inference_server in ['openai_chat']:
|
2088 |
+
max_tokens = model_token_mapping[model_name]
|
2089 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2090 |
+
max_input_tokens = max_tokens - 256
|
2091 |
+
elif isinstance(tokenizer, FakeTokenizer):
|
2092 |
+
max_input_tokens = tokenizer.model_max_length - 256
|
2093 |
+
else:
|
2094 |
+
# leave some room for 1 paragraph, even if min_new_tokens=0
|
2095 |
+
max_input_tokens = 2048 - 256
|
2096 |
+
max_input_tokens -= template_tokens
|
2097 |
+
# FIXME: Doesn't account for query, == context, or new lines between contexts
|
2098 |
+
where_res = np.where(tokens_cumsum < max_input_tokens)[0]
|
2099 |
+
if where_res.shape[0] == 0:
|
2100 |
+
# then no chunk can fit, still do first one
|
2101 |
+
top_k_docs_trial = 1
|
2102 |
+
else:
|
2103 |
+
top_k_docs_trial = 1 + where_res[-1]
|
2104 |
+
if 0 < top_k_docs_trial < max_chunks:
|
2105 |
+
# avoid craziness
|
2106 |
+
if top_k_docs == -1:
|
2107 |
+
top_k_docs = top_k_docs_trial
|
2108 |
+
else:
|
2109 |
+
top_k_docs = min(top_k_docs, top_k_docs_trial)
|
2110 |
+
if top_k_docs == -1:
|
2111 |
+
# if here, means 0 and just do best with 1 doc
|
2112 |
+
print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True)
|
2113 |
+
top_k_docs = 1
|
2114 |
+
docs_with_score = docs_with_score[:top_k_docs]
|
2115 |
+
else:
|
2116 |
+
docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
|
2117 |
+
# put most relevant chunks closest to question,
|
2118 |
+
# esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
|
2119 |
+
# BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
|
2120 |
+
if reverse_docs:
|
2121 |
+
docs_with_score.reverse()
|
2122 |
# cut off so no high distance docs/sources considered
|
2123 |
docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
|
2124 |
scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
|
|
|
2153 |
if len(docs) == 0:
|
2154 |
# avoid context == in prompt then
|
2155 |
use_context = False
|
2156 |
+
template = template_if_no_docs
|
2157 |
|
2158 |
+
if use_template:
|
2159 |
# instruct-like, rather than few-shot prompt_type='plain' as default
|
2160 |
# but then sources confuse the model with how inserted among rest of text, so avoid
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2161 |
prompt = PromptTemplate(
|
2162 |
# input_variables=["summaries", "question"],
|
2163 |
input_variables=["context", "question"],
|
|
|
2217 |
return ret, extra
|
2218 |
|
2219 |
|
2220 |
+
def clean_doc(docs1):
|
2221 |
+
if not isinstance(docs1, (list, tuple, types.GeneratorType)):
|
2222 |
+
docs1 = [docs1]
|
2223 |
+
for doci, doc in enumerate(docs1):
|
2224 |
+
docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()])
|
2225 |
+
return docs1
|
2226 |
+
|
2227 |
+
|
2228 |
+
def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
|
2229 |
if not chunk:
|
2230 |
return sources
|
2231 |
+
if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
|
2232 |
+
# if just one document
|
2233 |
+
sources = [sources]
|
2234 |
+
if language and False:
|
2235 |
+
# Bug in langchain, keep separator=True not working
|
2236 |
+
# https://github.com/hwchase17/langchain/issues/2836
|
2237 |
+
# so avoid this for now
|
2238 |
+
keep_separator = True
|
2239 |
+
separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
|
2240 |
+
else:
|
2241 |
+
separators = ["\n\n", "\n", " ", ""]
|
2242 |
+
keep_separator = False
|
2243 |
+
splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
|
2244 |
+
separators=separators)
|
2245 |
+
source_chunks = splitter.split_documents(sources)
|
2246 |
return source_chunks
|
2247 |
|
2248 |
|
|
|
2290 |
WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
|
2291 |
|
2292 |
resource_owner_config = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2293 |
try:
|
2294 |
+
import weaviate
|
2295 |
+
if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
|
2296 |
+
resource_owner_config = weaviate.AuthClientPassword(
|
2297 |
+
username=WEAVIATE_USERNAME,
|
2298 |
+
password=WEAVIATE_PASSWORD,
|
2299 |
+
scope=WEAVIATE_SCOPE
|
2300 |
+
)
|
2301 |
+
|
2302 |
client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
|
2303 |
+
return client
|
2304 |
except Exception as e:
|
2305 |
print(f"Failed to create Weaviate client: {e}")
|
2306 |
return None
|
gradio_runner.py
CHANGED
The diff for this file is too large to render.
See raw diff
|
|
gradio_themes.py
CHANGED
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|
3 |
from typing import Iterable
|
4 |
|
5 |
from gradio.themes.soft import Soft
|
6 |
-
from gradio.themes import Color
|
7 |
from gradio.themes.utils import colors, sizes, fonts
|
8 |
|
9 |
h2o_yellow = Color(
|
@@ -36,6 +36,42 @@ h2o_gray = Color(
|
|
36 |
)
|
37 |
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
class H2oTheme(Soft):
|
40 |
def __init__(
|
41 |
self,
|
@@ -158,19 +194,23 @@ h2o_logo = '<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/
|
|
158 |
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
159 |
|
160 |
|
161 |
-
def get_h2o_title(title):
|
162 |
-
|
|
|
|
|
|
|
|
|
163 |
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
164 |
<h1 style="line-height:60px">{title}</h1>
|
165 |
</div>
|
166 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
167 |
-
<img src=https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png
|
168 |
</div>
|
169 |
"""
|
170 |
|
171 |
|
172 |
-
def get_simple_title(title):
|
173 |
-
return f"""<h1 align="center"> {title}</h1>"""
|
174 |
|
175 |
|
176 |
def get_dark_js():
|
|
|
3 |
from typing import Iterable
|
4 |
|
5 |
from gradio.themes.soft import Soft
|
6 |
+
from gradio.themes import Color, Size
|
7 |
from gradio.themes.utils import colors, sizes, fonts
|
8 |
|
9 |
h2o_yellow = Color(
|
|
|
36 |
)
|
37 |
|
38 |
|
39 |
+
text_xsm = Size(
|
40 |
+
name="text_xsm",
|
41 |
+
xxs="4px",
|
42 |
+
xs="5px",
|
43 |
+
sm="6px",
|
44 |
+
md="7px",
|
45 |
+
lg="8px",
|
46 |
+
xl="10px",
|
47 |
+
xxl="12px",
|
48 |
+
)
|
49 |
+
|
50 |
+
|
51 |
+
spacing_xsm = Size(
|
52 |
+
name="spacing_xsm",
|
53 |
+
xxs="1px",
|
54 |
+
xs="1px",
|
55 |
+
sm="1px",
|
56 |
+
md="2px",
|
57 |
+
lg="3px",
|
58 |
+
xl="5px",
|
59 |
+
xxl="7px",
|
60 |
+
)
|
61 |
+
|
62 |
+
|
63 |
+
radius_xsm = Size(
|
64 |
+
name="radius_xsm",
|
65 |
+
xxs="1px",
|
66 |
+
xs="1px",
|
67 |
+
sm="1px",
|
68 |
+
md="2px",
|
69 |
+
lg="3px",
|
70 |
+
xl="5px",
|
71 |
+
xxl="7px",
|
72 |
+
)
|
73 |
+
|
74 |
+
|
75 |
class H2oTheme(Soft):
|
76 |
def __init__(
|
77 |
self,
|
|
|
194 |
'11.93S497.56,252.55,497.56,246.06Zm2.52,21.47h20.68v71.31H500.08Z"/></svg>'
|
195 |
|
196 |
|
197 |
+
def get_h2o_title(title, description):
|
198 |
+
# NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc.
|
199 |
+
return f"""<div style="float:left; justify-content:left; height: 80px; width: 195px; margin-top:0px">
|
200 |
+
{description}
|
201 |
+
</div>
|
202 |
+
<div style="display:flex; justify-content:center; margin-bottom:30px; margin-right:330px;">
|
203 |
<div style="height: 60px; width: 60px; margin-right:20px;">{h2o_logo}</div>
|
204 |
<h1 style="line-height:60px">{title}</h1>
|
205 |
</div>
|
206 |
<div style="float:right; height: 80px; width: 80px; margin-top:-100px">
|
207 |
+
<img src="https://raw.githubusercontent.com/h2oai/h2ogpt/main/docs/h2o-qr.png">
|
208 |
</div>
|
209 |
"""
|
210 |
|
211 |
|
212 |
+
def get_simple_title(title, description):
|
213 |
+
return f"""{description}<h1 align="center"> {title}</h1>"""
|
214 |
|
215 |
|
216 |
def get_dark_js():
|
gradio_utils/__pycache__/css.cpython-310.pyc
ADDED
Binary file (1.53 kB). View file
|
|
gradio_utils/__pycache__/grclient.cpython-310.pyc
ADDED
Binary file (2.69 kB). View file
|
|
gradio_utils/__pycache__/prompt_form.cpython-310.pyc
ADDED
Binary file (3.59 kB). View file
|
|
gradio_utils/css.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_css(kwargs) -> str:
|
2 |
+
if kwargs['h2ocolors']:
|
3 |
+
css_code = """footer {visibility: hidden;}
|
4 |
+
body{background:linear-gradient(#f5f5f5,#e5e5e5);}
|
5 |
+
body.dark{background:linear-gradient(#000000,#0d0d0d);}
|
6 |
+
"""
|
7 |
+
else:
|
8 |
+
css_code = """footer {visibility: hidden}"""
|
9 |
+
|
10 |
+
css_code += make_css_base()
|
11 |
+
return css_code
|
12 |
+
|
13 |
+
|
14 |
+
def make_css_base() -> str:
|
15 |
+
return """
|
16 |
+
@import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap');
|
17 |
+
|
18 |
+
body.dark{#warning {background-color: #555555};}
|
19 |
+
|
20 |
+
#small_btn {
|
21 |
+
margin: 0.6em 0em 0.55em 0;
|
22 |
+
max-width: 20em;
|
23 |
+
min-width: 5em !important;
|
24 |
+
height: 5em;
|
25 |
+
font-size: 14px !important;
|
26 |
+
}
|
27 |
+
|
28 |
+
#prompt-form {
|
29 |
+
border: 1px solid var(--primary-500) !important;
|
30 |
+
}
|
31 |
+
|
32 |
+
#prompt-form.block {
|
33 |
+
border-radius: var(--block-radius) !important;
|
34 |
+
}
|
35 |
+
|
36 |
+
#prompt-form textarea {
|
37 |
+
border: 1px solid rgb(209, 213, 219);
|
38 |
+
}
|
39 |
+
|
40 |
+
#prompt-form label > div {
|
41 |
+
margin-top: 4px;
|
42 |
+
}
|
43 |
+
|
44 |
+
button.primary:hover {
|
45 |
+
background-color: var(--primary-600) !important;
|
46 |
+
transition: .2s;
|
47 |
+
}
|
48 |
+
|
49 |
+
#prompt-form-area {
|
50 |
+
margin-bottom: 2.5rem;
|
51 |
+
}
|
52 |
+
.chatsmall chatbot {font-size: 10px !important}
|
53 |
+
"""
|
gradio_utils/grclient.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from typing import Callable
|
3 |
+
import os
|
4 |
+
|
5 |
+
from gradio_client.client import Job
|
6 |
+
|
7 |
+
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
8 |
+
|
9 |
+
from gradio_client import Client
|
10 |
+
|
11 |
+
|
12 |
+
class GradioClient(Client):
|
13 |
+
"""
|
14 |
+
Parent class of gradio client
|
15 |
+
To handle automatically refreshing client if detect gradio server changed
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, *args, **kwargs):
|
19 |
+
self.args = args
|
20 |
+
self.kwargs = kwargs
|
21 |
+
super().__init__(*args, **kwargs)
|
22 |
+
self.server_hash = self.get_server_hash()
|
23 |
+
|
24 |
+
def get_server_hash(self):
|
25 |
+
"""
|
26 |
+
Get server hash using super without any refresh action triggered
|
27 |
+
Returns: git hash of gradio server
|
28 |
+
"""
|
29 |
+
return super().submit(api_name='/system_hash').result()
|
30 |
+
|
31 |
+
def refresh_client_if_should(self):
|
32 |
+
# get current hash in order to update api_name -> fn_index map in case gradio server changed
|
33 |
+
# FIXME: Could add cli api as hash
|
34 |
+
server_hash = self.get_server_hash()
|
35 |
+
if self.server_hash != server_hash:
|
36 |
+
self.refresh_client()
|
37 |
+
self.server_hash = server_hash
|
38 |
+
else:
|
39 |
+
self.reset_session()
|
40 |
+
|
41 |
+
def refresh_client(self):
|
42 |
+
"""
|
43 |
+
Ensure every client call is independent
|
44 |
+
Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code)
|
45 |
+
Returns:
|
46 |
+
"""
|
47 |
+
# need session hash to be new every time, to avoid "generator already executing"
|
48 |
+
self.reset_session()
|
49 |
+
|
50 |
+
client = Client(*self.args, **self.kwargs)
|
51 |
+
for k, v in client.__dict__.items():
|
52 |
+
setattr(self, k, v)
|
53 |
+
|
54 |
+
def submit(
|
55 |
+
self,
|
56 |
+
*args,
|
57 |
+
api_name: str | None = None,
|
58 |
+
fn_index: int | None = None,
|
59 |
+
result_callbacks: Callable | list[Callable] | None = None,
|
60 |
+
) -> Job:
|
61 |
+
# Note predict calls submit
|
62 |
+
try:
|
63 |
+
self.refresh_client_if_should()
|
64 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
65 |
+
except Exception as e:
|
66 |
+
print("Hit e=%s" % str(e), flush=True)
|
67 |
+
# force reconfig in case only that
|
68 |
+
self.refresh_client()
|
69 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
70 |
+
|
71 |
+
# see if immediately failed
|
72 |
+
e = job.future._exception
|
73 |
+
if e is not None:
|
74 |
+
print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True)
|
75 |
+
# force reconfig in case only that
|
76 |
+
self.refresh_client()
|
77 |
+
job = super().submit(*args, api_name=api_name, fn_index=fn_index)
|
78 |
+
e2 = job.future._exception
|
79 |
+
if e2 is not None:
|
80 |
+
print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True)
|
81 |
+
|
82 |
+
return job
|
gradio_utils/prompt_form.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import math
|
3 |
+
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
|
7 |
+
def make_chatbots(output_label0, output_label0_model2, **kwargs):
|
8 |
+
text_outputs = []
|
9 |
+
chat_kwargs = []
|
10 |
+
for model_state_lock in kwargs['model_states']:
|
11 |
+
if os.environ.get('DEBUG_MODEL_LOCK'):
|
12 |
+
model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"]
|
13 |
+
else:
|
14 |
+
model_name = model_state_lock["base_model"]
|
15 |
+
output_label = f'h2oGPT [{model_name}]'
|
16 |
+
min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160
|
17 |
+
chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall',
|
18 |
+
height=kwargs['height'] or 400, min_width=min_width))
|
19 |
+
|
20 |
+
if kwargs['model_lock_columns'] == -1:
|
21 |
+
kwargs['model_lock_columns'] = len(kwargs['model_states'])
|
22 |
+
if kwargs['model_lock_columns'] is None:
|
23 |
+
kwargs['model_lock_columns'] = 3
|
24 |
+
|
25 |
+
ncols = kwargs['model_lock_columns']
|
26 |
+
if kwargs['model_states'] == 0:
|
27 |
+
nrows = 0
|
28 |
+
else:
|
29 |
+
nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns'])
|
30 |
+
|
31 |
+
if kwargs['model_lock_columns'] == 0:
|
32 |
+
# not using model_lock
|
33 |
+
pass
|
34 |
+
elif nrows <= 1:
|
35 |
+
with gr.Row():
|
36 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
37 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
38 |
+
elif nrows == kwargs['model_states']:
|
39 |
+
with gr.Row():
|
40 |
+
for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']):
|
41 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
42 |
+
elif nrows == 2:
|
43 |
+
with gr.Row():
|
44 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
45 |
+
if mii >= len(kwargs['model_states']) / 2:
|
46 |
+
continue
|
47 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
48 |
+
with gr.Row():
|
49 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
50 |
+
if mii < len(kwargs['model_states']) / 2:
|
51 |
+
continue
|
52 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
53 |
+
elif nrows == 3:
|
54 |
+
with gr.Row():
|
55 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
56 |
+
if mii >= 1 * len(kwargs['model_states']) / 3:
|
57 |
+
continue
|
58 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
59 |
+
with gr.Row():
|
60 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
61 |
+
if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3:
|
62 |
+
continue
|
63 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
64 |
+
with gr.Row():
|
65 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
66 |
+
if mii < 2 * len(kwargs['model_states']) / 3:
|
67 |
+
continue
|
68 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
69 |
+
elif nrows >= 4:
|
70 |
+
with gr.Row():
|
71 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
72 |
+
if mii >= 1 * len(kwargs['model_states']) / 4:
|
73 |
+
continue
|
74 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
75 |
+
with gr.Row():
|
76 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
77 |
+
if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4:
|
78 |
+
continue
|
79 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
80 |
+
with gr.Row():
|
81 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
82 |
+
if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4:
|
83 |
+
continue
|
84 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
85 |
+
with gr.Row():
|
86 |
+
for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])):
|
87 |
+
if mii < 3 * len(kwargs['model_states']) / 4:
|
88 |
+
continue
|
89 |
+
text_outputs.append(gr.Chatbot(**chat_kwargs1))
|
90 |
+
|
91 |
+
with gr.Row():
|
92 |
+
text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400)
|
93 |
+
text_output2 = gr.Chatbot(label=output_label0_model2,
|
94 |
+
visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400)
|
95 |
+
return text_output, text_output2, text_outputs
|
96 |
+
|
97 |
+
|
98 |
+
def make_prompt_form(kwargs):
|
99 |
+
if kwargs['input_lines'] > 1:
|
100 |
+
instruction_label = "Shift-Enter to Submit, Enter for more lines"
|
101 |
+
else:
|
102 |
+
instruction_label = "Enter to Submit, Shift-Enter for more lines"
|
103 |
+
|
104 |
+
with gr.Row():#elem_id='prompt-form-area'):
|
105 |
+
with gr.Column(scale=50):
|
106 |
+
instruction = gr.Textbox(
|
107 |
+
lines=kwargs['input_lines'],
|
108 |
+
label='Ask anything',
|
109 |
+
placeholder=instruction_label,
|
110 |
+
info=None,
|
111 |
+
elem_id='prompt-form',
|
112 |
+
container=True,
|
113 |
+
)
|
114 |
+
with gr.Row():
|
115 |
+
submit = gr.Button(value='Submit', variant='primary', scale=0, size='sm')
|
116 |
+
stop_btn = gr.Button(value="Stop", variant='secondary', scale=0, size='sm')
|
117 |
+
|
118 |
+
return instruction, submit, stop_btn
|
h2oai_pipeline.py
CHANGED
@@ -9,7 +9,7 @@ from prompter import Prompter, PromptType
|
|
9 |
|
10 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
-
sanitize_bot_response=
|
13 |
use_prompter=True, prompter=None,
|
14 |
prompt_type=None, prompt_dict=None,
|
15 |
max_input_tokens=2048 - 256, **kwargs):
|
@@ -51,25 +51,37 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
51 |
self.sanitize_bot_response = sanitize_bot_response
|
52 |
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
53 |
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
56 |
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
57 |
-
model_max_length =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
58 |
else:
|
59 |
# unknown
|
60 |
model_max_length = None
|
61 |
|
62 |
-
|
63 |
if model_max_length is not None:
|
64 |
-
num_prompt_tokens = None
|
65 |
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
66 |
# For https://github.com/h2oai/h2ogpt/issues/192
|
67 |
for trial in range(0, 3):
|
68 |
-
prompt_tokens =
|
69 |
num_prompt_tokens = len(prompt_tokens)
|
70 |
if num_prompt_tokens > model_max_length:
|
71 |
# conservative by using int()
|
72 |
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
|
|
73 |
prompt_text = prompt_text[-model_max_length * chars_per_token:]
|
74 |
if verbose:
|
75 |
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
|
@@ -79,20 +91,27 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
79 |
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
|
80 |
break
|
81 |
|
82 |
-
#
|
83 |
-
|
84 |
-
|
85 |
-
#
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
|
97 |
data_point = dict(context='', instruction=prompt_text, input='')
|
98 |
if self.prompter is not None:
|
@@ -100,7 +119,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
100 |
self.prompt_text = prompt_text
|
101 |
if handle_long_generation is None:
|
102 |
# forces truncation of inputs to avoid critical failure
|
103 |
-
handle_long_generation =
|
104 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
105 |
**generate_kwargs)
|
106 |
|
@@ -113,7 +132,7 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
113 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
114 |
sanitize_bot_response=self.sanitize_bot_response)
|
115 |
elif self.bot and self.human:
|
116 |
-
outputs = rec['generated_text'].split(self.bot)[1].
|
117 |
else:
|
118 |
outputs = rec['generated_text']
|
119 |
rec['generated_text'] = outputs
|
@@ -123,7 +142,8 @@ class H2OTextGenerationPipeline(TextGenerationPipeline):
|
|
123 |
if self.can_stop:
|
124 |
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
|
125 |
self.tokenizer, self.device,
|
126 |
-
human=self.human, bot=self.bot
|
|
|
127 |
generate_kwargs['stopping_criteria'] = stopping_criteria
|
128 |
# return super()._forward(model_inputs, **generate_kwargs)
|
129 |
return self.__forward(model_inputs, **generate_kwargs)
|
|
|
9 |
|
10 |
class H2OTextGenerationPipeline(TextGenerationPipeline):
|
11 |
def __init__(self, *args, debug=False, chat=False, stream_output=False,
|
12 |
+
sanitize_bot_response=False,
|
13 |
use_prompter=True, prompter=None,
|
14 |
prompt_type=None, prompt_dict=None,
|
15 |
max_input_tokens=2048 - 256, **kwargs):
|
|
|
51 |
self.sanitize_bot_response = sanitize_bot_response
|
52 |
self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs
|
53 |
|
54 |
+
@staticmethod
|
55 |
+
def limit_prompt(prompt_text, tokenizer, max_prompt_length=None):
|
56 |
+
verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0')))
|
57 |
+
|
58 |
+
if hasattr(tokenizer, 'model_max_length'):
|
59 |
# model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py
|
60 |
+
model_max_length = tokenizer.model_max_length
|
61 |
+
if max_prompt_length is not None:
|
62 |
+
model_max_length = min(model_max_length, max_prompt_length)
|
63 |
+
# cut at some upper likely limit to avoid excessive tokenization etc
|
64 |
+
# upper bound of 10 chars/token, e.g. special chars sometimes are long
|
65 |
+
if len(prompt_text) > model_max_length * 10:
|
66 |
+
len0 = len(prompt_text)
|
67 |
+
prompt_text = prompt_text[-model_max_length * 10:]
|
68 |
+
if verbose:
|
69 |
+
print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True)
|
70 |
else:
|
71 |
# unknown
|
72 |
model_max_length = None
|
73 |
|
74 |
+
num_prompt_tokens = None
|
75 |
if model_max_length is not None:
|
|
|
76 |
# can't wait for "hole" if not plain prompt_type, since would lose prefix like <human>:
|
77 |
# For https://github.com/h2oai/h2ogpt/issues/192
|
78 |
for trial in range(0, 3):
|
79 |
+
prompt_tokens = tokenizer(prompt_text)['input_ids']
|
80 |
num_prompt_tokens = len(prompt_tokens)
|
81 |
if num_prompt_tokens > model_max_length:
|
82 |
# conservative by using int()
|
83 |
chars_per_token = int(len(prompt_text) / num_prompt_tokens)
|
84 |
+
# keep tail, where question is if using langchain
|
85 |
prompt_text = prompt_text[-model_max_length * chars_per_token:]
|
86 |
if verbose:
|
87 |
print("reducing %s tokens, assuming average of %s chars/token for %s characters" % (
|
|
|
91 |
print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True)
|
92 |
break
|
93 |
|
94 |
+
# Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model
|
95 |
+
if False:
|
96 |
+
# if input prompt is some number of tokens, despite user request, can't have max_new_tokens more
|
97 |
+
#
|
98 |
+
assert num_prompt_tokens is not None
|
99 |
+
if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]:
|
100 |
+
# then give room for prompt
|
101 |
+
fudge = 20
|
102 |
+
else:
|
103 |
+
fudge = 0
|
104 |
+
max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'],
|
105 |
+
model_max_length - (num_prompt_tokens + fudge)))
|
106 |
+
if max_new_tokens < generate_kwargs['max_new_tokens']:
|
107 |
+
if verbose:
|
108 |
+
print("Reduced max_new_tokens from %s -> %s" % (
|
109 |
+
generate_kwargs['max_new_tokens'], max_new_tokens))
|
110 |
+
generate_kwargs['max_new_tokens'] = max_new_tokens
|
111 |
+
return prompt_text, num_prompt_tokens
|
112 |
+
|
113 |
+
def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs):
|
114 |
+
prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer)
|
115 |
|
116 |
data_point = dict(context='', instruction=prompt_text, input='')
|
117 |
if self.prompter is not None:
|
|
|
119 |
self.prompt_text = prompt_text
|
120 |
if handle_long_generation is None:
|
121 |
# forces truncation of inputs to avoid critical failure
|
122 |
+
handle_long_generation = None # disable with new approaches
|
123 |
return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation,
|
124 |
**generate_kwargs)
|
125 |
|
|
|
132 |
outputs = self.prompter.get_response(outputs, prompt=self.prompt_text,
|
133 |
sanitize_bot_response=self.sanitize_bot_response)
|
134 |
elif self.bot and self.human:
|
135 |
+
outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0]
|
136 |
else:
|
137 |
outputs = rec['generated_text']
|
138 |
rec['generated_text'] = outputs
|
|
|
142 |
if self.can_stop:
|
143 |
stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict,
|
144 |
self.tokenizer, self.device,
|
145 |
+
human=self.human, bot=self.bot,
|
146 |
+
model_max_length=self.tokenizer.model_max_length)
|
147 |
generate_kwargs['stopping_criteria'] = stopping_criteria
|
148 |
# return super()._forward(model_inputs, **generate_kwargs)
|
149 |
return self.__forward(model_inputs, **generate_kwargs)
|
iterators/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator
|
2 |
+
from .iterator_pipe import IteratorPipe, AsyncIteratorPipe
|
3 |
+
|
4 |
+
__all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"]
|
iterators/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (337 Bytes). View file
|
|
iterators/__pycache__/iterator_pipe.cpython-310.pyc
ADDED
Binary file (2.71 kB). View file
|
|
iterators/__pycache__/timeout_iterator.cpython-310.pyc
ADDED
Binary file (5.63 kB). View file
|
|
iterators/iterator_pipe.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
import asyncio
|
3 |
+
|
4 |
+
|
5 |
+
class IteratorPipe:
|
6 |
+
"""
|
7 |
+
Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, sentinel=object()):
|
11 |
+
self._q = queue.Queue()
|
12 |
+
self._sentinel = sentinel
|
13 |
+
self._sentinel_pushed = False
|
14 |
+
self._closed = False
|
15 |
+
|
16 |
+
def __iter__(self):
|
17 |
+
return self
|
18 |
+
|
19 |
+
def __next__(self):
|
20 |
+
if self._closed:
|
21 |
+
raise StopIteration
|
22 |
+
|
23 |
+
data = self._q.get(block=True)
|
24 |
+
if data is self._sentinel:
|
25 |
+
self._closed = True
|
26 |
+
raise StopIteration
|
27 |
+
|
28 |
+
return data
|
29 |
+
|
30 |
+
def put(self, data) -> bool:
|
31 |
+
"""
|
32 |
+
Pushes next item to Iterator and returns True
|
33 |
+
If iterator has been closed via close(), doesn't push anything and returns False
|
34 |
+
"""
|
35 |
+
if self._sentinel_pushed:
|
36 |
+
return False
|
37 |
+
|
38 |
+
self._q.put(data)
|
39 |
+
return True
|
40 |
+
|
41 |
+
def close(self):
|
42 |
+
"""
|
43 |
+
Close is idempotent. Calling close multiple times is safe
|
44 |
+
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
45 |
+
"""
|
46 |
+
# make close idempotent
|
47 |
+
if not self._sentinel_pushed:
|
48 |
+
self._sentinel_pushed = True
|
49 |
+
self._q.put(self._sentinel)
|
50 |
+
|
51 |
+
|
52 |
+
class AsyncIteratorPipe:
|
53 |
+
|
54 |
+
def __init__(self, sentinel=object()):
|
55 |
+
self._q = asyncio.Queue()
|
56 |
+
self._sentinel = sentinel
|
57 |
+
self._sentinel_pushed = False
|
58 |
+
self._closed = False
|
59 |
+
|
60 |
+
def __aiter__(self):
|
61 |
+
return self
|
62 |
+
|
63 |
+
async def __anext__(self):
|
64 |
+
if self._closed:
|
65 |
+
raise StopAsyncIteration
|
66 |
+
|
67 |
+
data = await self._q.get()
|
68 |
+
if data is self._sentinel:
|
69 |
+
self._closed = True
|
70 |
+
raise StopAsyncIteration
|
71 |
+
|
72 |
+
return data
|
73 |
+
|
74 |
+
async def put(self, data) -> bool:
|
75 |
+
"""
|
76 |
+
Pushes next item to Iterator and returns True
|
77 |
+
If iterator has been closed via close(), doesn't push anything and returns False
|
78 |
+
"""
|
79 |
+
if self._sentinel_pushed:
|
80 |
+
return False
|
81 |
+
|
82 |
+
await self._q.put(data)
|
83 |
+
return True
|
84 |
+
|
85 |
+
async def close(self):
|
86 |
+
"""
|
87 |
+
Close is idempotent. Calling close multiple times is safe
|
88 |
+
Iterator will raise StopIteration only after all elements pushed before close have been iterated
|
89 |
+
"""
|
90 |
+
# make close idempotent
|
91 |
+
if not self._sentinel_pushed:
|
92 |
+
self._sentinel_pushed = True
|
93 |
+
await self._q.put(self._sentinel)
|
iterators/timeout_iterator.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import queue
|
2 |
+
import asyncio
|
3 |
+
import threading
|
4 |
+
import traceback
|
5 |
+
|
6 |
+
|
7 |
+
class TimeoutIterator:
|
8 |
+
"""
|
9 |
+
Wrapper class to add timeout feature to synchronous iterators
|
10 |
+
- timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout()
|
11 |
+
- sentinel: the object returned by iterator when timeout happens
|
12 |
+
- reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration
|
13 |
+
|
14 |
+
TimeoutIterator uses a thread internally.
|
15 |
+
The thread stops once the iterator exhausts or raises an exception during iteration.
|
16 |
+
|
17 |
+
Any exceptions raised within the wrapped iterator are propagated as it is.
|
18 |
+
Exception is raised when all elements generated by the actual iterator before exception have been consumed
|
19 |
+
Timeout can be set dynamically before going for iteration
|
20 |
+
"""
|
21 |
+
ZERO_TIMEOUT = 0.0
|
22 |
+
|
23 |
+
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True):
|
24 |
+
self._iterator = iterator
|
25 |
+
self._timeout = timeout
|
26 |
+
self._sentinel = sentinel
|
27 |
+
self._reset_on_next = reset_on_next
|
28 |
+
self._raise_on_exception = raise_on_exception
|
29 |
+
|
30 |
+
self._interrupt = False
|
31 |
+
self._done = False
|
32 |
+
self._buffer = queue.Queue()
|
33 |
+
self._thread = threading.Thread(target=self.__lookahead)
|
34 |
+
self._thread.start()
|
35 |
+
|
36 |
+
def get_sentinel(self):
|
37 |
+
return self._sentinel
|
38 |
+
|
39 |
+
def set_reset_on_next(self, reset_on_next):
|
40 |
+
self._reset_on_next = reset_on_next
|
41 |
+
|
42 |
+
def set_timeout(self, timeout: float):
|
43 |
+
"""
|
44 |
+
Set timeout for next iteration
|
45 |
+
"""
|
46 |
+
self._timeout = timeout
|
47 |
+
|
48 |
+
def interrupt(self):
|
49 |
+
"""
|
50 |
+
interrupt and stop the underlying thread.
|
51 |
+
the thread acutally dies only after interrupt has been set and
|
52 |
+
the underlying iterator yields a value after that.
|
53 |
+
"""
|
54 |
+
self._interrupt = True
|
55 |
+
|
56 |
+
def __iter__(self):
|
57 |
+
return self
|
58 |
+
|
59 |
+
def __next__(self):
|
60 |
+
"""
|
61 |
+
yield the result from iterator
|
62 |
+
if timeout > 0:
|
63 |
+
yield data if available.
|
64 |
+
otherwise yield sentinal
|
65 |
+
"""
|
66 |
+
if self._done:
|
67 |
+
raise StopIteration
|
68 |
+
|
69 |
+
data = self._sentinel
|
70 |
+
try:
|
71 |
+
if self._timeout > self.ZERO_TIMEOUT:
|
72 |
+
data = self._buffer.get(timeout=self._timeout)
|
73 |
+
else:
|
74 |
+
data = self._buffer.get()
|
75 |
+
except queue.Empty:
|
76 |
+
pass
|
77 |
+
finally:
|
78 |
+
# see if timeout needs to be reset
|
79 |
+
if self._reset_on_next:
|
80 |
+
self._timeout = self.ZERO_TIMEOUT
|
81 |
+
|
82 |
+
# propagate any exceptions including StopIteration
|
83 |
+
if isinstance(data, BaseException):
|
84 |
+
self._done = True
|
85 |
+
if isinstance(data, StopIteration):
|
86 |
+
raise data
|
87 |
+
ex = ''.join(traceback.format_tb(data.__traceback__))
|
88 |
+
print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True)
|
89 |
+
if self._raise_on_exception:
|
90 |
+
raise data
|
91 |
+
else:
|
92 |
+
return data
|
93 |
+
|
94 |
+
return data
|
95 |
+
|
96 |
+
def __lookahead(self):
|
97 |
+
try:
|
98 |
+
while True:
|
99 |
+
self._buffer.put(next(self._iterator))
|
100 |
+
if self._interrupt:
|
101 |
+
raise StopIteration()
|
102 |
+
except BaseException as e:
|
103 |
+
self._buffer.put(e)
|
104 |
+
|
105 |
+
|
106 |
+
class AsyncTimeoutIterator:
|
107 |
+
"""
|
108 |
+
Async version of TimeoutIterator. See method documentation of TimeoutIterator
|
109 |
+
"""
|
110 |
+
ZERO_TIMEOUT = 0.0
|
111 |
+
|
112 |
+
def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False):
|
113 |
+
self._iterator = iterator
|
114 |
+
self._timeout = timeout
|
115 |
+
self._sentinel = sentinel
|
116 |
+
self._reset_on_next = reset_on_next
|
117 |
+
|
118 |
+
self._interrupt = False
|
119 |
+
self._done = False
|
120 |
+
self._buffer = asyncio.Queue()
|
121 |
+
self._task = asyncio.get_event_loop().create_task(self.__lookahead())
|
122 |
+
|
123 |
+
def get_sentinel(self):
|
124 |
+
return self._sentinel
|
125 |
+
|
126 |
+
def set_reset_on_next(self, reset_on_next):
|
127 |
+
self._reset_on_next = reset_on_next
|
128 |
+
|
129 |
+
def set_timeout(self, timeout: float):
|
130 |
+
self._timeout = timeout
|
131 |
+
|
132 |
+
def interrupt(self):
|
133 |
+
self._interrupt = True
|
134 |
+
|
135 |
+
def __aiter__(self):
|
136 |
+
return self
|
137 |
+
|
138 |
+
async def __anext__(self):
|
139 |
+
if self._done:
|
140 |
+
raise StopAsyncIteration
|
141 |
+
|
142 |
+
data = self._sentinel
|
143 |
+
try:
|
144 |
+
if self._timeout > self.ZERO_TIMEOUT:
|
145 |
+
data = await asyncio.wait_for(self._buffer.get(), self._timeout)
|
146 |
+
else:
|
147 |
+
data = await self._buffer.get()
|
148 |
+
except asyncio.TimeoutError:
|
149 |
+
pass
|
150 |
+
finally:
|
151 |
+
# see if timeout needs to be reset
|
152 |
+
if self._reset_on_next:
|
153 |
+
self._timeout = self.ZERO_TIMEOUT
|
154 |
+
|
155 |
+
# propagate any exceptions including StopIteration
|
156 |
+
if isinstance(data, BaseException):
|
157 |
+
self._done = True
|
158 |
+
raise data
|
159 |
+
|
160 |
+
return data
|
161 |
+
|
162 |
+
async def __lookahead(self):
|
163 |
+
try:
|
164 |
+
while True:
|
165 |
+
data = await self._iterator.__anext__()
|
166 |
+
await self._buffer.put(data)
|
167 |
+
if self._interrupt:
|
168 |
+
raise StopAsyncIteration()
|
169 |
+
except BaseException as e:
|
170 |
+
await self._buffer.put(e)
|
loaders.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
-
def get_loaders(
|
2 |
# NOTE: Some models need specific new prompt_type
|
3 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
|
|
|
|
4 |
if llama_type:
|
5 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
6 |
model_loader = LlamaForCausalLM
|
@@ -39,7 +41,8 @@ def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resu
|
|
39 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
40 |
local_files_only=local_files_only,
|
41 |
resume_download=resume_download,
|
42 |
-
use_auth_token=use_auth_token
|
|
|
43 |
|
44 |
tokenizer.pad_token_id = 0 # different from the eos token
|
45 |
# when generating, we will use the logits of right-most token to predict the next token
|
|
|
1 |
+
def get_loaders(model_name, reward_type, llama_type=None):
|
2 |
# NOTE: Some models need specific new prompt_type
|
3 |
# E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".)
|
4 |
+
if llama_type is None:
|
5 |
+
llama_type = "llama" in model_name.lower()
|
6 |
if llama_type:
|
7 |
from transformers import LlamaForCausalLM, LlamaTokenizer
|
8 |
model_loader = LlamaForCausalLM
|
|
|
41 |
tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model,
|
42 |
local_files_only=local_files_only,
|
43 |
resume_download=resume_download,
|
44 |
+
use_auth_token=use_auth_token,
|
45 |
+
padding_side='left')
|
46 |
|
47 |
tokenizer.pad_token_id = 0 # different from the eos token
|
48 |
# when generating, we will use the logits of right-most token to predict the next token
|
prompter.py
CHANGED
@@ -1,10 +1,10 @@
|
|
|
|
1 |
import ast
|
2 |
import time
|
3 |
from enums import PromptType # also supports imports from this file from other files
|
4 |
|
5 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
6 |
|
7 |
-
|
8 |
prompt_type_to_model_name = {
|
9 |
'plain': [
|
10 |
'EleutherAI/gpt-j-6B',
|
@@ -25,23 +25,29 @@ prompt_type_to_model_name = {
|
|
25 |
'mosaicml/mpt-7b-storywriter',
|
26 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
27 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
28 |
-
'
|
29 |
-
'llama', # plain, or need to choose prompt_type for given TheBloke model
|
30 |
-
'gpt4all_llama', # internally handles prompting
|
31 |
],
|
|
|
32 |
'prompt_answer': [
|
33 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
34 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
35 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
36 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
37 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
38 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
39 |
-
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
40 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
41 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
|
|
42 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
43 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
44 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
],
|
46 |
'instruct': [],
|
47 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
@@ -54,6 +60,7 @@ prompt_type_to_model_name = {
|
|
54 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
55 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
56 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
|
|
57 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
58 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
59 |
],
|
@@ -66,7 +73,16 @@ prompt_type_to_model_name = {
|
|
66 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
67 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
68 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
|
|
|
|
|
|
|
|
69 |
}
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
72 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
@@ -80,18 +96,29 @@ for p in PromptType:
|
|
80 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
81 |
|
82 |
|
83 |
-
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=False):
|
84 |
prompt_dict_error = ''
|
|
|
|
|
85 |
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
86 |
try:
|
87 |
prompt_dict = ast.literal_eval(prompt_dict)
|
88 |
except BaseException as e:
|
89 |
prompt_dict_error = str(e)
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
promptA = prompt_dict.get('promptA', '')
|
96 |
promptB = prompt_dict('promptB', '')
|
97 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
@@ -99,21 +126,23 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
99 |
PreResponse = prompt_dict.get('PreResponse', '')
|
100 |
terminate_response = prompt_dict.get('terminate_response', None)
|
101 |
chat_sep = prompt_dict.get('chat_sep', '\n')
|
|
|
102 |
humanstr = prompt_dict.get('humanstr', '')
|
103 |
botstr = prompt_dict.get('botstr', '')
|
104 |
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
105 |
PromptType.plain.name]:
|
106 |
-
promptA = promptB = PreInstruct = PreInput = PreResponse =
|
107 |
terminate_response = []
|
108 |
-
chat_sep = ''
|
109 |
-
|
110 |
-
|
|
|
111 |
elif prompt_type == 'simple_instruct':
|
112 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
113 |
terminate_response = []
|
114 |
-
chat_sep = '\n'
|
115 |
-
humanstr =
|
116 |
-
botstr =
|
117 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
118 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
119 |
str(PromptType.instruct_with_end.value),
|
@@ -139,7 +168,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
139 |
terminate_response = ['### End']
|
140 |
else:
|
141 |
terminate_response = None
|
142 |
-
chat_sep = '\n'
|
143 |
humanstr = PreInstruct
|
144 |
botstr = PreResponse
|
145 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
@@ -161,7 +190,7 @@ def get_prompt(prompt_type, prompt_dict, chat, context, reduced, return_dict=Fal
|
|
161 |
### Response:
|
162 |
"""
|
163 |
terminate_response = None
|
164 |
-
chat_sep = '\n'
|
165 |
humanstr = PreInstruct # first thing human says
|
166 |
botstr = PreResponse # first thing bot says
|
167 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
@@ -183,14 +212,14 @@ Current Time: {}
|
|
183 |
|
184 |
"""
|
185 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
186 |
-
start =
|
187 |
-
promptB = promptA = '%s%s
|
188 |
|
189 |
-
PreInstruct =
|
190 |
|
191 |
PreInput = None
|
192 |
|
193 |
-
if
|
194 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
195 |
PreResponse = bot + ' '
|
196 |
else:
|
@@ -198,10 +227,11 @@ Current Time: {}
|
|
198 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
199 |
PreResponse = bot
|
200 |
|
201 |
-
terminate_response = [
|
202 |
-
chat_sep = '\n'
|
203 |
humanstr = human # tag before human talks
|
204 |
botstr = bot # tag before bot talks
|
|
|
205 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
206 |
PromptType.dai_faq.name]:
|
207 |
promptA = ''
|
@@ -217,7 +247,7 @@ Current Time: {}
|
|
217 |
### Driverless AI documentation answer:
|
218 |
"""
|
219 |
terminate_response = ['\n\n']
|
220 |
-
chat_sep = terminate_response
|
221 |
humanstr = PreInstruct
|
222 |
botstr = PreResponse
|
223 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
@@ -226,7 +256,7 @@ Current Time: {}
|
|
226 |
PreInstruct = '## Main Text\n\n'
|
227 |
PreResponse = '\n\n## Summary\n\n'
|
228 |
terminate_response = None
|
229 |
-
chat_sep = '\n'
|
230 |
humanstr = PreInstruct
|
231 |
botstr = PreResponse
|
232 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
@@ -246,7 +276,7 @@ Current Time: {}
|
|
246 |
"""
|
247 |
terminate_response = [
|
248 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
249 |
-
chat_sep = '\n'
|
250 |
humanstr = PreInstruct
|
251 |
botstr = PreResponse
|
252 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
@@ -254,33 +284,50 @@ Current Time: {}
|
|
254 |
preprompt = ''
|
255 |
prompt_tokens = "<|prompt|>"
|
256 |
answer_tokens = "<|answer|>"
|
257 |
-
start =
|
258 |
promptB = promptA = '%s%s' % (preprompt, start)
|
259 |
-
PreInstruct =
|
260 |
PreInput = None
|
261 |
PreResponse = answer_tokens
|
262 |
eos = '<|endoftext|>' # neox eos
|
263 |
-
terminate_response = [start, PreResponse, eos]
|
264 |
-
chat_sep = eos
|
265 |
humanstr = prompt_tokens
|
266 |
botstr = answer_tokens
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
267 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
268 |
PromptType.open_assistant.name]:
|
269 |
# From added_tokens.json
|
270 |
preprompt = ''
|
271 |
prompt_tokens = "<|prompter|>"
|
272 |
answer_tokens = "<|assistant|>"
|
273 |
-
start =
|
274 |
promptB = promptA = '%s%s' % (preprompt, start)
|
275 |
-
PreInstruct =
|
276 |
PreInput = None
|
277 |
PreResponse = answer_tokens
|
278 |
pend = "<|prefix_end|>"
|
279 |
eos = "</s>"
|
280 |
-
terminate_response = [start, PreResponse, pend, eos]
|
281 |
-
chat_sep = eos
|
282 |
humanstr = prompt_tokens
|
283 |
botstr = answer_tokens
|
|
|
|
|
284 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
285 |
PromptType.wizard_lm.name]:
|
286 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
@@ -292,7 +339,7 @@ Current Time: {}
|
|
292 |
PreResponse = "\n\n### Response\n"
|
293 |
eos = "</s>"
|
294 |
terminate_response = [PreResponse, eos]
|
295 |
-
chat_sep = eos
|
296 |
humanstr = promptA
|
297 |
botstr = PreResponse
|
298 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
@@ -308,13 +355,12 @@ Current Time: {}
|
|
308 |
### Assistant:
|
309 |
"""
|
310 |
terminate_response = [PreResponse]
|
311 |
-
chat_sep = '\n'
|
312 |
humanstr = PreInstruct
|
313 |
botstr = PreResponse
|
314 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
315 |
PromptType.instruct_vicuna2.name]:
|
316 |
-
promptA = promptB = "" if not (
|
317 |
-
chat and reduced) else ''
|
318 |
|
319 |
PreInstruct = """
|
320 |
HUMAN:
|
@@ -327,13 +373,12 @@ ASSISTANT:
|
|
327 |
"""
|
328 |
terminate_response = [
|
329 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
330 |
-
chat_sep = '\n'
|
331 |
humanstr = PreInstruct
|
332 |
botstr = PreResponse
|
333 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
334 |
PromptType.instruct_vicuna3.name]:
|
335 |
-
promptA = promptB = "" if not (
|
336 |
-
chat and reduced) else ''
|
337 |
|
338 |
PreInstruct = """
|
339 |
### User:
|
@@ -346,13 +391,14 @@ ASSISTANT:
|
|
346 |
"""
|
347 |
terminate_response = [
|
348 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
349 |
-
chat_sep = '\n'
|
350 |
humanstr = PreInstruct
|
351 |
botstr = PreResponse
|
352 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
353 |
PromptType.wizard2.name]:
|
354 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
355 |
-
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request."""
|
|
|
356 |
start = ''
|
357 |
promptB = promptA = '%s%s' % (preprompt, start)
|
358 |
PreInstruct = """
|
@@ -363,27 +409,39 @@ ASSISTANT:
|
|
363 |
### Response:
|
364 |
"""
|
365 |
terminate_response = [PreResponse]
|
366 |
-
chat_sep = '\n'
|
367 |
humanstr = PreInstruct
|
368 |
botstr = PreResponse
|
369 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
370 |
PromptType.wizard3.name]:
|
371 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
372 |
-
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions."""
|
|
|
373 |
start = ''
|
374 |
promptB = promptA = '%s%s' % (preprompt, start)
|
375 |
PreInstruct = """USER: """
|
376 |
PreInput = None
|
377 |
PreResponse = """ASSISTANT: """
|
378 |
terminate_response = [PreResponse]
|
379 |
-
chat_sep = '\n'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
humanstr = PreInstruct
|
381 |
botstr = PreResponse
|
382 |
|
383 |
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
384 |
PromptType.instruct_simple.name]:
|
385 |
-
promptA = '' if not (chat and reduced) else ''
|
386 |
-
promptB = '' if not (chat and reduced) else ''
|
387 |
|
388 |
PreInstruct = """
|
389 |
### Instruction:
|
@@ -397,21 +455,90 @@ ASSISTANT:
|
|
397 |
### Response:
|
398 |
"""
|
399 |
terminate_response = None
|
400 |
-
chat_sep = '\n'
|
401 |
humanstr = PreInstruct
|
402 |
botstr = PreResponse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
403 |
else:
|
404 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
405 |
|
406 |
-
if
|
407 |
-
|
|
|
|
|
408 |
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
409 |
-
|
|
|
|
|
|
|
|
|
|
|
410 |
else:
|
411 |
-
return
|
412 |
|
413 |
|
414 |
-
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
415 |
context = data_point.get('context')
|
416 |
if context is None:
|
417 |
context = ''
|
@@ -422,9 +549,12 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
|
422 |
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
423 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
424 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
425 |
-
terminate_response, chat_sep, humanstr, botstr
|
|
|
|
|
426 |
|
427 |
-
|
|
|
428 |
|
429 |
if input and promptA:
|
430 |
prompt += f"""{promptA}"""
|
@@ -433,37 +563,37 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
|
433 |
|
434 |
if instruction and PreInstruct is not None and input and PreInput is not None:
|
435 |
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
436 |
-
prompt =
|
437 |
elif instruction and input and PreInstruct is None and PreInput is not None:
|
438 |
prompt += f"""{PreInput}{instruction}
|
439 |
{input}"""
|
440 |
-
prompt =
|
441 |
elif input and instruction and PreInput is None and PreInstruct is not None:
|
442 |
prompt += f"""{PreInstruct}{instruction}
|
443 |
{input}"""
|
444 |
-
prompt =
|
445 |
elif instruction and PreInstruct is not None:
|
446 |
prompt += f"""{PreInstruct}{instruction}"""
|
447 |
-
prompt =
|
448 |
elif input and PreInput is not None:
|
449 |
prompt += f"""{PreInput}{input}"""
|
450 |
-
prompt =
|
451 |
elif input and instruction and PreInput is not None:
|
452 |
prompt += f"""{PreInput}{instruction}{input}"""
|
453 |
-
prompt =
|
454 |
elif input and instruction and PreInstruct is not None:
|
455 |
prompt += f"""{PreInstruct}{instruction}{input}"""
|
456 |
-
prompt =
|
457 |
elif input and instruction:
|
458 |
# i.e. for simple_instruct
|
459 |
prompt += f"""{instruction}: {input}"""
|
460 |
-
prompt =
|
461 |
elif input:
|
462 |
prompt += f"""{input}"""
|
463 |
-
prompt =
|
464 |
elif instruction:
|
465 |
prompt += f"""{instruction}"""
|
466 |
-
prompt =
|
467 |
|
468 |
if PreResponse is not None:
|
469 |
prompt += f"""{PreResponse}"""
|
@@ -474,13 +604,13 @@ def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced):
|
|
474 |
if output:
|
475 |
prompt += f"""{output}"""
|
476 |
|
477 |
-
return prompt, pre_response, terminate_response, chat_sep
|
478 |
|
479 |
|
480 |
-
def
|
481 |
-
if
|
482 |
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
483 |
-
prompt +=
|
484 |
return prompt
|
485 |
|
486 |
|
@@ -489,9 +619,6 @@ class Prompter(object):
|
|
489 |
allowed_repeat_line_length=10):
|
490 |
self.prompt_type = prompt_type
|
491 |
self.prompt_dict = prompt_dict
|
492 |
-
data_point = dict(instruction='', input='', output='')
|
493 |
-
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
494 |
-
generate_prompt(data_point, self.prompt_type, self.prompt_dict, chat, False)
|
495 |
self.debug = debug
|
496 |
self.chat = chat
|
497 |
self.stream_output = stream_output
|
@@ -500,23 +627,41 @@ class Prompter(object):
|
|
500 |
self.prompt = None
|
501 |
context = "" # not for chat context
|
502 |
reduced = False # not for chat context
|
|
|
503 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
504 |
-
self.terminate_response, self.chat_sep, self.humanstr, self.botstr
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
510 |
if self.debug:
|
511 |
-
print("prompt: "
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
512 |
self.prompt = prompt
|
513 |
return prompt
|
514 |
|
515 |
-
def get_response(self, outputs, prompt=None, sanitize_bot_response=
|
516 |
if isinstance(outputs, str):
|
517 |
outputs = [outputs]
|
518 |
if self.debug:
|
519 |
-
print("output:\n"
|
520 |
if prompt is not None:
|
521 |
self.prompt = prompt
|
522 |
|
@@ -527,7 +672,8 @@ class Prompter(object):
|
|
527 |
if sanitize_bot_response:
|
528 |
from better_profanity import profanity
|
529 |
response = profanity.censor(response)
|
530 |
-
response
|
|
|
531 |
return response
|
532 |
|
533 |
def clean_repeats(response):
|
@@ -549,12 +695,12 @@ class Prompter(object):
|
|
549 |
# then use most basic parsing like pipeline
|
550 |
if self.botstr in output:
|
551 |
if self.humanstr:
|
552 |
-
output = clean_response(output.split(self.botstr)[1].
|
553 |
else:
|
554 |
# i.e. use after bot but only up to next bot
|
555 |
-
output = clean_response(output.split(self.botstr)[1].
|
556 |
else:
|
557 |
-
# output = clean_response(output
|
558 |
# assume just not printed yet
|
559 |
output = ""
|
560 |
else:
|
@@ -581,9 +727,9 @@ class Prompter(object):
|
|
581 |
allow_terminate = True
|
582 |
output = output[len(prompt):]
|
583 |
# clean after subtract prompt out, so correct removal of pre_response
|
584 |
-
output = clean_response(output)
|
585 |
if self.repeat_penalty:
|
586 |
-
output = clean_repeats(output)
|
587 |
if self.terminate_response and allow_terminate:
|
588 |
finds = []
|
589 |
for term in self.terminate_response:
|
@@ -591,11 +737,9 @@ class Prompter(object):
|
|
591 |
finds = [x for x in finds if x >= 0]
|
592 |
if len(finds) > 0:
|
593 |
termi = finds[0]
|
594 |
-
output = output[:termi]
|
595 |
else:
|
596 |
-
output = output
|
597 |
-
else:
|
598 |
-
output = output.strip()
|
599 |
if multi_output:
|
600 |
# prefix with output counter
|
601 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
@@ -606,5 +750,5 @@ class Prompter(object):
|
|
606 |
# join all outputs, only one extra new line between outputs
|
607 |
output = '\n'.join(outputs)
|
608 |
if self.debug:
|
609 |
-
print("outputclean:\n"
|
610 |
return output
|
|
|
1 |
+
import os
|
2 |
import ast
|
3 |
import time
|
4 |
from enums import PromptType # also supports imports from this file from other files
|
5 |
|
6 |
non_hf_types = ['gpt4all_llama', 'llama', 'gptj']
|
7 |
|
|
|
8 |
prompt_type_to_model_name = {
|
9 |
'plain': [
|
10 |
'EleutherAI/gpt-j-6B',
|
|
|
25 |
'mosaicml/mpt-7b-storywriter',
|
26 |
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
27 |
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
28 |
+
'mosaicml/mpt-30b-instruct', # internal code handles instruct
|
|
|
|
|
29 |
],
|
30 |
+
'gptj': ['gptj', 'gpt4all_llama'],
|
31 |
'prompt_answer': [
|
32 |
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
33 |
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
34 |
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
|
|
|
|
|
|
|
|
35 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b',
|
36 |
'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2',
|
37 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3',
|
38 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b',
|
39 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2',
|
40 |
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1',
|
41 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2',
|
42 |
+
'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k',
|
43 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k',
|
44 |
+
],
|
45 |
+
'prompt_answer_openllama': [
|
46 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt',
|
47 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2',
|
48 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt',
|
49 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b',
|
50 |
+
'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b',
|
51 |
],
|
52 |
'instruct': [],
|
53 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
60 |
'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy
|
61 |
'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy
|
62 |
'h2oai/h2ogpt-research-oasst1-512-30b',
|
63 |
+
'h2oai/h2ogpt-research-oasst1-llama-65b',
|
64 |
'h2oai/h2ogpt-oasst1-falcon-40b',
|
65 |
'h2oai/h2ogpt-oig-oasst1-falcon-40b',
|
66 |
],
|
|
|
73 |
"wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'],
|
74 |
"wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'],
|
75 |
"instruct_simple": ['JosephusCheung/Guanaco'],
|
76 |
+
"wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'],
|
77 |
+
"wizard2": ['llama', 'mosaicml/mpt-30b-instruct'],
|
78 |
+
"vicuna11": ['lmsys/vicuna-33b-v1.3'],
|
79 |
+
# could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin
|
80 |
}
|
81 |
+
if os.getenv('OPENAI_API_KEY'):
|
82 |
+
prompt_type_to_model_name.update({
|
83 |
+
"openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"],
|
84 |
+
"openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"],
|
85 |
+
})
|
86 |
|
87 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
88 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
|
|
96 |
prompt_types.extend([p.name, p.value, str(p.value)])
|
97 |
|
98 |
|
99 |
+
def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False):
|
100 |
prompt_dict_error = ''
|
101 |
+
generates_leading_space = False
|
102 |
+
|
103 |
if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict):
|
104 |
try:
|
105 |
prompt_dict = ast.literal_eval(prompt_dict)
|
106 |
except BaseException as e:
|
107 |
prompt_dict_error = str(e)
|
108 |
+
if prompt_dict_error:
|
109 |
+
promptA = None
|
110 |
+
promptB = None
|
111 |
+
PreInstruct = None
|
112 |
+
PreInput = ''
|
113 |
+
PreResponse = ''
|
114 |
+
terminate_response = None
|
115 |
+
chat_sep = ''
|
116 |
+
chat_turn_sep = ''
|
117 |
+
humanstr = ''
|
118 |
+
botstr = ''
|
119 |
+
generates_leading_space = False
|
120 |
+
elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value),
|
121 |
+
PromptType.custom.name]:
|
122 |
promptA = prompt_dict.get('promptA', '')
|
123 |
promptB = prompt_dict('promptB', '')
|
124 |
PreInstruct = prompt_dict.get('PreInstruct', '')
|
|
|
126 |
PreResponse = prompt_dict.get('PreResponse', '')
|
127 |
terminate_response = prompt_dict.get('terminate_response', None)
|
128 |
chat_sep = prompt_dict.get('chat_sep', '\n')
|
129 |
+
chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n')
|
130 |
humanstr = prompt_dict.get('humanstr', '')
|
131 |
botstr = prompt_dict.get('botstr', '')
|
132 |
elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value),
|
133 |
PromptType.plain.name]:
|
134 |
+
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
135 |
terminate_response = []
|
136 |
+
chat_turn_sep = chat_sep = ''
|
137 |
+
# plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token
|
138 |
+
humanstr = None
|
139 |
+
botstr = None
|
140 |
elif prompt_type == 'simple_instruct':
|
141 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
142 |
terminate_response = []
|
143 |
+
chat_turn_sep = chat_sep = '\n'
|
144 |
+
humanstr = None
|
145 |
+
botstr = None
|
146 |
elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value),
|
147 |
PromptType.instruct.name] + [PromptType.instruct_with_end.value,
|
148 |
str(PromptType.instruct_with_end.value),
|
|
|
168 |
terminate_response = ['### End']
|
169 |
else:
|
170 |
terminate_response = None
|
171 |
+
chat_turn_sep = chat_sep = '\n'
|
172 |
humanstr = PreInstruct
|
173 |
botstr = PreResponse
|
174 |
elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value),
|
|
|
190 |
### Response:
|
191 |
"""
|
192 |
terminate_response = None
|
193 |
+
chat_turn_sep = chat_sep = '\n'
|
194 |
humanstr = PreInstruct # first thing human says
|
195 |
botstr = PreResponse # first thing bot says
|
196 |
elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value),
|
|
|
212 |
|
213 |
"""
|
214 |
preprompt = PRE_PROMPT.format(cur_date, cur_time)
|
215 |
+
start = ''
|
216 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
217 |
|
218 |
+
PreInstruct = human + ' '
|
219 |
|
220 |
PreInput = None
|
221 |
|
222 |
+
if making_context:
|
223 |
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
224 |
PreResponse = bot + ' '
|
225 |
else:
|
|
|
227 |
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
228 |
PreResponse = bot
|
229 |
|
230 |
+
terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse]
|
231 |
+
chat_turn_sep = chat_sep = '\n'
|
232 |
humanstr = human # tag before human talks
|
233 |
botstr = bot # tag before bot talks
|
234 |
+
generates_leading_space = True
|
235 |
elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value),
|
236 |
PromptType.dai_faq.name]:
|
237 |
promptA = ''
|
|
|
247 |
### Driverless AI documentation answer:
|
248 |
"""
|
249 |
terminate_response = ['\n\n']
|
250 |
+
chat_turn_sep = chat_sep = terminate_response
|
251 |
humanstr = PreInstruct
|
252 |
botstr = PreResponse
|
253 |
elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value),
|
|
|
256 |
PreInstruct = '## Main Text\n\n'
|
257 |
PreResponse = '\n\n## Summary\n\n'
|
258 |
terminate_response = None
|
259 |
+
chat_turn_sep = chat_sep = '\n'
|
260 |
humanstr = PreInstruct
|
261 |
botstr = PreResponse
|
262 |
elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value),
|
|
|
276 |
"""
|
277 |
terminate_response = [
|
278 |
'### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
279 |
+
chat_turn_sep = chat_sep = '\n'
|
280 |
humanstr = PreInstruct
|
281 |
botstr = PreResponse
|
282 |
elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value),
|
|
|
284 |
preprompt = ''
|
285 |
prompt_tokens = "<|prompt|>"
|
286 |
answer_tokens = "<|answer|>"
|
287 |
+
start = ''
|
288 |
promptB = promptA = '%s%s' % (preprompt, start)
|
289 |
+
PreInstruct = prompt_tokens
|
290 |
PreInput = None
|
291 |
PreResponse = answer_tokens
|
292 |
eos = '<|endoftext|>' # neox eos
|
|
|
|
|
293 |
humanstr = prompt_tokens
|
294 |
botstr = answer_tokens
|
295 |
+
terminate_response = [humanstr, PreResponse, eos]
|
296 |
+
chat_sep = ''
|
297 |
+
chat_turn_sep = eos
|
298 |
+
elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value),
|
299 |
+
PromptType.prompt_answer_openllama.name]:
|
300 |
+
preprompt = ''
|
301 |
+
prompt_tokens = "<|prompt|>"
|
302 |
+
answer_tokens = "<|answer|>"
|
303 |
+
start = ''
|
304 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
305 |
+
PreInstruct = prompt_tokens
|
306 |
+
PreInput = None
|
307 |
+
PreResponse = answer_tokens
|
308 |
+
eos = '</s>' # llama eos
|
309 |
+
humanstr = prompt_tokens
|
310 |
+
botstr = answer_tokens
|
311 |
+
terminate_response = [humanstr, PreResponse, eos]
|
312 |
+
chat_sep = ''
|
313 |
+
chat_turn_sep = eos
|
314 |
elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value),
|
315 |
PromptType.open_assistant.name]:
|
316 |
# From added_tokens.json
|
317 |
preprompt = ''
|
318 |
prompt_tokens = "<|prompter|>"
|
319 |
answer_tokens = "<|assistant|>"
|
320 |
+
start = ''
|
321 |
promptB = promptA = '%s%s' % (preprompt, start)
|
322 |
+
PreInstruct = prompt_tokens
|
323 |
PreInput = None
|
324 |
PreResponse = answer_tokens
|
325 |
pend = "<|prefix_end|>"
|
326 |
eos = "</s>"
|
|
|
|
|
327 |
humanstr = prompt_tokens
|
328 |
botstr = answer_tokens
|
329 |
+
terminate_response = [humanstr, PreResponse, pend, eos]
|
330 |
+
chat_turn_sep = chat_sep = eos
|
331 |
elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value),
|
332 |
PromptType.wizard_lm.name]:
|
333 |
# https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py
|
|
|
339 |
PreResponse = "\n\n### Response\n"
|
340 |
eos = "</s>"
|
341 |
terminate_response = [PreResponse, eos]
|
342 |
+
chat_turn_sep = chat_sep = eos
|
343 |
humanstr = promptA
|
344 |
botstr = PreResponse
|
345 |
elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value),
|
|
|
355 |
### Assistant:
|
356 |
"""
|
357 |
terminate_response = [PreResponse]
|
358 |
+
chat_turn_sep = chat_sep = '\n'
|
359 |
humanstr = PreInstruct
|
360 |
botstr = PreResponse
|
361 |
elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value),
|
362 |
PromptType.instruct_vicuna2.name]:
|
363 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
364 |
|
365 |
PreInstruct = """
|
366 |
HUMAN:
|
|
|
373 |
"""
|
374 |
terminate_response = [
|
375 |
'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
376 |
+
chat_turn_sep = chat_sep = '\n'
|
377 |
humanstr = PreInstruct
|
378 |
botstr = PreResponse
|
379 |
elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value),
|
380 |
PromptType.instruct_vicuna3.name]:
|
381 |
+
promptA = promptB = "" if not (chat and reduced) else ''
|
|
|
382 |
|
383 |
PreInstruct = """
|
384 |
### User:
|
|
|
391 |
"""
|
392 |
terminate_response = [
|
393 |
'### User:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
394 |
+
chat_turn_sep = chat_sep = '\n'
|
395 |
humanstr = PreInstruct
|
396 |
botstr = PreResponse
|
397 |
elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value),
|
398 |
PromptType.wizard2.name]:
|
399 |
# https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML
|
400 |
+
preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not (
|
401 |
+
chat and reduced) else ''
|
402 |
start = ''
|
403 |
promptB = promptA = '%s%s' % (preprompt, start)
|
404 |
PreInstruct = """
|
|
|
409 |
### Response:
|
410 |
"""
|
411 |
terminate_response = [PreResponse]
|
412 |
+
chat_turn_sep = chat_sep = '\n'
|
413 |
humanstr = PreInstruct
|
414 |
botstr = PreResponse
|
415 |
elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value),
|
416 |
PromptType.wizard3.name]:
|
417 |
# https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML
|
418 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not (
|
419 |
+
chat and reduced) else ''
|
420 |
start = ''
|
421 |
promptB = promptA = '%s%s' % (preprompt, start)
|
422 |
PreInstruct = """USER: """
|
423 |
PreInput = None
|
424 |
PreResponse = """ASSISTANT: """
|
425 |
terminate_response = [PreResponse]
|
426 |
+
chat_turn_sep = chat_sep = '\n'
|
427 |
+
humanstr = PreInstruct
|
428 |
+
botstr = PreResponse
|
429 |
+
elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value),
|
430 |
+
PromptType.wizard_vicuna.name]:
|
431 |
+
preprompt = ''
|
432 |
+
start = ''
|
433 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
434 |
+
PreInstruct = """USER: """
|
435 |
+
PreInput = None
|
436 |
+
PreResponse = """ASSISTANT: """
|
437 |
+
terminate_response = [PreResponse]
|
438 |
+
chat_turn_sep = chat_sep = '\n'
|
439 |
humanstr = PreInstruct
|
440 |
botstr = PreResponse
|
441 |
|
442 |
elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value),
|
443 |
PromptType.instruct_simple.name]:
|
444 |
+
promptB = promptA = '' if not (chat and reduced) else ''
|
|
|
445 |
|
446 |
PreInstruct = """
|
447 |
### Instruction:
|
|
|
455 |
### Response:
|
456 |
"""
|
457 |
terminate_response = None
|
458 |
+
chat_turn_sep = chat_sep = '\n'
|
459 |
humanstr = PreInstruct
|
460 |
botstr = PreResponse
|
461 |
+
elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value),
|
462 |
+
PromptType.openai.name]:
|
463 |
+
preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not (
|
464 |
+
chat and reduced) else ''
|
465 |
+
start = ''
|
466 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
467 |
+
PreInstruct = "\nHuman: "
|
468 |
+
PreInput = None
|
469 |
+
PreResponse = "\nAI:"
|
470 |
+
terminate_response = [PreResponse] + [" Human:", " AI:"]
|
471 |
+
chat_turn_sep = chat_sep = '\n'
|
472 |
+
humanstr = PreInstruct
|
473 |
+
botstr = PreResponse
|
474 |
+
elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value),
|
475 |
+
PromptType.gptj.name]:
|
476 |
+
preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not (
|
477 |
+
chat and reduced) else ''
|
478 |
+
start = ''
|
479 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
480 |
+
PreInstruct = "\n### Prompt: "
|
481 |
+
PreInput = None
|
482 |
+
PreResponse = "\n### Response: "
|
483 |
+
terminate_response = [PreResponse] + ["Prompt:", "Response:"]
|
484 |
+
chat_turn_sep = chat_sep = '\n'
|
485 |
+
humanstr = PreInstruct
|
486 |
+
botstr = PreResponse
|
487 |
+
elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value),
|
488 |
+
PromptType.openai_chat.name]:
|
489 |
+
# prompting and termination all handled by endpoint
|
490 |
+
preprompt = """"""
|
491 |
+
start = ''
|
492 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
493 |
+
PreInstruct = ""
|
494 |
+
PreInput = None
|
495 |
+
PreResponse = ""
|
496 |
+
terminate_response = []
|
497 |
+
chat_turn_sep = chat_sep = '\n'
|
498 |
+
humanstr = None
|
499 |
+
botstr = None
|
500 |
+
elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value),
|
501 |
+
PromptType.vicuna11.name]:
|
502 |
+
preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not (
|
503 |
+
chat and reduced) else ''
|
504 |
+
start = ''
|
505 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
506 |
+
eos = '</s>'
|
507 |
+
PreInstruct = """USER: """
|
508 |
+
PreInput = None
|
509 |
+
PreResponse = """ASSISTANT:"""
|
510 |
+
terminate_response = [PreResponse]
|
511 |
+
chat_sep = ' '
|
512 |
+
chat_turn_sep = eos
|
513 |
+
humanstr = PreInstruct
|
514 |
+
botstr = PreResponse
|
515 |
+
|
516 |
+
if making_context:
|
517 |
+
# when making context, want it to appear as-if LLM generated, which starts with space after :
|
518 |
+
PreResponse = PreResponse + ' '
|
519 |
+
else:
|
520 |
+
# normally LLM adds space after this, because was how trained.
|
521 |
+
# if add space here, non-unique tokenization will often make LLM produce wrong output
|
522 |
+
PreResponse = PreResponse
|
523 |
else:
|
524 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
525 |
|
526 |
+
if isinstance(terminate_response, (tuple, list)):
|
527 |
+
assert '' not in terminate_response, "Bad terminate_response"
|
528 |
+
|
529 |
+
ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput,
|
530 |
PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep,
|
531 |
+
chat_turn_sep=chat_turn_sep,
|
532 |
+
humanstr=humanstr, botstr=botstr,
|
533 |
+
generates_leading_space=generates_leading_space)
|
534 |
+
|
535 |
+
if return_dict:
|
536 |
+
return ret_dict, prompt_dict_error
|
537 |
else:
|
538 |
+
return tuple(list(ret_dict.values()))
|
539 |
|
540 |
|
541 |
+
def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context):
|
542 |
context = data_point.get('context')
|
543 |
if context is None:
|
544 |
context = ''
|
|
|
549 |
prompt_dict = data_point.get('prompt_dict', prompt_dict)
|
550 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
551 |
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
552 |
+
terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \
|
553 |
+
generates_leading_space = get_prompt(prompt_type, prompt_dict, chat,
|
554 |
+
context, reduced, making_context)
|
555 |
|
556 |
+
# could avoid if reduce=True, but too complex for parent functions to handle
|
557 |
+
prompt = context
|
558 |
|
559 |
if input and promptA:
|
560 |
prompt += f"""{promptA}"""
|
|
|
563 |
|
564 |
if instruction and PreInstruct is not None and input and PreInput is not None:
|
565 |
prompt += f"""{PreInstruct}{instruction}{PreInput}{input}"""
|
566 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
567 |
elif instruction and input and PreInstruct is None and PreInput is not None:
|
568 |
prompt += f"""{PreInput}{instruction}
|
569 |
{input}"""
|
570 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
571 |
elif input and instruction and PreInput is None and PreInstruct is not None:
|
572 |
prompt += f"""{PreInstruct}{instruction}
|
573 |
{input}"""
|
574 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
575 |
elif instruction and PreInstruct is not None:
|
576 |
prompt += f"""{PreInstruct}{instruction}"""
|
577 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
578 |
elif input and PreInput is not None:
|
579 |
prompt += f"""{PreInput}{input}"""
|
580 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
581 |
elif input and instruction and PreInput is not None:
|
582 |
prompt += f"""{PreInput}{instruction}{input}"""
|
583 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
584 |
elif input and instruction and PreInstruct is not None:
|
585 |
prompt += f"""{PreInstruct}{instruction}{input}"""
|
586 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
587 |
elif input and instruction:
|
588 |
# i.e. for simple_instruct
|
589 |
prompt += f"""{instruction}: {input}"""
|
590 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
591 |
elif input:
|
592 |
prompt += f"""{input}"""
|
593 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
594 |
elif instruction:
|
595 |
prompt += f"""{instruction}"""
|
596 |
+
prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep)
|
597 |
|
598 |
if PreResponse is not None:
|
599 |
prompt += f"""{PreResponse}"""
|
|
|
604 |
if output:
|
605 |
prompt += f"""{output}"""
|
606 |
|
607 |
+
return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep
|
608 |
|
609 |
|
610 |
+
def inject_chatsep(prompt_type, prompt, chat_sep=None):
|
611 |
+
if chat_sep:
|
612 |
# only add new line if structured prompt, while 'plain' is just generation of next tokens from input
|
613 |
+
prompt += chat_sep
|
614 |
return prompt
|
615 |
|
616 |
|
|
|
619 |
allowed_repeat_line_length=10):
|
620 |
self.prompt_type = prompt_type
|
621 |
self.prompt_dict = prompt_dict
|
|
|
|
|
|
|
622 |
self.debug = debug
|
623 |
self.chat = chat
|
624 |
self.stream_output = stream_output
|
|
|
627 |
self.prompt = None
|
628 |
context = "" # not for chat context
|
629 |
reduced = False # not for chat context
|
630 |
+
making_context = False # not for chat context
|
631 |
self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \
|
632 |
+
self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \
|
633 |
+
self.generates_leading_space = \
|
634 |
+
get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context)
|
635 |
+
self.pre_response = self.PreResponse
|
636 |
+
|
637 |
+
def generate_prompt(self, data_point, reduced=None):
|
638 |
+
"""
|
639 |
+
data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt
|
640 |
+
:param data_point:
|
641 |
+
:param reduced:
|
642 |
+
:return:
|
643 |
+
"""
|
644 |
+
reduced = data_point.get('context') not in ['', None] if reduced is None else reduced
|
645 |
+
making_context = False # whether really making final prompt or just generating context
|
646 |
+
prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced,
|
647 |
+
making_context)
|
648 |
if self.debug:
|
649 |
+
print("prompt: %s" % prompt, flush=True)
|
650 |
+
# if have context, should have always reduced and only preappend promptA/B here
|
651 |
+
if data_point.get('context'):
|
652 |
+
if data_point.get('input') and self.promptA:
|
653 |
+
prompt = self.promptA + prompt
|
654 |
+
elif self.promptB:
|
655 |
+
prompt = self.promptB + prompt
|
656 |
+
|
657 |
self.prompt = prompt
|
658 |
return prompt
|
659 |
|
660 |
+
def get_response(self, outputs, prompt=None, sanitize_bot_response=False):
|
661 |
if isinstance(outputs, str):
|
662 |
outputs = [outputs]
|
663 |
if self.debug:
|
664 |
+
print("output:\n%s" % '\n\n'.join(outputs), flush=True)
|
665 |
if prompt is not None:
|
666 |
self.prompt = prompt
|
667 |
|
|
|
672 |
if sanitize_bot_response:
|
673 |
from better_profanity import profanity
|
674 |
response = profanity.censor(response)
|
675 |
+
if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ':
|
676 |
+
response = response[1:]
|
677 |
return response
|
678 |
|
679 |
def clean_repeats(response):
|
|
|
695 |
# then use most basic parsing like pipeline
|
696 |
if self.botstr in output:
|
697 |
if self.humanstr:
|
698 |
+
output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0])
|
699 |
else:
|
700 |
# i.e. use after bot but only up to next bot
|
701 |
+
output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0])
|
702 |
else:
|
703 |
+
# output = clean_response(output)
|
704 |
# assume just not printed yet
|
705 |
output = ""
|
706 |
else:
|
|
|
727 |
allow_terminate = True
|
728 |
output = output[len(prompt):]
|
729 |
# clean after subtract prompt out, so correct removal of pre_response
|
730 |
+
output = clean_response(output)
|
731 |
if self.repeat_penalty:
|
732 |
+
output = clean_repeats(output)
|
733 |
if self.terminate_response and allow_terminate:
|
734 |
finds = []
|
735 |
for term in self.terminate_response:
|
|
|
737 |
finds = [x for x in finds if x >= 0]
|
738 |
if len(finds) > 0:
|
739 |
termi = finds[0]
|
740 |
+
output = output[:termi]
|
741 |
else:
|
742 |
+
output = output
|
|
|
|
|
743 |
if multi_output:
|
744 |
# prefix with output counter
|
745 |
output = "\n=========== Output %d\n\n" % (1 + oi) + output
|
|
|
750 |
# join all outputs, only one extra new line between outputs
|
751 |
output = '\n'.join(outputs)
|
752 |
if self.debug:
|
753 |
+
print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True)
|
754 |
return output
|
requirements.txt
CHANGED
@@ -1,50 +1,50 @@
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
-
datasets==2.
|
3 |
-
sentencepiece==0.1.
|
4 |
-
gradio==3.
|
5 |
-
huggingface_hub==0.
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
-
docutils==0.
|
9 |
torch==2.0.1
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
13 |
scikit-learn==1.2.2
|
14 |
alt-profanity-check==1.2.2
|
15 |
-
better-profanity==0.
|
16 |
-
numpy==1.24.
|
17 |
-
pandas==2.0.
|
18 |
matplotlib==3.7.1
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
-
accelerate==0.
|
22 |
-
git+https://github.com/huggingface/peft.git@
|
23 |
-
transformers==4.
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
26 |
|
27 |
# optional for generate
|
28 |
pynvml==11.5.0
|
29 |
-
psutil==5.9.
|
30 |
boto3==1.26.101
|
31 |
botocore==1.29.101
|
32 |
|
33 |
# optional for finetune
|
34 |
-
tensorboard==2.
|
35 |
-
neptune==1.
|
36 |
|
37 |
# for gradio client
|
38 |
-
gradio_client==0.2.
|
39 |
beautifulsoup4==4.12.2
|
40 |
-
markdown==3.4.
|
41 |
|
42 |
# data and testing
|
43 |
pytest==7.2.2
|
44 |
pytest-xdist==3.2.1
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
-
pandoc==2.3
|
48 |
#pypandoc==1.11
|
49 |
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
@@ -57,17 +57,20 @@ instructorembedding==1.0.1
|
|
57 |
|
58 |
# for gpt4all .env file, but avoid worrying about imports
|
59 |
python-dotenv==1.0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
# optional for chat with PDF
|
61 |
-
langchain==0.0.
|
62 |
-
pypdf==3.
|
63 |
-
tiktoken==0.3.3
|
64 |
# avoid textract, requires old six
|
65 |
#textract==1.6.5
|
66 |
|
67 |
# for HF embeddings
|
68 |
sentence_transformers==2.2.2
|
69 |
-
# for OpenAI embeddings (requires key)
|
70 |
-
openai==0.27.6
|
71 |
|
72 |
# local vector db
|
73 |
chromadb==0.3.25
|
@@ -79,14 +82,14 @@ chromadb==0.3.25
|
|
79 |
|
80 |
# strong support for images
|
81 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
82 |
-
unstructured[local-inference]==0.
|
83 |
#pdf2image==1.16.3
|
84 |
#pytesseract==0.3.10
|
85 |
pillow
|
86 |
|
87 |
pdfminer.six==20221105
|
88 |
-
urllib3
|
89 |
-
requests_file
|
90 |
|
91 |
#pdf2image==1.16.3
|
92 |
#pytesseract==0.3.10
|
@@ -101,18 +104,15 @@ tabulate==0.9.0
|
|
101 |
pip-licenses==4.3.0
|
102 |
|
103 |
# weaviate vector db
|
104 |
-
weaviate-client==3.
|
105 |
# optional for chat with PDF
|
106 |
-
langchain==0.0.
|
107 |
-
pypdf==3.
|
108 |
-
tiktoken==0.3.3
|
109 |
# avoid textract, requires old six
|
110 |
#textract==1.6.5
|
111 |
|
112 |
# for HF embeddings
|
113 |
sentence_transformers==2.2.2
|
114 |
-
# for OpenAI embeddings (requires key)
|
115 |
-
openai==0.27.6
|
116 |
|
117 |
# local vector db
|
118 |
chromadb==0.3.25
|
@@ -124,14 +124,14 @@ chromadb==0.3.25
|
|
124 |
|
125 |
# strong support for images
|
126 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
-
unstructured[local-inference]==0.
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
130 |
pillow
|
131 |
|
132 |
pdfminer.six==20221105
|
133 |
-
urllib3
|
134 |
-
requests_file
|
135 |
|
136 |
#pdf2image==1.16.3
|
137 |
#pytesseract==0.3.10
|
@@ -146,7 +146,7 @@ tabulate==0.9.0
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
-
weaviate-client==3.
|
150 |
faiss-gpu==1.7.2
|
151 |
arxiv==1.4.7
|
152 |
pymupdf==1.22.3 # AGPL license
|
|
|
1 |
# for generate (gradio server) and finetune
|
2 |
+
datasets==2.13.0
|
3 |
+
sentencepiece==0.1.99
|
4 |
+
gradio==3.35.2
|
5 |
+
huggingface_hub==0.15.1
|
6 |
appdirs==1.4.4
|
7 |
fire==0.5.0
|
8 |
+
docutils==0.20.1
|
9 |
torch==2.0.1
|
10 |
evaluate==0.4.0
|
11 |
rouge_score==0.1.2
|
12 |
sacrebleu==2.3.1
|
13 |
scikit-learn==1.2.2
|
14 |
alt-profanity-check==1.2.2
|
15 |
+
better-profanity==0.7.0
|
16 |
+
numpy==1.24.3
|
17 |
+
pandas==2.0.2
|
18 |
matplotlib==3.7.1
|
19 |
loralib==0.1.1
|
20 |
bitsandbytes==0.39.0
|
21 |
+
accelerate==0.20.3
|
22 |
+
git+https://github.com/huggingface/peft.git@0b62b4378b4ce9367932c73540349da9a41bdea8
|
23 |
+
transformers==4.30.2
|
24 |
tokenizers==0.13.3
|
25 |
APScheduler==3.10.1
|
26 |
|
27 |
# optional for generate
|
28 |
pynvml==11.5.0
|
29 |
+
psutil==5.9.5
|
30 |
boto3==1.26.101
|
31 |
botocore==1.29.101
|
32 |
|
33 |
# optional for finetune
|
34 |
+
tensorboard==2.13.0
|
35 |
+
neptune==1.2.0
|
36 |
|
37 |
# for gradio client
|
38 |
+
gradio_client==0.2.7
|
39 |
beautifulsoup4==4.12.2
|
40 |
+
markdown==3.4.3
|
41 |
|
42 |
# data and testing
|
43 |
pytest==7.2.2
|
44 |
pytest-xdist==3.2.1
|
45 |
nltk==3.8.1
|
46 |
textstat==0.7.3
|
47 |
+
# pandoc==2.3
|
48 |
#pypandoc==1.11
|
49 |
pypandoc_binary==1.11
|
50 |
openpyxl==3.1.2
|
|
|
57 |
|
58 |
# for gpt4all .env file, but avoid worrying about imports
|
59 |
python-dotenv==1.0.0
|
60 |
+
|
61 |
+
text-generation==0.6.0
|
62 |
+
# for tokenization when don't have HF tokenizer
|
63 |
+
tiktoken==0.4.0
|
64 |
+
# optional: for OpenAI endpoint or embeddings (requires key)
|
65 |
+
openai==0.27.8
|
66 |
# optional for chat with PDF
|
67 |
+
langchain==0.0.202
|
68 |
+
pypdf==3.9.1
|
|
|
69 |
# avoid textract, requires old six
|
70 |
#textract==1.6.5
|
71 |
|
72 |
# for HF embeddings
|
73 |
sentence_transformers==2.2.2
|
|
|
|
|
74 |
|
75 |
# local vector db
|
76 |
chromadb==0.3.25
|
|
|
82 |
|
83 |
# strong support for images
|
84 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
85 |
+
unstructured[local-inference]==0.7.4
|
86 |
#pdf2image==1.16.3
|
87 |
#pytesseract==0.3.10
|
88 |
pillow
|
89 |
|
90 |
pdfminer.six==20221105
|
91 |
+
urllib3
|
92 |
+
requests_file
|
93 |
|
94 |
#pdf2image==1.16.3
|
95 |
#pytesseract==0.3.10
|
|
|
104 |
pip-licenses==4.3.0
|
105 |
|
106 |
# weaviate vector db
|
107 |
+
weaviate-client==3.20.0
|
108 |
# optional for chat with PDF
|
109 |
+
langchain==0.0.202
|
110 |
+
pypdf==3.9.1
|
|
|
111 |
# avoid textract, requires old six
|
112 |
#textract==1.6.5
|
113 |
|
114 |
# for HF embeddings
|
115 |
sentence_transformers==2.2.2
|
|
|
|
|
116 |
|
117 |
# local vector db
|
118 |
chromadb==0.3.25
|
|
|
124 |
|
125 |
# strong support for images
|
126 |
# Requires on Ubuntu: sudo apt-get install libmagic-dev poppler-utils tesseract-ocr libreoffice
|
127 |
+
unstructured[local-inference]==0.7.4
|
128 |
#pdf2image==1.16.3
|
129 |
#pytesseract==0.3.10
|
130 |
pillow
|
131 |
|
132 |
pdfminer.six==20221105
|
133 |
+
urllib3
|
134 |
+
requests_file
|
135 |
|
136 |
#pdf2image==1.16.3
|
137 |
#pytesseract==0.3.10
|
|
|
146 |
pip-licenses==4.3.0
|
147 |
|
148 |
# weaviate vector db
|
149 |
+
weaviate-client==3.20.0
|
150 |
faiss-gpu==1.7.2
|
151 |
arxiv==1.4.7
|
152 |
pymupdf==1.22.3 # AGPL license
|
stopping.py
CHANGED
@@ -1,17 +1,18 @@
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
4 |
-
from
|
5 |
|
6 |
|
7 |
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
|
9 |
-
def __init__(self, stops=[], encounters=[], device="cuda"):
|
10 |
super().__init__()
|
11 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
12 |
self.encounters = encounters
|
13 |
self.stops = [stop.to(device) for stop in stops]
|
14 |
self.num_stops = [0] * len(stops)
|
|
|
15 |
|
16 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
17 |
for stopi, stop in enumerate(self.stops):
|
@@ -20,12 +21,15 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
20 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
21 |
# print("Stopped", flush=True)
|
22 |
return True
|
|
|
|
|
|
|
23 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
24 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
25 |
return False
|
26 |
|
27 |
|
28 |
-
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:"):
|
29 |
# FIXME: prompt_dict unused currently
|
30 |
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
31 |
if prompt_type == PromptType.human_bot.name:
|
@@ -67,7 +71,8 @@ def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:',
|
|
67 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
68 |
# build stopper
|
69 |
stopping_criteria = StoppingCriteriaList(
|
70 |
-
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device
|
|
|
71 |
else:
|
72 |
stopping_criteria = StoppingCriteriaList()
|
73 |
return stopping_criteria
|
|
|
1 |
import torch
|
2 |
from transformers import StoppingCriteria, StoppingCriteriaList
|
3 |
|
4 |
+
from enums import PromptType
|
5 |
|
6 |
|
7 |
class StoppingCriteriaSub(StoppingCriteria):
|
8 |
|
9 |
+
def __init__(self, stops=[], encounters=[], device="cuda", model_max_length=None):
|
10 |
super().__init__()
|
11 |
assert len(stops) % len(encounters) == 0, "Number of stops and encounters must match"
|
12 |
self.encounters = encounters
|
13 |
self.stops = [stop.to(device) for stop in stops]
|
14 |
self.num_stops = [0] * len(stops)
|
15 |
+
self.model_max_length = model_max_length
|
16 |
|
17 |
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
18 |
for stopi, stop in enumerate(self.stops):
|
|
|
21 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
22 |
# print("Stopped", flush=True)
|
23 |
return True
|
24 |
+
if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length:
|
25 |
+
# critical limit
|
26 |
+
return True
|
27 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
28 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
29 |
return False
|
30 |
|
31 |
|
32 |
+
def get_stopping(prompt_type, prompt_dict, tokenizer, device, human='<human>:', bot="<bot>:", model_max_length=None):
|
33 |
# FIXME: prompt_dict unused currently
|
34 |
if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]:
|
35 |
if prompt_type == PromptType.human_bot.name:
|
|
|
71 |
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
72 |
# build stopper
|
73 |
stopping_criteria = StoppingCriteriaList(
|
74 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device,
|
75 |
+
model_max_length=model_max_length)])
|
76 |
else:
|
77 |
stopping_criteria = StoppingCriteriaList()
|
78 |
return stopping_criteria
|
utils.py
CHANGED
@@ -69,6 +69,25 @@ def ping():
|
|
69 |
pass
|
70 |
|
71 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
def get_torch_allocated():
|
73 |
import torch
|
74 |
return torch.cuda.memory_allocated()
|
@@ -98,27 +117,29 @@ def system_info():
|
|
98 |
system['CPU_C/%s' % k] = v
|
99 |
|
100 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
122 |
system['hash'] = get_githash()
|
123 |
|
124 |
return system
|
@@ -167,35 +188,39 @@ def _zip_data(root_dirs=None, zip_file=None, base_dir='./'):
|
|
167 |
return zip_file, zip_file
|
168 |
|
169 |
|
170 |
-
def save_generate_output(output=None, base_model=None, save_dir=None
|
|
|
171 |
try:
|
172 |
-
return _save_generate_output(output=output, base_model=base_model, save_dir=save_dir
|
|
|
173 |
except Exception as e:
|
174 |
traceback.print_exc()
|
175 |
print('Exception in saving: %s' % str(e))
|
176 |
|
177 |
|
178 |
-
def _save_generate_output(output=None, base_model=None, save_dir=None
|
|
|
179 |
"""
|
180 |
Save conversation to .json, row by row.
|
181 |
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
182 |
Appends if file exists
|
183 |
"""
|
|
|
|
|
184 |
assert save_dir, "save_dir must be provided"
|
185 |
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
186 |
raise RuntimeError("save_dir already exists and is not a directory!")
|
187 |
os.makedirs(save_dir, exist_ok=True)
|
188 |
import json
|
189 |
-
|
190 |
-
|
191 |
-
output = output[:-10]
|
192 |
with filelock.FileLock("save_dir.lock"):
|
193 |
# lock logging in case have concurrency
|
194 |
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
195 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
196 |
f.write(
|
197 |
" " + json.dumps(
|
198 |
-
|
199 |
) + ",\n"
|
200 |
)
|
201 |
|
@@ -801,6 +826,7 @@ def get_kwargs(func, exclude_names=None, **kwargs):
|
|
801 |
|
802 |
|
803 |
import pkg_resources
|
|
|
804 |
have_faiss = False
|
805 |
|
806 |
try:
|
@@ -828,7 +854,7 @@ def hash_file(file):
|
|
828 |
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
829 |
|
830 |
md5 = hashlib.md5()
|
831 |
-
#sha1 = hashlib.sha1()
|
832 |
|
833 |
with open(file, 'rb') as f:
|
834 |
while True:
|
@@ -836,7 +862,7 @@ def hash_file(file):
|
|
836 |
if not data:
|
837 |
break
|
838 |
md5.update(data)
|
839 |
-
#sha1.update(data)
|
840 |
except BaseException as e:
|
841 |
print("Cannot hash %s due to %s" % (file, str(e)))
|
842 |
traceback.print_exc()
|
@@ -848,8 +874,55 @@ def start_faulthandler():
|
|
848 |
# If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
|
849 |
# If more than one fork tries to write at same time, then looks corrupted.
|
850 |
import faulthandler
|
851 |
-
import signal
|
852 |
|
853 |
# SIGUSR1 in h2oai/__init__.py as well
|
854 |
faulthandler.enable()
|
855 |
-
faulthandler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
pass
|
70 |
|
71 |
|
72 |
+
def ping_gpu():
|
73 |
+
try:
|
74 |
+
print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True)
|
75 |
+
except AttributeError:
|
76 |
+
# some programs wrap print and will fail with flush passed
|
77 |
+
pass
|
78 |
+
try:
|
79 |
+
ping_gpu_memory()
|
80 |
+
except Exception as e:
|
81 |
+
print('Ping_GPU memory failure: %s' % str(e), flush=True)
|
82 |
+
|
83 |
+
|
84 |
+
def ping_gpu_memory():
|
85 |
+
from models.gpu_mem_track import MemTracker
|
86 |
+
gpu_tracker = MemTracker() # define a GPU tracker
|
87 |
+
from torch.cuda import memory_summary
|
88 |
+
gpu_tracker.track()
|
89 |
+
|
90 |
+
|
91 |
def get_torch_allocated():
|
92 |
import torch
|
93 |
return torch.cuda.memory_allocated()
|
|
|
117 |
system['CPU_C/%s' % k] = v
|
118 |
|
119 |
# https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt
|
120 |
+
try:
|
121 |
+
from pynvml.smi import nvidia_smi
|
122 |
+
nvsmi = nvidia_smi.getInstance()
|
123 |
+
|
124 |
+
gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in
|
125 |
+
enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])}
|
126 |
+
for k, v in gpu_power_dict.items():
|
127 |
+
system['GPU_W/%s' % k] = v
|
128 |
+
|
129 |
+
gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in
|
130 |
+
enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])}
|
131 |
+
for k, v in gpu_temp_dict.items():
|
132 |
+
system['GPU_C/%s' % k] = v
|
133 |
+
|
134 |
+
gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in
|
135 |
+
enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])}
|
136 |
+
gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in
|
137 |
+
enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])}
|
138 |
+
gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict}
|
139 |
+
for k, v in gpu_memory_frac_dict.items():
|
140 |
+
system[f'GPU_M/%s' % k] = v
|
141 |
+
except ModuleNotFoundError:
|
142 |
+
pass
|
143 |
system['hash'] = get_githash()
|
144 |
|
145 |
return system
|
|
|
188 |
return zip_file, zip_file
|
189 |
|
190 |
|
191 |
+
def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
192 |
+
extra_dict={}):
|
193 |
try:
|
194 |
+
return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir,
|
195 |
+
where_from=where_from, extra_dict=extra_dict)
|
196 |
except Exception as e:
|
197 |
traceback.print_exc()
|
198 |
print('Exception in saving: %s' % str(e))
|
199 |
|
200 |
|
201 |
+
def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from',
|
202 |
+
extra_dict={}):
|
203 |
"""
|
204 |
Save conversation to .json, row by row.
|
205 |
json_file_path is path to final JSON file. If not in ., then will attempt to make directories.
|
206 |
Appends if file exists
|
207 |
"""
|
208 |
+
prompt = '<not set>' if prompt is None else prompt
|
209 |
+
output = '<not set>' if output is None else output
|
210 |
assert save_dir, "save_dir must be provided"
|
211 |
if os.path.exists(save_dir) and not os.path.isdir(save_dir):
|
212 |
raise RuntimeError("save_dir already exists and is not a directory!")
|
213 |
os.makedirs(save_dir, exist_ok=True)
|
214 |
import json
|
215 |
+
dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from)
|
216 |
+
dict_to_save.update(extra_dict)
|
|
|
217 |
with filelock.FileLock("save_dir.lock"):
|
218 |
# lock logging in case have concurrency
|
219 |
with open(os.path.join(save_dir, "history.json"), "a") as f:
|
220 |
# just add [ at start, and ] at end, and have proper JSON dataset
|
221 |
f.write(
|
222 |
" " + json.dumps(
|
223 |
+
dict_to_save
|
224 |
) + ",\n"
|
225 |
)
|
226 |
|
|
|
826 |
|
827 |
|
828 |
import pkg_resources
|
829 |
+
|
830 |
have_faiss = False
|
831 |
|
832 |
try:
|
|
|
854 |
BUF_SIZE = 65536 # lets read stuff in 64kb chunks!
|
855 |
|
856 |
md5 = hashlib.md5()
|
857 |
+
# sha1 = hashlib.sha1()
|
858 |
|
859 |
with open(file, 'rb') as f:
|
860 |
while True:
|
|
|
862 |
if not data:
|
863 |
break
|
864 |
md5.update(data)
|
865 |
+
# sha1.update(data)
|
866 |
except BaseException as e:
|
867 |
print("Cannot hash %s due to %s" % (file, str(e)))
|
868 |
traceback.print_exc()
|
|
|
874 |
# If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump
|
875 |
# If more than one fork tries to write at same time, then looks corrupted.
|
876 |
import faulthandler
|
|
|
877 |
|
878 |
# SIGUSR1 in h2oai/__init__.py as well
|
879 |
faulthandler.enable()
|
880 |
+
if hasattr(faulthandler, 'register'):
|
881 |
+
# windows/mac
|
882 |
+
import signal
|
883 |
+
faulthandler.register(signal.SIGUSR1)
|
884 |
+
|
885 |
+
|
886 |
+
def get_hf_server(inference_server):
|
887 |
+
inf_split = inference_server.split(" ")
|
888 |
+
assert len(inf_split) == 1 or len(inf_split) == 3
|
889 |
+
inference_server = inf_split[0]
|
890 |
+
if len(inf_split) == 3:
|
891 |
+
headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])}
|
892 |
+
else:
|
893 |
+
headers = None
|
894 |
+
return inference_server, headers
|
895 |
+
|
896 |
+
|
897 |
+
class FakeTokenizer:
|
898 |
+
"""
|
899 |
+
1) For keeping track of model_max_length
|
900 |
+
2) For when model doesn't directly expose tokenizer but need to count tokens
|
901 |
+
"""
|
902 |
+
|
903 |
+
def __init__(self, model_max_length=2048, encoding_name="cl100k_base"):
|
904 |
+
# dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250
|
905 |
+
self.model_max_length = model_max_length - 250
|
906 |
+
self.encoding_name = encoding_name
|
907 |
+
# The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection.
|
908 |
+
import tiktoken
|
909 |
+
self.encoding = tiktoken.get_encoding(self.encoding_name)
|
910 |
+
|
911 |
+
def encode(self, x, *args, return_tensors="pt", **kwargs):
|
912 |
+
input_ids = self.encoding.encode(x, disallowed_special=())
|
913 |
+
if return_tensors == 'pt' and isinstance(input_ids, list):
|
914 |
+
import torch
|
915 |
+
input_ids = torch.tensor(input_ids)
|
916 |
+
return dict(input_ids=input_ids)
|
917 |
+
|
918 |
+
def decode(self, x, *args, **kwargs):
|
919 |
+
# input is input_ids[0] form
|
920 |
+
return self.encoding.decode(x)
|
921 |
+
|
922 |
+
def num_tokens_from_string(self, prompt: str) -> int:
|
923 |
+
"""Returns the number of tokens in a text string."""
|
924 |
+
num_tokens = len(self.encoding.encode(prompt))
|
925 |
+
return num_tokens
|
926 |
+
|
927 |
+
def __call__(self, x, *args, **kwargs):
|
928 |
+
return self.encode(x, *args, **kwargs)
|
utils_langchain.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Union, Optional
|
2 |
+
import time
|
3 |
+
import queue
|
4 |
+
|
5 |
+
from langchain.callbacks.base import BaseCallbackHandler
|
6 |
+
from langchain.schema import LLMResult
|
7 |
+
|
8 |
+
|
9 |
+
class StreamingGradioCallbackHandler(BaseCallbackHandler):
|
10 |
+
"""
|
11 |
+
Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend
|
12 |
+
"""
|
13 |
+
def __init__(self, timeout: Optional[float] = None, block=True):
|
14 |
+
super().__init__()
|
15 |
+
self.text_queue = queue.SimpleQueue()
|
16 |
+
self.stop_signal = None
|
17 |
+
self.do_stop = False
|
18 |
+
self.timeout = timeout
|
19 |
+
self.block = block
|
20 |
+
|
21 |
+
def on_llm_start(
|
22 |
+
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
23 |
+
) -> None:
|
24 |
+
"""Run when LLM starts running. Clean the queue."""
|
25 |
+
while not self.text_queue.empty():
|
26 |
+
try:
|
27 |
+
self.text_queue.get(block=False)
|
28 |
+
except queue.Empty:
|
29 |
+
continue
|
30 |
+
|
31 |
+
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
32 |
+
"""Run on new LLM token. Only available when streaming is enabled."""
|
33 |
+
self.text_queue.put(token)
|
34 |
+
|
35 |
+
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
36 |
+
"""Run when LLM ends running."""
|
37 |
+
self.text_queue.put(self.stop_signal)
|
38 |
+
|
39 |
+
def on_llm_error(
|
40 |
+
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
41 |
+
) -> None:
|
42 |
+
"""Run when LLM errors."""
|
43 |
+
self.text_queue.put(self.stop_signal)
|
44 |
+
|
45 |
+
def __iter__(self):
|
46 |
+
return self
|
47 |
+
|
48 |
+
def __next__(self):
|
49 |
+
while True:
|
50 |
+
try:
|
51 |
+
value = self.stop_signal # value looks unused in pycharm, not true
|
52 |
+
if self.do_stop:
|
53 |
+
print("hit stop", flush=True)
|
54 |
+
# could raise or break, maybe best to raise and make parent see if any exception in thread
|
55 |
+
raise StopIteration()
|
56 |
+
# break
|
57 |
+
value = self.text_queue.get(block=self.block, timeout=self.timeout)
|
58 |
+
break
|
59 |
+
except queue.Empty:
|
60 |
+
time.sleep(0.01)
|
61 |
+
if value == self.stop_signal:
|
62 |
+
raise StopIteration()
|
63 |
+
else:
|
64 |
+
return value
|