akhaliq's picture
akhaliq HF staff
add files
c80917c
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time
import os
from collections import defaultdict
import captioning.utils.opts as opts
import captioning.models as models
from captioning.data.pth_loader import CaptionDataset
import captioning.utils.eval_utils as eval_utils
import captioning.utils.misc as utils
from captioning.utils.rewards import init_scorer, get_self_critical_reward
from captioning.modules.loss_wrapper import LossWrapper
import pytorch_lightning as pl
import detectron2.utils.comm as d2comm
from detectron2.utils.env import seed_all_rng
seed_all_rng(1234)
class LitModel(pl.LightningModule):
def __init__(self, opt):
super().__init__()
self.opt = opt
# Intilaize dataset
self.dataset = CaptionDataset(opt)
opt.vocab_size = self.dataset.vocab_size
opt.seq_length = self.dataset.seq_length
self.batch_size = opt.batch_size
# Build model
opt.vocab = self.dataset.get_vocab()
model = models.setup(opt)
# print(model)
del opt.vocab
# wrapper with loss in it.
lw_model = LossWrapper(model, opt)
self.model = model
self.lw_model = lw_model
self.struc_flag = None
self.sc_flag = None
# if self.opt.use_clipscore:
# if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
# if CLIP-S+Grammar is used in reward -> Launch another CLIP-S where parameter is unchanged
if getattr(self.opt, 'use_grammar', False):
from captioning.utils.clipscore import CLIPScore
self.val_clipscore_model = CLIPScore(
mode=opt.clipscore_mode, use_grammar=False)
for p in self.val_clipscore_model.parameters():
p.requires_grad = False
else:
if self.lw_model.clipscore_model is not None:
self.val_clipscore_model = self.lw_model.clipscore_model
else:
from captioning.utils.clipscore import CLIPScore
self.val_clipscore_model = CLIPScore(
mode=opt.clipscore_mode, use_grammar=False)
for p in self.val_clipscore_model.parameters():
p.requires_grad = False
self.val_clipscore_model.eval()
# BERTSCORE
from bert_score import BERTScorer
self.bert_scorer = BERTScorer(
lang="en",
# rescale_with_baseline=True,
rescale_with_baseline=False,
device='cpu'
)
def forward(self, *args, **kwargs):
"""
I hate this design. Never pretend it as a nn.Module
"""
raise NotImplementedError
def train_dataloader(self):
train_dataset = torch.utils.data.Subset(
self.dataset,
self.dataset.split_ix['train']
)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=4,
collate_fn=self.dataset.collate_func
)
return train_loader
def val_dataloader(self, split='val'):
val_dataset = torch.utils.data.Subset(
self.dataset,
self.dataset.split_ix[split]
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
collate_fn=self.dataset.collate_func
)
return val_loader
def test_dataloader(self):
return self.val_dataloader('test')
def training_step(self, data, batch_idx):
sc_flag, struc_flag = self.sc_flag, self.struc_flag
tmp = [data['fc_feats'], data['att_feats'],
data['labels'], data['masks'], data['att_masks']]
fc_feats, att_feats, labels, masks, att_masks = tmp
if int(os.getenv('M2_cider', '0')) != 0:
data['gts'] = data['rawgts']
if self.opt.use_clipscore:
clip_vis_feats = data['clip_vis_feats']
model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag,
clip_vis_feats=clip_vis_feats)
else:
model_out = self.lw_model(fc_feats, att_feats, labels, masks, att_masks,
data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)
loss = model_out['loss']
data_time = self.trainer.profiler.recorded_durations["get_train_batch"][-1]
data_time = torch.tensor(data_time)
logger_logs = model_out.copy()
# if struc_flag or sc_flag:
# logger_logs['reward'] = model_out['reward'].mean()
# logger_logs['reward_var'] = model_out['reward'].var(1).mean()
if struc_flag or sc_flag:
logger_logs['reward'] = model_out['reward'].mean()
for k in ['CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
if k in model_out:
logger_logs[k] = model_out[k]
if struc_flag:
logger_logs['reward_var'] = model_out['reward'].var(1).mean()
logger_logs['scheduled_sampling_prob'] = torch.tensor(
self.model.ss_prob)
# logger_logs['training_loss'] = loss
logger_logs['loss'] = loss
logger_logs['data_time'] = data_time
# UserWarning: The {progress_bar:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
# Please use self.log(...) inside the lightningModule instead.
# # log on a step or aggregate epoch metric to the logger and/or progress bar
# # (inside LightningModule)
# self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
# warnings.warn(*args, **kwargs)
# UserWarning: The {log:dict keyword} was deprecated in 0.9.1 and will be removed in 1.0.0
# Please use self.log(...) inside the lightningModule instead.
# output = {
# 'loss': loss,
# 'log': logger_logs,
# 'progress_bar': {'data_time': data_time}
# }
for k, v in logger_logs.items():
if k in ['reward', 'reward_var', 'data_time', 'CLIP-S', 'RefCLIP-S', 'CIDEr', 'grammar_reward']:
self.log('train/'+k, v, prog_bar=True)
else:
self.log('train/'+k, v)
return loss
def validation_step(self, data, batch_idx):
model = self.model
crit = self.lw_model.crit
opt = self.opt
eval_kwargs = {'dataset': opt.input_json}
eval_kwargs.update(vars(opt))
# CLIPScore
use_grammar = getattr(self.opt, 'use_grammar', False)
joint_out = getattr(self.opt, 'joint_out', False)
verbose = eval_kwargs.get('verbose', True)
verbose_beam = eval_kwargs.get('verbose_beam', 0)
verbose_loss = eval_kwargs.get('verbose_loss', 1)
# num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
# lang_eval = eval_kwargs.get('language_eval', 0)
dataset = eval_kwargs.get('dataset', 'coco')
beam_size = eval_kwargs.get('beam_size', 1)
sample_n = eval_kwargs.get('sample_n', 1)
remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
# Use this nasty way to make other code clean since it's a global configuration
os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings)
predictions = []
n_predictions = []
loss = torch.tensor(0)
if data.get('labels', None) is not None and verbose_loss:
# forward the model to get loss
tmp = [data['fc_feats'], data['att_feats'],
data['labels'], data['masks'], data['att_masks']]
fc_feats, att_feats, labels, masks, att_masks = tmp
loss = crit(model(fc_feats, att_feats,
labels[..., :-1], att_masks), labels[..., 1:], masks[..., 1:])
# forward the model to also get generated samples for each image
# Only leave one feature for each image, in case duplicate sample
tmp_eval_kwargs = eval_kwargs.copy()
tmp_eval_kwargs.update({'sample_n': 1})
seq, seq_logprobs = model(
fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
seq = seq.data
entropy = - (F.softmax(seq_logprobs, dim=2) *
seq_logprobs).sum(2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
perplexity = - \
seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)
# Print beam search
if beam_size > 1 and verbose_beam:
for i in range(fc_feats.shape[0]):
print('\n'.join([utils.decode_sequence(model.vocab, _[
'seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
print('--' * 10)
sents = utils.decode_sequence(model.vocab, seq)
# if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
# text_feat = self.lw_model.clipscore_model.text_extract(sents)
text_feat = self.val_clipscore_model.text_extract(sents, proj_norm=False)
text_cont_feat = self.val_clipscore_model.clip_model.text_projection(text_feat)
text_cont_feat = text_cont_feat / text_cont_feat.norm(dim=-1, keepdim=True)
vis_feat = data['clip_vis_feats']
# if self.opt.clipscore_mode == 'clip_s':
# clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
# elif self.opt.clipscore_mode == 'refclip_s':
clip_s = self.val_clipscore_model(text_feat=text_cont_feat, img_feat=vis_feat, mode='clip_s')
# ref_text = utils.decode_sequence(model.vocab, data['gts'])
gt_indices = torch.arange(0, len(data['gts']))
data_gts = [data['gts'][_] for _ in gt_indices.tolist()]
B = len(data_gts)
gts = []
gts_valid_mask = []
max_n_refs = max([len(_gts) for _gts in data_gts])
for i in range(len(data_gts)):
_gts = utils.decode_sequence(model.vocab, data_gts[i])
# pad references
n_ref = len(_gts)
_gts.extend([''] * (max_n_refs - n_ref))
gts.extend(_gts)
gts_valid_mask.extend([1] * n_ref + [0] * (max_n_refs - n_ref))
assert len(gts) == B * max_n_refs
assert len(gts_valid_mask) == B * max_n_refs
ref_text = gts
ref_text_mask = gts_valid_mask
refclip_s = self.val_clipscore_model(
text_feat=text_cont_feat, img_feat=vis_feat,
ref_text=ref_text, ref_text_mask=ref_text_mask, mode='refclip_s')
# use_grammar = getattr(self.opt, 'use_grammar', False)
# joint_out = getattr(self.opt, 'joint_out', False)
if use_grammar and not joint_out:
with torch.no_grad():
# grammar_logit = self.val_clipscore_model.grammar_score_head(text_feat.view(-1, 512))
grammar_logit = self.lw_model.clipscore_model.grammar_score_head(text_feat.view(-1, 512))
grammar_prob = torch.softmax(grammar_logit, dim=-1)[:, 1]
# BERTScore
if next(self.bert_scorer._model.parameters()).device != self.device:
self.bert_scorer._model.to(self.device)
self.bert_scorer.device = self.device
# [B*K] -> [B, K]
ref_text_per_example = []
for i in range(B):
ref_text_list_example = []
for k in range(max_n_refs):
ref = ref_text[i * max_n_refs + k]
if len(ref) > 0:
ref_text_list_example.append(ref)
# assert len(ref_text_list_example) == max_n_refs
ref_text_per_example.append(ref_text_list_example)
assert len(ref_text_per_example) == B
P, R, F1 = self.bert_scorer.score(
sents,
ref_text_per_example,
)
bertscore_f1 = F1
# print('Example 5:')
# for i in range(5):
# print('Generated:', sents[i])
# print('ref_text:', ref_text_per_example[i])
# print('BERT-Score:', F1[i].item())
for k, sent in enumerate(sents):
entry = {'image_id': data['infos'][k]['id'], 'caption': sent,
'perplexity': perplexity[k].item(), 'entropy': entropy[k].item()}
if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
# if self.opt.clipscore_mode == 'clip_s':
# entry['clipscore'] = clipscore[k].item()
# entry['CLIP-S'] = clip_s[k].item()
# elif self.opt.clipscore_mode == 'refclip_s':
entry['CLIP-S'] = clip_s[k].item()
entry['RefCLIP-S'] = refclip_s[k].item()
if use_grammar and not joint_out:
entry['grammar_prob'] = grammar_prob[k].item()
# BERT-S
entry['BERT-S'] = bertscore_f1[k].item()
if eval_kwargs.get('dump_path', 0) == 1:
entry['file_name'] = data['infos'][k]['file_path']
predictions.append(entry)
if eval_kwargs.get('dump_images', 0) == 1:
# dump the raw image to vis/ folder
cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
'" vis/imgs/img' + \
str(len(predictions)) + '.jpg' # bit gross
print(cmd)
os.system(cmd)
if verbose:
print('image %s: %s' %
(entry['image_id'], entry['caption']))
if sample_n > 1:
eval_utils.eval_split_n(model, n_predictions, [
fc_feats, att_feats, att_masks, data], eval_kwargs)
output = {
# 'val_loss': loss,
'loss': loss,
'predictions': predictions,
'n_predictions': n_predictions,
}
return output
def test_step(self, *args, **kwargs):
return self.validation_step(*args, **kwargs)
def validation_epoch_end(self, outputs, split='val'):
outputs = d2comm.gather(outputs)
# master node
if d2comm.is_main_process():
assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
outputs = sum(outputs, [])
opt = self.opt
# val_loss_mean = sum([_['val_loss']
# val_loss_mean = sum([_['val_loss'].cpu()
val_loss_mean = sum([_['loss'].cpu()
for _ in outputs]) / len(outputs)
predictions = sum([_['predictions'] for _ in outputs], [])
if len(outputs[0]['n_predictions']) != 0:
n_predictions = sum([_['n_predictions'] for _ in outputs], [])
else:
n_predictions = []
lang_stats = None
if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
n_predictions = sorted(
n_predictions, key=lambda x: x['perplexity'])
if not os.path.isdir('eval_results'):
os.mkdir('eval_results')
torch.save((predictions, n_predictions), os.path.join(
'eval_results/', '.saved_pred_' + opt.id + '_' + split + '.pth'))
if opt.language_eval:
lang_stats = eval_utils.language_eval(
opt.input_json, predictions, n_predictions, vars(opt), split)
if opt.reduce_on_plateau:
optimizer = self.trainer.optimizers[0]
if 'CIDEr' in lang_stats:
optimizer.scheduler_step(-lang_stats['CIDEr'])
else:
optimizer.scheduler_step(val_loss_mean)
# out = {
# 'val_loss': val_loss_mean
# }
out = {
'loss': val_loss_mean
}
out.update(lang_stats)
# out['to_monitor'] = lang_stats['CIDEr'] if lang_stats is not None else -val_loss_mean
if self.opt.use_clipscore or os.getenv('EVALUATE', '0') == '1':
# if self.opt.clipscore_mode == 'clip_s':
# out['clipscore'] = sum([p['clipscore'] for p in predictions]) / len(predictions)
# print('CLIPScore', out['clipscore'])
# out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
# print('CLIP-S', out['CLIP-S'])
# elif self.opt.clipscore_mode == 'refclip_s':
out['CLIP-S'] = sum([p['CLIP-S'] for p in predictions]) / len(predictions)
print('CLIP-S', out['CLIP-S'])
out['RefCLIP-S'] = sum([p['RefCLIP-S'] for p in predictions]) / len(predictions)
print('RefCLIP-S', out['RefCLIP-S'])
if getattr(self.opt, 'use_grammar', False) and not getattr(self.opt, 'joint_out', False):
out['grammar_prob'] = sum([p['grammar_prob'] for p in predictions]) / len(predictions)
print('grammar_prob', out['grammar_prob'])
out['BERT-S'] = sum([p['BERT-S'] for p in predictions]) / len(predictions)
print('BERT-S', out['BERT-S'])
else:
out = {}
out = d2comm.all_gather(out)[0] # Only the one from master node
assert len(out) > 0 # make sure the head has index 0
# must all be tensors
out = {k: torch.tensor(v) if not torch.is_tensor(
v) else v for k, v in out.items()}
# return {
# 'progress_bar': {'val_loss': out['val_loss']},
# 'log': out,
# }
for k, v in out.items():
# if k in ['loss', 'clipscore', 'RefCLIP-S', 'CIDEr']:
# if split != 'test':
# self.log(f'{split}/{k}', v, prog_bar=True)
# elif k == 'to_monitor':
# if split != 'test':
# self.log(f'{split}/{k}', v)
# else:
self.log(f'{split}/{k}', v)
def test_epoch_end(self, outputs):
# out = self.validation_epoch_end(outputs, 'test')
# out['progress_bar'] = {
# # 'test_loss': out['progress_bar']['val_loss']
# 'test_loss': out['progress_bar']['loss']
# }
# out['log']['test_loss'] = out['log']['val_loss']
# del out['log']['val_loss']
# del out['log']['to_monitor']
# out['log'] = {'test_'+k if 'test' not in k else k:v \
# for k,v in out['log'].items()}
# return out
self.validation_epoch_end(outputs, 'test')
def configure_optimizers(self):
opt = self.opt
model = self.model
parameters = [p for p in model.parameters() if p.requires_grad]
if opt.noamopt:
# assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
optimizer = utils.get_std_opt(
model, optim_func=opt.optim, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
elif opt.reduce_on_plateau:
# optimizer = utils.build_optimizer(model.parameters(), opt)
optimizer = utils.build_optimizer(parameters, opt)
optimizer = utils.ReduceLROnPlateau(optimizer,
factor=opt.reduce_on_plateau_factor,
patience=opt.reduce_on_plateau_patience)
else:
# optimizer = utils.build_optimizer(model.parameters(), opt)
optimizer = utils.build_optimizer(parameters, opt)
return [optimizer], []
def optimizer_step(self, epoch, batch_idx, optimizer,
optimizer_idx, *args, **kwargs):
# warm up lr
opt = self.opt
iteration = self.trainer.global_step
if opt.use_warmup and (iteration < opt.noamopt_warmup):
opt.current_lr = opt.learning_rate * \
(iteration+1) / opt.noamopt_warmup
utils.set_lr(optimizer, opt.current_lr)
super().optimizer_step(epoch, batch_idx, optimizer,
optimizer_idx, *args, **kwargs)
def state_dict(self):
"""
Save the model state dict as well as opt and vocab
"""
state_dict = self.model.state_dict()
device = next(iter(state_dict.values())).device
assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
state_dict.update({
'_vocab': utils.serialize_to_tensor(self.model.vocab).to(device),
'_opt': utils.serialize_to_tensor(self.opt).to(device)
})
return state_dict
def load_state_dict(self, state_dict=None, strict=True):
if '_vocab' in state_dict:
self.model.vocab = utils.deserialize(state_dict['_vocab'])
del state_dict['_vocab']
# elif strict:
# raise KeyError
if '_opt' in state_dict:
saved_model_opt = utils.deserialize(state_dict['_opt'])
del state_dict['_opt']
opt = self.opt
# Make sure the saved opt is compatible with the curren topt
need_be_same = ["caption_model",
"rnn_type", "rnn_size", "num_layers"]
for checkme in need_be_same:
if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
getattr(opt, checkme) in ['updown', 'topdown']:
continue
assert getattr(saved_model_opt, checkme) == getattr(
opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
# elif strict:
# raise KeyError
self.model.load_state_dict(state_dict, strict)
class OnEpochStartCallback(pl.Callback):
def on_epoch_start(self, trainer, pl_module):
# Update lr/training stage/scheduled sampling prob etc.
opt = pl_module.opt
model = pl_module.model
epoch = trainer.current_epoch
optimizer = trainer.optimizers[0]
if not opt.noamopt and not opt.reduce_on_plateau:
# Assign the learning rate
if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
frac = (
epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
decay_factor = opt.learning_rate_decay_rate ** frac
opt.current_lr = opt.learning_rate * decay_factor
else:
opt.current_lr = opt.learning_rate
utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
# Assign the scheduled sampling prob
if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
frac = (
epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
opt.ss_prob = min(opt.scheduled_sampling_increase_prob *
frac, opt.scheduled_sampling_max_prob)
model.ss_prob = opt.ss_prob
# If start self critical training
if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
sc_flag = True
init_scorer(opt.cached_tokens)
else:
sc_flag = False
# If start structure loss training
if opt.structure_after != -1 and epoch >= opt.structure_after:
struc_flag = True
init_scorer(opt.cached_tokens)
else:
struc_flag = False
pl_module.struc_flag = struc_flag
pl_module.sc_flag = sc_flag
class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
def on_keyboard_interrupt(self, trainer, pl_module):
# Save model when keyboard interrupt
filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
self._save_model(filepath)
opt = opts.parse_opt()
checkpoint_callback = ModelCheckpoint(
filepath=opt.checkpoint_path,
# dirpath=opt.checkpoint_path,
save_last=True,
save_top_k=1,
verbose=True,
# monitor='to_monitor',
# monitor='val/to_monitor',
monitor='val/CIDEr',
mode='max',
# prefix=opt.id+'_',
prefix=opt.id,
# filename=f'{opt.id}_',
)
verbose = True
# import torch
# if torch.cuda.current_device() in [0, -1]:
if 'LOCAL_RANK' in os.environ and os.environ['LOCAL_RANK'] != '0':
verbose = False
if verbose:
print(opt)
print("""
val_image_use,
save_checkpoint_very
save_every_epoch,
save_history-ckpt will be ignored.
""")
# Lightning defines batch size as batch size per gpu
assert opt.batch_size % torch.cuda.device_count() == 0
opt.batch_size = opt.batch_size // torch.cuda.device_count()
# If resume from last checkpoint
# if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, f'{opt.id}_last.ckpt')):
# resume_from = os.path.join(opt.start_from, f'{opt.id}_last.ckpt')
if opt.start_from is not None:
resume_from = os.path.join(opt.start_from, f'{opt.id}-last.ckpt')
if os.path.isfile(resume_from):
if verbose:
print('Loading checkpoint from', resume_from)
else:
print("Checkpoint not found:", resume_from)
resume_from = None
else:
resume_from = None
from pytorch_lightning.loggers import WandbLogger
wandb_logger = WandbLogger(
project='CLIP-ViL-COCOCaption',
name=opt.id,
)
if verbose:
wandb_logger.experiment.config.update(opt)
from pathlib import Path
import glob
import wandb
# src_dir = Path(__file__).resolve().parent.parent
glob_str = "**/*.py"
base_path = './'
wandb.save(glob_str=glob_str, base_path=base_path)
# code = wandb.Artifact('project-source', type='code')
# for path in glob.glob('**/*.py', recursive=True):
# code.add_file(path, name='source/'+path)
# print(path)
# wandb.run.use_artifact(code)
lit = LitModel(opt)
# warning grad_clip_mode is ignored.
trainer = pl.Trainer(
callbacks=[
OnEpochStartCallback(),
# pl.callbacks.lr_logger.LearningRateLogger()
pl.callbacks.LearningRateMonitor()
],
default_root_dir=opt.checkpoint_path,
resume_from_checkpoint=resume_from,
distributed_backend='ddp',
check_val_every_n_epoch=1,
max_epochs=opt.max_epochs,
gradient_clip_val=opt.grad_clip_value,
gpus=torch.cuda.device_count(),
checkpoint_callback=checkpoint_callback,
log_gpu_memory='min_max',
# log_save_interval=opt.losses_log_every,
log_every_n_steps=opt.losses_log_every,
profiler=True,
# profiler='simple',
# row_log_interval=10, # what is it?
flush_logs_every_n_steps=10,
num_sanity_val_steps=0,
# val_check_interval=0.01,
# limit_train_batches=500,
# progress_bar_refresh_rate=0,
# fast_dev_run=True,
precision=opt.precision,
logger=wandb_logger
)
if os.getenv('EVALUATE', '0') == '1':
trainer.test(lit)
else:
trainer.fit(lit)