Text Generation
Transformers
PyTorch
Safetensors
Finnish
llama
finnish
text-generation-inference
aapot commited on
Commit
a971b09
1 Parent(s): f8bd133
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
EasyLM/__init__.py ADDED
File without changes
EasyLM/checkpoint.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from ml_collections import ConfigDict
4
+ import mlxu
5
+ import jax
6
+ import jax.numpy as jnp
7
+ import flax
8
+ from flax.serialization import (
9
+ from_bytes, to_bytes, to_state_dict, from_state_dict
10
+ )
11
+ from flax.traverse_util import flatten_dict, unflatten_dict, empty_node
12
+ import msgpack
13
+
14
+ from EasyLM.jax_utils import tree_apply, float_tensor_to_dtype
15
+
16
+
17
+ class StreamingCheckpointer(object):
18
+ """ Custom msgpack checkpointer that saves large train states by serializing
19
+ and saving tensors one by one in a streaming fashion. Avoids running
20
+ out of memory or local TPU disk with default flax checkpointer.
21
+ """
22
+
23
+ @staticmethod
24
+ def get_default_config(updates=None):
25
+ config = ConfigDict()
26
+ config.float_dtype = 'bf16'
27
+ config.save_optimizer_state = False
28
+
29
+ if updates is not None:
30
+ config.update(ConfigDict(updates).copy_and_resolve_references())
31
+ return config
32
+
33
+ def __init__(self, config, checkpoint_dir, enable=True):
34
+ self.config = self.get_default_config(config)
35
+ self.checkpoint_dir = checkpoint_dir
36
+ self.enable = enable
37
+
38
+ def save_checkpoint(self, train_state, filename, gather_fns=None):
39
+ if self.enable:
40
+ path = os.path.join(self.checkpoint_dir, filename)
41
+ else:
42
+ path = '/dev/null'
43
+ self.save_train_state_to_file(
44
+ train_state, path, gather_fns, self.config.float_dtype
45
+ )
46
+
47
+ @staticmethod
48
+ def save_train_state_to_file(train_state, path, gather_fns=None, float_dtype=None):
49
+ train_state = to_state_dict(train_state)
50
+ packer = msgpack.Packer()
51
+ flattend_train_state = flatten_dict(train_state)
52
+ if gather_fns is not None:
53
+ gather_fns = flatten_dict(to_state_dict(gather_fns))
54
+
55
+ with mlxu.open_file(path, "wb") as fout:
56
+ for key, value in flattend_train_state.items():
57
+ if gather_fns is not None:
58
+ value = gather_fns[key](value)
59
+ value = float_tensor_to_dtype(value, float_dtype)
60
+ fout.write(packer.pack((key, to_bytes(value))))
61
+
62
+ def save_pickle(self, obj, filename):
63
+ if self.enable:
64
+ path = os.path.join(self.checkpoint_dir, filename)
65
+ else:
66
+ path = '/dev/null'
67
+ mlxu.save_pickle(obj, path)
68
+
69
+ def save_all(self, train_state, gather_fns, metadata=None, dataset=None, milestone=False):
70
+ step = int(jax.device_get(train_state.step))
71
+ if self.config.save_optimizer_state:
72
+ checkpoint_state = train_state
73
+ checkpoint_name = 'streaming_train_state'
74
+ checkpoint_gather_fns = gather_fns
75
+ else:
76
+ checkpoint_state = train_state.params['params']
77
+ checkpoint_name = 'streaming_params'
78
+ checkpoint_gather_fns = gather_fns.params['params']
79
+
80
+ if milestone:
81
+ # Save a milestone checkpoint that will not be overwritten
82
+ self.save_pickle(metadata, f'metadata_{step}.pkl')
83
+ self.save_pickle(dataset, f'dataset_{step}.pkl')
84
+ self.save_checkpoint(
85
+ checkpoint_state, f'{checkpoint_name}_{step}', checkpoint_gather_fns
86
+ )
87
+ else:
88
+ # Save a normal checkpoint that can be overwritten
89
+ self.save_pickle(metadata, 'metadata.pkl')
90
+ self.save_pickle(dataset, 'dataset.pkl')
91
+ self.save_checkpoint(
92
+ checkpoint_state, f'{checkpoint_name}', checkpoint_gather_fns
93
+ )
94
+
95
+ @staticmethod
96
+ def load_checkpoint(path, target=None, shard_fns=None, remove_dict_prefix=None):
97
+ if shard_fns is not None:
98
+ shard_fns = flatten_dict(
99
+ to_state_dict(shard_fns)
100
+ )
101
+ if remove_dict_prefix is not None:
102
+ remove_dict_prefix = tuple(remove_dict_prefix)
103
+ flattend_train_state = {}
104
+ with mlxu.open_file(path) as fin:
105
+ # 83886080 bytes = 80 MB, which is 16 blocks on GCS
106
+ unpacker = msgpack.Unpacker(fin, read_size=83886080, max_buffer_size=0)
107
+ for key, value in unpacker:
108
+ key = tuple(key)
109
+ if remove_dict_prefix is not None:
110
+ if key[:len(remove_dict_prefix)] == remove_dict_prefix:
111
+ key = key[len(remove_dict_prefix):]
112
+ else:
113
+ continue
114
+
115
+ tensor = from_bytes(None, value)
116
+ if shard_fns is not None:
117
+ tensor = shard_fns[key](tensor)
118
+ flattend_train_state[key] = tensor
119
+
120
+ if target is not None:
121
+ flattened_target = flatten_dict(
122
+ to_state_dict(target), keep_empty_nodes=True
123
+ )
124
+ for key, value in flattened_target.items():
125
+ if key not in flattend_train_state and value == empty_node:
126
+ flattend_train_state[key] = value
127
+
128
+ train_state = unflatten_dict(flattend_train_state)
129
+ if target is None:
130
+ return train_state
131
+
132
+ return from_state_dict(target, train_state)
133
+
134
+ @staticmethod
135
+ def load_flax_checkpoint(path, target=None, shard_fns=None):
136
+ """ Load a standard flax checkpoint that's not saved with the
137
+ msgpack streaming format.
138
+ """
139
+ with mlxu.open_file(path, "rb") as fin:
140
+ encoded_bytes = fin.read()
141
+
142
+ state_dict = flax.serialization.msgpack_restore(encoded_bytes)
143
+ if shard_fns is not None:
144
+ shard_fns = to_state_dict(shard_fns)
145
+ state_dict = tree_apply(shard_fns, state_dict)
146
+
147
+ if target is None:
148
+ return state_dict
149
+ return from_state_dict(target, state_dict)
150
+
151
+ @classmethod
152
+ def load_trainstate_checkpoint(cls, load_from, trainstate_target=None,
153
+ trainstate_shard_fns=None,
154
+ disallow_trainstate=False):
155
+ if trainstate_target is not None:
156
+ params_target = trainstate_target.params['params']
157
+ else:
158
+ params_target = None
159
+
160
+ if trainstate_shard_fns is not None:
161
+ params_shard_fns = trainstate_shard_fns.params['params']
162
+ else:
163
+ params_shard_fns = None
164
+
165
+ load_type, load_path = load_from.split('::', 1)
166
+ if disallow_trainstate:
167
+ assert load_type != 'trainstate', 'Loading full trainstate is not allowed!'
168
+ train_state = None
169
+ restored_params = None
170
+ if load_type == 'trainstate':
171
+ # Load the entire train state in the streaming format
172
+ train_state = cls.load_checkpoint(
173
+ path=load_path,
174
+ target=trainstate_target,
175
+ shard_fns=trainstate_shard_fns,
176
+ )
177
+ elif load_type == 'trainstate_params':
178
+ # Load the params part of the train state in the streaming format
179
+ restored_params = cls.load_checkpoint(
180
+ path=load_path,
181
+ target=params_target,
182
+ shard_fns=params_shard_fns,
183
+ remove_dict_prefix=('params', 'params'),
184
+ )
185
+ restored_params = flax.core.frozen_dict.freeze(
186
+ {'params': restored_params}
187
+ )
188
+ elif load_type == 'params':
189
+ # Load the params in the streaming format
190
+ restored_params = cls.load_checkpoint(
191
+ path=load_path,
192
+ target=params_target,
193
+ shard_fns=params_shard_fns,
194
+ )
195
+ restored_params = flax.core.frozen_dict.freeze(
196
+ {'params': restored_params}
197
+ )
198
+ elif load_type == 'flax_params':
199
+ # Load the params in the standard flax format (non-streaming)
200
+ # This requires the entire params to fit in memory
201
+ restored_params = cls.load_flax_checkpoint(
202
+ path=load_path,
203
+ target=params_target,
204
+ shard_fns=params_shard_fns
205
+ )
206
+ restored_params = flax.core.frozen_dict.freeze(
207
+ {'params': restored_params}
208
+ )
209
+ else:
210
+ raise ValueError(f'Invalid load_from type: {load_type}')
211
+
212
+ return train_state, restored_params
EasyLM/data.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ import time
4
+ from functools import partial
5
+ import json
6
+ from multiprocessing import Pool
7
+
8
+ import h5py
9
+ import mlxu
10
+ from ml_collections.config_dict import config_dict
11
+ from ml_collections import ConfigDict
12
+ from tqdm import tqdm, trange
13
+ import numpy as np
14
+
15
+ from datasets import load_dataset, load_from_disk
16
+
17
+
18
+ class DatasetFactory(object):
19
+ """ Datset builder class. """
20
+
21
+ @staticmethod
22
+ def get_default_config(updates=None):
23
+ config = ConfigDict()
24
+ config.type = 'huggingface'
25
+ config.text_processor = TextProcessor.get_default_config()
26
+ config.huggingface_dataset = HuggingfaceDataset.get_default_config()
27
+ config.json_dataset = JsonDataset.get_default_config()
28
+
29
+ if updates is not None:
30
+ config.update(ConfigDict(updates).copy_and_resolve_references())
31
+ return config
32
+
33
+ @classmethod
34
+ def load_dataset(cls, config, tokenizer, **kwargs):
35
+ config = cls.get_default_config(config)
36
+ text_processor = TextProcessor(config.text_processor, tokenizer)
37
+ if config.type == 'huggingface':
38
+ return HuggingfaceDataset(
39
+ config.huggingface_dataset, tokenizer, text_processor, **kwargs
40
+ )
41
+ elif config.type == 'json':
42
+ return JsonDataset(config.json_dataset, tokenizer, text_processor, **kwargs)
43
+ else:
44
+ raise ValueError(f'Unknown dataset type: {config.type}')
45
+
46
+ def __init__(self):
47
+ raise ValueError('DatasetFactory is a static class and should not be instantiated.')
48
+
49
+
50
+ class TextProcessor(object):
51
+ """ Example processor that converts a dictionary of texts into tokens. """
52
+
53
+ @staticmethod
54
+ def get_default_config(updates=None):
55
+ config = ConfigDict()
56
+ config.fields_from_example = ''
57
+ config.fields = ''
58
+ config.subfield_separator = ' '
59
+ config.add_bos_token = True
60
+ config.add_eos_token = True
61
+ config.prepend_text = ''
62
+ if updates is not None:
63
+ config.update(ConfigDict(updates).copy_and_resolve_references())
64
+ return config
65
+
66
+ def __init__(self, config, tokenizer):
67
+ self.config = self.get_default_config(config)
68
+ assert self.config.fields != '' or self.config.fields_from_example != '', (
69
+ 'Either fields or fields_from_example must be specified.'
70
+ )
71
+ self.tokenizer = tokenizer
72
+
73
+ def __call__(self, example, has_aux=False):
74
+ if has_aux:
75
+ example, *aux = example
76
+ else:
77
+ aux = tuple()
78
+ token_buffer = []
79
+ loss_mask_buffer = []
80
+
81
+ if self.config.add_bos_token:
82
+ token_buffer.append(self.tokenizer.bos_token_id)
83
+ loss_mask_buffer.append(0.0)
84
+
85
+ if self.config.fields_from_example != '':
86
+ fields = example[self.config.fields_from_example].split(',')
87
+ else:
88
+ fields = self.config.fields.split(',')
89
+
90
+ for i, field in enumerate(fields):
91
+ if field.startswith('[') and field.endswith(']'):
92
+ # No loss for this field.
93
+ field = field[1:-1]
94
+ mask = 0.0
95
+ else:
96
+ mask = 1.0
97
+
98
+ if field == '<|bos|>':
99
+ token_buffer.append(self.tokenizer.bos_token_id)
100
+ loss_mask_buffer.append(mask)
101
+ elif field == '<|eos|>':
102
+ token_buffer.append(self.tokenizer.eos_token_id)
103
+ loss_mask_buffer.append(mask)
104
+ else:
105
+ subfields = field.split('+')
106
+ text = self.config.subfield_separator.join(
107
+ [example[subfield] for subfield in subfields]
108
+ )
109
+ if i == 0:
110
+ text = self.config.prepend_text + text
111
+ tokens = self.tokenizer.encode(text)
112
+ token_buffer.extend(tokens)
113
+ loss_mask_buffer.extend([mask for _ in range(len(tokens))])
114
+
115
+ if self.config.add_eos_token:
116
+ token_buffer.append(self.tokenizer.eos_token_id)
117
+ loss_mask_buffer.append(1.0)
118
+
119
+ return token_buffer, loss_mask_buffer, *aux
120
+
121
+
122
+ class HuggingfaceDataset(object):
123
+ """ Huggingface dataset, where the dataset is loaded using the huggingface
124
+ datasets.load_dataset() function.
125
+ """
126
+
127
+ @staticmethod
128
+ def get_default_config(updates=None):
129
+ config = ConfigDict()
130
+ config.path = 'c4'
131
+ config.name = 'en'
132
+ config.split = 'train'
133
+ config.streaming = False
134
+ config.seq_length = 1024
135
+ config.batch_size = 8
136
+ config.always_start_with_bos = False
137
+ config.start_seek_loc = 0
138
+ config.tokens_count_at_start = 0
139
+
140
+ if updates is not None:
141
+ config.update(ConfigDict(updates).copy_and_resolve_references())
142
+ return config
143
+
144
+ def __init__(self, config, tokenizer, text_processor, eval_dataset=False):
145
+ self.config = self.get_default_config(config)
146
+ name = self.config.name if self.config.name != '' else None
147
+ split = self.config.split if self.config.split != '' else None
148
+ self._tokenizer = tokenizer
149
+ self._text_processor = text_processor
150
+ self._dataset = load_from_disk(
151
+ self.config.path
152
+ )[split]
153
+ self._dataset = self._dataset.to_iterable_dataset(num_shards=128 if len(self._dataset) > 128 else len(self._dataset))
154
+ self._eval_dataset = eval_dataset
155
+ self._train_epochs = 0
156
+ self._dataset_loc = self.config.start_seek_loc
157
+ self._total_tokens = self.config.tokens_count_at_start
158
+ self._index = 0
159
+
160
+ def __iter__(self):
161
+ chunk_size = self.config.batch_size * self.config.seq_length
162
+ total_tokens = 0
163
+ while True:
164
+ token_buffer = []
165
+ loss_mask_buffer = []
166
+ for index, example in enumerate(self._dataset):
167
+ self._index = index
168
+ if not self._eval_dataset and self._dataset_loc > index:
169
+ continue
170
+ tokens, loss_masks = self.text_processor(example)
171
+ token_buffer.extend(tokens)
172
+ loss_mask_buffer.extend(loss_masks)
173
+ while len(token_buffer) > chunk_size + 1:
174
+ self._total_tokens += chunk_size
175
+ metrics = {
176
+ 'dataset_example_index': index,
177
+ 'dataset_total_tokens': self._total_tokens,
178
+ 'epoch': self._train_epochs,
179
+ }
180
+ batch = {
181
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
182
+ self.config.batch_size, -1
183
+ ),
184
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
185
+ self.config.batch_size, -1
186
+ ),
187
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
188
+ self.config.batch_size, -1
189
+ ),
190
+ }
191
+ if self.config.always_start_with_bos:
192
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
193
+ yield batch, metrics
194
+ token_buffer = token_buffer[chunk_size:]
195
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
196
+
197
+ if self._eval_dataset:
198
+ break
199
+ else:
200
+ self._dataset_loc = 0
201
+ self._shuffle()
202
+ self._train_epochs += 1
203
+ print(f"TRAIN {self._train_epochs} EPOCH DONE")
204
+
205
+ def _shuffle(self):
206
+ self._dataset = self._dataset.shuffle(buffer_size=100)
207
+
208
+ def get_state_dict(self):
209
+ return dict(
210
+ config=self.config,
211
+ dataset_loc=self._index if self._train_epochs < 1 else 0,
212
+ total_tokens=self._total_tokens,
213
+ epochs=self._train_epochs,
214
+ )
215
+
216
+ def load_state_dict(self, state_dict):
217
+ if 'config' in state_dict:
218
+ self.config.update(ConfigDict(state_dict['config']))
219
+ self._dataset_loc = state_dict.get('dataset_loc', self.config.start_seek_loc)
220
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
221
+ self._train_epochs = state_dict.get('epochs', 0)
222
+
223
+ @property
224
+ def seq_length(self):
225
+ return self.config.seq_length
226
+
227
+ @property
228
+ def tokenizer(self):
229
+ return self._tokenizer
230
+
231
+ @property
232
+ def text_processor(self):
233
+ return self._text_processor
234
+
235
+ @property
236
+ def dataset(self):
237
+ return self._dataset
238
+
239
+ @property
240
+ def vocab_size(self):
241
+ return len(self._tokenizer)
242
+
243
+
244
+ class JsonDataset(object):
245
+ """ JSON dataset, where each line of the data file contains a JSON
246
+ dictionary with text fields.
247
+ """
248
+
249
+ @staticmethod
250
+ def get_default_config(updates=None):
251
+ config = ConfigDict()
252
+ config.path = ''
253
+ config.seq_length = 1024
254
+ config.batch_size = 8
255
+ config.always_start_with_bos = False
256
+ config.start_seek_loc = 0
257
+ config.example_index_at_start = 0
258
+ config.tokens_count_at_start = 0
259
+ config.tokenizer_processes = 1
260
+ config.tokenizer_parallel_chunk_size = 32
261
+ config.tokenizer_parallel_batch_size = 1024
262
+ config.throughput_average_window_size = 200
263
+
264
+ if updates is not None:
265
+ config.update(ConfigDict(updates).copy_and_resolve_references())
266
+ return config
267
+
268
+ def __init__(self, config, tokenizer, text_processor):
269
+ self.config = self.get_default_config(config)
270
+ assert self.config.path != ''
271
+ self._tokenizer = tokenizer
272
+ self._text_processor = text_processor
273
+ self._index = self.config.example_index_at_start
274
+ self._file_loc = self.config.start_seek_loc
275
+ self._total_tokens = self.config.tokens_count_at_start
276
+
277
+ def parse_json(self, line):
278
+ if not line or line == '\n':
279
+ return None
280
+ try:
281
+ data = json.loads(line)
282
+ except json.decoder.JSONDecodeError:
283
+ print(f'Error parsing json line:\n{line}')
284
+ return None
285
+ return data
286
+
287
+ def json_iterator(self):
288
+ with mlxu.open_file(self.config.path, 'r') as fin:
289
+ fin.seek(self._file_loc)
290
+ while True:
291
+ line = fin.readline()
292
+ self._file_loc = fin.tell()
293
+ if not line: # Reached EOF
294
+ self._index = 0
295
+ fin.seek(0)
296
+ continue
297
+
298
+ data = self.parse_json(line)
299
+ if data is not None:
300
+ # JSON parsing succeeded
301
+ yield data, self._file_loc, self._index
302
+ self._index += 1
303
+
304
+ def batched(self, iterator, batch_size):
305
+ batch = []
306
+ for example in iterator:
307
+ batch.append(example)
308
+ if len(batch) == batch_size:
309
+ yield batch
310
+ batch = []
311
+ if len(batch) > 0:
312
+ yield batch
313
+
314
+ def parallel_example_iterator(self):
315
+ if self.config.tokenizer_processes == 1:
316
+ for example, loc, index in self.json_iterator():
317
+ yield self.text_processor((example, loc, index), has_aux=True)
318
+ else:
319
+ process_pool = Pool(self.config.tokenizer_processes)
320
+ batched_iterator = self.batched(
321
+ self.json_iterator(), self.config.tokenizer_parallel_batch_size
322
+ )
323
+ with process_pool as pool:
324
+ map_fn = partial(self.text_processor, has_aux=True)
325
+ next_batch = pool.map_async(
326
+ map_fn, next(batched_iterator),
327
+ chunksize=self.config.tokenizer_parallel_chunk_size
328
+ )
329
+ while True:
330
+ current_batch = next_batch
331
+ next_batch = pool.map_async(
332
+ map_fn, next(batched_iterator),
333
+ chunksize=self.config.tokenizer_parallel_chunk_size
334
+ )
335
+ for example in current_batch.get():
336
+ yield example
337
+
338
+ def __iter__(self):
339
+ chunk_size = self.config.batch_size * self.config.seq_length
340
+ token_buffer = []
341
+ loss_mask_buffer = []
342
+ last_time = 0.0
343
+ step_times = []
344
+ start_time = time.time()
345
+ start_tokens = self._total_tokens
346
+ for tokens, loss_masks, loc, index in self.parallel_example_iterator():
347
+ token_buffer.extend(tokens)
348
+ loss_mask_buffer.extend(loss_masks)
349
+ while len(token_buffer) > chunk_size + 1:
350
+ self._total_tokens += chunk_size
351
+ step_times.append(time.time() - last_time)
352
+ last_time = time.time()
353
+ if len(step_times) > self.config.throughput_average_window_size:
354
+ step_times = step_times[-self.config.throughput_average_window_size:]
355
+ average_throughput = chunk_size / np.mean(step_times)
356
+ accumulated_throughput = (
357
+ (self._total_tokens - start_tokens) / (time.time() - start_time)
358
+ )
359
+ metrics = {
360
+ 'dataset_file_loc': loc,
361
+ 'dataset_example_index': index,
362
+ 'dataset_total_tokens': self._total_tokens,
363
+ 'dataset_accumulated_tps': accumulated_throughput,
364
+ 'dataset_average_tps': average_throughput,
365
+ }
366
+ batch = {
367
+ 'input_tokens': np.array(token_buffer[:chunk_size], dtype=np.int32).reshape(
368
+ self.config.batch_size, -1
369
+ ),
370
+ 'target_tokens': np.array(token_buffer[1:chunk_size + 1], dtype=np.int32).reshape(
371
+ self.config.batch_size, -1
372
+ ),
373
+ 'loss_masks': np.array(loss_mask_buffer[1:chunk_size + 1], dtype=np.float32).reshape(
374
+ self.config.batch_size, -1
375
+ ),
376
+ }
377
+ if self.config.always_start_with_bos:
378
+ batch['input_tokens'][:, 0] = self.tokenizer.bos_token_id
379
+ yield batch, metrics
380
+ token_buffer = token_buffer[chunk_size:]
381
+ loss_mask_buffer = loss_mask_buffer[chunk_size:]
382
+
383
+ def get_state_dict(self):
384
+ return dict(
385
+ config=self.config,
386
+ index=self._index,
387
+ file_loc=self._file_loc,
388
+ total_tokens=self._total_tokens,
389
+ )
390
+
391
+ def load_state_dict(self, state_dict):
392
+ if 'config' in state_dict:
393
+ self.config.update(ConfigDict(state_dict['config']))
394
+ self._index = state_dict.get('index', self.config.example_index_at_start)
395
+ self._file_loc = state_dict.get('file_loc', self.config.start_seek_loc)
396
+ self._total_tokens = state_dict.get('total_tokens', self.config.tokens_count_at_start)
397
+
398
+ @property
399
+ def seq_length(self):
400
+ return self.config.seq_length
401
+
402
+ @property
403
+ def tokenizer(self):
404
+ return self._tokenizer
405
+
406
+ @property
407
+ def text_processor(self):
408
+ return self._text_processor
409
+
410
+ @property
411
+ def vocab_size(self):
412
+ return len(self.tokenizer)
EasyLM/jax_utils.py ADDED
@@ -0,0 +1,401 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4
+ from functools import partial
5
+ import re
6
+ import dataclasses
7
+ import random
8
+ from ml_collections import ConfigDict
9
+ from ml_collections.config_dict.config_dict import placeholder
10
+
11
+ import flax
12
+ import jax
13
+ import jax.numpy as jnp
14
+ from jax.sharding import PartitionSpec as PS
15
+ from jax.sharding import Mesh
16
+ from jax.experimental import mesh_utils
17
+ from jax.experimental.pjit import with_sharding_constraint as _with_sharding_constraint
18
+ from jax.experimental.pjit import pjit
19
+ from jax.interpreters import pxla
20
+ import numpy as np
21
+ from transformers import FlaxLogitsWarper
22
+
23
+
24
+ class JaxRNG(object):
25
+ """ A convenient stateful Jax RNG wrapper. Can be used to wrap RNG inside
26
+ pure function.
27
+ """
28
+
29
+ @classmethod
30
+ def from_seed(cls, seed):
31
+ return cls(jax.random.PRNGKey(seed))
32
+
33
+ def __init__(self, rng):
34
+ self.rng = rng
35
+
36
+ def __call__(self, keys=None):
37
+ if keys is None:
38
+ self.rng, split_rng = jax.random.split(self.rng)
39
+ return split_rng
40
+ elif isinstance(keys, int):
41
+ split_rngs = jax.random.split(self.rng, num=keys + 1)
42
+ self.rng = split_rngs[0]
43
+ return tuple(split_rngs[1:])
44
+ else:
45
+ split_rngs = jax.random.split(self.rng, num=len(keys) + 1)
46
+ self.rng = split_rngs[0]
47
+ return {key: val for key, val in zip(keys, split_rngs[1:])}
48
+
49
+
50
+ class JaxDistributedConfig(object):
51
+ """ Utility class for initializing JAX distributed. """
52
+
53
+ @staticmethod
54
+ def get_default_config(updates=None):
55
+ config = ConfigDict()
56
+ config.initialize_jax_distributed = False
57
+ config.coordinator_address = placeholder(str)
58
+ config.num_processes = placeholder(int)
59
+ config.process_id = placeholder(int)
60
+ config.local_device_ids = placeholder(str)
61
+
62
+ if updates is not None:
63
+ config.update(ConfigDict(updates).copy_and_resolve_references())
64
+ return config
65
+
66
+ @classmethod
67
+ def initialize(cls, config):
68
+ config = cls.get_default_config(config)
69
+ if config.initialize_jax_distributed:
70
+ if config.local_device_ids is not None:
71
+ local_device_ids = [int(x) for x in config.local_device_ids.split(',')]
72
+ else:
73
+ local_device_ids = None
74
+
75
+ jax.distributed.initialize(
76
+ coordinator_address=config.coordinator_address,
77
+ num_processes=config.num_processes,
78
+ process_id=config.process_id,
79
+ local_device_ids=local_device_ids,
80
+ )
81
+
82
+
83
+ class FlaxTemperatureLogitsWarper(FlaxLogitsWarper):
84
+ """ JIT traceable version of FlaxLogitsWarper that performs temperature scaling."""
85
+ def __init__(self, temperature):
86
+ self.temperature = temperature
87
+
88
+ def __call__(self, input_ids, scores, cur_len):
89
+ return scores / jnp.clip(self.temperature, a_min=1e-8)
90
+
91
+
92
+ def make_shard_and_gather_fns(partition_specs, dtype_specs=None):
93
+ """ Create pytree of sharding and gathering functions from pytree of
94
+ partition specs.
95
+ """
96
+ float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
97
+
98
+ def make_to_dtype_fn(dtype_spec):
99
+ def to_dtype(tensor):
100
+ if dtype_specs in float_dtypes and getattr(tensor, 'dtype', None) in float_dtypes:
101
+ # Convert all float tensors to the same dtype
102
+ return tensor.astype(dtype_specs)
103
+ elif hasattr(dtype_spec, 'dtype') and hasattr(tensor, 'dtype'):
104
+ return tensor.astype(dtype_spec.dtype)
105
+ return tensor
106
+ return to_dtype
107
+
108
+ def make_shard_fn(partition_spec, dtype_spec=None):
109
+ jax_shard_function = pjit(
110
+ make_to_dtype_fn(dtype_spec),
111
+ in_shardings=None,
112
+ out_shardings=partition_spec
113
+ )
114
+ def shard_fn(tensor):
115
+ return jax_shard_function(tensor).block_until_ready()
116
+ return shard_fn
117
+
118
+ def make_gather_fn(partition_spec, dtype_spec=None):
119
+ jax_gather_fn = pjit(
120
+ make_to_dtype_fn(dtype_spec),
121
+ in_shardings=partition_spec,
122
+ out_shardings=None
123
+ )
124
+ def gather_fn(tensor):
125
+ return jax.device_get(jax_gather_fn(tensor))
126
+ return gather_fn
127
+
128
+ if dtype_specs is None or dtype_specs in float_dtypes:
129
+ shard_fns = jax.tree_util.tree_map(make_shard_fn, partition_specs)
130
+ gather_fns = jax.tree_util.tree_map(make_gather_fn, partition_specs)
131
+ else:
132
+ shard_fns = jax.tree_util.tree_map(
133
+ make_shard_fn, partition_specs, dtype_specs
134
+ )
135
+ gather_fns = jax.tree_util.tree_map(
136
+ make_gather_fn, partition_specs, dtype_specs
137
+ )
138
+ return shard_fns, gather_fns
139
+
140
+
141
+ def set_random_seed(seed):
142
+ np.random.seed(seed)
143
+ random.seed(seed)
144
+ init_rng(seed)
145
+
146
+
147
+ def get_jax_mesh(axis_dims, names):
148
+ if axis_dims.startswith('!'):
149
+ # Allow splitting a physical mesh axis if needed
150
+ mesh_axis_splitting = True
151
+ axis_dims = axis_dims[1:]
152
+ else:
153
+ mesh_axis_splitting = False
154
+
155
+ if ':' in axis_dims:
156
+ dims = []
157
+ dim_names = []
158
+ for axis in axis_dims.split(','):
159
+ name, dim = axis.split(':')
160
+ assert name in names
161
+ dims.append(int(dim))
162
+ dim_names.append(name)
163
+ assert(set(dim_names) == set(names))
164
+ else:
165
+ dims = [int(x) for x in axis_dims.split(',')]
166
+ dim_names = names
167
+ assert len(dims) == len(names)
168
+ mesh_shape = np.arange(jax.device_count()).reshape(dims).shape
169
+ if mesh_axis_splitting:
170
+ physical_mesh = np.array(jax.devices()).reshape(mesh_shape)
171
+ else:
172
+ physical_mesh = mesh_utils.create_device_mesh(mesh_shape)
173
+ return Mesh(physical_mesh, dim_names)
174
+
175
+
176
+ def names_in_current_mesh(*names):
177
+ """ Check if current mesh axes contain these names. """
178
+ mesh_axis_names = pxla.thread_resources.env.physical_mesh.axis_names
179
+ return set(names) <= set(mesh_axis_names)
180
+
181
+
182
+ def get_names_from_parition_spec(partition_specs):
183
+ """ Return axis names from partition specs. """
184
+ names = set()
185
+ if isinstance(partition_specs, dict):
186
+ partition_specs = partition_specs.values()
187
+ for item in partition_specs:
188
+ if item is None:
189
+ continue
190
+ elif isinstance(item, str):
191
+ names.add(item)
192
+ else:
193
+ names.update(get_names_from_parition_spec(item))
194
+
195
+ return list(names)
196
+
197
+
198
+ def with_sharding_constraint(x, partition_specs):
199
+ """ A smarter version of with_sharding_constraint that only applies the
200
+ constraint if the current mesh contains the axes in the partition specs.
201
+ """
202
+ axis_names = get_names_from_parition_spec(partition_specs)
203
+ if names_in_current_mesh(*axis_names):
204
+ x = _with_sharding_constraint(x, partition_specs)
205
+ return x
206
+
207
+
208
+ def wrap_function_with_rng(rng):
209
+ """ To be used as decorator, automatically bookkeep a RNG for the wrapped function. """
210
+ def wrap_function(function):
211
+ def wrapped(*args, **kwargs):
212
+ nonlocal rng
213
+ rng, split_rng = jax.random.split(rng)
214
+ return function(split_rng, *args, **kwargs)
215
+ return wrapped
216
+ return wrap_function
217
+
218
+
219
+ def init_rng(seed):
220
+ global jax_utils_rng
221
+ jax_utils_rng = JaxRNG.from_seed(seed)
222
+
223
+
224
+ def next_rng(*args, **kwargs):
225
+ global jax_utils_rng
226
+ return jax_utils_rng(*args, **kwargs)
227
+
228
+
229
+ def get_metrics(metrics, unreplicate=False, stack=False):
230
+ if unreplicate:
231
+ metrics = flax.jax_utils.unreplicate(metrics)
232
+ metrics = jax.device_get(metrics)
233
+ if stack:
234
+ return jax.tree_map(lambda *args: np.stack(args), *metrics)
235
+ else:
236
+ return {key: float(val) for key, val in metrics.items()}
237
+
238
+
239
+ def mse_loss(val, target, valid=None):
240
+ if valid is None:
241
+ valid = jnp.ones((*target.shape[:2], 1))
242
+ valid = valid.astype(jnp.float32)
243
+ loss = jnp.mean(
244
+ jnp.where(
245
+ valid > 0.0,
246
+ jnp.square(val - target),
247
+ 0.0
248
+ )
249
+ )
250
+ return loss
251
+
252
+
253
+ def cross_entropy_loss_and_accuracy(logits, tokens, valid=None):
254
+ if valid is None:
255
+ valid = jnp.ones(tokens.shape[:2])
256
+ valid = valid.astype(jnp.float32)
257
+ valid_text_length = jnp.maximum(jnp.sum(valid, axis=-1), 1e-10)
258
+ logits = logits.astype(jnp.float32) # for numerical stability
259
+ token_log_prob = jnp.squeeze(
260
+ jnp.take_along_axis(
261
+ jax.nn.log_softmax(logits, axis=-1),
262
+ jnp.expand_dims(tokens, -1),
263
+ axis=-1,
264
+ ),
265
+ -1,
266
+ )
267
+ token_log_prob = jnp.where(valid > 0.0, token_log_prob, jnp.array(0.0))
268
+ loss = -jnp.mean(jnp.sum(token_log_prob, axis=-1) / valid_text_length)
269
+ correct = jnp.where(
270
+ valid > 0.0,
271
+ jnp.argmax(logits, axis=-1) == tokens,
272
+ jnp.array(False)
273
+ )
274
+ accuracy = jnp.mean(jnp.sum(correct, axis=-1) / valid_text_length)
275
+ return loss, accuracy
276
+
277
+
278
+ def global_norm(tree):
279
+ """ Return the global L2 norm of a pytree. """
280
+ squared = jax.tree_util.tree_map(lambda x: jnp.sum(jnp.square(x)), tree)
281
+ flattened, _ = jax.flatten_util.ravel_pytree(squared)
282
+ return jnp.sqrt(jnp.sum(flattened))
283
+
284
+
285
+ def average_metrics(metrics):
286
+ return jax.tree_map(
287
+ lambda *args: jnp.mean(jnp.stack(args)),
288
+ *metrics
289
+ )
290
+
291
+
292
+ def get_float_dtype_by_name(dtype):
293
+ return {
294
+ 'bf16': jnp.bfloat16,
295
+ 'bfloat16': jnp.bfloat16,
296
+ 'fp16': jnp.float16,
297
+ 'float16': jnp.float16,
298
+ 'fp32': jnp.float32,
299
+ 'float32': jnp.float32,
300
+ 'fp64': jnp.float64,
301
+ 'float64': jnp.float64,
302
+ }[dtype]
303
+
304
+
305
+ def float_tensor_to_dtype(tensor, dtype):
306
+ if dtype is None or dtype == '':
307
+ return tensor
308
+ if isinstance(dtype, str):
309
+ dtype = get_float_dtype_by_name(dtype)
310
+ float_dtypes = (jnp.bfloat16, jnp.float16, jnp.float32, jnp.float64)
311
+ if getattr(tensor, 'dtype', None) in float_dtypes:
312
+ tensor = tensor.astype(dtype)
313
+ return tensor
314
+
315
+
316
+ def float_to_dtype(tree, dtype):
317
+ return jax.tree_util.tree_map(
318
+ partial(float_tensor_to_dtype, dtype=dtype), tree
319
+ )
320
+
321
+
322
+ def get_gradient_checkpoint_policy(name):
323
+ return {
324
+ 'everything_saveable': jax.checkpoint_policies.everything_saveable,
325
+ 'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
326
+ 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
327
+ 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
328
+ }[name]
329
+
330
+
331
+ def tree_path_to_string(path, sep=None):
332
+ keys = []
333
+ for key in path:
334
+ if isinstance(key, jax.tree_util.SequenceKey):
335
+ keys.append(str(key.idx))
336
+ elif isinstance(key, jax.tree_util.DictKey):
337
+ keys.append(str(key.key))
338
+ elif isinstance(key, jax.tree_util.GetAttrKey):
339
+ keys.append(str(key.name))
340
+ elif isinstance(key, jax.tree_util.FlattenedIndexKey):
341
+ keys.append(str(key.key))
342
+ else:
343
+ keys.append(str(key))
344
+ if sep is None:
345
+ return tuple(keys)
346
+ return sep.join(keys)
347
+
348
+
349
+ def flatten_tree(xs, is_leaf=None, sep=None):
350
+ flattened, _ = jax.tree_util.tree_flatten_with_path(xs, is_leaf=is_leaf)
351
+ output = {}
352
+ for key, val in flattened:
353
+ output[tree_path_to_string(key, sep=sep)] = val
354
+ return output
355
+
356
+
357
+ def named_tree_map(f, tree, *rest, is_leaf=None, sep=None):
358
+ """ An extended version of jax.tree_util.tree_map, where the mapped function
359
+ f takes both the name (path) and the tree leaf as input.
360
+ """
361
+ return jax.tree_util.tree_map_with_path(
362
+ lambda path, x, *r: f(tree_path_to_string(path, sep=sep), x, *r),
363
+ tree, *rest,
364
+ is_leaf=is_leaf
365
+ )
366
+
367
+
368
+ def match_partition_rules(rules, params):
369
+ """ Returns a pytree of PartitionSpec according to rules. Supports handling
370
+ Flax TrainState and Optax optimizer state.
371
+ """
372
+ def get_partition_spec(name, leaf):
373
+ if len(leaf.shape) == 0 or np.prod(leaf.shape) == 1:
374
+ """ Don't partition scalar values. """
375
+ return PS()
376
+ for rule, ps in rules:
377
+ if re.search(rule, name) is not None:
378
+ return ps
379
+ raise ValueError(f'Partition rule not found for param: {name}')
380
+ return named_tree_map(get_partition_spec, params, sep='/')
381
+
382
+
383
+ def get_weight_decay_mask(exclusions):
384
+ """ Return a weight decay mask function that computes the pytree masks
385
+ according to the given exclusion rules.
386
+ """
387
+ def decay(name, _):
388
+ for rule in exclusions:
389
+ if re.search(rule, name) is not None:
390
+ return False
391
+ return True
392
+
393
+ def weight_decay_mask(params):
394
+ return named_tree_map(decay, params, sep='/')
395
+
396
+ return weight_decay_mask
397
+
398
+
399
+ def tree_apply(fns, tree):
400
+ """ Apply a pytree of functions to the pytree. """
401
+ return jax.tree_util.tree_map(lambda fn, x: fn(x), fns, tree)
EasyLM/models/__init__.py ADDED
File without changes
EasyLM/models/gptj/__init__.py ADDED
File without changes
EasyLM/models/gptj/gptj_model.py ADDED
@@ -0,0 +1,1054 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The EleutherAI and The HuggingFace Inc. team.
3
+ # Modifications copyright 2022 Xinyang Geng
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ from functools import partial
19
+ from typing import Optional, Tuple
20
+ import json
21
+
22
+ import numpy as np
23
+
24
+ import flax.linen as nn
25
+ import jax
26
+ import jax.numpy as jnp
27
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
28
+ from flax.linen import combine_masks, make_causal_mask
29
+ from flax.linen.attention import dot_product_attention_weights
30
+ from flax.traverse_util import flatten_dict, unflatten_dict
31
+ from jax import lax
32
+ from flax.linen import partitioning as nn_partitioning
33
+
34
+ from transformers.configuration_utils import PretrainedConfig
35
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
36
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
37
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
38
+ from transformers.generation.flax_logits_process import FlaxLogitsProcessorList
39
+ from transformers import AutoTokenizer
40
+ from jax.sharding import PartitionSpec
41
+
42
+ from ml_collections import ConfigDict
43
+ from ml_collections.config_dict import config_dict
44
+ from mlxu import function_args_to_config, load_pickle, open_file
45
+
46
+ from EasyLM.jax_utils import (
47
+ with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
48
+ )
49
+
50
+
51
+ """
52
+ The follow code is taken from
53
+ transformers/src/transformers/models/gptj/configuration_gptj.py
54
+ and modified to work with EasyLM.
55
+ """
56
+
57
+
58
+ GPTJ_STANDARD_CONFIGS = {
59
+ '6b': {
60
+ "vocab_size": 50400,
61
+ "n_positions": 2048,
62
+ "n_embd": 4096,
63
+ "n_layer": 28,
64
+ "n_head": 16,
65
+ "rotary_dim": 64,
66
+ "n_inner": None,
67
+ "activation_function": "gelu_new",
68
+ "layer_norm_epsilon": 1e-5,
69
+ "initializer_range": 0.02,
70
+ "scale_attn_weights": True,
71
+ "use_cache": True,
72
+ "bos_token_id": 50256,
73
+ "eos_token_id": 50256,
74
+ "tie_word_embeddings": False,
75
+ "n_real_tokens": 50257,
76
+ }
77
+ }
78
+
79
+
80
+ class GPTJConfig(PretrainedConfig):
81
+ r"""
82
+ This is the configuration class to store the configuration of a [`GPTJModel`]. It is used to instantiate a GPT-J
83
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
84
+ defaults will yield a similar configuration to that of the GPT-J
85
+ [EleutherAI/gpt-j-6B](https://huggingface.co/EleutherAI/gpt-j-6B) architecture. Configuration objects inherit from
86
+ [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`]
87
+ for more information.
88
+ Args:
89
+ vocab_size (`int`, *optional*, defaults to 50400):
90
+ Vocabulary size of the GPT-J model. Defines the number of different tokens that can be represented by the
91
+ `inputs_ids` passed when calling [`GPTJModel`].
92
+ n_positions (`int`, *optional*, defaults to 2048):
93
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
94
+ just in case (e.g., 512 or 1024 or 2048).
95
+ n_embd (`int`, *optional*, defaults to 4096):
96
+ Dimensionality of the embeddings and hidden states.
97
+ n_layer (`int`, *optional*, defaults to 28):
98
+ Number of hidden layers in the Transformer encoder.
99
+ n_head (`int`, *optional*, defaults to 16):
100
+ Number of attention heads for each attention layer in the Transformer encoder.
101
+ rotary_dim (`int`, *optional*, defaults to 64):
102
+ Number of dimensions in the embedding that Rotary Position Embedding is applied to.
103
+ n_inner (`int`, *optional*, defaults to 0):
104
+ Dimensionality of the inner feed-forward layers. 0 will set it to 4 times n_embd
105
+ activation_function (`str`, *optional*, defaults to `"gelu_new"`):
106
+ Activation function, to be selected in the list `["relu", "silu", "gelu", "tanh", "gelu_new"]`.
107
+ resid_pdrop (`float`, *optional*, defaults to 0.1):
108
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
109
+ embd_pdrop (`int`, *optional*, defaults to 0.1):
110
+ The dropout ratio for the embeddings.
111
+ attn_pdrop (`float`, *optional*, defaults to 0.1):
112
+ The dropout ratio for the attention.
113
+ layer_norm_epsilon (`float`, *optional*, defaults to 1e-5):
114
+ The epsilon to use in the layer normalization layers.
115
+ initializer_range (`float`, *optional*, defaults to 0.02):
116
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
117
+ scale_attn_weights (`bool`, *optional*, defaults to `True`):
118
+ Scale attention weights by dividing by sqrt(hidden_size).
119
+ use_cache (`bool`, *optional*, defaults to `True`):
120
+ Whether or not the model should return the last key/values attentions (not used by all models).
121
+ Example:
122
+ ```python
123
+ >>> from transformers import GPTJModel, GPTJConfig
124
+ >>> # Initializing a GPT-J 6B configuration
125
+ >>> configuration = GPTJConfig()
126
+ >>> # Initializing a model from the configuration
127
+ >>> model = GPTJModel(configuration)
128
+ >>> # Accessing the model configuration
129
+ >>> configuration = model.config
130
+ ```"""
131
+ model_type = "gptj"
132
+ attribute_map = {
133
+ "max_position_embeddings": "n_positions",
134
+ "hidden_size": "n_embd",
135
+ "num_attention_heads": "n_head",
136
+ "num_hidden_layers": "n_layer",
137
+ }
138
+
139
+ def __init__(
140
+ self,
141
+ vocab_size=50400,
142
+ n_positions=2048,
143
+ n_embd=4096,
144
+ n_layer=28,
145
+ n_head=16,
146
+ rotary_dim=64,
147
+ n_inner=None,
148
+ activation_function="gelu_new",
149
+ resid_pdrop=0.0,
150
+ embd_pdrop=0.0,
151
+ attn_pdrop=0.0,
152
+ layer_norm_epsilon=1e-5,
153
+ initializer_range=0.02,
154
+ scale_attn_weights=True,
155
+ use_cache=True,
156
+ bos_token_id=50256,
157
+ eos_token_id=50256,
158
+ tie_word_embeddings=False,
159
+ gradient_checkpointing=True,
160
+ gradient_checkpointing_policy='nothing_saveable',
161
+ n_real_tokens=50257,
162
+ fcm_min_ratio=0.0,
163
+ fcm_max_ratio=0.0,
164
+ **kwargs
165
+ ):
166
+ self.vocab_size = vocab_size
167
+ self.n_positions = n_positions
168
+ self.n_embd = n_embd
169
+ self.n_layer = n_layer
170
+ self.n_head = n_head
171
+ self.n_inner = n_inner
172
+ self.rotary_dim = rotary_dim
173
+ self.activation_function = activation_function
174
+ self.resid_pdrop = resid_pdrop
175
+ self.embd_pdrop = embd_pdrop
176
+ self.attn_pdrop = attn_pdrop
177
+ self.layer_norm_epsilon = layer_norm_epsilon
178
+ self.initializer_range = initializer_range
179
+ self.scale_attn_weights = scale_attn_weights
180
+ self.use_cache = use_cache
181
+ self.gradient_checkpointing = gradient_checkpointing
182
+ self.gradient_checkpointing_policy = gradient_checkpointing_policy
183
+ self.n_real_tokens = n_real_tokens
184
+ self.fcm_min_ratio = fcm_min_ratio
185
+ self.fcm_max_ratio = fcm_max_ratio
186
+ if self.n_real_tokens is None:
187
+ self.n_real_tokens = self.vocab_size
188
+
189
+ self.bos_token_id = bos_token_id
190
+ self.eos_token_id = eos_token_id
191
+
192
+ super().__init__(
193
+ bos_token_id=bos_token_id, eos_token_id=eos_token_id, tie_word_embeddings=tie_word_embeddings, **kwargs
194
+ )
195
+
196
+ @classmethod
197
+ def get_default_config(cls, updates=None):
198
+ none_arg_types = dict(
199
+ n_inner=int,
200
+ rotary_dim=int,
201
+ )
202
+ config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
203
+
204
+ if updates is not None:
205
+ config.update(ConfigDict(updates).copy_and_resolve_references())
206
+
207
+ return config
208
+
209
+ @staticmethod
210
+ def get_jax_mesh(axis_dims):
211
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
212
+
213
+ @staticmethod
214
+ def get_partition_rules():
215
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
216
+ the beginning rules match first. It is important to use
217
+ PartitionSpec() instead of None here because JAX does not treat
218
+ None as a pytree leaf.
219
+ """
220
+ return (
221
+ ('transformer/wte/embedding', PartitionSpec('mp', 'fsdp')),
222
+ ('attn/(k_proj|q_proj|v_proj)/kernel', PartitionSpec('fsdp', 'mp')),
223
+ ('attn/out_proj/kernel', PartitionSpec('mp', 'fsdp')),
224
+ ('mlp/fc_in/kernel', PartitionSpec('fsdp', 'mp')),
225
+ ('mlp/fc_in/bias', PartitionSpec('mp')),
226
+ ('mlp/fc_out/kernel', PartitionSpec('mp', 'fsdp')),
227
+ ('mlp/fc_out/bias', PartitionSpec()),
228
+ ('ln_[0-9]+/bias', PartitionSpec()),
229
+ ('[0-9]+/ln_[0-9]+/scale', PartitionSpec()),
230
+ ('ln_f/bias', PartitionSpec()),
231
+ ('ln_f/scale', PartitionSpec()),
232
+ ('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
233
+ ('lm_head/bias', PartitionSpec('mp')),
234
+ ('.*', PartitionSpec()),
235
+ )
236
+
237
+ @staticmethod
238
+ def get_weight_decay_exclusions():
239
+ return (
240
+ 'ln_[0-9]+/bias', 'ln_[0-9]+/scale', 'ln_f/bias', 'ln_f/scale',
241
+ 'bias'
242
+ )
243
+
244
+ @staticmethod
245
+ def rng_keys():
246
+ return ('params', 'dropout', 'fcm')
247
+
248
+ @staticmethod
249
+ def get_tokenizer_config(updates=None):
250
+ config = ConfigDict()
251
+ config.name = 'EleutherAI/gpt-j-6B'
252
+ config.bos_token = '<|endoftext|>'
253
+ config.eos_token = '<|endoftext|>'
254
+ config.pad_token = '<|extratoken_40|>'
255
+ config.cls_token = '<|extratoken_41|>'
256
+ config.mask_token = '<|extratoken_42|>'
257
+
258
+ if updates is not None:
259
+ config.update(ConfigDict(updates).copy_and_resolve_references())
260
+
261
+ return config
262
+
263
+ @classmethod
264
+ def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
265
+ config = cls.get_tokenizer_config(config)
266
+ return AutoTokenizer.from_pretrained(
267
+ config.name,
268
+ bos_token=config.bos_token,
269
+ eos_token=config.eos_token,
270
+ pad_token=config.pad_token,
271
+ cls_token=config.cls_token,
272
+ mask_token=config.mask_token,
273
+ padding_side=padding_side,
274
+ truncation_side=truncation_side,
275
+ )
276
+
277
+ @staticmethod
278
+ def load_pretrained(name, dtype=jnp.float32):
279
+ with jax.default_device(jax.devices("cpu")[0]):
280
+ params = FlaxGPTJForCausalLM.from_pretrained(
281
+ name, _do_init=False, dtype=dtype
282
+ )[1]
283
+ params = freeze({'params': params})
284
+ return jax.device_get(params)
285
+
286
+ @classmethod
287
+ def load_config(cls, path):
288
+ if path in GPTJ_STANDARD_CONFIGS:
289
+ return cls.from_dict(GPTJ_STANDARD_CONFIGS[path])
290
+ load_type, load_path = path.split('::', 1)
291
+ if load_type == 'pickle':
292
+ return cls.from_dict(load_pickle(load_path)['gptj_config'])
293
+ elif load_type == 'json':
294
+ with open_file(load_path, 'r') as fin:
295
+ raw_config = fin.read()
296
+ return cls.from_dict(json.loads(raw_config))
297
+ elif load_type == 'huggingface':
298
+ return cls.from_pretrained(load_path)
299
+ else:
300
+ raise ValueError(f'Unsupported load config type: {load_type}')
301
+
302
+
303
+ """
304
+ The follow code is taken from
305
+ transformers/src/transformers/models/gptj/modeling_flax_gptj.py
306
+ and modified to work with EasyLM.
307
+ """
308
+
309
+ logger = logging.get_logger(__name__)
310
+
311
+ _CHECKPOINT_FOR_DOC = "gptj"
312
+ _CONFIG_FOR_DOC = "GPTJConfig"
313
+
314
+ remat = nn_partitioning.remat
315
+
316
+
317
+ GPTJ_START_DOCSTRING = r"""
318
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
319
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
320
+ etc.)
321
+ This model is also a Flax Linen
322
+ [flax.nn.Module](https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html) subclass. Use it as a
323
+ regular Flax Module and refer to the Flax documentation for all matter related to general usage and behavior.
324
+ Finally, this model supports inherent JAX features such as:
325
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
326
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
327
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
328
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
329
+ Parameters:
330
+ config ([`GPTJConfig`]): Model configuration class with all the parameters of the model.
331
+ Initializing with a config file does not load the weights associated with the model, only the
332
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
333
+ dtype (`jax.numpy.dtype`, *optional*, defaults to `jax.numpy.float32`):
334
+ The data type of the computation. Can be one of `jax.numpy.float32`, `jax.numpy.float16` (on GPUs) and
335
+ `jax.numpy.bfloat16` (on TPUs).
336
+ This can be used to enable mixed-precision training or half-precision inference on GPUs or TPUs. If
337
+ specified all the computation will be performed with the given `dtype`.
338
+ **Note that this only specifies the dtype of the computation and does not influence the dtype of model
339
+ parameters.**
340
+ If you wish to change the dtype of the model parameters, see [`~FlaxPreTrainedModel.to_fp16`] and
341
+ [`~FlaxPreTrainedModel.to_bf16`].
342
+ """
343
+
344
+ GPTJ_INPUTS_DOCSTRING = r"""
345
+ Args:
346
+ input_ids (`numpy.ndarray` of shape `(batch_size, input_ids_length)`):
347
+ `input_ids_length` = `sequence_length`. Indices of input sequence tokens in the vocabulary.
348
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
349
+ [`PreTrainedTokenizer.__call__`] for details.
350
+ [What are input IDs?](../glossary#input-ids)
351
+ attention_mask (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
352
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
353
+ - 1 for tokens that are **not masked**,
354
+ - 0 for tokens that are **masked**.
355
+ [What are attention masks?](../glossary#attention-mask)
356
+ position_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`, *optional*):
357
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
358
+ config.max_position_embeddings - 1]`.
359
+ past_key_values (`Dict[str, np.ndarray]`, *optional*, returned by `init_cache` or when passing previous `past_key_values`):
360
+ Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast
361
+ auto-regressive decoding. Pre-computed key and value hidden-states are of shape *[batch_size, max_length]*.
362
+ output_attentions (`bool`, *optional*):
363
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
364
+ tensors for more detail.
365
+ output_hidden_states (`bool`, *optional*):
366
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
367
+ more detail.
368
+ return_dict (`bool`, *optional*):
369
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
370
+ """
371
+
372
+
373
+
374
+ def create_sinusoidal_positions(num_pos, dim):
375
+ inv_freq = 1.0 / (10000 ** (np.arange(0, dim, 2) / dim))
376
+ sinusoid_inp = np.einsum("i , j -> i j", np.arange(num_pos), inv_freq).astype("float32")
377
+ sin, cos = np.sin(sinusoid_inp), np.cos(sinusoid_inp)
378
+
379
+ sentinel = dim // 2 + dim % 2
380
+ out = np.zeros((num_pos, dim))
381
+ out[:, 0:sentinel] = sin
382
+ out[:, sentinel:] = cos
383
+
384
+ return jnp.array(out)
385
+
386
+
387
+ def rotate_every_two(tensor):
388
+ rotate_half_tensor = jnp.stack((-tensor[:, :, :, 1::2], tensor[:, :, :, ::2]), axis=-1)
389
+ rotate_half_tensor = rotate_half_tensor.reshape(rotate_half_tensor.shape[:-2] + (-1,))
390
+ return rotate_half_tensor
391
+
392
+
393
+ def apply_rotary_pos_emb(tensor, sincos):
394
+ sin_pos, cos_pos = sincos
395
+ sin_pos = sin_pos[:, :, None, :].repeat(2, 3)
396
+ cos_pos = cos_pos[:, :, None, :].repeat(2, 3)
397
+ return (tensor * cos_pos) + (rotate_every_two(tensor) * sin_pos)
398
+
399
+
400
+ class FlaxGPTJAttention(nn.Module):
401
+ config: GPTJConfig
402
+ dtype: jnp.dtype = jnp.float32
403
+ causal: bool = True
404
+ is_cross_attention: bool = False
405
+
406
+ def setup(self):
407
+ config = self.config
408
+ self.embed_dim = config.hidden_size
409
+ self.num_heads = config.num_attention_heads
410
+ self.head_dim = self.embed_dim // self.num_heads
411
+
412
+ self.rotary_dim = config.rotary_dim
413
+
414
+ dense = partial(
415
+ nn.Dense,
416
+ self.embed_dim,
417
+ use_bias=False,
418
+ dtype=self.dtype,
419
+ kernel_init=jax.nn.initializers.variance_scaling(
420
+ scale=1.0, mode='fan_in',
421
+ distribution='normal',
422
+ )
423
+ )
424
+
425
+ self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
426
+ self.out_proj = dense()
427
+
428
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
429
+
430
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings), dtype="bool"), dtype="bool")
431
+
432
+ if self.rotary_dim is not None and self.rotary_dim > 0:
433
+ pos_embd_dim = self.rotary_dim
434
+ else:
435
+ pos_embd_dim = self.embed_dim // self.num_heads
436
+ self.embed_positions = create_sinusoidal_positions(config.max_position_embeddings, pos_embd_dim)
437
+
438
+ def _split_heads(self, hidden_states):
439
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
440
+
441
+ def _merge_heads(self, hidden_states):
442
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
443
+
444
+ @nn.compact
445
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
446
+ """
447
+ This function takes projected key, value states from a single input token and concatenates the states to cached
448
+ states from previous steps. This function is slighly adapted from the official Flax repository:
449
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
450
+ """
451
+ # detect if we're initializing by absence of existing cache data.
452
+ is_initialized = self.has_variable("cache", "cached_key")
453
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
454
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
455
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
456
+
457
+ if is_initialized:
458
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
459
+ # update key, value caches with our new 1d spatial slices
460
+ cur_index = cache_index.value
461
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
462
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
463
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
464
+ cached_key.value = key
465
+ cached_value.value = value
466
+ num_updated_cache_vectors = query.shape[1]
467
+ cache_index.value = cache_index.value + num_updated_cache_vectors
468
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
469
+ pad_mask = jnp.broadcast_to(
470
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
471
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
472
+ )
473
+ attention_mask = combine_masks(pad_mask, attention_mask)
474
+ return key, value, attention_mask
475
+
476
+ def __call__(
477
+ self,
478
+ hidden_states,
479
+ attention_mask,
480
+ position_ids,
481
+ deterministic: bool = True,
482
+ init_cache: bool = False,
483
+ output_attentions: bool = False,
484
+ fcm_mask=None,
485
+ ):
486
+
487
+ query = self.q_proj(hidden_states)
488
+ key = self.k_proj(hidden_states)
489
+ value = self.v_proj(hidden_states)
490
+
491
+ query = self._split_heads(query)
492
+ key = self._split_heads(key)
493
+ value = self._split_heads(value)
494
+
495
+ sincos = jnp.take(self.embed_positions, position_ids, axis=0)
496
+ sincos = jnp.split(sincos, 2, axis=-1)
497
+ # Rotary position embeddings induce some weird issues in multi-host environments, so we remove activation-sharding for keys/query vectors to fix this.
498
+ # key = with_sharding_constraint(key, PartitionSpec("dp", None, None, None))
499
+ # query = with_sharding_constraint(query, PartitionSpec("dp", None, None, None))
500
+ if self.rotary_dim is not None and self.rotary_dim > 0:
501
+ k_rot = key[:, :, :, : self.rotary_dim]
502
+ k_pass = key[:, :, :, self.rotary_dim :]
503
+
504
+ q_rot = query[:, :, :, : self.rotary_dim]
505
+ q_pass = query[:, :, :, self.rotary_dim :]
506
+
507
+ k_rot = apply_rotary_pos_emb(k_rot, sincos)
508
+ q_rot = apply_rotary_pos_emb(q_rot, sincos)
509
+
510
+ key = jnp.concatenate([k_rot, k_pass], axis=-1)
511
+ query = jnp.concatenate([q_rot, q_pass], axis=-1)
512
+ else:
513
+ key = apply_rotary_pos_emb(key, sincos)
514
+ query = apply_rotary_pos_emb(query, sincos)
515
+
516
+ query_length, key_length = query.shape[1], key.shape[1]
517
+
518
+ if self.has_variable("cache", "cached_key"):
519
+ mask_shift = self.variables["cache"]["cache_index"]
520
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
521
+ causal_mask = lax.dynamic_slice(
522
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
523
+ )
524
+ else:
525
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
526
+
527
+ batch_size = hidden_states.shape[0]
528
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
529
+
530
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
531
+ if self.causal:
532
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
533
+ else:
534
+ attention_mask = attention_mask
535
+
536
+ dropout_rng = None
537
+ if not deterministic and self.config.attn_pdrop > 0.0:
538
+ dropout_rng = self.make_rng("dropout")
539
+
540
+ # During fast autoregressive decoding, we feed one position at a time,
541
+ # and cache the keys and values step by step.
542
+ if self.has_variable("cache", "cached_key") or init_cache:
543
+ key, value, attention_mask = self._concatenate_to_cache(key, value, query, attention_mask)
544
+
545
+ # transform boolean mask into float mask
546
+ attention_bias = lax.select(
547
+ attention_mask > 0,
548
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
549
+ jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
550
+ )
551
+
552
+ # usual dot product attention
553
+ attn_weights = dot_product_attention_weights(
554
+ query,
555
+ key,
556
+ bias=attention_bias,
557
+ dropout_rng=dropout_rng,
558
+ dropout_rate=self.config.attn_pdrop,
559
+ deterministic=deterministic,
560
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
561
+ precision=None,
562
+ )
563
+
564
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
565
+ attn_output = self._merge_heads(attn_output)
566
+ attn_output = self.out_proj(attn_output)
567
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
568
+
569
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
570
+ return outputs
571
+
572
+
573
+ class FlaxGPTJMLP(nn.Module):
574
+ config: GPTJConfig
575
+ intermediate_size: int
576
+ dtype: jnp.dtype = jnp.float32
577
+
578
+ def setup(self):
579
+ embed_dim = self.config.hidden_size
580
+ kernel_init=jax.nn.initializers.variance_scaling(
581
+ scale=1.0, mode='fan_in',
582
+ distribution='normal',
583
+ )
584
+
585
+ self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
586
+ self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)
587
+
588
+ self.act = ACT2FN[self.config.activation_function]
589
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
590
+
591
+ def __call__(self, hidden_states, deterministic: bool = True):
592
+ hidden_states = self.fc_in(hidden_states)
593
+ hidden_states = self.act(hidden_states)
594
+ hidden_states = self.fc_out(hidden_states)
595
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
596
+ return hidden_states
597
+
598
+
599
+ class FlaxGPTJBlock(nn.Module):
600
+ config: GPTJConfig
601
+ dtype: jnp.dtype = jnp.float32
602
+
603
+ def setup(self):
604
+ hidden_size = self.config.hidden_size
605
+ inner_dim = self.config.n_inner if self.config.n_inner is not None else 4 * hidden_size
606
+
607
+ self.ln_1 = nn.LayerNorm(
608
+ epsilon=self.config.layer_norm_epsilon,
609
+ dtype=jnp.promote_types(self.dtype, jnp.float32)
610
+ )
611
+ self.attn = FlaxGPTJAttention(self.config, dtype=self.dtype)
612
+
613
+ self.mlp = FlaxGPTJMLP(self.config, inner_dim, dtype=self.dtype)
614
+
615
+ def __call__(
616
+ self,
617
+ hidden_states,
618
+ attention_mask=None,
619
+ position_ids=None,
620
+ deterministic: bool = True,
621
+ init_cache: bool = False,
622
+ output_attentions: bool = False,
623
+ fcm_mask=None,
624
+ ):
625
+ residual = hidden_states
626
+ hidden_states = self.ln_1(hidden_states)
627
+ attn_outputs = self.attn(
628
+ hidden_states,
629
+ attention_mask=attention_mask,
630
+ position_ids=position_ids,
631
+ deterministic=deterministic,
632
+ init_cache=init_cache,
633
+ output_attentions=output_attentions,
634
+ fcm_mask=fcm_mask,
635
+ )
636
+ attn_output = attn_outputs[0]
637
+
638
+ feed_forward_hidden_states = self.mlp(hidden_states, deterministic=deterministic)
639
+ # residual connection
640
+ hidden_states = attn_output + feed_forward_hidden_states + residual
641
+
642
+ return (hidden_states,) + attn_outputs[1:]
643
+
644
+
645
+ class FlaxGPTJPreTrainedModel(FlaxPreTrainedModel):
646
+ """
647
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
648
+ models.
649
+ """
650
+
651
+ config_class = GPTJConfig
652
+ base_model_prefix = "transformer"
653
+ module_class: nn.Module = None
654
+
655
+ def __init__(
656
+ self,
657
+ config: GPTJConfig,
658
+ input_shape: Tuple = (1, 1),
659
+ seed: int = 0,
660
+ dtype: jnp.dtype = jnp.float32,
661
+ _do_init: bool = True,
662
+ **kwargs,
663
+ ):
664
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
665
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
666
+
667
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
668
+ # init input tensors
669
+ input_ids = jnp.zeros(input_shape, dtype="i4")
670
+ attention_mask = jnp.ones_like(input_ids)
671
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
672
+ params_rng, dropout_rng = jax.random.split(rng)
673
+ rngs = {"params": params_rng, "dropout": dropout_rng}
674
+
675
+ if self.config.add_cross_attention:
676
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd,))
677
+ encoder_attention_mask = attention_mask
678
+ module_init_outputs = self.module.init(
679
+ rngs,
680
+ input_ids,
681
+ attention_mask,
682
+ position_ids,
683
+ encoder_hidden_states,
684
+ encoder_attention_mask,
685
+ return_dict=False,
686
+ )
687
+ else:
688
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
689
+
690
+ random_params = module_init_outputs["params"]
691
+
692
+ if params is not None:
693
+ random_params = flatten_dict(unfreeze(random_params))
694
+ params = flatten_dict(unfreeze(params))
695
+ for missing_key in self._missing_keys:
696
+ params[missing_key] = random_params[missing_key]
697
+ self._missing_keys = set()
698
+ return freeze(unflatten_dict(params))
699
+ else:
700
+ return random_params
701
+
702
+ def init_cache(self, batch_size, max_length):
703
+ r"""
704
+ Args:
705
+ batch_size (`int`):
706
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
707
+ max_length (`int`):
708
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
709
+ cache.
710
+ """
711
+ # init input variables to retrieve cache
712
+ input_ids = jnp.ones((batch_size, max_length))
713
+ attention_mask = jnp.ones_like(input_ids)
714
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
715
+
716
+ init_variables = self.module.init(
717
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
718
+ )
719
+ return init_variables["cache"]
720
+
721
+ def _get_logits_processor(self,*args, **kwargs) -> FlaxLogitsProcessorList:
722
+ processors = super()._get_logits_processor(*args, **kwargs)
723
+ def squash_extra_tokens(input_ids, scores, cur_len):
724
+ return scores.at[:, self.config.n_real_tokens:].set(-float('inf'))
725
+
726
+ processors.append(squash_extra_tokens)
727
+ return processors
728
+
729
+ @add_start_docstrings_to_model_forward(GPTJ_INPUTS_DOCSTRING)
730
+ def __call__(
731
+ self,
732
+ input_ids,
733
+ attention_mask=None,
734
+ position_ids=None,
735
+ params: dict = None,
736
+ past_key_values: dict = None,
737
+ dropout_rng: jax.random.PRNGKey = None,
738
+ train: bool = False,
739
+ output_attentions: Optional[bool] = None,
740
+ output_hidden_states: Optional[bool] = None,
741
+ return_dict: Optional[bool] = None,
742
+ ):
743
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
744
+ output_hidden_states = (
745
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
746
+ )
747
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
748
+
749
+ batch_size, sequence_length = input_ids.shape
750
+
751
+ if position_ids is None:
752
+ if past_key_values is not None:
753
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
754
+
755
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
756
+
757
+ if attention_mask is None:
758
+ attention_mask = jnp.ones((batch_size, sequence_length))
759
+
760
+ # Handle any PRNG if needed
761
+ rngs = {}
762
+ if dropout_rng is not None:
763
+ rngs["dropout"] = dropout_rng
764
+
765
+ inputs = {"params": params or self.params}
766
+
767
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
768
+ if past_key_values:
769
+ inputs["cache"] = past_key_values
770
+ mutable = ["cache"]
771
+ else:
772
+ mutable = False
773
+
774
+ outputs = self.module.apply(
775
+ inputs,
776
+ jnp.array(input_ids, dtype="i4"),
777
+ jnp.array(attention_mask, dtype="i4"),
778
+ jnp.array(position_ids, dtype="i4"),
779
+ not train,
780
+ False,
781
+ output_attentions,
782
+ output_hidden_states,
783
+ return_dict,
784
+ rngs=rngs,
785
+ mutable=mutable,
786
+ )
787
+
788
+ # add updated cache to model output
789
+ if past_key_values is not None and return_dict:
790
+ outputs, past_key_values = outputs
791
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
792
+ return outputs
793
+ elif past_key_values is not None and not return_dict:
794
+ outputs, past_key_values = outputs
795
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
796
+
797
+ return outputs
798
+
799
+
800
+ class FlaxGPTJBlockCollection(nn.Module):
801
+ config: GPTJConfig
802
+ dtype: jnp.dtype = jnp.float32
803
+
804
+ def setup(self):
805
+ block = FlaxGPTJBlock
806
+ if self.config.gradient_checkpointing:
807
+ FlaxGPT2CheckpointBlock = remat(
808
+ block, static_argnums=(3, 4, 5),
809
+ policy=get_gradient_checkpoint_policy(
810
+ self.config.gradient_checkpointing_policy
811
+ )
812
+ )
813
+ block = FlaxGPT2CheckpointBlock
814
+ self.blocks = [
815
+ block(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_layers)
816
+ ]
817
+
818
+ def __call__(
819
+ self,
820
+ hidden_states,
821
+ attention_mask=None,
822
+ position_ids=None,
823
+ deterministic: bool = True,
824
+ init_cache: bool = False,
825
+ output_attentions: bool = False,
826
+ output_hidden_states: bool = False,
827
+ return_dict: bool = True,
828
+ ):
829
+ all_attentions = () if output_attentions else None
830
+ all_hidden_states = () if output_hidden_states else None
831
+
832
+ if not deterministic and self.config.fcm_max_ratio > 0:
833
+ # Apply forgetful causal mask
834
+ batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
835
+ fcm_ratio = jax.random.uniform(
836
+ self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
837
+ minval=self.config.fcm_min_ratio,
838
+ maxval=self.config.fcm_max_ratio
839
+ )
840
+ fcm_mask = jax.random.uniform(
841
+ self.make_rng('fcm'),
842
+ shape=(batch_size, 1, seq_length, seq_length)
843
+ ) > fcm_ratio
844
+ fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
845
+ fcm_mask = fcm_mask.astype('bool')
846
+ else:
847
+ fcm_mask = None
848
+
849
+ for block in self.blocks:
850
+ if output_hidden_states:
851
+ all_hidden_states += (hidden_states,)
852
+
853
+ layer_outputs = block(
854
+ hidden_states,
855
+ attention_mask,
856
+ position_ids,
857
+ deterministic,
858
+ init_cache,
859
+ output_attentions,
860
+ fcm_mask,
861
+ )
862
+ hidden_states = layer_outputs[0]
863
+
864
+ if output_attentions:
865
+ all_attentions += (layer_outputs[1],)
866
+
867
+ # this contains possible `None` values - `FlaxGPTJModule` will filter them out
868
+ outputs = (hidden_states, all_hidden_states, all_attentions)
869
+
870
+ return outputs
871
+
872
+
873
+ class FlaxGPTJModule(nn.Module):
874
+ config: GPTJConfig
875
+ dtype: jnp.dtype = jnp.float32
876
+
877
+ def setup(self):
878
+ self.embed_dim = self.config.hidden_size
879
+
880
+ self.wte = nn.Embed(
881
+ self.config.vocab_size,
882
+ self.config.hidden_size,
883
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
884
+ )
885
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
886
+ self.h = FlaxGPTJBlockCollection(self.config, dtype=self.dtype)
887
+ self.ln_f = nn.LayerNorm(
888
+ epsilon=self.config.layer_norm_epsilon,
889
+ dtype=jnp.promote_types(self.dtype, jnp.float32)
890
+ )
891
+
892
+ def __call__(
893
+ self,
894
+ input_ids,
895
+ attention_mask,
896
+ position_ids,
897
+ deterministic=True,
898
+ init_cache: bool = False,
899
+ output_attentions: bool = False,
900
+ output_hidden_states: bool = False,
901
+ return_dict: bool = True,
902
+ ):
903
+ input_embeds = self.wte(input_ids.astype("i4"))
904
+
905
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
906
+
907
+ outputs = self.h(
908
+ hidden_states,
909
+ attention_mask,
910
+ position_ids=position_ids,
911
+ deterministic=deterministic,
912
+ init_cache=init_cache,
913
+ output_attentions=output_attentions,
914
+ output_hidden_states=output_hidden_states,
915
+ return_dict=return_dict,
916
+ )
917
+
918
+ hidden_states = outputs[0]
919
+ hidden_states = self.ln_f(hidden_states)
920
+
921
+ if output_hidden_states:
922
+ all_hidden_states = outputs[1] + (hidden_states,)
923
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
924
+ else:
925
+ outputs = (hidden_states,) + outputs[1:]
926
+
927
+ if not return_dict:
928
+ return tuple(v for v in outputs if v is not None)
929
+
930
+ return FlaxBaseModelOutput(
931
+ last_hidden_state=hidden_states,
932
+ hidden_states=outputs[1],
933
+ attentions=outputs[-1],
934
+ )
935
+
936
+
937
+ @add_start_docstrings(
938
+ "The bare GPTJ Model transformer outputting raw hidden-states without any specific head on top.",
939
+ GPTJ_START_DOCSTRING,
940
+ )
941
+ class FlaxGPTJModel(FlaxGPTJPreTrainedModel):
942
+ module_class = FlaxGPTJModule
943
+
944
+
945
+ append_call_sample_docstring(
946
+ FlaxGPTJModel,
947
+ _CHECKPOINT_FOR_DOC,
948
+ FlaxCausalLMOutput,
949
+ _CONFIG_FOR_DOC,
950
+ )
951
+
952
+
953
+ class FlaxGPTJForCausalLMModule(nn.Module):
954
+ config: GPTJConfig
955
+ dtype: jnp.dtype = jnp.float32
956
+
957
+ def setup(self):
958
+ self.transformer = FlaxGPTJModule(self.config, dtype=self.dtype)
959
+ self.lm_head = nn.Dense(
960
+ self.config.vocab_size,
961
+ dtype=self.dtype,
962
+ kernel_init=jax.nn.initializers.variance_scaling(
963
+ scale=1.0, mode='fan_in',
964
+ distribution='normal',
965
+ )
966
+ )
967
+
968
+ def __call__(
969
+ self,
970
+ input_ids,
971
+ attention_mask=None,
972
+ position_ids=None,
973
+ deterministic: bool = True,
974
+ init_cache: bool = False,
975
+ output_attentions: bool = False,
976
+ output_hidden_states: bool = False,
977
+ return_dict: bool = True,
978
+ ):
979
+ batch_size, seq_length = input_ids.shape
980
+ if attention_mask is None:
981
+ attention_mask = jnp.ones_like(input_ids)
982
+ if position_ids is None:
983
+ position_ids = jnp.broadcast_to(
984
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
985
+ (batch_size, seq_length)
986
+ )
987
+
988
+ outputs = self.transformer(
989
+ input_ids,
990
+ attention_mask,
991
+ position_ids,
992
+ deterministic=deterministic,
993
+ init_cache=init_cache,
994
+ output_attentions=output_attentions,
995
+ output_hidden_states=output_hidden_states,
996
+ return_dict=return_dict,
997
+ )
998
+
999
+ hidden_states = outputs[0]
1000
+
1001
+ if self.config.tie_word_embeddings:
1002
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
1003
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
1004
+ else:
1005
+ lm_logits = self.lm_head(hidden_states)
1006
+
1007
+ if not return_dict:
1008
+ return (lm_logits,) + outputs[1:]
1009
+
1010
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
1011
+
1012
+
1013
+ @add_start_docstrings(
1014
+ """
1015
+ The GPTJ Model transformer with a language modeling head on top.
1016
+ """,
1017
+ GPTJ_START_DOCSTRING,
1018
+ )
1019
+ class FlaxGPTJForCausalLM(FlaxGPTJPreTrainedModel):
1020
+ module_class = FlaxGPTJForCausalLMModule
1021
+
1022
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1023
+ # initializing the cache
1024
+ batch_size, seq_length = input_ids.shape
1025
+
1026
+ past_key_values = self.init_cache(batch_size, max_length)
1027
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1028
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
1029
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1030
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1031
+ if attention_mask is not None:
1032
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1033
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1034
+ else:
1035
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1036
+
1037
+ return {
1038
+ "past_key_values": past_key_values,
1039
+ "attention_mask": extended_attention_mask,
1040
+ "position_ids": position_ids,
1041
+ }
1042
+
1043
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1044
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1045
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1046
+ return model_kwargs
1047
+
1048
+
1049
+ append_call_sample_docstring(
1050
+ FlaxGPTJForCausalLM,
1051
+ _CHECKPOINT_FOR_DOC,
1052
+ FlaxCausalLMOutput,
1053
+ _CONFIG_FOR_DOC,
1054
+ )
EasyLM/models/gptj/gptj_serve.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import mlxu
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax.experimental.pjit import pjit
10
+ from jax.sharding import PartitionSpec as PS
11
+ import flax
12
+ from flax import linen as nn
13
+ from flax.jax_utils import prefetch_to_device
14
+ from flax.training.train_state import TrainState
15
+ import optax
16
+ from transformers import GenerationConfig, FlaxLogitsProcessorList
17
+
18
+ from EasyLM.checkpoint import StreamingCheckpointer
19
+ from EasyLM.serving import LMServer
20
+ from EasyLM.jax_utils import (
21
+ JaxRNG, next_rng, match_partition_rules, tree_apply,
22
+ set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
23
+ with_sharding_constraint, FlaxTemperatureLogitsWarper
24
+ )
25
+ from EasyLM.models.gptj.gptj_model import (
26
+ GPTJConfig, FlaxGPTJForCausalLMModule, FlaxGPTJForCausalLM
27
+ )
28
+
29
+
30
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
+ seed=42,
32
+ initialize_jax_distributed=False,
33
+ mesh_dim='1,-1,1',
34
+ dtype='bf16',
35
+ input_length=1024,
36
+ seq_length=2048,
37
+ top_k=50,
38
+ top_p=1.0,
39
+ do_sample=True,
40
+ num_beams=1,
41
+ add_bos_token=False,
42
+ load_gptj_config='',
43
+ load_checkpoint='',
44
+ tokenizer=GPTJConfig.get_tokenizer_config(),
45
+ lm_server=LMServer.get_default_config(),
46
+ )
47
+
48
+
49
+ def main(argv):
50
+ if FLAGS.initialize_jax_distributed:
51
+ jax.distributed.initialize()
52
+ set_random_seed(FLAGS.seed)
53
+
54
+ prefix_tokenizer = GPTJConfig.get_tokenizer(
55
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
56
+ )
57
+ tokenizer = GPTJConfig.get_tokenizer(
58
+ FLAGS.tokenizer, truncation_side='right', padding_side='right'
59
+ )
60
+
61
+ with jax.default_device(jax.devices("cpu")[0]):
62
+ gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
63
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
64
+ if load_type == 'huggingface':
65
+ params = gptj_config.load_pretrained(load_path)
66
+ else:
67
+ _, params = StreamingCheckpointer.load_trainstate_checkpoint(
68
+ FLAGS.load_checkpoint, disallow_trainstate=True
69
+ )
70
+
71
+ hf_model = FlaxGPTJForCausalLM(
72
+ gptj_config,
73
+ input_shape=(1, FLAGS.seq_length),
74
+ seed=FLAGS.seed,
75
+ _do_init=False
76
+ )
77
+
78
+ model_ps = match_partition_rules(
79
+ GPTJConfig.get_partition_rules(), params
80
+ )
81
+ shard_fns, _ = make_shard_and_gather_fns(
82
+ model_ps, get_float_dtype_by_name(FLAGS.dtype)
83
+ )
84
+
85
+ @partial(
86
+ pjit,
87
+ in_shardings=(model_ps, PS(), PS()),
88
+ out_shardings=(PS(), PS(), PS())
89
+ )
90
+ def forward_loglikelihood(params, rng, batch):
91
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
92
+ rng_generator = JaxRNG(rng)
93
+ input_tokens = batch['input_tokens']
94
+ output_tokens = batch['output_tokens']
95
+ input_mask = batch['input_mask']
96
+ output_mask = batch['output_mask']
97
+
98
+ logits = hf_model.module.apply(
99
+ params, input_tokens, attention_mask=input_mask,
100
+ deterministic=True, rngs=rng_generator(gptj_config.rng_keys()),
101
+ ).logits
102
+ if gptj_config.n_real_tokens is not None:
103
+ logits = logits.at[:, :, gptj_config.n_real_tokens:].set(-1e8)
104
+ loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
105
+ logits, output_tokens
106
+ )
107
+ loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
108
+ match_count = jnp.sum(
109
+ (jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
110
+ axis=-1
111
+ )
112
+ total = jnp.sum(output_mask, axis=-1)
113
+ is_greedy = match_count == total
114
+ return loglikelihood, is_greedy, rng_generator()
115
+
116
+
117
+ @partial(
118
+ pjit,
119
+ in_shardings=(model_ps, PS(), PS(), PS()),
120
+ out_shardings=(PS(), PS())
121
+ )
122
+ def forward_generate(params, rng, batch, temperature):
123
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
124
+ rng_generator = JaxRNG(rng)
125
+ output = hf_model.generate(
126
+ batch['input_tokens'],
127
+ attention_mask=batch['attention_mask'],
128
+ params=params['params'],
129
+ prng_key=rng_generator(),
130
+ logits_processor=FlaxLogitsProcessorList(
131
+ [FlaxTemperatureLogitsWarper(temperature)]
132
+ ),
133
+ generation_config=GenerationConfig(
134
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
135
+ pad_token_id=tokenizer.eos_token_id,
136
+ bos_token_id=tokenizer.bos_token_id,
137
+ eos_token_id=tokenizer.eos_token_id,
138
+ do_sample=FLAGS.do_sample,
139
+ num_beams=FLAGS.num_beams,
140
+ top_k=FLAGS.top_k,
141
+ top_p=FLAGS.top_p,
142
+ )
143
+ ).sequences[:, batch['input_tokens'].shape[1]:]
144
+ return output, rng_generator()
145
+
146
+ @partial(
147
+ pjit,
148
+ in_shardings=(model_ps, PS(), PS()),
149
+ out_shardings=(PS(), PS())
150
+ )
151
+ def forward_greedy_generate(params, rng, batch):
152
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
153
+ rng_generator = JaxRNG(rng)
154
+ output = hf_model.generate(
155
+ batch['input_tokens'],
156
+ attention_mask=batch['attention_mask'],
157
+ params=params['params'],
158
+ prng_key=rng_generator(),
159
+ generation_config=GenerationConfig(
160
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
161
+ pad_token_id=tokenizer.eos_token_id,
162
+ bos_token_id=tokenizer.bos_token_id,
163
+ eos_token_id=tokenizer.eos_token_id,
164
+ do_sample=False,
165
+ num_beams=1,
166
+ )
167
+ ).sequences[:, batch['input_tokens'].shape[1]:]
168
+ return output, rng_generator()
169
+
170
+ mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
171
+ with mesh:
172
+ params = tree_apply(shard_fns, params)
173
+ sharded_rng = next_rng()
174
+
175
+ class ModelServer(LMServer):
176
+
177
+ @staticmethod
178
+ def loglikelihood(prefix_text, text):
179
+ nonlocal sharded_rng
180
+ prefix = prefix_tokenizer(
181
+ prefix_text,
182
+ padding='max_length',
183
+ truncation=True,
184
+ max_length=FLAGS.input_length,
185
+ return_tensors='np',
186
+ )
187
+ inputs = tokenizer(
188
+ text,
189
+ padding='max_length',
190
+ truncation=True,
191
+ max_length=FLAGS.seq_length - FLAGS.input_length,
192
+ return_tensors='np',
193
+ )
194
+ output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
195
+ bos_tokens = np.full(
196
+ (output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
197
+ )
198
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
199
+ input_mask = np.concatenate(
200
+ [prefix.attention_mask, inputs.attention_mask], axis=1
201
+ )
202
+ if FLAGS.add_bos_token:
203
+ bos_mask = np.ones_like(input_mask[:, :1])
204
+ else:
205
+ bos_mask = np.zeros_like(input_mask[:, :1])
206
+
207
+ input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
208
+ output_mask = np.concatenate(
209
+ [np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
210
+ )
211
+ batch = dict(
212
+ input_tokens=input_tokens,
213
+ output_tokens=output_tokens,
214
+ input_mask=input_mask,
215
+ output_mask=output_mask,
216
+ )
217
+ with mesh:
218
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
219
+ params, sharded_rng, batch
220
+ )
221
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
222
+ return loglikelihood, is_greedy
223
+
224
+ @staticmethod
225
+ def loglikelihood_rolling(text):
226
+ nonlocal sharded_rng
227
+ inputs = tokenizer(
228
+ text,
229
+ padding='longest',
230
+ truncation=False,
231
+ max_length=np.iinfo(np.int32).max,
232
+ return_tensors='np',
233
+ )
234
+ batch_size = inputs.input_ids.shape[0]
235
+ output_tokens = inputs.input_ids
236
+ attention_mask = inputs.attention_mask
237
+
238
+ if output_tokens.shape[1] < FLAGS.seq_length:
239
+ padding_length = FLAGS.seq_length - output_tokens.shape[1]
240
+ pad_tokens = np.full(
241
+ (batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
242
+ )
243
+ output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
244
+ pad_mask = np.zeros(
245
+ (batch_size, padding_length), dtype=inputs.attention_mask.dtype
246
+ )
247
+ attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
248
+
249
+ bos_tokens = np.full(
250
+ (batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
251
+ )
252
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
253
+ bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
254
+ total_seq_length = output_tokens.shape[1]
255
+
256
+ total_loglikelihood = 0.0
257
+ total_is_greedy = True
258
+ # Sliding window
259
+ for i in range(0, total_seq_length, FLAGS.seq_length):
260
+ # Last window
261
+ if i + FLAGS.seq_length > total_seq_length:
262
+ last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
263
+ last_output_mask[:, :i - total_seq_length] = 0.0
264
+
265
+ batch = dict(
266
+ input_tokens=input_tokens[:, -FLAGS.seq_length:],
267
+ output_tokens=output_tokens[:, -FLAGS.seq_length:],
268
+ input_mask=attention_mask[:, -FLAGS.seq_length:],
269
+ output_mask=last_output_mask,
270
+ )
271
+
272
+ # Normal window
273
+ else:
274
+ batch = dict(
275
+ input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
276
+ output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
277
+ input_mask=attention_mask[:, i:i + FLAGS.seq_length],
278
+ output_mask=attention_mask[:, i:i + FLAGS.seq_length],
279
+ )
280
+
281
+ with mesh:
282
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
283
+ params, sharded_rng, batch
284
+ )
285
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
286
+
287
+ total_loglikelihood += loglikelihood
288
+ total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
289
+
290
+ return total_loglikelihood, total_is_greedy
291
+
292
+ @staticmethod
293
+ def generate(text, temperature):
294
+ nonlocal sharded_rng
295
+ inputs = prefix_tokenizer(
296
+ text,
297
+ padding='max_length',
298
+ truncation=True,
299
+ max_length=FLAGS.input_length,
300
+ return_tensors='np',
301
+ )
302
+ input_tokens = inputs.input_ids
303
+ input_mask = inputs.attention_mask
304
+ if FLAGS.add_bos_token:
305
+ input_tokens[:, 0] = tokenizer.bos_token_id
306
+ input_mask[:, 0] = 1
307
+ batch = dict(
308
+ input_tokens=input_tokens,
309
+ attention_mask=input_mask,
310
+ )
311
+ with mesh:
312
+ output, sharded_rng = forward_generate(
313
+ params, sharded_rng, batch, temperature
314
+ )
315
+ output = jax.device_get(output)
316
+ output_text = []
317
+ for text in list(tokenizer.batch_decode(output)):
318
+ if tokenizer.eos_token in text:
319
+ text = text.split(tokenizer.eos_token, maxsplit=1)[0]
320
+ output_text.append(text)
321
+
322
+ return output_text
323
+
324
+ @staticmethod
325
+ def greedy_until(prefix_text, until, max_length):
326
+ nonlocal sharded_rng
327
+ all_outputs = []
328
+ for pf, ut in zip(prefix_text, until):
329
+ if isinstance(ut, str):
330
+ ut = [ut]
331
+ total_length = 0
332
+ total_generated = ''
333
+
334
+ while total_length < max_length:
335
+ pf_tokens = tokenizer(
336
+ pf,
337
+ padding=False,
338
+ truncation=False,
339
+ max_length=np.iinfo(np.int32).max,
340
+ return_tensors='np',
341
+ )
342
+ input_tokens = pf_tokens.input_ids
343
+ attention_mask = pf_tokens.attention_mask
344
+
345
+ if input_tokens.shape[1] < FLAGS.input_length:
346
+ extra = FLAGS.input_length - input_tokens.shape[1]
347
+ pad_tokens = np.full(
348
+ (1, extra), tokenizer.pad_token_id, dtype=np.int32
349
+ )
350
+ input_tokens = np.concatenate(
351
+ [pad_tokens, input_tokens], axis=1
352
+ )
353
+ pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
354
+ attention_mask = np.concatenate(
355
+ [pad_attention, attention_mask], axis=1
356
+ )
357
+ elif input_tokens.shape[1] > FLAGS.input_length:
358
+ input_tokens = input_tokens[:, -FLAGS.input_length:]
359
+ attention_mask = attention_mask[:, -FLAGS.input_length:]
360
+
361
+ if FLAGS.add_bos_token:
362
+ input_tokens[:, 0] = tokenizer.bos_token_id
363
+ attention_mask[:, 0] = 1
364
+
365
+ batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
366
+
367
+ with mesh:
368
+ output, sharded_rng = forward_greedy_generate(
369
+ params, sharded_rng, batch
370
+ )
371
+ output = jax.device_get(output)
372
+
373
+ total_length += output.shape[1]
374
+ output_text = tokenizer.batch_decode(output)[0]
375
+ total_generated = total_generated + output_text
376
+ pf = pf + output_text
377
+
378
+ done = False
379
+ for s in ut:
380
+ if s in total_generated:
381
+ total_generated = total_generated.split(s, maxsplit=1)[0]
382
+ done = True
383
+ if done:
384
+ break
385
+
386
+ all_outputs.append(total_generated)
387
+
388
+ return all_outputs
389
+
390
+
391
+ server = ModelServer(FLAGS.lm_server)
392
+ server.run()
393
+
394
+
395
+ if __name__ == "__main__":
396
+ mlxu.run(main)
EasyLM/models/gptj/gptj_train.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ from tqdm import tqdm, trange
5
+ import numpy as np
6
+ import mlxu
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax.experimental.pjit import pjit, with_sharding_constraint
11
+ from jax.sharding import PartitionSpec as PS
12
+ from flax.training.train_state import TrainState
13
+
14
+ from EasyLM.data import DatasetFactory
15
+ from EasyLM.checkpoint import StreamingCheckpointer
16
+ from EasyLM.optimizers import OptimizerFactory
17
+ from EasyLM.jax_utils import (
18
+ JaxRNG, next_rng, match_partition_rules,
19
+ cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
+ set_random_seed, average_metrics, get_weight_decay_mask,
21
+ make_shard_and_gather_fns, tree_apply
22
+ )
23
+ from EasyLM.models.gptj.gptj_model import GPTJConfig, FlaxGPTJForCausalLMModule
24
+
25
+
26
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
27
+ seed=42,
28
+ initialize_jax_distributed=False,
29
+ mesh_dim='1,-1,1',
30
+ dtype='fp32',
31
+ total_steps=10000,
32
+ load_gptj_config='',
33
+ update_gptj_config='',
34
+ load_checkpoint='',
35
+ load_dataset_state='',
36
+ log_freq=50,
37
+ save_model_freq=0,
38
+ save_milestone_freq=0,
39
+ eval_steps=0,
40
+ tokenizer=GPTJConfig.get_tokenizer_config(),
41
+ train_dataset=DatasetFactory.get_default_config(),
42
+ eval_dataset=DatasetFactory.get_default_config(),
43
+ optimizer=OptimizerFactory.get_default_config(),
44
+ checkpointer=StreamingCheckpointer.get_default_config(),
45
+ gptj=GPTJConfig.get_default_config(),
46
+ logger=mlxu.WandBLogger.get_default_config(),
47
+ log_all_worker=False,
48
+ )
49
+
50
+
51
+ def main(argv):
52
+ if FLAGS.initialize_jax_distributed:
53
+ jax.distributed.initialize()
54
+
55
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
56
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
57
+ logger = mlxu.WandBLogger(
58
+ config=FLAGS.logger,
59
+ variant=variant,
60
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
61
+ )
62
+ set_random_seed(FLAGS.seed)
63
+
64
+ tokenizer = GPTJConfig.get_tokenizer(FLAGS.tokenizer)
65
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
66
+ if FLAGS.load_dataset_state != '':
67
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
68
+
69
+ if FLAGS.eval_steps > 0:
70
+ eval_dataset = DatasetFactory.load_dataset(
71
+ FLAGS.eval_dataset, dataset.tokenizer
72
+ )
73
+ eval_iterator = iter(eval_dataset)
74
+
75
+ seq_length = dataset.seq_length
76
+
77
+ if FLAGS.load_gptj_config != '':
78
+ gptj_config = GPTJConfig.load_config(FLAGS.load_gptj_config)
79
+ else:
80
+ gptj_config = GPTJConfig(**FLAGS.gptj)
81
+
82
+ if FLAGS.update_gptj_config != '':
83
+ gptj_config.update(dict(eval(FLAGS.update_gptj_config)))
84
+
85
+ gptj_config.update(dict(
86
+ bos_token_id=dataset.tokenizer.bos_token_id,
87
+ eos_token_id=dataset.tokenizer.eos_token_id,
88
+ ))
89
+ if gptj_config.vocab_size < dataset.vocab_size:
90
+ gptj_config.update(dict(vocab_size=dataset.vocab_size))
91
+
92
+ model = FlaxGPTJForCausalLMModule(
93
+ gptj_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
94
+ )
95
+
96
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
97
+ FLAGS.optimizer,
98
+ get_weight_decay_mask(GPTJConfig.get_weight_decay_exclusions()),
99
+ )
100
+
101
+ def create_trainstate_from_params(params):
102
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
103
+
104
+ def init_fn(rng):
105
+ rng_generator = JaxRNG(rng)
106
+ params = model.init(
107
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
108
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
109
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
110
+ rngs=rng_generator(gptj_config.rng_keys()),
111
+ )
112
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
113
+
114
+ def train_step(train_state, rng, batch):
115
+ rng_generator = JaxRNG(rng)
116
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
117
+ def loss_and_accuracy(params):
118
+ logits = model.apply(
119
+ params, batch['input_tokens'], deterministic=False,
120
+ rngs=rng_generator(gptj_config.rng_keys()),
121
+ ).logits
122
+ return cross_entropy_loss_and_accuracy(
123
+ logits, batch['target_tokens'], batch['loss_masks']
124
+ )
125
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
126
+ (loss, accuracy), grads = grad_fn(train_state.params)
127
+ train_state = train_state.apply_gradients(grads=grads)
128
+ metrics = dict(
129
+ loss=loss,
130
+ accuracy=accuracy,
131
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
132
+ gradient_norm=global_norm(grads),
133
+ param_norm=global_norm(train_state.params),
134
+ )
135
+ return train_state, rng_generator(), metrics
136
+
137
+ def eval_step(train_state, rng, batch):
138
+ rng_generator = JaxRNG(rng)
139
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
140
+ logits = model.apply(
141
+ train_state.params, batch['input_tokens'], deterministic=True,
142
+ rngs=rng_generator(gptj_config.rng_keys()),
143
+ ).logits
144
+ loss, accuracy = cross_entropy_loss_and_accuracy(
145
+ logits, batch['target_tokens'], batch['loss_masks']
146
+ )
147
+ metrics = dict(
148
+ eval_loss=loss,
149
+ eval_accuracy=accuracy,
150
+ )
151
+ return rng_generator(), metrics
152
+
153
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
154
+ train_state_partition = match_partition_rules(
155
+ GPTJConfig.get_partition_rules(), train_state_shapes
156
+ )
157
+
158
+ shard_fns, gather_fns = make_shard_and_gather_fns(
159
+ train_state_partition, train_state_shapes
160
+ )
161
+ checkpointer = StreamingCheckpointer(
162
+ FLAGS.checkpointer, logger.output_dir,
163
+ enable=jax.process_index() == 0,
164
+ )
165
+
166
+ sharded_init_fn = pjit(
167
+ init_fn,
168
+ in_shardings=PS(),
169
+ out_shardings=train_state_partition
170
+ )
171
+
172
+ sharded_create_trainstate_from_params = pjit(
173
+ create_trainstate_from_params,
174
+ in_shardings=(train_state_partition.params, ),
175
+ out_shardings=train_state_partition,
176
+ donate_argnums=(0, ),
177
+ )
178
+
179
+ sharded_train_step = pjit(
180
+ train_step,
181
+ in_shardings=(train_state_partition, PS(), PS()),
182
+ out_shardings=(train_state_partition, PS(), PS()),
183
+ donate_argnums=(0, 1),
184
+ )
185
+
186
+ sharded_eval_step = pjit(
187
+ eval_step,
188
+ in_shardings=(train_state_partition, PS(), PS()),
189
+ out_shardings=(PS(), PS()),
190
+ donate_argnums=(1,),
191
+ )
192
+
193
+ def save_checkpoint(train_state, milestone=False):
194
+ step = int(jax.device_get(train_state.step))
195
+ metadata = dict(
196
+ step=step,
197
+ variant=variant,
198
+ flags=flags_config_dict,
199
+ gptj_config=gptj_config.to_dict(),
200
+ )
201
+ checkpointer.save_all(
202
+ train_state=train_state,
203
+ gather_fns=gather_fns,
204
+ metadata=metadata,
205
+ dataset=dataset.get_state_dict(),
206
+ milestone=milestone,
207
+ )
208
+
209
+ mesh = GPTJConfig.get_jax_mesh(FLAGS.mesh_dim)
210
+ with mesh:
211
+ train_state, restored_params = None, None
212
+ if FLAGS.load_checkpoint != '':
213
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
214
+ if load_type == 'huggingface':
215
+ restored_params = tree_apply(
216
+ shard_fns.params, gptj_config.load_pretrained(load_path)
217
+ )
218
+ train_state = None
219
+ else:
220
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
221
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
222
+ )
223
+
224
+ if train_state is None and restored_params is None:
225
+ # Initialize from scratch
226
+ train_state = sharded_init_fn(next_rng())
227
+ elif train_state is None and restored_params is not None:
228
+ # Restore from params but initialize train_state
229
+ train_state = sharded_create_trainstate_from_params(restored_params)
230
+ del restored_params
231
+
232
+ start_step = int(jax.device_get(train_state.step))
233
+
234
+ if FLAGS.save_model_freq > 0:
235
+ save_checkpoint(train_state)
236
+
237
+ sharded_rng = next_rng()
238
+
239
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
240
+
241
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
242
+ train_state, sharded_rng, metrics = sharded_train_step(
243
+ train_state, sharded_rng, batch
244
+ )
245
+
246
+ if step % FLAGS.log_freq == 0:
247
+ if FLAGS.eval_steps > 0:
248
+ eval_metric_list = []
249
+ for _ in range(FLAGS.eval_steps):
250
+ eval_batch, _ = next(eval_iterator)
251
+ sharded_rng, eval_metrics = sharded_eval_step(
252
+ train_state, sharded_rng, eval_batch
253
+ )
254
+ eval_metric_list.append(eval_metrics)
255
+ metrics.update(average_metrics(eval_metric_list))
256
+
257
+ log_metrics = {"step": step}
258
+ log_metrics.update(metrics)
259
+ log_metrics.update(dataset_metrics)
260
+ log_metrics = jax.device_get(log_metrics)
261
+ logger.log(log_metrics)
262
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
263
+
264
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
265
+ save_checkpoint(train_state, milestone=True)
266
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
267
+ save_checkpoint(train_state)
268
+
269
+ if FLAGS.save_model_freq > 0:
270
+ save_checkpoint(train_state)
271
+
272
+
273
+ if __name__ == "__main__":
274
+ mlxu.run(main)
EasyLM/models/llama/convert_easylm_to_hf.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 EleutherAI and The HuggingFace Inc. team. All rights reserved.
2
+ # Copyright 2023 Xinyang Geng
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # This script converts LLaMA model checkpoint trained by EsayLM to the
17
+ # HuggingFace transformers LLaMA PyTorch format, which can then be loaded
18
+ # by HuggingFace transformers.
19
+
20
+ import gc
21
+ import json
22
+ import math
23
+ import os
24
+ import shutil
25
+
26
+ import numpy as np
27
+ import mlxu
28
+ import jax
29
+ import jax.numpy as jnp
30
+ import flax
31
+ from flax.traverse_util import flatten_dict
32
+ import torch
33
+ from transformers import LlamaConfig, LlamaForCausalLM
34
+
35
+ from EasyLM.checkpoint import StreamingCheckpointer
36
+ from EasyLM.jax_utils import float_tensor_to_dtype
37
+
38
+
39
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
40
+ load_checkpoint='',
41
+ tokenizer_path='',
42
+ model_size='13b',
43
+ output_dir='',
44
+ )
45
+
46
+
47
+ LLAMA_STANDARD_CONFIGS = {
48
+ 'small': {
49
+ 'vocab_size': 64256,
50
+ 'dim': 768,
51
+ 'intermediate_size': 3072,
52
+ 'n_layers': 12,
53
+ 'n_heads': 12,
54
+ 'norm_eps': 1e-6,
55
+ },
56
+ 'medium': {
57
+ 'vocab_size': 64256,
58
+ 'dim': 1024,
59
+ 'intermediate_size': 4096,
60
+ 'n_layers': 24,
61
+ 'n_heads': 16,
62
+ 'norm_eps': 1e-6,
63
+ },
64
+ 'large': {
65
+ 'vocab_size': 64256,
66
+ 'dim': 1536,
67
+ 'intermediate_size': 6144,
68
+ 'n_layers': 24,
69
+ 'n_heads': 16,
70
+ 'norm_eps': 1e-6,
71
+ },
72
+ 'xlarge': {
73
+ 'vocab_size': 64256,
74
+ 'dim': 2048,
75
+ 'intermediate_size': 8192,
76
+ 'n_layers': 24,
77
+ 'n_heads': 32,
78
+ 'norm_eps': 1e-6,
79
+ },
80
+ '3b': {
81
+ 'vocab_size': 64256,
82
+ 'dim': 3200,
83
+ 'intermediate_size': 8640,
84
+ 'n_layers': 26,
85
+ 'n_heads': 32,
86
+ 'norm_eps': 1e-6,
87
+ },
88
+ '7b': {
89
+ 'vocab_size': 64256,
90
+ 'dim': 4096,
91
+ 'intermediate_size': 11008,
92
+ 'n_layers': 32,
93
+ 'n_heads': 32,
94
+ 'norm_eps': 1e-6,
95
+ },
96
+ '13b': {
97
+ 'vocab_size': 64256,
98
+ 'dim': 5120,
99
+ 'intermediate_size': 13824,
100
+ 'n_layers': 40,
101
+ 'n_heads': 40,
102
+ 'norm_eps': 1e-6,
103
+ },
104
+ '30b': {
105
+ 'vocab_size': 64256,
106
+ 'dim': 6656,
107
+ 'intermediate_size': 17920,
108
+ 'n_layers': 60,
109
+ 'n_heads': 52,
110
+ 'norm_eps': 1e-6,
111
+ },
112
+ '65b': {
113
+ 'vocab_size': 64256,
114
+ 'dim': 8192,
115
+ 'intermediate_size': 22016,
116
+ 'n_layers': 80,
117
+ 'n_heads': 64,
118
+ 'norm_eps': 1e-5,
119
+ },
120
+ }
121
+
122
+
123
+ def match_keywords(string, positives, negatives):
124
+ for positive in positives:
125
+ if positive not in string:
126
+ return False
127
+ for negative in negatives:
128
+ if negative in string:
129
+ return False
130
+ return True
131
+
132
+
133
+ def load_and_convert_checkpoint(path):
134
+ _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
135
+ flax_params = flatten_dict(flax_params['params']['params']['params'], sep='.')
136
+ torch_params = {}
137
+ for key, tensor in flax_params.items():
138
+ if match_keywords(key, ["kernel"], ["norm", 'ln_f']):
139
+ tensor = tensor.T
140
+ torch_params[key] = torch.tensor(
141
+ float_tensor_to_dtype(tensor, 'fp32'), dtype=torch.float16
142
+ )
143
+ return torch_params
144
+
145
+
146
+ def read_json(path):
147
+ with open(path, "r") as f:
148
+ return json.load(f)
149
+
150
+
151
+ def write_json(text, path):
152
+ with open(path, "w") as f:
153
+ json.dump(text, f)
154
+
155
+
156
+ def write_model(loaded, model_path, model_size):
157
+ os.makedirs(model_path, exist_ok=True)
158
+ tmp_model_path = os.path.join(model_path, "tmp")
159
+ os.makedirs(tmp_model_path, exist_ok=True)
160
+
161
+ params = LLAMA_STANDARD_CONFIGS[model_size]
162
+
163
+ n_layers = params["n_layers"]
164
+ n_heads = params["n_heads"]
165
+ dim = params["dim"]
166
+ dims_per_head = dim // n_heads
167
+ base = 10000.0
168
+ inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
169
+
170
+ # permute for sliced rotary
171
+ def permute(w):
172
+ return w.view(n_heads, dim // n_heads // 2, 2, dim).transpose(1, 2).reshape(dim, dim)
173
+
174
+
175
+ param_count = 0
176
+ index_dict = {"weight_map": {}}
177
+ for layer_i in range(n_layers):
178
+ filename = f"pytorch_model-{layer_i + 1}-of-{n_layers + 1}.bin"
179
+ state_dict = {
180
+ f"model.layers.{layer_i}.self_attn.q_proj.weight": permute(
181
+ loaded[f"transformer.h.{layer_i}.attention.wq.kernel"]
182
+ ),
183
+ f"model.layers.{layer_i}.self_attn.k_proj.weight": permute(
184
+ loaded[f"transformer.h.{layer_i}.attention.wk.kernel"]
185
+ ),
186
+ f"model.layers.{layer_i}.self_attn.v_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wv.kernel"],
187
+ f"model.layers.{layer_i}.self_attn.o_proj.weight": loaded[f"transformer.h.{layer_i}.attention.wo.kernel"],
188
+
189
+ f"model.layers.{layer_i}.mlp.gate_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w1.kernel"],
190
+ f"model.layers.{layer_i}.mlp.down_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w2.kernel"],
191
+ f"model.layers.{layer_i}.mlp.up_proj.weight": loaded[f"transformer.h.{layer_i}.feed_forward.w3.kernel"],
192
+
193
+ f"model.layers.{layer_i}.input_layernorm.weight": loaded[f"transformer.h.{layer_i}.attention_norm.kernel"],
194
+ f"model.layers.{layer_i}.post_attention_layernorm.weight": loaded[f"transformer.h.{layer_i}.ffn_norm.kernel"],
195
+
196
+ }
197
+
198
+ state_dict[f"model.layers.{layer_i}.self_attn.rotary_emb.inv_freq"] = inv_freq
199
+ for k, v in state_dict.items():
200
+ index_dict["weight_map"][k] = filename
201
+ param_count += v.numel()
202
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
203
+
204
+ filename = f"pytorch_model-{n_layers + 1}-of-{n_layers + 1}.bin"
205
+ # Unsharded
206
+ state_dict = {
207
+ "model.embed_tokens.weight": loaded["transformer.wte.embedding"],
208
+ "model.norm.weight": loaded["transformer.ln_f.kernel"],
209
+ "lm_head.weight": loaded["lm_head.kernel"],
210
+ }
211
+
212
+ for k, v in state_dict.items():
213
+ index_dict["weight_map"][k] = filename
214
+ param_count += v.numel()
215
+ torch.save(state_dict, os.path.join(tmp_model_path, filename))
216
+
217
+ # Write configs
218
+ index_dict["metadata"] = {"total_size": param_count * 2}
219
+ write_json(index_dict, os.path.join(tmp_model_path, "pytorch_model.bin.index.json"))
220
+
221
+ config = LlamaConfig(
222
+ vocab_size=params["vocab_size"],
223
+ hidden_size=dim,
224
+ intermediate_size=params["intermediate_size"],
225
+ num_attention_heads=params["n_heads"],
226
+ num_hidden_layers=params["n_layers"],
227
+ rms_norm_eps=params["norm_eps"],
228
+ )
229
+ config.save_pretrained(tmp_model_path)
230
+
231
+ # Make space so we can load the model properly now.
232
+ del state_dict
233
+ del loaded
234
+ gc.collect()
235
+
236
+ print("Loading the checkpoint in a Llama model.")
237
+ model = LlamaForCausalLM.from_pretrained(tmp_model_path, torch_dtype=torch.float16)
238
+ print("Model parameter count", model.num_parameters())
239
+ # Avoid saving this as part of the config.
240
+ del model.config._name_or_path
241
+
242
+ print("Saving in the Transformers format.")
243
+ model.save_pretrained(model_path)
244
+ shutil.rmtree(tmp_model_path)
245
+
246
+
247
+ def write_tokenizer(tokenizer_path, input_tokenizer_path):
248
+ print(f"Fetching the tokenizer from {input_tokenizer_path}.")
249
+ os.makedirs(tokenizer_path, exist_ok=True)
250
+ write_json(
251
+ {
252
+ "bos_token": {
253
+ "content": "<s>",
254
+ "lstrip": False,
255
+ "normalized": False,
256
+ "rstrip": False,
257
+ "single_word": False
258
+ },
259
+ "eos_token": {
260
+ "content": "</s>",
261
+ "lstrip": False,
262
+ "normalized": False,
263
+ "rstrip": False,
264
+ "single_word": False
265
+ },
266
+ "unk_token": {
267
+ "content": "<unk>",
268
+ "lstrip": False,
269
+ "normalized": False,
270
+ "rstrip": False,
271
+ "single_word": False
272
+ },
273
+ },
274
+ os.path.join(tokenizer_path, "special_tokens_map.json")
275
+ )
276
+ write_json(
277
+ {
278
+ "add_bos_token": True,
279
+ "add_eos_token": False,
280
+ "model_max_length": 2048,
281
+ "pad_token": None,
282
+ "sp_model_kwargs": {},
283
+ "tokenizer_class": "LlamaTokenizer",
284
+ "clean_up_tokenization_spaces": False,
285
+ "bos_token": {
286
+ "__type": "AddedToken",
287
+ "content": "<s>",
288
+ "lstrip": False,
289
+ "normalized": False,
290
+ "rstrip": False,
291
+ "single_word": False
292
+ },
293
+ "eos_token": {
294
+ "__type": "AddedToken",
295
+ "content": "</s>",
296
+ "lstrip": False,
297
+ "normalized": False,
298
+ "rstrip": False,
299
+ "single_word": False
300
+ },
301
+ "unk_token": {
302
+ "__type": "AddedToken",
303
+ "content": "<unk>",
304
+ "lstrip": False,
305
+ "normalized": False,
306
+ "rstrip": False,
307
+ "single_word": False
308
+ },
309
+ },
310
+ os.path.join(tokenizer_path, "tokenizer_config.json"),
311
+ )
312
+ shutil.copyfile(input_tokenizer_path, os.path.join(tokenizer_path, "tokenizer.model"))
313
+
314
+
315
+ def main(argv):
316
+ assert FLAGS.load_checkpoint != "" and FLAGS.output_dir != "" #and FLAGS.tokenizer_path != ""
317
+ assert FLAGS.model_size in LLAMA_STANDARD_CONFIGS
318
+ # write_tokenizer(
319
+ # tokenizer_path=FLAGS.output_dir,
320
+ # input_tokenizer_path=FLAGS.tokenizer_path,
321
+ # )
322
+ write_model(
323
+ load_and_convert_checkpoint(FLAGS.load_checkpoint),
324
+ model_path=FLAGS.output_dir,
325
+ model_size=FLAGS.model_size,
326
+ )
327
+
328
+
329
+ if __name__ == "__main__":
330
+ mlxu.run(main)
EasyLM/models/llama/convert_torch_to_easylm.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts the standrd LLaMA PyTorch checkpoint released by Meta
2
+ # to the EasyLM checkpoint format. The converted checkpoint can then be loaded
3
+ # by EasyLM for fine-tuning or inference.
4
+
5
+ # This script is largely borrow from https://github.com/Sea-Snell/JAX_llama
6
+
7
+ from pathlib import Path
8
+ import json
9
+ import numpy as np
10
+ import torch
11
+ import flax
12
+ import mlxu
13
+
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+
16
+
17
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
18
+ checkpoint_dir='',
19
+ output_file='',
20
+ streaming=True,
21
+ )
22
+
23
+
24
+ def main(argv):
25
+ ckpt_paths = sorted(Path(FLAGS.checkpoint_dir).glob("*.pth"))
26
+ ckpts = {}
27
+ for i, ckpt_path in enumerate(ckpt_paths):
28
+ checkpoint = torch.load(ckpt_path, map_location="cpu")
29
+ ckpts[int(ckpt_path.name.split('.', maxsplit=2)[1])] = checkpoint
30
+ ckpts = [ckpts[i] for i in sorted(list(ckpts.keys()))]
31
+ with open(Path(FLAGS.checkpoint_dir) / "params.json", "r") as f:
32
+ params = json.loads(f.read())
33
+
34
+ jax_weights = {
35
+ 'transformer': {
36
+ 'wte': {'embedding': np.concatenate([ckpt['tok_embeddings.weight'].numpy() for ckpt in ckpts], axis=1)},
37
+ 'ln_f': {'kernel': ckpts[0]['norm.weight'].numpy()},
38
+ 'h': {
39
+ '%d' % (layer): {
40
+ 'attention': {
41
+ 'wq': {'kernel': np.concatenate([ckpt['layers.%d.attention.wq.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
42
+ 'wk': {'kernel': np.concatenate([ckpt['layers.%d.attention.wk.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
43
+ 'wv': {'kernel': np.concatenate([ckpt['layers.%d.attention.wv.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
44
+ 'wo': {'kernel': np.concatenate([ckpt['layers.%d.attention.wo.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
45
+ },
46
+ 'feed_forward': {
47
+ 'w1': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w1.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
48
+ 'w2': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w2.weight' % (layer)].numpy() for ckpt in ckpts], axis=1).transpose()},
49
+ 'w3': {'kernel': np.concatenate([ckpt['layers.%d.feed_forward.w3.weight' % (layer)].numpy() for ckpt in ckpts], axis=0).transpose()},
50
+ },
51
+ 'attention_norm': {'kernel': ckpts[0]['layers.%d.attention_norm.weight' % (layer)].numpy()},
52
+ 'ffn_norm': {'kernel': ckpts[0]['layers.%d.ffn_norm.weight' % (layer)].numpy()},
53
+ }
54
+ for layer in range(params['n_layers'])},
55
+ },
56
+ 'lm_head': {'kernel': np.concatenate([ckpt['output.weight'].numpy() for ckpt in ckpts], axis=0).transpose()},
57
+ }
58
+ if FLAGS.streaming:
59
+ StreamingCheckpointer.save_train_state_to_file(
60
+ jax_weights, FLAGS.output_file
61
+ )
62
+ else:
63
+ with mlxu.open_file(FLAGS.output_file, 'wb') as fout:
64
+ fout.write(flax.serialization.msgpack_serialize(jax_weights, in_place=True))
65
+
66
+
67
+ if __name__ == '__main__':
68
+ mlxu.run(main)
EasyLM/models/llama/llama_model.py ADDED
@@ -0,0 +1,1257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from shutil import copyfile
3
+ from typing import Any, Dict, List, Optional, Tuple, Union
4
+ import json
5
+ import tempfile
6
+
7
+ import numpy as np
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax import lax
11
+ from jax.sharding import PartitionSpec as PS
12
+ import flax.linen as nn
13
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
14
+ from flax.linen import combine_masks, make_causal_mask
15
+ from flax.linen.attention import dot_product_attention_weights
16
+ from flax.traverse_util import flatten_dict, unflatten_dict
17
+ from flax.linen import partitioning as nn_partitioning
18
+
19
+ import sentencepiece as spm
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+ from transformers.tokenization_utils import PreTrainedTokenizer
23
+ from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput
24
+ from transformers.modeling_flax_utils import ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring
25
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging
26
+
27
+ from ml_collections import ConfigDict
28
+ from ml_collections.config_dict import config_dict
29
+ from mlxu import function_args_to_config, load_pickle, open_file
30
+
31
+ from EasyLM.jax_utils import (
32
+ with_sharding_constraint, get_jax_mesh, get_gradient_checkpoint_policy
33
+ )
34
+
35
+
36
+ LLAMA_STANDARD_CONFIGS = {
37
+ 'small': {
38
+ 'vocab_size': 64256,
39
+ 'hidden_size': 768,
40
+ 'intermediate_size': 3072,
41
+ 'num_hidden_layers': 12,
42
+ 'num_attention_heads': 12,
43
+ 'max_sequence_length': 2048,
44
+ 'initializer_range': 0.02,
45
+ 'rms_norm_eps': 1e-6,
46
+ 'use_cache': True,
47
+ 'tie_word_embeddings': False,
48
+ },
49
+ 'medium': {
50
+ 'vocab_size': 64256,
51
+ 'hidden_size': 1024,
52
+ 'intermediate_size': 4096,
53
+ 'num_hidden_layers': 24,
54
+ 'num_attention_heads': 16,
55
+ 'max_sequence_length': 2048,
56
+ 'initializer_range': 0.02,
57
+ 'rms_norm_eps': 1e-6,
58
+ 'use_cache': True,
59
+ 'tie_word_embeddings': False,
60
+ },
61
+ 'large': {
62
+ 'vocab_size': 64256,
63
+ 'hidden_size': 1536,
64
+ 'intermediate_size': 6144,
65
+ 'num_hidden_layers': 24,
66
+ 'num_attention_heads': 16,
67
+ 'max_sequence_length': 2048,
68
+ 'initializer_range': 0.02,
69
+ 'rms_norm_eps': 1e-6,
70
+ 'use_cache': True,
71
+ 'tie_word_embeddings': False,
72
+ },
73
+ 'xlarge': {
74
+ 'vocab_size': 64256,
75
+ 'hidden_size': 2048,
76
+ 'intermediate_size': 8192,
77
+ 'num_hidden_layers': 24,
78
+ 'num_attention_heads': 32,
79
+ 'max_sequence_length': 2048,
80
+ 'initializer_range': 0.02,
81
+ 'rms_norm_eps': 1e-6,
82
+ 'use_cache': True,
83
+ 'tie_word_embeddings': False,
84
+ },
85
+ '3b': {
86
+ 'vocab_size': 64256,
87
+ 'hidden_size': 3200,
88
+ 'intermediate_size': 8640,
89
+ 'num_hidden_layers': 26,
90
+ 'num_attention_heads': 32,
91
+ 'max_sequence_length': 2048,
92
+ 'initializer_range': 0.02,
93
+ 'rms_norm_eps': 1e-6,
94
+ 'use_cache': True,
95
+ 'tie_word_embeddings': False,
96
+ },
97
+ '7b': {
98
+ 'vocab_size': 64256,
99
+ 'hidden_size': 4096,
100
+ 'intermediate_size': 11008,
101
+ 'num_hidden_layers': 32,
102
+ 'num_attention_heads': 32,
103
+ 'max_sequence_length': 2048,
104
+ 'initializer_range': 0.02,
105
+ 'rms_norm_eps': 1e-6,
106
+ 'use_cache': True,
107
+ 'tie_word_embeddings': False,
108
+ },
109
+ '13b': {
110
+ 'vocab_size': 64256,
111
+ 'hidden_size': 5120,
112
+ 'intermediate_size': 13824,
113
+ 'num_hidden_layers': 40,
114
+ 'num_attention_heads': 40,
115
+ 'max_sequence_length': 2048,
116
+ 'initializer_range': 0.02,
117
+ 'rms_norm_eps': 1e-6,
118
+ 'use_cache': True,
119
+ 'tie_word_embeddings': False,
120
+ },
121
+ '30b': {
122
+ 'vocab_size': 64256,
123
+ 'hidden_size': 6656,
124
+ 'intermediate_size': 17920,
125
+ 'num_hidden_layers': 60,
126
+ 'num_attention_heads': 52,
127
+ 'max_sequence_length': 2048,
128
+ 'initializer_range': 0.02,
129
+ 'rms_norm_eps': 1e-6,
130
+ 'use_cache': True,
131
+ 'tie_word_embeddings': False,
132
+ },
133
+ '65b': {
134
+ 'vocab_size': 64256,
135
+ 'hidden_size': 8192,
136
+ 'intermediate_size': 22016,
137
+ 'num_hidden_layers': 80,
138
+ 'num_attention_heads': 64,
139
+ 'max_sequence_length': 2048,
140
+ 'initializer_range': 0.02,
141
+ 'rms_norm_eps': 1e-5,
142
+ 'use_cache': True,
143
+ 'tie_word_embeddings': False,
144
+ },
145
+ 'debug': { # A small model for debugging
146
+ 'vocab_size': 64256,
147
+ 'hidden_size': 128,
148
+ 'intermediate_size': 256,
149
+ 'num_hidden_layers': 2,
150
+ 'num_attention_heads': 4,
151
+ 'max_sequence_length': 2048,
152
+ 'initializer_range': 0.02,
153
+ 'rms_norm_eps': 1e-6,
154
+ 'use_cache': True,
155
+ 'tie_word_embeddings': False,
156
+ },
157
+ }
158
+
159
+
160
+ class LLaMAConfig(PretrainedConfig):
161
+ r"""
162
+ This is the configuration class to store the configuration of a [`~LLaMAModel`]. It is used to instantiate an LLaMA
163
+ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
164
+ defaults will yield a similar configuration to that of the LLaMA-7B.
165
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
166
+ documentation from [`PretrainedConfig`] for more information.
167
+ Args:
168
+ vocab_size (`int`, *optional*, defaults to 32000):
169
+ Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the
170
+ `inputs_ids` passed when calling [`~LLaMAModel`] or [`~TFLLaMAModel`].
171
+ hidden_size (`int`, *optional*, defaults to 4096):
172
+ Dimension of the hidden representations.
173
+ intermediate_size (`int`, *optional*, defaults to 11008):
174
+ Dimension of the MLP representations.
175
+ num_hidden_layers (`int`, *optional*, defaults to 32):
176
+ Number of hidden layers in the Transformer encoder.
177
+ num_attention_heads (`int`, *optional*, defaults to 32):
178
+ Number of attention heads for each attention layer in the Transformer encoder.
179
+ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
180
+ The non-linear activation function (function or string) in the decoder.
181
+ max_sequence_length (`int`, *optional*, defaults to 2048):
182
+ Max sequence length for model (for RoPE computation)
183
+ initializer_range (`float`, *optional*, defaults to 0.02):
184
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
185
+ rms_norm_eps (`float`, *optional*, defaults to 1e-12):
186
+ The epsilon used by the rms normalization layers.
187
+ use_cache (`bool`, *optional*, defaults to `True`):
188
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
189
+ relevant if `config.is_decoder=True`.
190
+ tie_word_embeddings(`bool`, *optional*, defaults to `False`):
191
+ Whether to tie weight embeddings
192
+ Example:
193
+ ```python
194
+ >>> from transformers import LLaMAModel, LLaMAConfig
195
+ >>> # Initializing a LLaMA llama-7b style configuration
196
+ >>> configuration = LLaMAConfig()
197
+ >>> # Initializing a model from the llama-7b style configuration
198
+ >>> model = LLaMAModel(configuration)
199
+ >>> # Accessing the model configuration
200
+ >>> configuration = model.config
201
+ ```"""
202
+ model_type = "llama"
203
+
204
+ def __init__(
205
+ self,
206
+ vocab_size=32000,
207
+ hidden_size=4096,
208
+ intermediate_size=11008,
209
+ num_hidden_layers=32,
210
+ num_attention_heads=32,
211
+ max_sequence_length=2048,
212
+ rms_norm_eps=1e-6,
213
+ initializer_range=0.02,
214
+ use_cache=True,
215
+ # pad_token_id=-1,
216
+ bos_token_id=0,
217
+ eos_token_id=1,
218
+ resid_pdrop=0.0,
219
+ embd_pdrop=0.0,
220
+ attn_pdrop=0.0,
221
+ tie_word_embeddings=False,
222
+ gradient_checkpointing='nothing_saveable',
223
+ fcm_min_ratio=0.0,
224
+ fcm_max_ratio=0.0,
225
+ **kwargs,
226
+ ):
227
+ self.vocab_size = vocab_size
228
+ self.hidden_size = hidden_size
229
+ self.initializer_range = initializer_range
230
+ self.intermediate_size = intermediate_size
231
+ self.num_hidden_layers = num_hidden_layers
232
+ self.num_attention_heads = num_attention_heads
233
+ self.max_sequence_length = max_sequence_length
234
+ self.rms_norm_eps = rms_norm_eps
235
+ self.use_cache = use_cache
236
+ self.resid_pdrop = resid_pdrop
237
+ self.embd_pdrop = embd_pdrop
238
+ self.attn_pdrop = attn_pdrop
239
+ self.gradient_checkpointing = gradient_checkpointing
240
+ self.fcm_min_ratio = fcm_min_ratio
241
+ self.fcm_max_ratio = fcm_max_ratio
242
+ super().__init__(
243
+ # pad_token_id=pad_token_id,
244
+ bos_token_id=bos_token_id,
245
+ eos_token_id=eos_token_id,
246
+ tie_word_embeddings=tie_word_embeddings,
247
+ **kwargs,
248
+ )
249
+
250
+ @classmethod
251
+ def get_default_config(cls, updates=None):
252
+ config = function_args_to_config(cls.__init__)
253
+
254
+ if updates is not None:
255
+ config.update(ConfigDict(updates).copy_and_resolve_references())
256
+
257
+ return config
258
+
259
+ @staticmethod
260
+ def get_jax_mesh(axis_dims):
261
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
262
+
263
+ @staticmethod
264
+ def get_partition_rules():
265
+ """ Parition rules for GPTJ. Note that these rules are orderd, so that
266
+ the beginning rules match first. It is important to use
267
+ PartitionSpec() instead of None here because JAX does not treat
268
+ None as a pytree leaf.
269
+ """
270
+ return (
271
+ # embeddings
272
+ ("transformer/wte/embedding", PS("mp", "fsdp")),
273
+ # atention
274
+ ("attention/(wq|wk|wv)/kernel", PS("fsdp", "mp")),
275
+ ("attention/wo/kernel", PS("mp", "fsdp")),
276
+ # mlp
277
+ ("feed_forward/w1/kernel", PS("fsdp", "mp")),
278
+ ("feed_forward/w2/kernel", PS("mp", "fsdp")),
279
+ ("feed_forward/w3/kernel", PS("fsdp", "mp")),
280
+ # layer norms
281
+ ("attention_norm/kernel", PS(None)),
282
+ ("ffn_norm/kernel", PS(None)),
283
+ # output head
284
+ ("transformer/ln_f/kernel", PS(None)),
285
+ ("lm_head/kernel", PS("fsdp", "mp")),
286
+ ('.*', PS(None)),
287
+ )
288
+
289
+ @staticmethod
290
+ def get_weight_decay_exclusions():
291
+ return (
292
+ "attention_norm/kernel",
293
+ "ffn_norm/kernel",
294
+ "transformer/ln_f/kernel",
295
+ )
296
+
297
+ @staticmethod
298
+ def rng_keys():
299
+ return ('params', 'dropout', 'fcm')
300
+
301
+ @staticmethod
302
+ def get_tokenizer_config(updates=None):
303
+ config = ConfigDict()
304
+ config.vocab_file = ''
305
+ config.add_bos_token = False
306
+ config.add_eos_token = False
307
+
308
+ if updates is not None:
309
+ config.update(ConfigDict(updates).copy_and_resolve_references())
310
+ return config
311
+
312
+ @classmethod
313
+ def get_tokenizer(cls, config, padding_side='left', truncation_side='right'):
314
+ config = cls.get_tokenizer_config(config)
315
+ assert config.vocab_file != '', 'vocab_file must be specified'
316
+ tokenizer = LLaMATokenizer(
317
+ vocab_file=config.vocab_file,
318
+ add_bos_token=config.add_bos_token,
319
+ add_eos_token=config.add_eos_token,
320
+ padding_side=padding_side,
321
+ truncation_side=truncation_side,
322
+ )
323
+ return tokenizer
324
+
325
+ @classmethod
326
+ def load_config(cls, path):
327
+ if path in LLAMA_STANDARD_CONFIGS:
328
+ return cls.from_dict(LLAMA_STANDARD_CONFIGS[path])
329
+ load_type, load_path = path.split('::', 1)
330
+ if load_type == 'pickle':
331
+ return cls.from_dict(load_pickle(load_path)['llama_config'])
332
+ elif load_type == 'json':
333
+ with open_file(load_path, 'r') as fin:
334
+ raw_config = fin.read()
335
+ return cls.from_dict(json.loads(raw_config))
336
+ else:
337
+ raise ValueError(f'Unsupported load config type: {load_type}')
338
+
339
+
340
+ remat = nn_partitioning.remat
341
+
342
+ logger = logging.get_logger(__name__)
343
+
344
+
345
+ class RMSNorm(nn.Module):
346
+ dim: int
347
+ eps: float=1e-6
348
+ dtype: jnp.dtype=jnp.float32
349
+ param_dtype: jnp.dtype=jnp.float32
350
+
351
+ def setup(self) -> None:
352
+ self.weight = self.param(
353
+ 'kernel',
354
+ nn.initializers.ones,
355
+ (self.dim,),
356
+ self.param_dtype,
357
+ )
358
+
359
+ def _norm(self, x: jnp.ndarray) -> jnp.ndarray:
360
+ return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
361
+
362
+ def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
363
+ x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
364
+ output = self._norm(x).astype(self.dtype)
365
+ weight = jnp.asarray(self.weight, self.dtype)
366
+ return output * weight
367
+
368
+ def precompute_freqs_cis(dim: int, end: int, theta: float=10000.0, dtype: jnp.dtype=jnp.float32) -> jnp.ndarray:
369
+ freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
370
+ t = np.arange(end) # type: ignore
371
+ freqs = np.outer(t, freqs).astype(dtype) # type: ignore
372
+ sin, cos = np.sin(freqs), np.cos(freqs)
373
+ freqs_cis = np.complex64(cos + 1j * sin)
374
+ return jnp.asarray(freqs_cis)
375
+
376
+ def apply_rotary_emb(
377
+ xq: jnp.ndarray,
378
+ xk: jnp.ndarray,
379
+ freqs_cis: jnp.ndarray,
380
+ dtype: jnp.dtype=jnp.float32,
381
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
382
+
383
+ reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
384
+ reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
385
+
386
+ xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
387
+ xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
388
+
389
+ # add head dim
390
+ freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
391
+
392
+ xq_out = xq_ * freqs_cis
393
+ xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
394
+
395
+ xk_out = xk_ * freqs_cis
396
+ xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
397
+
398
+ return xq_out.astype(dtype), xk_out.astype(dtype)
399
+
400
+
401
+ class FlaxLLaMAAttention(nn.Module):
402
+ config: LLaMAConfig
403
+ dtype: jnp.dtype=jnp.float32
404
+ param_dtype: jnp.dtype=jnp.float32
405
+ precision: Optional[Union[jax.lax.Precision, str]]=None
406
+
407
+ def setup(self):
408
+ config = self.config
409
+ self.embed_dim = config.hidden_size
410
+ self.num_heads = config.num_attention_heads
411
+ self.head_dim = self.embed_dim // self.num_heads
412
+
413
+ self.wq = nn.Dense(
414
+ config.num_attention_heads*self.head_dim,
415
+ dtype=self.dtype,
416
+ param_dtype=self.param_dtype,
417
+ use_bias=False,
418
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
419
+ precision=self.precision,
420
+ )
421
+ self.wk = nn.Dense(
422
+ config.num_attention_heads*self.head_dim,
423
+ dtype=self.dtype,
424
+ param_dtype=self.param_dtype,
425
+ use_bias=False,
426
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
427
+ precision=self.precision,
428
+ )
429
+ self.wv = nn.Dense(
430
+ config.num_attention_heads*self.head_dim,
431
+ dtype=self.dtype,
432
+ param_dtype=self.param_dtype,
433
+ use_bias=False,
434
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
435
+ precision=self.precision,
436
+ )
437
+ self.wo = nn.Dense(
438
+ config.hidden_size,
439
+ dtype=self.dtype,
440
+ param_dtype=self.param_dtype,
441
+ use_bias=False,
442
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
443
+ precision=self.precision,
444
+ )
445
+
446
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
447
+
448
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
449
+
450
+ self.freqs_cis = precompute_freqs_cis(
451
+ self.head_dim,
452
+ config.max_sequence_length * 2,
453
+ dtype=self.dtype,
454
+ )
455
+
456
+ def _split_heads(self, hidden_states):
457
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
458
+
459
+ def _merge_heads(self, hidden_states):
460
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
461
+
462
+ @nn.compact
463
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
464
+ """
465
+ This function takes projected key, value states from a single input token and concatenates the states to cached
466
+ states from previous steps. This function is slighly adapted from the official Flax repository:
467
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
468
+ """
469
+ # detect if we're initializing by absence of existing cache data.
470
+ is_initialized = self.has_variable("cache", "cached_key")
471
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
472
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
473
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
474
+
475
+ if is_initialized:
476
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
477
+ # update key, value caches with our new 1d spatial slices
478
+ cur_index = cache_index.value
479
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
480
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
481
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
482
+ cached_key.value = key
483
+ cached_value.value = value
484
+ num_updated_cache_vectors = query.shape[1]
485
+ cache_index.value = cache_index.value + num_updated_cache_vectors
486
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
487
+ pad_mask = jnp.broadcast_to(
488
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
489
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
490
+ )
491
+ attention_mask = combine_masks(pad_mask, attention_mask)
492
+ return key, value, attention_mask
493
+
494
+ def __call__(
495
+ self,
496
+ hidden_states,
497
+ attention_mask,
498
+ position_ids,
499
+ deterministic: bool = True,
500
+ init_cache: bool = False,
501
+ output_attentions: bool = False,
502
+ fcm_mask=None,
503
+ ):
504
+ xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
505
+
506
+ xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
507
+ xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
508
+ xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
509
+
510
+ xq = self._split_heads(xq)
511
+ xk = self._split_heads(xk)
512
+ xv = self._split_heads(xv)
513
+
514
+ freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
515
+
516
+ xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
517
+
518
+ query_length, key_length = xq.shape[1], xk.shape[1]
519
+
520
+ if self.has_variable("cache", "cached_key"):
521
+ mask_shift = self.variables["cache"]["cache_index"]
522
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
523
+ causal_mask = lax.dynamic_slice(
524
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
525
+ )
526
+ else:
527
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
528
+
529
+ batch_size = hidden_states.shape[0]
530
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
531
+
532
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
533
+ attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
534
+
535
+ dropout_rng = None
536
+ if not deterministic and self.config.attn_pdrop > 0.0:
537
+ dropout_rng = self.make_rng("dropout")
538
+
539
+ # During fast autoregressive decoding, we feed one position at a time,
540
+ # and cache the keys and values step by step.
541
+ if self.has_variable("cache", "cached_key") or init_cache:
542
+ xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
543
+
544
+ # transform boolean mask into float mask
545
+ attention_bias = lax.select(
546
+ attention_mask > 0,
547
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
548
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
549
+ )
550
+
551
+ # usual dot product attention
552
+ attn_weights = dot_product_attention_weights(
553
+ xq,
554
+ xk,
555
+ bias=attention_bias,
556
+ dropout_rng=dropout_rng,
557
+ dropout_rate=self.config.attn_pdrop,
558
+ deterministic=deterministic,
559
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
560
+ precision=self.precision,
561
+ )
562
+ attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
563
+
564
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
565
+ attn_output = self._merge_heads(attn_output)
566
+ attn_output = self.wo(attn_output)
567
+ attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
568
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
569
+ return outputs
570
+
571
+
572
+ class FlaxLLaMAMLP(nn.Module):
573
+ config: LLaMAConfig
574
+ dtype: jnp.dtype=jnp.float32
575
+ param_dtype: jnp.dtype=jnp.float32
576
+ precision: Optional[Union[jax.lax.Precision, str]]=None
577
+
578
+ def setup(self) -> None:
579
+ config = self.config
580
+
581
+ self.w1 = nn.Dense(
582
+ config.intermediate_size,
583
+ dtype=self.dtype,
584
+ param_dtype=self.param_dtype,
585
+ use_bias=False,
586
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
587
+ precision=self.precision,
588
+ )
589
+ self.w2 = nn.Dense(
590
+ config.hidden_size,
591
+ dtype=self.dtype,
592
+ param_dtype=self.param_dtype,
593
+ use_bias=False,
594
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
595
+ precision=self.precision,
596
+ )
597
+ self.w3 = nn.Dense(
598
+ config.intermediate_size,
599
+ dtype=self.dtype,
600
+ param_dtype=self.param_dtype,
601
+ use_bias=False,
602
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
603
+ precision=self.precision,
604
+ )
605
+ self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
606
+
607
+ def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
608
+ x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
609
+ x = self.dropout(x, deterministic=deterministic)
610
+ return x
611
+
612
+
613
+ class FlaxLLaMABlock(nn.Module):
614
+ config: LLaMAConfig
615
+ dtype: jnp.dtype=jnp.float32
616
+ param_dtype: jnp.dtype=jnp.float32
617
+ precision: Optional[Union[jax.lax.Precision, str]]=None
618
+
619
+ def setup(self) -> None:
620
+ self.attention = FlaxLLaMAAttention(
621
+ self.config,
622
+ dtype=self.dtype,
623
+ param_dtype=self.param_dtype,
624
+ precision=self.precision,
625
+ )
626
+ self.feed_forward = FlaxLLaMAMLP(
627
+ self.config,
628
+ dtype=self.dtype,
629
+ param_dtype=self.param_dtype,
630
+ precision=self.precision,
631
+ )
632
+ self.attention_norm = RMSNorm(
633
+ self.config.hidden_size,
634
+ eps=self.config.rms_norm_eps,
635
+ dtype=self.dtype,
636
+ param_dtype=self.param_dtype,
637
+ )
638
+ self.ffn_norm = RMSNorm(
639
+ self.config.hidden_size,
640
+ eps=self.config.rms_norm_eps,
641
+ dtype=self.dtype,
642
+ param_dtype=self.param_dtype,
643
+ )
644
+
645
+ def __call__(
646
+ self,
647
+ hidden_states,
648
+ attention_mask=None,
649
+ position_ids=None,
650
+ deterministic: bool = True,
651
+ init_cache: bool = False,
652
+ output_attentions: bool = False,
653
+ fcm_mask: Optional[jnp.ndarray] = None,
654
+ ):
655
+ attn_outputs = self.attention(
656
+ self.attention_norm(hidden_states),
657
+ attention_mask=attention_mask,
658
+ position_ids=position_ids,
659
+ deterministic=deterministic,
660
+ init_cache=init_cache,
661
+ output_attentions=output_attentions,
662
+ fcm_mask=fcm_mask,
663
+ )
664
+ attn_output = attn_outputs[0]
665
+ hidden_states = hidden_states + attn_output
666
+
667
+ feed_forward_hidden_states = self.feed_forward(
668
+ self.ffn_norm(hidden_states),
669
+ deterministic=deterministic,
670
+ )
671
+ hidden_states = hidden_states + feed_forward_hidden_states
672
+
673
+ return (hidden_states,) + attn_outputs[1:]
674
+
675
+
676
+ class FlaxLLaMAPreTrainedModel(FlaxPreTrainedModel):
677
+ """
678
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
679
+ models.
680
+ """
681
+
682
+ config_class = LLaMAConfig
683
+ base_model_prefix = "transformer"
684
+ module_class: nn.Module = None
685
+
686
+ def __init__(
687
+ self,
688
+ config: LLaMAConfig,
689
+ input_shape: Tuple = (1, 1),
690
+ seed: int = 0,
691
+ dtype: jnp.dtype = jnp.float32,
692
+ _do_init: bool = True,
693
+ **kwargs,
694
+ ):
695
+ module = self.module_class(config=config, dtype=dtype, **kwargs)
696
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
697
+
698
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
699
+ # init input tensors
700
+ input_ids = jnp.zeros(input_shape, dtype="i4")
701
+ attention_mask = jnp.ones_like(input_ids)
702
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)
703
+ params_rng, dropout_rng = jax.random.split(rng)
704
+ rngs = {"params": params_rng, "dropout": dropout_rng}
705
+
706
+ if self.config.add_cross_attention:
707
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
708
+ encoder_attention_mask = attention_mask
709
+ module_init_outputs = self.module.init(
710
+ rngs,
711
+ input_ids,
712
+ attention_mask,
713
+ position_ids,
714
+ encoder_hidden_states,
715
+ encoder_attention_mask,
716
+ return_dict=False,
717
+ )
718
+ else:
719
+ module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)
720
+
721
+ random_params = module_init_outputs["params"]
722
+
723
+ if params is not None:
724
+ random_params = flatten_dict(unfreeze(random_params))
725
+ params = flatten_dict(unfreeze(params))
726
+ for missing_key in self._missing_keys:
727
+ params[missing_key] = random_params[missing_key]
728
+ self._missing_keys = set()
729
+ return freeze(unflatten_dict(params))
730
+ else:
731
+ return random_params
732
+
733
+ def init_cache(self, batch_size, max_length):
734
+ r"""
735
+ Args:
736
+ batch_size (`int`):
737
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
738
+ max_length (`int`):
739
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
740
+ cache.
741
+ """
742
+ # init input variables to retrieve cache
743
+ input_ids = jnp.ones((batch_size, max_length))
744
+ attention_mask = jnp.ones_like(input_ids)
745
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
746
+
747
+ init_variables = self.module.init(
748
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
749
+ )
750
+ return init_variables["cache"]
751
+
752
+ @add_start_docstrings_to_model_forward("")
753
+ def __call__(
754
+ self,
755
+ input_ids,
756
+ attention_mask=None,
757
+ position_ids=None,
758
+ params: dict = None,
759
+ past_key_values: dict = None,
760
+ dropout_rng: jax.random.PRNGKey = None,
761
+ train: bool = False,
762
+ output_attentions: Optional[bool] = None,
763
+ output_hidden_states: Optional[bool] = None,
764
+ return_dict: Optional[bool] = None,
765
+ ):
766
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
767
+ output_hidden_states = (
768
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
769
+ )
770
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
771
+
772
+ batch_size, sequence_length = input_ids.shape
773
+
774
+ if position_ids is None:
775
+ if past_key_values is not None:
776
+ raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
777
+
778
+ position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
779
+
780
+ if attention_mask is None:
781
+ attention_mask = jnp.ones((batch_size, sequence_length))
782
+
783
+ # Handle any PRNG if needed
784
+ rngs = {}
785
+ if dropout_rng is not None:
786
+ rngs["dropout"] = dropout_rng
787
+
788
+ inputs = {"params": params or self.params}
789
+
790
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPTJAttention module
791
+ if past_key_values:
792
+ inputs["cache"] = past_key_values
793
+ mutable = ["cache"]
794
+ else:
795
+ mutable = False
796
+
797
+ outputs = self.module.apply(
798
+ inputs,
799
+ jnp.array(input_ids, dtype="i4"),
800
+ jnp.array(attention_mask, dtype="i4"),
801
+ jnp.array(position_ids, dtype="i4"),
802
+ not train,
803
+ False,
804
+ output_attentions,
805
+ output_hidden_states,
806
+ return_dict,
807
+ rngs=rngs,
808
+ mutable=mutable,
809
+ )
810
+
811
+ # add updated cache to model output
812
+ if past_key_values is not None and return_dict:
813
+ outputs, past_key_values = outputs
814
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
815
+ return outputs
816
+ elif past_key_values is not None and not return_dict:
817
+ outputs, past_key_values = outputs
818
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
819
+
820
+ return outputs
821
+
822
+
823
+ class FlaxLLaMABlockCollection(nn.Module):
824
+ config: LLaMAConfig
825
+ dtype: jnp.dtype = jnp.float32
826
+ param_dtype: jnp.dtype=jnp.float32
827
+ precision: Optional[Union[jax.lax.Precision, str]]=None
828
+
829
+ def setup(self):
830
+ block = FlaxLLaMABlock
831
+ if self.config.gradient_checkpointing != '':
832
+ FlaxLLaMACheckpointBlock = remat(
833
+ block, static_argnums=(3, 4, 5),
834
+ policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
835
+ )
836
+ block = FlaxLLaMACheckpointBlock
837
+ self.blocks = [
838
+ block(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision) for i in range(self.config.num_hidden_layers)
839
+ ]
840
+
841
+ def __call__(
842
+ self,
843
+ hidden_states,
844
+ attention_mask=None,
845
+ position_ids=None,
846
+ deterministic: bool = True,
847
+ init_cache: bool = False,
848
+ output_attentions: bool = False,
849
+ output_hidden_states: bool = False,
850
+ return_dict: bool = True,
851
+ ):
852
+ all_attentions = () if output_attentions else None
853
+ all_hidden_states = () if output_hidden_states else None
854
+
855
+ if not deterministic and self.config.fcm_max_ratio > 0:
856
+ # Apply forgetful causal mask
857
+ batch_size, seq_length = hidden_states.shape[0], hidden_states.shape[1]
858
+ fcm_ratio = jax.random.uniform(
859
+ self.make_rng('fcm'), shape=(batch_size, 1, 1, 1),
860
+ minval=self.config.fcm_min_ratio,
861
+ maxval=self.config.fcm_max_ratio
862
+ )
863
+ fcm_mask = jax.random.uniform(
864
+ self.make_rng('fcm'),
865
+ shape=(batch_size, 1, seq_length, seq_length)
866
+ ) > fcm_ratio
867
+ fcm_mask = fcm_mask.at[:, :, :, 0].set(True)
868
+ fcm_mask = fcm_mask.astype('bool')
869
+ else:
870
+ fcm_mask = None
871
+
872
+ for block in self.blocks:
873
+ if output_hidden_states:
874
+ all_hidden_states += (hidden_states,)
875
+
876
+ layer_outputs = block(
877
+ hidden_states,
878
+ attention_mask,
879
+ position_ids,
880
+ deterministic,
881
+ init_cache,
882
+ output_attentions,
883
+ fcm_mask,
884
+ )
885
+ hidden_states = layer_outputs[0]
886
+
887
+ if output_attentions:
888
+ all_attentions += (layer_outputs[1],)
889
+
890
+ # this contains possible `None` values - `FlaxGPTJModule` will filter them out
891
+ outputs = (hidden_states, all_hidden_states, all_attentions)
892
+
893
+ return outputs
894
+
895
+
896
+ class FlaxLLaMAModule(nn.Module):
897
+ config: LLaMAConfig
898
+ dtype: jnp.dtype = jnp.float32
899
+ param_dtype: jnp.dtype=jnp.float32
900
+ precision: Optional[Union[jax.lax.Precision, str]]=None
901
+
902
+ def setup(self):
903
+ self.embed_dim = self.config.hidden_size
904
+
905
+ self.wte = nn.Embed(
906
+ self.config.vocab_size,
907
+ self.config.hidden_size,
908
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
909
+ dtype=self.dtype,
910
+ param_dtype=self.param_dtype,
911
+ )
912
+ self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
913
+ self.h = FlaxLLaMABlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
914
+ self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype, param_dtype=self.param_dtype)
915
+
916
+ def __call__(
917
+ self,
918
+ input_ids,
919
+ attention_mask,
920
+ position_ids,
921
+ deterministic=True,
922
+ init_cache: bool = False,
923
+ output_attentions: bool = False,
924
+ output_hidden_states: bool = False,
925
+ return_dict: bool = True,
926
+ ):
927
+ input_embeds = self.wte(input_ids.astype("i4"))
928
+
929
+ hidden_states = self.dropout(input_embeds, deterministic=deterministic)
930
+
931
+ outputs = self.h(
932
+ hidden_states,
933
+ attention_mask,
934
+ position_ids=position_ids,
935
+ deterministic=deterministic,
936
+ init_cache=init_cache,
937
+ output_attentions=output_attentions,
938
+ output_hidden_states=output_hidden_states,
939
+ return_dict=return_dict,
940
+ )
941
+
942
+ hidden_states = outputs[0]
943
+ hidden_states = self.ln_f(hidden_states)
944
+
945
+ if output_hidden_states:
946
+ all_hidden_states = outputs[1] + (hidden_states,)
947
+ outputs = (hidden_states, all_hidden_states) + outputs[2:]
948
+ else:
949
+ outputs = (hidden_states,) + outputs[1:]
950
+
951
+ if not return_dict:
952
+ return tuple(v for v in outputs if v is not None)
953
+
954
+ return FlaxBaseModelOutput(
955
+ last_hidden_state=hidden_states,
956
+ hidden_states=outputs[1],
957
+ attentions=outputs[-1],
958
+ )
959
+
960
+ @add_start_docstrings("", "")
961
+ class FlaxLLaMAModel(FlaxLLaMAPreTrainedModel):
962
+ module_class = FlaxLLaMAModule
963
+
964
+ # append_call_sample_docstring(
965
+ # FlaxLLaMAModel,
966
+ # _TOKENIZER_FOR_DOC,
967
+ # _CHECKPOINT_FOR_DOC,
968
+ # FlaxCausalLMOutput,
969
+ # _CONFIG_FOR_DOC,
970
+ # )
971
+
972
+ class FlaxLLaMAForCausalLMModule(nn.Module):
973
+ config: LLaMAConfig
974
+ dtype: jnp.dtype = jnp.float32
975
+ param_dtype: jnp.dtype=jnp.float32
976
+ precision: Optional[Union[jax.lax.Precision, str]]=None
977
+
978
+ def setup(self):
979
+ self.transformer = FlaxLLaMAModule(self.config, dtype=self.dtype)
980
+ self.lm_head = nn.Dense(
981
+ self.config.vocab_size,
982
+ dtype=self.dtype,
983
+ param_dtype=self.param_dtype,
984
+ use_bias=False,
985
+ kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
986
+ precision=self.precision,
987
+ )
988
+
989
+ def __call__(
990
+ self,
991
+ input_ids,
992
+ attention_mask=None,
993
+ position_ids=None,
994
+ deterministic: bool = True,
995
+ init_cache: bool = False,
996
+ output_attentions: bool = False,
997
+ output_hidden_states: bool = False,
998
+ return_dict: bool = True,
999
+ ):
1000
+ batch_size, seq_length = input_ids.shape
1001
+ if attention_mask is None:
1002
+ attention_mask = jnp.ones_like(input_ids)
1003
+ if position_ids is None:
1004
+ position_ids = jnp.broadcast_to(
1005
+ jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1006
+ (batch_size, seq_length)
1007
+ )
1008
+ outputs = self.transformer(
1009
+ input_ids,
1010
+ attention_mask,
1011
+ position_ids,
1012
+ deterministic=deterministic,
1013
+ init_cache=init_cache,
1014
+ output_attentions=output_attentions,
1015
+ output_hidden_states=output_hidden_states,
1016
+ return_dict=return_dict,
1017
+ )
1018
+
1019
+ hidden_states = outputs[0]
1020
+
1021
+ if self.config.tie_word_embeddings:
1022
+ shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
1023
+ lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
1024
+ else:
1025
+ lm_logits = self.lm_head(hidden_states)
1026
+
1027
+ if not return_dict:
1028
+ return (lm_logits,) + outputs[1:]
1029
+
1030
+ return FlaxCausalLMOutput(logits=lm_logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
1031
+
1032
+
1033
+ @add_start_docstrings("", "")
1034
+ class FlaxLLaMAForCausalLM(FlaxLLaMAPreTrainedModel):
1035
+ module_class = FlaxLLaMAForCausalLMModule
1036
+
1037
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1038
+ # initializing the cache
1039
+ batch_size, seq_length = input_ids.shape
1040
+
1041
+ past_key_values = self.init_cache(batch_size, max_length)
1042
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1043
+ # But since GPTJ uses a causal mask, those positions are masked anyways.
1044
+ # Thus we can create a single static attention_mask here, which is more efficient for compilation
1045
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1046
+ if attention_mask is not None:
1047
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1048
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1049
+ else:
1050
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1051
+
1052
+ return {
1053
+ "past_key_values": past_key_values,
1054
+ "attention_mask": extended_attention_mask,
1055
+ "position_ids": position_ids,
1056
+ }
1057
+
1058
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1059
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1060
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1061
+ return model_kwargs
1062
+
1063
+ # append_call_sample_docstring(
1064
+ # FlaxGPTJForCausalLM,
1065
+ # _TOKENIZER_FOR_DOC,
1066
+ # _CHECKPOINT_FOR_DOC,
1067
+ # FlaxCausalLMOutput,
1068
+ # _CONFIG_FOR_DOC,
1069
+ # )
1070
+
1071
+
1072
+
1073
+ VOCAB_FILES_NAMES = {"vocab_file": "tokenizer.model"}
1074
+
1075
+ PRETRAINED_VOCAB_FILES_MAP = {}
1076
+
1077
+
1078
+ class LLaMATokenizer(PreTrainedTokenizer):
1079
+ """
1080
+ Construct a LLaMA tokenizer. Based on byte-level Byte-Pair-Encoding.
1081
+ Args:
1082
+ vocab_file (`str`):
1083
+ Path to the vocabulary file.
1084
+ """
1085
+
1086
+ vocab_files_names = VOCAB_FILES_NAMES
1087
+ pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
1088
+ model_input_names = ["input_ids", "attention_mask"]
1089
+
1090
+ def __init__(
1091
+ self,
1092
+ vocab_file,
1093
+ unk_token="<unk>",
1094
+ bos_token="<s>",
1095
+ eos_token="</s>",
1096
+ sp_model_kwargs: Optional[Dict[str, Any]] = None,
1097
+ add_bos_token=False,
1098
+ add_eos_token=False,
1099
+ **kwargs,
1100
+ ):
1101
+ self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
1102
+ super().__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, **kwargs)
1103
+ self.vocab_file = vocab_file
1104
+ self.add_bos_token = add_bos_token
1105
+ self.add_eos_token = add_eos_token
1106
+ self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
1107
+
1108
+ with tempfile.NamedTemporaryFile() as tfile:
1109
+ with open_file(self.vocab_file, 'rb') as fin:
1110
+ tfile.write(fin.read())
1111
+ tfile.flush()
1112
+ tfile.seek(0)
1113
+ self.sp_model.Load(tfile.name)
1114
+ """ Initialisation"""
1115
+ self.add_special_tokens(dict(
1116
+ unk_token=unk_token,
1117
+ bos_token=bos_token,
1118
+ eos_token=eos_token,
1119
+ ))
1120
+ self.pad_token_id = self.unk_token_id
1121
+
1122
+ @property
1123
+ def vocab_size(self):
1124
+ """Returns vocab size"""
1125
+ return self.sp_model.get_piece_size()
1126
+
1127
+ @property
1128
+ def bos_token_id(self) -> Optional[int]:
1129
+ return self.sp_model.bos_id()
1130
+
1131
+ @property
1132
+ def eos_token_id(self) -> Optional[int]:
1133
+ return self.sp_model.eos_id()
1134
+
1135
+ def get_vocab(self):
1136
+ """Returns vocab as a dict"""
1137
+ vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
1138
+ vocab.update(self.added_tokens_encoder)
1139
+ return vocab
1140
+
1141
+ def _tokenize(self, text):
1142
+ """Returns a tokenized string."""
1143
+ return self.sp_model.encode(text, out_type=str)
1144
+
1145
+ def _convert_token_to_id(self, token):
1146
+ """Converts a token (str) in an id using the vocab."""
1147
+ return self.sp_model.piece_to_id(token)
1148
+
1149
+ def _convert_id_to_token(self, index):
1150
+ """Converts an index (integer) in a token (str) using the vocab."""
1151
+ token = self.sp_model.IdToPiece(index)
1152
+ return token
1153
+
1154
+ def convert_tokens_to_string(self, tokens):
1155
+ """Converts a sequence of tokens (string) in a single string."""
1156
+ current_sub_tokens = []
1157
+ out_string = ""
1158
+ prev_is_special = False
1159
+ for token in tokens:
1160
+ # make sure that special tokens are not decoded using sentencepiece model
1161
+ if token in self.all_special_tokens:
1162
+ if not prev_is_special:
1163
+ out_string += " "
1164
+ out_string += self.sp_model.decode(current_sub_tokens) + token
1165
+ prev_is_special = True
1166
+ current_sub_tokens = []
1167
+ else:
1168
+ current_sub_tokens.append(token)
1169
+ prev_is_special = False
1170
+ out_string += self.sp_model.decode(current_sub_tokens)
1171
+ return out_string.strip()
1172
+
1173
+ def save_vocabulary(self, save_directory, filename_prefix: Optional[str] = None) -> Tuple[str]:
1174
+ """
1175
+ Save the vocabulary and special tokens file to a directory.
1176
+ Args:
1177
+ save_directory (`str`):
1178
+ The directory in which to save the vocabulary.
1179
+ Returns:
1180
+ `Tuple(str)`: Paths to the files saved.
1181
+ """
1182
+ if not os.path.isdir(save_directory):
1183
+ logger.error(f"Vocabulary path ({save_directory}) should be a directory")
1184
+ return
1185
+ out_vocab_file = os.path.join(
1186
+ save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"]
1187
+ )
1188
+
1189
+ if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file):
1190
+ copyfile(self.vocab_file, out_vocab_file)
1191
+ elif not os.path.isfile(self.vocab_file):
1192
+ with open(out_vocab_file, "wb") as fi:
1193
+ content_spiece_model = self.sp_model.serialized_model_proto()
1194
+ fi.write(content_spiece_model)
1195
+
1196
+ return (out_vocab_file,)
1197
+
1198
+ def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
1199
+ if self.add_bos_token:
1200
+ bos_token_ids = [self.bos_token_id]
1201
+ else:
1202
+ bos_token_ids = []
1203
+
1204
+ output = bos_token_ids + token_ids_0
1205
+
1206
+ if token_ids_1 is not None:
1207
+ output = output + token_ids_1
1208
+
1209
+ if self.add_eos_token:
1210
+ output = output + [self.eos_token_id]
1211
+
1212
+ return output
1213
+
1214
+ def get_special_tokens_mask(
1215
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
1216
+ ) -> List[int]:
1217
+ """
1218
+ Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
1219
+ special tokens using the tokenizer `prepare_for_model` method.
1220
+ Args:
1221
+ token_ids_0 (`List[int]`):
1222
+ List of IDs.
1223
+ token_ids_1 (`List[int]`, *optional*):
1224
+ Optional second list of IDs for sequence pairs.
1225
+ already_has_special_tokens (`bool`, *optional*, defaults to `False`):
1226
+ Whether or not the token list is already formatted with special tokens for the model.
1227
+ Returns:
1228
+ `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
1229
+ """
1230
+ if already_has_special_tokens:
1231
+ return super().get_special_tokens_mask(
1232
+ token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
1233
+ )
1234
+
1235
+ if token_ids_1 is None:
1236
+ return [1] + ([0] * len(token_ids_0)) + [1]
1237
+ return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
1238
+
1239
+ def create_token_type_ids_from_sequences(
1240
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
1241
+ ) -> List[int]:
1242
+ """
1243
+ Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
1244
+ use of token type ids, therefore a list of zeros is returned.
1245
+ Args:
1246
+ token_ids_0 (`List[int]`):
1247
+ List of IDs.
1248
+ token_ids_1 (`List[int]`, *optional*):
1249
+ Optional second list of IDs for sequence pairs.
1250
+ Returns:
1251
+ `List[int]`: List of zeros.
1252
+ """
1253
+ eos = [self.eos_token_id]
1254
+
1255
+ if token_ids_1 is None:
1256
+ return len(token_ids_0 + eos) * [0]
1257
+ return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
EasyLM/models/llama/llama_serve.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ import numpy as np
5
+ import mlxu
6
+
7
+ import jax
8
+ import jax.numpy as jnp
9
+ from jax.experimental.pjit import pjit
10
+ from jax.sharding import PartitionSpec as PS
11
+ import optax
12
+ from transformers import GenerationConfig, FlaxLogitsProcessorList
13
+
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+ from EasyLM.serving import LMServer
16
+ from EasyLM.jax_utils import (
17
+ JaxRNG, next_rng, match_partition_rules, tree_apply,
18
+ set_random_seed, get_float_dtype_by_name, make_shard_and_gather_fns,
19
+ with_sharding_constraint, FlaxTemperatureLogitsWarper
20
+ )
21
+ from EasyLM.models.llama.llama_model import LLaMAConfig, FlaxLLaMAForCausalLM
22
+
23
+
24
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
25
+ seed=42,
26
+ initialize_jax_distributed=False,
27
+ mesh_dim='1,-1,1',
28
+ dtype='bf16',
29
+ input_length=1024,
30
+ seq_length=2048,
31
+ top_k=50,
32
+ top_p=1.0,
33
+ do_sample=True,
34
+ num_beams=1,
35
+ add_bos_token=True,
36
+ load_llama_config='',
37
+ load_checkpoint='',
38
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
39
+ lm_server=LMServer.get_default_config(),
40
+ )
41
+
42
+
43
+ def main(argv):
44
+ if FLAGS.initialize_jax_distributed:
45
+ jax.distributed.initialize()
46
+ set_random_seed(FLAGS.seed)
47
+
48
+ prefix_tokenizer = LLaMAConfig.get_tokenizer(
49
+ FLAGS.tokenizer, truncation_side='left', padding_side='left'
50
+ )
51
+ tokenizer = LLaMAConfig.get_tokenizer(
52
+ FLAGS.tokenizer, truncation_side='right', padding_side='right'
53
+ )
54
+
55
+ with jax.default_device(jax.devices("cpu")[0]):
56
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
57
+ _, params = StreamingCheckpointer.load_trainstate_checkpoint(
58
+ FLAGS.load_checkpoint, disallow_trainstate=True
59
+ )
60
+
61
+ hf_model = FlaxLLaMAForCausalLM(
62
+ llama_config,
63
+ input_shape=(1, FLAGS.seq_length),
64
+ seed=FLAGS.seed,
65
+ _do_init=False
66
+ )
67
+
68
+ model_ps = match_partition_rules(
69
+ LLaMAConfig.get_partition_rules(), params
70
+ )
71
+ shard_fns, _ = make_shard_and_gather_fns(
72
+ model_ps, get_float_dtype_by_name(FLAGS.dtype)
73
+ )
74
+
75
+ @partial(
76
+ pjit,
77
+ in_shardings=(model_ps, PS(), PS()),
78
+ out_shardings=(PS(), PS(), PS())
79
+ )
80
+ def forward_loglikelihood(params, rng, batch):
81
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
82
+ rng_generator = JaxRNG(rng)
83
+ input_tokens = batch['input_tokens']
84
+ output_tokens = batch['output_tokens']
85
+ input_mask = batch['input_mask']
86
+ output_mask = batch['output_mask']
87
+
88
+ logits = hf_model.module.apply(
89
+ params, input_tokens, attention_mask=input_mask,
90
+ deterministic=True, rngs=rng_generator(llama_config.rng_keys()),
91
+ ).logits
92
+ # if llama_config.n_real_tokens is not None:
93
+ # logits = logits.at[:, :, llama_config.n_real_tokens:].set(-1e8)
94
+ loglikelihood = -optax.softmax_cross_entropy_with_integer_labels(
95
+ logits, output_tokens
96
+ )
97
+ loglikelihood = jnp.sum(loglikelihood * output_mask, axis=-1)
98
+ match_count = jnp.sum(
99
+ (jnp.argmax(logits, axis=-1) == output_tokens) * output_mask,
100
+ axis=-1
101
+ )
102
+ total = jnp.sum(output_mask, axis=-1)
103
+ is_greedy = match_count == total
104
+ return loglikelihood, is_greedy, rng_generator()
105
+
106
+
107
+ @partial(
108
+ pjit,
109
+ in_shardings=(model_ps, PS(), PS(), PS()),
110
+ out_shardings=(PS(), PS())
111
+ )
112
+ def forward_generate(params, rng, batch, temperature):
113
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
114
+ rng_generator = JaxRNG(rng)
115
+ output = hf_model.generate(
116
+ batch['input_tokens'],
117
+ attention_mask=batch['attention_mask'],
118
+ params=params['params'],
119
+ prng_key=rng_generator(),
120
+ logits_processor=FlaxLogitsProcessorList(
121
+ [FlaxTemperatureLogitsWarper(temperature)]
122
+ ),
123
+ generation_config=GenerationConfig(
124
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
125
+ pad_token_id=tokenizer.eos_token_id,
126
+ bos_token_id=tokenizer.bos_token_id,
127
+ eos_token_id=tokenizer.eos_token_id,
128
+ do_sample=FLAGS.do_sample,
129
+ num_beams=FLAGS.num_beams,
130
+ top_k=FLAGS.top_k,
131
+ top_p=FLAGS.top_p,
132
+ )
133
+ ).sequences[:, batch['input_tokens'].shape[1]:]
134
+ return output, rng_generator()
135
+
136
+ @partial(
137
+ pjit,
138
+ in_shardings=(model_ps, PS(), PS()),
139
+ out_shardings=(PS(), PS())
140
+ )
141
+ def forward_greedy_generate(params, rng, batch):
142
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
143
+ rng_generator = JaxRNG(rng)
144
+ output = hf_model.generate(
145
+ batch['input_tokens'],
146
+ attention_mask=batch['attention_mask'],
147
+ params=params['params'],
148
+ prng_key=rng_generator(),
149
+ generation_config=GenerationConfig(
150
+ max_new_tokens=FLAGS.seq_length - FLAGS.input_length,
151
+ pad_token_id=tokenizer.eos_token_id,
152
+ bos_token_id=tokenizer.bos_token_id,
153
+ eos_token_id=tokenizer.eos_token_id,
154
+ do_sample=False,
155
+ num_beams=1,
156
+ )
157
+ ).sequences[:, batch['input_tokens'].shape[1]:]
158
+ return output, rng_generator()
159
+
160
+ mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
161
+ with mesh:
162
+ params = tree_apply(shard_fns, params)
163
+ sharded_rng = next_rng()
164
+
165
+ class ModelServer(LMServer):
166
+
167
+ @staticmethod
168
+ def loglikelihood(prefix_text, text):
169
+ nonlocal sharded_rng
170
+ prefix = prefix_tokenizer(
171
+ prefix_text,
172
+ padding='max_length',
173
+ truncation=True,
174
+ max_length=FLAGS.input_length,
175
+ return_tensors='np',
176
+ )
177
+ inputs = tokenizer(
178
+ text,
179
+ padding='max_length',
180
+ truncation=True,
181
+ max_length=FLAGS.seq_length - FLAGS.input_length,
182
+ return_tensors='np',
183
+ )
184
+ output_tokens = np.concatenate([prefix.input_ids, inputs.input_ids], axis=1)
185
+ bos_tokens = np.full(
186
+ (output_tokens.shape[0], 1), tokenizer.bos_token_id, dtype=np.int32
187
+ )
188
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
189
+ input_mask = np.concatenate(
190
+ [prefix.attention_mask, inputs.attention_mask], axis=1
191
+ )
192
+ if FLAGS.add_bos_token:
193
+ bos_mask = np.ones_like(input_mask[:, :1])
194
+ else:
195
+ bos_mask = np.zeros_like(input_mask[:, :1])
196
+
197
+ input_mask = np.concatenate([bos_mask, input_mask[:, :-1]], axis=1)
198
+ output_mask = np.concatenate(
199
+ [np.zeros_like(prefix.attention_mask), inputs.attention_mask], axis=1
200
+ )
201
+ batch = dict(
202
+ input_tokens=input_tokens,
203
+ output_tokens=output_tokens,
204
+ input_mask=input_mask,
205
+ output_mask=output_mask,
206
+ )
207
+ with mesh:
208
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
209
+ params, sharded_rng, batch
210
+ )
211
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
212
+ return loglikelihood, is_greedy
213
+
214
+ @staticmethod
215
+ def loglikelihood_rolling(text):
216
+ nonlocal sharded_rng
217
+ inputs = tokenizer(
218
+ text,
219
+ padding='longest',
220
+ truncation=False,
221
+ max_length=np.iinfo(np.int32).max,
222
+ return_tensors='np',
223
+ )
224
+ batch_size = inputs.input_ids.shape[0]
225
+ output_tokens = inputs.input_ids
226
+ attention_mask = inputs.attention_mask
227
+
228
+ if output_tokens.shape[1] < FLAGS.seq_length:
229
+ padding_length = FLAGS.seq_length - output_tokens.shape[1]
230
+ pad_tokens = np.full(
231
+ (batch_size, padding_length), tokenizer.pad_token_id, dtype=np.int32
232
+ )
233
+ output_tokens = np.concatenate([output_tokens, pad_tokens], axis=-1)
234
+ pad_mask = np.zeros(
235
+ (batch_size, padding_length), dtype=inputs.attention_mask.dtype
236
+ )
237
+ attention_mask = np.concatenate([attention_mask, pad_mask], axis=-1)
238
+
239
+ bos_tokens = np.full(
240
+ (batch_size, 1), tokenizer.bos_token_id, dtype=np.int32
241
+ )
242
+ input_tokens = np.concatenate([bos_tokens, output_tokens[:, :-1]], axis=-1)
243
+ bos_mask = np.ones((batch_size, 1), dtype=inputs.attention_mask.dtype)
244
+ total_seq_length = output_tokens.shape[1]
245
+
246
+ total_loglikelihood = 0.0
247
+ total_is_greedy = True
248
+ # Sliding window
249
+ for i in range(0, total_seq_length, FLAGS.seq_length):
250
+ # Last window
251
+ if i + FLAGS.seq_length > total_seq_length:
252
+ last_output_mask = np.copy(attention_mask[:, -FLAGS.seq_length:])
253
+ last_output_mask[:, :i - total_seq_length] = 0.0
254
+
255
+ batch = dict(
256
+ input_tokens=input_tokens[:, -FLAGS.seq_length:],
257
+ output_tokens=output_tokens[:, -FLAGS.seq_length:],
258
+ input_mask=attention_mask[:, -FLAGS.seq_length:],
259
+ output_mask=last_output_mask,
260
+ )
261
+
262
+ # Normal window
263
+ else:
264
+ batch = dict(
265
+ input_tokens=input_tokens[:, i:i + FLAGS.seq_length],
266
+ output_tokens=output_tokens[:, i:i + FLAGS.seq_length],
267
+ input_mask=attention_mask[:, i:i + FLAGS.seq_length],
268
+ output_mask=attention_mask[:, i:i + FLAGS.seq_length],
269
+ )
270
+
271
+ with mesh:
272
+ loglikelihood, is_greedy, sharded_rng = forward_loglikelihood(
273
+ params, sharded_rng, batch
274
+ )
275
+ loglikelihood, is_greedy = jax.device_get((loglikelihood, is_greedy))
276
+
277
+ total_loglikelihood += loglikelihood
278
+ total_is_greedy = np.logical_and(is_greedy, total_is_greedy)
279
+
280
+ return total_loglikelihood, total_is_greedy
281
+
282
+ @staticmethod
283
+ def generate(text, temperature):
284
+ nonlocal sharded_rng
285
+ inputs = prefix_tokenizer(
286
+ text,
287
+ padding='max_length',
288
+ truncation=True,
289
+ max_length=FLAGS.input_length,
290
+ return_tensors='np',
291
+ )
292
+ input_tokens = inputs.input_ids
293
+ input_mask = inputs.attention_mask
294
+ if FLAGS.add_bos_token:
295
+ input_tokens[:, 0] = tokenizer.bos_token_id
296
+ input_mask[:, 0] = 1
297
+ batch = dict(
298
+ input_tokens=input_tokens,
299
+ attention_mask=input_mask,
300
+ )
301
+ with mesh:
302
+ output, sharded_rng = forward_generate(
303
+ params, sharded_rng, batch, temperature
304
+ )
305
+ output = jax.device_get(output)
306
+ output_text = []
307
+ for text in list(tokenizer.batch_decode(output)):
308
+ if tokenizer.eos_token in text:
309
+ text = text.split(tokenizer.eos_token, maxsplit=1)[0]
310
+ output_text.append(text)
311
+
312
+ return output_text
313
+
314
+ @staticmethod
315
+ def greedy_until(prefix_text, until, max_length):
316
+ nonlocal sharded_rng
317
+ all_outputs = []
318
+ for pf, ut in zip(prefix_text, until):
319
+ if isinstance(ut, str):
320
+ ut = [ut]
321
+ total_length = 0
322
+ total_generated = ''
323
+
324
+ while total_length < max_length:
325
+ pf_tokens = tokenizer(
326
+ pf,
327
+ padding=False,
328
+ truncation=False,
329
+ max_length=np.iinfo(np.int32).max,
330
+ return_tensors='np',
331
+ )
332
+ input_tokens = pf_tokens.input_ids
333
+ attention_mask = pf_tokens.attention_mask
334
+
335
+ if input_tokens.shape[1] < FLAGS.input_length:
336
+ extra = FLAGS.input_length - input_tokens.shape[1]
337
+ pad_tokens = np.full(
338
+ (1, extra), tokenizer.pad_token_id, dtype=np.int32
339
+ )
340
+ input_tokens = np.concatenate(
341
+ [pad_tokens, input_tokens], axis=1
342
+ )
343
+ pad_attention = np.zeros((1, extra), dtype=attention_mask.dtype)
344
+ attention_mask = np.concatenate(
345
+ [pad_attention, attention_mask], axis=1
346
+ )
347
+ elif input_tokens.shape[1] > FLAGS.input_length:
348
+ input_tokens = input_tokens[:, -FLAGS.input_length:]
349
+ attention_mask = attention_mask[:, -FLAGS.input_length:]
350
+
351
+ if FLAGS.add_bos_token:
352
+ input_tokens[:, 0] = tokenizer.bos_token_id
353
+ attention_mask[:, 0] = 1
354
+
355
+ batch = dict(input_tokens=input_tokens, attention_mask=attention_mask)
356
+
357
+ with mesh:
358
+ output, sharded_rng = forward_greedy_generate(
359
+ params, sharded_rng, batch
360
+ )
361
+ output = jax.device_get(output)
362
+
363
+ total_length += output.shape[1]
364
+ output_text = tokenizer.batch_decode(output)[0]
365
+ total_generated = total_generated + output_text
366
+ pf = pf + output_text
367
+
368
+ done = False
369
+ for s in ut:
370
+ if s in total_generated:
371
+ total_generated = total_generated.split(s, maxsplit=1)[0]
372
+ done = True
373
+ if done:
374
+ break
375
+
376
+ all_outputs.append(total_generated)
377
+
378
+ return all_outputs
379
+
380
+
381
+ server = ModelServer(FLAGS.lm_server)
382
+ server.run()
383
+
384
+
385
+ if __name__ == "__main__":
386
+ mlxu.run(main)
EasyLM/models/llama/llama_train.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pprint
2
+ from functools import partial
3
+
4
+ from tqdm import tqdm, trange
5
+ import numpy as np
6
+ import mlxu
7
+
8
+ import jax
9
+ import jax.numpy as jnp
10
+ from jax.experimental.pjit import pjit
11
+ from jax.sharding import PartitionSpec as PS
12
+ from flax.training.train_state import TrainState
13
+
14
+ from EasyLM.data import DatasetFactory
15
+ from EasyLM.checkpoint import StreamingCheckpointer
16
+ from EasyLM.optimizers import OptimizerFactory
17
+ from EasyLM.jax_utils import (
18
+ JaxRNG, next_rng, match_partition_rules,
19
+ cross_entropy_loss_and_accuracy, global_norm, get_float_dtype_by_name,
20
+ set_random_seed, average_metrics, get_weight_decay_mask,
21
+ make_shard_and_gather_fns, with_sharding_constraint,
22
+ )
23
+ from EasyLM.models.llama.llama_model import (
24
+ LLaMAConfig, FlaxLLaMAForCausalLMModule
25
+ )
26
+
27
+
28
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
29
+ seed=42,
30
+ initialize_jax_distributed=False,
31
+ mesh_dim='1,-1,1',
32
+ dtype='fp32',
33
+ param_dtype='fp32',
34
+ total_steps=10000,
35
+ load_llama_config='',
36
+ update_llama_config='',
37
+ load_checkpoint='',
38
+ load_dataset_state='',
39
+ log_freq=50,
40
+ save_model_freq=0,
41
+ save_milestone_freq=0,
42
+ eval_freq=0,
43
+ tokenizer=LLaMAConfig.get_tokenizer_config(),
44
+ train_dataset=DatasetFactory.get_default_config(),
45
+ eval_dataset=DatasetFactory.get_default_config(),
46
+ optimizer=OptimizerFactory.get_default_config(),
47
+ checkpointer=StreamingCheckpointer.get_default_config(),
48
+ llama=LLaMAConfig.get_default_config(),
49
+ logger=mlxu.WandBLogger.get_default_config(),
50
+ log_all_worker=False,
51
+ )
52
+
53
+
54
+ def main(argv):
55
+ if FLAGS.initialize_jax_distributed:
56
+ jax.distributed.initialize()
57
+
58
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
59
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
60
+ logger = mlxu.WandBLogger(
61
+ config=FLAGS.logger,
62
+ variant=variant,
63
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
64
+ )
65
+ set_random_seed(FLAGS.seed)
66
+
67
+ tokenizer = LLaMAConfig.get_tokenizer(FLAGS.tokenizer)
68
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
69
+
70
+ if FLAGS.load_dataset_state != '':
71
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
72
+
73
+ if FLAGS.eval_freq > 0:
74
+ eval_dataset = DatasetFactory.load_dataset(
75
+ FLAGS.eval_dataset, dataset.tokenizer, eval_dataset=True
76
+ )
77
+
78
+ seq_length = dataset.seq_length
79
+
80
+ if FLAGS.load_llama_config != '':
81
+ llama_config = LLaMAConfig.load_config(FLAGS.load_llama_config)
82
+ else:
83
+ llama_config = LLaMAConfig(**FLAGS.llama)
84
+
85
+ if FLAGS.update_llama_config != '':
86
+ llama_config.update(dict(eval(FLAGS.update_llama_config)))
87
+
88
+ llama_config.update(dict(
89
+ bos_token_id=dataset.tokenizer.bos_token_id,
90
+ eos_token_id=dataset.tokenizer.eos_token_id,
91
+ ))
92
+ if llama_config.vocab_size < dataset.vocab_size:
93
+ llama_config.update(dict(vocab_size=dataset.vocab_size))
94
+
95
+ model = FlaxLLaMAForCausalLMModule(
96
+ llama_config, dtype=get_float_dtype_by_name(FLAGS.dtype), param_dtype=get_float_dtype_by_name(FLAGS.param_dtype)
97
+ )
98
+
99
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
100
+ FLAGS.optimizer,
101
+ get_weight_decay_mask(LLaMAConfig.get_weight_decay_exclusions())
102
+ )
103
+
104
+ def create_trainstate_from_params(params):
105
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
106
+
107
+ def init_fn(rng):
108
+ rng_generator = JaxRNG(rng)
109
+ params = model.init(
110
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
111
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
112
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
113
+ rngs=rng_generator(llama_config.rng_keys()),
114
+ )
115
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
116
+
117
+ def train_step(train_state, rng, batch):
118
+ rng_generator = JaxRNG(rng)
119
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
120
+ def loss_and_accuracy(params):
121
+ logits = model.apply(
122
+ params, batch['input_tokens'], deterministic=False,
123
+ rngs=rng_generator(llama_config.rng_keys()),
124
+ ).logits
125
+ return cross_entropy_loss_and_accuracy(
126
+ logits, batch['target_tokens'], batch['loss_masks']
127
+ )
128
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
129
+ (loss, accuracy), grads = grad_fn(train_state.params)
130
+ train_state = train_state.apply_gradients(grads=grads)
131
+ metrics = dict(
132
+ loss=loss,
133
+ accuracy=accuracy,
134
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
135
+ gradient_norm=global_norm(grads),
136
+ param_norm=global_norm(train_state.params),
137
+ )
138
+ return train_state, rng_generator(), metrics
139
+
140
+ def eval_step(train_state, rng, batch):
141
+ rng_generator = JaxRNG(rng)
142
+ batch = with_sharding_constraint(batch, PS(('dp', 'fsdp')))
143
+ logits = model.apply(
144
+ train_state.params, batch['input_tokens'], deterministic=True,
145
+ rngs=rng_generator(llama_config.rng_keys()),
146
+ ).logits
147
+ loss, accuracy = cross_entropy_loss_and_accuracy(
148
+ logits, batch['target_tokens'], batch['loss_masks']
149
+ )
150
+ metrics = dict(
151
+ eval_loss=loss,
152
+ eval_accuracy=accuracy,
153
+ )
154
+ return rng_generator(), metrics
155
+
156
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
157
+ train_state_partition = match_partition_rules(
158
+ LLaMAConfig.get_partition_rules(), train_state_shapes
159
+ )
160
+
161
+ shard_fns, gather_fns = make_shard_and_gather_fns(
162
+ train_state_partition, train_state_shapes
163
+ )
164
+ checkpointer = StreamingCheckpointer(
165
+ FLAGS.checkpointer, logger.output_dir,
166
+ enable=jax.process_index() == 0,
167
+ )
168
+
169
+ sharded_init_fn = pjit(
170
+ init_fn,
171
+ in_shardings=PS(),
172
+ out_shardings=train_state_partition
173
+ )
174
+
175
+ sharded_create_trainstate_from_params = pjit(
176
+ create_trainstate_from_params,
177
+ in_shardings=(train_state_partition.params, ),
178
+ out_shardings=train_state_partition,
179
+ donate_argnums=(0, ),
180
+ )
181
+
182
+ sharded_train_step = pjit(
183
+ train_step,
184
+ in_shardings=(train_state_partition, PS(), PS()),
185
+ out_shardings=(train_state_partition, PS(), PS()),
186
+ donate_argnums=(0, 1),
187
+ )
188
+
189
+ sharded_eval_step = pjit(
190
+ eval_step,
191
+ in_shardings=(train_state_partition, PS(), PS()),
192
+ out_shardings=(PS(), PS()),
193
+ donate_argnums=(1,),
194
+ )
195
+
196
+ def save_checkpoint(train_state, milestone=False):
197
+ step = int(jax.device_get(train_state.step))
198
+ metadata = dict(
199
+ step=step,
200
+ variant=variant,
201
+ flags=flags_config_dict,
202
+ llama_config=llama_config.to_dict(),
203
+ )
204
+ checkpointer.save_all(
205
+ train_state=train_state,
206
+ gather_fns=gather_fns,
207
+ metadata=metadata,
208
+ dataset=dataset.get_state_dict(),
209
+ milestone=milestone,
210
+ )
211
+
212
+ mesh = LLaMAConfig.get_jax_mesh(FLAGS.mesh_dim)
213
+ with mesh:
214
+ train_state, restored_params = None, None
215
+ if FLAGS.load_checkpoint != '':
216
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
217
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
218
+ )
219
+
220
+ if train_state is None and restored_params is None:
221
+ # Initialize from scratch
222
+ train_state = sharded_init_fn(next_rng())
223
+ elif train_state is None and restored_params is not None:
224
+ # Restore from params but initialize train_state
225
+ train_state = sharded_create_trainstate_from_params(restored_params)
226
+ del restored_params
227
+
228
+ start_step = int(jax.device_get(train_state.step))
229
+
230
+ if FLAGS.save_model_freq > 0:
231
+ save_checkpoint(train_state)
232
+
233
+ sharded_rng = next_rng()
234
+
235
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
236
+
237
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
238
+ train_state, sharded_rng, metrics = sharded_train_step(
239
+ train_state, sharded_rng, batch
240
+ )
241
+
242
+ if FLAGS.eval_freq > 0 and (step + 1) % FLAGS.eval_freq == 0:
243
+ eval_metric_list = []
244
+ eval_iterator = iter(eval_dataset)
245
+ for eval_batch, _ in eval_iterator:
246
+ sharded_rng, eval_metrics = sharded_eval_step(
247
+ train_state, sharded_rng, eval_batch
248
+ )
249
+ eval_metric_list.append(eval_metrics)
250
+ metrics.update(average_metrics(eval_metric_list))
251
+
252
+ if FLAGS.log_freq > 0 and (step + 1) % FLAGS.log_freq == 0:
253
+ log_metrics = {"step": step}
254
+ log_metrics.update(metrics)
255
+ log_metrics.update(dataset_metrics)
256
+ log_metrics = jax.device_get(log_metrics)
257
+ logger.log(log_metrics)
258
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
259
+
260
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
261
+ save_checkpoint(train_state, milestone=True)
262
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
263
+ save_checkpoint(train_state)
264
+
265
+ if FLAGS.save_model_freq > 0:
266
+ save_checkpoint(train_state)
267
+
268
+
269
+ if __name__ == "__main__":
270
+ mlxu.run(main)
EasyLM/models/roberta/__init__.py ADDED
File without changes
EasyLM/models/roberta/roberta_model.py ADDED
@@ -0,0 +1,1694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The Google Flax Team Authors and The HuggingFace Inc. team.
3
+ # Modifications copyright 2022 Xinyang Geng
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from typing import Callable, Optional, Tuple
17
+ from collections import OrderedDict
18
+ from typing import Mapping
19
+
20
+ import numpy as np
21
+
22
+ import flax.linen as nn
23
+ import jax
24
+ import jax.numpy as jnp
25
+ from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
26
+ from flax.linen import combine_masks, make_causal_mask
27
+ from flax.linen import partitioning as nn_partitioning
28
+ from flax.linen.attention import dot_product_attention_weights
29
+ from flax.traverse_util import flatten_dict, unflatten_dict
30
+ from jax import lax
31
+ from jax.sharding import PartitionSpec
32
+
33
+ from transformers.configuration_utils import PretrainedConfig
34
+ from transformers.modeling_flax_outputs import (
35
+ FlaxBaseModelOutputWithPastAndCrossAttentions,
36
+ FlaxBaseModelOutputWithPooling,
37
+ FlaxBaseModelOutputWithPoolingAndCrossAttentions,
38
+ FlaxCausalLMOutputWithCrossAttentions,
39
+ FlaxMaskedLMOutput,
40
+ FlaxMultipleChoiceModelOutput,
41
+ FlaxQuestionAnsweringModelOutput,
42
+ FlaxSequenceClassifierOutput,
43
+ FlaxTokenClassifierOutput,
44
+ )
45
+ from transformers.modeling_flax_utils import (
46
+ ACT2FN, FlaxPreTrainedModel, append_call_sample_docstring,
47
+ overwrite_call_docstring
48
+ )
49
+ from transformers.utils import (
50
+ add_start_docstrings, add_start_docstrings_to_model_forward, logging
51
+ )
52
+ from transformers import AutoTokenizer
53
+
54
+ from ml_collections import ConfigDict
55
+ from ml_collections.config_dict import config_dict
56
+ from mlxu import function_args_to_config, load_pickle
57
+
58
+ from EasyLM.jax_utils import with_sharding_constraint, get_jax_mesh
59
+
60
+
61
+ """
62
+ The follow code is taken from
63
+ transformers/src/transformers/models/roberta/configuration_roberta.py
64
+ and modified to work with EasyLM.
65
+ """
66
+
67
+
68
+ ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
69
+ "roberta-base": "https://huggingface.co/roberta-base/resolve/main/config.json",
70
+ "roberta-large": "https://huggingface.co/roberta-large/resolve/main/config.json",
71
+ "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/config.json",
72
+ "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/config.json",
73
+ "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/config.json",
74
+ "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/config.json",
75
+ }
76
+
77
+
78
+ class RobertaConfig(PretrainedConfig):
79
+ r"""
80
+ This is the configuration class to store the configuration of a [`RobertaModel`] or a [`TFRobertaModel`]. It is
81
+ used to instantiate a RoBERTa model according to the specified arguments, defining the model architecture.
82
+ Instantiating a configuration with the defaults will yield a similar configuration to that of the RoBERTa
83
+ [roberta-base](https://huggingface.co/roberta-base) architecture.
84
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
85
+ documentation from [`PretrainedConfig`] for more information.
86
+ Args:
87
+ vocab_size (`int`, *optional*, defaults to 30522):
88
+ Vocabulary size of the RoBERTa model. Defines the number of different tokens that can be represented by the
89
+ `inputs_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
90
+ hidden_size (`int`, *optional*, defaults to 768):
91
+ Dimensionality of the encoder layers and the pooler layer.
92
+ num_hidden_layers (`int`, *optional*, defaults to 12):
93
+ Number of hidden layers in the Transformer encoder.
94
+ num_attention_heads (`int`, *optional*, defaults to 12):
95
+ Number of attention heads for each attention layer in the Transformer encoder.
96
+ intermediate_size (`int`, *optional*, defaults to 3072):
97
+ Dimensionality of the "intermediate" (often named feed-forward) layer in the Transformer encoder.
98
+ hidden_act (`str` or `Callable`, *optional*, defaults to `"gelu"`):
99
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
100
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
101
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
102
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
103
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
104
+ The dropout ratio for the attention probabilities.
105
+ max_position_embeddings (`int`, *optional*, defaults to 512):
106
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
107
+ just in case (e.g., 512 or 1024 or 2048).
108
+ type_vocab_size (`int`, *optional*, defaults to 2):
109
+ The vocabulary size of the `token_type_ids` passed when calling [`RobertaModel`] or [`TFRobertaModel`].
110
+ initializer_range (`float`, *optional*, defaults to 0.02):
111
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
112
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
113
+ The epsilon used by the layer normalization layers.
114
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
115
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
116
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
117
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
118
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
119
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
120
+ use_cache (`bool`, *optional*, defaults to `True`):
121
+ Whether or not the model should return the last key/values attentions (not used by all models). Only
122
+ relevant if `config.is_decoder=True`.
123
+ classifier_dropout (`float`, *optional*):
124
+ The dropout ratio for the classification head.
125
+ Examples:
126
+ ```python
127
+ >>> from transformers import RobertaConfig, RobertaModel
128
+ >>> # Initializing a RoBERTa configuration
129
+ >>> configuration = RobertaConfig()
130
+ >>> # Initializing a model (with random weights) from the configuration
131
+ >>> model = RobertaModel(configuration)
132
+ >>> # Accessing the model configuration
133
+ >>> configuration = model.config
134
+ ```"""
135
+ model_type = "roberta"
136
+
137
+ def __init__(
138
+ self,
139
+ vocab_size=50265,
140
+ hidden_size=768,
141
+ num_hidden_layers=12,
142
+ num_attention_heads=12,
143
+ intermediate_size=3072,
144
+ hidden_act="gelu",
145
+ hidden_dropout_prob=0.1,
146
+ attention_probs_dropout_prob=0.1,
147
+ max_position_embeddings=514,
148
+ type_vocab_size=1,
149
+ initializer_range=0.02,
150
+ layer_norm_eps=1e-5,
151
+ pad_token_id=1,
152
+ bos_token_id=0,
153
+ eos_token_id=2,
154
+ position_embedding_type="absolute",
155
+ use_cache=True,
156
+ classifier_dropout=None,
157
+ **kwargs
158
+ ):
159
+ super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
160
+
161
+ self.vocab_size = vocab_size
162
+ self.hidden_size = hidden_size
163
+ self.num_hidden_layers = num_hidden_layers
164
+ self.num_attention_heads = num_attention_heads
165
+ self.hidden_act = hidden_act
166
+ self.intermediate_size = intermediate_size
167
+ self.hidden_dropout_prob = hidden_dropout_prob
168
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
169
+ self.max_position_embeddings = max_position_embeddings
170
+ self.type_vocab_size = type_vocab_size
171
+ self.initializer_range = initializer_range
172
+ self.layer_norm_eps = layer_norm_eps
173
+ self.position_embedding_type = position_embedding_type
174
+ self.use_cache = use_cache
175
+ self.classifier_dropout = classifier_dropout
176
+
177
+ @classmethod
178
+ def get_default_config(cls, updates=None):
179
+ none_arg_types = dict(
180
+ classifier_dropout=float,
181
+ )
182
+ config = function_args_to_config(cls.__init__, none_arg_types=none_arg_types)
183
+ config.tie_word_embeddings = True
184
+
185
+ if updates is not None:
186
+ config.update(ConfigDict(updates).copy_and_resolve_references())
187
+
188
+ return config
189
+
190
+ @staticmethod
191
+ def get_jax_mesh(axis_dims):
192
+ return get_jax_mesh(axis_dims, ('dp', 'fsdp', 'mp'))
193
+
194
+ @staticmethod
195
+ def get_partition_rules():
196
+ """ Parition rules for Roberta model. """
197
+ return (
198
+ ('embeddings/(position_embeddings|token_type_embeddings)/embedding', PartitionSpec()),
199
+ ('embeddings/word_embeddings/embedding', PartitionSpec()),
200
+ ('attention/self/(key|query|value)/kernel', PartitionSpec('fsdp', 'mp')),
201
+ ('attention/self/(key|query|value)/bias', PartitionSpec()),
202
+ ('attention/output/dense/kernel', PartitionSpec('mp', 'fsdp')),
203
+ ('attention/output/dense/bias', PartitionSpec()),
204
+ ('(LayerNorm|layer_norm)/(bias|scale)', PartitionSpec()),
205
+ ('intermediate/dense/kernel', PartitionSpec('fsdp', 'mp')),
206
+ ('intermediate/dense/bias', PartitionSpec('mp')),
207
+ ('output/dense/kernel', PartitionSpec('mp', 'fsdp')),
208
+ ('output/dense/bias', PartitionSpec()),
209
+ ('lm_head/dense/kernel', PartitionSpec()),
210
+ ('lm_head/dense/bias', PartitionSpec()),
211
+ ('lm_head/decoder/kernel', PartitionSpec('fsdp', 'mp')),
212
+ ('lm_head/decoder/bias', PartitionSpec('mp')),
213
+ ('.*', PartitionSpec()),
214
+ )
215
+
216
+ @staticmethod
217
+ def get_weight_decay_exclusions():
218
+ return ('bias', 'LayerNorm/scale', 'layer_norm/scale')
219
+
220
+ @staticmethod
221
+ def rng_keys():
222
+ return ('params', 'dropout')
223
+
224
+ @staticmethod
225
+ def get_tokenizer_config(updates=None):
226
+ config = ConfigDict()
227
+ config.name = 'roberta-base'
228
+
229
+ if updates is not None:
230
+ config.update(ConfigDict(updates).copy_and_resolve_references())
231
+
232
+ return config
233
+
234
+ @classmethod
235
+ def get_tokenizer(cls, config):
236
+ config = cls.get_tokenizer_config(config)
237
+ return AutoTokenizer.from_pretrained(
238
+ config.name,
239
+ )
240
+
241
+ @staticmethod
242
+ def load_pretrained(name):
243
+ with jax.default_device(jax.devices("cpu")[0]):
244
+ params = FlaxRobertaForMaskedLM.from_pretrained(name, _do_init=False)[1]
245
+ params = freeze({'params': params})
246
+ return params
247
+
248
+ @classmethod
249
+ def load_config(cls, path):
250
+ load_type, load_path = path.split('::', 1)
251
+ if load_type == 'pickle':
252
+ return cls.from_dict(load_pickle(load_path)['roberta_config'])
253
+ elif load_type == 'huggingface':
254
+ return cls.from_pretrained(load_path)
255
+ else:
256
+ raise ValueError(f'Unsupported load config type: {load_type}')
257
+
258
+
259
+ """
260
+ The follow code is taken from
261
+ transformers/src/transformers/models/roberta/modeling_flax_roberta.py
262
+ and modified to work with EasyLM.
263
+ """
264
+
265
+
266
+ logger = logging.get_logger(__name__)
267
+
268
+ _CHECKPOINT_FOR_DOC = "roberta-base"
269
+ _CONFIG_FOR_DOC = "RobertaConfig"
270
+
271
+ remat = nn_partitioning.remat
272
+
273
+
274
+ def create_position_ids_from_input_ids(input_ids, padding_idx):
275
+ """
276
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
277
+ are ignored. This is modified from fairseq's `utils.make_positions`.
278
+ Args:
279
+ input_ids: jnp.ndarray
280
+ padding_idx: int
281
+ Returns: jnp.ndarray
282
+ """
283
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
284
+ mask = (input_ids != padding_idx).astype("i4")
285
+
286
+ if mask.ndim > 2:
287
+ mask = mask.reshape((-1, mask.shape[-1]))
288
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
289
+ incremental_indices = incremental_indices.reshape(input_ids.shape)
290
+ else:
291
+ incremental_indices = jnp.cumsum(mask, axis=1).astype("i4") * mask
292
+
293
+ return incremental_indices.astype("i4") + padding_idx
294
+
295
+
296
+ ROBERTA_START_DOCSTRING = r"""
297
+ This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
298
+ library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
299
+ This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
300
+ subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
301
+ general usage and behavior.
302
+ Finally, this model supports inherent JAX features such as:
303
+ - [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
304
+ - [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
305
+ - [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
306
+ - [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
307
+ Parameters:
308
+ config ([`RobertaConfig`]): Model configuration class with all the parameters of the
309
+ model. Initializing with a config file does not load the weights associated with the model, only the
310
+ configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
311
+ """
312
+
313
+ ROBERTA_INPUTS_DOCSTRING = r"""
314
+ Args:
315
+ input_ids (`numpy.ndarray` of shape `({0})`):
316
+ Indices of input sequence tokens in the vocabulary.
317
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
318
+ [`PreTrainedTokenizer.__call__`] for details.
319
+ [What are input IDs?](../glossary#input-ids)
320
+ attention_mask (`numpy.ndarray` of shape `({0})`, *optional*):
321
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
322
+ - 1 for tokens that are **not masked**,
323
+ - 0 for tokens that are **masked**.
324
+ [What are attention masks?](../glossary#attention-mask)
325
+ token_type_ids (`numpy.ndarray` of shape `({0})`, *optional*):
326
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
327
+ 1]`:
328
+ - 0 corresponds to a *sentence A* token,
329
+ - 1 corresponds to a *sentence B* token.
330
+ [What are token type IDs?](../glossary#token-type-ids)
331
+ position_ids (`numpy.ndarray` of shape `({0})`, *optional*):
332
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
333
+ config.max_position_embeddings - 1]`.
334
+ head_mask (`numpy.ndarray` of shape `({0})`, `optional):
335
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
336
+ - 1 indicates the head is **not masked**,
337
+ - 0 indicates the head is **masked**.
338
+ return_dict (`bool`, *optional*):
339
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
340
+ """
341
+
342
+
343
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEmbeddings with Bert->Roberta
344
+ class FlaxRobertaEmbeddings(nn.Module):
345
+ """Construct the embeddings from word, position and token_type embeddings."""
346
+
347
+ config: RobertaConfig
348
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
349
+
350
+ def setup(self):
351
+ self.word_embeddings = nn.Embed(
352
+ self.config.vocab_size,
353
+ self.config.hidden_size,
354
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
355
+ dtype=self.dtype,
356
+ )
357
+ self.position_embeddings = nn.Embed(
358
+ self.config.max_position_embeddings,
359
+ self.config.hidden_size,
360
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
361
+ dtype=self.dtype,
362
+ )
363
+ self.token_type_embeddings = nn.Embed(
364
+ self.config.type_vocab_size,
365
+ self.config.hidden_size,
366
+ embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
367
+ dtype=self.dtype,
368
+ )
369
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
370
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
371
+
372
+ def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):
373
+ # Embed
374
+ inputs_embeds = self.word_embeddings(input_ids.astype("i4"))
375
+ position_embeds = self.position_embeddings(position_ids.astype("i4"))
376
+ token_type_embeddings = self.token_type_embeddings(token_type_ids.astype("i4"))
377
+
378
+ # Sum all embeddings
379
+ hidden_states = inputs_embeds + token_type_embeddings + position_embeds
380
+
381
+ # Layer Norm
382
+ hidden_states = self.LayerNorm(hidden_states)
383
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
384
+ return hidden_states
385
+
386
+
387
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfAttention with Bert->Roberta
388
+ class FlaxRobertaSelfAttention(nn.Module):
389
+ config: RobertaConfig
390
+ causal: bool = False
391
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
392
+
393
+ def setup(self):
394
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
395
+ if self.config.hidden_size % self.config.num_attention_heads != 0:
396
+ raise ValueError(
397
+ "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads` "
398
+ " : {self.config.num_attention_heads}"
399
+ )
400
+
401
+ self.query = nn.Dense(
402
+ self.config.hidden_size,
403
+ dtype=self.dtype,
404
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
405
+ )
406
+ self.key = nn.Dense(
407
+ self.config.hidden_size,
408
+ dtype=self.dtype,
409
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
410
+ )
411
+ self.value = nn.Dense(
412
+ self.config.hidden_size,
413
+ dtype=self.dtype,
414
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
415
+ )
416
+
417
+ if self.causal:
418
+ self.causal_mask = make_causal_mask(
419
+ jnp.ones((1, self.config.max_position_embeddings), dtype="bool"), dtype="bool"
420
+ )
421
+
422
+ def _split_heads(self, hidden_states):
423
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.num_attention_heads, self.head_dim))
424
+
425
+ def _merge_heads(self, hidden_states):
426
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.config.hidden_size,))
427
+
428
+ @nn.compact
429
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartAttention._concatenate_to_cache
430
+ def _concatenate_to_cache(self, key, value, query, attention_mask):
431
+ """
432
+ This function takes projected key, value states from a single input token and concatenates the states to cached
433
+ states from previous steps. This function is slighly adapted from the official Flax repository:
434
+ https://github.com/google/flax/blob/491ce18759622506588784b4fca0e4bf05f8c8cd/flax/linen/attention.py#L252
435
+ """
436
+ # detect if we're initializing by absence of existing cache data.
437
+ is_initialized = self.has_variable("cache", "cached_key")
438
+ cached_key = self.variable("cache", "cached_key", jnp.zeros, key.shape, key.dtype)
439
+ cached_value = self.variable("cache", "cached_value", jnp.zeros, value.shape, value.dtype)
440
+ cache_index = self.variable("cache", "cache_index", lambda: jnp.array(0, dtype=jnp.int32))
441
+
442
+ if is_initialized:
443
+ *batch_dims, max_length, num_heads, depth_per_head = cached_key.value.shape
444
+ # update key, value caches with our new 1d spatial slices
445
+ cur_index = cache_index.value
446
+ indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
447
+ key = lax.dynamic_update_slice(cached_key.value, key, indices)
448
+ value = lax.dynamic_update_slice(cached_value.value, value, indices)
449
+ cached_key.value = key
450
+ cached_value.value = value
451
+ num_updated_cache_vectors = query.shape[1]
452
+ cache_index.value = cache_index.value + num_updated_cache_vectors
453
+ # causal mask for cached decoder self-attention: our single query position should only attend to those key positions that have already been generated and cached, not the remaining zero elements.
454
+ pad_mask = jnp.broadcast_to(
455
+ jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
456
+ tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
457
+ )
458
+ attention_mask = combine_masks(pad_mask, attention_mask)
459
+ return key, value, attention_mask
460
+
461
+ def __call__(
462
+ self,
463
+ hidden_states,
464
+ attention_mask,
465
+ layer_head_mask,
466
+ key_value_states: Optional[jnp.array] = None,
467
+ init_cache: bool = False,
468
+ deterministic=True,
469
+ output_attentions: bool = False,
470
+ ):
471
+ # if key_value_states are provided this layer is used as a cross-attention layer
472
+ # for the decoder
473
+ is_cross_attention = key_value_states is not None
474
+ batch_size = hidden_states.shape[0]
475
+
476
+ # get query proj
477
+ query_states = self.query(hidden_states)
478
+ # get key, value proj
479
+ if is_cross_attention:
480
+ # cross_attentions
481
+ key_states = self.key(key_value_states)
482
+ value_states = self.value(key_value_states)
483
+ else:
484
+ # self_attention
485
+ key_states = self.key(hidden_states)
486
+ value_states = self.value(hidden_states)
487
+
488
+ query_states = self._split_heads(query_states)
489
+ key_states = self._split_heads(key_states)
490
+ value_states = self._split_heads(value_states)
491
+
492
+ # handle cache prepare causal attention mask
493
+ if self.causal:
494
+ query_length, key_length = query_states.shape[1], key_states.shape[1]
495
+ if self.has_variable("cache", "cached_key"):
496
+ mask_shift = self.variables["cache"]["cache_index"]
497
+ max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
498
+ causal_mask = lax.dynamic_slice(
499
+ self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
500
+ )
501
+ else:
502
+ causal_mask = self.causal_mask[:, :, :query_length, :key_length]
503
+ causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
504
+
505
+ # combine masks if needed
506
+ if attention_mask is not None and self.causal:
507
+ attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
508
+ attention_mask = combine_masks(attention_mask, causal_mask)
509
+ elif self.causal:
510
+ attention_mask = causal_mask
511
+ elif attention_mask is not None:
512
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
513
+
514
+ # During fast autoregressive decoding, we feed one position at a time,
515
+ # and cache the keys and values step by step.
516
+ if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
517
+ key_states, value_states, attention_mask = self._concatenate_to_cache(
518
+ key_states, value_states, query_states, attention_mask
519
+ )
520
+
521
+ # Convert the boolean attention mask to an attention bias.
522
+ if attention_mask is not None:
523
+ # attention mask in the form of attention bias
524
+ attention_bias = lax.select(
525
+ attention_mask > 0,
526
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
527
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
528
+ )
529
+ else:
530
+ attention_bias = None
531
+
532
+ dropout_rng = None
533
+ if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
534
+ dropout_rng = self.make_rng("dropout")
535
+
536
+ attn_weights = dot_product_attention_weights(
537
+ query_states,
538
+ key_states,
539
+ bias=attention_bias,
540
+ dropout_rng=dropout_rng,
541
+ dropout_rate=self.config.attention_probs_dropout_prob,
542
+ broadcast_dropout=True,
543
+ deterministic=deterministic,
544
+ dtype=self.dtype,
545
+ precision=None,
546
+ )
547
+
548
+ # Mask heads if we want to
549
+ if layer_head_mask is not None:
550
+ attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)
551
+
552
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
553
+ attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))
554
+
555
+ outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
556
+ return outputs
557
+
558
+
559
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertSelfOutput with Bert->Roberta
560
+ class FlaxRobertaSelfOutput(nn.Module):
561
+ config: RobertaConfig
562
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
563
+
564
+ def setup(self):
565
+ self.dense = nn.Dense(
566
+ self.config.hidden_size,
567
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
568
+ dtype=self.dtype,
569
+ )
570
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
571
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
572
+
573
+ def __call__(self, hidden_states, input_tensor, deterministic: bool = True):
574
+ hidden_states = self.dense(hidden_states)
575
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
576
+ hidden_states = self.LayerNorm(hidden_states + input_tensor)
577
+ return hidden_states
578
+
579
+
580
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertAttention with Bert->Roberta
581
+ class FlaxRobertaAttention(nn.Module):
582
+ config: RobertaConfig
583
+ causal: bool = False
584
+ dtype: jnp.dtype = jnp.float32
585
+
586
+ def setup(self):
587
+ self.self = FlaxRobertaSelfAttention(self.config, causal=self.causal, dtype=self.dtype)
588
+ self.output = FlaxRobertaSelfOutput(self.config, dtype=self.dtype)
589
+
590
+ def __call__(
591
+ self,
592
+ hidden_states,
593
+ attention_mask,
594
+ layer_head_mask,
595
+ key_value_states=None,
596
+ init_cache=False,
597
+ deterministic=True,
598
+ output_attentions: bool = False,
599
+ ):
600
+ # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
601
+ # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
602
+ # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
603
+ attn_outputs = self.self(
604
+ hidden_states,
605
+ attention_mask,
606
+ layer_head_mask=layer_head_mask,
607
+ key_value_states=key_value_states,
608
+ init_cache=init_cache,
609
+ deterministic=deterministic,
610
+ output_attentions=output_attentions,
611
+ )
612
+ attn_output = attn_outputs[0]
613
+ hidden_states = self.output(attn_output, hidden_states, deterministic=deterministic)
614
+
615
+ outputs = (hidden_states,)
616
+
617
+ if output_attentions:
618
+ outputs += (attn_outputs[1],)
619
+
620
+ return outputs
621
+
622
+
623
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertIntermediate with Bert->Roberta
624
+ class FlaxRobertaIntermediate(nn.Module):
625
+ config: RobertaConfig
626
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
627
+
628
+ def setup(self):
629
+ self.dense = nn.Dense(
630
+ self.config.intermediate_size,
631
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
632
+ dtype=self.dtype,
633
+ )
634
+ self.activation = ACT2FN[self.config.hidden_act]
635
+
636
+ def __call__(self, hidden_states):
637
+ hidden_states = self.dense(hidden_states)
638
+ hidden_states = self.activation(hidden_states)
639
+ return hidden_states
640
+
641
+
642
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertOutput with Bert->Roberta
643
+ class FlaxRobertaOutput(nn.Module):
644
+ config: RobertaConfig
645
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
646
+
647
+ def setup(self):
648
+ self.dense = nn.Dense(
649
+ self.config.hidden_size,
650
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
651
+ dtype=self.dtype,
652
+ )
653
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
654
+ self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
655
+
656
+ def __call__(self, hidden_states, attention_output, deterministic: bool = True):
657
+ hidden_states = self.dense(hidden_states)
658
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
659
+ hidden_states = self.LayerNorm(hidden_states + attention_output)
660
+ return hidden_states
661
+
662
+
663
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayer with Bert->Roberta
664
+ class FlaxRobertaLayer(nn.Module):
665
+ config: RobertaConfig
666
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
667
+
668
+ def setup(self):
669
+ self.attention = FlaxRobertaAttention(self.config, causal=self.config.is_decoder, dtype=self.dtype)
670
+ self.intermediate = FlaxRobertaIntermediate(self.config, dtype=self.dtype)
671
+ self.output = FlaxRobertaOutput(self.config, dtype=self.dtype)
672
+ if self.config.add_cross_attention:
673
+ self.crossattention = FlaxRobertaAttention(self.config, causal=False, dtype=self.dtype)
674
+
675
+ def __call__(
676
+ self,
677
+ hidden_states,
678
+ attention_mask,
679
+ layer_head_mask,
680
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
681
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
682
+ init_cache: bool = False,
683
+ deterministic: bool = True,
684
+ output_attentions: bool = False,
685
+ ):
686
+ # Self Attention
687
+ attention_outputs = self.attention(
688
+ hidden_states,
689
+ attention_mask,
690
+ layer_head_mask=layer_head_mask,
691
+ init_cache=init_cache,
692
+ deterministic=deterministic,
693
+ output_attentions=output_attentions,
694
+ )
695
+ attention_output = attention_outputs[0]
696
+
697
+ # Cross-Attention Block
698
+ if encoder_hidden_states is not None:
699
+ cross_attention_outputs = self.crossattention(
700
+ attention_output,
701
+ attention_mask=encoder_attention_mask,
702
+ layer_head_mask=layer_head_mask,
703
+ key_value_states=encoder_hidden_states,
704
+ deterministic=deterministic,
705
+ output_attentions=output_attentions,
706
+ )
707
+ attention_output = cross_attention_outputs[0]
708
+
709
+ hidden_states = self.intermediate(attention_output)
710
+ hidden_states = self.output(hidden_states, attention_output, deterministic=deterministic)
711
+
712
+ outputs = (hidden_states,)
713
+
714
+ if output_attentions:
715
+ outputs += (attention_outputs[1],)
716
+ if encoder_hidden_states is not None:
717
+ outputs += (cross_attention_outputs[1],)
718
+ return outputs
719
+
720
+
721
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertLayerCollection with Bert->Roberta
722
+ class FlaxRobertaLayerCollection(nn.Module):
723
+ config: RobertaConfig
724
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
725
+ gradient_checkpointing: bool = False
726
+
727
+ def setup(self):
728
+ if self.gradient_checkpointing:
729
+ FlaxRobertaCheckpointLayer = remat(FlaxRobertaLayer, static_argnums=(5, 6, 7))
730
+ self.layers = [
731
+ FlaxRobertaCheckpointLayer(self.config, name=str(i), dtype=self.dtype)
732
+ for i in range(self.config.num_hidden_layers)
733
+ ]
734
+ else:
735
+ self.layers = [
736
+ FlaxRobertaLayer(self.config, name=str(i), dtype=self.dtype)
737
+ for i in range(self.config.num_hidden_layers)
738
+ ]
739
+
740
+ def __call__(
741
+ self,
742
+ hidden_states,
743
+ attention_mask,
744
+ head_mask,
745
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
746
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
747
+ init_cache: bool = False,
748
+ deterministic: bool = True,
749
+ output_attentions: bool = False,
750
+ output_hidden_states: bool = False,
751
+ return_dict: bool = True,
752
+ ):
753
+ all_attentions = () if output_attentions else None
754
+ all_hidden_states = () if output_hidden_states else None
755
+ all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
756
+
757
+ # Check if head_mask has a correct number of layers specified if desired
758
+ if head_mask is not None:
759
+ if head_mask.shape[0] != (len(self.layers)):
760
+ raise ValueError(
761
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for "
762
+ f" {head_mask.shape[0]}."
763
+ )
764
+
765
+ for i, layer in enumerate(self.layers):
766
+ if output_hidden_states:
767
+ all_hidden_states += (hidden_states,)
768
+
769
+ layer_outputs = layer(
770
+ hidden_states,
771
+ attention_mask,
772
+ head_mask[i] if head_mask is not None else None,
773
+ encoder_hidden_states,
774
+ encoder_attention_mask,
775
+ init_cache,
776
+ deterministic,
777
+ output_attentions,
778
+ )
779
+
780
+ hidden_states = layer_outputs[0]
781
+
782
+ if output_attentions:
783
+ all_attentions += (layer_outputs[1],)
784
+
785
+ if encoder_hidden_states is not None:
786
+ all_cross_attentions += (layer_outputs[2],)
787
+
788
+ if output_hidden_states:
789
+ all_hidden_states += (hidden_states,)
790
+
791
+ outputs = (hidden_states, all_hidden_states, all_attentions, all_cross_attentions)
792
+
793
+ if not return_dict:
794
+ return tuple(v for v in outputs if v is not None)
795
+
796
+ return FlaxBaseModelOutputWithPastAndCrossAttentions(
797
+ last_hidden_state=hidden_states,
798
+ hidden_states=all_hidden_states,
799
+ attentions=all_attentions,
800
+ cross_attentions=all_cross_attentions,
801
+ )
802
+
803
+
804
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertEncoder with Bert->Roberta
805
+ class FlaxRobertaEncoder(nn.Module):
806
+ config: RobertaConfig
807
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
808
+ gradient_checkpointing: bool = False
809
+
810
+ def setup(self):
811
+ self.layer = FlaxRobertaLayerCollection(
812
+ self.config,
813
+ dtype=self.dtype,
814
+ gradient_checkpointing=self.gradient_checkpointing,
815
+ )
816
+
817
+ def __call__(
818
+ self,
819
+ hidden_states,
820
+ attention_mask,
821
+ head_mask,
822
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
823
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
824
+ init_cache: bool = False,
825
+ deterministic: bool = True,
826
+ output_attentions: bool = False,
827
+ output_hidden_states: bool = False,
828
+ return_dict: bool = True,
829
+ ):
830
+ return self.layer(
831
+ hidden_states,
832
+ attention_mask,
833
+ head_mask=head_mask,
834
+ encoder_hidden_states=encoder_hidden_states,
835
+ encoder_attention_mask=encoder_attention_mask,
836
+ init_cache=init_cache,
837
+ deterministic=deterministic,
838
+ output_attentions=output_attentions,
839
+ output_hidden_states=output_hidden_states,
840
+ return_dict=return_dict,
841
+ )
842
+
843
+
844
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPooler with Bert->Roberta
845
+ class FlaxRobertaPooler(nn.Module):
846
+ config: RobertaConfig
847
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
848
+
849
+ def setup(self):
850
+ self.dense = nn.Dense(
851
+ self.config.hidden_size,
852
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
853
+ dtype=self.dtype,
854
+ )
855
+
856
+ def __call__(self, hidden_states):
857
+ cls_hidden_state = hidden_states[:, 0]
858
+ cls_hidden_state = self.dense(cls_hidden_state)
859
+ return nn.tanh(cls_hidden_state)
860
+
861
+
862
+ class FlaxRobertaLMHead(nn.Module):
863
+ config: RobertaConfig
864
+ dtype: jnp.dtype = jnp.float32
865
+ bias_init: Callable[..., np.ndarray] = jax.nn.initializers.zeros
866
+
867
+ def setup(self):
868
+ self.dense = nn.Dense(
869
+ self.config.hidden_size,
870
+ dtype=self.dtype,
871
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
872
+ )
873
+ self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
874
+ self.decoder = nn.Dense(
875
+ self.config.vocab_size,
876
+ dtype=self.dtype,
877
+ use_bias=False,
878
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
879
+ )
880
+ self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
881
+
882
+ def __call__(self, hidden_states, shared_embedding=None):
883
+ hidden_states = self.dense(hidden_states)
884
+ hidden_states = ACT2FN["gelu"](hidden_states)
885
+ hidden_states = self.layer_norm(hidden_states)
886
+
887
+ if shared_embedding is not None:
888
+ hidden_states = self.decoder.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
889
+ else:
890
+ hidden_states = self.decoder(hidden_states)
891
+
892
+ bias = jnp.asarray(self.bias, self.dtype)
893
+ hidden_states += bias
894
+ return hidden_states
895
+
896
+
897
+ class FlaxRobertaClassificationHead(nn.Module):
898
+ config: RobertaConfig
899
+ dtype: jnp.dtype = jnp.float32
900
+
901
+ def setup(self):
902
+ self.dense = nn.Dense(
903
+ self.config.hidden_size,
904
+ dtype=self.dtype,
905
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
906
+ )
907
+ classifier_dropout = (
908
+ self.config.classifier_dropout
909
+ if self.config.classifier_dropout is not None
910
+ else self.config.hidden_dropout_prob
911
+ )
912
+ self.dropout = nn.Dropout(rate=classifier_dropout)
913
+ self.out_proj = nn.Dense(
914
+ self.config.num_labels,
915
+ dtype=self.dtype,
916
+ kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
917
+ )
918
+
919
+ def __call__(self, hidden_states, deterministic=True):
920
+ hidden_states = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
921
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
922
+ hidden_states = self.dense(hidden_states)
923
+ hidden_states = nn.tanh(hidden_states)
924
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
925
+ hidden_states = self.out_proj(hidden_states)
926
+ return hidden_states
927
+
928
+
929
+ class FlaxRobertaPreTrainedModel(FlaxPreTrainedModel):
930
+ """
931
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
932
+ models.
933
+ """
934
+
935
+ config_class = RobertaConfig
936
+ base_model_prefix = "roberta"
937
+
938
+ module_class: nn.Module = None
939
+
940
+ def __init__(
941
+ self,
942
+ config: RobertaConfig,
943
+ input_shape: Tuple = (1, 1),
944
+ seed: int = 0,
945
+ dtype: jnp.dtype = jnp.float32,
946
+ _do_init: bool = True,
947
+ gradient_checkpointing: bool = False,
948
+ **kwargs,
949
+ ):
950
+ module = self.module_class(config=config, dtype=dtype, gradient_checkpointing=gradient_checkpointing, **kwargs)
951
+ super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
952
+
953
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertPreTrainedModel.enable_gradient_checkpointing
954
+ def enable_gradient_checkpointing(self):
955
+ self._module = self.module_class(
956
+ config=self.config,
957
+ dtype=self.dtype,
958
+ gradient_checkpointing=True,
959
+ )
960
+
961
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
962
+ # init input tensors
963
+ input_ids = jnp.zeros(input_shape, dtype="i4")
964
+ token_type_ids = jnp.ones_like(input_ids)
965
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
966
+ attention_mask = jnp.ones_like(input_ids)
967
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
968
+
969
+ params_rng, dropout_rng = jax.random.split(rng)
970
+ rngs = {"params": params_rng, "dropout": dropout_rng}
971
+
972
+ if self.config.add_cross_attention:
973
+ encoder_hidden_states = jnp.zeros(input_shape + (self.config.hidden_size,))
974
+ encoder_attention_mask = attention_mask
975
+ module_init_outputs = self.module.init(
976
+ rngs,
977
+ input_ids,
978
+ attention_mask,
979
+ token_type_ids,
980
+ position_ids,
981
+ head_mask,
982
+ encoder_hidden_states,
983
+ encoder_attention_mask,
984
+ return_dict=False,
985
+ )
986
+ else:
987
+ module_init_outputs = self.module.init(
988
+ rngs, input_ids, attention_mask, token_type_ids, position_ids, head_mask, return_dict=False
989
+ )
990
+
991
+ random_params = module_init_outputs["params"]
992
+
993
+ if params is not None:
994
+ random_params = flatten_dict(unfreeze(random_params))
995
+ params = flatten_dict(unfreeze(params))
996
+ for missing_key in self._missing_keys:
997
+ params[missing_key] = random_params[missing_key]
998
+ self._missing_keys = set()
999
+ return freeze(unflatten_dict(params))
1000
+ else:
1001
+ return random_params
1002
+
1003
+ # Copied from transformers.models.bart.modeling_flax_bart.FlaxBartDecoderPreTrainedModel.init_cache
1004
+ def init_cache(self, batch_size, max_length):
1005
+ r"""
1006
+ Args:
1007
+ batch_size (`int`):
1008
+ batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.
1009
+ max_length (`int`):
1010
+ maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized
1011
+ cache.
1012
+ """
1013
+ # init input variables to retrieve cache
1014
+ input_ids = jnp.ones((batch_size, max_length), dtype="i4")
1015
+ attention_mask = jnp.ones_like(input_ids, dtype="i4")
1016
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
1017
+
1018
+ init_variables = self.module.init(
1019
+ jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True
1020
+ )
1021
+ return unfreeze(init_variables["cache"])
1022
+
1023
+ @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
1024
+ def __call__(
1025
+ self,
1026
+ input_ids,
1027
+ attention_mask=None,
1028
+ token_type_ids=None,
1029
+ position_ids=None,
1030
+ head_mask=None,
1031
+ encoder_hidden_states=None,
1032
+ encoder_attention_mask=None,
1033
+ params: dict = None,
1034
+ dropout_rng: jax.random.PRNGKey = None,
1035
+ train: bool = False,
1036
+ output_attentions: Optional[bool] = None,
1037
+ output_hidden_states: Optional[bool] = None,
1038
+ return_dict: Optional[bool] = None,
1039
+ past_key_values: dict = None,
1040
+ ):
1041
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1042
+ output_hidden_states = (
1043
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1044
+ )
1045
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
1046
+
1047
+ # init input tensors if not passed
1048
+ if token_type_ids is None:
1049
+ token_type_ids = jnp.zeros_like(input_ids)
1050
+
1051
+ if position_ids is None:
1052
+ position_ids = create_position_ids_from_input_ids(input_ids, self.config.pad_token_id)
1053
+
1054
+ if attention_mask is None:
1055
+ attention_mask = jnp.ones_like(input_ids)
1056
+
1057
+ if head_mask is None:
1058
+ head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads))
1059
+
1060
+ # Handle any PRNG if needed
1061
+ rngs = {}
1062
+ if dropout_rng is not None:
1063
+ rngs["dropout"] = dropout_rng
1064
+
1065
+ inputs = {"params": params or self.params}
1066
+
1067
+ if self.config.add_cross_attention:
1068
+ # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed
1069
+ # down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be
1070
+ # changed by FlaxRobertaAttention module
1071
+ if past_key_values:
1072
+ inputs["cache"] = past_key_values
1073
+ mutable = ["cache"]
1074
+ else:
1075
+ mutable = False
1076
+
1077
+ outputs = self.module.apply(
1078
+ inputs,
1079
+ jnp.array(input_ids, dtype="i4"),
1080
+ jnp.array(attention_mask, dtype="i4"),
1081
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
1082
+ position_ids=jnp.array(position_ids, dtype="i4"),
1083
+ head_mask=jnp.array(head_mask, dtype="i4"),
1084
+ encoder_hidden_states=encoder_hidden_states,
1085
+ encoder_attention_mask=encoder_attention_mask,
1086
+ deterministic=not train,
1087
+ output_attentions=output_attentions,
1088
+ output_hidden_states=output_hidden_states,
1089
+ return_dict=return_dict,
1090
+ rngs=rngs,
1091
+ mutable=mutable,
1092
+ )
1093
+
1094
+ # add updated cache to model output
1095
+ if past_key_values is not None and return_dict:
1096
+ outputs, past_key_values = outputs
1097
+ outputs["past_key_values"] = unfreeze(past_key_values["cache"])
1098
+ return outputs
1099
+ elif past_key_values is not None and not return_dict:
1100
+ outputs, past_key_values = outputs
1101
+ outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]
1102
+
1103
+ else:
1104
+ outputs = self.module.apply(
1105
+ inputs,
1106
+ jnp.array(input_ids, dtype="i4"),
1107
+ jnp.array(attention_mask, dtype="i4"),
1108
+ token_type_ids=jnp.array(token_type_ids, dtype="i4"),
1109
+ position_ids=jnp.array(position_ids, dtype="i4"),
1110
+ head_mask=jnp.array(head_mask, dtype="i4"),
1111
+ deterministic=not train,
1112
+ output_attentions=output_attentions,
1113
+ output_hidden_states=output_hidden_states,
1114
+ return_dict=return_dict,
1115
+ rngs=rngs,
1116
+ )
1117
+
1118
+ return outputs
1119
+
1120
+
1121
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertModule with Bert->Roberta
1122
+ class FlaxRobertaModule(nn.Module):
1123
+ config: RobertaConfig
1124
+ dtype: jnp.dtype = jnp.float32 # the dtype of the computation
1125
+ add_pooling_layer: bool = True
1126
+ gradient_checkpointing: bool = False
1127
+
1128
+ def setup(self):
1129
+ self.embeddings = FlaxRobertaEmbeddings(self.config, dtype=self.dtype)
1130
+ self.encoder = FlaxRobertaEncoder(
1131
+ self.config,
1132
+ dtype=self.dtype,
1133
+ gradient_checkpointing=self.gradient_checkpointing,
1134
+ )
1135
+ self.pooler = FlaxRobertaPooler(self.config, dtype=self.dtype)
1136
+
1137
+ def __call__(
1138
+ self,
1139
+ input_ids,
1140
+ attention_mask,
1141
+ token_type_ids: Optional[jnp.ndarray] = None,
1142
+ position_ids: Optional[jnp.ndarray] = None,
1143
+ head_mask: Optional[jnp.ndarray] = None,
1144
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1145
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1146
+ init_cache: bool = False,
1147
+ deterministic: bool = True,
1148
+ output_attentions: bool = False,
1149
+ output_hidden_states: bool = False,
1150
+ return_dict: bool = True,
1151
+ ):
1152
+ # make sure `token_type_ids` is correctly initialized when not passed
1153
+ if token_type_ids is None:
1154
+ token_type_ids = jnp.zeros_like(input_ids)
1155
+
1156
+ # make sure `position_ids` is correctly initialized when not passed
1157
+ if position_ids is None:
1158
+ position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)
1159
+
1160
+ hidden_states = self.embeddings(
1161
+ input_ids, token_type_ids, position_ids, attention_mask, deterministic=deterministic
1162
+ )
1163
+ outputs = self.encoder(
1164
+ hidden_states,
1165
+ attention_mask,
1166
+ head_mask=head_mask,
1167
+ deterministic=deterministic,
1168
+ encoder_hidden_states=encoder_hidden_states,
1169
+ encoder_attention_mask=encoder_attention_mask,
1170
+ init_cache=init_cache,
1171
+ output_attentions=output_attentions,
1172
+ output_hidden_states=output_hidden_states,
1173
+ return_dict=return_dict,
1174
+ )
1175
+ hidden_states = outputs[0]
1176
+ pooled = self.pooler(hidden_states) if self.add_pooling_layer else None
1177
+
1178
+ if not return_dict:
1179
+ # if pooled is None, don't return it
1180
+ if pooled is None:
1181
+ return (hidden_states,) + outputs[1:]
1182
+ return (hidden_states, pooled) + outputs[1:]
1183
+
1184
+ return FlaxBaseModelOutputWithPoolingAndCrossAttentions(
1185
+ last_hidden_state=hidden_states,
1186
+ pooler_output=pooled,
1187
+ hidden_states=outputs.hidden_states,
1188
+ attentions=outputs.attentions,
1189
+ cross_attentions=outputs.cross_attentions,
1190
+ )
1191
+
1192
+
1193
+ @add_start_docstrings(
1194
+ "The bare RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
1195
+ ROBERTA_START_DOCSTRING,
1196
+ )
1197
+ class FlaxRobertaModel(FlaxRobertaPreTrainedModel):
1198
+ module_class = FlaxRobertaModule
1199
+
1200
+
1201
+ append_call_sample_docstring(FlaxRobertaModel, _CHECKPOINT_FOR_DOC, FlaxBaseModelOutputWithPooling, _CONFIG_FOR_DOC)
1202
+
1203
+
1204
+ class FlaxRobertaForMaskedLMModule(nn.Module):
1205
+ config: RobertaConfig
1206
+ dtype: jnp.dtype = jnp.float32
1207
+ gradient_checkpointing: bool = False
1208
+
1209
+ def setup(self):
1210
+ self.roberta = FlaxRobertaModule(
1211
+ config=self.config,
1212
+ add_pooling_layer=False,
1213
+ dtype=self.dtype,
1214
+ gradient_checkpointing=self.gradient_checkpointing,
1215
+ )
1216
+ self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
1217
+
1218
+ def __call__(
1219
+ self,
1220
+ input_ids,
1221
+ attention_mask,
1222
+ token_type_ids,
1223
+ position_ids,
1224
+ head_mask,
1225
+ deterministic: bool = True,
1226
+ output_attentions: bool = False,
1227
+ output_hidden_states: bool = False,
1228
+ return_dict: bool = True,
1229
+ ):
1230
+ # Model
1231
+ outputs = self.roberta(
1232
+ input_ids,
1233
+ attention_mask,
1234
+ token_type_ids,
1235
+ position_ids,
1236
+ head_mask,
1237
+ deterministic=deterministic,
1238
+ output_attentions=output_attentions,
1239
+ output_hidden_states=output_hidden_states,
1240
+ return_dict=return_dict,
1241
+ )
1242
+
1243
+ hidden_states = outputs[0]
1244
+ if self.config.tie_word_embeddings:
1245
+ shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1246
+ else:
1247
+ shared_embedding = None
1248
+
1249
+ # Compute the prediction scores
1250
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1251
+
1252
+ if not return_dict:
1253
+ return (logits,) + outputs[1:]
1254
+
1255
+ return FlaxMaskedLMOutput(
1256
+ logits=logits,
1257
+ hidden_states=outputs.hidden_states,
1258
+ attentions=outputs.attentions,
1259
+ )
1260
+
1261
+
1262
+ @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top.""", ROBERTA_START_DOCSTRING)
1263
+ class FlaxRobertaForMaskedLM(FlaxRobertaPreTrainedModel):
1264
+ module_class = FlaxRobertaForMaskedLMModule
1265
+
1266
+
1267
+ append_call_sample_docstring(
1268
+ FlaxRobertaForMaskedLM,
1269
+ _CHECKPOINT_FOR_DOC,
1270
+ FlaxBaseModelOutputWithPooling,
1271
+ _CONFIG_FOR_DOC,
1272
+ mask="<mask>",
1273
+ )
1274
+
1275
+
1276
+ class FlaxRobertaForSequenceClassificationModule(nn.Module):
1277
+ config: RobertaConfig
1278
+ dtype: jnp.dtype = jnp.float32
1279
+ gradient_checkpointing: bool = False
1280
+
1281
+ def setup(self):
1282
+ self.roberta = FlaxRobertaModule(
1283
+ config=self.config,
1284
+ dtype=self.dtype,
1285
+ add_pooling_layer=False,
1286
+ gradient_checkpointing=self.gradient_checkpointing,
1287
+ )
1288
+ self.classifier = FlaxRobertaClassificationHead(config=self.config, dtype=self.dtype)
1289
+
1290
+ def __call__(
1291
+ self,
1292
+ input_ids,
1293
+ attention_mask,
1294
+ token_type_ids,
1295
+ position_ids,
1296
+ head_mask,
1297
+ deterministic: bool = True,
1298
+ output_attentions: bool = False,
1299
+ output_hidden_states: bool = False,
1300
+ return_dict: bool = True,
1301
+ ):
1302
+ # Model
1303
+ outputs = self.roberta(
1304
+ input_ids,
1305
+ attention_mask,
1306
+ token_type_ids,
1307
+ position_ids,
1308
+ head_mask,
1309
+ deterministic=deterministic,
1310
+ output_attentions=output_attentions,
1311
+ output_hidden_states=output_hidden_states,
1312
+ return_dict=return_dict,
1313
+ )
1314
+
1315
+ sequence_output = outputs[0]
1316
+ logits = self.classifier(sequence_output, deterministic=deterministic)
1317
+
1318
+ if not return_dict:
1319
+ return (logits,) + outputs[1:]
1320
+
1321
+ return FlaxSequenceClassifierOutput(
1322
+ logits=logits,
1323
+ hidden_states=outputs.hidden_states,
1324
+ attentions=outputs.attentions,
1325
+ )
1326
+
1327
+
1328
+ @add_start_docstrings(
1329
+ """
1330
+ Roberta Model transformer with a sequence classification/regression head on top (a linear layer on top of the
1331
+ pooled output) e.g. for GLUE tasks.
1332
+ """,
1333
+ ROBERTA_START_DOCSTRING,
1334
+ )
1335
+ class FlaxRobertaForSequenceClassification(FlaxRobertaPreTrainedModel):
1336
+ module_class = FlaxRobertaForSequenceClassificationModule
1337
+
1338
+
1339
+ append_call_sample_docstring(
1340
+ FlaxRobertaForSequenceClassification,
1341
+ _CHECKPOINT_FOR_DOC,
1342
+ FlaxSequenceClassifierOutput,
1343
+ _CONFIG_FOR_DOC,
1344
+ )
1345
+
1346
+
1347
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForMultipleChoiceModule with Bert->Roberta, with self.bert->self.roberta
1348
+ class FlaxRobertaForMultipleChoiceModule(nn.Module):
1349
+ config: RobertaConfig
1350
+ dtype: jnp.dtype = jnp.float32
1351
+ gradient_checkpointing: bool = False
1352
+
1353
+ def setup(self):
1354
+ self.roberta = FlaxRobertaModule(
1355
+ config=self.config,
1356
+ dtype=self.dtype,
1357
+ gradient_checkpointing=self.gradient_checkpointing,
1358
+ )
1359
+ self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
1360
+ self.classifier = nn.Dense(1, dtype=self.dtype)
1361
+
1362
+ def __call__(
1363
+ self,
1364
+ input_ids,
1365
+ attention_mask,
1366
+ token_type_ids,
1367
+ position_ids,
1368
+ head_mask,
1369
+ deterministic: bool = True,
1370
+ output_attentions: bool = False,
1371
+ output_hidden_states: bool = False,
1372
+ return_dict: bool = True,
1373
+ ):
1374
+ num_choices = input_ids.shape[1]
1375
+ input_ids = input_ids.reshape(-1, input_ids.shape[-1]) if input_ids is not None else None
1376
+ attention_mask = attention_mask.reshape(-1, attention_mask.shape[-1]) if attention_mask is not None else None
1377
+ token_type_ids = token_type_ids.reshape(-1, token_type_ids.shape[-1]) if token_type_ids is not None else None
1378
+ position_ids = position_ids.reshape(-1, position_ids.shape[-1]) if position_ids is not None else None
1379
+
1380
+ # Model
1381
+ outputs = self.roberta(
1382
+ input_ids,
1383
+ attention_mask,
1384
+ token_type_ids,
1385
+ position_ids,
1386
+ head_mask,
1387
+ deterministic=deterministic,
1388
+ output_attentions=output_attentions,
1389
+ output_hidden_states=output_hidden_states,
1390
+ return_dict=return_dict,
1391
+ )
1392
+
1393
+ pooled_output = outputs[1]
1394
+ pooled_output = self.dropout(pooled_output, deterministic=deterministic)
1395
+ logits = self.classifier(pooled_output)
1396
+
1397
+ reshaped_logits = logits.reshape(-1, num_choices)
1398
+
1399
+ if not return_dict:
1400
+ return (reshaped_logits,) + outputs[2:]
1401
+
1402
+ return FlaxMultipleChoiceModelOutput(
1403
+ logits=reshaped_logits,
1404
+ hidden_states=outputs.hidden_states,
1405
+ attentions=outputs.attentions,
1406
+ )
1407
+
1408
+
1409
+ @add_start_docstrings(
1410
+ """
1411
+ Roberta Model with a multiple choice classification head on top (a linear layer on top of the pooled output and a
1412
+ softmax) e.g. for RocStories/SWAG tasks.
1413
+ """,
1414
+ ROBERTA_START_DOCSTRING,
1415
+ )
1416
+ class FlaxRobertaForMultipleChoice(FlaxRobertaPreTrainedModel):
1417
+ module_class = FlaxRobertaForMultipleChoiceModule
1418
+
1419
+
1420
+ overwrite_call_docstring(
1421
+ FlaxRobertaForMultipleChoice, ROBERTA_INPUTS_DOCSTRING.format("batch_size, num_choices, sequence_length")
1422
+ )
1423
+ append_call_sample_docstring(
1424
+ FlaxRobertaForMultipleChoice,
1425
+ _CHECKPOINT_FOR_DOC,
1426
+ FlaxMultipleChoiceModelOutput,
1427
+ _CONFIG_FOR_DOC,
1428
+ )
1429
+
1430
+
1431
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForTokenClassificationModule with Bert->Roberta, with self.bert->self.roberta
1432
+ class FlaxRobertaForTokenClassificationModule(nn.Module):
1433
+ config: RobertaConfig
1434
+ dtype: jnp.dtype = jnp.float32
1435
+ gradient_checkpointing: bool = False
1436
+
1437
+ def setup(self):
1438
+ self.roberta = FlaxRobertaModule(
1439
+ config=self.config,
1440
+ dtype=self.dtype,
1441
+ add_pooling_layer=False,
1442
+ gradient_checkpointing=self.gradient_checkpointing,
1443
+ )
1444
+ classifier_dropout = (
1445
+ self.config.classifier_dropout
1446
+ if self.config.classifier_dropout is not None
1447
+ else self.config.hidden_dropout_prob
1448
+ )
1449
+ self.dropout = nn.Dropout(rate=classifier_dropout)
1450
+ self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
1451
+
1452
+ def __call__(
1453
+ self,
1454
+ input_ids,
1455
+ attention_mask,
1456
+ token_type_ids,
1457
+ position_ids,
1458
+ head_mask,
1459
+ deterministic: bool = True,
1460
+ output_attentions: bool = False,
1461
+ output_hidden_states: bool = False,
1462
+ return_dict: bool = True,
1463
+ ):
1464
+ # Model
1465
+ outputs = self.roberta(
1466
+ input_ids,
1467
+ attention_mask,
1468
+ token_type_ids,
1469
+ position_ids,
1470
+ head_mask,
1471
+ deterministic=deterministic,
1472
+ output_attentions=output_attentions,
1473
+ output_hidden_states=output_hidden_states,
1474
+ return_dict=return_dict,
1475
+ )
1476
+
1477
+ hidden_states = outputs[0]
1478
+ hidden_states = self.dropout(hidden_states, deterministic=deterministic)
1479
+ logits = self.classifier(hidden_states)
1480
+
1481
+ if not return_dict:
1482
+ return (logits,) + outputs[1:]
1483
+
1484
+ return FlaxTokenClassifierOutput(
1485
+ logits=logits,
1486
+ hidden_states=outputs.hidden_states,
1487
+ attentions=outputs.attentions,
1488
+ )
1489
+
1490
+
1491
+ @add_start_docstrings(
1492
+ """
1493
+ Roberta Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for
1494
+ Named-Entity-Recognition (NER) tasks.
1495
+ """,
1496
+ ROBERTA_START_DOCSTRING,
1497
+ )
1498
+ class FlaxRobertaForTokenClassification(FlaxRobertaPreTrainedModel):
1499
+ module_class = FlaxRobertaForTokenClassificationModule
1500
+
1501
+
1502
+ append_call_sample_docstring(
1503
+ FlaxRobertaForTokenClassification,
1504
+ _CHECKPOINT_FOR_DOC,
1505
+ FlaxTokenClassifierOutput,
1506
+ _CONFIG_FOR_DOC,
1507
+ )
1508
+
1509
+
1510
+ # Copied from transformers.models.bert.modeling_flax_bert.FlaxBertForQuestionAnsweringModule with Bert->Roberta, with self.bert->self.roberta
1511
+ class FlaxRobertaForQuestionAnsweringModule(nn.Module):
1512
+ config: RobertaConfig
1513
+ dtype: jnp.dtype = jnp.float32
1514
+ gradient_checkpointing: bool = False
1515
+
1516
+ def setup(self):
1517
+ self.roberta = FlaxRobertaModule(
1518
+ config=self.config,
1519
+ dtype=self.dtype,
1520
+ add_pooling_layer=False,
1521
+ gradient_checkpointing=self.gradient_checkpointing,
1522
+ )
1523
+ self.qa_outputs = nn.Dense(self.config.num_labels, dtype=self.dtype)
1524
+
1525
+ def __call__(
1526
+ self,
1527
+ input_ids,
1528
+ attention_mask,
1529
+ token_type_ids,
1530
+ position_ids,
1531
+ head_mask,
1532
+ deterministic: bool = True,
1533
+ output_attentions: bool = False,
1534
+ output_hidden_states: bool = False,
1535
+ return_dict: bool = True,
1536
+ ):
1537
+ # Model
1538
+ outputs = self.roberta(
1539
+ input_ids,
1540
+ attention_mask,
1541
+ token_type_ids,
1542
+ position_ids,
1543
+ head_mask,
1544
+ deterministic=deterministic,
1545
+ output_attentions=output_attentions,
1546
+ output_hidden_states=output_hidden_states,
1547
+ return_dict=return_dict,
1548
+ )
1549
+
1550
+ hidden_states = outputs[0]
1551
+
1552
+ logits = self.qa_outputs(hidden_states)
1553
+ start_logits, end_logits = logits.split(self.config.num_labels, axis=-1)
1554
+ start_logits = start_logits.squeeze(-1)
1555
+ end_logits = end_logits.squeeze(-1)
1556
+
1557
+ if not return_dict:
1558
+ return (start_logits, end_logits) + outputs[1:]
1559
+
1560
+ return FlaxQuestionAnsweringModelOutput(
1561
+ start_logits=start_logits,
1562
+ end_logits=end_logits,
1563
+ hidden_states=outputs.hidden_states,
1564
+ attentions=outputs.attentions,
1565
+ )
1566
+
1567
+
1568
+ @add_start_docstrings(
1569
+ """
1570
+ Roberta Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
1571
+ layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
1572
+ """,
1573
+ ROBERTA_START_DOCSTRING,
1574
+ )
1575
+ class FlaxRobertaForQuestionAnswering(FlaxRobertaPreTrainedModel):
1576
+ module_class = FlaxRobertaForQuestionAnsweringModule
1577
+
1578
+
1579
+ append_call_sample_docstring(
1580
+ FlaxRobertaForQuestionAnswering,
1581
+ _CHECKPOINT_FOR_DOC,
1582
+ FlaxQuestionAnsweringModelOutput,
1583
+ _CONFIG_FOR_DOC,
1584
+ )
1585
+
1586
+
1587
+ class FlaxRobertaForCausalLMModule(nn.Module):
1588
+ config: RobertaConfig
1589
+ dtype: jnp.dtype = jnp.float32
1590
+ gradient_checkpointing: bool = False
1591
+
1592
+ def setup(self):
1593
+ self.roberta = FlaxRobertaModule(
1594
+ config=self.config,
1595
+ add_pooling_layer=False,
1596
+ dtype=self.dtype,
1597
+ gradient_checkpointing=self.gradient_checkpointing,
1598
+ )
1599
+ self.lm_head = FlaxRobertaLMHead(config=self.config, dtype=self.dtype)
1600
+
1601
+ def __call__(
1602
+ self,
1603
+ input_ids,
1604
+ attention_mask,
1605
+ position_ids,
1606
+ token_type_ids: Optional[jnp.ndarray] = None,
1607
+ head_mask: Optional[jnp.ndarray] = None,
1608
+ encoder_hidden_states: Optional[jnp.ndarray] = None,
1609
+ encoder_attention_mask: Optional[jnp.ndarray] = None,
1610
+ init_cache: bool = False,
1611
+ deterministic: bool = True,
1612
+ output_attentions: bool = False,
1613
+ output_hidden_states: bool = False,
1614
+ return_dict: bool = True,
1615
+ ):
1616
+ # Model
1617
+ outputs = self.roberta(
1618
+ input_ids,
1619
+ attention_mask,
1620
+ token_type_ids,
1621
+ position_ids,
1622
+ head_mask,
1623
+ encoder_hidden_states=encoder_hidden_states,
1624
+ encoder_attention_mask=encoder_attention_mask,
1625
+ init_cache=init_cache,
1626
+ deterministic=deterministic,
1627
+ output_attentions=output_attentions,
1628
+ output_hidden_states=output_hidden_states,
1629
+ return_dict=return_dict,
1630
+ )
1631
+
1632
+ hidden_states = outputs[0]
1633
+ if self.config.tie_word_embeddings:
1634
+ shared_embedding = self.roberta.variables["params"]["embeddings"]["word_embeddings"]["embedding"]
1635
+ else:
1636
+ shared_embedding = None
1637
+
1638
+ # Compute the prediction scores
1639
+ logits = self.lm_head(hidden_states, shared_embedding=shared_embedding)
1640
+
1641
+ if not return_dict:
1642
+ return (logits,) + outputs[1:]
1643
+
1644
+ return FlaxCausalLMOutputWithCrossAttentions(
1645
+ logits=logits,
1646
+ hidden_states=outputs.hidden_states,
1647
+ attentions=outputs.attentions,
1648
+ cross_attentions=outputs.cross_attentions,
1649
+ )
1650
+
1651
+
1652
+ @add_start_docstrings(
1653
+ """
1654
+ Roberta Model with a language modeling head on top (a linear layer on top of the hidden-states output) e.g for
1655
+ autoregressive tasks.
1656
+ """,
1657
+ ROBERTA_START_DOCSTRING,
1658
+ )
1659
+ class FlaxRobertaForCausalLM(FlaxRobertaPreTrainedModel):
1660
+ module_class = FlaxRobertaForCausalLMModule
1661
+
1662
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
1663
+ # initializing the cache
1664
+ batch_size, seq_length = input_ids.shape
1665
+
1666
+ past_key_values = self.init_cache(batch_size, max_length)
1667
+ # Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
1668
+ # But since the decoder uses a causal mask, those positions are masked anyway.
1669
+ # Thus, we can create a single static attention_mask here, which is more efficient for compilation
1670
+ extended_attention_mask = jnp.ones((batch_size, max_length), dtype="i4")
1671
+ if attention_mask is not None:
1672
+ position_ids = attention_mask.cumsum(axis=-1) - 1
1673
+ extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, attention_mask, (0, 0))
1674
+ else:
1675
+ position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
1676
+
1677
+ return {
1678
+ "past_key_values": past_key_values,
1679
+ "attention_mask": extended_attention_mask,
1680
+ "position_ids": position_ids,
1681
+ }
1682
+
1683
+ def update_inputs_for_generation(self, model_outputs, model_kwargs):
1684
+ model_kwargs["past_key_values"] = model_outputs.past_key_values
1685
+ model_kwargs["position_ids"] = model_kwargs["position_ids"][:, -1:] + 1
1686
+ return model_kwargs
1687
+
1688
+
1689
+ append_call_sample_docstring(
1690
+ FlaxRobertaForCausalLM,
1691
+ _CHECKPOINT_FOR_DOC,
1692
+ FlaxCausalLMOutputWithCrossAttentions,
1693
+ _CONFIG_FOR_DOC,
1694
+ )
EasyLM/models/roberta/roberta_train.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ from functools import partial
4
+ import re
5
+
6
+ from tqdm import tqdm, trange
7
+ import numpy as np
8
+ import mlxu
9
+
10
+ import jax
11
+ import jax.numpy as jnp
12
+ from jax.experimental.pjit import pjit, with_sharding_constraint
13
+ from jax.sharding import PartitionSpec as PS
14
+ from flax.training.train_state import TrainState
15
+
16
+ from EasyLM.data import DatasetFactory
17
+ from EasyLM.checkpoint import StreamingCheckpointer
18
+ from EasyLM.optimizers import OptimizerFactory
19
+ from EasyLM.jax_utils import (
20
+ JaxRNG, next_rng, match_partition_rules, get_float_dtype_by_name,
21
+ cross_entropy_loss_and_accuracy, named_tree_map, global_norm,
22
+ set_random_seed, average_metrics, get_weight_decay_mask,
23
+ make_shard_and_gather_fns, tree_apply
24
+ )
25
+ from EasyLM.models.roberta.roberta_model import (
26
+ RobertaConfig, FlaxRobertaForMaskedLMModule
27
+ )
28
+
29
+
30
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
31
+ seed=42,
32
+ initialize_jax_distributed=False,
33
+ mesh_dim='-1,1,1',
34
+ dtype='fp32',
35
+ mask_token_probability=0.15,
36
+ total_steps=10000,
37
+ load_roberta_config='',
38
+ update_roberta_config='',
39
+ load_checkpoint='',
40
+ load_dataset_state='',
41
+ log_freq=50,
42
+ save_model_freq=0,
43
+ save_milestone_freq=0,
44
+ eval_steps=0,
45
+ tokenizer=RobertaConfig.get_tokenizer_config(),
46
+ train_dataset=DatasetFactory.get_default_config(),
47
+ eval_dataset=DatasetFactory.get_default_config(),
48
+ optimizer=OptimizerFactory.get_default_config(),
49
+ checkpointer=StreamingCheckpointer.get_default_config(),
50
+ roberta=RobertaConfig.get_default_config(),
51
+ logger=mlxu.WandBLogger.get_default_config(),
52
+ log_all_worker=False,
53
+ )
54
+
55
+
56
+ def main(argv):
57
+ if FLAGS.initialize_jax_distributed:
58
+ jax.distributed.initialize()
59
+
60
+ variant = mlxu.get_user_flags(FLAGS, FLAGS_DEF)
61
+ flags_config_dict = mlxu.user_flags_to_config_dict(FLAGS, FLAGS_DEF)
62
+ logger = mlxu.WandBLogger(
63
+ config=FLAGS.logger,
64
+ variant=variant,
65
+ enable=FLAGS.log_all_worker or (jax.process_index() == 0),
66
+ )
67
+ set_random_seed(FLAGS.seed)
68
+
69
+ tokenizer = RobertaConfig.get_tokenizer(FLAGS.tokenizer)
70
+ dataset = DatasetFactory.load_dataset(FLAGS.train_dataset, tokenizer)
71
+ if FLAGS.load_dataset_state != '':
72
+ dataset.load_state_dict(mlxu.load_pickle(FLAGS.load_dataset_state))
73
+
74
+ if FLAGS.eval_steps > 0:
75
+ eval_dataset = DatasetFactory.load_dataset(
76
+ FLAGS.eval_dataset, dataset.tokenizer
77
+ )
78
+ eval_iterator = iter(eval_dataset)
79
+
80
+ seq_length = dataset.seq_length
81
+
82
+ if FLAGS.load_roberta_config != '':
83
+ roberta_config = RobertaConfig.load_config(FLAGS.load_roberta_config)
84
+ else:
85
+ roberta_config = RobertaConfig(**FLAGS.roberta)
86
+
87
+ if FLAGS.update_roberta_config != '':
88
+ roberta_config.update(dict(eval(FLAGS.update_roberta_config)))
89
+
90
+ roberta_config.update(dict(
91
+ bos_token_id=dataset.tokenizer.bos_token_id,
92
+ eos_token_id=dataset.tokenizer.eos_token_id,
93
+ pad_token_id=dataset.tokenizer.pad_token_id,
94
+ vocab_size=dataset.vocab_size,
95
+ ))
96
+
97
+ model = FlaxRobertaForMaskedLMModule(
98
+ roberta_config, dtype=get_float_dtype_by_name(FLAGS.dtype)
99
+ )
100
+
101
+ optimizer, optimizer_info = OptimizerFactory.get_optimizer(
102
+ FLAGS.optimizer,
103
+ get_weight_decay_mask(RobertaConfig.get_weight_decay_exclusions()),
104
+ )
105
+
106
+ def create_trainstate_from_params(params):
107
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
108
+
109
+ def init_fn(rng):
110
+ rng_generator = JaxRNG(rng)
111
+ params = model.init(
112
+ input_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
113
+ position_ids=jnp.zeros((4, seq_length), dtype=jnp.int32),
114
+ attention_mask=jnp.ones((4, seq_length), dtype=jnp.int32),
115
+ token_type_ids=None,
116
+ head_mask=None,
117
+ rngs=rng_generator(roberta_config.rng_keys()),
118
+ )
119
+ return TrainState.create(params=params, tx=optimizer, apply_fn=None)
120
+
121
+ def train_step(train_state, rng, batch):
122
+ rng_generator = JaxRNG(rng)
123
+ tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
124
+ def loss_and_accuracy(params):
125
+ altered_tokens = jax.random.uniform(
126
+ rng_generator(), shape=tokens.shape
127
+ ) < FLAGS.mask_token_probability
128
+ random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
129
+ altered_by_mask = altered_tokens & (random_uniform < 0.8)
130
+ altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
131
+ inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
132
+ random_tokens = jax.random.randint(
133
+ rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
134
+ )
135
+ inputs = jnp.where(altered_by_random, random_tokens, inputs)
136
+ logits = model.apply(
137
+ params, inputs,
138
+ attention_mask=jnp.ones_like(inputs),
139
+ token_type_ids=None,
140
+ position_ids=None,
141
+ head_mask=None,
142
+ deterministic=False,
143
+ rngs=rng_generator(roberta_config.rng_keys()),
144
+ ).logits
145
+ return cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
146
+ grad_fn = jax.value_and_grad(loss_and_accuracy, has_aux=True)
147
+ (loss, accuracy), grads = grad_fn(train_state.params)
148
+ train_state = train_state.apply_gradients(grads=grads)
149
+ metrics = dict(
150
+ loss=loss,
151
+ accuracy=accuracy,
152
+ learning_rate=optimizer_info['learning_rate_schedule'](train_state.step),
153
+ gradient_norm=global_norm(grads),
154
+ param_norm=global_norm(train_state.params),
155
+ )
156
+ return train_state, rng_generator(), metrics
157
+
158
+ def eval_step(train_state, rng, batch):
159
+ rng_generator = JaxRNG(rng)
160
+ tokens = with_sharding_constraint(batch['target_tokens'], PS(('dp', 'fsdp')))
161
+ altered_tokens = jax.random.uniform(
162
+ rng_generator(), shape=tokens.shape
163
+ ) < FLAGS.mask_token_probability
164
+ random_uniform = jax.random.uniform(rng_generator(), shape=tokens.shape)
165
+ altered_by_mask = altered_tokens & (random_uniform < 0.8)
166
+ altered_by_random = altered_tokens & (random_uniform >= 0.8) & (random_uniform < 0.9)
167
+ inputs = jnp.where(altered_by_mask, dataset.tokenizer.mask_token_id, tokens)
168
+ random_tokens = jax.random.randint(
169
+ rng_generator(), shape=tokens.shape, minval=0, maxval=dataset.vocab_size
170
+ )
171
+ inputs = jnp.where(altered_by_random, random_tokens, inputs)
172
+ logits = model.apply(
173
+ train_state.params, inputs,
174
+ attention_mask=jnp.ones_like(inputs),
175
+ token_type_ids=None,
176
+ position_ids=None,
177
+ head_mask=None,
178
+ deterministic=False,
179
+ rngs=rng_generator(roberta_config.rng_keys()),
180
+ ).logits
181
+ loss, accuracy = cross_entropy_loss_and_accuracy(logits, tokens, valid=altered_tokens)
182
+ metrics = dict(
183
+ eval_loss=loss,
184
+ eval_accuracy=accuracy,
185
+ )
186
+ return rng_generator(), metrics
187
+
188
+ train_state_shapes = jax.eval_shape(init_fn, next_rng())
189
+ train_state_partition = match_partition_rules(
190
+ RobertaConfig.get_partition_rules(), train_state_shapes
191
+ )
192
+
193
+ shard_fns, gather_fns = make_shard_and_gather_fns(
194
+ train_state_partition, train_state_shapes
195
+ )
196
+ checkpointer = StreamingCheckpointer(
197
+ FLAGS.checkpointer, logger.output_dir,
198
+ enable=jax.process_index() == 0
199
+ )
200
+
201
+ sharded_init_fn = pjit(
202
+ init_fn,
203
+ in_shardings=PS(),
204
+ out_shardings=train_state_partition
205
+ )
206
+
207
+ sharded_create_trainstate_from_params = pjit(
208
+ create_trainstate_from_params,
209
+ in_shardings=(train_state_partition.params, ),
210
+ out_shardings=train_state_partition,
211
+ donate_argnums=(0, ),
212
+ )
213
+
214
+ sharded_train_step = pjit(
215
+ train_step,
216
+ in_shardings=(train_state_partition, PS(), PS()),
217
+ out_shardings=(train_state_partition, PS(), PS()),
218
+ donate_argnums=(0, 1),
219
+ )
220
+
221
+ sharded_eval_step = pjit(
222
+ eval_step,
223
+ in_shardings=(train_state_partition, PS(), PS()),
224
+ out_shardings=(PS(), PS()),
225
+ donate_argnums=(1,),
226
+ )
227
+
228
+ def save_checkpoint(train_state, milestone=False):
229
+ step = int(jax.device_get(train_state.step))
230
+ metadata = dict(
231
+ step=step,
232
+ variant=variant,
233
+ flags=flags_config_dict,
234
+ roberta_config=roberta_config.to_dict(),
235
+ )
236
+ checkpointer.save_all(
237
+ train_state=train_state,
238
+ gather_fns=gather_fns,
239
+ metadata=metadata,
240
+ dataset=dataset.get_state_dict(),
241
+ milestone=milestone,
242
+ )
243
+
244
+ mesh = RobertaConfig.get_jax_mesh(FLAGS.mesh_dim)
245
+ with mesh:
246
+ train_state, restored_params = None, None
247
+ if FLAGS.load_checkpoint != '':
248
+ load_type, load_path = FLAGS.load_checkpoint.split('::', 1)
249
+ if load_type == 'huggingface':
250
+ restored_params = tree_apply(
251
+ shard_fns.params, roberta_config.load_pretrained(load_path)
252
+ )
253
+ train_state = None
254
+ else:
255
+ train_state, restored_params = checkpointer.load_trainstate_checkpoint(
256
+ FLAGS.load_checkpoint, train_state_shapes, shard_fns
257
+ )
258
+
259
+ if train_state is None and restored_params is None:
260
+ # Initialize from scratch
261
+ train_state = sharded_init_fn(next_rng())
262
+ elif train_state is None and restored_params is not None:
263
+ # Restore from params but initialize train_state
264
+ train_state = sharded_create_trainstate_from_params(restored_params)
265
+ del restored_params
266
+
267
+ start_step = int(jax.device_get(train_state.step))
268
+
269
+ if FLAGS.save_model_freq > 0:
270
+ save_checkpoint(train_state)
271
+
272
+ sharded_rng = next_rng()
273
+
274
+ step_counter = trange(start_step, FLAGS.total_steps, ncols=0)
275
+
276
+ for step, (batch, dataset_metrics) in zip(step_counter, dataset):
277
+ train_state, sharded_rng, metrics = sharded_train_step(
278
+ train_state, sharded_rng, batch
279
+ )
280
+
281
+ if step % FLAGS.log_freq == 0:
282
+ if FLAGS.eval_steps > 0:
283
+ eval_metric_list = []
284
+ for _ in range(FLAGS.eval_steps):
285
+ eval_batch, _ = next(eval_iterator)
286
+ sharded_rng, eval_metrics = sharded_eval_step(
287
+ train_state, sharded_rng, eval_batch
288
+ )
289
+ eval_metric_list.append(eval_metrics)
290
+ metrics.update(average_metrics(eval_metric_list))
291
+
292
+ log_metrics = {"step": step}
293
+ log_metrics.update(metrics)
294
+ log_metrics.update(dataset_metrics)
295
+ log_metrics = jax.device_get(log_metrics)
296
+ logger.log(log_metrics)
297
+ tqdm.write("\n" + pprint.pformat(log_metrics) + "\n")
298
+
299
+ if FLAGS.save_milestone_freq > 0 and (step + 1) % FLAGS.save_milestone_freq == 0:
300
+ save_checkpoint(train_state, milestone=True)
301
+ elif FLAGS.save_model_freq > 0 and (step + 1) % FLAGS.save_model_freq == 0:
302
+ save_checkpoint(train_state)
303
+
304
+ if FLAGS.save_model_freq > 0:
305
+ save_checkpoint(train_state)
306
+
307
+
308
+ if __name__ == "__main__":
309
+ mlxu.run(main)
EasyLM/optimizers.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import Any, Mapping, Text, Tuple, Union, NamedTuple
4
+ from functools import partial
5
+ import re
6
+ import dataclasses
7
+ import random
8
+
9
+ from ml_collections.config_dict import config_dict
10
+ from ml_collections import ConfigDict
11
+ import jax
12
+ import jax.numpy as jnp
13
+ import numpy as np
14
+ from absl import logging
15
+ import optax
16
+
17
+ from EasyLM.jax_utils import float_to_dtype
18
+
19
+
20
+ class OptimizerFactory(object):
21
+ """ Configurable optax optimizer factory. """
22
+
23
+ def __init__(self):
24
+ raise NotImplementedError
25
+
26
+ @staticmethod
27
+ def get_default_config(updates=None):
28
+ config = ConfigDict()
29
+ config.accumulate_gradient_steps = 1
30
+ config.type = 'adamw'
31
+ config.palm_optimizer = PalmOptimizerFactory.get_default_config()
32
+ config.adamw_optimizer = AdamWOptimizerFactory.get_default_config()
33
+ config.lion_optimizer = LionOptimizerFactory.get_default_config()
34
+
35
+ if updates is not None:
36
+ config.update(ConfigDict(updates).copy_and_resolve_references())
37
+ return config
38
+
39
+ @classmethod
40
+ def get_optimizer(cls, config, weight_decay_mask=None):
41
+ config = cls.get_default_config(config)
42
+ if config.type == 'palm':
43
+ optimizer, optimizer_info = PalmOptimizerFactory.get_optimizer(
44
+ config.palm_optimizer, weight_decay_mask
45
+ )
46
+ elif config.type == 'adamw':
47
+ optimizer, optimizer_info = AdamWOptimizerFactory.get_optimizer(
48
+ config.adamw_optimizer, weight_decay_mask
49
+ )
50
+ elif config.type == 'lion':
51
+ optimizer, optimizer_info = LionOptimizerFactory.get_optimizer(
52
+ config.lion_optimizer, weight_decay_mask
53
+ )
54
+ else:
55
+ raise ValueError(f'Unknown optimizer type: {config.type}')
56
+
57
+ if config.accumulate_gradient_steps > 1:
58
+ optimizer = optax.MultiSteps(
59
+ optimizer, config.accumulate_gradient_steps
60
+ )
61
+
62
+ return optimizer, optimizer_info
63
+
64
+
65
+ class PalmOptimizerFactory(object):
66
+ """ PaLM optimizer factory. This optimizer implements the optimizer
67
+ described in the PaLM paper: https://arxiv.org/abs/2204.02311
68
+ """
69
+
70
+ def __init__(self):
71
+ raise NotImplementedError
72
+
73
+ @staticmethod
74
+ def get_default_config(updates=None):
75
+ config = ConfigDict()
76
+ config.lr = 0.01
77
+ config.lr_warmup_steps = 10000
78
+ config.b1 = 0.9
79
+ config.b2 = 0.99
80
+ config.clip_gradient = 1.0
81
+ config.weight_decay = 1e-4
82
+ config.bf16_momentum = False
83
+
84
+ if updates is not None:
85
+ config.update(ConfigDict(updates).copy_and_resolve_references())
86
+ return config
87
+
88
+ @classmethod
89
+ def get_optimizer(cls, config, weight_decay_mask=None):
90
+ config = cls.get_default_config(config)
91
+
92
+ def learning_rate_schedule(step):
93
+ multiplier = config.lr / 0.01
94
+ return multiplier / jnp.sqrt(jnp.maximum(step, config.lr_warmup_steps))
95
+
96
+ def weight_decay_schedule(step):
97
+ multiplier = config.weight_decay / 1e-4
98
+ return -multiplier * jnp.square(learning_rate_schedule(step))
99
+
100
+ optimizer_info = dict(
101
+ learning_rate_schedule=learning_rate_schedule,
102
+ weight_decay_schedule=weight_decay_schedule,
103
+ )
104
+
105
+ optimizer = optax.chain(
106
+ optax.clip_by_global_norm(config.clip_gradient),
107
+ optax.adafactor(
108
+ learning_rate=learning_rate_schedule,
109
+ multiply_by_parameter_scale=True,
110
+ momentum=config.b1,
111
+ decay_rate=config.b2,
112
+ factored=False,
113
+ clipping_threshold=None,
114
+ dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
115
+ ),
116
+ optax_add_scheduled_weight_decay(
117
+ weight_decay_schedule, weight_decay_mask
118
+ )
119
+ )
120
+ return optimizer, optimizer_info
121
+
122
+
123
+ class AdamWOptimizerFactory(object):
124
+ """ AdamW optimizer with cosine schedule. """
125
+
126
+ def __init__(self):
127
+ raise NotImplementedError
128
+
129
+ @staticmethod
130
+ def get_default_config(updates=None):
131
+ config = ConfigDict()
132
+ config.init_lr = 0.0
133
+ config.end_lr = 0.001
134
+ config.lr = 0.01
135
+ config.lr_warmup_steps = 2000
136
+ config.lr_decay_steps = 500000
137
+ config.b1 = 0.9
138
+ config.b2 = 0.95
139
+ config.clip_gradient = 1.0
140
+ config.weight_decay = 1e-4
141
+ config.bf16_momentum = False
142
+ config.multiply_by_parameter_scale = False
143
+
144
+ if updates is not None:
145
+ config.update(ConfigDict(updates).copy_and_resolve_references())
146
+ return config
147
+
148
+ @classmethod
149
+ def get_optimizer(cls, config, weight_decay_mask=None):
150
+ config = cls.get_default_config(config)
151
+
152
+ learning_rate_schedule = optax.warmup_cosine_decay_schedule(
153
+ init_value=config.init_lr,
154
+ peak_value=config.lr,
155
+ warmup_steps=config.lr_warmup_steps,
156
+ decay_steps=config.lr_decay_steps,
157
+ end_value=config.end_lr,
158
+ )
159
+
160
+ optimizer_info = dict(
161
+ learning_rate_schedule=learning_rate_schedule,
162
+ )
163
+
164
+ if config.multiply_by_parameter_scale:
165
+ optimizer = optax.chain(
166
+ optax.clip_by_global_norm(config.clip_gradient),
167
+ optax.adafactor(
168
+ learning_rate=learning_rate_schedule,
169
+ multiply_by_parameter_scale=True,
170
+ momentum=config.b1,
171
+ decay_rate=config.b2,
172
+ factored=False,
173
+ clipping_threshold=None,
174
+ dtype_momentum=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
175
+ ),
176
+ optax_add_scheduled_weight_decay(
177
+ lambda step: -learning_rate_schedule(step) * config.weight_decay,
178
+ weight_decay_mask
179
+ )
180
+ )
181
+ else:
182
+ optimizer = optax.chain(
183
+ optax.clip_by_global_norm(config.clip_gradient),
184
+ optax.adamw(
185
+ learning_rate=learning_rate_schedule,
186
+ weight_decay=config.weight_decay,
187
+ b1=config.b1,
188
+ b2=config.b2,
189
+ mask=weight_decay_mask,
190
+ mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
191
+ ),
192
+ )
193
+
194
+ return optimizer, optimizer_info
195
+
196
+
197
+ class LionOptimizerFactory(object):
198
+ """ Lion optimizer with cosine schedule. """
199
+
200
+ def __init__(self):
201
+ raise NotImplementedError
202
+
203
+ @staticmethod
204
+ def get_default_config(updates=None):
205
+ config = ConfigDict()
206
+ config.init_lr = 0.0
207
+ config.end_lr = 0.0001
208
+ config.lr = 0.001
209
+ config.lr_warmup_steps = 2000
210
+ config.lr_decay_steps = 500000
211
+ config.b1 = 0.9
212
+ config.b2 = 0.98
213
+ config.clip_gradient = 1.0
214
+ config.weight_decay = 1e-3
215
+ config.bf16_momentum = False
216
+
217
+ if updates is not None:
218
+ config.update(ConfigDict(updates).copy_and_resolve_references())
219
+ return config
220
+
221
+ @classmethod
222
+ def get_optimizer(cls, config, weight_decay_mask=None):
223
+ config = cls.get_default_config(config)
224
+
225
+ learning_rate_schedule = optax.warmup_cosine_decay_schedule(
226
+ init_value=config.init_lr,
227
+ peak_value=config.lr,
228
+ warmup_steps=config.lr_warmup_steps,
229
+ decay_steps=config.lr_decay_steps,
230
+ end_value=config.end_lr,
231
+ )
232
+
233
+ optimizer_info = dict(
234
+ learning_rate_schedule=learning_rate_schedule,
235
+ )
236
+
237
+ optimizer = optax.chain(
238
+ optax.clip_by_global_norm(config.clip_gradient),
239
+ optax.lion(
240
+ learning_rate=learning_rate_schedule,
241
+ weight_decay=config.weight_decay,
242
+ b1=config.b1,
243
+ b2=config.b2,
244
+ mask=weight_decay_mask,
245
+ mu_dtype=jnp.bfloat16 if config.bf16_momentum else jnp.float32,
246
+ ),
247
+ )
248
+
249
+ return optimizer, optimizer_info
250
+
251
+
252
+ class OptaxScheduledWeightDecayState(NamedTuple):
253
+ count: jnp.DeviceArray
254
+
255
+
256
+ def optax_add_scheduled_weight_decay(schedule_fn, mask=None):
257
+ """ Apply weight decay with schedule. """
258
+
259
+ def init_fn(params):
260
+ del params
261
+ return OptaxScheduledWeightDecayState(count=jnp.zeros([], jnp.int32))
262
+
263
+ def update_fn(updates, state, params):
264
+ if params is None:
265
+ raise ValueError('Params cannot be None for weight decay!')
266
+
267
+ weight_decay = schedule_fn(state.count)
268
+ updates = jax.tree_util.tree_map(
269
+ lambda g, p: g + weight_decay * p, updates, params
270
+ )
271
+ return updates, OptaxScheduledWeightDecayState(
272
+ count=optax.safe_int32_increment(state.count)
273
+ )
274
+
275
+ if mask is not None:
276
+ return optax.masked(optax.GradientTransformation(init_fn, update_fn), mask)
277
+ return optax.GradientTransformation(init_fn, update_fn)
EasyLM/scripts/__init__.py ADDED
File without changes
EasyLM/scripts/convert_checkpoint.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts model checkpoint trained by EsayLM to a standard
2
+ # mspack checkpoint that can be loaded by huggingface transformers or
3
+ # flax.serialization.msgpack_restore. Such conversion allows models to be
4
+ # used by other frameworks that integrate with huggingface transformers.
5
+
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ import numpy as np
10
+ import mlxu
11
+ import jax.numpy as jnp
12
+ import flax.serialization
13
+ from EasyLM.checkpoint import StreamingCheckpointer
14
+ from EasyLM.jax_utils import float_to_dtype
15
+
16
+
17
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
18
+ load_checkpoint='',
19
+ output_file='',
20
+ streaming=False,
21
+ float_dtype='bf16',
22
+ )
23
+
24
+
25
+ def main(argv):
26
+ assert FLAGS.load_checkpoint != '' and FLAGS.output_file != '', 'input and output must be specified'
27
+ params = StreamingCheckpointer.load_trainstate_checkpoint(
28
+ FLAGS.load_checkpoint, disallow_trainstate=True
29
+ )[1]['params']
30
+
31
+ if FLAGS.streaming:
32
+ StreamingCheckpointer.save_train_state_to_file(
33
+ params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
34
+ )
35
+ else:
36
+ params = float_to_dtype(params, FLAGS.float_dtype)
37
+ with mlxu.open_file(FLAGS.output, 'wb') as fout:
38
+ fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
39
+
40
+
41
+ if __name__ == "__main__":
42
+ mlxu.run(main)
EasyLM/scripts/diff_checkpoint.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script converts model checkpoint trained by EsayLM to a standard
2
+ # mspack checkpoint that can be loaded by huggingface transformers or
3
+ # flax.serialization.msgpack_restore. Such conversion allows models to be
4
+ # used by other frameworks that integrate with huggingface transformers.
5
+
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ import numpy as np
10
+ import jax
11
+ import jax.numpy as jnp
12
+ import flax.serialization
13
+ import mlxu
14
+ from EasyLM.checkpoint import StreamingCheckpointer
15
+ from EasyLM.jax_utils import float_to_dtype
16
+
17
+
18
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
19
+ recover_diff=False,
20
+ load_base_checkpoint='',
21
+ load_target_checkpoint='',
22
+ output_file='',
23
+ streaming=True,
24
+ float_dtype='bf16',
25
+ )
26
+
27
+
28
+ def main(argv):
29
+ assert FLAGS.load_base_checkpoint != '' and FLAGS.load_target_checkpoint != ''
30
+ assert FLAGS.output_file != ''
31
+ base_params = StreamingCheckpointer.load_trainstate_checkpoint(
32
+ FLAGS.load_base_checkpoint, disallow_trainstate=True
33
+ )[1]['params']
34
+
35
+ target_params = StreamingCheckpointer.load_trainstate_checkpoint(
36
+ FLAGS.load_target_checkpoint, disallow_trainstate=True
37
+ )[1]['params']
38
+
39
+ if FLAGS.recover_diff:
40
+ params = jax.tree_util.tree_map(
41
+ lambda b, t: b + t, base_params, target_params
42
+ )
43
+ else:
44
+ params = jax.tree_util.tree_map(
45
+ lambda b, t: t - b, base_params, target_params
46
+ )
47
+
48
+ if FLAGS.streaming:
49
+ StreamingCheckpointer.save_train_state_to_file(
50
+ params, FLAGS.output_file, float_dtype=FLAGS.float_dtype
51
+ )
52
+ else:
53
+ params = float_to_dtype(params, FLAGS.float_dtype)
54
+ with mlxu.open_file(FLAGS.output, 'wb') as fout:
55
+ fout.write(flax.serialization.msgpack_serialize(params, in_place=True))
56
+
57
+
58
+ if __name__ == "__main__":
59
+ mlxu.run(main)
EasyLM/scripts/lm_eval_harness.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This script runs lm_eval_harness evaluations against a served language model.
2
+ # Typically, you need to run a language model server first, e.g.:
3
+ # python -m EasyLM.models.gptj.gptj_serve ...
4
+
5
+ import dataclasses
6
+ import pprint
7
+ from functools import partial
8
+ import os
9
+ from tqdm import tqdm, trange
10
+ import numpy as np
11
+ import mlxu
12
+
13
+ from flax.traverse_util import flatten_dict
14
+ from lm_eval import evaluator, tasks
15
+ from lm_eval.base import LM
16
+
17
+ from EasyLM.serving import LMClient
18
+
19
+
20
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
21
+ tasks='wsc,piqa,winogrande,openbookqa,logiqa',
22
+ shots=0,
23
+ lm_client=LMClient.get_default_config(),
24
+ logger=mlxu.WandBLogger.get_default_config(),
25
+ )
26
+
27
+
28
+ class LMEvalHarnessInterface(LM):
29
+
30
+ def __init__(self, lm_client):
31
+ self.lm_client = lm_client
32
+
33
+ def greedy_until(self, inputs):
34
+ prefix, until = zip(*inputs)
35
+ return self.lm_client.greedy_until(prefix, until)
36
+
37
+ def loglikelihood_rolling(self, inputs):
38
+ loglikelihood, is_greedy = self.lm_client.loglikelihood_rolling(inputs)
39
+ return list(zip(loglikelihood, is_greedy))
40
+
41
+ def loglikelihood(self, inputs):
42
+ prefix, text = zip(*inputs)
43
+ loglikelihood, is_greedy = self.lm_client.loglikelihood(prefix, text)
44
+ return list(zip(loglikelihood, is_greedy))
45
+
46
+
47
+ def main(argv):
48
+ logger = mlxu.WandBLogger(
49
+ config=FLAGS.logger, variant=mlxu.get_user_flags(FLAGS, FLAGS_DEF)
50
+ )
51
+ model = LMEvalHarnessInterface(LMClient(FLAGS.lm_client))
52
+ task_list = FLAGS.tasks.split(',')
53
+ results = evaluator.evaluate(
54
+ model, tasks.get_task_dict(task_list), False, FLAGS.shots, None
55
+ )
56
+ logger.log(flatten_dict(results['results'], sep='/'))
57
+ pprint.pprint(results)
58
+
59
+
60
+ if __name__ == "__main__":
61
+ mlxu.run(main)
EasyLM/scripts/lm_eval_json.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import mlxu
3
+ from EasyLM.serving import LMClient
4
+
5
+
6
+ FLAGS, FLAGS_DEF = mlxu.define_flags_with_default(
7
+ input_file='',
8
+ output_file='',
9
+ prefix_field='prefix',
10
+ text_field='text',
11
+ until_field='until',
12
+ eval_type='loglikelihood',
13
+ lm_client=LMClient.get_default_config(),
14
+ )
15
+
16
+
17
+ def main(argv):
18
+ lm_client = LMClient(FLAGS.lm_client)
19
+ with mlxu.open_file(FLAGS.input_file, 'r') as fin:
20
+ input_data = json.load(fin)
21
+
22
+ if FLAGS.eval_type == 'loglikelihood':
23
+ prefix = input_data[FLAGS.prefix_field]
24
+ text = input_data[FLAGS.text_field]
25
+ loglikelihoods, is_greedys = lm_client.loglikelihood(prefix, text)
26
+ output_data = {
27
+ 'loglikelihood': loglikelihoods,
28
+ 'is_greedy': is_greedys,
29
+ }
30
+ elif FLAGS.eval_type == 'loglikelihood_rolling':
31
+ text = input_data[FLAGS.text_field]
32
+ loglikelihoods, is_greedys = lm_client.loglikelihood_rolling(text)
33
+ output_data = {
34
+ 'loglikelihood': loglikelihoods,
35
+ 'is_greedy': is_greedys,
36
+ }
37
+ elif FLAGS.eval_type == 'greedy_until':
38
+ prefix = input_data[FLAGS.prefix_field]
39
+ until = input_data[FLAGS.until_field]
40
+ output_data = {'output_text': lm_client.greedy_until(prefix, until)}
41
+ elif FLAGS.eval_type == 'generate':
42
+ prefix = input_data[FLAGS.prefix_field]
43
+ output_data = {'output_text': lm_client.generate(prefix)}
44
+ else:
45
+ raise ValueError(f'Unknown eval_type: {FLAGS.eval_type}')
46
+
47
+ with mlxu.open_file(FLAGS.output_file, 'w') as fout:
48
+ json.dump(output_data, fout)
49
+
50
+
51
+ if __name__ == "__main__":
52
+ mlxu.run(main)
EasyLM/serving.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import pprint
3
+ from functools import partial
4
+ import re
5
+ import os
6
+ from threading import Lock
7
+ import urllib
8
+ import time
9
+ from typing import List, Optional, Union
10
+
11
+ from pydantic import BaseModel
12
+ import absl.logging
13
+ from tqdm import tqdm, trange
14
+ import numpy as np
15
+ import mlxu
16
+ from ml_collections import ConfigDict
17
+ import uvicorn
18
+ from fastapi import FastAPI
19
+ import gradio as gr
20
+ import requests
21
+ from requests.exceptions import Timeout, ConnectionError
22
+
23
+
24
+ class InferenceRequest(BaseModel):
25
+ prefix_text: Optional[List[str]] = None
26
+ text: Optional[List[str]] = None
27
+ until: Optional[Union[List[str], List[List[str]]]] = None
28
+ temperature: Optional[float] = None
29
+
30
+
31
+ class ChatRequest(BaseModel):
32
+ prompt: str
33
+ context: str = ''
34
+ temperature: Optional[float] = None
35
+
36
+
37
+ class LMServer(object):
38
+ """ HTTP server for serving langauge models. """
39
+
40
+ @staticmethod
41
+ def get_default_config(updates=None):
42
+ config = ConfigDict()
43
+ config.host = '0.0.0.0'
44
+ config.port = 5007
45
+ config.batch_size = 1
46
+ config.logging = False
47
+ config.pre_compile = 'loglikelihood'
48
+ config.default_temperature = 1.0
49
+ config.greedy_until_max_length = 5000
50
+ config.prepend_to_prefix = ''
51
+ config.append_to_prefix = ''
52
+ config.prepend_to_text = ''
53
+ config.append_to_text = ''
54
+ config.chat_prepend_text = ''
55
+ config.chat_user_prefix = ''
56
+ config.chat_user_suffix = ''
57
+ config.chat_lm_prefix = ''
58
+ config.chat_lm_suffix = ''
59
+ config.notes = ''
60
+
61
+ if updates is not None:
62
+ config.update(ConfigDict(updates).copy_and_resolve_references())
63
+ return config
64
+
65
+ def __init__(self, config):
66
+ self.config = self.get_default_config(config)
67
+ self.lock = Lock()
68
+ self.app = FastAPI()
69
+ self.app.post('/loglikelihood')(self.serve_loglikelihood)
70
+ self.app.post('/loglikelihood-rolling')(self.serve_loglikelihood_rolling)
71
+ self.app.post('/generate')(self.serve_generate)
72
+ self.app.post('/greedy-until')(self.serve_greedy_until)
73
+ self.app.post('/chat')(self.serve_chat)
74
+ self.app.get('/ready')(self.serve_ready)
75
+ self.app = gr.mount_gradio_app(self.app, self.create_chat_app(), '/')
76
+
77
+ @staticmethod
78
+ def loglikelihood(prefix_text, text):
79
+ raise NotImplementedError()
80
+
81
+ @staticmethod
82
+ def loglikelihood_rolling(text):
83
+ raise NotImplementedError()
84
+
85
+ @staticmethod
86
+ def generate(text, temperature):
87
+ raise NotImplementedError()
88
+
89
+ @staticmethod
90
+ def greedy_until(prefix_text, until, max_length):
91
+ raise NotImplementedError()
92
+
93
+ @staticmethod
94
+ def to_list(x):
95
+ if isinstance(x, np.ndarray):
96
+ return x.tolist()
97
+ return x
98
+
99
+ def serve_ready(self):
100
+ return 'Ready!\n'
101
+
102
+ def serve_loglikelihood(self, data: InferenceRequest):
103
+ with self.lock:
104
+ if self.config.logging:
105
+ absl.logging.info(
106
+ '\n========= Serving Log Likelihood Request ========= \n'
107
+ + pprint.pformat(data) + '\n'
108
+ )
109
+
110
+ if data.prefix_text is None:
111
+ data.prefix_text = ['' for _ in data.text]
112
+
113
+ prefix_text = [
114
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
115
+ for p in data.prefix_text
116
+ ]
117
+ text = [
118
+ self.config.prepend_to_text + t + self.config.append_to_text
119
+ for t in data.text
120
+ ]
121
+
122
+ log_likelihood = []
123
+ is_greedy = []
124
+ for i in trange(0, len(text), self.config.batch_size, ncols=0):
125
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
126
+ batch_text = text[i:i + self.config.batch_size]
127
+ batch_size = len(batch_text)
128
+
129
+ if batch_size < self.config.batch_size:
130
+ extra = self.config.batch_size - batch_size
131
+ batch_prefix_text.extend(['a' for _ in range(extra)])
132
+ batch_text.extend(['a' for _ in range(extra)])
133
+
134
+ batch_log_likelihood, batch_is_greedy = self.loglikelihood(
135
+ batch_prefix_text, batch_text
136
+ )
137
+ batch_log_likelihood = self.to_list(batch_log_likelihood)
138
+ batch_is_greedy = self.to_list(batch_is_greedy)
139
+ log_likelihood.extend(batch_log_likelihood[:batch_size])
140
+ is_greedy.extend(batch_is_greedy[:batch_size])
141
+
142
+ output = {
143
+ 'prefix_text': data.prefix_text,
144
+ 'text': data.text,
145
+ 'log_likelihood': log_likelihood,
146
+ 'is_greedy': is_greedy,
147
+ }
148
+ if self.config.logging:
149
+ absl.logging.info(
150
+ '\n========= Output ========= \n'
151
+ + pprint.pformat(output) + '\n'
152
+ )
153
+
154
+ return output
155
+
156
+ def serve_loglikelihood_rolling(self, data: InferenceRequest):
157
+ with self.lock:
158
+ if self.config.logging:
159
+ absl.logging.info(
160
+ '\n========= Serving Log Likelihood Request ========= \n'
161
+ + pprint.pformat(data) + '\n'
162
+ )
163
+
164
+ text = [
165
+ self.config.prepend_to_text + t + self.config.append_to_text
166
+ for t in data.text
167
+ ]
168
+ log_likelihood = []
169
+ is_greedy = []
170
+ for i in trange(0, len(text), self.config.batch_size, ncols=0):
171
+ batch_text = text[i:i + self.config.batch_size]
172
+ batch_size = len(batch_text)
173
+
174
+ if batch_size < self.config.batch_size:
175
+ extra = self.config.batch_size - batch_size
176
+ batch_text.extend(['a' for _ in range(extra)])
177
+
178
+ batch_log_likelihood, batch_is_greedy = self.loglikelihood_rolling(
179
+ batch_text
180
+ )
181
+ batch_log_likelihood = self.to_list(batch_log_likelihood)
182
+ batch_is_greedy = self.to_list(batch_is_greedy)
183
+ log_likelihood.extend(batch_log_likelihood[:batch_size])
184
+ is_greedy.extend(batch_is_greedy[:batch_size])
185
+
186
+ output = {
187
+ 'text': data.text,
188
+ 'log_likelihood': log_likelihood,
189
+ 'is_greedy': is_greedy,
190
+ }
191
+ if self.config.logging:
192
+ absl.logging.info(
193
+ '\n========= Output ========= \n'
194
+ + pprint.pformat(output) + '\n'
195
+ )
196
+
197
+ return output
198
+
199
+ def serve_generate(self, data: InferenceRequest):
200
+ with self.lock:
201
+ if self.config.logging:
202
+ absl.logging.info(
203
+ '\n========= Serving Generate Request ========= \n'
204
+ + pprint.pformat(data) + '\n'
205
+ )
206
+ prefix_text = [
207
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
208
+ for p in data.prefix_text
209
+ ]
210
+
211
+ if data.temperature is None:
212
+ data.temperature = self.config.default_temperature
213
+
214
+ output_text = []
215
+ for i in trange(0, len(prefix_text), self.config.batch_size, ncols=0):
216
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
217
+ batch_size = len(batch_prefix_text)
218
+
219
+ if batch_size < self.config.batch_size:
220
+ extra = self.config.batch_size - batch_size
221
+ batch_prefix_text.extend(['a' for _ in range(extra)])
222
+
223
+ batch_output_text = self.generate(
224
+ batch_prefix_text,
225
+ temperature=data.temperature,
226
+ )
227
+ output_text.extend(self.to_list(batch_output_text)[:batch_size])
228
+
229
+ output = {
230
+ 'prefix_text': data.prefix_text,
231
+ 'output_text': output_text,
232
+ 'temperature': data.temperature,
233
+ }
234
+ if self.config.logging:
235
+ absl.logging.info(
236
+ '\n========= Output ========= \n'
237
+ + pprint.pformat(output) + '\n'
238
+ )
239
+ return output
240
+
241
+ def serve_greedy_until(self, data: InferenceRequest):
242
+ with self.lock:
243
+ if self.config.logging:
244
+ absl.logging.info(
245
+ '\n========= Serving Greedy Until Request ========= \n'
246
+ + pprint.pformat(data) + '\n'
247
+ )
248
+ prefix_text = [
249
+ self.config.prepend_to_prefix + p + self.config.append_to_prefix
250
+ for p in data.prefix_text
251
+ ]
252
+ until = data.until
253
+ max_length = self.config.greedy_until_max_length
254
+
255
+ output_text = []
256
+ for i in range(0, len(prefix_text), self.config.batch_size):
257
+ batch_prefix_text = prefix_text[i:i + self.config.batch_size]
258
+ batch_until = until[i:i + self.config.batch_size]
259
+ batch_size = len(batch_prefix_text)
260
+
261
+ batch_output_text = self.greedy_until(batch_prefix_text, batch_until, max_length)
262
+ output_text.extend(self.to_list(batch_output_text)[:batch_size])
263
+
264
+ output = {
265
+ 'prefix_text': data.prefix_text,
266
+ 'until': data.until,
267
+ 'max_length': max_length,
268
+ 'output_text': output_text,
269
+ }
270
+ if self.config.logging:
271
+ absl.logging.info(
272
+ '\n========= Output ========= \n'
273
+ + pprint.pformat(output) + '\n'
274
+ )
275
+ return output
276
+
277
+ def process_chat(self, prompt, context, temperature):
278
+ context = (
279
+ context + self.config.chat_user_prefix
280
+ + prompt + self.config.chat_user_suffix
281
+ + self.config.chat_lm_prefix
282
+ )
283
+ response = self.generate(
284
+ [self.config.chat_prepend_text + context],
285
+ temperature=float(temperature),
286
+ )[0]
287
+ context = context + response + self.config.chat_lm_suffix
288
+ return response, context
289
+
290
+ def serve_chat(self, data: ChatRequest):
291
+ if data.temperature is None:
292
+ data.temperature = self.config.default_temperature
293
+ response, context = self.process_chat(
294
+ data.prompt, data.context,
295
+ temperature=data.temperature,
296
+ )
297
+ return {
298
+ 'response': response,
299
+ 'context': context,
300
+ 'temperature': data.temperature,
301
+ }
302
+
303
+ def create_chat_app(self):
304
+ with gr.Blocks(analytics_enabled=False, title='EasyLM Chat') as gradio_chatbot:
305
+ gr.Markdown('# Chatbot Powered by [EasyLM](https://github.com/young-geng/EasyLM)')
306
+ gr.Markdown(self.config.notes)
307
+ chatbot = gr.Chatbot(label='Chat history')
308
+ msg = gr.Textbox(
309
+ placeholder='Type your message here...',
310
+ show_label=False
311
+ )
312
+ with gr.Row():
313
+ send = gr.Button('Send')
314
+ regenerate = gr.Button('Regenerate', interactive=False)
315
+ clear = gr.Button('Reset')
316
+
317
+ temp_slider = gr.Slider(
318
+ label='Temperature', minimum=0, maximum=2.0,
319
+ value=self.config.default_temperature
320
+ )
321
+
322
+ context_state = gr.State(['', ''])
323
+
324
+ def user_fn(user_message, history, context):
325
+ return {
326
+ msg: gr.update(value='', interactive=False),
327
+ clear: gr.update(interactive=False),
328
+ send: gr.update(interactive=False),
329
+ regenerate: gr.update(interactive=False),
330
+ chatbot: history + [[user_message, None]],
331
+ context_state: [context[1], context[1]],
332
+ }
333
+
334
+ def model_fn(history, context, temperature):
335
+ history[-1][1], new_context = self.process_chat(
336
+ history[-1][0], context[0], temperature
337
+ )
338
+ return {
339
+ msg: gr.update(value='', interactive=True),
340
+ clear: gr.update(interactive=True),
341
+ send: gr.update(interactive=True),
342
+ chatbot: history,
343
+ context_state: [context[0], new_context],
344
+ regenerate: gr.update(interactive=True),
345
+ }
346
+
347
+ def regenerate_fn():
348
+ return {
349
+ msg: gr.update(value='', interactive=False),
350
+ clear: gr.update(interactive=False),
351
+ send: gr.update(interactive=False),
352
+ regenerate: gr.update(interactive=False),
353
+ }
354
+
355
+ def clear_fn():
356
+ return {
357
+ chatbot: None,
358
+ msg: '',
359
+ context_state: ['', ''],
360
+ regenerate: gr.update(interactive=False),
361
+ }
362
+
363
+ msg.submit(
364
+ user_fn,
365
+ inputs=[msg, chatbot, context_state],
366
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
367
+ queue=False
368
+ ).then(
369
+ model_fn,
370
+ inputs=[chatbot, context_state, temp_slider],
371
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
372
+ queue=True
373
+ )
374
+ send.click(
375
+ user_fn,
376
+ inputs=[msg, chatbot, context_state],
377
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
378
+ queue=False
379
+ ).then(
380
+ model_fn,
381
+ inputs=[chatbot, context_state, temp_slider],
382
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
383
+ queue=True
384
+ )
385
+ regenerate.click(
386
+ regenerate_fn,
387
+ inputs=None,
388
+ outputs=[msg, clear, send, regenerate],
389
+ queue=False
390
+ ).then(
391
+ model_fn,
392
+ inputs=[chatbot, context_state, temp_slider],
393
+ outputs=[msg, clear, send, chatbot, context_state, regenerate],
394
+ queue=True
395
+ )
396
+ clear.click(
397
+ clear_fn,
398
+ inputs=None,
399
+ outputs=[chatbot, msg, context_state, regenerate],
400
+ queue=False
401
+ )
402
+
403
+ gradio_chatbot.queue(concurrency_count=1)
404
+ return gradio_chatbot
405
+
406
+ def run(self):
407
+ if self.config.pre_compile != '':
408
+ if self.config.pre_compile == 'all':
409
+ pre_compile = ['loglikelihood', 'generate', 'greedy_until', 'chat']
410
+ else:
411
+ pre_compile = self.config.pre_compile.split(',')
412
+
413
+ pre_compile_data = ['a' for _ in range(self.config.batch_size)]
414
+ for task in pre_compile:
415
+ if task == 'loglikelihood':
416
+ self.loglikelihood(pre_compile_data, pre_compile_data)
417
+ self.loglikelihood_rolling(pre_compile_data)
418
+ elif task == 'generate':
419
+ self.generate(pre_compile_data, 1.0)
420
+ elif task == 'greedy_until':
421
+ self.greedy_until(
422
+ pre_compile_data, pre_compile_data,
423
+ self.config.greedy_until_max_length
424
+ )
425
+ elif task == 'chat':
426
+ self.process_chat('a', 'a', 1.0)
427
+ else:
428
+ raise ValueError(f'Invalid precompile task: {task}!')
429
+
430
+ uvicorn.run(self.app, host=self.config.host, port=self.config.port)
431
+
432
+
433
+ class LMClient(object):
434
+ """ A simple client for the LM server. """
435
+
436
+ @staticmethod
437
+ def get_default_config(updates=None):
438
+ config = ConfigDict()
439
+ config.url = 'http://localhost:5007'
440
+ config.batch_size = 1
441
+ config.wait_for_ready = True
442
+ config.dummy = False
443
+
444
+ if updates is not None:
445
+ config.update(ConfigDict(updates).copy_and_resolve_references())
446
+ return config
447
+
448
+ def __init__(self, config=None):
449
+ self.config = self.get_default_config(config)
450
+ if self.config.wait_for_ready:
451
+ self.wait_for_ready()
452
+
453
+ def wait_for_ready(self):
454
+ if self.config.dummy:
455
+ return
456
+ while True:
457
+ try:
458
+ requests.get(urllib.parse.urljoin(self.config.url, 'ready'))
459
+ return
460
+ except (Timeout, ConnectionError) as e:
461
+ time.sleep(10)
462
+
463
+ @staticmethod
464
+ def batched(iterator, batch_size):
465
+ batch = []
466
+ for example in iterator:
467
+ batch.append(example)
468
+ if len(batch) == batch_size:
469
+ yield batch
470
+ batch = []
471
+ if len(batch) > 0:
472
+ yield batch
473
+
474
+ def loglikelihood(self, prefix, text):
475
+ prefix, text = list(prefix), list(text)
476
+ if self.config.dummy:
477
+ return [-1.0 for _ in text], [False for _ in text]
478
+
479
+ log_likelihood = []
480
+ is_greedy = []
481
+
482
+ batched_iterator = list(zip(
483
+ self.batched(prefix, self.config.batch_size),
484
+ self.batched(text, self.config.batch_size)
485
+ ))
486
+ for batch_prefix, batch_text in tqdm(batched_iterator, ncols=0):
487
+ response = requests.post(
488
+ urllib.parse.urljoin(self.config.url, 'loglikelihood'),
489
+ json={'prefix_text': batch_prefix, 'text': batch_text}
490
+ ).json()
491
+ log_likelihood.extend(response['log_likelihood'])
492
+ is_greedy.extend(response['is_greedy'])
493
+
494
+ return log_likelihood, is_greedy
495
+
496
+ def loglikelihood_rolling(self, text):
497
+ text = list(text)
498
+ if self.config.dummy:
499
+ return [-1.0 for _ in text], [False for _ in text]
500
+
501
+ log_likelihood = []
502
+ is_greedy = []
503
+ batched_iterator = list(self.batched(text, self.config.batch_size))
504
+ for batch_text in tqdm(batched_iterator, ncols=0):
505
+ response = requests.post(
506
+ urllib.parse.urljoin(self.config.url, 'loglikelihood-rolling'),
507
+ json={'text': batch_text}
508
+ ).json()
509
+ log_likelihood.extend(response['log_likelihood'])
510
+ is_greedy.extend(response['is_greedy'])
511
+ return log_likelihood, is_greedy
512
+
513
+ def greedy_until(self, prefix, until):
514
+ prefix, until = list(prefix), list(until)
515
+ if self.config.dummy:
516
+ results = []
517
+ for u in until:
518
+ if isinstance(u, str):
519
+ results.append('dummy text ' + u)
520
+ else:
521
+ results.append('dummy text ' + u[0])
522
+ return results
523
+
524
+ batched_iterator = list(zip(
525
+ self.batched(prefix, self.config.batch_size),
526
+ self.batched(until, self.config.batch_size),
527
+ ))
528
+ output_text = []
529
+ for batch_prefix, batch_until in tqdm(batched_iterator, ncols=0):
530
+ response = requests.post(
531
+ urllib.parse.urljoin(self.config.url, 'greedy-until'),
532
+ json={'prefix_text': batch_prefix, 'until': batch_until}
533
+ ).json()
534
+ output_text.extend(response['output_text'])
535
+ return output_text
536
+
537
+ def generate(self, prefix, temperature=None):
538
+ prefix = list(prefix)
539
+ if self.config.dummy:
540
+ return ['' for _ in prefix]
541
+
542
+ output_text = []
543
+ batched_iterator = list(self.batched(prefix, self.config.batch_size))
544
+ for batch_prefix in tqdm(batched_iterator, ncols=0):
545
+ response = requests.post(
546
+ urllib.parse.urljoin(self.config.url, 'generate'),
547
+ json={
548
+ 'prefix_text': batch_prefix,
549
+ 'temperature': temperature,
550
+ }
551
+ ).json()
552
+ output_text.extend(response['output_text'])
553
+ return output_text
554
+
555
+ def chat(self, prompt, context, temperature=None):
556
+ if self.config.dummy:
557
+ return ''
558
+ response = requests.post(
559
+ urllib.parse.urljoin(self.config.url, 'chat'),
560
+ json={
561
+ 'prompt': prompt,
562
+ 'context': context,
563
+ 'temperature': temperature,
564
+ }
565
+ ).json()
566
+ return response['response'], response['context']
convert_to_hf_model.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ JAX_PLATFORM_NAME=cpu python3 -m EasyLM.models.llama.convert_easylm_to_hf \
2
+ --load_checkpoint='' \
3
+ --model_size='7b' \
4
+ --output_dir='./'
pretrain_llama_7b.sh ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /bin/bash
2
+
3
+ # Put your WANDB API key here to enable logging to wandb.
4
+ export WANDB_API_KEY=''
5
+
6
+ # TPU specific flags to improve training throughput
7
+ export LIBTPU_INIT_ARGS='--xla_jf_spmd_threshold_for_windowed_einsum_mib=0 --xla_tpu_spmd_threshold_for_allgather_cse=10000 --xla_tpu_spmd_rewrite_einsum_with_reshape=true --jax_enable_async_collective_offload=true --xla_tpu_enable_latency_hiding_scheduler=true TPU_MEGACORE=MEGACORE_DENSE'
8
+
9
+
10
+ python3 -m EasyLM.models.llama.llama_train \
11
+ --initialize_jax_distributed=True \
12
+ --mesh_dim='1,-1,4' \
13
+ --dtype='bf16' \
14
+ --total_steps=1000000 \
15
+ --eval_freq=50000 \
16
+ --log_freq=1000 \
17
+ --save_model_freq=2000 \
18
+ --save_milestone_freq=50000 \
19
+ --load_llama_config='7b' \
20
+ --update_llama_config='' \
21
+ --load_dataset_state='' \
22
+ --load_checkpoint='' \
23
+ --tokenizer.vocab_file='tokenizer.model' \
24
+ --optimizer.type='lion' \
25
+ --optimizer.lion_optimizer.weight_decay=1.0 \
26
+ --optimizer.lion_optimizer.lr=3e-5 \
27
+ --optimizer.lion_optimizer.end_lr=3e-6 \
28
+ --optimizer.lion_optimizer.lr_warmup_steps=2000 \
29
+ --optimizer.lion_optimizer.lr_decay_steps=1000000 \
30
+ --optimizer.lion_optimizer.bf16_momentum=True \
31
+ --train_dataset.type='huggingface' \
32
+ --train_dataset.text_processor.fields='text' \
33
+ --train_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_v2_filtered' \
34
+ --train_dataset.huggingface_dataset.split='train' \
35
+ --train_dataset.huggingface_dataset.seq_length=2048 \
36
+ --train_dataset.huggingface_dataset.batch_size=64 \
37
+ --eval_dataset.type='huggingface' \
38
+ --eval_dataset.text_processor.fields='text' \
39
+ --eval_dataset.huggingface_dataset.path='/researchdisk/lm_training_dataset_v2_filtered' \
40
+ --eval_dataset.huggingface_dataset.split='validation' \
41
+ --eval_dataset.huggingface_dataset.seq_length=2048 \
42
+ --eval_dataset.huggingface_dataset.batch_size=64 \
43
+ --checkpointer.save_optimizer_state=True \
44
+ --logger.online=True \
45
+ --logger.prefix='EasyLM' \
46
+ --logger.project="open_llama_7b" \
47
+ --logger.output_dir="gs://finnish-nlp-research/llama-7b-checkpoint" \
48
+ --logger.wandb_dir="./"
49
+
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<unk>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ab1b681ec7fc02fed5edd3026687d7a692a918c4dd8e150ca2e3994a6229843b
3
+ size 534194
tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": true,
3
+ "add_eos_token": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<s>",
7
+ "lstrip": false,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": false,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "</s>",
16
+ "lstrip": false,
17
+ "normalized": false,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "model_max_length": 2048,
22
+ "pad_token": null,
23
+ "sp_model_kwargs": {},
24
+ "tokenizer_class": "LlamaTokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<unk>",
28
+ "lstrip": false,
29
+ "normalized": false,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
train_tokenizer.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ from transformers import AutoTokenizer
3
+
4
+ dataset = datasets.load_from_disk("/researchdisk/lm_training_dataset_v2_filtered")
5
+ dataset = dataset["train"].train_test_split(train_size=0.02)
6
+
7
+ old_tokenizer = AutoTokenizer.from_pretrained("openlm-research/open_llama_7b_700bt_preview")
8
+
9
+ def get_training_corpus():
10
+ return (
11
+ dataset["train"][i : i + 1000]["text"]
12
+ for i in range(0, len(dataset["train"]), 1000)
13
+ )
14
+
15
+
16
+ training_corpus = get_training_corpus()
17
+
18
+ tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, vocab_size=64256, min_frequency=2)
19
+ tokenizer.save_pretrained("./")