mtasic85 commited on
Commit
c721106
1 Parent(s): 716dba4

general pretrain data generation

Browse files
Files changed (1) hide show
  1. scripts/prepare_pretrain_dataset.py +1 -333
scripts/prepare_pretrain_dataset.py CHANGED
@@ -1,319 +1,10 @@
1
- import gc
2
  from typing import Optional
 
3
 
4
  from datasets import load_dataset
5
  from litdata import optimize, TokensLoader
6
  from litgpt.tokenizer import Tokenizer
7
- from functools import partial
8
-
9
-
10
- """
11
- def batch_iterator_1(path=None):
12
- # text
13
- if path in (None, 'saillab/taco-datasets'):
14
- dataset = (
15
- load_dataset(path, data_dir=data_dir, split='train')
16
- for data_dir in [
17
- 'multilingual-instruction-tuning-dataset /multilingual-alpaca-52k-gpt-4',
18
- 'multilingual-instruction-tuning-dataset /multilinugal-dolly-15k',
19
- ]
20
- )
21
-
22
- for d in dataset:
23
- for row in d:
24
- for n in row:
25
- yield (
26
- row['instruction'] +
27
- ' ' +
28
- row['input'] +
29
- ' ' +
30
- row['output']
31
- )
32
-
33
- del dataset
34
- gc.collect()
35
-
36
- # text
37
- if path in (None, 'xu-song/cc100-samples'):
38
- dataset = (
39
- load_dataset(path, lang, split='train')
40
- for lang in [
41
- 'am', 'ar', 'as', 'az', 'be', 'bg', 'bn', 'bn_rom', 'br',
42
- 'bs', 'ca', 'cs', 'cy', 'da', 'de', 'el', 'en', 'eo', 'es',
43
- 'et', 'eu', 'fa', 'ff', 'fi', 'fr', 'fy', 'ga', 'gd', 'gl',
44
- 'gn', 'gu', 'ha', 'he', 'hi', 'hi_rom', 'hr', 'ht', 'hu',
45
- 'hy', 'id', 'ig', 'is', 'it', 'ja', 'jv', 'ka', 'kk', 'km',
46
- 'kn', 'ko', 'ku', 'ky', 'la', 'lg', 'li', 'ln', 'lo', 'lt',
47
- 'lv', 'mg', 'mk', 'ml', 'mn', 'mr', 'ms', 'my', 'my_zaw',
48
- 'ne', 'nl', 'no', 'ns', 'om', 'or', 'pa', 'pl', 'ps', 'pt',
49
- 'qu', 'rm', 'ro', 'ru', 'sa', 'si', 'sc', 'sd', 'sk', 'sl',
50
- 'so', 'sq', 'sr', 'ss', 'su', 'sv', 'sw', 'ta', 'ta_rom',
51
- 'te', 'te_rom', 'th', 'tl', 'tn', 'tr', 'ug', 'uk', 'ur',
52
- 'ur_rom', 'uz', 'vi', 'wo', 'xh', 'yi', 'yo',
53
- 'zh-Hans', 'zh-Hant', 'zu',
54
- ]
55
- )
56
-
57
- for d in dataset:
58
- for row in d['text']:
59
- yield row
60
-
61
- del dataset
62
- gc.collect()
63
-
64
- # text
65
- if path in (None, 'ontocord/fineweb-permissive-multilingual-2m'):
66
- dataset = load_dataset(path, split='train')
67
-
68
- for row in dataset['text']:
69
- yield row
70
-
71
- del dataset
72
- gc.collect()
73
-
74
- # text
75
- if path in (None, 'MuskumPillerum/General-Knowledge'):
76
- dataset = load_dataset(path, split='train')
77
-
78
- for row in dataset:
79
- if not row['Question'] or not row['Answer']:
80
- continue
81
-
82
- yield row['Question'] + ' ' + row['Answer']
83
-
84
- del dataset
85
- gc.collect()
86
-
87
- # text
88
- if path in (None, 'yirenc/general_knowledge_boolean'):
89
- for split in ['train', 'validation']:
90
- dataset = load_dataset(path, split=split)
91
-
92
- for row in dataset:
93
- yield row['question'] + '? ' + str(row['answer']) + '. ' + row['passage']
94
-
95
- del dataset
96
- gc.collect()
97
-
98
- # text
99
- if path in (None, 'nampdn-ai/tiny-textbooks'):
100
- for split in ['train', 'test']:
101
- dataset = load_dataset(path, split=split)
102
-
103
- for row in dataset['textbook']:
104
- yield row
105
-
106
- del dataset
107
- gc.collect()
108
-
109
- # code
110
- if path in (None, 'nampdn-ai/tiny-codes'):
111
- dataset = load_dataset(path, split='train')
112
-
113
- for row in dataset:
114
- yield (
115
- row['prompt'] +
116
- ' ' +
117
- row['response']
118
- )
119
-
120
- del dataset
121
- gc.collect()
122
-
123
- # code
124
- if path in (None, 'bigcode/the-stack-smol-xs'):
125
- dataset = (
126
- load_dataset(path, lang, split='train', trust_remote_code=True)
127
- for lang in [
128
- 'ada', 'agda', 'alloy', 'antlr', 'applescript', 'assembly',
129
- 'augeas', 'awk', 'batchfile', 'bison', 'bluespec', 'c',
130
- 'c++', 'c-sharp', 'clojure', 'cmake', 'coffeescript', 'common-lisp',
131
- 'css', 'cuda', 'dart', 'dockerfile', 'elixir',
132
- 'elm', 'emacs-lisp','erlang', 'f-sharp', 'fortran', 'glsl', 'go',
133
- 'groovy', 'haskell','html', 'idris', 'isabelle', 'java',
134
- 'java-server-pages', 'javascript', 'julia', 'kotlin', 'lean',
135
- 'literate-agda', 'literate-coffeescript', 'literate-haskell',
136
- 'lua', 'makefile', 'maple', 'markdown', 'mathematica', 'matlab',
137
- 'ocaml', 'pascal', 'perl', 'php', 'powershell', 'prolog',
138
- 'protocol-buffer', 'python', 'r', 'racket', 'restructuredtext',
139
- 'rmarkdown', 'ruby', 'rust', 'sas', 'scala', 'scheme',
140
- 'shell', 'smalltalk', 'solidity', 'sparql', 'sql', 'stan',
141
- 'standard-ml', 'stata', 'systemverilog', 'tcl', 'tcsh', 'tex',
142
- 'thrift', 'typescript', 'verilog', 'vhdl', 'visual-basic', 'xslt',
143
- 'yacc', 'zig',
144
- ]
145
- )
146
-
147
- for d in dataset:
148
- for row in d['content']:
149
- yield row
150
-
151
- del dataset
152
- gc.collect()
153
-
154
- # text + code
155
- if path in (None, 'm-a-p/CodeFeedback-Filtered-Instruction'):
156
- dataset = load_dataset(path, split='train')
157
 
158
- for row in dataset:
159
- yield (
160
- row['query'] +
161
- ' ' +
162
- row['answer']
163
- )
164
-
165
- del dataset
166
- gc.collect()
167
-
168
- # code
169
- if path in (None, 'jtatman/python-code-dataset-500k'):
170
- dataset = load_dataset(path, split='train')
171
-
172
- for row in dataset:
173
- yield (
174
- row['instruction'] +
175
- ' ' +
176
- row['output']
177
- )
178
-
179
- del dataset
180
- gc.collect()
181
-
182
- # code
183
- if path in (None, 'iamtarun/python_code_instructions_18k_alpaca'):
184
- dataset = load_dataset(path, split='train')
185
-
186
- for row in dataset:
187
- yield (
188
- row['instruction'] +
189
- ' ' +
190
- row['input'] +
191
- ' ' +
192
- row['output']
193
- )
194
-
195
- del dataset
196
- gc.collect()
197
-
198
- # code
199
- if path in (None, 'HuggingFaceH4/CodeAlpaca_20K'):
200
- for split in ['train', 'test']:
201
- dataset = load_dataset(path, split=split)
202
-
203
- for row in dataset:
204
- yield (
205
- row['prompt'] +
206
- ' ' +
207
- row['completion']
208
- )
209
-
210
- del dataset
211
- gc.collect()
212
-
213
- # math
214
- if path in (None, 'gair-prox/open-web-math-pro'):
215
- dataset = load_dataset(path, split='train')
216
-
217
- for row in dataset['text']:
218
- yield row
219
-
220
- del dataset
221
- gc.collect()
222
-
223
- # math
224
- if path in (None, 'rvv-karma/Math-QA'):
225
- for split in ['train', 'val', 'test']:
226
- dataset = load_dataset(path, split=split)
227
-
228
- for row in dataset:
229
- yield (
230
- row['question'] +
231
- ' ' +
232
- row['answer']
233
- )
234
-
235
- del dataset
236
- gc.collect()
237
-
238
- # math
239
- if path in (None, 'ajibawa-2023/Maths-College'):
240
- dataset = load_dataset(path, split='train')
241
-
242
- for row in dataset:
243
- yield (
244
- row['instruction'] +
245
- ' ' +
246
- row['output']
247
- )
248
-
249
- del dataset
250
- gc.collect()
251
-
252
- # math
253
- if path in (None, 'microsoft/orca-math-word-problems-200k'):
254
- dataset = load_dataset(path, split='train')
255
-
256
- for row in dataset:
257
- yield (
258
- row['question'] +
259
- ' ' +
260
- row['answer']
261
- )
262
-
263
- del dataset
264
- gc.collect()
265
-
266
- # math
267
- if path in (None, 'fblgit/simple-math'):
268
- for split in ['train', 'test']:
269
- dataset = load_dataset(path, revision='refs/convert/parquet', split=split)
270
-
271
- for row in dataset:
272
- yield (
273
- str(row['instruction']) +
274
- ' = ' +
275
- str(row['output'])
276
- )
277
-
278
- del dataset
279
- gc.collect()
280
-
281
- # reasoning
282
- if path in (None, 'SkunkworksAI/reasoning-0.01'):
283
- dataset = load_dataset(path, split='train')
284
-
285
- for row in dataset:
286
- yield (
287
- row['instruction'] +
288
- ' ' +
289
- row['reasoning'] +
290
- ' ' +
291
- row['output']
292
- )
293
-
294
- del dataset
295
- gc.collect()
296
-
297
- # emoji
298
- if path in (None, 'badrex/llm-emoji-dataset'):
299
- dataset = load_dataset(path, split='train')
300
-
301
- for row in dataset:
302
- yield (
303
- row['character'] +
304
- ' ' +
305
- row['unicode'] +
306
- ' ' +
307
- row['short description'] +
308
- ' ' +
309
- str(row['tags']) +
310
- ' ' +
311
- row['LLM description']
312
- )
313
-
314
- del dataset
315
- gc.collect()
316
- """
317
 
