versae commited on
Commit
cefa210
1 Parent(s): 529e26d

Step... (1000/50000 | Loss: 1.7686773538589478, Acc: 0.6487793326377869): 3%|▊ | 1286/50000 [29:40<20:20:20, 1.50s/it]

Browse files
Files changed (45) hide show
  1. README.md +31 -0
  2. config.json +25 -0
  3. configs/base/config.json +25 -0
  4. configs/base/tokenizer.json +0 -0
  5. configs/large/config.json +25 -0
  6. configs/large/tokenizer.json +0 -0
  7. convert.py +29 -0
  8. flax_model.msgpack +3 -0
  9. merges.txt +0 -0
  10. outputs/checkpoints/checkpoint-1000/config.json +25 -0
  11. outputs/checkpoints/checkpoint-1000/data_collator.joblib +3 -0
  12. outputs/checkpoints/checkpoint-1000/flax_model.msgpack +3 -0
  13. outputs/checkpoints/checkpoint-1000/optimizer_state.msgpack +3 -0
  14. outputs/checkpoints/checkpoint-1000/training_args.joblib +3 -0
  15. outputs/checkpoints/checkpoint-1000/training_state.json +1 -0
  16. outputs/config.json +25 -0
  17. outputs/data_collator.joblib +3 -0
  18. outputs/events.out.tfevents.1627258355.tablespoon.3000110.3.v2 +3 -0
  19. outputs/flax_model.msgpack +3 -0
  20. outputs/optimizer_state.msgpack +3 -0
  21. outputs/training_args.joblib +3 -0
  22. outputs/training_state.json +1 -0
  23. push_to_hub.sh +3 -0
  24. pytorch_model.bin +3 -0
  25. run_mlm_flax_stream.py +832 -0
  26. run_stream.128.sh +27 -0
  27. run_stream.512.log +0 -0
  28. run_stream.512.sh +27 -0
  29. special_tokens_map.json +1 -0
  30. tokenizer.json +0 -0
  31. tokenizer_config.json +1 -0
  32. vocab.json +0 -0
  33. wandb/debug-internal.log +1 -0
  34. wandb/debug.log +1 -0
  35. wandb/latest-run +1 -0
  36. wandb/run-20210726_001233-17u6inbn/files/code/run_mlm_flax_stream.py +832 -0
  37. wandb/run-20210726_001233-17u6inbn/files/config.yaml +324 -0
  38. wandb/run-20210726_001233-17u6inbn/files/events.out.tfevents.1627258355.tablespoon.3000110.3.v2 +1 -0
  39. wandb/run-20210726_001233-17u6inbn/files/output.log +823 -0
  40. wandb/run-20210726_001233-17u6inbn/files/requirements.txt +108 -0
  41. wandb/run-20210726_001233-17u6inbn/files/wandb-metadata.json +47 -0
  42. wandb/run-20210726_001233-17u6inbn/files/wandb-summary.json +1 -0
  43. wandb/run-20210726_001233-17u6inbn/logs/debug-internal.log +0 -0
  44. wandb/run-20210726_001233-17u6inbn/logs/debug.log +27 -0
  45. wandb/run-20210726_001233-17u6inbn/run-17u6inbn.wandb +0 -0
