Optimizers are essential tools within the modeling stack. One of the vital widely used optimizers is the Adam optimizer introduced by Kingma and Ba [paper]. This optimizer keeps track of running averages of the gradient (aka momentum term) and the second moment (aka energy term) of the gradient using exponential moving average (EMA) filters and uses the square-root of the energy term to normalize the momentum term before taking a step.
This looks as if a very good idea, right? For one it looks like a diagonal approximation of second order optimization and two, when the gradients are noisy, the denominator (square root of the energy term) is large relative to the numerator (momentum term) and steps are small. Then again when gradients are consistent, the denominator is roughly equal to the numerator, and we take constant sized steps equal to the educational rate. That’s why this optimizer is the de facto selection amongst ML researchers and practitioners.
Many variations of the Adam optimizer have been proposed and studied (see AdamW from PyTorch docs and optimizers like QHAdam from the torch_optimizer package). Here, we propose a variation that doesn’t deviate from Adam significantly for noisy “recent” gradients but greatly amplifies the actual step size for parameters with a consistent “recent” gradient history (how recent is dependent upon the EMA parameters β1 and β2).
Our modification is easy, yet effective: We replace the EMA filter of the gradient energy term with an EMA filter of the gradient variance term. This implies final update equation θ(t) = θ(t-1) – γ * SNR, where SNR refers back to the signal-to-noise ratio of the “recent” gradient history (SNR is a term borrowed from signal processing literature and refers back to the ratio of the common of a signal to its standard-deviation). So parameters with a high SNR, i.e., with consistent “recent” gradient histories will see a much larger step size (as large as infinity if the gradient is constant) than those who have noisy “recent” gradient histories (or low gradient SNRs). The implementation of this optimizer is easy and is given below for completeness.
from typing import Tupleimport torch
from torch.optim.optimizer import Optimizer
class SNRAdam(Optimizer):
r"""Implements the SNRAdam optimization algorithm, which uses std deviation for the denominator quite than
sqrt(energy) term utilized in conventional Adam. Why is that this a very good idea? If gradient stddev for a param is small, we
should take larger steps because it means the gradient is consistent over time.
Arguments:
params: iterable of parameters to optimize or dicts defining
parameter groups
lr: learning rate (default: 1e-3)
betas: coefficients used for computing
running averages of gradient and its variance (default: (0.9, 0.999))
eps: term added to the denominator to enhance
numerical stability (default: 1e-8)
weight_decay: weight decay (L2 penalty) (default: 0)
"""
def __init__(
self,
params,
lr: float = 1e-3,
betas: Tuple[float, float] = (0.9, 0.999),
weight_decay: float = 0.0,
eps: float = 1e-8,
):
if lr <= 0.0:
raise ValueError('Invalid learning rate: {}'.format(lr))
if eps < 0.0:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError(
'Invalid beta parameter at index 0: {}'.format(betas[0])
)
if not 0.0 <= betas[1] < 1.0:
raise ValueError(
'Invalid beta parameter at index 1: {}'.format(betas[1])
)
if weight_decay < 0:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay)
)
defaults = {
'lr': lr,
'betas': betas,
'weight_decay': weight_decay,
'eps': eps,
}
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure: A closure that reevaluates the model and returns the loss.
"""
loss = None
if closure isn't None:
loss = closure()
for group in self.param_groups:
lr = group['lr']
beta1, beta2 = group['betas']
weight_decay = group['weight_decay']
eps = group['eps']
for p in group['params']:
if p.grad is None:
proceed
d_p = p.grad.data
if d_p.is_sparse:
raise RuntimeError(
'SNRAdam doesn't support sparse gradients, '
'please consider SparseAdam as a substitute'
)
state = self.state[p]
if weight_decay != 0:
p.data.mul_(1 - lr * weight_decay)
if len(state) == 0:
state['iter_'] = 1
state['exp_avg'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format
)
state['exp_avg_sq'] = torch.zeros_like(
p.data, memory_format=torch.preserve_format
)
iter_ = state['iter_']
exp_avg = state['exp_avg']
if iter_ == 1:
d_sub_p_sq = d_p - exp_avg
else:
d_sub_p_sq = d_p - exp_avg.mul(1.0 / (1 - beta1 ** (iter_ - 1)))
d_sub_p_sq.mul_(d_sub_p_sq)
exp_avg_sq = state['exp_avg_sq']
exp_avg.mul_(beta1).add_(d_p, alpha=1.0 - beta1)
exp_avg_sq.mul_(beta2).add_(d_sub_p_sq, alpha=1.0 - beta2)
p.data.addcdiv_(exp_avg.mul(1.0 / (1 - beta1 ** iter_)),
exp_avg_sq.mul(1.0 / (1 - beta2 ** iter_)).sqrt() + eps, value=-lr)
state['iter_'] += 1
return loss
We ran experiments using an easy 100K parameter Vision Transformer model (loosely inspired by this medium post) on the MNIST dataset with batch size 4096 over 20 epochs with learning rate 1e-3. The outcomes point towards quick convergence for the proposed optimizer in comparison to Adam optimizer (we plot one among the runs but noticed the identical behavior consistently for this model and dataset combination):
In an effort to disentangle the source of the gains, we make the batch size = ∞ and compare the 2 algorithms. This shows whether the gains from correcting for the “stochastic” in stochastic gradient descent or the noise within the gradient coming from the optimization trajectory (the gradient descent portion). We see that the gains come from compensating for the latter (noise in gradient coming from the trajectory quite than SGD):