Spaces:
Runtime error
Runtime error
import math | |
from mmcv import Config | |
from mmcv.runner import build_optimizer as mm_build_optimizer, OPTIMIZER_BUILDERS, DefaultOptimizerConstructor, \ | |
OPTIMIZERS | |
from mmcv.utils import _BatchNorm, _InstanceNorm | |
from torch.nn import GroupNorm, LayerNorm | |
from .logger import get_root_logger | |
from typing import Tuple, Optional, Callable | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from came_pytorch import CAME | |
def auto_scale_lr(effective_bs, optimizer_cfg, rule='linear', base_batch_size=256): | |
assert rule in ['linear', 'sqrt'] | |
logger = get_root_logger() | |
# scale by world size | |
if rule == 'sqrt': | |
scale_ratio = math.sqrt(effective_bs / base_batch_size) | |
elif rule == 'linear': | |
scale_ratio = effective_bs / base_batch_size | |
optimizer_cfg['lr'] *= scale_ratio | |
logger.info(f'Automatically adapt lr to {optimizer_cfg["lr"]:.5f} (using {rule} scaling rule).') | |
return scale_ratio | |
class MyOptimizerConstructor(DefaultOptimizerConstructor): | |
def add_params(self, params, module, prefix='', is_dcn_module=None): | |
"""Add all parameters of module to the params list. | |
The parameters of the given module will be added to the list of param | |
groups, with specific rules defined by paramwise_cfg. | |
Args: | |
params (list[dict]): A list of param groups, it will be modified | |
in place. | |
module (nn.Module): The module to be added. | |
prefix (str): The prefix of the module | |
""" | |
# get param-wise options | |
custom_keys = self.paramwise_cfg.get('custom_keys', {}) | |
# first sort with alphabet order and then sort with reversed len of str | |
# sorted_keys = sorted(sorted(custom_keys.keys()), key=len, reverse=True) | |
bias_lr_mult = self.paramwise_cfg.get('bias_lr_mult', 1.) | |
bias_decay_mult = self.paramwise_cfg.get('bias_decay_mult', 1.) | |
norm_decay_mult = self.paramwise_cfg.get('norm_decay_mult', 1.) | |
bypass_duplicate = self.paramwise_cfg.get('bypass_duplicate', False) | |
# special rules for norm layers and depth-wise conv layers | |
is_norm = isinstance(module, | |
(_BatchNorm, _InstanceNorm, GroupNorm, LayerNorm)) | |
for name, param in module.named_parameters(recurse=False): | |
base_lr = self.base_lr | |
if name == 'bias' and not (is_norm or is_dcn_module): | |
base_lr *= bias_lr_mult | |
# apply weight decay policies | |
base_wd = self.base_wd | |
if self.base_wd is not None: | |
# norm decay | |
if is_norm: | |
base_wd *= norm_decay_mult | |
# bias lr and decay | |
elif name == 'bias' and not is_dcn_module: | |
# TODO: current bias_decay_mult will have affect on DCN | |
base_wd *= bias_decay_mult | |
param_group = {'params': [param]} | |
if not param.requires_grad: | |
param_group['requires_grad'] = False | |
params.append(param_group) | |
continue | |
if bypass_duplicate and self._is_in(param_group, params): | |
logger = get_root_logger() | |
logger.warn(f'{prefix} is duplicate. It is skipped since ' | |
f'bypass_duplicate={bypass_duplicate}') | |
continue | |
# if the parameter match one of the custom keys, ignore other rules | |
is_custom = False | |
for key in custom_keys: | |
if isinstance(key, tuple): | |
scope, key_name = key | |
else: | |
scope, key_name = None, key | |
if scope is not None and scope not in f'{prefix}': | |
continue | |
if key_name in f'{prefix}.{name}': | |
is_custom = True | |
if 'lr_mult' in custom_keys[key]: | |
# if 'base_classes' in f'{prefix}.{name}' or 'attn_base' in f'{prefix}.{name}': | |
# param_group['lr'] = self.base_lr | |
# else: | |
param_group['lr'] = self.base_lr * custom_keys[key]['lr_mult'] | |
elif 'lr' not in param_group: | |
param_group['lr'] = base_lr | |
if self.base_wd is not None: | |
if 'decay_mult' in custom_keys[key]: | |
param_group['weight_decay'] = self.base_wd * custom_keys[key]['decay_mult'] | |
elif 'weight_decay' not in param_group: | |
param_group['weight_decay'] = base_wd | |
if not is_custom: | |
# bias_lr_mult affects all bias parameters | |
# except for norm.bias dcn.conv_offset.bias | |
if base_lr != self.base_lr: | |
param_group['lr'] = base_lr | |
if base_wd != self.base_wd: | |
param_group['weight_decay'] = base_wd | |
params.append(param_group) | |
for child_name, child_mod in module.named_children(): | |
child_prefix = f'{prefix}.{child_name}' if prefix else child_name | |
self.add_params( | |
params, | |
child_mod, | |
prefix=child_prefix, | |
is_dcn_module=is_dcn_module) | |
def build_optimizer(model, optimizer_cfg): | |
# default parameter-wise config | |
logger = get_root_logger() | |
if hasattr(model, 'module'): | |
model = model.module | |
# set optimizer constructor | |
optimizer_cfg.setdefault('constructor', 'MyOptimizerConstructor') | |
# parameter-wise setting: cancel weight decay for some specific modules | |
custom_keys = dict() | |
for name, module in model.named_modules(): | |
if hasattr(module, 'zero_weight_decay'): | |
custom_keys.update({(name, key): dict(decay_mult=0) for key in module.zero_weight_decay}) | |
paramwise_cfg = Config(dict(cfg=dict(custom_keys=custom_keys))) | |
given_cfg = optimizer_cfg.get('paramwise_cfg') | |
if given_cfg: | |
paramwise_cfg.merge_from_dict(dict(cfg=given_cfg)) | |
optimizer_cfg['paramwise_cfg'] = paramwise_cfg.cfg | |
# build optimizer | |
optimizer = mm_build_optimizer(model, optimizer_cfg) | |
weight_decay_groups = dict() | |
lr_groups = dict() | |
for group in optimizer.param_groups: | |
if not group.get('requires_grad', True): continue | |
lr_groups.setdefault(group['lr'], []).append(group) | |
weight_decay_groups.setdefault(group['weight_decay'], []).append(group) | |
learnable_count, fix_count = 0, 0 | |
for p in model.parameters(): | |
if p.requires_grad: | |
learnable_count += 1 | |
else: | |
fix_count += 1 | |
fix_info = f"{learnable_count} are learnable, {fix_count} are fix" | |
lr_info = "Lr group: " + ", ".join([f'{len(group)} params with lr {lr:.5f}' for lr, group in lr_groups.items()]) | |
wd_info = "Weight decay group: " + ", ".join( | |
[f'{len(group)} params with weight decay {wd}' for wd, group in weight_decay_groups.items()]) | |
opt_info = f"{optimizer.__class__.__name__} Optimizer: total {len(optimizer.param_groups)} param groups, {fix_info}. {lr_info}; {wd_info}." | |
logger.info(opt_info) | |
return optimizer | |
class Lion(Optimizer): | |
def __init__( | |
self, | |
params, | |
lr: float = 1e-4, | |
betas: Tuple[float, float] = (0.9, 0.99), | |
weight_decay: float = 0.0, | |
): | |
assert lr > 0. | |
assert all([0. <= beta <= 1. for beta in betas]) | |
defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) | |
super().__init__(params, defaults) | |
def update_fn(p, grad, exp_avg, lr, wd, beta1, beta2): | |
# stepweight decay | |
p.data.mul_(1 - lr * wd) | |
# weight update | |
update = exp_avg.clone().lerp_(grad, 1 - beta1).sign_() | |
p.add_(update, alpha=-lr) | |
# decay the momentum running average coefficient | |
exp_avg.lerp_(grad, 1 - beta2) | |
def exists(val): | |
return val is not None | |
def step( | |
self, | |
closure: Optional[Callable] = None | |
): | |
loss = None | |
if self.exists(closure): | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
for p in filter(lambda p: self.exists(p.grad), group['params']): | |
grad, lr, wd, beta1, beta2, state = p.grad, group['lr'], group['weight_decay'], *group['betas'], \ | |
self.state[p] | |
# init state - exponential moving average of gradient values | |
if len(state) == 0: | |
state['exp_avg'] = torch.zeros_like(p) | |
exp_avg = state['exp_avg'] | |
self.update_fn( | |
p, | |
grad, | |
exp_avg, | |
lr, | |
wd, | |
beta1, | |
beta2 | |
) | |
return loss | |
class CAMEWrapper(CAME): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) |