318
  def batch_iterator(path: str,
319
  name: Optional[str]=None,
@@ -342,29 +33,6 @@ def tokenize_fn(datasets_config, tokenizer=None):
342
  text_ids = tokenizer.encode(text, bos=False, eos=True)
343
  yield text_ids
344
 
345
- """
346
- datasets_names = [
347
- 'saillab/taco-datasets',
348
- # 'xu-song/cc100-samples',
349
- # 'ontocord/fineweb-permissive-multilingual-2m',
350
- 'MuskumPillerum/General-Knowledge',
351
- 'yirenc/general_knowledge_boolean',
352
- 'nampdn-ai/tiny-textbooks',
353
- # 'nampdn-ai/tiny-codes',
354
- 'bigcode/the-stack-smol-xs',
355
- 'm-a-p/CodeFeedback-Filtered-Instruction',
356
- # 'jtatman/python-code-dataset-500k',
357
- 'iamtarun/python_code_instructions_18k_alpaca',
358
- 'HuggingFaceH4/CodeAlpaca_20K',
359
- # 'gair-prox/open-web-math-pro',
360
- 'rvv-karma/Math-QA',
361
- # 'ajibawa-2023/Maths-College',
362
- 'microsoft/orca-math-word-problems-200k',
363
- 'fblgit/simple-math',
364
- # 'SkunkworksAI/reasoning-0.01',
365
- 'badrex/llm-emoji-dataset',
366
- ]
367
- """
368
 
369
  datasets_configs = [
370
  {'path': 'yahma/alpaca-cleaned', 'format': '{instruction} {input} {output}'},
 
 
1
  from typing import Optional
2
+ from functools import partial
3
 
4
  from datasets import load_dataset
5
  from litdata import optimize, TokensLoader
6
  from litgpt.tokenizer import Tokenizer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def batch_iterator(path: str,
10
  name: Optional[str]=None,
 
33
  text_ids = tokenizer.encode(text, bos=False, eos=True)
34
  yield text_ids
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  datasets_configs = [
38
  {'path': 'yahma/alpaca-cleaned', 'format': '{instruction} {input} {output}'},