Adam算法

https://zh.d2l.ai/chapter_optimization/adam.html

11.10.1部分,标准化解释不清晰。


修改:

1 Like

对7.1的AlexNet来说,学习率0.02时Adam发散/Yogi收敛,0.01时都收敛,0.03时都发散

# 创建yogi类
from torch.optim import Optimizer
class Yogi(Optimizer):
    def __init__(self, params, lr=0.01,t=1):
        defaults = dict(lr=lr, t=t)
        super(Yogi, self).__init__(params, defaults)
        device = self.param_groups[0]['params'][0].device
        self.states = self.init_adam_states(self.param_groups[0]['params'], device=device)
        
        self.hyperparams = {'lr':lr, 't':t}

    def step(self, closure=None):
        for group in self.param_groups:
            self.yogi_step(group['params'], self.states, self.hyperparams)
            break
                
    def init_adam_states(self, params, device):
        feature_dim = params[0].shape[1]
        states = []
        for param in params:
            v, s = torch.zeros(param.shape),torch.zeros(param.shape)
            v, s = v.to(device), s.to(device)
            states.append((v,s))
        return states
    
    def yogi_step(self, params, states, hyperparams):
        beta1, beta2, eps = 0.9, 0.999, 1e-3
        for p, (v, s) in zip(params, states):
            with torch.no_grad():
                v[:] = beta1 * v + (1 - beta1) * p.grad
                s[:] = s + (1 - beta2) * torch.sign(
                    torch.square(p.grad) - s) * torch.square(p.grad)
                v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
                s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
                p[:] -= ( hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
                                                           + eps) )
            p.grad.data.zero_()
        hyperparams['t'] += 1