LKCell / base_ml /base_early_stopping.py
xiazhi1
initial commit
aea73e2
raw
history blame
2.76 kB
# -*- coding: utf-8 -*-
# Base Machine Learning Experiment
#
# @ Fabian Hörst, [email protected]
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
import logging
logger = logging.getLogger("__main__")
logger.addHandler(logging.NullHandler())
import wandb
class EarlyStopping:
"""Early Stopping Class
Args:
patience (int): Patience to wait before stopping
strategy (str, optional): Optimization strategy.
Please select 'minimize' or 'maximize' for strategy. Defaults to "minimize".
"""
def __init__(self, patience: int, strategy: str = "minimize"):
assert strategy.lower() in [
"minimize",
"maximize",
], "Please select 'minimize' or 'maximize' for strategy"
self.patience = patience
self.counter = 0
self.strategy = strategy.lower()
self.best_metric = None
self.best_epoch = None
self.early_stop = False
logger.info(
f"Using early stopping with a range of {self.patience} and {self.strategy} strategy"
)
def __call__(self, metric: float, epoch: int) -> bool:
"""Early stopping update call
Args:
metric (float): Metric for early stopping
epoch (int): Current epoch
Returns:
bool: Returns true if the model is performing better than the current best model,
otherwise false
"""
if self.best_metric is None:
self.best_metric = metric
self.best_epoch = epoch
return True
else:
if self.strategy == "minimize":
if self.best_metric >= metric:
self.best_metric = metric
self.best_epoch = epoch
self.counter = 0
wandb.run.summary["Best-Epoch"] = epoch
wandb.run.summary["Best-Metric"] = metric
return True
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False
elif self.strategy == "maximize":
if self.best_metric <= metric:
self.best_metric = metric
self.best_epoch = epoch
self.counter = 0
wandb.run.summary["Best-Epoch"] = epoch
wandb.run.summary["Best-Metric"] = metric
return True
else:
self.counter += 1
if self.counter >= self.patience:
self.early_stop = True
return False