artificialguybr's picture
Hi
eadd7b4
import math
from mmcv import Config
from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \
OPTIMIZERS
from mmcv.utils import _BatchNorm, _InstanceNorm
from torch.nn import GroupNorm, LayerNorm
from .logger import get_root_logger
from typing import Tuple, Optional, Callable
import torch
from torch.optim.optimizer import Optimizer
from came_pytorch import CAME
def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256):
assert rule in ['linear', 'sqrt']
logger = get_root_logger()
# scale by world size
if rule == 'sqrt':
scale_ratio = math.sqrt(effective_bs / base_batch_size)
elif rule == 'linear':
scale_ratio = effective_bs / base_batch_size
optimizer_cfg['lr'] *= scale_ratio
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).')
return scale_ratio
@OPTIMIZER_BUILDERS.register_module()
class MyOptimizerConstructor(DefaultOptimizerConstructor):
def add_params(self, params, module, prefix='', is_dcn_module=None):
"""Add all parameters of module to the params list.
The parameters of the given module will be added to the list of param
groups, with specific rules defined by paramwise_cfg.
Args:
params (list[dict]): A list of param groups, it will be modified
in place.
module (nn.Module): The module to be added.
prefix (str): The prefix of the module
"""
# get param-wise options
custom_keys = self.paramwise_cfg.get('custom_keys', {})
# first sort with alphabet order and then sort with reversed len of str
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True)
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.)
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.)
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.)
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False)
# special rules for norm layers and depth-wise conv layers
is_norm = isinstance(module,
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm))
for name, param in module.named_parameters(recurse=False):
base_lr = self.base_lr
if name == 'bias' and not (is_norm or is_dcn_module):
base_lr *= bias_lr_mult
# apply weight decay policies
base_wd = self.base_wd
if self.base_wd is not None:
# norm decay
if is_norm:
base_wd *= norm_decay_mult
# bias lr and decay
elif name == 'bias' and not is_dcn_module:
# TODO: current bias_decay_mult will have affect on DCN
base_wd *= bias_decay_mult
param_group = {'params': [param]}
if not param.requires_grad:
param_group['requires_grad'] = False
params.append(param_group)
continue
if bypass_duplicate and self._is_in(param_group, params):
logger = get_root_logger()
logger.warn(f'{prefix} is duplicate. It is skipped since '
f'bypass_duplicate={bypass_duplicate}')
continue
# if the parameter match one of the custom keys, ignore other rules
is_custom = False
for key in custom_keys:
if isinstance(key, tuple):
scope, key_name = key
else:
scope, key_name = None, key
if scope is not None and scope not in f'{prefix}':
continue
if key_name in f'{prefix}.{name}':
is_custom = True
if 'lr_mult' in custom_keys[key]:
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}':
# param_group['lr'] = self.base_lr
# else:
param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult']
elif 'lr' not in param_group:
param_group['lr'] = base_lr
if self.base_wd is not None:
if 'decay_mult' in custom_keys[key]:
param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult']
elif 'weight_decay' not in param_group:
param_group['weight_decay'] = base_wd
if not is_custom:
# bias_lr_mult affects all bias parameters
# except for norm.bias dcn.conv_offset.bias
if base_lr != self.base_lr:
param_group['lr'] = base_lr
if base_wd != self.base_wd:
param_group['weight_decay'] = base_wd
params.append(param_group)
for child_name, child_mod in module.named_children():
child_prefix = f'{prefix}.{child_name}' if prefix else child_name
self.add_params(
params,
child_mod,
prefix=child_prefix,
is_dcn_module=is_dcn_module)
def build_optimizer(model, optimizer_cfg):
# default parameter-wise config
logger = get_root_logger()
if hasattr(model, 'module'):
model = model.module
# set optimizer constructor
optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor')
# parameter-wise setting: cancel weight decay for some specific modules
custom_keys = dict()
for name, module in model.named_modules():
if hasattr(module, 'zero_weight_decay'):
custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay})
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys)))
given_cfg = optimizer_cfg.get('paramwise_cfg')
if given_cfg:
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg))
optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg
# build optimizer
optimizer = mm_build_optimizer(model, optimizer_cfg)
weight_decay_groups = dict()
lr_groups = dict()
for group in optimizer.param_groups:
if not group.get('requires_grad', True): continue
lr_groups.setdefault(group['lr'], []).append(group)
weight_decay_groups.setdefault(group['weight_decay'], []).append(group)
learnable_count, fix_count = 0, 0
for p in model.parameters():
if p.requires_grad:
learnable_count += 1
else:
fix_count += 1
fix_info = f"{learnable_count} are learnable, {fix_count} are fix"
lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()])
wd_info = "Weight decay group: " + ", ".join(
[f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()])
opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}."
logger.info(opt_info)
return optimizer
@OPTIMIZERS.register_module()
class Lion(Optimizer):
def __init__(
self,
params,
lr: float = 1e-4,
betas: Tuple[float, float] = (0.9, 0.99),
weight_decay: float = 0.0,
):
assert lr > 0.
assert all([0. <= beta <= 1. for beta in betas])
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
super().__init__(params, defaults)
@staticmethod
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2):
# stepweight decay
p.data.mul_(1 - lr * wd)
# weight update
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_()
p.add_(update, alpha=-lr)
# decay the momentum running average coefficient
exp_avg.lerp_(grad, 1 - beta2)
@staticmethod
def exists(val):
return val is not None
@torch.no_grad()
def step(
self,
closure: Optional[Callable] = None
):
loss = None
if self.exists(closure):
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
for p in filter(lambda p: self.exists(p.grad), group['params']):
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \
self.state[p]
# init state - exponential moving average of gradient values
if len(state) == 0:
state['exp_avg'] = torch.zeros_like(p)
exp_avg = state['exp_avg']
self.update_fn(
p,
grad,
exp_avg,
lr,
wd,
beta1,
beta2
)
return loss
@OPTIMIZERS.register_module()
class CAMEWrapper(CAME):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)