tjxj
basicsr
6f7f0bf
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class EDVRModel(VideoBaseModel):
"""EDVR Model.
Paper: EDVR: Video Restoration with Enhanced Deformable Convolutional Networks. # noqa: E501
"""
def __init__(self, opt):
super(EDVRModel, self).__init__(opt)
if self.is_train:
self.train_tsa_iter = opt['train'].get('tsa_iter')
def setup_optimizers(self):
train_opt = self.opt['train']
dcn_lr_mul = train_opt.get('dcn_lr_mul', 1)
logger = get_root_logger()
logger.info(f'Multiple the learning rate for dcn with {dcn_lr_mul}.')
if dcn_lr_mul == 1:
optim_params = self.net_g.parameters()
else: # separate dcn params and normal params for different lr
normal_params = []
dcn_params = []
for name, param in self.net_g.named_parameters():
if 'dcn' in name:
dcn_params.append(param)
else:
normal_params.append(param)
optim_params = [
{ # add normal params first
'params': normal_params,
'lr': train_opt['optim_g']['lr']
},
{
'params': dcn_params,
'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
},
]
optim_type = train_opt['optim_g'].pop('type')
self.optimizer_g = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
self.optimizers.append(self.optimizer_g)
def optimize_parameters(self, current_iter):
if self.train_tsa_iter:
if current_iter == 1:
logger = get_root_logger()
logger.info(f'Only train TSA module for {self.train_tsa_iter} iters.')
for name, param in self.net_g.named_parameters():
if 'fusion' not in name:
param.requires_grad = False
elif current_iter == self.train_tsa_iter:
logger = get_root_logger()
logger.warning('Train all the parameters.')
for param in self.net_g.parameters():
param.requires_grad = True
super(EDVRModel, self).optimize_parameters(current_iter)