Adding checkpointing, wandb, and new mlm script
Browse files- README.md +29 -1
- perplexity.py +22 -0
- run_mlm_flax.py +60 -31
- tokens.py +3 -1
README.md
CHANGED
@@ -1,19 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# BERTIN
|
|
|
2 |
BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
|
3 |
|
4 |
This is part of the
|
5 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
## Team members
|
|
|
8 |
- Javier de la Rosa (versae)
|
9 |
- Manu Romero (mrm8488)
|
10 |
- María Grandury (mariagrandury)
|
11 |
- Ari Polakov (aripo99)
|
12 |
- Pablogps
|
13 |
- daveni
|
14 |
-
- Sri Lakshmi
|
15 |
|
16 |
## Useful links
|
|
|
17 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
18 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
19 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
|
|
1 |
+
---
|
2 |
+
language: no
|
3 |
+
license: CC-BY 4.0
|
4 |
+
tags:
|
5 |
+
- spanish
|
6 |
+
- roberta
|
7 |
+
pipeline_tag: fill-mask
|
8 |
+
widget:
|
9 |
+
- text: "Lo hizo en un abrir y cerar de <mask>."
|
10 |
+
---
|
11 |
+
|
12 |
# BERTIN
|
13 |
+
|
14 |
BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
|
15 |
|
16 |
This is part of the
|
17 |
[Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
18 |
|
19 |
+
## Spanish mC4
|
20 |
+
|
21 |
+
The Spanish portion of mC4 containes about 416 million records and 235 billion words.
|
22 |
+
|
23 |
+
```bash
|
24 |
+
$ zcat c4/multilingual/c4-es*.tfrecord*.json.gz | wc -l
|
25 |
+
416057992
|
26 |
+
```
|
27 |
+
|
28 |
+
```bash
|
29 |
+
$ zcat c4/multilingual/c4-es*.tfrecord-*.json.gz | jq -r '.text | split(" ") | length' | paste -s -d+ - | bc
|
30 |
+
235303687795
|
31 |
+
```
|
32 |
+
|
33 |
## Team members
|
34 |
+
|
35 |
- Javier de la Rosa (versae)
|
36 |
- Manu Romero (mrm8488)
|
37 |
- María Grandury (mariagrandury)
|
38 |
- Ari Polakov (aripo99)
|
39 |
- Pablogps
|
40 |
- daveni
|
41 |
+
- Sri Lakshmi
|
42 |
|
43 |
## Useful links
|
44 |
+
|
45 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
46 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
47 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
perplexity.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import kenlm
|
3 |
+
from datasets import load_dataset
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
|
7 |
+
def pp(log_score, length):
|
8 |
+
return 10.0 ** (-log_score / length)
|
9 |
+
|
10 |
+
# http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
|
11 |
+
model = kenlm.Model("es.arpa.bin")
|
12 |
+
mc4 = load_dataset("mc4", "es", streaming=True)
|
13 |
+
with open("mc4-es-perplexity.txt", "w") as f:
|
14 |
+
for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
|
15 |
+
lines = sample["text"].split("\n")
|
16 |
+
doc_log_score, doc_length = 0, 0
|
17 |
+
for line in lines:
|
18 |
+
log_score = model.score(line)
|
19 |
+
length = len(line.split()) + 1
|
20 |
+
doc_log_score += log_score
|
21 |
+
doc_length += length
|
22 |
+
f.write(f"{pp(doc_log_score, doc_length)}\n")
|
run_mlm_flax.py
CHANGED
@@ -56,22 +56,6 @@ from transformers import (
|
|
56 |
)
|
57 |
|
58 |
|
59 |
-
# Cache the result
|
60 |
-
has_tensorboard = is_tensorboard_available()
|
61 |
-
if has_tensorboard:
|
62 |
-
try:
|
63 |
-
from flax.metrics.tensorboard import SummaryWriter
|
64 |
-
except ImportError as ie:
|
65 |
-
has_tensorboard = False
|
66 |
-
print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
|
67 |
-
|
68 |
-
else:
|
69 |
-
print(
|
70 |
-
"Unable to display metrics through TensorBoard because the package is not installed: "
|
71 |
-
"Please run pip install tensorboard to enable."
|
72 |
-
)
|
73 |
-
|
74 |
-
|
75 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
76 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
77 |
|
@@ -126,6 +110,9 @@ class DataTrainingArguments:
|
|
126 |
dataset_config_name: Optional[str] = field(
|
127 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
128 |
)
|
|
|
|
|
|
|
129 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
130 |
validation_file: Optional[str] = field(
|
131 |
default=None,
|
@@ -269,7 +256,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
|
|
269 |
return batch_idx
|
270 |
|
271 |
|
272 |
-
def
|
273 |
summary_writer.scalar("train_time", train_time, step)
|
274 |
|
275 |
train_metrics = get_metrics(train_metrics)
|
@@ -278,6 +265,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
278 |
for i, val in enumerate(vals):
|
279 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
280 |
|
|
|
|
|
281 |
for metric_name, value in eval_metrics.items():
|
282 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
283 |
|
@@ -315,10 +304,6 @@ if __name__ == "__main__":
|
|
315 |
|
316 |
# Log on each process the small summary:
|
317 |
logger = logging.getLogger(__name__)
|
318 |
-
#logger.warning(
|
319 |
-
# f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
320 |
-
# + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
321 |
-
#)
|
322 |
|
323 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
324 |
logger.info(f"Training/evaluation parameters {training_args}")
|
@@ -337,7 +322,7 @@ if __name__ == "__main__":
|
|
337 |
# download the dataset.
|
338 |
if data_args.dataset_name is not None:
|
339 |
# Downloading and loading a dataset from the hub.
|
340 |
-
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
|
341 |
|
342 |
if "validation" not in datasets.keys():
|
343 |
datasets["validation"] = load_dataset(
|
@@ -345,12 +330,14 @@ if __name__ == "__main__":
|
|
345 |
data_args.dataset_config_name,
|
346 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
347 |
cache_dir=model_args.cache_dir,
|
|
|
348 |
)
|
349 |
datasets["train"] = load_dataset(
|
350 |
data_args.dataset_name,
|
351 |
data_args.dataset_config_name,
|
352 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
353 |
cache_dir=model_args.cache_dir,
|
|
|
354 |
)
|
355 |
else:
|
356 |
data_files = {}
|
@@ -469,10 +456,32 @@ if __name__ == "__main__":
|
|
469 |
num_proc=data_args.preprocessing_num_workers,
|
470 |
load_from_cache_file=not data_args.overwrite_cache,
|
471 |
)
|
472 |
-
|
473 |
# Enable tensorboard only on the master node
|
|
|
474 |
if has_tensorboard and jax.process_index() == 0:
|
475 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
# Data collator
|
478 |
# This one will take care of randomly masking the tokens.
|
@@ -521,7 +530,7 @@ if __name__ == "__main__":
|
|
521 |
learning_rate=linear_decay_lr_schedule_fn,
|
522 |
b1=training_args.adam_beta1,
|
523 |
b2=training_args.adam_beta2,
|
524 |
-
eps=
|
525 |
weight_decay=training_args.weight_decay,
|
526 |
mask=decay_mask_fn,
|
527 |
)
|
@@ -601,7 +610,7 @@ if __name__ == "__main__":
|
|
601 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
602 |
|
603 |
# Gather the indexes for creating the batch and do a training step
|
604 |
-
for
|
605 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
606 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
607 |
|
@@ -610,11 +619,31 @@ if __name__ == "__main__":
|
|
610 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
611 |
train_metrics.append(train_metric)
|
612 |
|
613 |
-
|
614 |
|
615 |
-
|
616 |
-
|
617 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
618 |
|
619 |
# ======================== Evaluating ==============================
|
620 |
num_eval_samples = len(tokenized_datasets["validation"])
|
@@ -645,7 +674,7 @@ if __name__ == "__main__":
|
|
645 |
# Save metrics
|
646 |
if has_tensorboard and jax.process_index() == 0:
|
647 |
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
648 |
-
|
649 |
|
650 |
# save checkpoint after each epoch and push checkpoint to the hub
|
651 |
if jax.process_index() == 0:
|
|
|
56 |
)
|
57 |
|
58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
|
60 |
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
61 |
|
|
|
110 |
dataset_config_name: Optional[str] = field(
|
111 |
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
112 |
)
|
113 |
+
dataset_streaming: bool = field(
|
114 |
+
default=False, metadata={"help": "Whether dataset_name should be retrieved using streaming if available."}
|
115 |
+
)
|
116 |
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
117 |
validation_file: Optional[str] = field(
|
118 |
default=None,
|
|
|
256 |
return batch_idx
|
257 |
|
258 |
|
259 |
+
def write_train_metric(summary_writer, train_metrics, train_time, step):
|
260 |
summary_writer.scalar("train_time", train_time, step)
|
261 |
|
262 |
train_metrics = get_metrics(train_metrics)
|
|
|
265 |
for i, val in enumerate(vals):
|
266 |
summary_writer.scalar(tag, val, step - len(vals) + i + 1)
|
267 |
|
268 |
+
|
269 |
+
def write_eval_metric(summary_writer, eval_metrics, step):
|
270 |
for metric_name, value in eval_metrics.items():
|
271 |
summary_writer.scalar(f"eval_{metric_name}", value, step)
|
272 |
|
|
|
304 |
|
305 |
# Log on each process the small summary:
|
306 |
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
307 |
|
308 |
# Set the verbosity to info of the Transformers logger (on main process only):
|
309 |
logger.info(f"Training/evaluation parameters {training_args}")
|
|
|
322 |
# download the dataset.
|
323 |
if data_args.dataset_name is not None:
|
324 |
# Downloading and loading a dataset from the hub.
|
325 |
+
datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.dataset_streaming)
|
326 |
|
327 |
if "validation" not in datasets.keys():
|
328 |
datasets["validation"] = load_dataset(
|
|
|
330 |
data_args.dataset_config_name,
|
331 |
split=f"train[:{data_args.validation_split_percentage}%]",
|
332 |
cache_dir=model_args.cache_dir,
|
333 |
+
streaming=data_args.dataset_streaming,
|
334 |
)
|
335 |
datasets["train"] = load_dataset(
|
336 |
data_args.dataset_name,
|
337 |
data_args.dataset_config_name,
|
338 |
split=f"train[{data_args.validation_split_percentage}%:]",
|
339 |
cache_dir=model_args.cache_dir,
|
340 |
+
streaming=data_args.dataset_streaming,
|
341 |
)
|
342 |
else:
|
343 |
data_files = {}
|
|
|
456 |
num_proc=data_args.preprocessing_num_workers,
|
457 |
load_from_cache_file=not data_args.overwrite_cache,
|
458 |
)
|
|
|
459 |
# Enable tensorboard only on the master node
|
460 |
+
has_tensorboard = is_tensorboard_available()
|
461 |
if has_tensorboard and jax.process_index() == 0:
|
462 |
+
try:
|
463 |
+
from flax.metrics.tensorboard import SummaryWriter
|
464 |
+
summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
|
465 |
+
# Enable Weight&Biases
|
466 |
+
import wandb
|
467 |
+
wandb.init(
|
468 |
+
entity='wandb',
|
469 |
+
project='hf-flax-bertin-roberta-es',
|
470 |
+
sync_tensorboard=True,
|
471 |
+
)
|
472 |
+
wandb.config.update(training_args)
|
473 |
+
wandb.config.update(model_args)
|
474 |
+
wandb.config.update(data_args)
|
475 |
+
except ImportError as ie:
|
476 |
+
has_tensorboard = False
|
477 |
+
logger.warning(
|
478 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
479 |
+
)
|
480 |
+
else:
|
481 |
+
logger.warning(
|
482 |
+
"Unable to display metrics through TensorBoard because the package is not installed: "
|
483 |
+
"Please run pip install tensorboard to enable."
|
484 |
+
)
|
485 |
|
486 |
# Data collator
|
487 |
# This one will take care of randomly masking the tokens.
|
|
|
530 |
learning_rate=linear_decay_lr_schedule_fn,
|
531 |
b1=training_args.adam_beta1,
|
532 |
b2=training_args.adam_beta2,
|
533 |
+
eps=training_args.adam_epsilon,
|
534 |
weight_decay=training_args.weight_decay,
|
535 |
mask=decay_mask_fn,
|
536 |
)
|
|
|
610 |
train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
|
611 |
|
612 |
# Gather the indexes for creating the batch and do a training step
|
613 |
+
for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
|
614 |
samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
|
615 |
model_inputs = data_collator(samples, pad_to_multiple_of=16)
|
616 |
|
|
|
619 |
state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
|
620 |
train_metrics.append(train_metric)
|
621 |
|
622 |
+
cur_step = epoch * (num_train_samples // train_batch_size) + step
|
623 |
|
624 |
+
if cur_step % training_args.logging_steps == 0 and cur_step > 0:
|
625 |
+
# Save metrics
|
626 |
+
train_metric = jax_utils.unreplicate(train_metric)
|
627 |
+
train_time += time.time() - train_start
|
628 |
+
if has_tensorboard and jax.process_index() == 0:
|
629 |
+
write_train_metric(summary_writer, train_metrics, train_time, cur_step)
|
630 |
+
|
631 |
+
epochs.write(
|
632 |
+
f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
|
633 |
+
)
|
634 |
+
|
635 |
+
train_metrics = []
|
636 |
+
|
637 |
+
if training_args.save_strategy == "steps" and cur_step and cur_step % training_args.save_steps == 0:
|
638 |
+
if jax.process_index() == 0:
|
639 |
+
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
640 |
+
model.save_pretrained(
|
641 |
+
Path(str(training_args.output_dir)) / "checkpoints" / f"checkpoint-{cur_step}",
|
642 |
+
params=params,
|
643 |
+
push_to_hub=training_args.push_to_hub,
|
644 |
+
temp_dir=True,
|
645 |
+
commit_message=f"Saving weights and logs of step {cur_step}",
|
646 |
+
)
|
647 |
|
648 |
# ======================== Evaluating ==============================
|
649 |
num_eval_samples = len(tokenized_datasets["validation"])
|
|
|
674 |
# Save metrics
|
675 |
if has_tensorboard and jax.process_index() == 0:
|
676 |
cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
|
677 |
+
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
678 |
|
679 |
# save checkpoint after each epoch and push checkpoint to the hub
|
680 |
if jax.process_index() == 0:
|
tokens.py
CHANGED
@@ -3,12 +3,14 @@ from datasets import load_dataset
|
|
3 |
from tokenizers import ByteLevelBPETokenizer
|
4 |
|
5 |
# Load dataset
|
6 |
-
dataset = load_dataset("
|
|
|
7 |
# Instantiate tokenizer
|
8 |
tokenizer = ByteLevelBPETokenizer()
|
9 |
def batch_iterator(batch_size=100_000_000):
|
10 |
for i in range(0, len(dataset), batch_size):
|
11 |
yield dataset["text"][i: i + batch_size]
|
|
|
12 |
# Customized training
|
13 |
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
14 |
"<s>",
|
|
|
3 |
from tokenizers import ByteLevelBPETokenizer
|
4 |
|
5 |
# Load dataset
|
6 |
+
dataset = load_dataset("oscar", "unshuffled_deduplicated_es")
|
7 |
+
|
8 |
# Instantiate tokenizer
|
9 |
tokenizer = ByteLevelBPETokenizer()
|
10 |
def batch_iterator(batch_size=100_000_000):
|
11 |
for i in range(0, len(dataset), batch_size):
|
12 |
yield dataset["text"][i: i + batch_size]
|
13 |
+
|
14 |
# Customized training
|
15 |
tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
|
16 |
"<s>",
|