import logging log = logging.getLogger() def get_parameter_groups(model, stage_cfg, print_log=False): """ Assign different weight decays and learning rates to different parameters. Returns a parameter group which can be passed to the optimizer. """ weight_decay = stage_cfg.weight_decay embed_weight_decay = stage_cfg.embed_weight_decay backbone_lr_ratio = stage_cfg.backbone_lr_ratio base_lr = stage_cfg.learning_rate backbone_params = [] embed_params = [] other_params = [] embedding_names = ['summary_pos', 'query_init', 'query_emb', 'obj_pe'] embedding_names = [e + '.weight' for e in embedding_names] # inspired by detectron2 memo = set() for name, param in model.named_parameters(): if not param.requires_grad: continue # Avoid duplicating parameters if param in memo: continue memo.add(param) if name.startswith('module'): name = name[7:] inserted = False if name.startswith('pixel_encoder.'): backbone_params.append(param) inserted = True if print_log: log.info(f'{name} counted as a backbone parameter.') else: for e in embedding_names: if name.endswith(e): embed_params.append(param) inserted = True if print_log: log.info(f'{name} counted as an embedding parameter.') break if not inserted: other_params.append(param) parameter_groups = [ { 'params': backbone_params, 'lr': base_lr * backbone_lr_ratio, 'weight_decay': weight_decay }, { 'params': embed_params, 'lr': base_lr, 'weight_decay': embed_weight_decay }, { 'params': other_params, 'lr': base_lr, 'weight_decay': weight_decay }, ] return parameter_groups