pere commited on
Commit
d7f0cd6
1 Parent(s): fe264c0

final training stopped at 37500 of 50000

Browse files
.run_t5_mlm_flax_streaming.py.swp ADDED
Binary file (16.4 kB). View file
 
events.out.tfevents.1626337452.t1v-n-e90463ba-w-0.346366.3.v2 CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3da1b6317e90c6df82bfee8608f82c3e9493305abea2149f0391d700f55af098
3
- size 5561997
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2886b719b5b9a9a977222e308c32a9c3f726b722f2683658c1a603b4d6572abb
3
+ size 5711137
events.out.tfevents.1627000897.t1v-n-e90463ba-w-0.387907.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac502db72300246192fdf676f06ee60927605b4b27bae345d81d23fdcb4def8a
3
+ size 220634
events.out.tfevents.1627065320.t1v-n-e90463ba-w-0.398067.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:63ac734e6482b51764eeb77033dab1cdac8bb766271dc26fa2636f92634707b8
3
+ size 40
events.out.tfevents.1627076517.t1v-n-e90463ba-w-0.400064.3.v2 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36e4a9799fa558b0f6bf4edeb2e4fd8acde9159fb3361d74682b9f9ef6091fa2
3
+ size 40
run.sh ADDED
@@ -0,0 +1 @@
 
 
1
+ ./run_t5_mlm_flax_streaming.py --output_dir="./norwegian-mt5" --model_name_or_path="./norwegian-mt5" --dataset_name="pere/nb_nn_balanced_shuffled" --max_seq_length="512" --per_device_train_batch_size="16" --learning_rate="1e-2" --weight_decay="0.001" --warmup_steps="5000" --overwrite_output_dir --num_train_epochs="5" --logging_steps="500" --save_steps="2500" --eval_steps="2500" --adafactor --push_to_hub --adafactor --preprocessing_num_workers 94
run_t5_mlm_flax_streaming.py CHANGED
@@ -551,7 +551,17 @@ if __name__ == "__main__":
551
  rng = jax.random.PRNGKey(training_args.seed)
552
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
553
 
554
- model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
 
 
 
 
 
 
 
 
 
 
555
 
556
  # Data collator
557
  # This one will take care of randomly masking the tokens.
 
551
  rng = jax.random.PRNGKey(training_args.seed)
552
  dropout_rngs = jax.random.split(rng, jax.local_device_count())
553
 
554
+ #model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
555
+
556
+ if model_args.model_name_or_path:
557
+ model = FlaxT5ForConditionalGeneration.from_pretrained(
558
+ model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
559
+ )
560
+ else:
561
+ model = FlaxT5ForConditionalGeneration.from_config(
562
+ config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
563
+ )
564
+
565
 
566
  # Data collator
567
  # This one will take care of randomly masking the tokens.