对7.1的AlexNet来说,学习率0.02时Adam发散/Yogi收敛,0.01时都收敛,0.03时都发散
# 创建yogi类
from torch.optim import Optimizer
class Yogi(Optimizer):
def __init__(self, params, lr=0.01,t=1):
defaults = dict(lr=lr, t=t)
super(Yogi, self).__init__(params, defaults)
device = self.param_groups[0]['params'][0].device
self.states = self.init_adam_states(self.param_groups[0]['params'], device=device)
self.hyperparams = {'lr':lr, 't':t}
def step(self, closure=None):
for group in self.param_groups:
self.yogi_step(group['params'], self.states, self.hyperparams)
break
def init_adam_states(self, params, device):
feature_dim = params[0].shape[1]
states = []
for param in params:
v, s = torch.zeros(param.shape),torch.zeros(param.shape)
v, s = v.to(device), s.to(device)
states.append((v,s))
return states
def yogi_step(self, params, states, hyperparams):
beta1, beta2, eps = 0.9, 0.999, 1e-3
for p, (v, s) in zip(params, states):
with torch.no_grad():
v[:] = beta1 * v + (1 - beta1) * p.grad
s[:] = s + (1 - beta2) * torch.sign(
torch.square(p.grad) - s) * torch.square(p.grad)
v_bias_corr = v / (1 - beta1 ** hyperparams['t'])
s_bias_corr = s / (1 - beta2 ** hyperparams['t'])
p[:] -= ( hyperparams['lr'] * v_bias_corr / (torch.sqrt(s_bias_corr)
+ eps) )
p.grad.data.zero_()
hyperparams['t'] += 1