sdiazlor HF staff commited on
Commit
3f21280
1 Parent(s): 3c2fc33

fix: apply feedback

Browse files
src/distilabel_dataset_generator/apps/base.py CHANGED
@@ -38,8 +38,8 @@ def get_main_ui(
38
  if task == TEXTCAT_TASK:
39
  result = fn_generate_dataset(
40
  system_prompt=system_prompt,
41
- difficulty="mixed",
42
- clarity="mixed",
43
  labels=[],
44
  num_labels=1,
45
  num_rows=1,
@@ -271,7 +271,11 @@ def get_iterate_on_sample_dataset_ui(
271
  with gr.Row():
272
  sample_dataset = gr.Dataframe(
273
  value=default_datasets[0],
274
- label="Sample dataset. Prompts and completions truncated to 256 tokens.",
 
 
 
 
275
  interactive=False,
276
  wrap=True,
277
  )
 
38
  if task == TEXTCAT_TASK:
39
  result = fn_generate_dataset(
40
  system_prompt=system_prompt,
41
+ difficulty="high school",
42
+ clarity="clear",
43
  labels=[],
44
  num_labels=1,
45
  num_rows=1,
 
271
  with gr.Row():
272
  sample_dataset = gr.Dataframe(
273
  value=default_datasets[0],
274
+ label=(
275
+ "Sample dataset. Text truncated to 256 tokens."
276
+ if task == TEXTCAT_TASK
277
+ else "Sample dataset. Prompts and completions truncated to 256 tokens."
278
+ ),
279
  interactive=False,
280
  wrap=True,
281
  )
src/distilabel_dataset_generator/apps/textcat.py CHANGED
@@ -215,7 +215,6 @@ def generate_dataset(
215
  system_prompt=system_prompt,
216
  labels=labels,
217
  num_labels=num_labels,
218
- is_sample=is_sample,
219
  )
220
  total_steps: int = num_rows * 2
221
  batch_size = DEFAULT_BATCH_SIZE
@@ -309,6 +308,9 @@ def validate_input_labels(labels):
309
  )
310
  return labels
311
 
 
 
 
312
 
313
  (
314
  app,
@@ -354,7 +356,7 @@ with app:
354
  ],
355
  value="mixed",
356
  label="Difficulty",
357
- info="The difficulty of the text to be generated.",
358
  )
359
  clarity = gr.Dropdown(
360
  choices=[
@@ -368,7 +370,7 @@ with app:
368
  ],
369
  value="mixed",
370
  label="Clarity",
371
- info="The clarity of the text to be generated.",
372
  )
373
  with gr.Column():
374
  labels = gr.Dropdown(
@@ -385,18 +387,18 @@ with app:
385
  size="sm",
386
  )
387
  num_labels = gr.Number(
388
- label="Number of labels",
389
  value=1,
390
  minimum=1,
391
  maximum=10,
392
- info="The number of labels to classify the text.",
393
  )
394
  num_rows = gr.Number(
395
  label="Number of rows",
396
  value=10,
397
  minimum=1,
398
  maximum=500,
399
- info="More rows will take longer to generate.",
400
  )
401
 
402
  pipeline_code = get_pipeline_code_ui(
@@ -415,6 +417,10 @@ with app:
415
  fn=update_suggested_labels,
416
  inputs=[system_prompt],
417
  outputs=labels,
 
 
 
 
418
  )
419
 
420
  gr.on(
@@ -540,6 +546,10 @@ with app:
540
  fn=generate_pipeline_code,
541
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
542
  outputs=[pipeline_code],
 
 
 
 
543
  )
544
  num_labels.change(
545
  fn=generate_pipeline_code,
 
215
  system_prompt=system_prompt,
216
  labels=labels,
217
  num_labels=num_labels,
 
218
  )
219
  total_steps: int = num_rows * 2
220
  batch_size = DEFAULT_BATCH_SIZE
 
308
  )
309
  return labels
310
 
311
+ def update_max_num_labels(labels):
312
+ return gr.update(maximum=len(labels) if labels else 1)
313
+
314
 
315
  (
316
  app,
 
356
  ],
357
  value="mixed",
358
  label="Difficulty",
359
+ info="Select the comprehension level for the text. Ensure it matches the task context.",
360
  )
361
  clarity = gr.Dropdown(
362
  choices=[
 
370
  ],
371
  value="mixed",
372
  label="Clarity",
373
+ info="Set how easily the correct label can be identified.",
374
  )
375
  with gr.Column():
376
  labels = gr.Dropdown(
 
387
  size="sm",
388
  )
389
  num_labels = gr.Number(
390
+ label="Number of labels per text",
391
  value=1,
392
  minimum=1,
393
  maximum=10,
394
+ info="Select 1 for single-label and >1 for multi-label.",
395
  )
396
  num_rows = gr.Number(
397
  label="Number of rows",
398
  value=10,
399
  minimum=1,
400
  maximum=500,
401
+ info="Select the number of rows in the dataset. More rows will take more time.",
402
  )
403
 
404
  pipeline_code = get_pipeline_code_ui(
 
417
  fn=update_suggested_labels,
418
  inputs=[system_prompt],
419
  outputs=labels,
420
+ ).then(
421
+ fn=update_max_num_labels,
422
+ inputs=[labels, num_labels],
423
+ outputs=[num_labels],
424
  )
425
 
426
  gr.on(
 
546
  fn=generate_pipeline_code,
547
  inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
548
  outputs=[pipeline_code],
549
+ ).then(
550
+ fn=update_max_num_labels,
551
+ inputs=[labels, num_labels],
552
+ outputs=[num_labels],
553
  )
554
  num_labels.change(
555
  fn=generate_pipeline_code,
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -176,7 +176,8 @@ def get_textcat_generator(difficulty, clarity, is_sample):
176
  api_key=_get_next_api_key(),
177
  generation_kwargs={
178
  "temperature": 0.8,
179
- "max_new_tokens": 256 if is_sample else 1024,
 
180
  },
181
  ),
182
  difficulty=None if difficulty == "mixed" else difficulty,
@@ -186,7 +187,7 @@ def get_textcat_generator(difficulty, clarity, is_sample):
186
  return textcat_generator
187
 
188
 
189
- def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
190
  labeller_generator = TextClassification(
191
  llm=InferenceEndpointsLLM(
192
  model_id=MODEL,
@@ -194,7 +195,7 @@ def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
194
  api_key=_get_next_api_key(),
195
  generation_kwargs={
196
  "temperature": 0.8,
197
- "max_new_tokens": 256 if is_sample else 1024,
198
  },
199
  ),
200
  context=system_prompt,
 
176
  api_key=_get_next_api_key(),
177
  generation_kwargs={
178
  "temperature": 0.8,
179
+ "max_new_tokens": 256 if is_sample else 2048,
180
+ "do_sample": True,
181
  },
182
  ),
183
  difficulty=None if difficulty == "mixed" else difficulty,
 
187
  return textcat_generator
188
 
189
 
190
+ def get_labeller_generator(system_prompt, labels, num_labels):
191
  labeller_generator = TextClassification(
192
  llm=InferenceEndpointsLLM(
193
  model_id=MODEL,
 
195
  api_key=_get_next_api_key(),
196
  generation_kwargs={
197
  "temperature": 0.8,
198
+ "max_new_tokens": 2048,
199
  },
200
  ),
201
  context=system_prompt,