README.md ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: es
3
+ license: CC-BY 4.0
4
+ tags:
5
+ - spanish
6
+ - roberta
7
+ pipeline_tag: fill-mask
8
+ widget:
9
+ - text: "Fui a la librería a comprar un <mask>."
10
+ ---
11
+
12
+ This is a **RoBERTa-base** model trained from scratch in Spanish.
13
+
14
+ The training dataset is [mc4](https://huggingface.co/datasets/bertin-project/mc4-es-sampled ) subsampling documents to a total of about 50 million examples. Sampling is biased towards average perplexity values (using a Gaussian function), discarding more often documents with very large values (poor quality) of very small values (short, repetitive texts).
15
+
16
+ This model takes the one using [sequence length 128](https://huggingface.co/bertin-project/bertin-base-stepwise) and trains during 25.000 steps using sequence length 512.
17
+
18
+ Please see our main [card](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) for more information.
19
+
20
+ This is part of the
21
+ [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
22
+
23
+
24
+ ## Team members
25
+
26
+ - Eduardo González ([edugp](https://huggingface.co/edugp))
27
+ - Javier de la Rosa ([versae](https://huggingface.co/versae))
28
+ - Manu Romero ([mrm8488](https://huggingface.co/))
29
+ - María Grandury ([mariagrandury](https://huggingface.co/))
30
+ - Pablo González de Prado ([Pablogps](https://huggingface.co/Pablogps))
31
+ - Paulo Villegas ([paulo](https://huggingface.co/paulo))
config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
configs/base/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
configs/base/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
configs/large/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 1024,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 4096,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 16,
18
+ "num_hidden_layers": 24,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
configs/large/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
convert.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import tempfile
3
+
4
+ import jax
5
+ from jax import numpy as jnp
6
+ from transformers import AutoTokenizer, FlaxRobertaForMaskedLM, RobertaForMaskedLM
7
+
8
+
9
+ def to_f32(t):
10
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
11
+
12
+
13
+ def main():
14
+ # Saving extra files from config.json and tokenizer.json files
15
+ tokenizer = AutoTokenizer.from_pretrained("./")
16
+ tokenizer.save_pretrained("./")
17
+
18
+ # Temporary saving bfloat16 Flax model into float32
19
+ tmp = tempfile.mkdtemp()
20
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained("./")
21
+ flax_model.params = to_f32(flax_model.params)
22
+ flax_model.save_pretrained(tmp)
23
+ # Converting float32 Flax to PyTorch
24
+ model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
25
+ model.save_pretrained("./", save_config=False)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ main()
flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a03759b12cc223a5e978aaada836d2e6c39f107a37072fc713c1cd6c8cb58f1
3
+ size 249750019
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
outputs/checkpoints/checkpoint-1000/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/checkpoints/checkpoint-1000/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/checkpoints/checkpoint-1000/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a03759b12cc223a5e978aaada836d2e6c39f107a37072fc713c1cd6c8cb58f1
3
+ size 249750019
outputs/checkpoints/checkpoint-1000/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45abc23402f1ddc533f4b24453f01978f94bc139d5216bf074d6542cec750bab
3
+ size 499500278
outputs/checkpoints/checkpoint-1000/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dcde99c91fe01c5143995806e1d6595b728cb8ed0a2d9f2f3c5610aeebeb7c2
3
+ size 1871
outputs/checkpoints/checkpoint-1000/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 1001}
outputs/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "RobertaForMaskedLM"
4
+ ],
5
+ "attention_probs_dropout_prob": 0.1,
6
+ "bos_token_id": 0,
7
+ "eos_token_id": 2,
8
+ "gradient_checkpointing": false,
9
+ "hidden_act": "gelu",
10
+ "hidden_dropout_prob": 0.1,
11
+ "hidden_size": 768,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 3072,
14
+ "layer_norm_eps": 1e-05,
15
+ "max_position_embeddings": 514,
16
+ "model_type": "roberta",
17
+ "num_attention_heads": 12,
18
+ "num_hidden_layers": 12,
19
+ "pad_token_id": 1,
20
+ "position_embedding_type": "absolute",
21
+ "transformers_version": "4.9.0.dev0",
22
+ "type_vocab_size": 1,
23
+ "use_cache": true,
24
+ "vocab_size": 50265
25
+ }
outputs/data_collator.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e02a6e9cfa63cb321cac9402efd29841b652999fcbf787800ae050e747b161ee
3
+ size 1471394
outputs/events.out.tfevents.1627258355.tablespoon.3000110.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c15d6accbbffb760087f0c740d328a1efb90da5cf430f6b59cc62f5450455429
3
+ size 147205
outputs/flax_model.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a03759b12cc223a5e978aaada836d2e6c39f107a37072fc713c1cd6c8cb58f1
3
+ size 249750019
outputs/optimizer_state.msgpack ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:45abc23402f1ddc533f4b24453f01978f94bc139d5216bf074d6542cec750bab
3
+ size 499500278
outputs/training_args.joblib ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4dcde99c91fe01c5143995806e1d6595b728cb8ed0a2d9f2f3c5610aeebeb7c2
3
+ size 1871
outputs/training_state.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"step": 1001}
push_to_hub.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git add -A
2
+ git commit -m "$(sed 's/\r/\n/g' run_stream.512.log | tail -1)"
3
+ git push
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f4c870186998e7c252130640a265d196f785ac49cace7f1f7085dfd0cf139aa6
3
+ size 498858859
run_mlm_flax_stream.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import json
25
+ import os
26
+ import shutil
27
+ import sys
28
+ import tempfile
29
+ import time
30
+ from collections import defaultdict
31
+ from dataclasses import dataclass, field
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ import joblib
35
+ from pathlib import Path
36
+ from typing import Dict, List, Optional, Tuple
37
+
38
+ import datasets
39
+ import numpy as np
40
+ from datasets import load_dataset
41
+ from tqdm import tqdm
42
+
43
+ import flax
44
+ import jax
45
+ import jax.numpy as jnp
46
+ import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
47
+ import optax
48
+ from flax import jax_utils, traverse_util
49
+ from flax.serialization import from_bytes, to_bytes
50
+ from flax.training import train_state
51
+ from flax.training.common_utils import get_metrics, onehot, shard
52
+ from transformers import (
53
+ CONFIG_MAPPING,
54
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
55
+ AutoConfig,
56
+ AutoTokenizer,
57
+ FlaxAutoModelForMaskedLM,
58
+ HfArgumentParser,
59
+ PreTrainedTokenizerBase,
60
+ TensorType,
61
+ TrainingArguments,
62
+ is_tensorboard_available,
63
+ set_seed,
64
+ FlaxRobertaForMaskedLM,
65
+ RobertaForMaskedLM,
66
+ )
67
+
68
+
69
+ if datasets.__version__ <= "1.8.0":
70
+ raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
71
+
72
+
73
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
74
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
81
+ """
82
+
83
+ model_name_or_path: Optional[str] = field(
84
+ default=None,
85
+ metadata={
86
+ "help": "The model checkpoint for weights initialization."
87
+ "Don't set if you want to train a model from scratch."
88
+ },
89
+ )
90
+ model_type: Optional[str] = field(
91
+ default=None,
92
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
93
+ )
94
+ config_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
+ )
97
+ tokenizer_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
+ )
100
+ cache_dir: Optional[str] = field(
101
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ dtype: Optional[str] = field(
108
+ default="float32",
109
+ metadata={
110
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
+ },
112
+ )
113
+
114
+ @dataclass
115
+ class DataTrainingArguments:
116
+ """
117
+ Arguments pertaining to what data we are going to input our model for training and eval.
118
+ """
119
+
120
+ dataset_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
122
+ )
123
+ dataset_config_name: Optional[str] = field(
124
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
+ )
126
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
127
+ validation_file: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
130
+ )
131
+ train_ref_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
134
+ )
135
+ validation_ref_file: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
138
+ )
139
+ overwrite_cache: bool = field(
140
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
141
+ )
142
+ validation_split_percentage: Optional[int] = field(
143
+ default=5,
144
+ metadata={
145
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
146
+ },
147
+ )
148
+ max_seq_length: Optional[int] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
152
+ "than this will be truncated. Default to the max input length of the model."
153
+ },
154
+ )
155
+ preprocessing_num_workers: Optional[int] = field(
156
+ default=None,
157
+ metadata={"help": "The number of processes to use for the preprocessing."},
158
+ )
159
+ mlm_probability: float = field(
160
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
161
+ )
162
+ pad_to_max_length: bool = field(
163
+ default=False,
164
+ metadata={
165
+ "help": "Whether to pad all samples to `max_seq_length`. "
166
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
167
+ },
168
+ )
169
+ line_by_line: bool = field(
170
+ default=False,
171
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
172
+ )
173
+ text_column_name: str = field(
174
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
175
+ )
176
+ shuffle_buffer_size: int = field(
177
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
178
+ )
179
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
180
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
181
+
182
+ def __post_init__(self):
183
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
184
+ raise ValueError("Need either a dataset name or a training/validation file.")
185
+ else:
186
+ if self.train_file is not None:
187
+ extension = self.train_file.split(".")[-1]
188
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
189
+ if self.validation_file is not None:
190
+ extension = self.validation_file.split(".")[-1]
191
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
192
+
193
+
194
+ @flax.struct.dataclass
195
+ class FlaxDataCollatorForLanguageModeling:
196
+ """
197
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
198
+ are not all of the same length.
199
+
200
+ Args:
201
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
202
+ The tokenizer used for encoding the data.
203
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
204
+ The probability with which to (randomly) mask tokens in the input.
205
+
206
+ .. note::
207
+
208
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
209
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
210
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
211
+ argument :obj:`return_special_tokens_mask=True`.
212
+ """
213
+
214
+ tokenizer: PreTrainedTokenizerBase
215
+ mlm_probability: float = 0.15
216
+
217
+ def __post_init__(self):
218
+ if self.tokenizer.mask_token is None:
219
+ raise ValueError(
220
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
221
+ "You should pass `mlm=False` to train on causal language modeling instead."
222
+ )
223
+
224
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
225
+ # Handle dict or lists with proper padding and conversion to tensor.
226
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
227
+
228
+ # If special token mask has been preprocessed, pop it from the dict.
229
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
230
+
231
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
232
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
233
+ )
234
+ return batch
235
+
236
+ def mask_tokens(
237
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
238
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
239
+ """
240
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
241
+ """
242
+ labels = inputs.copy()
243
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
244
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
245
+ special_tokens_mask = special_tokens_mask.astype("bool")
246
+
247
+ probability_matrix[special_tokens_mask] = 0.0
248
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
249
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
250
+
251
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
252
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
253
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
254
+
255
+ # 10% of the time, we replace masked input tokens with random word
256
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
257
+ indices_random &= masked_indices & ~indices_replaced
258
+
259
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
260
+ inputs[indices_random] = random_words[indices_random]
261
+
262
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
263
+ return inputs, labels
264
+
265
+
266
+ @dataclass
267
+ class SamplingArguments:
268
+ """
269
+ Arguments pertaining to how to perform sampling of the dataset.
270
+ """
271
+
272
+ perplexity_model: Optional[str] = field(
273
+ default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
274
+ )
275
+ sampling_method: Optional[str] = field(
276
+ default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
277
+ )
278
+ sampling_factor: Optional[float] = field(
279
+ default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
280
+ )
281
+ boundaries: Optional[str] = field(
282
+ default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
283
+ )
284
+
285
+ def __post_init__(self):
286
+ self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
287
+
288
+
289
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
290
+ num_samples = len(samples_idx)
291
+ samples_to_remove = num_samples % batch_size
292
+
293
+ if samples_to_remove != 0:
294
+ samples_idx = samples_idx[:-samples_to_remove]
295
+ sections_split = num_samples // batch_size
296
+ batch_idx = np.split(samples_idx, sections_split)
297
+ return batch_idx
298
+
299
+
300
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
301
+ """
302
+ The training iterator is advanced so that after groupifying the samples,
303
+ `num_samples` of length `max_seq_length` are returned.
304
+ """
305
+ num_total_tokens = max_seq_length * num_samples
306
+ samples = defaultdict(list)
307
+
308
+ i = 0
309
+ while i < num_total_tokens:
310
+ tokenized_samples = next(train_iterator)
311
+ i += len(tokenized_samples["input_ids"])
312
+
313
+ # concatenate tokenized samples to list
314
+ samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
315
+
316
+ # Concatenated tokens are split to lists of length `max_seq_length`.
317
+ # Note that remainedr of % max_seq_length are thrown away.
318
+ def group_texts(examples):
319
+ result = {
320
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
321
+ for k, t in examples.items()
322
+ }
323
+ return result
324
+
325
+ grouped_samples = group_texts(samples)
326
+ return grouped_samples
327
+
328
+
329
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
330
+ summary_writer.scalar("train_time", train_time, step)
331
+
332
+ train_metrics = get_metrics(train_metrics)
333
+ for key, vals in train_metrics.items():
334
+ tag = f"train_{key}"
335
+ for i, val in enumerate(vals):
336
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
337
+
338
+
339
+ def write_eval_metric(summary_writer, eval_metrics, step):
340
+ for metric_name, value in eval_metrics.items():
341
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
342
+
343
+
344
+ def save_checkpoint_files(state, data_collator, training_args, save_dir):
345
+ unreplicated_state = jax_utils.unreplicate(state)
346
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
347
+ f.write(to_bytes(unreplicated_state.opt_state))
348
+ joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
349
+ joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
350
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
351
+ json.dump({"step": unreplicated_state.step.item()}, f)
352
+
353
+
354
+ def restore_checkpoint(save_dir, state):
355
+ logger.info(f"Restoring checkpoint from {save_dir}")
356
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
357
+ params = from_bytes(state.params, f.read())
358
+
359
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f:
360
+ opt_state = from_bytes(state.opt_state, f.read())
361
+
362
+ args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
363
+ data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
364
+
365
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
366
+ training_state = json.load(f)
367
+ step = training_state["step"]
368
+
369
+ return params, opt_state, step, args, data_collator
370
+
371
+
372
+ def rotate_checkpoints(path, max_checkpoints=5):
373
+ paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
374
+ if len(paths) > max_checkpoints:
375
+ for path_to_delete in paths[max_checkpoints:]:
376
+ try:
377
+ shutil.rmtree(path_to_delete)
378
+ except OSError:
379
+ os.remove(path_to_delete)
380
+
381
+
382
+ def to_f32(t):
383
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
384
+
385
+
386
+ def convert(output_dir, destination_dir="./"):
387
+ shutil.copyfile(Path(output_dir) / "flax_model.msgpack", Path(destination_dir) / "flax_model.msgpack")
388
+ shutil.copyfile(Path(output_dir) / "config.json", Path(destination_dir) / "config.json")
389
+ # Saving extra files from config.json and tokenizer.json files
390
+ tokenizer = AutoTokenizer.from_pretrained(destination_dir)
391
+ tokenizer.save_pretrained(destination_dir)
392
+
393
+ # Temporary saving bfloat16 Flax model into float32
394
+ tmp = tempfile.mkdtemp()
395
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir)
396
+ flax_model.params = to_f32(flax_model.params)
397
+ flax_model.save_pretrained(tmp)
398
+ # Converting float32 Flax to PyTorch
399
+ model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
400
+ model.save_pretrained(destination_dir, save_config=False)
401
+
402
+
403
+ if __name__ == "__main__":
404
+ # See all possible arguments in src/transformers/training_args.py
405
+ # or by passing the --help flag to this script.
406
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
407
+
408
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
409
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
410
+ # If we pass only one argument to the script and it's the path to a json file,
411
+ # let's parse it to get our arguments.
412
+ model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
413
+ else:
414
+ model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
415
+
416
+ if (
417
+ os.path.exists(training_args.output_dir)
418
+ and os.listdir(training_args.output_dir)
419
+ and training_args.do_train
420
+ and not training_args.overwrite_output_dir
421
+ ):
422
+ raise ValueError(
423
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
424
+ "Use --overwrite_output_dir to overcome."
425
+ )
426
+
427
+ # Setup logging
428
+ logging.basicConfig(
429
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
430
+ level="INFO",
431
+ datefmt="[%X]",
432
+ )
433
+
434
+ # Log on each process the small summary:
435
+ logger = logging.getLogger(__name__)
436
+ logger.warning(
437
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
438
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
439
+ )
440
+
441
+ # Set the verbosity to info of the Transformers logger (on main process only):
442
+ logger.info(f"Training/evaluation parameters {training_args}")
443
+
444
+ # Set seed before initializing model.
445
+ set_seed(training_args.seed)
446
+
447
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
448
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
449
+ # (the dataset will be downloaded automatically from the datasets Hub).
450
+ #
451
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
452
+ # 'text' is found. You can easily tweak this behavior (see below).
453
+ if data_args.dataset_name is not None:
454
+ # Downloading and loading a dataset from the hub.
455
+ filepaths = {}
456
+ if data_args.train_file:
457
+ filepaths["train"] = data_args.train_file
458
+ if data_args.validation_file:
459
+ filepaths["validation"] = data_args.validation_file
460
+ try:
461
+ dataset = load_dataset(
462
+ data_args.dataset_name,
463
+ data_args.dataset_config_name,
464
+ cache_dir=model_args.cache_dir,
465
+ streaming=True,
466
+ split="train",
467
+ sampling_method=sampling_args.sampling_method,
468
+ sampling_factor=sampling_args.sampling_factor,
469
+ boundaries=sampling_args.boundaries,
470
+ perplexity_model=sampling_args.perplexity_model,
471
+ seed=training_args.seed,
472
+ data_files=filepaths,
473
+ )
474
+ except Exception as exc:
475
+ logger.warning(
476
+ f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
477
+ )
478
+ dataset = load_dataset(
479
+ data_args.dataset_name,
480
+ data_args.dataset_config_name,
481
+ cache_dir=model_args.cache_dir,
482
+ streaming=True,
483
+ split="train",
484
+ )
485
+
486
+ if model_args.config_name:
487
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
488
+ elif model_args.model_name_or_path:
489
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
490
+ else:
491
+ config = CONFIG_MAPPING[model_args.model_type]()
492
+ logger.warning("You are instantiating a new config instance from scratch.")
493
+
494
+ if model_args.tokenizer_name:
495
+ tokenizer = AutoTokenizer.from_pretrained(
496
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
497
+ )
498
+ elif model_args.model_name_or_path:
499
+ tokenizer = AutoTokenizer.from_pretrained(
500
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
501
+ )
502
+ else:
503
+ raise ValueError(
504
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
505
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
506
+ )
507
+
508
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
509
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
510
+ # efficient when it receives the `special_tokens_mask`.
511
+ def tokenize_function(examples):
512
+ return tokenizer(
513
+ examples[data_args.text_column_name],
514
+ return_special_tokens_mask=True
515
+ )
516
+
517
+ tokenized_datasets = dataset.map(
518
+ tokenize_function,
519
+ batched=True,
520
+ )
521
+
522
+ shuffle_seed = training_args.seed
523
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
524
+
525
+ # Enable tensorboard only on the master node
526
+ has_tensorboard = is_tensorboard_available()
527
+ if has_tensorboard and jax.process_index() == 0:
528
+ try:
529
+ # Enable Weight&Biases
530
+ import wandb
531
+ wandb.init(
532
+ entity='wandb',
533
+ project='hf-flax-bertin-roberta-es',
534
+ sync_tensorboard=True,
535
+ )
536
+ wandb.config.update(training_args)
537
+ wandb.config.update(model_args)
538
+ wandb.config.update(data_args)
539
+ from flax.metrics.tensorboard import SummaryWriter
540
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
541
+ except ImportError as ie:
542
+ has_tensorboard = False
543
+ logger.warning(
544
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
545
+ )
546
+ else:
547
+ logger.warning(
548
+ "Unable to display metrics through TensorBoard because the package is not installed: "
549
+ "Please run pip install tensorboard to enable."
550
+ )
551
+
552
+ # Data collator
553
+ # This one will take care of randomly masking the tokens.
554
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
555
+
556
+ # Initialize our training
557
+ rng = jax.random.PRNGKey(training_args.seed)
558
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
559
+
560
+ if model_args.model_name_or_path:
561
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
562
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
563
+ )
564
+ else:
565
+ model = FlaxAutoModelForMaskedLM.from_config(
566
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
567
+ )
568
+
569
+ # Store some constant
570
+ num_epochs = int(training_args.num_train_epochs)
571
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
572
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
573
+
574
+ # define number steps per stream epoch
575
+ num_train_steps = data_args.num_train_steps
576
+
577
+ # Create learning rate schedule
578
+ warmup_fn = optax.linear_schedule(
579
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
580
+ )
581
+ decay_fn = optax.linear_schedule(
582
+ init_value=training_args.learning_rate,
583
+ end_value=0,
584
+ transition_steps=num_train_steps - training_args.warmup_steps,
585
+ )
586
+ linear_decay_lr_schedule_fn = optax.join_schedules(
587
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
588
+ )
589
+
590
+ # We use Optax's "masking" functionality to not apply weight decay
591
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
592
+ # mask boolean with the same structure as the parameters.
593
+ # The mask is True for parameters that should be decayed.
594
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
595
+ # For other models, one should correct the layer norm parameter naming
596
+ # accordingly.
597
+ def decay_mask_fn(params):
598
+ flat_params = traverse_util.flatten_dict(params)
599
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
600
+ return traverse_util.unflatten_dict(flat_mask)
601
+
602
+ # create adam optimizer
603
+ adamw = optax.adamw(
604
+ learning_rate=linear_decay_lr_schedule_fn,
605
+ b1=training_args.adam_beta1,
606
+ b2=training_args.adam_beta2,
607
+ eps=training_args.adam_epsilon,
608
+ weight_decay=training_args.weight_decay,
609
+ mask=decay_mask_fn,
610
+ )
611
+
612
+ # Setup train state
613
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
614
+ saved_step = -1
615
+ if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
616
+ params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
617
+ # Create learning rate schedule
618
+ warmup_fn = optax.linear_schedule(
619
+ init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
620
+ )
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=args.learning_rate,
623
+ end_value=0,
624
+ transition_steps=data_args.num_train_steps - args.warmup_steps,
625
+ )
626
+ linear_decay_lr_schedule_fn = optax.join_schedules(
627
+ schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
628
+ )
629
+ # create adam optimizer
630
+ adamw = optax.adamw(
631
+ learning_rate=linear_decay_lr_schedule_fn,
632
+ b1=training_args.adam_beta1,
633
+ b2=training_args.adam_beta2,
634
+ eps=training_args.adam_epsilon,
635
+ weight_decay=args.weight_decay,
636
+ mask=decay_mask_fn,
637
+ )
638
+ state = train_state.TrainState(
639
+ step=saved_step,
640
+ apply_fn=model.__call__,
641
+ params=params,
642
+ tx=adamw,
643
+ opt_state=opt_state,
644
+ )
645
+ # self.args = args
646
+ # data_collator = data_collator
647
+ # scheduler_fn = args.learning_rate
648
+ model.params = params
649
+
650
+
651
+ # Define gradient update step fn
652
+ def train_step(state, batch, dropout_rng):
653
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
654
+
655
+ def loss_fn(params):
656
+ labels = batch.pop("labels")
657
+
658
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
659
+
660
+ # compute loss, ignore padded input tokens
661
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
662
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
663
+
664
+ # take average
665
+ loss = loss.sum() / label_mask.sum()
666
+
667
+ return loss
668
+
669
+ grad_fn = jax.value_and_grad(loss_fn)
670
+ loss, grad = grad_fn(state.params)
671
+ grad = jax.lax.pmean(grad, "batch")
672
+ new_state = state.apply_gradients(grads=grad)
673
+
674
+ metrics = jax.lax.pmean(
675
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
676
+ )
677
+
678
+ return new_state, metrics, new_dropout_rng
679
+
680
+ # Create parallel version of the train step
681
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
682
+
683
+ # Define eval fn
684
+ def eval_step(params, batch):
685
+ labels = batch.pop("labels")
686
+
687
+ logits = model(**batch, params=params, train=False)[0]
688
+
689
+ # compute loss, ignore padded input tokens
690
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
691
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
692
+
693
+ # compute accuracy
694
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
695
+
696
+ # summarize metrics
697
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
698
+ metrics = jax.lax.psum(metrics, axis_name="batch")
699
+
700
+ return metrics
701
+
702
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
703
+
704
+ # Replicate the train state on each device
705
+ state = jax_utils.replicate(state)
706
+
707
+ train_time = 0
708
+ train_start = time.time()
709
+ train_metrics = []
710
+ eval_metrics = []
711
+
712
+ training_iter = iter(tokenized_datasets)
713
+
714
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
715
+ eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
716
+
717
+ last_desc = ""
718
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
719
+ for step in range(num_train_steps):
720
+ if step < saved_step:
721
+ steps.update(1)
722
+ continue
723
+ # ======================== Training ================================
724
+ try:
725
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
726
+ except StopIteration:
727
+ # Once the end of the dataset stream is reached, the training iterator
728
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
729
+ shuffle_seed += 1
730
+ tokenized_datasets.set_epoch(shuffle_seed)
731
+
732
+ training_iter = iter(tokenized_datasets)
733
+
734
+ eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
735
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
736
+
737
+ # process input samples
738
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
739
+
740
+ # Model forward
741
+ model_inputs = shard(model_inputs.data)
742
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
743
+
744
+ train_metrics.append(train_metric)
745
+
746
+ if step % training_args.logging_steps == 0 and step > 0:
747
+ steps.write(
748
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
749
+ )
750
+ train_time += time.time() - train_start
751
+ if has_tensorboard and jax.process_index() == 0:
752
+ write_train_metric(summary_writer, train_metrics, train_time, step)
753
+ train_metrics = []
754
+
755
+ # ======================== Evaluating ==============================
756
+ if step % training_args.eval_steps == 0 and step > 0:
757
+ eval_samples_idx = jnp.arange(data_args.num_eval_samples)
758
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
759
+
760
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
761
+ # process input samples
762
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
763
+ model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
764
+
765
+ # Model forward
766
+ model_inputs = shard(model_inputs.data)
767
+ metrics = p_eval_step(state.params, model_inputs)
768
+ eval_metrics.append(metrics)
769
+
770
+ # normalize eval metrics
771
+ eval_metrics = get_metrics(eval_metrics)
772
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
773
+ eval_normalizer = eval_metrics.pop("normalizer")
774
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
775
+
776
+ # Update progress bar
777
+ steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
778
+ last_desc = steps.desc
779
+
780
+ if has_tensorboard and jax.process_index() == 0:
781
+ write_eval_metric(summary_writer, eval_metrics, step)
782
+ eval_metrics = []
783
+
784
+ # save checkpoint after eval_steps
785
+ if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
786
+ logger.info(f"Saving checkpoint at {step} steps")
787
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
788
+ model.save_pretrained(
789
+ training_args.output_dir,
790
+ params=params,
791
+ push_to_hub=False,
792
+ )
793
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
794
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
795
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
796
+ model.save_pretrained(checkpoints_dir, params=params)
797
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
798
+ rotate_checkpoints(
799
+ Path(training_args.output_dir) / "checkpoints",
800
+ max_checkpoints=training_args.save_total_limit
801
+ )
802
+ convert(training_args.output_dir, "./")
803
+ model.save_pretrained(
804
+ training_args.output_dir,
805
+ params=params,
806
+ push_to_hub=training_args.push_to_hub,
807
+ commit_message=last_desc,
808
+ )
809
+
810
+ # update tqdm bar
811
+ steps.update(1)
812
+
813
+ if jax.process_index() == 0:
814
+ logger.info(f"Saving checkpoint at {step} steps")
815
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
816
+ model.save_pretrained(
817
+ training_args.output_dir,
818
+ params=params,
819
+ push_to_hub=False,
820
+ )
821
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
822
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
823
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
824
+ model.save_pretrained(checkpoints_dir, params=params)
825
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
826
+ convert(training_args.output_dir, "./")
827
+ model.save_pretrained(
828
+ training_args.output_dir,
829
+ params=params,
830
+ push_to_hub=training_args.push_to_hub,
831
+ commit_message=last_desc or "Saving model after training",
832
+ )
run_stream.128.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://arxiv.org/pdf/1907.11692.pdf for base model
2
+ python -c "import jax; print('TPUs', jax.device_count())"
3
+ python ./run_mlm_flax_stream.py \
4
+ --output_dir="./outputs" \
5
+ --model_type="roberta" \
6
+ --config_name="./configs/base" \
7
+ --tokenizer_name="./configs/base" \
8
+ --dataset_name="./mc4" \
9
+ --dataset_config_name="es" \
10
+ --train_file="../mc4-es-train-50M-random.jsonl" \
11
+ --max_seq_length="128" \
12
+ --pad_to_max_length \
13
+ --per_device_train_batch_size="256" \
14
+ --per_device_eval_batch_size="256" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --adam_epsilon="1e-6" \
18
+ --learning_rate="6e-4" \
19
+ --weight_decay="0.01" \
20
+ --save_steps="10000" \
21
+ --save_total_limit="5" \
22
+ --warmup_steps="24000" \
23
+ --overwrite_output_dir \
24
+ --num_train_steps="250000" \
25
+ --eval_steps="10000" \
26
+ --dtype="bfloat16" \
27
+ --logging_steps="500" 2>&1 | tee run_stream.log
run_stream.512.log ADDED
The diff for this file is too large to render. See raw diff
 
run_stream.512.sh ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://arxiv.org/pdf/1907.11692.pdf for base model
2
+ python -c "import jax; print('TPUs', jax.device_count())"
3
+ python ./run_mlm_flax_stream.py \
4
+ --model_name_or_path="bertin-project/bertin-base-stepwise" \
5
+ --output_dir="./outputs" \
6
+ --model_type="roberta" \
7
+ --config_name="./configs/base" \
8
+ --tokenizer_name="./configs/base" \
9
+ --dataset_name="bertin-project/mc4-es-sampled" \
10
+ --dataset_config_name="stepwise" \
11
+ --max_seq_length="512" \
12
+ --pad_to_max_length \
13
+ --per_device_train_batch_size="48" \
14
+ --per_device_eval_batch_size="48" \
15
+ --adam_beta1="0.9" \
16
+ --adam_beta2="0.98" \
17
+ --adam_epsilon="1e-6" \
18
+ --learning_rate="6e-4" \
19
+ --weight_decay="0.01" \
20
+ --save_steps="1000" \
21
+ --save_total_limit="5" \
22
+ --warmup_steps="500" \
23
+ --overwrite_output_dir \
24
+ --num_train_steps="50000" \
25
+ --eval_steps="1000" \
26
+ --dtype="bfloat16" \
27
+ --logging_steps="500" 2>&1 | tee run_stream.512.log
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"bos_token": "<s>", "eos_token": "</s>", "unk_token": "<unk>", "sep_token": "</s>", "pad_token": "<pad>", "cls_token": "<s>", "mask_token": {"content": "<mask>", "single_word": false, "lstrip": true, "rstrip": false, "normalized": false}}
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>", "add_prefix_space": false, "errors": "replace", "sep_token": "</s>", "cls_token": "<s>", "pad_token": "<pad>", "mask_token": "<mask>", "special_tokens_map_file": null, "name_or_path": "./", "tokenizer_class": "RobertaTokenizer"}
vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
wandb/debug-internal.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210726_001233-17u6inbn/logs/debug-internal.log
wandb/debug.log ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210726_001233-17u6inbn/logs/debug.log
wandb/latest-run ADDED
@@ -0,0 +1 @@
 
 
1
+ run-20210726_001233-17u6inbn
wandb/run-20210726_001233-17u6inbn/files/code/run_mlm_flax_stream.py ADDED
@@ -0,0 +1,832 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2021 The HuggingFace Team All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """
17
+ Fine-tuning the library models for masked language modeling (BERT, ALBERT, RoBERTa...) with whole word masking on a
18
+ text file or a dataset.
19
+
20
+ Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
21
+ https://huggingface.co/models?filter=masked-lm
22
+ """
23
+ import logging
24
+ import json
25
+ import os
26
+ import shutil
27
+ import sys
28
+ import tempfile
29
+ import time
30
+ from collections import defaultdict
31
+ from dataclasses import dataclass, field
32
+
33
+ # You can also adapt this script on your own masked language modeling task. Pointers for this are left as comments.
34
+ import joblib
35
+ from pathlib import Path
36
+ from typing import Dict, List, Optional, Tuple
37
+
38
+ import datasets
39
+ import numpy as np
40
+ from datasets import load_dataset
41
+ from tqdm import tqdm
42
+
43
+ import flax
44
+ import jax
45
+ import jax.numpy as jnp
46
+ import kenlm # pip install https://github.com/kpu/kenlm/archive/master.zip
47
+ import optax
48
+ from flax import jax_utils, traverse_util
49
+ from flax.serialization import from_bytes, to_bytes
50
+ from flax.training import train_state
51
+ from flax.training.common_utils import get_metrics, onehot, shard
52
+ from transformers import (
53
+ CONFIG_MAPPING,
54
+ FLAX_MODEL_FOR_MASKED_LM_MAPPING,
55
+ AutoConfig,
56
+ AutoTokenizer,
57
+ FlaxAutoModelForMaskedLM,
58
+ HfArgumentParser,
59
+ PreTrainedTokenizerBase,
60
+ TensorType,
61
+ TrainingArguments,
62
+ is_tensorboard_available,
63
+ set_seed,
64
+ FlaxRobertaForMaskedLM,
65
+ RobertaForMaskedLM,
66
+ )
67
+
68
+
69
+ if datasets.__version__ <= "1.8.0":
70
+ raise ValueError("Make sure to upgrade `datasets` to a version >= 1.9.0 to use dataset streaming")
71
+
72
+
73
+ MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
74
+ MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
75
+
76
+
77
+ @dataclass
78
+ class ModelArguments:
79
+ """
80
+ Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
81
+ """
82
+
83
+ model_name_or_path: Optional[str] = field(
84
+ default=None,
85
+ metadata={
86
+ "help": "The model checkpoint for weights initialization."
87
+ "Don't set if you want to train a model from scratch."
88
+ },
89
+ )
90
+ model_type: Optional[str] = field(
91
+ default=None,
92
+ metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
93
+ )
94
+ config_name: Optional[str] = field(
95
+ default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
96
+ )
97
+ tokenizer_name: Optional[str] = field(
98
+ default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
99
+ )
100
+ cache_dir: Optional[str] = field(
101
+ default=None, metadata={"help": "Where do you want to store the pretrained models downloaded from s3"}
102
+ )
103
+ use_fast_tokenizer: bool = field(
104
+ default=True,
105
+ metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
106
+ )
107
+ dtype: Optional[str] = field(
108
+ default="float32",
109
+ metadata={
110
+ "help": "Floating-point format in which the model weights should be initialized and trained. Choose one of `[float32, float16, bfloat16]`."
111
+ },
112
+ )
113
+
114
+ @dataclass
115
+ class DataTrainingArguments:
116
+ """
117
+ Arguments pertaining to what data we are going to input our model for training and eval.
118
+ """
119
+
120
+ dataset_name: Optional[str] = field(
121
+ default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
122
+ )
123
+ dataset_config_name: Optional[str] = field(
124
+ default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
125
+ )
126
+ train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
127
+ validation_file: Optional[str] = field(
128
+ default=None,
129
+ metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
130
+ )
131
+ train_ref_file: Optional[str] = field(
132
+ default=None,
133
+ metadata={"help": "An optional input train ref data file for whole word masking in Chinese."},
134
+ )
135
+ validation_ref_file: Optional[str] = field(
136
+ default=None,
137
+ metadata={"help": "An optional input validation ref data file for whole word masking in Chinese."},
138
+ )
139
+ overwrite_cache: bool = field(
140
+ default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
141
+ )
142
+ validation_split_percentage: Optional[int] = field(
143
+ default=5,
144
+ metadata={
145
+ "help": "The percentage of the train set used as validation set in case there's no validation split"
146
+ },
147
+ )
148
+ max_seq_length: Optional[int] = field(
149
+ default=None,
150
+ metadata={
151
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
152
+ "than this will be truncated. Default to the max input length of the model."
153
+ },
154
+ )
155
+ preprocessing_num_workers: Optional[int] = field(
156
+ default=None,
157
+ metadata={"help": "The number of processes to use for the preprocessing."},
158
+ )
159
+ mlm_probability: float = field(
160
+ default=0.15, metadata={"help": "Ratio of tokens to mask for masked language modeling loss"}
161
+ )
162
+ pad_to_max_length: bool = field(
163
+ default=False,
164
+ metadata={
165
+ "help": "Whether to pad all samples to `max_seq_length`. "
166
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
167
+ },
168
+ )
169
+ line_by_line: bool = field(
170
+ default=False,
171
+ metadata={"help": "Whether distinct lines of text in the dataset are to be handled as distinct sequences."},
172
+ )
173
+ text_column_name: str = field(
174
+ default="text", metadata={"help": "The name of the column to retrieve the training text."}
175
+ )
176
+ shuffle_buffer_size: int = field(
177
+ default=10000, metadata={"help": "The number of examples to pre-load for shuffling."}
178
+ )
179
+ num_train_steps: int = field(default=50000, metadata={"help": "The number of training steps."})
180
+ num_eval_samples: int = field(default=50000, metadata={"help": "The number of samples to be used for evaluation"})
181
+
182
+ def __post_init__(self):
183
+ if self.dataset_name is None and self.train_file is None and self.validation_file is None:
184
+ raise ValueError("Need either a dataset name or a training/validation file.")
185
+ else:
186
+ if self.train_file is not None:
187
+ extension = self.train_file.split(".")[-1]
188
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`train_file` should be a csv, a json (lines) or a txt file."
189
+ if self.validation_file is not None:
190
+ extension = self.validation_file.split(".")[-1]
191
+ assert extension in ["csv", "json", "jsonl", "txt", "gz"], "`validation_file` should be a csv, a json (lines) or a txt file."
192
+
193
+
194
+ @flax.struct.dataclass
195
+ class FlaxDataCollatorForLanguageModeling:
196
+ """
197
+ Data collator used for language modeling. Inputs are dynamically padded to the maximum length of a batch if they
198
+ are not all of the same length.
199
+
200
+ Args:
201
+ tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
202
+ The tokenizer used for encoding the data.
203
+ mlm_probability (:obj:`float`, `optional`, defaults to 0.15):
204
+ The probability with which to (randomly) mask tokens in the input.
205
+
206
+ .. note::
207
+
208
+ For best performance, this data collator should be used with a dataset having items that are dictionaries or
209
+ BatchEncoding, with the :obj:`"special_tokens_mask"` key, as returned by a
210
+ :class:`~transformers.PreTrainedTokenizer` or a :class:`~transformers.PreTrainedTokenizerFast` with the
211
+ argument :obj:`return_special_tokens_mask=True`.
212
+ """
213
+
214
+ tokenizer: PreTrainedTokenizerBase
215
+ mlm_probability: float = 0.15
216
+
217
+ def __post_init__(self):
218
+ if self.tokenizer.mask_token is None:
219
+ raise ValueError(
220
+ "This tokenizer does not have a mask token which is necessary for masked language modeling. "
221
+ "You should pass `mlm=False` to train on causal language modeling instead."
222
+ )
223
+
224
+ def __call__(self, examples: List[Dict[str, np.ndarray]], pad_to_multiple_of: int) -> Dict[str, np.ndarray]:
225
+ # Handle dict or lists with proper padding and conversion to tensor.
226
+ batch = self.tokenizer.pad(examples, pad_to_multiple_of=pad_to_multiple_of, return_tensors=TensorType.NUMPY)
227
+
228
+ # If special token mask has been preprocessed, pop it from the dict.
229
+ special_tokens_mask = batch.pop("special_tokens_mask", None)
230
+
231
+ batch["input_ids"], batch["labels"] = self.mask_tokens(
232
+ batch["input_ids"], special_tokens_mask=special_tokens_mask
233
+ )
234
+ return batch
235
+
236
+ def mask_tokens(
237
+ self, inputs: np.ndarray, special_tokens_mask: Optional[np.ndarray]
238
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
239
+ """
240
+ Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
241
+ """
242
+ labels = inputs.copy()
243
+ # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
244
+ probability_matrix = np.full(labels.shape, self.mlm_probability)
245
+ special_tokens_mask = special_tokens_mask.astype("bool")
246
+
247
+ probability_matrix[special_tokens_mask] = 0.0
248
+ masked_indices = np.random.binomial(1, probability_matrix).astype("bool")
249
+ labels[~masked_indices] = -100 # We only compute loss on masked tokens
250
+
251
+ # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
252
+ indices_replaced = np.random.binomial(1, np.full(labels.shape, 0.8)).astype("bool") & masked_indices
253
+ inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
254
+
255
+ # 10% of the time, we replace masked input tokens with random word
256
+ indices_random = np.random.binomial(1, np.full(labels.shape, 0.5)).astype("bool")
257
+ indices_random &= masked_indices & ~indices_replaced
258
+
259
+ random_words = np.random.randint(self.tokenizer.vocab_size, size=labels.shape, dtype="i4")
260
+ inputs[indices_random] = random_words[indices_random]
261
+
262
+ # The rest of the time (10% of the time) we keep the masked input tokens unchanged
263
+ return inputs, labels
264
+
265
+
266
+ @dataclass
267
+ class SamplingArguments:
268
+ """
269
+ Arguments pertaining to how to perform sampling of the dataset.
270
+ """
271
+
272
+ perplexity_model: Optional[str] = field(
273
+ default="./es.arpa.bin", metadata={"help": "Path to KenLM model to use to get perplexity values."}
274
+ )
275
+ sampling_method: Optional[str] = field(
276
+ default=None, metadata={"help": "Sample using a 'step' or 'gaussian' perplexity function per document, or 'random'."}
277
+ )
278
+ sampling_factor: Optional[float] = field(
279
+ default=None, metadata={"help": "Sampling factor. Integers for step function, decimals for gaussian."}
280
+ )
281
+ boundaries: Optional[str] = field(
282
+ default="536394.99320948,662247.50212365,919250.87225178", metadata={"help": "Quartile boundaries"}
283
+ )
284
+
285
+ def __post_init__(self):
286
+ self.boundaries = [float(q.strip()) for q in self.boundaries.split(",")]
287
+
288
+
289
+ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndarray:
290
+ num_samples = len(samples_idx)
291
+ samples_to_remove = num_samples % batch_size
292
+
293
+ if samples_to_remove != 0:
294
+ samples_idx = samples_idx[:-samples_to_remove]
295
+ sections_split = num_samples // batch_size
296
+ batch_idx = np.split(samples_idx, sections_split)
297
+ return batch_idx
298
+
299
+
300
+ def advance_iter_and_group_samples(train_iterator, num_samples, max_seq_length):
301
+ """
302
+ The training iterator is advanced so that after groupifying the samples,
303
+ `num_samples` of length `max_seq_length` are returned.
304
+ """
305
+ num_total_tokens = max_seq_length * num_samples
306
+ samples = defaultdict(list)
307
+
308
+ i = 0
309
+ while i < num_total_tokens:
310
+ tokenized_samples = next(train_iterator)
311
+ i += len(tokenized_samples["input_ids"])
312
+
313
+ # concatenate tokenized samples to list
314
+ samples = {k: samples[k] + tokenized_samples[k] for k in tokenized_samples.keys()}
315
+
316
+ # Concatenated tokens are split to lists of length `max_seq_length`.
317
+ # Note that remainedr of % max_seq_length are thrown away.
318
+ def group_texts(examples):
319
+ result = {
320
+ k: [t[i : i + max_seq_length] for i in range(0, num_total_tokens, max_seq_length)]
321
+ for k, t in examples.items()
322
+ }
323
+ return result
324
+
325
+ grouped_samples = group_texts(samples)
326
+ return grouped_samples
327
+
328
+
329
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
330
+ summary_writer.scalar("train_time", train_time, step)
331
+
332
+ train_metrics = get_metrics(train_metrics)
333
+ for key, vals in train_metrics.items():
334
+ tag = f"train_{key}"
335
+ for i, val in enumerate(vals):
336
+ summary_writer.scalar(tag, val, step - len(vals) + i + 1)
337
+
338
+
339
+ def write_eval_metric(summary_writer, eval_metrics, step):
340
+ for metric_name, value in eval_metrics.items():
341
+ summary_writer.scalar(f"eval_{metric_name}", value, step)
342
+
343
+
344
+ def save_checkpoint_files(state, data_collator, training_args, save_dir):
345
+ unreplicated_state = jax_utils.unreplicate(state)
346
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "wb") as f:
347
+ f.write(to_bytes(unreplicated_state.opt_state))
348
+ joblib.dump(training_args, os.path.join(save_dir, "training_args.joblib"))
349
+ joblib.dump(data_collator, os.path.join(save_dir, "data_collator.joblib"))
350
+ with open(os.path.join(save_dir, "training_state.json"), "w") as f:
351
+ json.dump({"step": unreplicated_state.step.item()}, f)
352
+
353
+
354
+ def restore_checkpoint(save_dir, state):
355
+ logger.info(f"Restoring checkpoint from {save_dir}")
356
+ with open(os.path.join(save_dir, "flax_model.msgpack"), "rb") as f:
357
+ params = from_bytes(state.params, f.read())
358
+
359
+ with open(os.path.join(save_dir, "optimizer_state.msgpack"), "rb") as f:
360
+ opt_state = from_bytes(state.opt_state, f.read())
361
+
362
+ args = joblib.load(os.path.join(save_dir, "training_args.joblib"))
363
+ data_collator = joblib.load(os.path.join(save_dir, "data_collator.joblib"))
364
+
365
+ with open(os.path.join(save_dir, "training_state.json"), "r") as f:
366
+ training_state = json.load(f)
367
+ step = training_state["step"]
368
+
369
+ return params, opt_state, step, args, data_collator
370
+
371
+
372
+ def rotate_checkpoints(path, max_checkpoints=5):
373
+ paths = sorted(Path(path).iterdir(), key=os.path.getmtime)[::-1]
374
+ if len(paths) > max_checkpoints:
375
+ for path_to_delete in paths[max_checkpoints:]:
376
+ try:
377
+ shutil.rmtree(path_to_delete)
378
+ except OSError:
379
+ os.remove(path_to_delete)
380
+
381
+
382
+ def to_f32(t):
383
+ return jax.tree_map(lambda x: x.astype(jnp.float32) if x.dtype == jnp.bfloat16 else x, t)
384
+
385
+
386
+ def convert(output_dir, destination_dir="./"):
387
+ shutil.copyfile(Path(output_dir) / "flax_model.msgpack", Path(destination_dir) / "flax_model.msgpack")
388
+ shutil.copyfile(Path(output_dir) / "config.json", Path(destination_dir) / "config.json")
389
+ # Saving extra files from config.json and tokenizer.json files
390
+ tokenizer = AutoTokenizer.from_pretrained(destination_dir)
391
+ tokenizer.save_pretrained(destination_dir)
392
+
393
+ # Temporary saving bfloat16 Flax model into float32
394
+ tmp = tempfile.mkdtemp()
395
+ flax_model = FlaxRobertaForMaskedLM.from_pretrained(destination_dir)
396
+ flax_model.params = to_f32(flax_model.params)
397
+ flax_model.save_pretrained(tmp)
398
+ # Converting float32 Flax to PyTorch
399
+ model = RobertaForMaskedLM.from_pretrained(tmp, from_flax=True)
400
+ model.save_pretrained(destination_dir, save_config=False)
401
+
402
+
403
+ if __name__ == "__main__":
404
+ # See all possible arguments in src/transformers/training_args.py
405
+ # or by passing the --help flag to this script.
406
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
407
+
408
+ parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments, SamplingArguments))
409
+ if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
410
+ # If we pass only one argument to the script and it's the path to a json file,
411
+ # let's parse it to get our arguments.
412
+ model_args, data_args, training_args, sampling_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
413
+ else:
414
+ model_args, data_args, training_args, sampling_args = parser.parse_args_into_dataclasses()
415
+
416
+ if (
417
+ os.path.exists(training_args.output_dir)
418
+ and os.listdir(training_args.output_dir)
419
+ and training_args.do_train
420
+ and not training_args.overwrite_output_dir
421
+ ):
422
+ raise ValueError(
423
+ f"Output directory ({training_args.output_dir}) already exists and is not empty."
424
+ "Use --overwrite_output_dir to overcome."
425
+ )
426
+
427
+ # Setup logging
428
+ logging.basicConfig(
429
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
430
+ level="INFO",
431
+ datefmt="[%X]",
432
+ )
433
+
434
+ # Log on each process the small summary:
435
+ logger = logging.getLogger(__name__)
436
+ logger.warning(
437
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
438
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
439
+ )
440
+
441
+ # Set the verbosity to info of the Transformers logger (on main process only):
442
+ logger.info(f"Training/evaluation parameters {training_args}")
443
+
444
+ # Set seed before initializing model.
445
+ set_seed(training_args.seed)
446
+
447
+ # Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
448
+ # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
449
+ # (the dataset will be downloaded automatically from the datasets Hub).
450
+ #
451
+ # For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
452
+ # 'text' is found. You can easily tweak this behavior (see below).
453
+ if data_args.dataset_name is not None:
454
+ # Downloading and loading a dataset from the hub.
455
+ filepaths = {}
456
+ if data_args.train_file:
457
+ filepaths["train"] = data_args.train_file
458
+ if data_args.validation_file:
459
+ filepaths["validation"] = data_args.validation_file
460
+ try:
461
+ dataset = load_dataset(
462
+ data_args.dataset_name,
463
+ data_args.dataset_config_name,
464
+ cache_dir=model_args.cache_dir,
465
+ streaming=True,
466
+ split="train",
467
+ sampling_method=sampling_args.sampling_method,
468
+ sampling_factor=sampling_args.sampling_factor,
469
+ boundaries=sampling_args.boundaries,
470
+ perplexity_model=sampling_args.perplexity_model,
471
+ seed=training_args.seed,
472
+ data_files=filepaths,
473
+ )
474
+ except Exception as exc:
475
+ logger.warning(
476
+ f"Unable to load local dataset with perplexity sampling support. Using huggingface.co/datasets/{data_args.dataset_name}: {exc}"
477
+ )
478
+ dataset = load_dataset(
479
+ data_args.dataset_name,
480
+ data_args.dataset_config_name,
481
+ cache_dir=model_args.cache_dir,
482
+ streaming=True,
483
+ split="train",
484
+ )
485
+
486
+ if model_args.config_name:
487
+ config = AutoConfig.from_pretrained(model_args.config_name, cache_dir=model_args.cache_dir)
488
+ elif model_args.model_name_or_path:
489
+ config = AutoConfig.from_pretrained(model_args.model_name_or_path, cache_dir=model_args.cache_dir)
490
+ else:
491
+ config = CONFIG_MAPPING[model_args.model_type]()
492
+ logger.warning("You are instantiating a new config instance from scratch.")
493
+
494
+ if model_args.tokenizer_name:
495
+ tokenizer = AutoTokenizer.from_pretrained(
496
+ model_args.tokenizer_name, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
497
+ )
498
+ elif model_args.model_name_or_path:
499
+ tokenizer = AutoTokenizer.from_pretrained(
500
+ model_args.model_name_or_path, cache_dir=model_args.cache_dir, use_fast=model_args.use_fast_tokenizer
501
+ )
502
+ else:
503
+ raise ValueError(
504
+ "You are instantiating a new tokenizer from scratch. This is not supported by this script."
505
+ "You can do it from another script, save it, and load it from here, using --tokenizer_name."
506
+ )
507
+
508
+ # Otherwise, we tokenize every text, then concatenate them together before splitting them in smaller parts.
509
+ # We use `return_special_tokens_mask=True` because DataCollatorForLanguageModeling (see below) is more
510
+ # efficient when it receives the `special_tokens_mask`.
511
+ def tokenize_function(examples):
512
+ return tokenizer(
513
+ examples[data_args.text_column_name],
514
+ return_special_tokens_mask=True
515
+ )
516
+
517
+ tokenized_datasets = dataset.map(
518
+ tokenize_function,
519
+ batched=True,
520
+ )
521
+
522
+ shuffle_seed = training_args.seed
523
+ tokenized_datasets = tokenized_datasets.shuffle(buffer_size=data_args.shuffle_buffer_size, seed=shuffle_seed)
524
+
525
+ # Enable tensorboard only on the master node
526
+ has_tensorboard = is_tensorboard_available()
527
+ if has_tensorboard and jax.process_index() == 0:
528
+ try:
529
+ # Enable Weight&Biases
530
+ import wandb
531
+ wandb.init(
532
+ entity='wandb',
533
+ project='hf-flax-bertin-roberta-es',
534
+ sync_tensorboard=True,
535
+ )
536
+ wandb.config.update(training_args)
537
+ wandb.config.update(model_args)
538
+ wandb.config.update(data_args)
539
+ from flax.metrics.tensorboard import SummaryWriter
540
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
541
+ except ImportError as ie:
542
+ has_tensorboard = False
543
+ logger.warning(
544
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
545
+ )
546
+ else:
547
+ logger.warning(
548
+ "Unable to display metrics through TensorBoard because the package is not installed: "
549
+ "Please run pip install tensorboard to enable."
550
+ )
551
+
552
+ # Data collator
553
+ # This one will take care of randomly masking the tokens.
554
+ data_collator = FlaxDataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=data_args.mlm_probability)
555
+
556
+ # Initialize our training
557
+ rng = jax.random.PRNGKey(training_args.seed)
558
+ dropout_rngs = jax.random.split(rng, jax.local_device_count())
559
+
560
+ if model_args.model_name_or_path:
561
+ model = FlaxAutoModelForMaskedLM.from_pretrained(
562
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
563
+ )
564
+ else:
565
+ model = FlaxAutoModelForMaskedLM.from_config(
566
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
567
+ )
568
+
569
+ # Store some constant
570
+ num_epochs = int(training_args.num_train_epochs)
571
+ train_batch_size = int(training_args.per_device_train_batch_size) * jax.device_count()
572
+ eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
573
+
574
+ # define number steps per stream epoch
575
+ num_train_steps = data_args.num_train_steps
576
+
577
+ # Create learning rate schedule
578
+ warmup_fn = optax.linear_schedule(
579
+ init_value=0.0, end_value=training_args.learning_rate, transition_steps=training_args.warmup_steps
580
+ )
581
+ decay_fn = optax.linear_schedule(
582
+ init_value=training_args.learning_rate,
583
+ end_value=0,
584
+ transition_steps=num_train_steps - training_args.warmup_steps,
585
+ )
586
+ linear_decay_lr_schedule_fn = optax.join_schedules(
587
+ schedules=[warmup_fn, decay_fn], boundaries=[training_args.warmup_steps]
588
+ )
589
+
590
+ # We use Optax's "masking" functionality to not apply weight decay
591
+ # to bias and LayerNorm scale parameters. decay_mask_fn returns a
592
+ # mask boolean with the same structure as the parameters.
593
+ # The mask is True for parameters that should be decayed.
594
+ # Note that this mask is specifically adapted for FlaxBERT-like models.
595
+ # For other models, one should correct the layer norm parameter naming
596
+ # accordingly.
597
+ def decay_mask_fn(params):
598
+ flat_params = traverse_util.flatten_dict(params)
599
+ flat_mask = {path: (path[-1] != "bias" and path[-2:] != ("LayerNorm", "scale")) for path in flat_params}
600
+ return traverse_util.unflatten_dict(flat_mask)
601
+
602
+ # create adam optimizer
603
+ adamw = optax.adamw(
604
+ learning_rate=linear_decay_lr_schedule_fn,
605
+ b1=training_args.adam_beta1,
606
+ b2=training_args.adam_beta2,
607
+ eps=training_args.adam_epsilon,
608
+ weight_decay=training_args.weight_decay,
609
+ mask=decay_mask_fn,
610
+ )
611
+
612
+ # Setup train state
613
+ state = train_state.TrainState.create(apply_fn=model.__call__, params=model.params, tx=adamw)
614
+ saved_step = -1
615
+ if model_args.model_name_or_path and "checkpoint" in model_args.model_name_or_path:
616
+ params, opt_state, saved_step, args, data_collator = restore_checkpoint(model_args.model_name_or_path, state)
617
+ # Create learning rate schedule
618
+ warmup_fn = optax.linear_schedule(
619
+ init_value=0.0, end_value=args.learning_rate, transition_steps=args.warmup_steps
620
+ )
621
+ decay_fn = optax.linear_schedule(
622
+ init_value=args.learning_rate,
623
+ end_value=0,
624
+ transition_steps=data_args.num_train_steps - args.warmup_steps,
625
+ )
626
+ linear_decay_lr_schedule_fn = optax.join_schedules(
627
+ schedules=[warmup_fn, decay_fn], boundaries=[args.warmup_steps]
628
+ )
629
+ # create adam optimizer
630
+ adamw = optax.adamw(
631
+ learning_rate=linear_decay_lr_schedule_fn,
632
+ b1=training_args.adam_beta1,
633
+ b2=training_args.adam_beta2,
634
+ eps=training_args.adam_epsilon,
635
+ weight_decay=args.weight_decay,
636
+ mask=decay_mask_fn,
637
+ )
638
+ state = train_state.TrainState(
639
+ step=saved_step,
640
+ apply_fn=model.__call__,
641
+ params=params,
642
+ tx=adamw,
643
+ opt_state=opt_state,
644
+ )
645
+ # self.args = args
646
+ # data_collator = data_collator
647
+ # scheduler_fn = args.learning_rate
648
+ model.params = params
649
+
650
+
651
+ # Define gradient update step fn
652
+ def train_step(state, batch, dropout_rng):
653
+ dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
654
+
655
+ def loss_fn(params):
656
+ labels = batch.pop("labels")
657
+
658
+ logits = state.apply_fn(**batch, params=params, dropout_rng=dropout_rng, train=True)[0]
659
+
660
+ # compute loss, ignore padded input tokens
661
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
662
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
663
+
664
+ # take average
665
+ loss = loss.sum() / label_mask.sum()
666
+
667
+ return loss
668
+
669
+ grad_fn = jax.value_and_grad(loss_fn)
670
+ loss, grad = grad_fn(state.params)
671
+ grad = jax.lax.pmean(grad, "batch")
672
+ new_state = state.apply_gradients(grads=grad)
673
+
674
+ metrics = jax.lax.pmean(
675
+ {"loss": loss, "learning_rate": linear_decay_lr_schedule_fn(state.step)}, axis_name="batch"
676
+ )
677
+
678
+ return new_state, metrics, new_dropout_rng
679
+
680
+ # Create parallel version of the train step
681
+ p_train_step = jax.pmap(train_step, "batch", donate_argnums=(0,))
682
+
683
+ # Define eval fn
684
+ def eval_step(params, batch):
685
+ labels = batch.pop("labels")
686
+
687
+ logits = model(**batch, params=params, train=False)[0]
688
+
689
+ # compute loss, ignore padded input tokens
690
+ label_mask = jnp.where(labels > 0, 1.0, 0.0)
691
+ loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])) * label_mask
692
+
693
+ # compute accuracy
694
+ accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask
695
+
696
+ # summarize metrics
697
+ metrics = {"loss": loss.sum(), "accuracy": accuracy.sum(), "normalizer": label_mask.sum()}
698
+ metrics = jax.lax.psum(metrics, axis_name="batch")
699
+
700
+ return metrics
701
+
702
+ p_eval_step = jax.pmap(eval_step, "batch", donate_argnums=(0,))
703
+
704
+ # Replicate the train state on each device
705
+ state = jax_utils.replicate(state)
706
+
707
+ train_time = 0
708
+ train_start = time.time()
709
+ train_metrics = []
710
+ eval_metrics = []
711
+
712
+ training_iter = iter(tokenized_datasets)
713
+
714
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
715
+ eval_samples = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
716
+
717
+ last_desc = ""
718
+ steps = tqdm(range(num_train_steps), desc="Training...", position=0)
719
+ for step in range(num_train_steps):
720
+ if step < saved_step:
721
+ steps.update(1)
722
+ continue
723
+ # ======================== Training ================================
724
+ try:
725
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
726
+ except StopIteration:
727
+ # Once the end of the dataset stream is reached, the training iterator
728
+ # is reinitialized and reshuffled and a new eval dataset is randomely chosen.
729
+ shuffle_seed += 1
730
+ tokenized_datasets.set_epoch(shuffle_seed)
731
+
732
+ training_iter = iter(tokenized_datasets)
733
+
734
+ eval_dataset = advance_iter_and_group_samples(training_iter, data_args.num_eval_samples, max_seq_length)
735
+ samples = advance_iter_and_group_samples(training_iter, train_batch_size, max_seq_length)
736
+
737
+ # process input samples
738
+ model_inputs = data_collator(samples, pad_to_multiple_of=16)
739
+
740
+ # Model forward
741
+ model_inputs = shard(model_inputs.data)
742
+ state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
743
+
744
+ train_metrics.append(train_metric)
745
+
746
+ if step % training_args.logging_steps == 0 and step > 0:
747
+ steps.write(
748
+ f"Step... ({step} | Loss: {train_metric['loss'].mean()}, Learning Rate: {train_metric['learning_rate'].mean()})"
749
+ )
750
+ train_time += time.time() - train_start
751
+ if has_tensorboard and jax.process_index() == 0:
752
+ write_train_metric(summary_writer, train_metrics, train_time, step)
753
+ train_metrics = []
754
+
755
+ # ======================== Evaluating ==============================
756
+ if step % training_args.eval_steps == 0 and step > 0:
757
+ eval_samples_idx = jnp.arange(data_args.num_eval_samples)
758
+ eval_batch_idx = generate_batch_splits(eval_samples_idx, eval_batch_size)
759
+
760
+ for i, batch_idx in enumerate(tqdm(eval_batch_idx, desc="Evaluating ...", position=1)):
761
+ # process input samples
762
+ batch_eval_samples = {k: [v[idx] for idx in batch_idx] for k, v in eval_samples.items()}
763
+ model_inputs = data_collator(batch_eval_samples, pad_to_multiple_of=16)
764
+
765
+ # Model forward
766
+ model_inputs = shard(model_inputs.data)
767
+ metrics = p_eval_step(state.params, model_inputs)
768
+ eval_metrics.append(metrics)
769
+
770
+ # normalize eval metrics
771
+ eval_metrics = get_metrics(eval_metrics)
772
+ eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
773
+ eval_normalizer = eval_metrics.pop("normalizer")
774
+ eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
775
+
776
+ # Update progress bar
777
+ steps.desc = f"Step... ({step}/{num_train_steps} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
778
+ last_desc = steps.desc
779
+
780
+ if has_tensorboard and jax.process_index() == 0:
781
+ write_eval_metric(summary_writer, eval_metrics, step)
782
+ eval_metrics = []
783
+
784
+ # save checkpoint after eval_steps
785
+ if step % training_args.save_steps == 0 and step > 0 and jax.process_index() == 0:
786
+ logger.info(f"Saving checkpoint at {step} steps")
787
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
788
+ model.save_pretrained(
789
+ training_args.output_dir,
790
+ params=params,
791
+ push_to_hub=False,
792
+ )
793
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
794
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
795
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
796
+ model.save_pretrained(checkpoints_dir, params=params)
797
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
798
+ rotate_checkpoints(
799
+ Path(training_args.output_dir) / "checkpoints",
800
+ max_checkpoints=training_args.save_total_limit
801
+ )
802
+ convert(training_args.output_dir, "./")
803
+ model.save_pretrained(
804
+ training_args.output_dir,
805
+ params=params,
806
+ push_to_hub=training_args.push_to_hub,
807
+ commit_message=last_desc,
808
+ )
809
+
810
+ # update tqdm bar
811
+ steps.update(1)
812
+
813
+ if jax.process_index() == 0:
814
+ logger.info(f"Saving checkpoint at {step} steps")
815
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
816
+ model.save_pretrained(
817
+ training_args.output_dir,
818
+ params=params,
819
+ push_to_hub=False,
820
+ )
821
+ save_checkpoint_files(state, data_collator, training_args, training_args.output_dir)
822
+ checkpoints_dir = Path(training_args.output_dir) / "checkpoints" / f"checkpoint-{step}"
823
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
824
+ model.save_pretrained(checkpoints_dir, params=params)
825
+ save_checkpoint_files(state, data_collator, training_args, checkpoints_dir)
826
+ convert(training_args.output_dir, "./")
827
+ model.save_pretrained(
828
+ training_args.output_dir,
829
+ params=params,
830
+ push_to_hub=training_args.push_to_hub,
831
+ commit_message=last_desc or "Saving model after training",
832
+ )
wandb/run-20210726_001233-17u6inbn/files/config.yaml ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wandb_version: 1
2
+
3
+ __cached__setup_devices:
4
+ desc: null
5
+ value: cpu
6
+ _n_gpu:
7
+ desc: null
8
+ value: 0
9
+ _wandb:
10
+ desc: null
11
+ value:
12
+ cli_version: 0.10.33
13
+ code_path: code/run_mlm_flax_stream.py
14
+ framework: huggingface
15
+ huggingface_version: 4.9.0.dev0
16
+ is_jupyter_run: false
17
+ is_kaggle_kernel: false
18
+ python_version: 3.8.10
19
+ t:
20
+ 1:
21
+ - 1
22
+ - 3
23
+ - 11
24
+ 4: 3.8.10
25
+ 5: 0.10.33
26
+ 6: 4.9.0.dev0
27
+ 8:
28
+ - 5
29
+ adafactor:
30
+ desc: null
31
+ value: false
32
+ adam_beta1:
33
+ desc: null
34
+ value: 0.9
35
+ adam_beta2:
36
+ desc: null
37
+ value: 0.98
38
+ adam_epsilon:
39
+ desc: null
40
+ value: 1.0e-06
41
+ cache_dir:
42
+ desc: null
43
+ value: null
44
+ config_name:
45
+ desc: null
46
+ value: ./configs/base
47
+ dataloader_drop_last:
48
+ desc: null
49
+ value: false
50
+ dataloader_num_workers:
51
+ desc: null
52
+ value: 0
53
+ dataloader_pin_memory:
54
+ desc: null
55
+ value: true
56
+ dataset_config_name:
57
+ desc: null
58
+ value: stepwise
59
+ dataset_name:
60
+ desc: null
61
+ value: bertin-project/mc4-es-sampled
62
+ ddp_find_unused_parameters:
63
+ desc: null
64
+ value: null
65
+ debug:
66
+ desc: null
67
+ value: []
68
+ deepspeed:
69
+ desc: null
70
+ value: null
71
+ disable_tqdm:
72
+ desc: null
73
+ value: false
74
+ do_eval:
75
+ desc: null
76
+ value: false
77
+ do_predict:
78
+ desc: null
79
+ value: false
80
+ do_train:
81
+ desc: null
82
+ value: false
83
+ dtype:
84
+ desc: null
85
+ value: bfloat16
86
+ eval_accumulation_steps:
87
+ desc: null
88
+ value: null
89
+ eval_steps:
90
+ desc: null
91
+ value: 1000
92
+ evaluation_strategy:
93
+ desc: null
94
+ value: IntervalStrategy.NO
95
+ fp16:
96
+ desc: null
97
+ value: false
98
+ fp16_backend:
99
+ desc: null
100
+ value: auto
101
+ fp16_full_eval:
102
+ desc: null
103
+ value: false
104
+ fp16_opt_level:
105
+ desc: null
106
+ value: O1
107
+ gradient_accumulation_steps:
108
+ desc: null
109
+ value: 1
110
+ greater_is_better:
111
+ desc: null
112
+ value: null
113
+ group_by_length:
114
+ desc: null
115
+ value: false
116
+ ignore_data_skip:
117
+ desc: null
118
+ value: false
119
+ label_names:
120
+ desc: null
121
+ value: null
122
+ label_smoothing_factor:
123
+ desc: null
124
+ value: 0.0
125
+ learning_rate:
126
+ desc: null
127
+ value: 0.0006
128
+ length_column_name:
129
+ desc: null
130
+ value: length
131
+ line_by_line:
132
+ desc: null
133
+ value: false
134
+ load_best_model_at_end:
135
+ desc: null
136
+ value: false
137
+ local_rank:
138
+ desc: null
139
+ value: -1
140
+ log_level:
141
+ desc: null
142
+ value: -1
143
+ log_level_replica:
144
+ desc: null
145
+ value: -1
146
+ log_on_each_node:
147
+ desc: null
148
+ value: true
149
+ logging_dir:
150
+ desc: null
151
+ value: ./outputs/runs/Jul26_00-12-25_tablespoon
152
+ logging_first_step:
153
+ desc: null
154
+ value: false
155
+ logging_steps:
156
+ desc: null
157
+ value: 500
158
+ logging_strategy:
159
+ desc: null
160
+ value: IntervalStrategy.STEPS
161
+ lr_scheduler_type:
162
+ desc: null
163
+ value: SchedulerType.LINEAR
164
+ max_grad_norm:
165
+ desc: null
166
+ value: 1.0
167
+ max_seq_length:
168
+ desc: null
169
+ value: 512
170
+ max_steps:
171
+ desc: null
172
+ value: -1
173
+ metric_for_best_model:
174
+ desc: null
175
+ value: null
176
+ mlm_probability:
177
+ desc: null
178
+ value: 0.15
179
+ model_name_or_path:
180
+ desc: null
181
+ value: bertin-project/bertin-base-stepwise
182
+ model_type:
183
+ desc: null
184
+ value: roberta
185
+ mp_parameters:
186
+ desc: null
187
+ value: ''
188
+ no_cuda:
189
+ desc: null
190
+ value: false
191
+ num_eval_samples:
192
+ desc: null
193
+ value: 50000
194
+ num_train_epochs:
195
+ desc: null
196
+ value: 3.0
197
+ num_train_steps:
198
+ desc: null
199
+ value: 50000
200
+ output_dir:
201
+ desc: null
202
+ value: ./outputs
203
+ overwrite_cache:
204
+ desc: null
205
+ value: false
206
+ overwrite_output_dir:
207
+ desc: null
208
+ value: true
209
+ pad_to_max_length:
210
+ desc: null
211
+ value: true
212
+ past_index:
213
+ desc: null
214
+ value: -1
215
+ per_device_eval_batch_size:
216
+ desc: null
217
+ value: 48
218
+ per_device_train_batch_size:
219
+ desc: null
220
+ value: 48
221
+ per_gpu_eval_batch_size:
222
+ desc: null
223
+ value: null
224
+ per_gpu_train_batch_size:
225
+ desc: null
226
+ value: null
227
+ prediction_loss_only:
228
+ desc: null
229
+ value: false
230
+ preprocessing_num_workers:
231
+ desc: null
232
+ value: null
233
+ push_to_hub:
234
+ desc: null
235
+ value: false
236
+ push_to_hub_model_id:
237
+ desc: null
238
+ value: outputs
239
+ push_to_hub_organization:
240
+ desc: null
241
+ value: null
242
+ push_to_hub_token:
243
+ desc: null
244
+ value: null
245
+ remove_unused_columns:
246
+ desc: null
247
+ value: true
248
+ report_to:
249
+ desc: null
250
+ value:
251
+ - tensorboard
252
+ - wandb
253
+ resume_from_checkpoint:
254
+ desc: null
255
+ value: null
256
+ run_name:
257
+ desc: null
258
+ value: ./outputs
259
+ save_on_each_node:
260
+ desc: null
261
+ value: false
262
+ save_steps:
263
+ desc: null
264
+ value: 1000
265
+ save_strategy:
266
+ desc: null
267
+ value: IntervalStrategy.STEPS
268
+ save_total_limit:
269
+ desc: null
270
+ value: 5
271
+ seed:
272
+ desc: null
273
+ value: 42
274
+ sharded_ddp:
275
+ desc: null
276
+ value: []
277
+ shuffle_buffer_size:
278
+ desc: null
279
+ value: 10000
280
+ skip_memory_metrics:
281
+ desc: null
282
+ value: true
283
+ text_column_name:
284
+ desc: null
285
+ value: text
286
+ tokenizer_name:
287
+ desc: null
288
+ value: ./configs/base
289
+ tpu_metrics_debug:
290
+ desc: null
291
+ value: false
292
+ tpu_num_cores:
293
+ desc: null
294
+ value: null
295
+ train_file:
296
+ desc: null
297
+ value: null
298
+ train_ref_file:
299
+ desc: null
300
+ value: null
301
+ use_fast_tokenizer:
302
+ desc: null
303
+ value: true
304
+ use_legacy_prediction_loop:
305
+ desc: null
306
+ value: false
307
+ validation_file:
308
+ desc: null
309
+ value: null
310
+ validation_ref_file:
311
+ desc: null
312
+ value: null
313
+ validation_split_percentage:
314
+ desc: null
315
+ value: 5
316
+ warmup_ratio:
317
+ desc: null
318
+ value: 0.0
319
+ warmup_steps:
320
+ desc: null
321
+ value: 500
322
+ weight_decay:
323
+ desc: null
324
+ value: 0.01
wandb/run-20210726_001233-17u6inbn/files/events.out.tfevents.1627258355.tablespoon.3000110.3.v2 ADDED
@@ -0,0 +1 @@
 
 
1
+ /var/hf/experiment-base-exp-512seq-stepwise/outputs/events.out.tfevents.1627258355.tablespoon.3000110.3.v2
wandb/run-20210726_001233-17u6inbn/files/output.log ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2021-07-26 00:12:35.575266: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2
+ 2021-07-26 00:12:35.575304: W tensorflow/stream_executor/cuda/cuda_driver.cc:326] failed call to cuInit: UNKNOWN ERROR (303)
3
+ [00:12:36] - INFO - filelock - Lock 139656499698272 acquired on /home/versae/.cache/huggingface/transformers/27b7e968d2908b27f8c1df265c2dc08aef61be0f25bdc735df4df552829968fd.04a8293889c44bb7f31a5ee6212b8aa0b690121444e9c7ce1616fbe2a461ebba.lock
4
+
5
+
6
+
7
+ Downloading: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 250M/250M [00:06<00:00, 35.8MB/s]
8
+ [00:12:43] - INFO - filelock - Lock 139656499698272 released on /home/versae/.cache/huggingface/transformers/27b7e968d2908b27f8c1df265c2dc08aef61be0f25bdc735df4df552829968fd.04a8293889c44bb7f31a5ee6212b8aa0b690121444e9c7ce1616fbe2a461ebba.lock
9
+ /var/hf/venv/lib/python3.8/site-packages/jax/lib/xla_bridge.py:386: UserWarning: jax.host_count has been renamed to jax.process_count. This alias will eventually be removed; please update your code.
10
+ warnings.warn(
11
+ /var/hf/venv/lib/python3.8/site-packages/jax/lib/xla_bridge.py:373: UserWarning: jax.host_id has been renamed to jax.process_index. This alias will eventually be removed; please update your code.
12
+ warnings.warn(
13
+
14
+
15
+
16
+
17
+
18
+
19
+
20
+
21
+
22
+
23
+
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+
35
+
36
+
37
+
38
+
39
+
40
+
41
+
42
+
43
+
44
+
45
+
46
+
47
+
48
+
49
+
50
+
51
+
52
+
53
+
54
+
55
+
56
+
57
+
58
+
59
+
60
+
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+
70
+
71
+
72
+
73
+
74
+
75
+
76
+
77
+
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
86
+
87
+
88
+
89
+
90
+
91
+
92
+
93
+
94
+
95
+
96
+
97
+
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+
110
+
111
+
112
+
113
+
114
+
115
+
116
+
117
+
118
+
119
+
120
+
121
+
122
+
123
+
124
+
125
+
126
+
127
+
128
+
129
+
130
+
131
+
132
+
133
+
134
+
135
+
136
+
137
+
138
+
139
+
140
+
141
+
142
+
143
+
144
+
145
+
146
+
147
+
148
+
149
+
150
+
151
+
152
+
153
+
154
+
155
+
156
+
157
+
158
+
159
+
160
+
161
+
162
+
163
+
164
+
165
+
166
+
167
+
168
+
169
+
170
+
171
+
172
+
173
+
174
+
175
+
176
+
177
+
178
+
179
+
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+
190
+
191
+
192
+
193
+
194
+
195
+
196
+
197
+
198
+
199
+
200
+
201
+
202
+
203
+
204
+
205
+
206
+
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+
216
+
217
+
218
+
219
+
220
+
221
+
222
+
223
+
224
+
225
+
226
+
227
+
228
+
229
+
230
+
231
+
232
+
233
+
234
+
235
+
236
+
237
+
238
+
239
+
240
+
241
+
242
+
243
+
244
+
245
+
246
+
247
+
248
+
249
+
250
+
251
+
252
+
253
+
254
+
255
+
256
+
257
+
258
+
259
+
260
+
261
+
262
+
263
+
264
+
265
+
266
+
267
+
268
+
269
+
270
+
271
+
272
+
273
+
274
+
275
+
276
+
277
+
278
+
279
+
280
+
281
+
282
+
283
+
284
+
285
+
286
+
287
+
288
+
289
+
290
+
291
+
292
+
293
+
294
+
295
+
296
+
297
+
298
+
299
+
300
+
301
+
302
+
303
+
304
+
305
+
306
+
307
+
308
+
309
+
310
+
311
+
312
+
313
+
314
+
315
+
316
+
317
+
318
+
319
+
320
+
321
+
322
+
323
+
324
+
325
+
326
+
327
+
328
+
329
+
330
+
331
+
332
+
333
+
334
+
335
+
336
+
337
+
338
+
339
+
340
+
341
+
342
+
343
+
344
+
345
+
346
+
347
+
348
+
349
+
350
+
351
+
352
+
353
+
354
+
355
+
356
+
357
+
358
+
359
+
360
+
361
+
362
+
363
+
364
+
365
+
366
+
367
+
368
+
369
+
370
+
371
+
372
+
373
+
374
+
375
+
376
+
377
+
378
+
379
+
380
+
381
+
382
+
383
+
384
+
385
+
386
+
387
+
388
+
389
+
390
+
391
+
392
+
393
+
394
+
395
+
396
+
397
+
398
+
399
+
400
+
401
+
402
+
403
+
404
+
405
+
406
+
407
+
408
+
409
+
410
+
411
+
412
+
413
+
414
+
415
+
416
+
417
+
418
+
419
+
420
+
421
+
422
+
423
+
424
+
425
+
426
+
427
+
428
+
429
+
430
+
431
+
432
+
433
+
434
+
435
+
436
+
437
+
438
+
439
+
440
+
441
+
442
+
443
+
444
+
445
+
446
+
447
+
448
+
449
+
450
+
451
+
452
+
453
+
454
+
455
+
456
+
457
+
458
+
459
+
460
+
461
+
462
+
463
+
464
+
465
+
466
+
467
+
468
+
469
+
470
+
471
+
472
+
473
+
474
+
475
+
476
+
477
+
478
+
479
+
480
+
481
+
482
+
483
+
484
+
485
+
486
+
487
+
488
+
489
+
490
+
491
+
492
+
493
+
494
+
495
+
496
+
497
+
498
+
499
+
500
+
501
+
502
+
503
+
504
+
505
+
506
+
507
+
508
+
509
+
510
+
511
+
512
+
513
+
514
+
515
+
516
+
517
+
518
+
519
+
520
+
521
+
522
+
523
+
524
+
525
+
526
+
527
+
528
+
529
+
530
+
531
+
532
+
533
+
534
+
535
+
536
+
537
+
538
+
539
+
540
+
541
+
542
+
543
+
544
+
545
+
546
+
547
+
548
+
549
+
550
+
551
+
552
+
553
+
554
+
555
+
556
+
557
+
558
+
559
+
560
+
561
+
562
+
563
+
564
+
565
+
566
+
567
+
568
+
569
+
570
+
571
+
572
+
573
+
574
+
575
+
576
+
577
+
578
+
579
+
580
+
581
+
582
+
583
+
584
+
585
+
586
+
587
+
588
+
589
+
590
+
591
+
592
+
593
+
594
+
595
+
596
+
597
+
598
+
599
+
600
+
601
+
602
+
603
+
604
+
605
+
606
+
607
+
608
+
609
+
610
+
611
+
612
+
613
+
614
+
615
+
616
+
617
+
618
+
619
+
620
+
621
+
622
+
623
+
624
+
625
+
626
+ Training...: 2%|█▊ | 1000/50000 [22:19<17:30:45, 1.29s/it]
627
+ Step... (500 | Loss: 1.8920137882232666, Learning Rate: 0.0006000000284984708)
628
+ Training...: 2%|█▊ | 1000/50000 [22:21<17:30:45, 1.29s/it]
629
+
630
+
631
+
632
+
633
+
634
+
635
+
636
+
637
+
638
+
639
+
640
+
641
+
642
+ [02:30:54] - INFO - __main__ - Saving checkpoint at 1000 steps██████████████████████████████████████████████████████| 130/130 [00:31<00:00, 4.59it/s]
643
+ /var/hf/transformers-orig/src/transformers/modeling_flax_pytorch_utils.py:201: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:180.)
644
+ pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
645
+ All Flax model weights were used when initializing RobertaForMaskedLM.
646
+ Some weights of RobertaForMaskedLM were not initialized from the Flax model and are newly initialized: ['lm_head.decoder.weight', 'roberta.embeddings.position_ids', 'lm_head.decoder.bias']
647
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
648
+
649
+
650
+
651
+
652
+
653
+
654
+
655
+
656
+
657
+
658
+
659
+
660
+
661
+
662
+
663
+
664
+
665
+
666
+
667
+
668
+
669
+
670
+
671
+
672
+
673
+
674
+
675
+
676
+
677
+
678
+
679
+
680
+
681
+
682
+
683
+
684
+
685
+
686
+
687
+
688
+
689
+
690
+
691
+
692
+
693
+
694
+
695
+
696
+
697
+
698
+
699
+
700
+
701
+
702
+
703
+
704
+
705
+
706
+
707
+
708
+
709
+
710
+
711
+
712
+
713
+
714
+
715
+
716
+
717
+
718
+
719
+
720
+
721
+
722
+
723
+
724
+
725
+
726
+
727
+
728
+
729
+
730
+
731
+
732
+
733
+
734
+
735
+
736
+
737
+
738
+
739
+
740
+
741
+
742
+
743
+
744
+
745
+
746
+
747
+
748
+
749
+
750
+
751
+
752
+
753
+
754
+
755
+
756
+
757
+
758
+
759
+
760
+
761
+
762
+
763
+
764
+
765
+
766
+
767
+
768
+
769
+
770
+
771
+
772
+
773
+
774
+
775
+
776
+
777
+
778
+
779
+
780
+
781
+
782
+
783
+
784
+
785
+
786
+
787
+
788
+
789
+
790
+
791
+
792
+
793
+
794
+
795
+
796
+
797
+
798
+
799
+
800
+
801
+
802
+
803
+
804
+
805
+
806
+
807
+
808
+
809
+
810
+
811
+
812
+
813
+
814
+
815
+
816
+
817
+
818
+
819
+
820
+
821
+
822
+
823
+
wandb/run-20210726_001233-17u6inbn/files/requirements.txt ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==0.13.0
2
+ aiohttp==3.7.4.post0
3
+ astunparse==1.6.3
4
+ async-timeout==3.0.1
5
+ attrs==21.2.0
6
+ backcall==0.2.0
7
+ cachetools==4.2.2
8
+ certifi==2021.5.30
9
+ chardet==4.0.0
10
+ chex==0.0.8
11
+ click==8.0.1
12
+ configparser==5.0.2
13
+ cycler==0.10.0
14
+ datasets==1.9.1.dev0
15
+ decorator==5.0.9
16
+ dill==0.3.4
17
+ dm-tree==0.1.6
18
+ docker-pycreds==0.4.0
19
+ filelock==3.0.12
20
+ flatbuffers==1.12
21
+ flax==0.3.4
22
+ fsspec==2021.6.1
23
+ gast==0.4.0
24
+ gitdb==4.0.7
25
+ gitpython==3.1.18
26
+ google-auth-oauthlib==0.4.4
27
+ google-auth==1.32.1
28
+ google-pasta==0.2.0
29
+ grpcio==1.38.1
30
+ h5py==3.1.0
31
+ huggingface-hub==0.0.12
32
+ idna==2.10
33
+ ipython-genutils==0.2.0
34
+ ipython==7.25.0
35
+ jax==0.2.17
36
+ jaxlib==0.1.68
37
+ jedi==0.18.0
38
+ joblib==1.0.1
39
+ kenlm==0.0.0
40
+ keras-nightly==2.5.0.dev2021032900
41
+ keras-preprocessing==1.1.2
42
+ kiwisolver==1.3.1
43
+ libtpu-nightly==0.1.dev20210615
44
+ markdown==3.3.4
45
+ matplotlib-inline==0.1.2
46
+ matplotlib==3.4.2
47
+ msgpack==1.0.2
48
+ multidict==5.1.0
49
+ multiprocess==0.70.12.2
50
+ numpy==1.21.0
51
+ oauthlib==3.1.1
52
+ opt-einsum==3.3.0
53
+ optax==0.0.9
54
+ packaging==21.0
55
+ pandas==1.3.0
56
+ parso==0.8.2
57
+ pathtools==0.1.2
58
+ pexpect==4.8.0
59
+ pickleshare==0.7.5
60
+ pillow==8.3.1
61
+ pip==20.0.2
62
+ pkg-resources==0.0.0
63
+ promise==2.3
64
+ prompt-toolkit==3.0.19
65
+ protobuf==3.17.3
66
+ psutil==5.8.0
67
+ ptyprocess==0.7.0
68
+ pyarrow==4.0.1
69
+ pyasn1-modules==0.2.8
70
+ pyasn1==0.4.8
71
+ pygments==2.9.0
72
+ pyparsing==2.4.7
73
+ python-dateutil==2.8.1
74
+ pytz==2021.1
75
+ pyyaml==5.4.1
76
+ regex==2021.7.6
77
+ requests-oauthlib==1.3.0
78
+ requests==2.25.1
79
+ rsa==4.7.2
80
+ sacremoses==0.0.45
81
+ scipy==1.7.0
82
+ sentry-sdk==1.3.0
83
+ setuptools==44.0.0
84
+ shortuuid==1.0.1
85
+ six==1.15.0
86
+ smmap==4.0.0
87
+ subprocess32==3.5.4
88
+ tensorboard-data-server==0.6.1
89
+ tensorboard-plugin-wit==1.8.0
90
+ tensorboard==2.5.0
91
+ tensorflow-estimator==2.5.0
92
+ tensorflow==2.5.0
93
+ termcolor==1.1.0
94
+ tokenizers==0.10.3
95
+ toolz==0.11.1
96
+ torch==1.9.0
97
+ tqdm==4.61.2
98
+ traitlets==5.0.5
99
+ transformers==4.9.0.dev0
100
+ typing-extensions==3.7.4.3
101
+ urllib3==1.26.6
102
+ wandb==0.10.33
103
+ wcwidth==0.2.5
104
+ werkzeug==2.0.1
105
+ wheel==0.36.2
106
+ wrapt==1.12.1
107
+ xxhash==2.0.2
108
+ yarl==1.6.3
wandb/run-20210726_001233-17u6inbn/files/wandb-metadata.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-1043-gcp-x86_64-with-glibc2.29",
3
+ "python": "3.8.10",
4
+ "heartbeatAt": "2021-07-26T00:12:35.406409",
5
+ "startedAt": "2021-07-26T00:12:33.305928",
6
+ "docker": null,
7
+ "cpu_count": 96,
8
+ "cuda": null,
9
+ "args": [
10
+ "--model_name_or_path=bertin-project/bertin-base-stepwise",
11
+ "--output_dir=./outputs",
12
+ "--model_type=roberta",
13
+ "--config_name=./configs/base",
14
+ "--tokenizer_name=./configs/base",
15
+ "--dataset_name=bertin-project/mc4-es-sampled",
16
+ "--dataset_config_name=stepwise",
17
+ "--max_seq_length=512",
18
+ "--pad_to_max_length",
19
+ "--per_device_train_batch_size=48",
20
+ "--per_device_eval_batch_size=48",
21
+ "--adam_beta1=0.9",
22
+ "--adam_beta2=0.98",
23
+ "--adam_epsilon=1e-6",
24
+ "--learning_rate=6e-4",
25
+ "--weight_decay=0.01",
26
+ "--save_steps=1000",
27
+ "--save_total_limit=5",
28
+ "--warmup_steps=500",
29
+ "--overwrite_output_dir",
30
+ "--num_train_steps=50000",
31
+ "--eval_steps=1000",
32
+ "--dtype=bfloat16",
33
+ "--logging_steps=500"
34
+ ],
35
+ "state": "running",
36
+ "program": "./run_mlm_flax_stream.py",
37
+ "codePath": "run_mlm_flax_stream.py",
38
+ "git": {
39
+ "remote": "https://huggingface.co/bertin-project/bertin-base-stepwise-exp-512seqlen",
40
+ "commit": "529e26d977dcd80df13f8ff4dc528756c974c3b3"
41
+ },
42
+ "email": "[email protected]",
43
+ "root": "/var/hf/experiment-base-exp-512seq-stepwise",
44
+ "host": "tablespoon",
45
+ "username": "versae",
46
+ "executable": "/var/hf/venv/bin/python"
47
+ }
wandb/run-20210726_001233-17u6inbn/files/wandb-summary.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"global_step": 1000, "_timestamp": 1627266616.074129, "train_time": 15844.4716796875, "train_learning_rate": 0.0005939393886364996, "_step": 1994, "train_loss": 1.8408620357513428}
wandb/run-20210726_001233-17u6inbn/logs/debug-internal.log ADDED
The diff for this file is too large to render. See raw diff
 
wandb/run-20210726_001233-17u6inbn/logs/debug.log ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_setup.py:_flush():69] setting env: {}
2
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_setup.py:_flush():69] setting login settings: {}
3
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_init.py:_log_setup():337] Logging user logs to /var/hf/experiment-base-exp-512seq-stepwise/wandb/run-20210726_001233-17u6inbn/logs/debug.log
4
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_init.py:_log_setup():338] Logging internal logs to /var/hf/experiment-base-exp-512seq-stepwise/wandb/run-20210726_001233-17u6inbn/logs/debug-internal.log
5
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_init.py:init():370] calling init triggers
6
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_init.py:init():375] wandb.init called with sweep_config: {}
7
+ config: {}
8
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [wandb_init.py:init():419] starting backend
9
+ 2021-07-26 00:12:33,307 INFO MainThread:3000110 [backend.py:_multiprocessing_setup():70] multiprocessing start_methods=fork,spawn,forkserver, using: spawn
10
+ 2021-07-26 00:12:33,351 INFO MainThread:3000110 [backend.py:ensure_launched():135] starting backend process...
11
+ 2021-07-26 00:12:33,394 INFO MainThread:3000110 [backend.py:ensure_launched():139] started backend process with pid: 3001431
12
+ 2021-07-26 00:12:33,396 INFO MainThread:3000110 [wandb_init.py:init():424] backend started and connected
13
+ 2021-07-26 00:12:33,399 INFO MainThread:3000110 [wandb_init.py:init():472] updated telemetry
14
+ 2021-07-26 00:12:33,400 INFO MainThread:3000110 [wandb_init.py:init():491] communicating current version
15
+ 2021-07-26 00:12:34,050 INFO MainThread:3000110 [wandb_init.py:init():496] got version response upgrade_message: "wandb version 0.11.0 is available! To upgrade, please run:\n $ pip install wandb --upgrade"
16
+
17
+ 2021-07-26 00:12:34,050 INFO MainThread:3000110 [wandb_init.py:init():504] communicating run to backend with 30 second timeout
18
+ 2021-07-26 00:12:34,261 INFO MainThread:3000110 [wandb_init.py:init():529] starting run threads in backend
19
+ 2021-07-26 00:12:35,502 INFO MainThread:3000110 [wandb_run.py:_console_start():1623] atexit reg
20
+ 2021-07-26 00:12:35,502 INFO MainThread:3000110 [wandb_run.py:_redirect():1497] redirect: SettingsConsole.REDIRECT
21
+ 2021-07-26 00:12:35,503 INFO MainThread:3000110 [wandb_run.py:_redirect():1502] Redirecting console.
22
+ 2021-07-26 00:12:35,505 INFO MainThread:3000110 [wandb_run.py:_redirect():1558] Redirects installed.
23
+ 2021-07-26 00:12:35,505 INFO MainThread:3000110 [wandb_init.py:init():554] run started, returning control to user process
24
+ 2021-07-26 00:12:35,506 INFO MainThread:3000110 [wandb_run.py:_config_callback():872] config_cb None None {'output_dir': './outputs', 'overwrite_output_dir': True, 'do_train': False, 'do_eval': False, 'do_predict': False, 'evaluation_strategy': 'IntervalStrategy.NO', 'prediction_loss_only': False, 'per_device_train_batch_size': 48, 'per_device_eval_batch_size': 48, 'per_gpu_train_batch_size': None, 'per_gpu_eval_batch_size': None, 'gradient_accumulation_steps': 1, 'eval_accumulation_steps': None, 'learning_rate': 0.0006, 'weight_decay': 0.01, 'adam_beta1': 0.9, 'adam_beta2': 0.98, 'adam_epsilon': 1e-06, 'max_grad_norm': 1.0, 'num_train_epochs': 3.0, 'max_steps': -1, 'lr_scheduler_type': 'SchedulerType.LINEAR', 'warmup_ratio': 0.0, 'warmup_steps': 500, 'log_level': -1, 'log_level_replica': -1, 'log_on_each_node': True, 'logging_dir': './outputs/runs/Jul26_00-12-25_tablespoon', 'logging_strategy': 'IntervalStrategy.STEPS', 'logging_first_step': False, 'logging_steps': 500, 'save_strategy': 'IntervalStrategy.STEPS', 'save_steps': 1000, 'save_total_limit': 5, 'save_on_each_node': False, 'no_cuda': False, 'seed': 42, 'fp16': False, 'fp16_opt_level': 'O1', 'fp16_backend': 'auto', 'fp16_full_eval': False, 'local_rank': -1, 'tpu_num_cores': None, 'tpu_metrics_debug': False, 'debug': [], 'dataloader_drop_last': False, 'eval_steps': 1000, 'dataloader_num_workers': 0, 'past_index': -1, 'run_name': './outputs', 'disable_tqdm': False, 'remove_unused_columns': True, 'label_names': None, 'load_best_model_at_end': False, 'metric_for_best_model': None, 'greater_is_better': None, 'ignore_data_skip': False, 'sharded_ddp': [], 'deepspeed': None, 'label_smoothing_factor': 0.0, 'adafactor': False, 'group_by_length': False, 'length_column_name': 'length', 'report_to': ['tensorboard', 'wandb'], 'ddp_find_unused_parameters': None, 'dataloader_pin_memory': True, 'skip_memory_metrics': True, 'use_legacy_prediction_loop': False, 'push_to_hub': False, 'resume_from_checkpoint': None, 'push_to_hub_model_id': 'outputs', 'push_to_hub_organization': None, 'push_to_hub_token': None, 'mp_parameters': '', '_n_gpu': 0, '__cached__setup_devices': 'cpu'}
25
+ 2021-07-26 00:12:35,507 INFO MainThread:3000110 [wandb_run.py:_config_callback():872] config_cb None None {'model_name_or_path': 'bertin-project/bertin-base-stepwise', 'model_type': 'roberta', 'config_name': './configs/base', 'tokenizer_name': './configs/base', 'cache_dir': None, 'use_fast_tokenizer': True, 'dtype': 'bfloat16'}
26
+ 2021-07-26 00:12:35,508 INFO MainThread:3000110 [wandb_run.py:_config_callback():872] config_cb None None {'dataset_name': 'bertin-project/mc4-es-sampled', 'dataset_config_name': 'stepwise', 'train_file': None, 'validation_file': None, 'train_ref_file': None, 'validation_ref_file': None, 'overwrite_cache': False, 'validation_split_percentage': 5, 'max_seq_length': 512, 'preprocessing_num_workers': None, 'mlm_probability': 0.15, 'pad_to_max_length': True, 'line_by_line': False, 'text_column_name': 'text', 'shuffle_buffer_size': 10000, 'num_train_steps': 50000, 'num_eval_samples': 50000}
27
+ 2021-07-26 00:12:35,587 INFO MainThread:3000110 [wandb_run.py:_tensorboard_callback():943] tensorboard callback: outputs, None
wandb/run-20210726_001233-17u6inbn/run-17u6inbn.wandb ADDED
Binary file (463 kB). View file