Implements Adafactor algorithm.
This implementation is based on: Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
(see https://arxiv.org/abs/1804.04235)
Note that this optimizer internally adjusts the learning rate depending on the scale_parameter, relative_step and warmup_init options. To use a manual (external) learning rate schedule you should set scale_parameter=False
and relative_step=False
.
Arguments
`params` (iterable): iterable of parameters to optimize or dicts defining parameter groups
`lr` (float, optional): external learning rate (default: None)
`eps` (tuple[float, float]): regularization constans for square gradient and parameter scale respectively (default: (1e-30, 1e-3))
`clip_threshold` (float): threshold of root mean square of final gradient update (default: 1.0)
`decay_rate` (float): coefficient used to compute running averages of square gradient (default: -0.8)
`mom` (float): coefficient used for computing running averages of gradient (default: None)
`weight_decay` (float, optional): weight decay (L2 penalty) (default: 0)
`scale_parameter` (bool): if True, learning rate is scaled by root mean square of parameter (default: True)
`relative_step` (bool): if True, time-dependent learning rate is computed instead of external learning rate (default: True)
`warmup_init` (bool): time-dependent learning rate computation depends on whether warm-up initialization is being used (default: False)
@delegates(Adafactor.__init__)
def adafactor(param_groups, **kwargs):
return OptimWrapper(Adafactor([{'params': ps, **kwargs} for ps in param_groups]))
Wrapping a pytorch optimizer in fastai's OptimWrapper
enables its use with fastai
ps = [tensor([1,2,3])] #, tensor([4,5,6])]
adaf = Adafactor(ps, mom=0.9, weight_decay=1e-2)
test_adaf = adafactor(param_groups=ps, mom=0.9, weight_decay=1e-2)
#Access to param_groups
test_eq(test_adaf.param_lists[0], adaf.param_groups[0]['params'])
#Set param_groups
test_adaf.param_lists = [[tensor([4,5,6])]]
test_eq(test_adaf.opt.param_groups[0]['params'], [tensor(4,5,6)])
#Access to hypers
# test_eq(test_adaf.hypers, [{**adaf.defaults}])
# #Set hypers
test_adaf.set_hyper('mom', 0.95)
test_eq(test_adaf.opt.param_groups[0]['mom'], 0.95)