fix: apply feedback
Browse files
src/distilabel_dataset_generator/apps/base.py
CHANGED
@@ -38,8 +38,8 @@ def get_main_ui(
|
|
38 |
if task == TEXTCAT_TASK:
|
39 |
result = fn_generate_dataset(
|
40 |
system_prompt=system_prompt,
|
41 |
-
difficulty="
|
42 |
-
clarity="
|
43 |
labels=[],
|
44 |
num_labels=1,
|
45 |
num_rows=1,
|
@@ -271,7 +271,11 @@ def get_iterate_on_sample_dataset_ui(
|
|
271 |
with gr.Row():
|
272 |
sample_dataset = gr.Dataframe(
|
273 |
value=default_datasets[0],
|
274 |
-
label=
|
|
|
|
|
|
|
|
|
275 |
interactive=False,
|
276 |
wrap=True,
|
277 |
)
|
|
|
38 |
if task == TEXTCAT_TASK:
|
39 |
result = fn_generate_dataset(
|
40 |
system_prompt=system_prompt,
|
41 |
+
difficulty="high school",
|
42 |
+
clarity="clear",
|
43 |
labels=[],
|
44 |
num_labels=1,
|
45 |
num_rows=1,
|
|
|
271 |
with gr.Row():
|
272 |
sample_dataset = gr.Dataframe(
|
273 |
value=default_datasets[0],
|
274 |
+
label=(
|
275 |
+
"Sample dataset. Text truncated to 256 tokens."
|
276 |
+
if task == TEXTCAT_TASK
|
277 |
+
else "Sample dataset. Prompts and completions truncated to 256 tokens."
|
278 |
+
),
|
279 |
interactive=False,
|
280 |
wrap=True,
|
281 |
)
|
src/distilabel_dataset_generator/apps/textcat.py
CHANGED
@@ -215,7 +215,6 @@ def generate_dataset(
|
|
215 |
system_prompt=system_prompt,
|
216 |
labels=labels,
|
217 |
num_labels=num_labels,
|
218 |
-
is_sample=is_sample,
|
219 |
)
|
220 |
total_steps: int = num_rows * 2
|
221 |
batch_size = DEFAULT_BATCH_SIZE
|
@@ -309,6 +308,9 @@ def validate_input_labels(labels):
|
|
309 |
)
|
310 |
return labels
|
311 |
|
|
|
|
|
|
|
312 |
|
313 |
(
|
314 |
app,
|
@@ -354,7 +356,7 @@ with app:
|
|
354 |
],
|
355 |
value="mixed",
|
356 |
label="Difficulty",
|
357 |
-
info="
|
358 |
)
|
359 |
clarity = gr.Dropdown(
|
360 |
choices=[
|
@@ -368,7 +370,7 @@ with app:
|
|
368 |
],
|
369 |
value="mixed",
|
370 |
label="Clarity",
|
371 |
-
info="
|
372 |
)
|
373 |
with gr.Column():
|
374 |
labels = gr.Dropdown(
|
@@ -385,18 +387,18 @@ with app:
|
|
385 |
size="sm",
|
386 |
)
|
387 |
num_labels = gr.Number(
|
388 |
-
label="Number of labels",
|
389 |
value=1,
|
390 |
minimum=1,
|
391 |
maximum=10,
|
392 |
-
info="
|
393 |
)
|
394 |
num_rows = gr.Number(
|
395 |
label="Number of rows",
|
396 |
value=10,
|
397 |
minimum=1,
|
398 |
maximum=500,
|
399 |
-
info="More rows will take
|
400 |
)
|
401 |
|
402 |
pipeline_code = get_pipeline_code_ui(
|
@@ -415,6 +417,10 @@ with app:
|
|
415 |
fn=update_suggested_labels,
|
416 |
inputs=[system_prompt],
|
417 |
outputs=labels,
|
|
|
|
|
|
|
|
|
418 |
)
|
419 |
|
420 |
gr.on(
|
@@ -540,6 +546,10 @@ with app:
|
|
540 |
fn=generate_pipeline_code,
|
541 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
542 |
outputs=[pipeline_code],
|
|
|
|
|
|
|
|
|
543 |
)
|
544 |
num_labels.change(
|
545 |
fn=generate_pipeline_code,
|
|
|
215 |
system_prompt=system_prompt,
|
216 |
labels=labels,
|
217 |
num_labels=num_labels,
|
|
|
218 |
)
|
219 |
total_steps: int = num_rows * 2
|
220 |
batch_size = DEFAULT_BATCH_SIZE
|
|
|
308 |
)
|
309 |
return labels
|
310 |
|
311 |
+
def update_max_num_labels(labels):
|
312 |
+
return gr.update(maximum=len(labels) if labels else 1)
|
313 |
+
|
314 |
|
315 |
(
|
316 |
app,
|
|
|
356 |
],
|
357 |
value="mixed",
|
358 |
label="Difficulty",
|
359 |
+
info="Select the comprehension level for the text. Ensure it matches the task context.",
|
360 |
)
|
361 |
clarity = gr.Dropdown(
|
362 |
choices=[
|
|
|
370 |
],
|
371 |
value="mixed",
|
372 |
label="Clarity",
|
373 |
+
info="Set how easily the correct label can be identified.",
|
374 |
)
|
375 |
with gr.Column():
|
376 |
labels = gr.Dropdown(
|
|
|
387 |
size="sm",
|
388 |
)
|
389 |
num_labels = gr.Number(
|
390 |
+
label="Number of labels per text",
|
391 |
value=1,
|
392 |
minimum=1,
|
393 |
maximum=10,
|
394 |
+
info="Select 1 for single-label and >1 for multi-label.",
|
395 |
)
|
396 |
num_rows = gr.Number(
|
397 |
label="Number of rows",
|
398 |
value=10,
|
399 |
minimum=1,
|
400 |
maximum=500,
|
401 |
+
info="Select the number of rows in the dataset. More rows will take more time.",
|
402 |
)
|
403 |
|
404 |
pipeline_code = get_pipeline_code_ui(
|
|
|
417 |
fn=update_suggested_labels,
|
418 |
inputs=[system_prompt],
|
419 |
outputs=labels,
|
420 |
+
).then(
|
421 |
+
fn=update_max_num_labels,
|
422 |
+
inputs=[labels, num_labels],
|
423 |
+
outputs=[num_labels],
|
424 |
)
|
425 |
|
426 |
gr.on(
|
|
|
546 |
fn=generate_pipeline_code,
|
547 |
inputs=[system_prompt, difficulty, clarity, labels, num_labels, num_rows],
|
548 |
outputs=[pipeline_code],
|
549 |
+
).then(
|
550 |
+
fn=update_max_num_labels,
|
551 |
+
inputs=[labels, num_labels],
|
552 |
+
outputs=[num_labels],
|
553 |
)
|
554 |
num_labels.change(
|
555 |
fn=generate_pipeline_code,
|
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -176,7 +176,8 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
176 |
api_key=_get_next_api_key(),
|
177 |
generation_kwargs={
|
178 |
"temperature": 0.8,
|
179 |
-
"max_new_tokens": 256 if is_sample else
|
|
|
180 |
},
|
181 |
),
|
182 |
difficulty=None if difficulty == "mixed" else difficulty,
|
@@ -186,7 +187,7 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
186 |
return textcat_generator
|
187 |
|
188 |
|
189 |
-
def get_labeller_generator(system_prompt, labels, num_labels
|
190 |
labeller_generator = TextClassification(
|
191 |
llm=InferenceEndpointsLLM(
|
192 |
model_id=MODEL,
|
@@ -194,7 +195,7 @@ def get_labeller_generator(system_prompt, labels, num_labels, is_sample):
|
|
194 |
api_key=_get_next_api_key(),
|
195 |
generation_kwargs={
|
196 |
"temperature": 0.8,
|
197 |
-
"max_new_tokens":
|
198 |
},
|
199 |
),
|
200 |
context=system_prompt,
|
|
|
176 |
api_key=_get_next_api_key(),
|
177 |
generation_kwargs={
|
178 |
"temperature": 0.8,
|
179 |
+
"max_new_tokens": 256 if is_sample else 2048,
|
180 |
+
"do_sample": True,
|
181 |
},
|
182 |
),
|
183 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
187 |
return textcat_generator
|
188 |
|
189 |
|
190 |
+
def get_labeller_generator(system_prompt, labels, num_labels):
|
191 |
labeller_generator = TextClassification(
|
192 |
llm=InferenceEndpointsLLM(
|
193 |
model_id=MODEL,
|
|
|
195 |
api_key=_get_next_api_key(),
|
196 |
generation_kwargs={
|
197 |
"temperature": 0.8,
|
198 |
+
"max_new_tokens": 2048,
|
199 |
},
|
200 |
),
|
201 |
context=system_prompt,
|