versae commited on
Commit
d988382
1 Parent(s): 48f8c78

Adding checkpointing, wandb, and new mlm script

Browse files
Files changed (4) hide show
  1. README.md +29 -1
  2. perplexity.py +22 -0
  3. run_mlm_flax.py +60 -31
  4. tokens.py +3 -1
README.md CHANGED
@@ -1,19 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
1
  # BERTIN
 
2
  BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
3
 
4
  This is part of the
5
  [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  ## Team members
 
8
  - Javier de la Rosa (versae)
9
  - Manu Romero (mrm8488)
10
  - María Grandury (mariagrandury)
11
  - Ari Polakov (aripo99)
12
  - Pablogps
13
  - daveni
14
- - Sri Lakshmi
15
 
16
  ## Useful links
 
17
  - [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
18
  - [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
19
  - [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
 
1
+ ---
2
+ language: no
3
+ license: CC-BY 4.0
4
+ tags:
5
+ - spanish
6
+ - roberta
7
+ pipeline_tag: fill-mask
8
+ widget:
9
+ - text: "Lo hizo en un abrir y cerar de <mask>."
10
+ ---
11
+
12
  # BERTIN
13
+
14
  BERTIN is a series of BERT-based models for Spanish. This one is a RoBERTa-large model trained from scratch on the Spanish portion of mC4 using [Flax](https://github.com/google/flax), including training scripts.
15
 
16
  This is part of the
17
  [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
18
 
19
+ ## Spanish mC4
20
+
21
+ The Spanish portion of mC4 containes about 416 million records and 235 billion words.
22
+
23
+ ```bash
24
+ $ zcat c4/multilingual/c4-es*.tfrecord*.json.gz | wc -l
25
+ 416057992
26
+ ```
27
+
28
+ ```bash
29
+ $ zcat c4/multilingual/c4-es*.tfrecord-*.json.gz | jq -r '.text | split(" ") | length' | paste -s -d+ - | bc
30
+ 235303687795
31
+ ```
32
+
33
  ## Team members
34
+
35
  - Javier de la Rosa (versae)
36
  - Manu Romero (mrm8488)
37
  - María Grandury (mariagrandury)
38
  - Ari Polakov (aripo99)
39
  - Pablogps
40
  - daveni
41
+ - Sri Lakshmi
42
 
43
  ## Useful links
44
+
45
  - [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
46
  - [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
47
  - [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
perplexity.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import kenlm
3
+ from datasets import load_dataset
4
+ from tqdm import tqdm
5
+
6
+
7
+ def pp(log_score, length):
8
+ return 10.0 ** (-log_score / length)
9
+
10
+ # http://dl.fbaipublicfiles.com/cc_net/lm/es.arpa.bin
11
+ model = kenlm.Model("es.arpa.bin")
12
+ mc4 = load_dataset("mc4", "es", streaming=True)
13
+ with open("mc4-es-perplexity.txt", "w") as f:
14
+ for sample in tqdm(mc4["train"].shuffle(buffer_size=100_000), total=416057992):
15
+ lines = sample["text"].split("\n")
16
+ doc_log_score, doc_length = 0, 0
17
+ for line in lines:
18
+ log_score = model.score(line)
19
+ length = len(line.split()) + 1
20
+ doc_log_score += log_score
21
+ doc_length += length
22
+ f.write(f"{pp(doc_log_score, doc_length)}\n")
run_mlm_flax.py CHANGED
@@ -56,22 +56,6 @@ from transformers import (
56
  )
57
 
58
 
59
- # Cache the result
60
- has_tensorboard = is_tensorboard_available()
61
- if has_tensorboard:
62
- try:
63
- from flax.metrics.tensorboard import SummaryWriter
64
- except ImportError as ie:
65
- has_tensorboard = False
66
- print(f"Unable to display metrics through TensorBoard because some package are not installed: {ie}")
67
-
68
- else:
69
- print(
70
- "Unable to display metrics through TensorBoard because the package is not installed: "
71
- "Please run pip install tensorboard to enable."
72
- )
73
-
74
-
75
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
76
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
77
 
@@ -126,6 +110,9 @@ class DataTrainingArguments:
126
  dataset_config_name: Optional[str] = field(
127
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
128
  )
 
 
 
129
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
130
  validation_file: Optional[str] = field(
131
  default=None,
@@ -269,7 +256,7 @@ def generate_batch_splits(samples_idx: jnp.ndarray, batch_size: int) -> jnp.ndar
269
  return batch_idx
270
 
271
 
272
- def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
273
  summary_writer.scalar("train_time", train_time, step)
274
 
275
  train_metrics = get_metrics(train_metrics)
@@ -278,6 +265,8 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
278
  for i, val in enumerate(vals):
279
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
280
 
 
 
281
  for metric_name, value in eval_metrics.items():
282
  summary_writer.scalar(f"eval_{metric_name}", value, step)
283
 
@@ -315,10 +304,6 @@ if __name__ == "__main__":
315
 
316
  # Log on each process the small summary:
317
  logger = logging.getLogger(__name__)
318
- #logger.warning(
319
- # f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
320
- # + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
321
- #)
322
 
323
  # Set the verbosity to info of the Transformers logger (on main process only):
324
  logger.info(f"Training/evaluation parameters {training_args}")
@@ -337,7 +322,7 @@ if __name__ == "__main__":
337
  # download the dataset.
338
  if data_args.dataset_name is not None:
339
  # Downloading and loading a dataset from the hub.
340
- datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir)
341
 
342
  if "validation" not in datasets.keys():
343
  datasets["validation"] = load_dataset(
@@ -345,12 +330,14 @@ if __name__ == "__main__":
345
  data_args.dataset_config_name,
346
  split=f"train[:{data_args.validation_split_percentage}%]",
347
  cache_dir=model_args.cache_dir,
 
348
  )
349
  datasets["train"] = load_dataset(
350
  data_args.dataset_name,
351
  data_args.dataset_config_name,
352
  split=f"train[{data_args.validation_split_percentage}%:]",
353
  cache_dir=model_args.cache_dir,
 
354
  )
355
  else:
356
  data_files = {}
@@ -469,10 +456,32 @@ if __name__ == "__main__":
469
  num_proc=data_args.preprocessing_num_workers,
470
  load_from_cache_file=not data_args.overwrite_cache,
471
  )
472
-
473
  # Enable tensorboard only on the master node
 
474
  if has_tensorboard and jax.process_index() == 0:
475
- summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
476
 
477
  # Data collator
478
  # This one will take care of randomly masking the tokens.
@@ -521,7 +530,7 @@ if __name__ == "__main__":
521
  learning_rate=linear_decay_lr_schedule_fn,
522
  b1=training_args.adam_beta1,
523
  b2=training_args.adam_beta2,
524
- eps=1e-8,
525
  weight_decay=training_args.weight_decay,
526
  mask=decay_mask_fn,
527
  )
@@ -601,7 +610,7 @@ if __name__ == "__main__":
601
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
602
 
603
  # Gather the indexes for creating the batch and do a training step
604
- for i, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
605
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
606
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
607
 
@@ -610,11 +619,31 @@ if __name__ == "__main__":
610
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
611
  train_metrics.append(train_metric)
612
 
613
- train_time += time.time() - train_start
614
 
615
- epochs.write(
616
- f"Epoch... ({epoch + 1}/{num_epochs} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
617
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
618
 
619
  # ======================== Evaluating ==============================
620
  num_eval_samples = len(tokenized_datasets["validation"])
@@ -645,7 +674,7 @@ if __name__ == "__main__":
645
  # Save metrics
646
  if has_tensorboard and jax.process_index() == 0:
647
  cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
648
- write_metric(summary_writer, train_metrics, eval_metrics, train_time, cur_step)
649
 
650
  # save checkpoint after each epoch and push checkpoint to the hub
651
  if jax.process_index() == 0:
 
56
  )
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  MODEL_CONFIG_CLASSES = list(FLAX_MODEL_FOR_MASKED_LM_MAPPING.keys())
60
  MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
61
 
 
110
  dataset_config_name: Optional[str] = field(
111
  default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
112
  )
113
+ dataset_streaming: bool = field(
114
+ default=False, metadata={"help": "Whether dataset_name should be retrieved using streaming if available."}
115
+ )
116
  train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
117
  validation_file: Optional[str] = field(
118
  default=None,
 
256
  return batch_idx
257
 
258
 
259
+ def write_train_metric(summary_writer, train_metrics, train_time, step):
260
  summary_writer.scalar("train_time", train_time, step)
261
 
262
  train_metrics = get_metrics(train_metrics)
 
265
  for i, val in enumerate(vals):
266
  summary_writer.scalar(tag, val, step - len(vals) + i + 1)
267
 
268
+
269
+ def write_eval_metric(summary_writer, eval_metrics, step):
270
  for metric_name, value in eval_metrics.items():
271
  summary_writer.scalar(f"eval_{metric_name}", value, step)
272
 
 
304
 
305
  # Log on each process the small summary:
306
  logger = logging.getLogger(__name__)
 
 
 
 
307
 
308
  # Set the verbosity to info of the Transformers logger (on main process only):
309
  logger.info(f"Training/evaluation parameters {training_args}")
 
322
  # download the dataset.
323
  if data_args.dataset_name is not None:
324
  # Downloading and loading a dataset from the hub.
325
+ datasets = load_dataset(data_args.dataset_name, data_args.dataset_config_name, cache_dir=model_args.cache_dir, streaming=data_args.dataset_streaming)
326
 
327
  if "validation" not in datasets.keys():
328
  datasets["validation"] = load_dataset(
 
330
  data_args.dataset_config_name,
331
  split=f"train[:{data_args.validation_split_percentage}%]",
332
  cache_dir=model_args.cache_dir,
333
+ streaming=data_args.dataset_streaming,
334
  )
335
  datasets["train"] = load_dataset(
336
  data_args.dataset_name,
337
  data_args.dataset_config_name,
338
  split=f"train[{data_args.validation_split_percentage}%:]",
339
  cache_dir=model_args.cache_dir,
340
+ streaming=data_args.dataset_streaming,
341
  )
342
  else:
343
  data_files = {}
 
456
  num_proc=data_args.preprocessing_num_workers,
457
  load_from_cache_file=not data_args.overwrite_cache,
458
  )
 
459
  # Enable tensorboard only on the master node
460
+ has_tensorboard = is_tensorboard_available()
461
  if has_tensorboard and jax.process_index() == 0:
462
+ try:
463
+ from flax.metrics.tensorboard import SummaryWriter
464
+ summary_writer = SummaryWriter(log_dir=Path(training_args.output_dir))
465
+ # Enable Weight&Biases
466
+ import wandb
467
+ wandb.init(
468
+ entity='wandb',
469
+ project='hf-flax-bertin-roberta-es',
470
+ sync_tensorboard=True,
471
+ )
472
+ wandb.config.update(training_args)
473
+ wandb.config.update(model_args)
474
+ wandb.config.update(data_args)
475
+ except ImportError as ie:
476
+ has_tensorboard = False
477
+ logger.warning(
478
+ f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
479
+ )
480
+ else:
481
+ logger.warning(
482
+ "Unable to display metrics through TensorBoard because the package is not installed: "
483
+ "Please run pip install tensorboard to enable."
484
+ )
485
 
486
  # Data collator
487
  # This one will take care of randomly masking the tokens.
 
530
  learning_rate=linear_decay_lr_schedule_fn,
531
  b1=training_args.adam_beta1,
532
  b2=training_args.adam_beta2,
533
+ eps=training_args.adam_epsilon,
534
  weight_decay=training_args.weight_decay,
535
  mask=decay_mask_fn,
536
  )
 
610
  train_batch_idx = generate_batch_splits(train_samples_idx, train_batch_size)
611
 
612
  # Gather the indexes for creating the batch and do a training step
613
+ for step, batch_idx in enumerate(tqdm(train_batch_idx, desc="Training...", position=1)):
614
  samples = [tokenized_datasets["train"][int(idx)] for idx in batch_idx]
615
  model_inputs = data_collator(samples, pad_to_multiple_of=16)
616
 
 
619
  state, train_metric, dropout_rngs = p_train_step(state, model_inputs, dropout_rngs)
620
  train_metrics.append(train_metric)
621
 
622
+ cur_step = epoch * (num_train_samples // train_batch_size) + step
623
 
624
+ if cur_step % training_args.logging_steps == 0 and cur_step > 0:
625
+ # Save metrics
626
+ train_metric = jax_utils.unreplicate(train_metric)
627
+ train_time += time.time() - train_start
628
+ if has_tensorboard and jax.process_index() == 0:
629
+ write_train_metric(summary_writer, train_metrics, train_time, cur_step)
630
+
631
+ epochs.write(
632
+ f"Step... ({cur_step} | Loss: {train_metric['loss']}, Learning Rate: {train_metric['learning_rate']})"
633
+ )
634
+
635
+ train_metrics = []
636
+
637
+ if training_args.save_strategy == "steps" and cur_step and cur_step % training_args.save_steps == 0:
638
+ if jax.process_index() == 0:
639
+ params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
640
+ model.save_pretrained(
641
+ Path(str(training_args.output_dir)) / "checkpoints" / f"checkpoint-{cur_step}",
642
+ params=params,
643
+ push_to_hub=training_args.push_to_hub,
644
+ temp_dir=True,
645
+ commit_message=f"Saving weights and logs of step {cur_step}",
646
+ )
647
 
648
  # ======================== Evaluating ==============================
649
  num_eval_samples = len(tokenized_datasets["validation"])
 
674
  # Save metrics
675
  if has_tensorboard and jax.process_index() == 0:
676
  cur_step = epoch * (len(tokenized_datasets["train"]) // train_batch_size)
677
+ write_eval_metric(summary_writer, eval_metrics, cur_step)
678
 
679
  # save checkpoint after each epoch and push checkpoint to the hub
680
  if jax.process_index() == 0:
tokens.py CHANGED
@@ -3,12 +3,14 @@ from datasets import load_dataset
3
  from tokenizers import ByteLevelBPETokenizer
4
 
5
  # Load dataset
6
- dataset = load_dataset("large_spanish_corpus", split="train")
 
7
  # Instantiate tokenizer
8
  tokenizer = ByteLevelBPETokenizer()
9
  def batch_iterator(batch_size=100_000_000):
10
  for i in range(0, len(dataset), batch_size):
11
  yield dataset["text"][i: i + batch_size]
 
12
  # Customized training
13
  tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
14
  "<s>",
 
3
  from tokenizers import ByteLevelBPETokenizer
4
 
5
  # Load dataset
6
+ dataset = load_dataset("oscar", "unshuffled_deduplicated_es")
7
+
8
  # Instantiate tokenizer
9
  tokenizer = ByteLevelBPETokenizer()
10
  def batch_iterator(batch_size=100_000_000):
11
  for i in range(0, len(dataset), batch_size):
12
  yield dataset["text"][i: i + batch_size]
13
+
14
  # Customized training
15
  tokenizer.train_from_iterator(batch_iterator(), vocab_size=50265, min_frequency=2, special_tokens=[
16
  "<s>",