final training stopped at 37500 of 50000
Browse files- .run_t5_mlm_flax_streaming.py.swp +0 -0
- events.out.tfevents.1626337452.t1v-n-e90463ba-w-0.346366.3.v2 +2 -2
- events.out.tfevents.1627000897.t1v-n-e90463ba-w-0.387907.3.v2 +3 -0
- events.out.tfevents.1627065320.t1v-n-e90463ba-w-0.398067.3.v2 +3 -0
- events.out.tfevents.1627076517.t1v-n-e90463ba-w-0.400064.3.v2 +3 -0
- run.sh +1 -0
- run_t5_mlm_flax_streaming.py +11 -1
.run_t5_mlm_flax_streaming.py.swp
ADDED
Binary file (16.4 kB). View file
|
|
events.out.tfevents.1626337452.t1v-n-e90463ba-w-0.346366.3.v2
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2886b719b5b9a9a977222e308c32a9c3f726b722f2683658c1a603b4d6572abb
|
3 |
+
size 5711137
|
events.out.tfevents.1627000897.t1v-n-e90463ba-w-0.387907.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ac502db72300246192fdf676f06ee60927605b4b27bae345d81d23fdcb4def8a
|
3 |
+
size 220634
|
events.out.tfevents.1627065320.t1v-n-e90463ba-w-0.398067.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:63ac734e6482b51764eeb77033dab1cdac8bb766271dc26fa2636f92634707b8
|
3 |
+
size 40
|
events.out.tfevents.1627076517.t1v-n-e90463ba-w-0.400064.3.v2
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36e4a9799fa558b0f6bf4edeb2e4fd8acde9159fb3361d74682b9f9ef6091fa2
|
3 |
+
size 40
|
run.sh
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
./run_t5_mlm_flax_streaming.py --output_dir="./norwegian-mt5" --model_name_or_path="./norwegian-mt5" --dataset_name="pere/nb_nn_balanced_shuffled" --max_seq_length="512" --per_device_train_batch_size="16" --learning_rate="1e-2" --weight_decay="0.001" --warmup_steps="5000" --overwrite_output_dir --num_train_epochs="5" --logging_steps="500" --save_steps="2500" --eval_steps="2500" --adafactor --push_to_hub --adafactor --preprocessing_num_workers 94
|
run_t5_mlm_flax_streaming.py
CHANGED
@@ -551,7 +551,17 @@ if __name__ == "__main__":
|
|
551 |
rng = jax.random.PRNGKey(training_args.seed)
|
552 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
553 |
|
554 |
-
model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
555 |
|
556 |
# Data collator
|
557 |
# This one will take care of randomly masking the tokens.
|
|
|
551 |
rng = jax.random.PRNGKey(training_args.seed)
|
552 |
dropout_rngs = jax.random.split(rng, jax.local_device_count())
|
553 |
|
554 |
+
#model = FlaxT5ForConditionalGeneration(config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype))
|
555 |
+
|
556 |
+
if model_args.model_name_or_path:
|
557 |
+
model = FlaxT5ForConditionalGeneration.from_pretrained(
|
558 |
+
model_args.model_name_or_path, config=config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
559 |
+
)
|
560 |
+
else:
|
561 |
+
model = FlaxT5ForConditionalGeneration.from_config(
|
562 |
+
config, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype)
|
563 |
+
)
|
564 |
+
|
565 |
|
566 |
# Data collator
|
567 |
# This one will take care of randomly masking the tokens.
|