davidberenstein1957 HF staff commited on
Commit
b000e50
1 Parent(s): 0d28c87

feat: update notification flow generation

Browse files
src/distilabel_dataset_generator/apps/sft.py CHANGED
@@ -10,7 +10,6 @@ from src.distilabel_dataset_generator.pipelines.sft import (
10
  DEFAULT_DATASET,
11
  DEFAULT_DATASET_DESCRIPTION,
12
  DEFAULT_SYSTEM_PROMPT,
13
- MODEL,
14
  PROMPT_CREATION_PROMPT,
15
  get_pipeline,
16
  get_prompt_generation_step,
@@ -104,10 +103,6 @@ def generate_dataset(
104
  else:
105
  duration = 1000
106
 
107
- gr.Info(
108
- "Dataset generation started. This might take a while. Don't close the page.",
109
- duration=duration,
110
- )
111
  result_queue = multiprocessing.Queue()
112
  p = multiprocessing.Process(
113
  target=_run_pipeline,
@@ -122,7 +117,7 @@ def generate_dataset(
122
  break
123
  progress(
124
  (step + 1) / total_steps,
125
- desc=f"Generating dataset with {num_rows} rows",
126
  )
127
  time.sleep(duration / total_steps) # Adjust this value based on your needs
128
  p.join()
@@ -151,52 +146,11 @@ def generate_dataset(
151
  return pd.DataFrame(outputs)
152
 
153
 
154
- def generate_pipeline_code(system_prompt):
155
- code = f"""
156
- from distilabel.pipeline import Pipeline
157
- from distilabel.steps import KeepColumns
158
- from distilabel.steps.tasks import MagpieGenerator
159
- from distilabel.llms import InferenceEndpointsLLM
160
-
161
- MODEL = "{MODEL}"
162
- SYSTEM_PROMPT = "{system_prompt}"
163
- # increase this to generate multi-turn conversations
164
- NUM_TURNS = 1
165
- # increase this to generate a larger dataset
166
- NUM_ROWS = 100
167
-
168
- with Pipeline(name="sft") as pipeline:
169
- magpie = MagpieGenerator(
170
- llm=InferenceEndpointsLLM(
171
- model_id=MODEL,
172
- tokenizer_id=MODEL,
173
- magpie_pre_query_template="llama3",
174
- generation_kwargs={{
175
- "temperature": 0.8,
176
- "do_sample": True,
177
- "max_new_tokens": 2048,
178
- "stop_sequences": [
179
- "<|eot_id|>",
180
- "<|end_of_text|>",
181
- "<|start_header_id|>",
182
- "<|end_header_id|>",
183
- "assistant",
184
- ],
185
- }}
186
- ),
187
- n_turns=NUM_TURNS,
188
- num_rows=NUM_ROWS,
189
- system_prompt=SYSTEM_PROMPT,
190
- )
191
-
192
- if __name__ == "__main__":
193
- distiset = pipeline.run()
194
- """
195
- return code
196
 
197
-
198
- def update_pipeline_code(system_prompt):
199
- return generate_pipeline_code(system_prompt)
200
 
201
 
202
  with gr.Blocks(
@@ -267,7 +221,7 @@ with gr.Blocks(
267
  minimum=1,
268
  maximum=4,
269
  step=1,
270
- info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'conversation' column).",
271
  )
272
  num_rows = gr.Number(
273
  value=100,
@@ -297,6 +251,7 @@ with gr.Blocks(
297
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
298
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
299
  <p style="margin-top: 0.5em;">
 
300
  Your dataset is now available at:
301
  <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
302
  https://huggingface.co/datasets/{org_name}/{repo_name}
@@ -307,7 +262,13 @@ with gr.Blocks(
307
  visible=True,
308
  )
309
 
 
 
 
310
  btn_generate_full_dataset.click(
 
 
 
311
  fn=generate_dataset,
312
  inputs=[
313
  system_prompt,
@@ -329,13 +290,11 @@ with gr.Blocks(
329
  gr.Markdown("## Or run this pipeline locally with distilabel")
330
 
331
  with gr.Accordion("Run this pipeline on Distilabel", open=False):
332
- pipeline_code = gr.Code(language="python", label="Distilabel Pipeline Code")
333
-
334
- system_prompt.change(
335
- fn=update_pipeline_code,
336
- inputs=[system_prompt],
337
- outputs=[pipeline_code],
338
- )
339
 
340
  app.load(get_token, outputs=[hf_token])
341
  app.load(get_org_dropdown, outputs=[org_name])
 
10
  DEFAULT_DATASET,
11
  DEFAULT_DATASET_DESCRIPTION,
12
  DEFAULT_SYSTEM_PROMPT,
 
13
  PROMPT_CREATION_PROMPT,
14
  get_pipeline,
15
  get_prompt_generation_step,
 
103
  else:
104
  duration = 1000
105
 
 
 
 
 
106
  result_queue = multiprocessing.Queue()
107
  p = multiprocessing.Process(
108
  target=_run_pipeline,
 
117
  break
118
  progress(
119
  (step + 1) / total_steps,
120
+ desc=f"Generating dataset with {num_rows} rows. Don't close this window.",
121
  )
122
  time.sleep(duration / total_steps) # Adjust this value based on your needs
123
  p.join()
 
146
  return pd.DataFrame(outputs)
147
 
148
 
149
+ def generate_pipeline_code() -> str:
150
+ with open("src/distilabel_dataset_generator/pipelines/sft.py", "r") as f:
151
+ pipeline_code = f.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
+ return pipeline_code
 
 
154
 
155
 
156
  with gr.Blocks(
 
221
  minimum=1,
222
  maximum=4,
223
  step=1,
224
+ info="Choose between 1 (single turn with 'instruction-response' columns) and 2-4 (multi-turn conversation with a 'messages' column).",
225
  )
226
  num_rows = gr.Number(
227
  value=100,
 
251
  <div style="padding: 1em; background-color: #e6f3e6; border-radius: 5px; margin-top: 1em;">
252
  <h3 style="color: #2e7d32; margin: 0;">Dataset Published Successfully!</h3>
253
  <p style="margin-top: 0.5em;">
254
+ The generated dataset is in the right format for Fine-tuning with TRL, AutoTrain or other frameworks.
255
  Your dataset is now available at:
256
  <a href="https://huggingface.co/datasets/{org_name}/{repo_name}" target="_blank" style="color: #1565c0; text-decoration: none;">
257
  https://huggingface.co/datasets/{org_name}/{repo_name}
 
262
  visible=True,
263
  )
264
 
265
+ def hide_success_message():
266
+ return gr.Markdown(visible=False)
267
+
268
  btn_generate_full_dataset.click(
269
+ fn=hide_success_message,
270
+ outputs=[success_message],
271
+ ).then(
272
  fn=generate_dataset,
273
  inputs=[
274
  system_prompt,
 
290
  gr.Markdown("## Or run this pipeline locally with distilabel")
291
 
292
  with gr.Accordion("Run this pipeline on Distilabel", open=False):
293
+ pipeline_code = gr.Code(
294
+ value=generate_pipeline_code(),
295
+ language="python",
296
+ label="Distilabel Pipeline Code",
297
+ )
 
 
298
 
299
  app.load(get_token, outputs=[hf_token])
300
  app.load(get_org_dropdown, outputs=[org_name])
src/distilabel_dataset_generator/utils.py CHANGED
@@ -39,7 +39,7 @@ def get_login_button():
39
  or get_space() is None
40
  ):
41
  return gr.LoginButton(
42
- value="Sign in with Hugging Face to generate a full dataset and push it to the Hub!",
43
  size="lg",
44
  )
45
 
 
39
  or get_space() is None
40
  ):
41
  return gr.LoginButton(
42
+ value="Sign in with Hugging Face! (This resets the session)",
43
  size="lg",
44
  )
45