from collections import defaultdict
import math
import torch
from torch.optim.optimizer import required
from torch_npu.utils import npu_combine_tensors
from torch_npu.utils._error_code import ErrCode, pta_error
from .npu_fused_optim_base import NpuFusedOptimizerBase
LR_MIN = 0.0
MOMENTUM_MIN = 0.0
DAMPENING_DEFAULT = 0.0
WEIGHT_DECAY_MIN = 0.0
__all__ = ["NpuFusedSGD"]
class NpuFusedSGD(NpuFusedOptimizerBase):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional, default=0): momentum factor
weight_decay (float, optional, default=0): weight decay (L2 penalty)
dampening (float, optional, default=0): dampening for momentum
nesterov (bool, optional, default=False): enables Nesterov momentum
"""
def __init__(self,
params,
lr=required,
momentum=MOMENTUM_MIN,
dampening=DAMPENING_DEFAULT,
weight_decay=WEIGHT_DECAY_MIN,
nesterov=False):
if lr is not required and lr < LR_MIN:
raise ValueError("Invalid learning rate: {}".format(lr) + pta_error(ErrCode.VALUE))
if momentum < MOMENTUM_MIN:
raise ValueError("Invalid momentum value: {}".format(momentum) + pta_error(ErrCode.VALUE))
if weight_decay < WEIGHT_DECAY_MIN:
raise ValueError(
"Invalid weight_decay value: {}".format(weight_decay) + pta_error(ErrCode.VALUE))
if nesterov and (momentum <= MOMENTUM_MIN
or not math.isclose(dampening, DAMPENING_DEFAULT, abs_tol=1e-15)):
raise ValueError(
"Nesterov momentum requires a momentum and zero dampening" + pta_error(ErrCode.VALUE))
defaults = dict(lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov)
self._momentum_buffer_already_in_state = False
super(NpuFusedSGD, self).__init__(params, defaults)
def __setstate__(self, state):
super(NpuFusedSGD, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def _init_param_state(self, p, weight_decay):
d_p = p.grad
state = self.state[p]
if 'momentum_buffer' not in state:
self._momentum_buffer_already_in_state = False
if weight_decay != 0:
d_p = d_p.add(p, alpha=weight_decay)
state['momentum_buffer'] = torch.clone(d_p).detach()
else:
self._momentum_buffer_already_in_state = True
temp = torch.clone(d_p).detach()
temp.copy_(state['momentum_buffer'])
state['momentum_buffer'] = temp
def _combine_group_param_states(self, group_index):
group = self.param_groups[group_index]
group_params = self.params_lists_indexed_by_group[group_index]
weight_decay = group['weight_decay']
momentum = group['momentum']
group_combined_states = []
for group_params_one_dtype in group_params:
if momentum == 0:
group_params_one_dtype = defaultdict(dict)
group_params_one_dtype['momentum_buffer'] = None
group_combined_states.append(group_params_one_dtype)
else:
momentum_buffer_list = []
for p in group_params_one_dtype:
if p.grad is None:
continue
self._init_param_state(p, weight_decay)
state = self.state[p]
momentum_buffer_list.append(state['momentum_buffer'])
combined_momentum_buffer = None
if len(momentum_buffer_list) > 0:
combined_momentum_buffer = npu_combine_tensors(
momentum_buffer_list)
combined_states_one_dtype = defaultdict(dict)
combined_states_one_dtype[
'momentum_buffer'] = combined_momentum_buffer
group_combined_states.append(combined_states_one_dtype)
self.combined_param_states_indexed_by_group[
group_index] = group_combined_states
def _maybe_init_combined_states(self):
if self.is_states_combined:
return
self.combined_param_states_indexed_by_group = len(self.param_groups) * [None]
for i, _ in enumerate(self.param_groups):
self._combine_group_param_states(i)
if not all(value is None for value in self.combined_param_states_indexed_by_group):
self.is_states_combined = True
def _group_step(self, group_index):
group = self.param_groups[group_index]
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
combined_group_params = self.combined_params_indexed_by_group[
group_index]
combined_group_grads = self.combined_grads_indexed_by_group[
group_index]
combined_group_param_states = self.combined_param_states_indexed_by_group[
group_index]
for combined_param_one_dtype, combined_grad_one_dtype, combined_param_state_one_dtype in zip(
combined_group_params, combined_group_grads,
combined_group_param_states):
if combined_param_one_dtype is None or combined_grad_one_dtype is None:
continue
if weight_decay != 0:
combined_grad_one_dtype = combined_grad_one_dtype.add(
combined_param_one_dtype, alpha=weight_decay)
if momentum != 0:
buf = combined_param_state_one_dtype['momentum_buffer']
if self._momentum_buffer_already_in_state:
buf.mul_(momentum).add_(combined_grad_one_dtype,
alpha=1 - dampening)
if nesterov:
combined_grad_one_dtype = combined_grad_one_dtype.add(
buf, alpha=momentum)
else:
combined_grad_one_dtype = buf
combined_param_one_dtype.add_(combined_grad_one_dtype,
alpha=-group['lr'])
def step(self, closure=None):
ret = super().step(closure)
self._momentum_buffer_already_in_state = True
return ret