sdiazlor HF staff commited on
Commit
46f00bc
1 Parent(s): d27c1e6

fix: add seed for more randomized samples

Browse files
src/distilabel_dataset_generator/pipelines/textcat.py CHANGED
@@ -1,6 +1,7 @@
1
  from typing import List
2
 
3
  import pandas as pd
 
4
  from distilabel.llms import InferenceEndpointsLLM
5
  from distilabel.steps.tasks import (
6
  GenerateTextClassificationData,
@@ -88,6 +89,7 @@ def generate_pipeline_code(
88
  base_code = f"""
89
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
90
  import os
 
91
  from distilabel.llms import InferenceEndpointsLLM
92
  from distilabel.pipeline import Pipeline
93
  from distilabel.steps import LoadDataFromDicts, KeepColumns
@@ -111,6 +113,8 @@ with Pipeline(name="textcat") as pipeline:
111
  generation_kwargs={{
112
  "temperature": 0.8,
113
  "max_new_tokens": 2048,
 
 
114
  }},
115
  ),
116
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
@@ -151,6 +155,7 @@ with Pipeline(name="textcat") as pipeline:
151
  generation_kwargs={{
152
  "temperature": 0.8,
153
  "max_new_tokens": 2048,
 
154
  }},
155
  ),
156
  n={num_labels},
@@ -175,9 +180,10 @@ def get_textcat_generator(difficulty, clarity, is_sample):
175
  tokenizer_id=MODEL,
176
  api_key=_get_next_api_key(),
177
  generation_kwargs={
178
- "temperature": 0.8,
179
  "max_new_tokens": 256 if is_sample else 2048,
180
  "do_sample": True,
 
181
  },
182
  ),
183
  difficulty=None if difficulty == "mixed" else difficulty,
 
1
  from typing import List
2
 
3
  import pandas as pd
4
+ import random
5
  from distilabel.llms import InferenceEndpointsLLM
6
  from distilabel.steps.tasks import (
7
  GenerateTextClassificationData,
 
89
  base_code = f"""
90
  # Requirements: `pip install distilabel[hf-inference-endpoints]`
91
  import os
92
+ import random
93
  from distilabel.llms import InferenceEndpointsLLM
94
  from distilabel.pipeline import Pipeline
95
  from distilabel.steps import LoadDataFromDicts, KeepColumns
 
113
  generation_kwargs={{
114
  "temperature": 0.8,
115
  "max_new_tokens": 2048,
116
+ "do_sample": True,
117
+ "seed": random.randint(0, 2**32 - 1),
118
  }},
119
  ),
120
  difficulty={None if difficulty == "mixed" else repr(difficulty)},
 
155
  generation_kwargs={{
156
  "temperature": 0.8,
157
  "max_new_tokens": 2048,
158
+ "do_sample": True,
159
  }},
160
  ),
161
  n={num_labels},
 
180
  tokenizer_id=MODEL,
181
  api_key=_get_next_api_key(),
182
  generation_kwargs={
183
+ "temperature": 0.9,
184
  "max_new_tokens": 256 if is_sample else 2048,
185
  "do_sample": True,
186
+ "seed": random.randint(0, 2**32 - 1),
187
  },
188
  ),
189
  difficulty=None if difficulty == "mixed" else difficulty,