davidberenstein1957 HF staff commited on
Commit
75f9ac3
1 Parent(s): 35ca0fa

feat: add login after generation

Browse files
src/distilabel_dataset_generator/sft.py CHANGED
@@ -1,5 +1,4 @@
1
  import multiprocessing
2
- import os
3
 
4
  import gradio as gr
5
  import pandas as pd
@@ -10,11 +9,9 @@ from distilabel.steps.tasks import MagpieGenerator, TextGeneration
10
 
11
  from src.distilabel_dataset_generator.utils import (
12
  OAuthToken,
13
- get_css,
14
  get_duplicate_button,
15
  get_login_button,
16
  get_org_dropdown,
17
- list_orgs,
18
  swap_visibilty,
19
  )
20
 
@@ -172,9 +169,6 @@ def _run_pipeline(result_queue, num_turns, num_rows, system_prompt, token: str =
172
 
173
 
174
  def generate_system_prompt(dataset_description, token: OAuthToken = None):
175
- if token is None:
176
- raise gr.Error("Please sign in with Hugging Face to generate a dataset.")
177
- os.environ["HF_TOKEN"] = token.token
178
  generate_description = TextGeneration(
179
  llm=InferenceEndpointsLLM(
180
  model_id=MODEL,
@@ -210,18 +204,14 @@ def generate_dataset(
210
  dataset_name=None,
211
  token: OAuthToken = None,
212
  ):
213
- if token is None:
214
- raise gr.Error("Please sign in with Hugging Face to generate a dataset.")
215
  if dataset_name is not None:
216
  if not dataset_name:
217
  raise gr.Error("Please provide a dataset name to push the dataset to.")
218
- if orgs_selector is not None:
219
- if not orgs_selector:
220
  raise gr.Error(
221
- f"Please select an organization to push the dataset to from: {list_orgs(token)}"
222
  )
223
 
224
- os.environ["HF_TOKEN"] = token.token
225
  gr.Info("Started pipeline execution.")
226
  result_queue = multiprocessing.Queue()
227
  p = multiprocessing.Process(
@@ -243,7 +233,7 @@ def generate_dataset(
243
  )
244
  gr.Info(f"Dataset pushed to Hugging Face Hub: https://huggingface.co/{repo_id}")
245
  else:
246
- # If not pushing to hub, generate the dataset directly
247
  distiset = distiset["default"]["train"]
248
  if num_turns == 1:
249
  outputs = distiset.to_pandas()[["instruction", "response"]]
@@ -261,7 +251,6 @@ def generate_dataset(
261
  with gr.Blocks(
262
  title="⚗️ Distilabel Dataset Generator",
263
  head="⚗️ Distilabel Dataset Generator",
264
- css=get_css(),
265
  ) as app:
266
  gr.Markdown(
267
  """
@@ -270,83 +259,74 @@ with gr.Blocks(
270
  More information on distilabel and techniques can be found in the "FAQ" tab. The code can be found in the [Spaces repository](https://huggingface.co/spaces/argilla/distilabel-dataset-generator/tree/main).
271
  """
272
  )
273
- with gr.Row(variant="panel"):
274
- with gr.Column():
275
- btn_login = get_login_button()
276
- with gr.Column():
277
- btn_duplicate = get_duplicate_button()
278
- with gr.Row():
279
- with gr.Column(visible=True) as main_ui:
280
- dataset_description = gr.Textbox(
281
- label="Provide a description of the dataset",
282
- value=DEFAULT_SYSTEM_PROMPT_DESCRIPTION,
283
- )
284
 
285
- btn_generate_system_prompt = gr.Button(value="🧪 Generate Sytem Prompt")
 
 
 
286
 
287
- system_prompt = gr.Textbox(
288
- label="Provide or correct the system prompt",
289
- value=DEFAULT_SYSTEM_PROMPT,
290
- )
291
 
292
- btn_generate_system_prompt.click(
293
- fn=generate_system_prompt,
294
- inputs=[dataset_description],
295
- outputs=[system_prompt],
296
- )
297
 
298
- btn_generate_sample_dataset = gr.Button(
299
- value="🧪 Generate Sample Dataset of 5 rows and a single turn",
300
- )
 
 
301
 
302
- table = gr.Dataframe(
303
- label="Generated Dataset", wrap=True, value=DEFAULT_DATASET
304
- )
305
 
306
- btn_generate_sample_dataset.click(
307
- fn=generate_dataset,
308
- inputs=[system_prompt],
309
- outputs=[table],
310
- )
311
 
312
- with gr.Row(variant="panel"):
313
- num_turns = gr.Number(
314
- value=1,
315
- label="Number of turns in the conversation",
316
- minimum=1,
317
- info="Whether the dataset is for a single turn with 'instruction-response' columns or a multi-turn conversation with a 'conversation' column.",
318
- )
319
- num_rows = gr.Number(
320
- value=100,
321
- label="Number of rows in the dataset",
322
- minimum=1,
323
- info="The number of rows in the dataset. Note that you are able to generate several 1000 rows at once but that this will take time.",
324
- )
325
- private = gr.Checkbox(
326
- label="Private dataset", value=True, interactive=True
327
- )
328
-
329
- with gr.Row(variant="panel"):
330
- orgs_selector = gr.Dropdown(label="Organization")
331
- dataset_name_push_to_hub = gr.Textbox(
332
- label="Dataset Name to push to Hub"
333
- )
334
-
335
- btn_generate_full_dataset = gr.Button(
336
- value="⚗️ Generate Full Dataset", variant="primary"
337
  )
338
-
339
- btn_generate_full_dataset.click(
340
- fn=generate_dataset,
341
- inputs=[
342
- system_prompt,
343
- num_turns,
344
- num_rows,
345
- private,
346
- orgs_selector,
347
- dataset_name_push_to_hub,
348
- ],
349
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  app.load(get_org_dropdown, outputs=[orgs_selector])
352
- app.load(fn=swap_visibilty, outputs=main_ui)
 
1
  import multiprocessing
 
2
 
3
  import gradio as gr
4
  import pandas as pd
 
9
 
10
  from src.distilabel_dataset_generator.utils import (
11
  OAuthToken,
 
12
  get_duplicate_button,
13
  get_login_button,
14
  get_org_dropdown,
 
15
  swap_visibilty,
16
  )
17
 
 
169
 
170
 
171
  def generate_system_prompt(dataset_description, token: OAuthToken = None):
 
 
 
172
  generate_description = TextGeneration(
173
  llm=InferenceEndpointsLLM(
174
  model_id=MODEL,
 
204
  dataset_name=None,
205
  token: OAuthToken = None,
206
  ):
 
 
207
  if dataset_name is not None:
208
  if not dataset_name:
209
  raise gr.Error("Please provide a dataset name to push the dataset to.")
210
+ if token is None:
 
211
  raise gr.Error(
212
+ "Please sign in with Hugging Face to be able to push the dataset to the Hub."
213
  )
214
 
 
215
  gr.Info("Started pipeline execution.")
216
  result_queue = multiprocessing.Queue()
217
  p = multiprocessing.Process(
 
233
  )
234
  gr.Info(f"Dataset pushed to Hugging Face Hub: https://huggingface.co/{repo_id}")
235
  else:
236
+ # If not pushing to hub generate the dataset directly
237
  distiset = distiset["default"]["train"]
238
  if num_turns == 1:
239
  outputs = distiset.to_pandas()[["instruction", "response"]]
 
251
  with gr.Blocks(
252
  title="⚗️ Distilabel Dataset Generator",
253
  head="⚗️ Distilabel Dataset Generator",
 
254
  ) as app:
255
  gr.Markdown(
256
  """
 
259
  More information on distilabel and techniques can be found in the "FAQ" tab. The code can be found in the [Spaces repository](https://huggingface.co/spaces/argilla/distilabel-dataset-generator/tree/main).
260
  """
261
  )
262
+ btn_duplicate = get_duplicate_button()
 
 
 
 
 
 
 
 
 
 
263
 
264
+ dataset_description = gr.Textbox(
265
+ label="Provide a description of the dataset",
266
+ value=DEFAULT_SYSTEM_PROMPT_DESCRIPTION,
267
+ )
268
 
269
+ btn_generate_system_prompt = gr.Button(value="🧪 Generate Sytem Prompt")
 
 
 
270
 
271
+ system_prompt = gr.Textbox(
272
+ label="Provide or correct the system prompt",
273
+ value=DEFAULT_SYSTEM_PROMPT,
274
+ )
 
275
 
276
+ btn_generate_system_prompt.click(
277
+ fn=generate_system_prompt,
278
+ inputs=[dataset_description],
279
+ outputs=[system_prompt],
280
+ )
281
 
282
+ btn_generate_sample_dataset = gr.Button(
283
+ value="🧪 Generate Sample Dataset of 5 rows and a single turn",
284
+ )
285
 
286
+ table = gr.Dataframe(label="Generated Dataset", wrap=True, value=DEFAULT_DATASET)
 
 
 
 
287
 
288
+ btn_generate_sample_dataset.click(
289
+ fn=generate_dataset,
290
+ inputs=[system_prompt],
291
+ outputs=[table],
292
+ )
293
+ btn_login: gr.LoginButton | None = get_login_button()
294
+ with gr.Column() as push_to_hub_ui:
295
+ with gr.Row(variant="panel"):
296
+ num_turns = gr.Number(
297
+ value=1,
298
+ label="Number of turns in the conversation",
299
+ minimum=1,
300
+ info="Whether the dataset is for a single turn with 'instruction-response' columns or a multi-turn conversation with a 'conversation' column.",
 
 
 
 
 
 
 
 
 
 
 
 
301
  )
302
+ num_rows = gr.Number(
303
+ value=100,
304
+ label="Number of rows in the dataset",
305
+ minimum=1,
306
+ maximum=5000,
307
+ info="The number of rows in the dataset. Note that you are able to generate more rows at once but that this will take time.",
 
 
 
 
 
308
  )
309
+ private = gr.Checkbox(label="Private dataset", value=True, interactive=True)
310
+
311
+ with gr.Row(variant="panel"):
312
+ orgs_selector = gr.Dropdown(label="Organization")
313
+ dataset_name_push_to_hub = gr.Textbox(label="Dataset Name to push to Hub")
314
+
315
+ btn_generate_full_dataset = gr.Button(
316
+ value="⚗️ Generate Full Dataset", variant="primary"
317
+ )
318
+
319
+ btn_generate_full_dataset.click(
320
+ fn=generate_dataset,
321
+ inputs=[
322
+ system_prompt,
323
+ num_turns,
324
+ num_rows,
325
+ private,
326
+ orgs_selector,
327
+ dataset_name_push_to_hub,
328
+ ],
329
+ )
330
 
331
  app.load(get_org_dropdown, outputs=[orgs_selector])
332
+ app.load(fn=swap_visibilty, outputs=push_to_hub_ui)
src/distilabel_dataset_generator/utils.py CHANGED
@@ -39,7 +39,7 @@ def get_login_button():
39
  or get_space() is None
40
  ):
41
  return gr.LoginButton(
42
- value="Sign in with Hugging Face to generate a dataset!",
43
  size="lg",
44
  )
45
 
@@ -70,23 +70,7 @@ def get_org_dropdown(token: OAuthToken = None):
70
 
71
 
72
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
73
- if get_space():
74
- if profile is None:
75
- return gr.Column(visible=False)
76
- else:
77
- return gr.Column(visible=True)
78
  else:
79
  return gr.Column(visible=True)
80
-
81
-
82
- def get_css():
83
- css = """
84
- h1{font-size: 2em}
85
- h3{margin-top: 0}
86
- #component-1{text-align:center}
87
- .main_ui_logged_out{opacity: 0.3; pointer-events: none}
88
- .tabitem{border: 0px}
89
- .group_padding{padding: .55em}
90
- #space_model .wrap > label:last-child{opacity: 0.3; pointer-events:none}
91
- """
92
- return css
 
39
  or get_space() is None
40
  ):
41
  return gr.LoginButton(
42
+ value="Sign in with Hugging Face to generate a full dataset and push it to the Hub!",
43
  size="lg",
44
  )
45
 
 
70
 
71
 
72
  def swap_visibilty(profile: Union[gr.OAuthProfile, None]):
73
+ if profile is None:
74
+ return gr.Column(visible=False)
 
 
 
75
  else:
76
  return gr.Column(visible=True)