Update app.py
Browse files
app.py
CHANGED
@@ -558,7 +558,7 @@ def modelTFT(csv_file, prax):
|
|
558 |
train_dataloader,
|
559 |
val_dataloader,
|
560 |
model_path="optuna_test",
|
561 |
-
n_trials=
|
562 |
max_epochs=MAX_EPOCHS,
|
563 |
gradient_clip_val_range=(0.01, 0.5),
|
564 |
hidden_size_range=(8, 64),
|
@@ -568,7 +568,7 @@ def modelTFT(csv_file, prax):
|
|
568 |
dropout_range=(0.1, 0.3),
|
569 |
trainer_kwargs=dict(limit_train_batches=30),
|
570 |
reduce_on_plateau_patience=4,
|
571 |
-
pruner=optuna.pruners.MedianPruner(n_min_trials=
|
572 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
573 |
)
|
574 |
#torch.cuda.empty_cache()
|
@@ -582,6 +582,7 @@ def modelTFT(csv_file, prax):
|
|
582 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
583 |
callbacks=[lr_logger, early_stop_callback],
|
584 |
logger=logger,
|
|
|
585 |
)
|
586 |
|
587 |
tft = TemporalFusionTransformer.from_dataset(
|
@@ -795,7 +796,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
795 |
train_dataloader,
|
796 |
val_dataloader,
|
797 |
model_path="optuna_test",
|
798 |
-
n_trials=
|
799 |
max_epochs=MAX_EPOCHS,
|
800 |
gradient_clip_val_range=(0.01, 0.5),
|
801 |
hidden_size_range=(8, 64),
|
@@ -805,7 +806,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
805 |
dropout_range=(0.1, 0.3),
|
806 |
trainer_kwargs=dict(limit_train_batches=30),
|
807 |
reduce_on_plateau_patience=4,
|
808 |
-
pruner=optuna.pruners.MedianPruner(n_min_trials=
|
809 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
810 |
)
|
811 |
#torch.cuda.empty_cache()
|
@@ -819,6 +820,7 @@ def modelTFT_OpenGap(csv_file, prax):
|
|
819 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
820 |
callbacks=[lr_logger, early_stop_callback],
|
821 |
logger=logger,
|
|
|
822 |
)
|
823 |
|
824 |
tft = TemporalFusionTransformer.from_dataset(
|
|
|
558 |
train_dataloader,
|
559 |
val_dataloader,
|
560 |
model_path="optuna_test",
|
561 |
+
n_trials=5,
|
562 |
max_epochs=MAX_EPOCHS,
|
563 |
gradient_clip_val_range=(0.01, 0.5),
|
564 |
hidden_size_range=(8, 64),
|
|
|
568 |
dropout_range=(0.1, 0.3),
|
569 |
trainer_kwargs=dict(limit_train_batches=30),
|
570 |
reduce_on_plateau_patience=4,
|
571 |
+
pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_startup_trials=3),
|
572 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
573 |
)
|
574 |
#torch.cuda.empty_cache()
|
|
|
582 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
583 |
callbacks=[lr_logger, early_stop_callback],
|
584 |
logger=logger,
|
585 |
+
precision="bf16-mixed",
|
586 |
)
|
587 |
|
588 |
tft = TemporalFusionTransformer.from_dataset(
|
|
|
796 |
train_dataloader,
|
797 |
val_dataloader,
|
798 |
model_path="optuna_test",
|
799 |
+
n_trials=5,
|
800 |
max_epochs=MAX_EPOCHS,
|
801 |
gradient_clip_val_range=(0.01, 0.5),
|
802 |
hidden_size_range=(8, 64),
|
|
|
806 |
dropout_range=(0.1, 0.3),
|
807 |
trainer_kwargs=dict(limit_train_batches=30),
|
808 |
reduce_on_plateau_patience=4,
|
809 |
+
pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_warmup_steps=3),
|
810 |
use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
|
811 |
)
|
812 |
#torch.cuda.empty_cache()
|
|
|
820 |
#fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
|
821 |
callbacks=[lr_logger, early_stop_callback],
|
822 |
logger=logger,
|
823 |
+
precision="bf16-mixed",
|
824 |
)
|
825 |
|
826 |
tft = TemporalFusionTransformer.from_dataset(
|