File size: 4,517 Bytes
826d651
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
from bisect import bisect_right

class _LRScheduler(object):
    def __init__(self, optimizer, last_iter=-1):
        if not isinstance(optimizer, torch.optim.Optimizer):
            raise TypeError('{} is not an Optimizer'.format(
                type(optimizer).__name__))
        self.optimizer = optimizer
        if last_iter == -1:
            for group in optimizer.param_groups:
                group.setdefault('initial_lr', group['lr'])
        else:
            for i, group in enumerate(optimizer.param_groups):
                if 'initial_lr' not in group:
                    raise KeyError("param 'initial_lr' is not specified "
                                   "in param_groups[{}] when resuming an optimizer".format(i))
        self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
        self.last_iter = last_iter

    def _get_new_lr(self):
        raise NotImplementedError

    def get_lr(self):
        return list(map(lambda group: group['lr'], self.optimizer.param_groups))

    def step(self, this_iter=None):
        if this_iter is None:
            this_iter = self.last_iter + 1
        self.last_iter = this_iter
        for param_group, lr in zip(self.optimizer.param_groups, self._get_new_lr()):
            param_group['lr'] = lr

class _WarmUpLRSchedulerOld(_LRScheduler):

    def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
        self.base_lr = base_lr
        self.warmup_steps = warmup_steps
        if warmup_steps == 0:
            self.warmup_lr = base_lr
        else:
            self.warmup_lr = warmup_lr
        super(_WarmUpLRSchedulerOld, self).__init__(optimizer, last_iter)
    
    def _get_warmup_lr(self):
        if self.warmup_steps > 0 and self.last_iter < self.warmup_steps:
            # first compute relative scale for self.base_lr, then multiply to base_lr
            scale = ((self.last_iter/self.warmup_steps)*(self.warmup_lr - self.base_lr) + self.base_lr)/self.base_lr
            #print('last_iter: {}, warmup_lr: {}, base_lr: {}, scale: {}'.format(self.last_iter, self.warmup_lr, self.base_lr, scale))
            return [scale * base_lr for base_lr in self.base_lrs]
        else:
            return None

class _WarmUpLRScheduler(_LRScheduler):

    def __init__(self, optimizer, base_lr, warmup_lr, warmup_steps, last_iter=-1):
        self.base_lr = base_lr
        self.warmup_lr = warmup_lr
        self.warmup_steps = warmup_steps
        assert isinstance(warmup_lr, list)
        assert isinstance(warmup_steps, list)
        assert len(warmup_lr) == len(warmup_steps)
        super(_WarmUpLRScheduler, self).__init__(optimizer, last_iter)
    
    def _get_warmup_lr(self):
        pos = bisect_right(self.warmup_steps, self.last_iter)
        if pos >= len(self.warmup_steps):
            return None
        else:
            if pos == 0:
                curr_lr = self.base_lr + self.last_iter * (self.warmup_lr[pos] - self.base_lr) / self.warmup_steps[pos]
            else:
                curr_lr = self.warmup_lr[pos - 1] + (self.last_iter - self.warmup_steps[pos - 1]) * (self.warmup_lr[pos] - self.warmup_lr[pos - 1]) / (self.warmup_steps[pos] - self.warmup_steps[pos - 1])
        scale = curr_lr / self.base_lr
        return [scale * base_lr for base_lr in self.base_lrs]

class StepLRScheduler(_WarmUpLRScheduler):
    def __init__(self, optimizer, milestones, lr_mults, base_lr, warmup_lr, warmup_steps, last_iter=-1):
        super(StepLRScheduler, self).__init__(optimizer, base_lr, warmup_lr, warmup_steps, last_iter)

        assert len(milestones) == len(lr_mults), "{} vs {}".format(milestones, lr_mults)
        for x in milestones:
            assert isinstance(x, int)
        if not list(milestones) == sorted(milestones):
            raise ValueError('Milestones should be a list of'
                             ' increasing integers. Got {}', milestones)
        self.milestones = milestones
        self.lr_mults = [1.0]
        for x in lr_mults:
            self.lr_mults.append(self.lr_mults[-1]*x)
    
    def _get_new_lr(self):
        warmup_lrs = self._get_warmup_lr()
        if warmup_lrs is not None:
            return warmup_lrs

        pos = bisect_right(self.milestones, self.last_iter)
        if len(self.warmup_lr) == 0:
            scale = self.lr_mults[pos]
        else:
            scale = self.warmup_lr[-1] * self.lr_mults[pos] / self.base_lr
        return [base_lr * scale for base_lr in self.base_lrs]