|
import matplotlib |
|
from torch.nn import DataParallel |
|
from torch.nn.parallel import DistributedDataParallel |
|
|
|
matplotlib.use('Agg') |
|
import glob |
|
import itertools |
|
import subprocess |
|
import threading |
|
import traceback |
|
|
|
from pytorch_lightning.callbacks import GradientAccumulationScheduler |
|
from pytorch_lightning.callbacks import ModelCheckpoint |
|
|
|
from functools import wraps |
|
from torch.cuda._utils import _get_device_index |
|
import numpy as np |
|
import torch.optim |
|
import torch.utils.data |
|
import copy |
|
import logging |
|
import os |
|
import re |
|
import sys |
|
import torch |
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
import tqdm |
|
from torch.optim.optimizer import Optimizer |
|
|
|
|
|
def get_a_var(obj): |
|
if isinstance(obj, torch.Tensor): |
|
return obj |
|
|
|
if isinstance(obj, list) or isinstance(obj, tuple): |
|
for result in map(get_a_var, obj): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
if isinstance(obj, dict): |
|
for result in map(get_a_var, obj.items()): |
|
if isinstance(result, torch.Tensor): |
|
return result |
|
return None |
|
|
|
|
|
def data_loader(fn): |
|
""" |
|
Decorator to make any fx with this use the lazy property |
|
:param fn: |
|
:return: |
|
""" |
|
|
|
wraps(fn) |
|
attr_name = '_lazy_' + fn.__name__ |
|
|
|
def _get_data_loader(self): |
|
try: |
|
value = getattr(self, attr_name) |
|
except AttributeError: |
|
try: |
|
value = fn(self) |
|
if ( |
|
value is not None and |
|
not isinstance(value, list) and |
|
fn.__name__ in ['test_dataloader', 'val_dataloader'] |
|
): |
|
value = [value] |
|
except AttributeError as e: |
|
|
|
traceback.print_exc() |
|
error = f'{fn.__name__}: An AttributeError was encountered: ' + str(e) |
|
raise RuntimeError(error) from e |
|
setattr(self, attr_name, value) |
|
return value |
|
|
|
return _get_data_loader |
|
|
|
|
|
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): |
|
r"""Applies each `module` in :attr:`modules` in parallel on arguments |
|
contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) |
|
on each of :attr:`devices`. |
|
|
|
Args: |
|
modules (Module): modules to be parallelized |
|
inputs (tensor): inputs to the modules |
|
devices (list of int or torch.device): CUDA devices |
|
|
|
:attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and |
|
:attr:`devices` (if given) should all have same length. Moreover, each |
|
element of :attr:`inputs` can either be a single object as the only argument |
|
to a module, or a collection of positional arguments. |
|
""" |
|
assert len(modules) == len(inputs) |
|
if kwargs_tup is not None: |
|
assert len(modules) == len(kwargs_tup) |
|
else: |
|
kwargs_tup = ({},) * len(modules) |
|
if devices is not None: |
|
assert len(modules) == len(devices) |
|
else: |
|
devices = [None] * len(modules) |
|
devices = list(map(lambda x: _get_device_index(x, True), devices)) |
|
lock = threading.Lock() |
|
results = {} |
|
grad_enabled = torch.is_grad_enabled() |
|
|
|
def _worker(i, module, input, kwargs, device=None): |
|
torch.set_grad_enabled(grad_enabled) |
|
if device is None: |
|
device = get_a_var(input).get_device() |
|
try: |
|
with torch.cuda.device(device): |
|
|
|
if not isinstance(input, (list, tuple)): |
|
input = (input,) |
|
|
|
|
|
|
|
if module.training: |
|
output = module.training_step(*input, **kwargs) |
|
|
|
elif module.testing: |
|
output = module.test_step(*input, **kwargs) |
|
|
|
else: |
|
output = module.validation_step(*input, **kwargs) |
|
|
|
|
|
with lock: |
|
results[i] = output |
|
except Exception as e: |
|
with lock: |
|
results[i] = e |
|
|
|
|
|
|
|
root_m = modules[0] |
|
for m in modules[1:]: |
|
m.training = root_m.training |
|
m.testing = root_m.testing |
|
|
|
if len(modules) > 1: |
|
threads = [threading.Thread(target=_worker, |
|
args=(i, module, input, kwargs, device)) |
|
for i, (module, input, kwargs, device) in |
|
enumerate(zip(modules, inputs, kwargs_tup, devices))] |
|
|
|
for thread in threads: |
|
thread.start() |
|
for thread in threads: |
|
thread.join() |
|
else: |
|
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) |
|
|
|
outputs = [] |
|
for i in range(len(inputs)): |
|
output = results[i] |
|
if isinstance(output, Exception): |
|
raise output |
|
outputs.append(output) |
|
return outputs |
|
|
|
|
|
def _find_tensors(obj): |
|
r""" |
|
Recursively find all tensors contained in the specified object. |
|
""" |
|
if isinstance(obj, torch.Tensor): |
|
return [obj] |
|
if isinstance(obj, (list, tuple)): |
|
return itertools.chain(*map(_find_tensors, obj)) |
|
if isinstance(obj, dict): |
|
return itertools.chain(*map(_find_tensors, obj.values())) |
|
return [] |
|
|
|
|
|
class DDP(DistributedDataParallel): |
|
""" |
|
Override the forward call in lightning so it goes to training and validation step respectively |
|
""" |
|
|
|
def parallel_apply(self, replicas, inputs, kwargs): |
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) |
|
|
|
def forward(self, *inputs, **kwargs): |
|
self._sync_params() |
|
if self.device_ids: |
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
if len(self.device_ids) == 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.module.training: |
|
output = self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
output = self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
output = self.module.validation_step(*inputs[0], **kwargs[0]) |
|
else: |
|
outputs = self.parallel_apply(self._module_copies[:len(inputs)], inputs, kwargs) |
|
output = self.gather(outputs, self.output_device) |
|
else: |
|
|
|
output = self.module(*inputs, **kwargs) |
|
|
|
if torch.is_grad_enabled(): |
|
|
|
|
|
|
|
|
|
|
|
if self.find_unused_parameters: |
|
self.reducer.prepare_for_backward(list(_find_tensors(output))) |
|
else: |
|
self.reducer.prepare_for_backward([]) |
|
return output |
|
|
|
|
|
class DP(DataParallel): |
|
""" |
|
Override the forward call in lightning so it goes to training and validation step respectively |
|
""" |
|
|
|
def forward(self, *inputs, **kwargs): |
|
if not self.device_ids: |
|
return self.module(*inputs, **kwargs) |
|
|
|
for t in itertools.chain(self.module.parameters(), self.module.buffers()): |
|
if t.device != self.src_device_obj: |
|
raise RuntimeError("module must have its parameters and buffers " |
|
"on device {} (device_ids[0]) but found one of " |
|
"them on device: {}".format(self.src_device_obj, t.device)) |
|
|
|
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
|
if len(self.device_ids) == 1: |
|
|
|
if self.module.training: |
|
return self.module.training_step(*inputs[0], **kwargs[0]) |
|
elif self.module.testing: |
|
return self.module.test_step(*inputs[0], **kwargs[0]) |
|
else: |
|
return self.module.validation_step(*inputs[0], **kwargs[0]) |
|
|
|
replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) |
|
outputs = self.parallel_apply(replicas, inputs, kwargs) |
|
return self.gather(outputs, self.output_device) |
|
|
|
def parallel_apply(self, replicas, inputs, kwargs): |
|
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) |
|
|
|
|
|
class GradientAccumulationScheduler: |
|
def __init__(self, scheduling: dict): |
|
if scheduling == {}: |
|
raise TypeError("Empty dict cannot be interpreted correct") |
|
|
|
for key in scheduling.keys(): |
|
if not isinstance(key, int) or not isinstance(scheduling[key], int): |
|
raise TypeError("All epoches and accumulation factor must be integers") |
|
|
|
minimal_epoch = min(scheduling.keys()) |
|
if minimal_epoch < 1: |
|
msg = f"Epochs indexing from 1, epoch {minimal_epoch} cannot be interpreted correct" |
|
raise IndexError(msg) |
|
elif minimal_epoch != 1: |
|
scheduling.update({1: 1}) |
|
|
|
self.scheduling = scheduling |
|
self.epochs = sorted(scheduling.keys()) |
|
|
|
def on_epoch_begin(self, epoch, trainer): |
|
epoch += 1 |
|
for i in reversed(range(len(self.epochs))): |
|
if epoch >= self.epochs[i]: |
|
trainer.accumulate_grad_batches = self.scheduling.get(self.epochs[i]) |
|
break |
|
|
|
|
|
class LatestModelCheckpoint(ModelCheckpoint): |
|
def __init__(self, filepath, monitor='val_loss', verbose=0, num_ckpt_keep=5, |
|
save_weights_only=False, mode='auto', period=1, prefix='model', save_best=True): |
|
super(ModelCheckpoint, self).__init__() |
|
self.monitor = monitor |
|
self.verbose = verbose |
|
self.filepath = filepath |
|
os.makedirs(filepath, exist_ok=True) |
|
self.num_ckpt_keep = num_ckpt_keep |
|
self.save_best = save_best |
|
self.save_weights_only = save_weights_only |
|
self.period = period |
|
self.epochs_since_last_check = 0 |
|
self.prefix = prefix |
|
self.best_k_models = {} |
|
|
|
self.kth_best_model = '' |
|
self.save_top_k = 1 |
|
self.task = None |
|
if mode == 'min': |
|
self.monitor_op = np.less |
|
self.best = np.Inf |
|
self.mode = 'min' |
|
elif mode == 'max': |
|
self.monitor_op = np.greater |
|
self.best = -np.Inf |
|
self.mode = 'max' |
|
else: |
|
if 'acc' in self.monitor or self.monitor.startswith('fmeasure'): |
|
self.monitor_op = np.greater |
|
self.best = -np.Inf |
|
self.mode = 'max' |
|
else: |
|
self.monitor_op = np.less |
|
self.best = np.Inf |
|
self.mode = 'min' |
|
if os.path.exists(f'{self.filepath}/best_valid.npy'): |
|
self.best = np.load(f'{self.filepath}/best_valid.npy')[0] |
|
|
|
def get_all_ckpts(self): |
|
return sorted(glob.glob(f'{self.filepath}/{self.prefix}_ckpt_steps_*.ckpt'), |
|
key=lambda x: -int(re.findall('.*steps\_(\d+)\.ckpt', x)[0])) |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
logs = logs or {} |
|
self.epochs_since_last_check += 1 |
|
best_filepath = f'{self.filepath}/{self.prefix}_ckpt_best.pt' |
|
if self.epochs_since_last_check >= self.period: |
|
self.epochs_since_last_check = 0 |
|
filepath = f'{self.filepath}/{self.prefix}_ckpt_steps_{self.task.global_step}.ckpt' |
|
if self.verbose > 0: |
|
logging.info(f'Epoch {epoch:05d}@{self.task.global_step}: saving model to {filepath}') |
|
self._save_model(filepath) |
|
for old_ckpt in self.get_all_ckpts()[self.num_ckpt_keep:]: |
|
|
|
os.remove(old_ckpt) |
|
|
|
if self.verbose > 0: |
|
logging.info(f'Delete ckpt: {os.path.basename(old_ckpt)}') |
|
current = logs.get(self.monitor) |
|
if current is not None and self.save_best: |
|
if self.monitor_op(current, self.best): |
|
self.best = current |
|
if self.verbose > 0: |
|
logging.info( |
|
f'Epoch {epoch:05d}@{self.task.global_step}: {self.monitor} reached' |
|
f' {current:0.5f} (best {self.best:0.5f}), saving model to' |
|
f' {best_filepath} as top 1') |
|
self._save_model(best_filepath) |
|
np.save(f'{self.filepath}/best_valid.npy', [self.best]) |
|
|
|
def _save_model(self,path): |
|
return self.save_function(path) |
|
|
|
|
|
|
|
class BaseTrainer: |
|
def __init__( |
|
self, |
|
logger=True, |
|
checkpoint_callback=True, |
|
default_save_path=None, |
|
gradient_clip_val=0, |
|
process_position=0, |
|
gpus=-1, |
|
log_gpu_memory=None, |
|
show_progress_bar=True, |
|
track_grad_norm=-1, |
|
check_val_every_n_epoch=1, |
|
accumulate_grad_batches=1, |
|
max_updates=1000, |
|
min_epochs=1, |
|
val_check_interval=1.0, |
|
log_save_interval=100, |
|
row_log_interval=10, |
|
print_nan_grads=False, |
|
weights_summary='full', |
|
num_sanity_val_steps=5, |
|
resume_from_checkpoint=None, |
|
): |
|
self.log_gpu_memory = log_gpu_memory |
|
self.gradient_clip_val = gradient_clip_val |
|
self.check_val_every_n_epoch = check_val_every_n_epoch |
|
self.track_grad_norm = track_grad_norm |
|
self.on_gpu = True if (gpus and torch.cuda.is_available()) else False |
|
self.process_position = process_position |
|
self.weights_summary = weights_summary |
|
self.max_updates = max_updates |
|
self.min_epochs = min_epochs |
|
self.num_sanity_val_steps = num_sanity_val_steps |
|
self.print_nan_grads = print_nan_grads |
|
self.resume_from_checkpoint = resume_from_checkpoint |
|
self.default_save_path = default_save_path |
|
|
|
|
|
self.total_batch_idx = 0 |
|
self.running_loss = [] |
|
self.avg_loss = 0 |
|
self.batch_idx = 0 |
|
self.tqdm_metrics = {} |
|
self.callback_metrics = {} |
|
self.num_val_batches = 0 |
|
self.num_training_batches = 0 |
|
self.num_test_batches = 0 |
|
self.get_train_dataloader = None |
|
self.get_test_dataloaders = None |
|
self.get_val_dataloaders = None |
|
self.is_iterable_train_dataloader = False |
|
|
|
|
|
self.model = None |
|
self.testing = False |
|
self.disable_validation = False |
|
self.lr_schedulers = [] |
|
self.optimizers = None |
|
self.global_step = 0 |
|
self.current_epoch = 0 |
|
self.total_batches = 0 |
|
|
|
|
|
self.checkpoint_callback = checkpoint_callback |
|
self.checkpoint_callback.save_function = self.save_checkpoint |
|
self.weights_save_path = self.checkpoint_callback.filepath |
|
|
|
|
|
self.configure_accumulated_gradients(accumulate_grad_batches) |
|
|
|
|
|
self.data_parallel_device_ids = [ |
|
int(x) for x in os.environ.get("CUDA_VISIBLE_DEVICES", "").split(",") if x != ''] |
|
if len(self.data_parallel_device_ids) == 0: |
|
self.root_gpu = None |
|
self.on_gpu = False |
|
else: |
|
self.root_gpu = self.data_parallel_device_ids[0] |
|
self.on_gpu = True |
|
|
|
|
|
self.use_ddp = False |
|
self.use_dp = False |
|
self.single_gpu = False |
|
self.distributed_backend = 'ddp' if self.num_gpus > 0 else 'dp' |
|
self.set_distributed_mode(self.distributed_backend) |
|
|
|
self.proc_rank = 0 |
|
self.world_size = 1 |
|
self.node_rank = 0 |
|
|
|
|
|
|
|
self.show_progress_bar = show_progress_bar |
|
|
|
|
|
self.log_save_interval = log_save_interval |
|
self.val_check_interval = val_check_interval |
|
self.logger = logger |
|
self.logger.rank = 0 |
|
self.row_log_interval = row_log_interval |
|
|
|
@property |
|
def num_gpus(self): |
|
gpus = self.data_parallel_device_ids |
|
if gpus is None: |
|
return 0 |
|
else: |
|
return len(gpus) |
|
|
|
@property |
|
def data_parallel(self): |
|
return self.use_dp or self.use_ddp |
|
|
|
def get_model(self): |
|
is_dp_module = isinstance(self.model, (DDP, DP)) |
|
model = self.model.module if is_dp_module else self.model |
|
return model |
|
|
|
|
|
|
|
|
|
def fit(self, model): |
|
if self.use_ddp: |
|
mp.spawn(self.ddp_train, nprocs=self.num_gpus, args=(model,)) |
|
else: |
|
model.model = model.build_model() |
|
if not self.testing: |
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) |
|
if self.use_dp: |
|
model.cuda(self.root_gpu) |
|
model = DP(model, device_ids=self.data_parallel_device_ids) |
|
elif self.single_gpu: |
|
model.cuda(self.root_gpu) |
|
self.run_pretrain_routine(model) |
|
return 1 |
|
|
|
def init_optimizers(self, optimizers): |
|
|
|
|
|
if isinstance(optimizers, Optimizer): |
|
return [optimizers], [] |
|
|
|
|
|
elif len(optimizers) == 2 and isinstance(optimizers[0], list): |
|
optimizers, lr_schedulers = optimizers |
|
return optimizers, lr_schedulers |
|
|
|
|
|
elif isinstance(optimizers, list) or isinstance(optimizers, tuple): |
|
return optimizers, [] |
|
|
|
def run_pretrain_routine(self, model): |
|
"""Sanity check a few things before starting actual training. |
|
|
|
:param model: |
|
""" |
|
ref_model = model |
|
if self.data_parallel: |
|
ref_model = model.module |
|
|
|
|
|
ref_model.trainer = self |
|
|
|
|
|
self.copy_trainer_model_properties(ref_model) |
|
|
|
|
|
if self.logger is not None: |
|
ref_model.logger = self.logger |
|
self.logger.save() |
|
|
|
if self.use_ddp: |
|
dist.barrier() |
|
|
|
|
|
|
|
|
|
|
|
self.get_dataloaders(ref_model) |
|
|
|
|
|
|
|
self.model = model |
|
|
|
|
|
self.restore_weights(model) |
|
|
|
|
|
if self.testing: |
|
self.run_evaluation(test=True) |
|
return |
|
|
|
|
|
self.disable_validation = self.num_val_batches == 0 |
|
|
|
|
|
|
|
ref_model.on_sanity_check_start() |
|
ref_model.on_train_start() |
|
if not self.disable_validation and self.num_sanity_val_steps > 0: |
|
|
|
pbar = tqdm.tqdm(desc='Validation sanity check', |
|
total=self.num_sanity_val_steps * len(self.get_val_dataloaders()), |
|
leave=False, position=2 * self.process_position, |
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch') |
|
self.main_progress_bar = pbar |
|
|
|
self.val_progress_bar = tqdm.tqdm(disable=True) |
|
|
|
self.evaluate(model, self.get_val_dataloaders(), self.num_sanity_val_steps, self.testing) |
|
|
|
|
|
self.main_progress_bar.close() |
|
self.val_progress_bar.close() |
|
|
|
|
|
pbar = tqdm.tqdm(leave=True, position=2 * self.process_position, |
|
disable=not self.show_progress_bar, dynamic_ncols=True, unit='batch', |
|
file=sys.stdout) |
|
self.main_progress_bar = pbar |
|
|
|
|
|
if self.on_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
|
|
self.train() |
|
|
|
def test(self, model): |
|
self.testing = True |
|
self.fit(model) |
|
|
|
@property |
|
def training_tqdm_dict(self): |
|
tqdm_dict = { |
|
'step': '{}'.format(self.global_step), |
|
} |
|
tqdm_dict.update(self.tqdm_metrics) |
|
return tqdm_dict |
|
|
|
|
|
|
|
|
|
def restore_weights(self, model): |
|
""" |
|
To restore weights we have two cases. |
|
First, attempt to restore hpc weights. If successful, don't restore |
|
other weights. |
|
|
|
Otherwise, try to restore actual weights |
|
:param model: |
|
:return: |
|
""" |
|
|
|
if self.on_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
if self.resume_from_checkpoint is not None: |
|
self.restore(self.resume_from_checkpoint, on_gpu=self.on_gpu) |
|
else: |
|
|
|
self.restore_state_if_checkpoint_exists(model) |
|
|
|
|
|
if self.use_ddp: |
|
|
|
dist.barrier() |
|
|
|
|
|
if self.on_gpu: |
|
torch.cuda.empty_cache() |
|
|
|
def restore_state_if_checkpoint_exists(self, model): |
|
did_restore = False |
|
|
|
|
|
no_ckpt_callback = (self.checkpoint_callback is None) or (not self.checkpoint_callback) |
|
if no_ckpt_callback or not os.path.exists(self.checkpoint_callback.filepath): |
|
return did_restore |
|
|
|
|
|
last_steps = -1 |
|
last_ckpt_name = None |
|
|
|
|
|
checkpoints = os.listdir(self.checkpoint_callback.filepath) |
|
for name in checkpoints: |
|
if '.ckpt' in name and not name.endswith('part'): |
|
if 'steps_' in name: |
|
steps = name.split('steps_')[1] |
|
steps = int(re.sub('[^0-9]', '', steps)) |
|
|
|
if steps > last_steps: |
|
last_steps = steps |
|
last_ckpt_name = name |
|
|
|
|
|
if last_ckpt_name is not None: |
|
last_ckpt_path = os.path.join(self.checkpoint_callback.filepath, last_ckpt_name) |
|
self.restore(last_ckpt_path, self.on_gpu) |
|
logging.info(f'model and trainer restored from checkpoint: {last_ckpt_path}') |
|
did_restore = True |
|
|
|
return did_restore |
|
|
|
def restore(self, checkpoint_path, on_gpu): |
|
checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
|
|
|
model = self.get_model() |
|
|
|
|
|
model.load_state_dict(checkpoint['state_dict'], strict=False) |
|
if on_gpu: |
|
model.cuda(self.root_gpu) |
|
|
|
self.restore_training_state(checkpoint) |
|
model.global_step = self.global_step |
|
del checkpoint |
|
|
|
try: |
|
if dist.is_initialized() and dist.get_rank() > 0: |
|
return |
|
except Exception as e: |
|
print(e) |
|
return |
|
|
|
def restore_training_state(self, checkpoint): |
|
""" |
|
Restore trainer state. |
|
Model will get its change to update |
|
:param checkpoint: |
|
:return: |
|
""" |
|
if self.checkpoint_callback is not None and self.checkpoint_callback is not False: |
|
|
|
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best'] |
|
|
|
self.global_step = checkpoint['global_step'] |
|
self.current_epoch = checkpoint['epoch'] |
|
|
|
if self.testing: |
|
return |
|
|
|
|
|
optimizer_states = checkpoint['optimizer_states'] |
|
for optimizer, opt_state in zip(self.optimizers, optimizer_states): |
|
if optimizer is None: |
|
return |
|
optimizer.load_state_dict(opt_state) |
|
|
|
|
|
|
|
if self.root_gpu is not None: |
|
for state in optimizer.state.values(): |
|
for k, v in state.items(): |
|
if isinstance(v, torch.Tensor): |
|
state[k] = v.cuda(self.root_gpu) |
|
|
|
|
|
lr_schedulers = checkpoint['lr_schedulers'] |
|
for scheduler, lrs_state in zip(self.lr_schedulers, lr_schedulers): |
|
scheduler.load_state_dict(lrs_state) |
|
|
|
|
|
|
|
|
|
def _atomic_save(self, checkpoint, filepath): |
|
"""Saves a checkpoint atomically, avoiding the creation of incomplete checkpoints. |
|
|
|
This will create a temporary checkpoint with a suffix of ``.part``, then copy it to the final location once |
|
saving is finished. |
|
|
|
Args: |
|
checkpoint (object): The object to save. |
|
Built to be used with the ``dump_checkpoint`` method, but can deal with anything which ``torch.save`` |
|
accepts. |
|
filepath (str|pathlib.Path): The path to which the checkpoint will be saved. |
|
This points to the file that the checkpoint will be stored in. |
|
""" |
|
tmp_path = str(filepath) + ".part" |
|
torch.save(checkpoint, tmp_path) |
|
os.replace(tmp_path, filepath) |
|
|
|
def save_checkpoint(self, filepath): |
|
checkpoint = self.dump_checkpoint() |
|
self._atomic_save(checkpoint, filepath) |
|
|
|
def dump_checkpoint(self): |
|
|
|
checkpoint = { |
|
'epoch': self.current_epoch, |
|
'global_step': self.global_step |
|
} |
|
|
|
if self.checkpoint_callback is not None and self.checkpoint_callback is not False: |
|
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best |
|
|
|
|
|
optimizer_states = [] |
|
for i, optimizer in enumerate(self.optimizers): |
|
if optimizer is not None: |
|
optimizer_states.append(optimizer.state_dict()) |
|
|
|
checkpoint['optimizer_states'] = optimizer_states |
|
|
|
|
|
lr_schedulers = [] |
|
for i, scheduler in enumerate(self.lr_schedulers): |
|
lr_schedulers.append(scheduler.state_dict()) |
|
|
|
checkpoint['lr_schedulers'] = lr_schedulers |
|
|
|
|
|
model = self.get_model() |
|
checkpoint['state_dict'] = model.state_dict() |
|
|
|
model.on_save_checkpoint(checkpoint) |
|
|
|
return checkpoint |
|
|
|
def copy_trainer_model_properties(self, model): |
|
if isinstance(model, DP): |
|
ref_model = model.module |
|
elif isinstance(model, DDP): |
|
ref_model = model.module |
|
else: |
|
ref_model = model |
|
|
|
for m in [model, ref_model]: |
|
m.trainer = self |
|
m.on_gpu = self.on_gpu |
|
m.use_dp = self.use_dp |
|
m.use_ddp = self.use_ddp |
|
m.testing = self.testing |
|
m.single_gpu = self.single_gpu |
|
|
|
def transfer_batch_to_gpu(self, batch, gpu_id): |
|
|
|
if callable(getattr(batch, 'cuda', None)): |
|
return batch.cuda(gpu_id, non_blocking=True) |
|
|
|
elif callable(getattr(batch, 'to', None)): |
|
return batch.to(torch.device('cuda', gpu_id), non_blocking=True) |
|
|
|
|
|
elif isinstance(batch, list): |
|
for i, x in enumerate(batch): |
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id) |
|
return batch |
|
|
|
|
|
elif isinstance(batch, tuple): |
|
batch = list(batch) |
|
for i, x in enumerate(batch): |
|
batch[i] = self.transfer_batch_to_gpu(x, gpu_id) |
|
return tuple(batch) |
|
|
|
|
|
elif isinstance(batch, dict): |
|
for k, v in batch.items(): |
|
batch[k] = self.transfer_batch_to_gpu(v, gpu_id) |
|
|
|
return batch |
|
|
|
|
|
return batch |
|
|
|
def set_distributed_mode(self, distributed_backend): |
|
|
|
if self.num_gpus == 0: |
|
return |
|
|
|
|
|
|
|
|
|
elif self.num_gpus == 1: |
|
self.single_gpu = True |
|
self.use_dp = False |
|
self.use_ddp = False |
|
self.root_gpu = 0 |
|
self.data_parallel_device_ids = [0] |
|
else: |
|
if distributed_backend is not None: |
|
self.use_dp = distributed_backend == 'dp' |
|
self.use_ddp = distributed_backend == 'ddp' |
|
elif distributed_backend is None: |
|
self.use_dp = True |
|
self.use_ddp = False |
|
|
|
logging.info(f'gpu available: {torch.cuda.is_available()}, used: {self.on_gpu}') |
|
|
|
def ddp_train(self, gpu_idx, model): |
|
""" |
|
Entry point into a DP thread |
|
:param gpu_idx: |
|
:param model: |
|
:param cluster_obj: |
|
:return: |
|
""" |
|
|
|
self.node_rank = 0 |
|
|
|
|
|
self.show_progress_bar = self.show_progress_bar and self.node_rank == 0 and gpu_idx == 0 |
|
|
|
|
|
if self.use_ddp: |
|
self.proc_rank = self.node_rank * self.num_gpus + gpu_idx |
|
self.world_size = self.num_gpus |
|
|
|
|
|
if self.logger is not None: |
|
self.logger.rank = self.proc_rank |
|
|
|
|
|
|
|
|
|
model.trainer = self |
|
model.init_ddp_connection(self.proc_rank, self.world_size) |
|
|
|
|
|
|
|
model.model = model.build_model() |
|
if not self.testing: |
|
self.optimizers, self.lr_schedulers = self.init_optimizers(model.configure_optimizers()) |
|
|
|
|
|
|
|
if self.distributed_backend == 'ddp': |
|
torch.cuda.set_device(gpu_idx) |
|
model.cuda(gpu_idx) |
|
|
|
|
|
self.copy_trainer_model_properties(model) |
|
|
|
|
|
self.root_gpu = gpu_idx |
|
|
|
if self.distributed_backend == 'ddp': |
|
device_ids = [gpu_idx] |
|
else: |
|
device_ids = None |
|
|
|
|
|
model = model.configure_ddp(model, device_ids) |
|
|
|
|
|
self.run_pretrain_routine(model) |
|
|
|
def resolve_root_node_address(self, root_node): |
|
if '[' in root_node: |
|
name = root_node.split('[')[0] |
|
number = root_node.split(',')[0] |
|
if '-' in number: |
|
number = number.split('-')[0] |
|
|
|
number = re.sub('[^0-9]', '', number) |
|
root_node = name + number |
|
|
|
return root_node |
|
|
|
def log_metrics(self, metrics, grad_norm_dic, step=None): |
|
"""Logs the metric dict passed in. |
|
|
|
:param metrics: |
|
:param grad_norm_dic: |
|
""" |
|
|
|
metrics['epoch'] = self.current_epoch |
|
|
|
|
|
metrics.update(grad_norm_dic) |
|
|
|
|
|
scalar_metrics = self.metrics_to_scalars(metrics) |
|
|
|
step = step if step is not None else self.global_step |
|
|
|
if self.proc_rank == 0 and self.logger is not None: |
|
self.logger.log_metrics(scalar_metrics, step=step) |
|
self.logger.save() |
|
|
|
def add_tqdm_metrics(self, metrics): |
|
for k, v in metrics.items(): |
|
if type(v) is torch.Tensor: |
|
v = v.item() |
|
|
|
self.tqdm_metrics[k] = v |
|
|
|
def metrics_to_scalars(self, metrics): |
|
new_metrics = {} |
|
for k, v in metrics.items(): |
|
if isinstance(v, torch.Tensor): |
|
v = v.item() |
|
|
|
if type(v) is dict: |
|
v = self.metrics_to_scalars(v) |
|
|
|
new_metrics[k] = v |
|
|
|
return new_metrics |
|
|
|
def process_output(self, output, train=False): |
|
"""Reduces output according to the training mode. |
|
|
|
Separates loss from logging and tqdm metrics |
|
:param output: |
|
:return: |
|
""" |
|
|
|
|
|
|
|
|
|
callback_metrics = {} |
|
for k, v in output.items(): |
|
if k not in ['progress_bar', 'log', 'hiddens']: |
|
callback_metrics[k] = v |
|
|
|
if train and self.use_dp: |
|
num_gpus = self.num_gpus |
|
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus) |
|
|
|
for k, v in callback_metrics.items(): |
|
if isinstance(v, torch.Tensor): |
|
callback_metrics[k] = v.item() |
|
|
|
|
|
|
|
|
|
try: |
|
progress_output = output['progress_bar'] |
|
|
|
|
|
if train and self.use_dp: |
|
num_gpus = self.num_gpus |
|
progress_output = self.reduce_distributed_output(progress_output, num_gpus) |
|
|
|
progress_bar_metrics = progress_output |
|
except Exception: |
|
progress_bar_metrics = {} |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
log_output = output['log'] |
|
|
|
|
|
if train and self.use_dp: |
|
num_gpus = self.num_gpus |
|
log_output = self.reduce_distributed_output(log_output, num_gpus) |
|
|
|
log_metrics = log_output |
|
except Exception: |
|
log_metrics = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = None |
|
if train: |
|
try: |
|
loss = output['loss'] |
|
except Exception: |
|
if type(output) is torch.Tensor: |
|
loss = output |
|
else: |
|
raise RuntimeError( |
|
'No `loss` value in the dictionary returned from `model.training_step()`.' |
|
) |
|
|
|
|
|
if self.use_dp: |
|
loss = self.reduce_distributed_output(loss, self.num_gpus) |
|
|
|
|
|
|
|
|
|
hiddens = output.get('hiddens') |
|
|
|
|
|
callback_metrics.update(progress_bar_metrics) |
|
callback_metrics.update(log_metrics) |
|
|
|
|
|
for k, v in callback_metrics.items(): |
|
if isinstance(v, torch.Tensor): |
|
callback_metrics[k] = v.item() |
|
|
|
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens |
|
|
|
def reduce_distributed_output(self, output, num_gpus): |
|
if num_gpus <= 1: |
|
return output |
|
|
|
|
|
|
|
if type(output) is torch.Tensor: |
|
return output.mean() |
|
|
|
for k, v in output.items(): |
|
|
|
if isinstance(output[k], dict): |
|
output[k] = self.reduce_distributed_output(output[k], num_gpus) |
|
|
|
|
|
elif isinstance(output[k], torch.Tensor) and output[k].dim() == 0: |
|
pass |
|
|
|
|
|
elif output[k].size(0) == num_gpus: |
|
reduced = torch.mean(output[k]) |
|
output[k] = reduced |
|
return output |
|
|
|
def clip_gradients(self): |
|
if self.gradient_clip_val > 0: |
|
model = self.get_model() |
|
torch.nn.utils.clip_grad_norm_(model.parameters(), self.gradient_clip_val) |
|
|
|
def print_nan_gradients(self): |
|
model = self.get_model() |
|
for param in model.parameters(): |
|
if (param.grad is not None) and torch.isnan(param.grad.float()).any(): |
|
logging.info(param, param.grad) |
|
|
|
def configure_accumulated_gradients(self, accumulate_grad_batches): |
|
self.accumulate_grad_batches = None |
|
|
|
if isinstance(accumulate_grad_batches, dict): |
|
self.accumulation_scheduler = GradientAccumulationScheduler(accumulate_grad_batches) |
|
elif isinstance(accumulate_grad_batches, int): |
|
schedule = {1: accumulate_grad_batches} |
|
self.accumulation_scheduler = GradientAccumulationScheduler(schedule) |
|
else: |
|
raise TypeError("Gradient accumulation supports only int and dict types") |
|
|
|
def get_dataloaders(self, model): |
|
if not self.testing: |
|
self.init_train_dataloader(model) |
|
self.init_val_dataloader(model) |
|
else: |
|
self.init_test_dataloader(model) |
|
|
|
if self.use_ddp: |
|
dist.barrier() |
|
if not self.testing: |
|
self.get_train_dataloader() |
|
self.get_val_dataloaders() |
|
else: |
|
self.get_test_dataloaders() |
|
|
|
def init_train_dataloader(self, model): |
|
self.fisrt_epoch = True |
|
self.get_train_dataloader = model.train_dataloader |
|
if isinstance(self.get_train_dataloader(), torch.utils.data.DataLoader): |
|
self.num_training_batches = len(self.get_train_dataloader()) |
|
self.num_training_batches = int(self.num_training_batches) |
|
else: |
|
self.num_training_batches = float('inf') |
|
self.is_iterable_train_dataloader = True |
|
if isinstance(self.val_check_interval, int): |
|
self.val_check_batch = self.val_check_interval |
|
else: |
|
self._percent_range_check('val_check_interval') |
|
self.val_check_batch = int(self.num_training_batches * self.val_check_interval) |
|
self.val_check_batch = max(1, self.val_check_batch) |
|
|
|
def init_val_dataloader(self, model): |
|
self.get_val_dataloaders = model.val_dataloader |
|
self.num_val_batches = 0 |
|
if self.get_val_dataloaders() is not None: |
|
if isinstance(self.get_val_dataloaders()[0], torch.utils.data.DataLoader): |
|
self.num_val_batches = sum(len(dataloader) for dataloader in self.get_val_dataloaders()) |
|
self.num_val_batches = int(self.num_val_batches) |
|
else: |
|
self.num_val_batches = float('inf') |
|
|
|
def init_test_dataloader(self, model): |
|
self.get_test_dataloaders = model.test_dataloader |
|
if self.get_test_dataloaders() is not None: |
|
if isinstance(self.get_test_dataloaders()[0], torch.utils.data.DataLoader): |
|
self.num_test_batches = sum(len(dataloader) for dataloader in self.get_test_dataloaders()) |
|
self.num_test_batches = int(self.num_test_batches) |
|
else: |
|
self.num_test_batches = float('inf') |
|
|
|
def evaluate(self, model, dataloaders, max_batches, test=False): |
|
"""Run evaluation code. |
|
|
|
:param model: PT model |
|
:param dataloaders: list of PT dataloaders |
|
:param max_batches: Scalar |
|
:param test: boolean |
|
:return: |
|
""" |
|
|
|
model.zero_grad() |
|
model.eval() |
|
|
|
|
|
self.copy_trainer_model_properties(model) |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
|
if test: |
|
self.get_model().test_start() |
|
|
|
outputs = [] |
|
|
|
|
|
for dataloader_idx, dataloader in enumerate(dataloaders): |
|
dl_outputs = [] |
|
for batch_idx, batch in enumerate(dataloader): |
|
|
|
if batch is None: |
|
continue |
|
|
|
|
|
if batch_idx >= max_batches: |
|
break |
|
|
|
|
|
|
|
|
|
output = self.evaluation_forward(model, |
|
batch, |
|
batch_idx, |
|
dataloader_idx, |
|
test) |
|
|
|
|
|
dl_outputs.append(output) |
|
|
|
|
|
if test: |
|
self.test_progress_bar.update(1) |
|
else: |
|
self.val_progress_bar.update(1) |
|
outputs.append(dl_outputs) |
|
|
|
|
|
if len(dataloaders) == 1: |
|
outputs = outputs[0] |
|
|
|
|
|
model = self.get_model() |
|
if test: |
|
eval_results_ = model.test_end(outputs) |
|
else: |
|
eval_results_ = model.validation_end(outputs) |
|
eval_results = eval_results_ |
|
|
|
|
|
model.train() |
|
|
|
|
|
torch.set_grad_enabled(True) |
|
|
|
return eval_results |
|
|
|
def run_evaluation(self, test=False): |
|
|
|
model = self.get_model() |
|
model.on_pre_performance_check() |
|
|
|
|
|
if test: |
|
dataloaders = self.get_test_dataloaders() |
|
max_batches = self.num_test_batches |
|
else: |
|
|
|
dataloaders = self.get_val_dataloaders() |
|
max_batches = self.num_val_batches |
|
|
|
|
|
|
|
position = 2 * self.process_position + (not test) |
|
desc = 'Testing' if test else 'Validating' |
|
pbar = tqdm.tqdm(desc=desc, total=max_batches, leave=test, position=position, |
|
disable=not self.show_progress_bar, dynamic_ncols=True, |
|
unit='batch', file=sys.stdout) |
|
setattr(self, f'{"test" if test else "val"}_progress_bar', pbar) |
|
|
|
|
|
eval_results = self.evaluate(self.model, |
|
dataloaders, |
|
max_batches, |
|
test) |
|
if eval_results is not None: |
|
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output( |
|
eval_results) |
|
|
|
|
|
self.add_tqdm_metrics(prog_bar_metrics) |
|
|
|
|
|
self.log_metrics(log_metrics, {}) |
|
|
|
|
|
self.callback_metrics.update(callback_metrics) |
|
|
|
|
|
model.on_post_performance_check() |
|
|
|
|
|
tqdm_metrics = self.training_tqdm_dict |
|
if not test: |
|
self.main_progress_bar.set_postfix(**tqdm_metrics) |
|
|
|
|
|
if test: |
|
self.test_progress_bar.close() |
|
else: |
|
self.val_progress_bar.close() |
|
|
|
|
|
if self.proc_rank == 0 and self.checkpoint_callback is not None and not test: |
|
self.checkpoint_callback.on_epoch_end(epoch=self.current_epoch, |
|
logs=self.callback_metrics) |
|
|
|
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test=False): |
|
|
|
args = [batch, batch_idx] |
|
|
|
if test and len(self.get_test_dataloaders()) > 1: |
|
args.append(dataloader_idx) |
|
|
|
elif not test and len(self.get_val_dataloaders()) > 1: |
|
args.append(dataloader_idx) |
|
|
|
|
|
if self.use_ddp or self.use_dp: |
|
output = model(*args) |
|
return output |
|
|
|
|
|
if self.single_gpu: |
|
|
|
root_gpu = 0 |
|
if isinstance(self.data_parallel_device_ids, list): |
|
root_gpu = self.data_parallel_device_ids[0] |
|
batch = self.transfer_batch_to_gpu(batch, root_gpu) |
|
args[0] = batch |
|
|
|
|
|
if test: |
|
output = model.test_step(*args) |
|
else: |
|
output = model.validation_step(*args) |
|
|
|
return output |
|
|
|
def train(self): |
|
model = self.get_model() |
|
|
|
for epoch in range(self.current_epoch, 1000000): |
|
|
|
if self.use_ddp and hasattr(self.get_train_dataloader().sampler, 'set_epoch'): |
|
self.get_train_dataloader().sampler.set_epoch(epoch) |
|
|
|
|
|
model = self.get_model() |
|
|
|
|
|
model.current_epoch = epoch |
|
self.current_epoch = epoch |
|
|
|
total_val_batches = 0 |
|
if not self.disable_validation: |
|
|
|
is_val_epoch = (self.current_epoch + 1) % self.check_val_every_n_epoch == 0 |
|
val_checks_per_epoch = self.num_training_batches // self.val_check_batch |
|
val_checks_per_epoch = val_checks_per_epoch if is_val_epoch else 0 |
|
total_val_batches = self.num_val_batches * val_checks_per_epoch |
|
|
|
|
|
self.total_batches = self.num_training_batches + total_val_batches |
|
self.batch_loss_value = 0 |
|
|
|
if self.is_iterable_train_dataloader: |
|
|
|
num_iterations = None |
|
else: |
|
num_iterations = self.total_batches |
|
|
|
|
|
|
|
desc = f'Epoch {epoch + 1}' if not self.is_iterable_train_dataloader else '' |
|
self.main_progress_bar.set_description(desc) |
|
|
|
|
|
self.accumulation_scheduler.on_epoch_begin(epoch, self) |
|
|
|
|
|
|
|
|
|
self.run_training_epoch() |
|
|
|
|
|
if self.lr_schedulers is not None: |
|
for lr_scheduler in self.lr_schedulers: |
|
lr_scheduler.step(epoch=self.current_epoch) |
|
|
|
self.main_progress_bar.close() |
|
|
|
model.on_train_end() |
|
|
|
if self.logger is not None: |
|
self.logger.finalize("success") |
|
|
|
def run_training_epoch(self): |
|
|
|
if self.is_function_implemented('on_epoch_start'): |
|
model = self.get_model() |
|
model.on_epoch_start() |
|
|
|
|
|
for batch_idx, batch in enumerate(self.get_train_dataloader()): |
|
|
|
if batch_idx >= self.num_training_batches: |
|
break |
|
|
|
self.batch_idx = batch_idx |
|
|
|
model = self.get_model() |
|
model.global_step = self.global_step |
|
|
|
|
|
|
|
|
|
output = self.run_training_batch(batch, batch_idx) |
|
batch_result, grad_norm_dic, batch_step_metrics = output |
|
|
|
|
|
early_stop_epoch = batch_result == -1 |
|
|
|
|
|
|
|
|
|
should_check_val = ( |
|
not self.disable_validation and self.global_step % self.val_check_batch == 0 and not self.fisrt_epoch) |
|
self.fisrt_epoch = False |
|
|
|
if should_check_val: |
|
self.run_evaluation(test=self.testing) |
|
|
|
|
|
should_save_log = (batch_idx + 1) % self.log_save_interval == 0 or early_stop_epoch |
|
if should_save_log: |
|
if self.proc_rank == 0 and self.logger is not None: |
|
self.logger.save() |
|
|
|
|
|
should_log_metrics = batch_idx % self.row_log_interval == 0 or early_stop_epoch |
|
if should_log_metrics: |
|
|
|
self.log_metrics(batch_step_metrics, grad_norm_dic) |
|
|
|
self.global_step += 1 |
|
self.total_batch_idx += 1 |
|
|
|
|
|
|
|
|
|
if early_stop_epoch: |
|
break |
|
if self.global_step > self.max_updates: |
|
print("| Training end..") |
|
exit() |
|
|
|
|
|
if self.is_function_implemented('on_epoch_end'): |
|
model = self.get_model() |
|
model.on_epoch_end() |
|
|
|
def run_training_batch(self, batch, batch_idx): |
|
|
|
grad_norm_dic = {} |
|
|
|
|
|
all_callback_metrics = [] |
|
|
|
|
|
all_log_metrics = [] |
|
|
|
if batch is None: |
|
return 0, grad_norm_dic, {} |
|
|
|
|
|
if self.is_function_implemented('on_batch_start'): |
|
model_ref = self.get_model() |
|
response = model_ref.on_batch_start(batch) |
|
|
|
if response == -1: |
|
return -1, grad_norm_dic, {} |
|
|
|
splits = [batch] |
|
self.hiddens = None |
|
for split_idx, split_batch in enumerate(splits): |
|
self.split_idx = split_idx |
|
|
|
|
|
for opt_idx, optimizer in enumerate(self.optimizers): |
|
if optimizer is None: |
|
continue |
|
|
|
|
|
if len(self.optimizers) > 1: |
|
for param in self.get_model().parameters(): |
|
param.requires_grad = False |
|
for group in optimizer.param_groups: |
|
for param in group['params']: |
|
param.requires_grad = True |
|
|
|
|
|
def optimizer_closure(): |
|
|
|
output = self.training_forward( |
|
split_batch, batch_idx, opt_idx, self.hiddens) |
|
|
|
closure_loss = output[0] |
|
progress_bar_metrics = output[1] |
|
log_metrics = output[2] |
|
callback_metrics = output[3] |
|
self.hiddens = output[4] |
|
if closure_loss is None: |
|
return None |
|
|
|
|
|
|
|
closure_loss = closure_loss / self.accumulate_grad_batches |
|
|
|
|
|
model_ref = self.get_model() |
|
if closure_loss.requires_grad: |
|
model_ref.backward(closure_loss, optimizer) |
|
|
|
|
|
all_callback_metrics.append(callback_metrics) |
|
|
|
|
|
self.add_tqdm_metrics(progress_bar_metrics) |
|
all_log_metrics.append(log_metrics) |
|
|
|
|
|
if self.is_function_implemented('on_after_backward'): |
|
model_ref = self.get_model() |
|
model_ref.on_after_backward() |
|
|
|
return closure_loss |
|
|
|
|
|
loss = optimizer_closure() |
|
if loss is None: |
|
continue |
|
|
|
|
|
if self.print_nan_grads: |
|
self.print_nan_gradients() |
|
|
|
|
|
self.batch_loss_value += loss.item() |
|
|
|
|
|
if (self.batch_idx + 1) % self.accumulate_grad_batches == 0: |
|
|
|
|
|
if batch_idx % self.row_log_interval == 0: |
|
if self.track_grad_norm > 0: |
|
model = self.get_model() |
|
grad_norm_dic = model.grad_norm( |
|
self.track_grad_norm) |
|
|
|
|
|
self.clip_gradients() |
|
|
|
|
|
|
|
model = self.get_model() |
|
model.optimizer_step(self.current_epoch, batch_idx, optimizer, opt_idx) |
|
|
|
|
|
self.running_loss.append(self.batch_loss_value) |
|
self.batch_loss_value = 0 |
|
self.avg_loss = np.mean(self.running_loss[-100:]) |
|
|
|
|
|
if self.is_function_implemented('on_batch_end'): |
|
model = self.get_model() |
|
model.on_batch_end() |
|
|
|
|
|
self.main_progress_bar.update(1) |
|
self.main_progress_bar.set_postfix(**self.training_tqdm_dict) |
|
|
|
|
|
all_log_metrics = {k: v for d in all_log_metrics for k, v in d.items()} |
|
|
|
|
|
self.callback_metrics.update({k: v for d in all_callback_metrics for k, v in d.items()}) |
|
|
|
return 0, grad_norm_dic, all_log_metrics |
|
|
|
def training_forward(self, batch, batch_idx, opt_idx, hiddens): |
|
""" |
|
Handle forward for each training case (distributed, single gpu, etc...) |
|
:param batch: |
|
:param batch_idx: |
|
:return: |
|
""" |
|
|
|
|
|
|
|
|
|
args = [batch, batch_idx, opt_idx] |
|
|
|
|
|
if self.use_ddp or self.use_dp: |
|
output = self.model(*args) |
|
|
|
elif self.single_gpu: |
|
gpu_id = 0 |
|
if isinstance(self.data_parallel_device_ids, list): |
|
gpu_id = self.data_parallel_device_ids[0] |
|
batch = self.transfer_batch_to_gpu(copy.copy(batch), gpu_id) |
|
args[0] = batch |
|
output = self.model.training_step(*args) |
|
|
|
else: |
|
output = self.model.training_step(*args) |
|
|
|
|
|
model_ref = self.get_model() |
|
output_ = model_ref.training_end(output) |
|
if output_ is not None: |
|
output = output_ |
|
|
|
|
|
output = self.process_output(output, train=True) |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
def is_function_implemented(self, f_name): |
|
model = self.get_model() |
|
f_op = getattr(model, f_name, None) |
|
return callable(f_op) |
|
|
|
def _percent_range_check(self, name): |
|
value = getattr(self, name) |
|
msg = f"`{name}` must lie in the range [0.0, 1.0], but got {value:.3f}." |
|
if name == "val_check_interval": |
|
msg += " If you want to disable validation set `val_percent_check` to 0.0 instead." |
|
|
|
if not 0. <= value <= 1.: |
|
raise ValueError(msg) |
|
|