upload?
Browse files- app.py +250 -0
- lemur-7B/config.json +23 -0
- lemur-7B/generation_config.json +7 -0
- lemur-7B/pytorch_model.bin +3 -0
- lemur-7B/special_tokens_map.json +24 -0
- lemur-7B/tokenizer.json +0 -0
- lemur-7B/tokenizer.model +3 -0
- lemur-7B/tokenizer_config.json +31 -0
- utils/gradio.py +71 -0
- utils/inference.py +107 -0
- variables.py +5 -0
app.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import gradio as gr
|
7 |
+
import logging
|
8 |
+
|
9 |
+
from utils.inference import load_tokenizer_and_model, decode, \
|
10 |
+
get_prompt_with_history, is_stop_word_or_prefix
|
11 |
+
|
12 |
+
from utils.gradio import reset_textbox, cancel_outputing, transfer_input, \
|
13 |
+
delete_last_conversation, reset_state, convert_to_markdown
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
# set variables
|
18 |
+
model = "lemur-7B"
|
19 |
+
|
20 |
+
|
21 |
+
print("Loading model...")
|
22 |
+
|
23 |
+
import time
|
24 |
+
|
25 |
+
start = time.time()
|
26 |
+
|
27 |
+
tokenizer, model, device = load_tokenizer_and_model(model, load_8bit=True)
|
28 |
+
|
29 |
+
print("Model loaded in {} seconds.".format(time.time() - start))
|
30 |
+
|
31 |
+
|
32 |
+
def predict(
|
33 |
+
text,
|
34 |
+
chatbot,
|
35 |
+
history,
|
36 |
+
top_p,
|
37 |
+
temperature,
|
38 |
+
max_length_tokens,
|
39 |
+
max_context_length_tokens,
|
40 |
+
):
|
41 |
+
if text == "":
|
42 |
+
yield chatbot, history, "Empty context."
|
43 |
+
return
|
44 |
+
|
45 |
+
inputs = get_prompt_with_history(
|
46 |
+
text, history, tokenizer, max_length=max_context_length_tokens
|
47 |
+
)
|
48 |
+
if inputs is None:
|
49 |
+
yield chatbot, history, "Input too long."
|
50 |
+
return
|
51 |
+
else:
|
52 |
+
prompt, inputs = inputs
|
53 |
+
|
54 |
+
input_ids = inputs["input_ids"][:, -max_context_length_tokens:].to(device)
|
55 |
+
torch.cuda.empty_cache()
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
for x in decode(
|
59 |
+
input_ids,
|
60 |
+
model,
|
61 |
+
tokenizer,
|
62 |
+
stop_words=["[Human]", "[AI]"],
|
63 |
+
max_length=max_length_tokens,
|
64 |
+
temperature=temperature,
|
65 |
+
top_p=top_p,
|
66 |
+
):
|
67 |
+
if is_stop_word_or_prefix(x, ["[Human]", "[AI]"]) is False:
|
68 |
+
if "[Human]" in x:
|
69 |
+
x = x[: x.index("[Human]")].strip()
|
70 |
+
if "[AI]" in x:
|
71 |
+
x = x[: x.index("[AI]")].strip()
|
72 |
+
x = x.strip(" ")
|
73 |
+
a, b = [[y[0], convert_to_markdown(y[1])] for y in history] + [
|
74 |
+
[text, convert_to_markdown(x)]
|
75 |
+
], history + [[text, x]]
|
76 |
+
yield a, b, "Generating..."
|
77 |
+
|
78 |
+
torch.cuda.empty_cache()
|
79 |
+
print(prompt)
|
80 |
+
print(x)
|
81 |
+
print("=" * 80)
|
82 |
+
try:
|
83 |
+
yield a, b, "Generate: Success"
|
84 |
+
except:
|
85 |
+
pass
|
86 |
+
|
87 |
+
def retry(
|
88 |
+
text,
|
89 |
+
chatbot,
|
90 |
+
history,
|
91 |
+
top_p,
|
92 |
+
temperature,
|
93 |
+
max_length_tokens,
|
94 |
+
max_context_length_tokens,
|
95 |
+
):
|
96 |
+
logging.info("Retry...")
|
97 |
+
if len(history) == 0:
|
98 |
+
yield chatbot, history, "Empty context."
|
99 |
+
return
|
100 |
+
chatbot.pop()
|
101 |
+
inputs = history.pop()[0]
|
102 |
+
for x in predict(
|
103 |
+
inputs,
|
104 |
+
chatbot,
|
105 |
+
history,
|
106 |
+
top_p,
|
107 |
+
temperature,
|
108 |
+
max_length_tokens,
|
109 |
+
max_context_length_tokens,
|
110 |
+
):
|
111 |
+
yield x
|
112 |
+
|
113 |
+
|
114 |
+
with gr.Blocks(
|
115 |
+
theme=gr.themes.Soft(),
|
116 |
+
css=".disclaimer {font-variant-caps: all-small-caps;}"
|
117 |
+
) as demo:
|
118 |
+
history = gr.State([])
|
119 |
+
user_question = gr.State("")
|
120 |
+
with gr.Row():
|
121 |
+
gr.HTML("<h1>Lemur 🦥</h1>")
|
122 |
+
status_display = gr.Markdown("Success", elem_id="status_display")
|
123 |
+
|
124 |
+
with gr.Row(scale=1).style(equal_height=True):
|
125 |
+
with gr.Column(scale=5):
|
126 |
+
with gr.Row(scale=1):
|
127 |
+
chatbot = gr.Chatbot(elem_id="chuanhu_chatbot").style(height=800)
|
128 |
+
with gr.Row(scale=1):
|
129 |
+
with gr.Column(scale=12):
|
130 |
+
user_input = gr.Textbox(
|
131 |
+
show_label=False, placeholder="Enter text"
|
132 |
+
).style(container=False)
|
133 |
+
with gr.Column(min_width=70, scale=1):
|
134 |
+
submitBtn = gr.Button("📤 Send")
|
135 |
+
with gr.Column(min_width=70, scale=1):
|
136 |
+
cancelBtn = gr.Button("⏸️ Stop")
|
137 |
+
|
138 |
+
with gr.Row(scale=1):
|
139 |
+
emptyBtn = gr.Button(
|
140 |
+
"🧹 New Conversation",
|
141 |
+
)
|
142 |
+
retryBtn = gr.Button("🔄 Regenerate")
|
143 |
+
delLastBtn = gr.Button("🗑️ Remove Last Turn")
|
144 |
+
with gr.Column():
|
145 |
+
with gr.Column(min_width=50, scale=1):
|
146 |
+
with gr.Tab(label="Parameter Setting"):
|
147 |
+
gr.Markdown("# Parameters")
|
148 |
+
top_p = gr.Slider(
|
149 |
+
minimum=-0,
|
150 |
+
maximum=1.0,
|
151 |
+
value=0.95,
|
152 |
+
step=0.05,
|
153 |
+
interactive=True,
|
154 |
+
label="Top-p",
|
155 |
+
)
|
156 |
+
temperature = gr.Slider(
|
157 |
+
minimum=0.1,
|
158 |
+
maximum=2.0,
|
159 |
+
value=1,
|
160 |
+
step=0.1,
|
161 |
+
interactive=True,
|
162 |
+
label="Temperature",
|
163 |
+
)
|
164 |
+
max_length_tokens = gr.Slider(
|
165 |
+
minimum=0,
|
166 |
+
maximum=512,
|
167 |
+
value=512,
|
168 |
+
step=8,
|
169 |
+
interactive=True,
|
170 |
+
label="Max Generation Tokens",
|
171 |
+
)
|
172 |
+
max_context_length_tokens = gr.Slider(
|
173 |
+
minimum=0,
|
174 |
+
maximum=4096,
|
175 |
+
value=2048,
|
176 |
+
step=128,
|
177 |
+
interactive=True,
|
178 |
+
label="Max History Tokens",
|
179 |
+
)
|
180 |
+
|
181 |
+
predict_args = dict(
|
182 |
+
fn=predict,
|
183 |
+
inputs=[
|
184 |
+
user_question,
|
185 |
+
chatbot,
|
186 |
+
history,
|
187 |
+
top_p,
|
188 |
+
temperature,
|
189 |
+
max_length_tokens,
|
190 |
+
max_context_length_tokens,
|
191 |
+
],
|
192 |
+
outputs=[chatbot, history, status_display],
|
193 |
+
show_progress=True,
|
194 |
+
)
|
195 |
+
retry_args = dict(
|
196 |
+
fn=retry,
|
197 |
+
inputs=[
|
198 |
+
user_input,
|
199 |
+
chatbot,
|
200 |
+
history,
|
201 |
+
top_p,
|
202 |
+
temperature,
|
203 |
+
max_length_tokens,
|
204 |
+
max_context_length_tokens,
|
205 |
+
],
|
206 |
+
outputs=[chatbot, history, status_display],
|
207 |
+
show_progress=True,
|
208 |
+
)
|
209 |
+
|
210 |
+
reset_args = dict(fn=reset_textbox, inputs=[], outputs=[user_input, status_display])
|
211 |
+
|
212 |
+
# Chatbot
|
213 |
+
|
214 |
+
transfer_input_args = dict(
|
215 |
+
fn=transfer_input,
|
216 |
+
inputs=[user_input],
|
217 |
+
outputs=[user_question, user_input, submitBtn, cancelBtn],
|
218 |
+
show_progress=True,
|
219 |
+
)
|
220 |
+
|
221 |
+
submit_event = user_input.submit(**transfer_input_args).then(**predict_args)
|
222 |
+
|
223 |
+
submit_click_event = submitBtn.click(**transfer_input_args).then(**predict_args)
|
224 |
+
|
225 |
+
emptyBtn.click(
|
226 |
+
reset_state,
|
227 |
+
outputs=[chatbot, history, status_display],
|
228 |
+
show_progress=True,
|
229 |
+
)
|
230 |
+
emptyBtn.click(**reset_args)
|
231 |
+
|
232 |
+
retry_click_event = retryBtn.click(**retry_args)
|
233 |
+
|
234 |
+
cancelBtn.click(
|
235 |
+
fn=cancel_outputing,
|
236 |
+
inputs=[],
|
237 |
+
outputs=[status_display],
|
238 |
+
cancels=[submit_event, submit_click_event]
|
239 |
+
)
|
240 |
+
|
241 |
+
delLastBtn.click(
|
242 |
+
delete_last_conversation,
|
243 |
+
[chatbot, history],
|
244 |
+
[chatbot, history, status_display],
|
245 |
+
show_progress=True,
|
246 |
+
)
|
247 |
+
|
248 |
+
demo.title = "Lemur"
|
249 |
+
demo.queue(max_size=128, concurrency_count=2)
|
250 |
+
demo.launch()
|
lemur-7B/config.json
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_name_or_path": "llama-7B",
|
3 |
+
"architectures": [
|
4 |
+
"LlamaForCausalLM"
|
5 |
+
],
|
6 |
+
"bos_token_id": 1,
|
7 |
+
"eos_token_id": 2,
|
8 |
+
"hidden_act": "silu",
|
9 |
+
"hidden_size": 4096,
|
10 |
+
"initializer_range": 0.02,
|
11 |
+
"intermediate_size": 11008,
|
12 |
+
"max_position_embeddings": 2048,
|
13 |
+
"model_type": "llama",
|
14 |
+
"num_attention_heads": 32,
|
15 |
+
"num_hidden_layers": 32,
|
16 |
+
"pad_token_id": 0,
|
17 |
+
"rms_norm_eps": 1e-06,
|
18 |
+
"tie_word_embeddings": false,
|
19 |
+
"torch_dtype": "float32",
|
20 |
+
"transformers_version": "4.30.1",
|
21 |
+
"use_cache": true,
|
22 |
+
"vocab_size": 32000
|
23 |
+
}
|
lemur-7B/generation_config.json
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"_from_model_config": true,
|
3 |
+
"bos_token_id": 1,
|
4 |
+
"eos_token_id": 2,
|
5 |
+
"pad_token_id": 0,
|
6 |
+
"transformers_version": "4.30.1"
|
7 |
+
}
|
lemur-7B/pytorch_model.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8800f80fe257ad94942049beaa2dc86703571a8696bcaf0f03f57c021a2ec6ec
|
3 |
+
size 524332500
|
lemur-7B/special_tokens_map.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"content": "<s>",
|
4 |
+
"lstrip": false,
|
5 |
+
"normalized": true,
|
6 |
+
"rstrip": false,
|
7 |
+
"single_word": false
|
8 |
+
},
|
9 |
+
"eos_token": {
|
10 |
+
"content": "</s>",
|
11 |
+
"lstrip": false,
|
12 |
+
"normalized": true,
|
13 |
+
"rstrip": false,
|
14 |
+
"single_word": false
|
15 |
+
},
|
16 |
+
"pad_token": "</s>",
|
17 |
+
"unk_token": {
|
18 |
+
"content": "<unk>",
|
19 |
+
"lstrip": false,
|
20 |
+
"normalized": true,
|
21 |
+
"rstrip": false,
|
22 |
+
"single_word": false
|
23 |
+
}
|
24 |
+
}
|
lemur-7B/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
lemur-7B/tokenizer.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
|
3 |
+
size 499723
|
lemur-7B/tokenizer_config.json
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"bos_token": {
|
3 |
+
"__type": "AddedToken",
|
4 |
+
"content": "<s>",
|
5 |
+
"lstrip": false,
|
6 |
+
"normalized": true,
|
7 |
+
"rstrip": false,
|
8 |
+
"single_word": false
|
9 |
+
},
|
10 |
+
"clean_up_tokenization_spaces": false,
|
11 |
+
"eos_token": {
|
12 |
+
"__type": "AddedToken",
|
13 |
+
"content": "</s>",
|
14 |
+
"lstrip": false,
|
15 |
+
"normalized": true,
|
16 |
+
"rstrip": false,
|
17 |
+
"single_word": false
|
18 |
+
},
|
19 |
+
"model_max_length": 1000000000000000019884624838656,
|
20 |
+
"pad_token": null,
|
21 |
+
"sp_model_kwargs": {},
|
22 |
+
"tokenizer_class": "LlamaTokenizer",
|
23 |
+
"unk_token": {
|
24 |
+
"__type": "AddedToken",
|
25 |
+
"content": "<unk>",
|
26 |
+
"lstrip": false,
|
27 |
+
"normalized": true,
|
28 |
+
"rstrip": false,
|
29 |
+
"single_word": false
|
30 |
+
}
|
31 |
+
}
|
utils/gradio.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from utils.inference import shared_state
|
3 |
+
import re
|
4 |
+
|
5 |
+
def convert_to_markdown(text):
|
6 |
+
text = text.replace("$", "$")
|
7 |
+
|
8 |
+
def replace_leading_tabs_and_spaces(line):
|
9 |
+
new_line = []
|
10 |
+
|
11 |
+
for char in line:
|
12 |
+
if char == "\t":
|
13 |
+
new_line.append("	")
|
14 |
+
elif char == " ":
|
15 |
+
new_line.append(" ")
|
16 |
+
else:
|
17 |
+
break
|
18 |
+
return "".join(new_line) + line[len(new_line) :]
|
19 |
+
|
20 |
+
markdown_text = ""
|
21 |
+
lines = text.split("\n")
|
22 |
+
in_code_block = False
|
23 |
+
|
24 |
+
for line in lines:
|
25 |
+
if in_code_block is False and line.startswith("```"):
|
26 |
+
in_code_block = True
|
27 |
+
markdown_text += "```\n"
|
28 |
+
elif in_code_block is True and line.startswith("```"):
|
29 |
+
in_code_block = False
|
30 |
+
markdown_text += "```\n"
|
31 |
+
elif in_code_block:
|
32 |
+
markdown_text += f"{line}\n"
|
33 |
+
else:
|
34 |
+
line = replace_leading_tabs_and_spaces(line)
|
35 |
+
line = re.sub(r"^(#)", r"\\\1", line)
|
36 |
+
markdown_text += f"{line} \n"
|
37 |
+
|
38 |
+
return markdown_text
|
39 |
+
|
40 |
+
def reset_textbox():
|
41 |
+
return gr.update(value=""), ""
|
42 |
+
|
43 |
+
def cancel_outputing():
|
44 |
+
shared_state.interrupt()
|
45 |
+
textbox = reset_textbox()
|
46 |
+
return "Stop Done"
|
47 |
+
|
48 |
+
def reset_state():
|
49 |
+
return [], [], "Reset Done"
|
50 |
+
|
51 |
+
def transfer_input(inputs):
|
52 |
+
textbox = reset_textbox()
|
53 |
+
return (
|
54 |
+
inputs,
|
55 |
+
gr.update(value=""),
|
56 |
+
gr.Button.update(visible=True),
|
57 |
+
gr.Button.update(visible=True)
|
58 |
+
)
|
59 |
+
|
60 |
+
def delete_last_conversation(chatbot, history):
|
61 |
+
if len(chatbot) > 0:
|
62 |
+
chatbot.pop()
|
63 |
+
|
64 |
+
if len(history) > 0:
|
65 |
+
history.pop()
|
66 |
+
|
67 |
+
return (
|
68 |
+
chatbot,
|
69 |
+
history,
|
70 |
+
"Delete Done",
|
71 |
+
)
|
utils/inference.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
3 |
+
from peft import PeftModel
|
4 |
+
from typing import Iterator
|
5 |
+
from variables import SYSTEM, HUMAN, AI
|
6 |
+
|
7 |
+
|
8 |
+
def load_tokenizer_and_model(base_model, load_8bit=True):
|
9 |
+
|
10 |
+
if torch.cuda.is_available():
|
11 |
+
device = "cuda"
|
12 |
+
else:
|
13 |
+
device = "cpu"
|
14 |
+
|
15 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model)
|
16 |
+
model = AutoModelForCausalLM.from_pretrained(base_model, load_8bit=load_8bit)
|
17 |
+
|
18 |
+
return tokenizer, model, device
|
19 |
+
|
20 |
+
class State:
|
21 |
+
interrupted = False
|
22 |
+
|
23 |
+
def interrupt(self):
|
24 |
+
self.interrupted = True
|
25 |
+
|
26 |
+
def recover(self):
|
27 |
+
self.interrupted = False
|
28 |
+
|
29 |
+
shared_state = State()
|
30 |
+
|
31 |
+
def decode(
|
32 |
+
input_ids: torch.Tensor,
|
33 |
+
model: PeftModel,
|
34 |
+
tokenizer: AutoTokenizer,
|
35 |
+
stop_words: list,
|
36 |
+
max_length: int,
|
37 |
+
temperature: float = 1.0,
|
38 |
+
top_p: float = 1.0,
|
39 |
+
) -> Iterator[str]:
|
40 |
+
generated_tokens = []
|
41 |
+
past_key_values = None
|
42 |
+
|
43 |
+
for _ in range(max_length):
|
44 |
+
with torch.no_grad():
|
45 |
+
if past_key_values is None:
|
46 |
+
outputs = model(input_ids)
|
47 |
+
else:
|
48 |
+
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
|
49 |
+
logits = outputs.logits[:, -1, :]
|
50 |
+
past_key_values = outputs.past_key_values
|
51 |
+
|
52 |
+
# apply temperature
|
53 |
+
logits /= temperature
|
54 |
+
|
55 |
+
probs = torch.softmax(logits, dim=-1)
|
56 |
+
# apply top_p
|
57 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
58 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
59 |
+
mask = probs_sum - probs_sort > top_p
|
60 |
+
probs_sort[mask] = 0.0
|
61 |
+
|
62 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
63 |
+
next_token = torch.multinomial(probs_sort, num_samples=1)
|
64 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
65 |
+
|
66 |
+
input_ids = torch.cat((input_ids, next_token), dim=-1)
|
67 |
+
|
68 |
+
generated_tokens.append(next_token[0].item())
|
69 |
+
text = tokenizer.decode(generated_tokens)
|
70 |
+
|
71 |
+
yield text
|
72 |
+
if any([x in text for x in stop_words]):
|
73 |
+
return
|
74 |
+
|
75 |
+
|
76 |
+
def get_prompt_with_history(text, history, tokenizer, max_length=2048):
|
77 |
+
prompt = SYSTEM
|
78 |
+
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
|
79 |
+
history.append(f"\n{HUMAN} {text}\n{AI}")
|
80 |
+
history_text = ""
|
81 |
+
flag = False
|
82 |
+
for x in history[::-1]:
|
83 |
+
if (
|
84 |
+
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
|
85 |
+
-1
|
86 |
+
)
|
87 |
+
<= max_length
|
88 |
+
):
|
89 |
+
history_text = x + history_text
|
90 |
+
flag = True
|
91 |
+
else:
|
92 |
+
break
|
93 |
+
if flag:
|
94 |
+
return prompt + history_text, tokenizer(
|
95 |
+
prompt + history_text, return_tensors="pt"
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
return None
|
99 |
+
|
100 |
+
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
|
101 |
+
for stop_word in stop_words:
|
102 |
+
if s.endswith(stop_word):
|
103 |
+
return True
|
104 |
+
for i in range(1, len(stop_word)):
|
105 |
+
if s.endswith(stop_word[:i]):
|
106 |
+
return True
|
107 |
+
return False
|
variables.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SYSTEM = "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
|
2 |
+
HUMAN = "[Human]:"
|
3 |
+
AI = "[AI]:"
|
4 |
+
NAME = "Lemur"
|
5 |
+
ORGANIZATION = "UC San Diego (UCSD)"
|