niraito sam-mosaic commited on
Commit
2ea6751
0 Parent(s):

Duplicate from mosaicml/mpt-7b-instruct

Browse files

Co-authored-by: Sam <[email protected]>

Files changed (5) hide show
  1. .gitattributes +34 -0
  2. README.md +13 -0
  3. app.py +246 -0
  4. quick_pipeline.py +85 -0
  5. requirements.txt +4 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MPT-7B-Instruct
3
+ emoji: 💁
4
+ colorFrom: yellow
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.28.0
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: mosaicml/mpt-7b-instruct
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 MosaicML spaces authors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # and
4
+ # the https://huggingface.co/spaces/HuggingFaceH4/databricks-dolly authors
5
+ import datetime
6
+ import os
7
+ from threading import Event, Thread
8
+ from uuid import uuid4
9
+
10
+ import gradio as gr
11
+ import requests
12
+ import torch
13
+ from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
14
+
15
+ from quick_pipeline import InstructionTextGenerationPipeline as pipeline
16
+
17
+
18
+ # Configuration
19
+ HF_TOKEN = os.getenv("HF_TOKEN", None)
20
+
21
+ examples = [
22
+ # to do: add coupled hparams so e.g. poem has higher temp
23
+ "Write a travel blog about a 3-day trip to Thailand.",
24
+ "Write a short story about a robot that has a nice day.",
25
+ "Convert the following to a single line of JSON:\n\n```name: John\nage: 30\naddress:\n street:123 Main St.\n city: San Francisco\n state: CA\n zip: 94101\n```",
26
+ "Write a quick email to congratulate MosaicML about the launch of their inference offering.",
27
+ "Explain how a candle works to a 6 year old in a few sentences.",
28
+ "What are some of the most common misconceptions about birds?",
29
+ ]
30
+
31
+ # Initialize the model and tokenizer
32
+ generate = pipeline(
33
+ "mosaicml/mpt-7b-instruct",
34
+ torch_dtype=torch.bfloat16,
35
+ trust_remote_code=True,
36
+ use_auth_token=HF_TOKEN,
37
+ )
38
+ stop_token_ids = generate.tokenizer.convert_tokens_to_ids(["<|endoftext|>"])
39
+
40
+
41
+ # Define a custom stopping criteria
42
+ class StopOnTokens(StoppingCriteria):
43
+ def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
44
+ for stop_id in stop_token_ids:
45
+ if input_ids[0][-1] == stop_id:
46
+ return True
47
+ return False
48
+
49
+
50
+ def log_conversation(session_id, instruction, response, generate_kwargs):
51
+ logging_url = os.getenv("LOGGING_URL", None)
52
+ if logging_url is None:
53
+ return
54
+
55
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%dT%H:%M:%S")
56
+
57
+ data = {
58
+ "session_id": session_id,
59
+ "timestamp": timestamp,
60
+ "instruction": instruction,
61
+ "response": response,
62
+ "generate_kwargs": generate_kwargs,
63
+ }
64
+
65
+ try:
66
+ requests.post(logging_url, json=data)
67
+ except requests.exceptions.RequestException as e:
68
+ print(f"Error logging conversation: {e}")
69
+
70
+
71
+ def process_stream(instruction, temperature, top_p, top_k, max_new_tokens, session_id):
72
+ # Tokenize the input
73
+ input_ids = generate.tokenizer(
74
+ generate.format_instruction(instruction), return_tensors="pt"
75
+ ).input_ids
76
+ input_ids = input_ids.to(generate.model.device)
77
+
78
+ # Initialize the streamer and stopping criteria
79
+ streamer = TextIteratorStreamer(
80
+ generate.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
81
+ )
82
+ stop = StopOnTokens()
83
+
84
+ if temperature < 0.1:
85
+ temperature = 0.0
86
+ do_sample = False
87
+ else:
88
+ do_sample = True
89
+
90
+ gkw = {
91
+ **generate.generate_kwargs,
92
+ **{
93
+ "input_ids": input_ids,
94
+ "max_new_tokens": max_new_tokens,
95
+ "temperature": temperature,
96
+ "do_sample": do_sample,
97
+ "top_p": top_p,
98
+ "top_k": top_k,
99
+ "streamer": streamer,
100
+ "stopping_criteria": StoppingCriteriaList([stop]),
101
+ },
102
+ }
103
+
104
+ response = ""
105
+ stream_complete = Event()
106
+
107
+ def generate_and_signal_complete():
108
+ generate.model.generate(**gkw)
109
+ stream_complete.set()
110
+
111
+ def log_after_stream_complete():
112
+ stream_complete.wait()
113
+ log_conversation(
114
+ session_id,
115
+ instruction,
116
+ response,
117
+ {
118
+ "top_k": top_k,
119
+ "top_p": top_p,
120
+ "temperature": temperature,
121
+ },
122
+ )
123
+
124
+ t1 = Thread(target=generate_and_signal_complete)
125
+ t1.start()
126
+
127
+ t2 = Thread(target=log_after_stream_complete)
128
+ t2.start()
129
+
130
+ for new_text in streamer:
131
+ response += new_text
132
+ yield response
133
+
134
+
135
+ with gr.Blocks(
136
+ theme=gr.themes.Soft(),
137
+ css=".disclaimer {font-variant-caps: all-small-caps;}",
138
+ ) as demo:
139
+ session_id = gr.State(lambda: str(uuid4()))
140
+ gr.Markdown(
141
+ """<h1><center>MosaicML MPT-7B-Instruct</center></h1>
142
+
143
+ This demo is of [MPT-7B-Instruct](https://huggingface.co/mosaicml/mpt-7b-instruct). It is based on [MPT-7B](https://huggingface.co/mosaicml/mpt-7b) fine-tuned with approximately [60,000 instruction demonstrations](https://huggingface.co/datasets/sam-mosaic/dolly_hhrlhf)
144
+
145
+ If you're interested in [training](https://www.mosaicml.com/training) and [deploying](https://www.mosaicml.com/inference) your own MPT or LLMs, [sign up](https://forms.mosaicml.com/demo?utm_source=huggingface&utm_medium=referral&utm_campaign=mpt-7b) for MosaicML platform.
146
+
147
+ This is running on a smaller, shared GPU, so it may take a few seconds to respond. If you want to run it on your own GPU, you can [download the model from HuggingFace](https://huggingface.co/mosaicml/mpt-7b-instruct) and run it locally. Or [Duplicate the Space](https://huggingface.co/spaces/mosaicml/mpt-7b-instruct?duplicate=true) to skip the queue and run in a private space."""
148
+ )
149
+ with gr.Row():
150
+ with gr.Column():
151
+ with gr.Row():
152
+ instruction = gr.Textbox(
153
+ placeholder="Enter your question here",
154
+ label="Question/Instruction",
155
+ elem_id="q-input",
156
+ )
157
+ with gr.Accordion("Advanced Options:", open=False):
158
+ with gr.Row():
159
+ with gr.Column():
160
+ with gr.Row():
161
+ temperature = gr.Slider(
162
+ label="Temperature",
163
+ value=0.1,
164
+ minimum=0.0,
165
+ maximum=1.0,
166
+ step=0.1,
167
+ interactive=True,
168
+ info="Higher values produce more diverse outputs",
169
+ )
170
+ with gr.Column():
171
+ with gr.Row():
172
+ top_p = gr.Slider(
173
+ label="Top-p (nucleus sampling)",
174
+ value=1.0,
175
+ minimum=0.0,
176
+ maximum=1,
177
+ step=0.01,
178
+ interactive=True,
179
+ info=(
180
+ "Sample from the smallest possible set of tokens whose cumulative probability "
181
+ "exceeds top_p. Set to 1 to disable and sample from all tokens."
182
+ ),
183
+ )
184
+ with gr.Column():
185
+ with gr.Row():
186
+ top_k = gr.Slider(
187
+ label="Top-k",
188
+ value=0,
189
+ minimum=0.0,
190
+ maximum=200,
191
+ step=1,
192
+ interactive=True,
193
+ info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
194
+ )
195
+ with gr.Column():
196
+ with gr.Row():
197
+ max_new_tokens = gr.Slider(
198
+ label="Maximum new tokens",
199
+ value=256,
200
+ minimum=0,
201
+ maximum=1664,
202
+ step=5,
203
+ interactive=True,
204
+ info="The maximum number of new tokens to generate",
205
+ )
206
+ with gr.Row():
207
+ submit = gr.Button("Submit")
208
+ with gr.Row():
209
+ with gr.Box():
210
+ gr.Markdown("**MPT-7B-Instruct**")
211
+ output_7b = gr.Markdown()
212
+
213
+ with gr.Row():
214
+ gr.Examples(
215
+ examples=examples,
216
+ inputs=[instruction],
217
+ cache_examples=False,
218
+ fn=process_stream,
219
+ outputs=output_7b,
220
+ )
221
+ with gr.Row():
222
+ gr.Markdown(
223
+ "Disclaimer: MPT-7B can produce factually incorrect output, and should not be relied on to produce "
224
+ "factually accurate information. MPT-7B was trained on various public datasets; while great efforts "
225
+ "have been taken to clean the pretraining data, it is possible that this model could generate lewd, "
226
+ "biased, or otherwise offensive outputs.",
227
+ elem_classes=["disclaimer"],
228
+ )
229
+ with gr.Row():
230
+ gr.Markdown(
231
+ "[Privacy policy](https://gist.github.com/samhavens/c29c68cdcd420a9aa0202d0839876dac)",
232
+ elem_classes=["disclaimer"],
233
+ )
234
+
235
+ submit.click(
236
+ process_stream,
237
+ inputs=[instruction, temperature, top_p, top_k, max_new_tokens, session_id],
238
+ outputs=output_7b,
239
+ )
240
+ instruction.submit(
241
+ process_stream,
242
+ inputs=[instruction, temperature, top_p, top_k, max_new_tokens, session_id],
243
+ outputs=output_7b,
244
+ )
245
+
246
+ demo.queue(max_size=32, concurrency_count=4).launch(debug=True)
quick_pipeline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Tuple
2
+ import warnings
3
+
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+
7
+
8
+ INSTRUCTION_KEY = "### Instruction:"
9
+ RESPONSE_KEY = "### Response:"
10
+ END_KEY = "### End"
11
+ INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
12
+ PROMPT_FOR_GENERATION_FORMAT = """{intro}
13
+
14
+ {instruction_key}
15
+ {instruction}
16
+
17
+ {response_key}
18
+ """.format(
19
+ intro=INTRO_BLURB,
20
+ instruction_key=INSTRUCTION_KEY,
21
+ instruction="{instruction}",
22
+ response_key=RESPONSE_KEY,
23
+ )
24
+
25
+
26
+ class InstructionTextGenerationPipeline:
27
+ def __init__(
28
+ self,
29
+ model_name,
30
+ torch_dtype=torch.bfloat16,
31
+ trust_remote_code=True,
32
+ use_auth_token=None,
33
+ ) -> None:
34
+ self.model = AutoModelForCausalLM.from_pretrained(
35
+ model_name,
36
+ torch_dtype=torch_dtype,
37
+ trust_remote_code=trust_remote_code,
38
+ use_auth_token=use_auth_token,
39
+ )
40
+
41
+ tokenizer = AutoTokenizer.from_pretrained(
42
+ model_name,
43
+ trust_remote_code=trust_remote_code,
44
+ use_auth_token=use_auth_token,
45
+ )
46
+ if tokenizer.pad_token_id is None:
47
+ warnings.warn(
48
+ "pad_token_id is not set for the tokenizer. Using eos_token_id as pad_token_id."
49
+ )
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ tokenizer.padding_side = "left"
52
+ self.tokenizer = tokenizer
53
+
54
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
+ self.model.eval()
56
+ self.model.to(device=device, dtype=torch_dtype)
57
+
58
+ self.generate_kwargs = {
59
+ "temperature": 0.5,
60
+ "top_p": 0.92,
61
+ "top_k": 0,
62
+ "max_new_tokens": 512,
63
+ "use_cache": True,
64
+ "do_sample": True,
65
+ "eos_token_id": self.tokenizer.eos_token_id,
66
+ "pad_token_id": self.tokenizer.pad_token_id,
67
+ "repetition_penalty": 1.1, # 1.0 means no penalty, > 1.0 means penalty, 1.2 from CTRL paper
68
+ }
69
+
70
+ def format_instruction(self, instruction):
71
+ return PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
72
+
73
+ def __call__(
74
+ self, instruction: str, **generate_kwargs: Dict[str, Any]
75
+ ) -> Tuple[str, str, float]:
76
+ s = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
77
+ input_ids = self.tokenizer(s, return_tensors="pt").input_ids
78
+ input_ids = input_ids.to(self.model.device)
79
+ gkw = {**self.generate_kwargs, **generate_kwargs}
80
+ with torch.no_grad():
81
+ output_ids = self.model.generate(input_ids, **gkw)
82
+ # Slice the output_ids tensor to get only new tokens
83
+ new_tokens = output_ids[0, len(input_ids[0]) :]
84
+ output_text = self.tokenizer.decode(new_tokens, skip_special_tokens=True)
85
+ return output_text
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ -e git+https://github.com/samhavens/just-triton-flash.git#egg=flash_attn
2
+ einops
3
+ torch
4
+ transformers