Commit
•
b000e50
1
Parent(s):
0d28c87
feat: update notification flow generation
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -10,7 +10,6 @@ from src.distilabel_dataset_generator.pipelines.sft import (
|
|
10 |
DEFAULT_DATASET,
|
11 |
DEFAULT_DATASET_DESCRIPTION,
|
12 |
DEFAULT_SYSTEM_PROMPT,
|
13 |
-
MODEL,
|
14 |
PROMPT_CREATION_PROMPT,
|
15 |
get_pipeline,
|
16 |
get_prompt_generation_step,
|
@@ -104,10 +103,6 @@ def generate_dataset(
|
|
104 |
else:
|
105 |
duration = 1000
|
106 |
|
107 |
-
gr.Info(
|
108 |
-
"Dataset generation started. This might take a while. Don't close the page.",
|
109 |
-
duration=duration,
|
110 |
-
)
|
111 |
result_queue = multiprocessing.Queue()
|
112 |
p = multiprocessing.Process(
|
113 |
target=_run_pipeline,
|
@@ -122,7 +117,7 @@ def generate_dataset(
|
|
122 |
break
|
123 |
progress(
|
124 |
(step + 1) / total_steps,
|
125 |
-
desc=f"Generating dataset with {num_rows} rows",
|
126 |
)
|
127 |
time.sleep(duration / total_steps) # Adjust this value based on your needs
|
128 |
p.join()
|
@@ -151,52 +146,11 @@ def generate_dataset(
|
|
151 |
return pd.DataFrame(outputs)
|
152 |
|
153 |
|
154 |
-
def generate_pipeline_code(
|
155 |
-
|
156 |
-
|
157 |
-
from distilabel.steps import KeepColumns
|
158 |
-
from distilabel.steps.tasks import MagpieGenerator
|
159 |
-
from distilabel.llms import InferenceEndpointsLLM
|
160 |
-
|
161 |
-
MODEL = "{MODEL}"
|
162 |
-
SYSTEM_PROMPT = "{system_prompt}"
|
163 |
-
# increase this to generate multi-turn conversations
|
164 |
-
NUM_TURNS = 1
|
165 |
-
# increase this to generate a larger dataset
|
166 |
-
NUM_ROWS = 100
|
167 |
-
|
168 |
-
with Pipeline(name="sft") as pipeline:
|
169 |
-
magpie = MagpieGenerator(
|
170 |
-
llm=InferenceEndpointsLLM(
|
171 |
-
model_id=MODEL,
|
172 |
-
tokenizer_id=MODEL,
|
173 |
-
magpie_pre_query_template="llama3",
|
174 |
-
generation_kwargs={{
|
175 |
-
"temperature": 0.8,
|
176 |
-
"do_sample": True,
|
177 |
-
"max_new_tokens": 2048,
|
178 |
-
"stop_sequences": [
|
179 |
-
"<|eot_id|>",
|
180 |
-
"<|end_of_text|>",
|
181 |
-
"<|start_header_id|>",
|
182 |
-
"<|end_header_id|>",
|
183 |
-
"assistant",
|
184 |
-
],
|
185 |
-
}}
|
186 |
-
),
|
187 |
-
n_turns=NUM_TURNS,
|
188 |
-
num_rows=NUM_ROWS,
|
189 |
-
system_prompt=SYSTEM_PROMPT,
|
190 |
-
)
|
191 |
-
|
192 |
-
if __name__ == "__main__":
|
193 |
-
distiset = pipeline.run()
|
194 |
-
"""
|
195 |
-
return code
|
196 |
|
197 |
-
|
198 |
-
def update_pipeline_code(system_prompt):
|
199 |
-
return generate_pipeline_code(system_prompt)
|
200 |
|
201 |
|
202 |
with gr.Blocks(
|
@@ -267,7 +221,7 @@ with gr.Blocks(
|
|
267 |
minimum=1,
|
268 |
maximum=4,
|
269 |
step=1,
|
270 |
-
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a '
|
271 |
)
|
272 |
num_rows = gr.Number(
|
273 |
value=100,
|
@@ -297,6 +251,7 @@ with gr.Blocks(
|
|
297 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
298 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
299 |
<p style="margin-top: 0.5em;">
|
|
|
300 |
Your dataset is now available at:
|
301 |
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
302 |
https://huggingface.co/datasets/{org_name}/{repo_name}
|
@@ -307,7 +262,13 @@ with gr.Blocks(
|
|
307 |
visible=True,
|
308 |
)
|
309 |
|
|
|
|
|
|
|
310 |
btn_generate_full_dataset.click(
|
|
|
|
|
|
|
311 |
fn=generate_dataset,
|
312 |
inputs=[
|
313 |
system_prompt,
|
@@ -329,13 +290,11 @@ with gr.Blocks(
|
|
329 |
gr.Markdown("## Or run this pipeline locally with distilabel")
|
330 |
|
331 |
with gr.Accordion("Run this pipeline on Distilabel", open=False):
|
332 |
-
pipeline_code = gr.Code(
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
outputs=[pipeline_code],
|
338 |
-
)
|
339 |
|
340 |
app.load(get_token, outputs=[hf_token])
|
341 |
app.load(get_org_dropdown, outputs=[org_name])
|
|
|
10 |
DEFAULT_DATASET,
|
11 |
DEFAULT_DATASET_DESCRIPTION,
|
12 |
DEFAULT_SYSTEM_PROMPT,
|
|
|
13 |
PROMPT_CREATION_PROMPT,
|
14 |
get_pipeline,
|
15 |
get_prompt_generation_step,
|
|
|
103 |
else:
|
104 |
duration = 1000
|
105 |
|
|
|
|
|
|
|
|
|
106 |
result_queue = multiprocessing.Queue()
|
107 |
p = multiprocessing.Process(
|
108 |
target=_run_pipeline,
|
|
|
117 |
break
|
118 |
progress(
|
119 |
(step + 1) / total_steps,
|
120 |
+
desc=f"Generating dataset with {num_rows} rows. Don't close this window.",
|
121 |
)
|
122 |
time.sleep(duration / total_steps) # Adjust this value based on your needs
|
123 |
p.join()
|
|
|
146 |
return pd.DataFrame(outputs)
|
147 |
|
148 |
|
149 |
+
def generate_pipeline_code() -> str:
|
150 |
+
with open("src/distilabel_dataset_generator/pipelines/sft.py", "r") as f:
|
151 |
+
pipeline_code = f.read()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
+
return pipeline_code
|
|
|
|
|
154 |
|
155 |
|
156 |
with gr.Blocks(
|
|
|
221 |
minimum=1,
|
222 |
maximum=4,
|
223 |
step=1,
|
224 |
+
info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
|
225 |
)
|
226 |
num_rows = gr.Number(
|
227 |
value=100,
|
|
|
251 |
<div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
|
252 |
<h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
|
253 |
<p style="margin-top: 0.5em;">
|
254 |
+
The generated dataset is in the right format for Fine-tuning with TRL, AutoTrain or other frameworks.
|
255 |
Your dataset is now available at:
|
256 |
<a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
|
257 |
https://huggingface.co/datasets/{org_name}/{repo_name}
|
|
|
262 |
visible=True,
|
263 |
)
|
264 |
|
265 |
+
def hide_success_message():
|
266 |
+
return gr.Markdown(visible=False)
|
267 |
+
|
268 |
btn_generate_full_dataset.click(
|
269 |
+
fn=hide_success_message,
|
270 |
+
outputs=[success_message],
|
271 |
+
).then(
|
272 |
fn=generate_dataset,
|
273 |
inputs=[
|
274 |
system_prompt,
|
|
|
290 |
gr.Markdown("## Or run this pipeline locally with distilabel")
|
291 |
|
292 |
with gr.Accordion("Run this pipeline on Distilabel", open=False):
|
293 |
+
pipeline_code = gr.Code(
|
294 |
+
value=generate_pipeline_code(),
|
295 |
+
language="python",
|
296 |
+
label="Distilabel Pipeline Code",
|
297 |
+
)
|
|
|
|
|
298 |
|
299 |
app.load(get_token, outputs=[hf_token])
|
300 |
app.load(get_org_dropdown, outputs=[org_name])
|
src/distilabel_dataset_generator/utils.py
CHANGED
@@ -39,7 +39,7 @@ def get_login_button():
|
|
39 |
or get_space() is None
|
40 |
):
|
41 |
return gr.LoginButton(
|
42 |
-
value="Sign in with Hugging Face
|
43 |
size="lg",
|
44 |
)
|
45 |
|
|
|
39 |
or get_space() is None
|
40 |
):
|
41 |
return gr.LoginButton(
|
42 |
+
value="Sign in with Hugging Face! (This resets the session)",
|
43 |
size="lg",
|
44 |
)
|
45 |
|