fix: add seed for more randomized samples
Browse files
src/distilabel_dataset_generator/pipelines/textcat.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
from typing import List
|
2 |
|
3 |
import pandas as pd
|
|
|
4 |
from distilabel.llms import InferenceEndpointsLLM
|
5 |
from distilabel.steps.tasks import (
|
6 |
GenerateTextClassificationData,
|
@@ -88,6 +89,7 @@ def generate_pipeline_code(
|
|
88 |
base_code = f"""
|
89 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
90 |
import os
|
|
|
91 |
from distilabel.llms import InferenceEndpointsLLM
|
92 |
from distilabel.pipeline import Pipeline
|
93 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
@@ -111,6 +113,8 @@ with Pipeline(name="textcat") as pipeline:
|
|
111 |
generation_kwargs={{
|
112 |
"temperature": 0.8,
|
113 |
"max_new_tokens": 2048,
|
|
|
|
|
114 |
}},
|
115 |
),
|
116 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
@@ -151,6 +155,7 @@ with Pipeline(name="textcat") as pipeline:
|
|
151 |
generation_kwargs={{
|
152 |
"temperature": 0.8,
|
153 |
"max_new_tokens": 2048,
|
|
|
154 |
}},
|
155 |
),
|
156 |
n={num_labels},
|
@@ -175,9 +180,10 @@ def get_textcat_generator(difficulty, clarity, is_sample):
|
|
175 |
tokenizer_id=MODEL,
|
176 |
api_key=_get_next_api_key(),
|
177 |
generation_kwargs={
|
178 |
-
"temperature": 0.
|
179 |
"max_new_tokens": 256 if is_sample else 2048,
|
180 |
"do_sample": True,
|
|
|
181 |
},
|
182 |
),
|
183 |
difficulty=None if difficulty == "mixed" else difficulty,
|
|
|
1 |
from typing import List
|
2 |
|
3 |
import pandas as pd
|
4 |
+
import random
|
5 |
from distilabel.llms import InferenceEndpointsLLM
|
6 |
from distilabel.steps.tasks import (
|
7 |
GenerateTextClassificationData,
|
|
|
89 |
base_code = f"""
|
90 |
# Requirements: `pip install distilabel[hf-inference-endpoints]`
|
91 |
import os
|
92 |
+
import random
|
93 |
from distilabel.llms import InferenceEndpointsLLM
|
94 |
from distilabel.pipeline import Pipeline
|
95 |
from distilabel.steps import LoadDataFromDicts, KeepColumns
|
|
|
113 |
generation_kwargs={{
|
114 |
"temperature": 0.8,
|
115 |
"max_new_tokens": 2048,
|
116 |
+
"do_sample": True,
|
117 |
+
"seed": random.randint(0, 2**32 - 1),
|
118 |
}},
|
119 |
),
|
120 |
difficulty={None if difficulty == "mixed" else repr(difficulty)},
|
|
|
155 |
generation_kwargs={{
|
156 |
"temperature": 0.8,
|
157 |
"max_new_tokens": 2048,
|
158 |
+
"do_sample": True,
|
159 |
}},
|
160 |
),
|
161 |
n={num_labels},
|
|
|
180 |
tokenizer_id=MODEL,
|
181 |
api_key=_get_next_api_key(),
|
182 |
generation_kwargs={
|
183 |
+
"temperature": 0.9,
|
184 |
"max_new_tokens": 256 if is_sample else 2048,
|
185 |
"do_sample": True,
|
186 |
+
"seed": random.randint(0, 2**32 - 1),
|
187 |
},
|
188 |
),
|
189 |
difficulty=None if difficulty == "mixed" else difficulty,
|