kryox64 commited on
Commit
a734b5e
1 Parent(s): cdec8c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -558,7 +558,7 @@ def modelTFT(csv_file, prax):
558
  train_dataloader,
559
  val_dataloader,
560
  model_path="optuna_test",
561
- n_trials=7,
562
  max_epochs=MAX_EPOCHS,
563
  gradient_clip_val_range=(0.01, 0.5),
564
  hidden_size_range=(8, 64),
@@ -568,7 +568,7 @@ def modelTFT(csv_file, prax):
568
  dropout_range=(0.1, 0.3),
569
  trainer_kwargs=dict(limit_train_batches=30),
570
  reduce_on_plateau_patience=4,
571
- pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_startup_trials=5),
572
  use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
573
  )
574
  #torch.cuda.empty_cache()
@@ -582,6 +582,7 @@ def modelTFT(csv_file, prax):
582
  #fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
583
  callbacks=[lr_logger, early_stop_callback],
584
  logger=logger,
 
585
  )
586
 
587
  tft = TemporalFusionTransformer.from_dataset(
@@ -795,7 +796,7 @@ def modelTFT_OpenGap(csv_file, prax):
795
  train_dataloader,
796
  val_dataloader,
797
  model_path="optuna_test",
798
- n_trials=7,
799
  max_epochs=MAX_EPOCHS,
800
  gradient_clip_val_range=(0.01, 0.5),
801
  hidden_size_range=(8, 64),
@@ -805,7 +806,7 @@ def modelTFT_OpenGap(csv_file, prax):
805
  dropout_range=(0.1, 0.3),
806
  trainer_kwargs=dict(limit_train_batches=30),
807
  reduce_on_plateau_patience=4,
808
- pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=5),
809
  use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
810
  )
811
  #torch.cuda.empty_cache()
@@ -819,6 +820,7 @@ def modelTFT_OpenGap(csv_file, prax):
819
  #fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
820
  callbacks=[lr_logger, early_stop_callback],
821
  logger=logger,
 
822
  )
823
 
824
  tft = TemporalFusionTransformer.from_dataset(
 
558
  train_dataloader,
559
  val_dataloader,
560
  model_path="optuna_test",
561
+ n_trials=5,
562
  max_epochs=MAX_EPOCHS,
563
  gradient_clip_val_range=(0.01, 0.5),
564
  hidden_size_range=(8, 64),
 
568
  dropout_range=(0.1, 0.3),
569
  trainer_kwargs=dict(limit_train_batches=30),
570
  reduce_on_plateau_patience=4,
571
+ pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_startup_trials=3),
572
  use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
573
  )
574
  #torch.cuda.empty_cache()
 
582
  #fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
583
  callbacks=[lr_logger, early_stop_callback],
584
  logger=logger,
585
+ precision="bf16-mixed",
586
  )
587
 
588
  tft = TemporalFusionTransformer.from_dataset(
 
796
  train_dataloader,
797
  val_dataloader,
798
  model_path="optuna_test",
799
+ n_trials=5,
800
  max_epochs=MAX_EPOCHS,
801
  gradient_clip_val_range=(0.01, 0.5),
802
  hidden_size_range=(8, 64),
 
806
  dropout_range=(0.1, 0.3),
807
  trainer_kwargs=dict(limit_train_batches=30),
808
  reduce_on_plateau_patience=4,
809
+ pruner=optuna.pruners.MedianPruner(n_min_trials=3, n_warmup_steps=3),
810
  use_learning_rate_finder=False, # use Optuna to find ideal learning rate or use in-built learning rate finder
811
  )
812
  #torch.cuda.empty_cache()
 
820
  #fast_dev_run=True, # comment in to check that networkor dataset has no serious bugs
821
  callbacks=[lr_logger, early_stop_callback],
822
  logger=logger,
823
+ precision="bf16-mixed",
824
  )
825
 
826
  tft = TemporalFusionTransformer.from_dataset(