from typing import List
import torch
import torch.distributed
import torch_npu
from torch import Tensor
from .utils import TensorManager, ShareMemory
class CompressTensor(TensorManager):
""" Class for managing first and second-order momentum
compression and decompression states.
"""
def __init__(self, like_tensor: torch.Tensor, compress_ratio: float = 0.5) -> None:
if compress_ratio < 0 or compress_ratio > 1:
raise ValueError
super().__init__(like_tensor, compress_ratio)
self.normal_init(like_tensor)
self.filter_condition = self.can_be_compress(like_tensor)
if self.filter_condition:
self.compress_init(like_tensor)
self.shape, self.dtype, self.device = like_tensor.shape, like_tensor.dtype, like_tensor.device
def normal_init(self, like_tensor: torch.Tensor):
self.tensor = torch.zeros(like_tensor.shape, device=like_tensor.device, dtype=like_tensor.dtype)
def compress_init(self, like_tensor: torch.Tensor):
var = ShareMemory(numel=like_tensor.numel() // 2, dtype=like_tensor.dtype)
var.can_be_used = False
self.malloc(var, statistic=True)
self.encode()
def encode_state(self):
self.encode()
self.tensor = None
def decode_state(self):
self.tensor = torch.empty(self.shape, dtype=self.dtype, device=self.device)
self.decode()
def recover(self):
if self.filter_condition:
self.decode_state()
return self.tensor
def update(self, step):
if self.filter_condition:
self.encode_state()
self.adjust_pdf_statistic(step)
def adjust_pdf_statistic(self, step):
if step < 3 or step % 100 == 0:
self.statistic = True
else:
self.statistic = False
@staticmethod
def can_be_compress(tensor: torch.Tensor):
return tensor.numel() % 64 == 0 and tensor.numel() > 32768
def compress_adamw_impl(params: List[Tensor],
grads: List[Tensor],
exp_avgs: List[CompressTensor],
exp_avg_sqs: List[CompressTensor],
max_exp_avg_sqs: List[Tensor],
step: int,
*,
amsgrad: bool,
beta1: float,
beta2: float,
lr: float,
weight_decay: float,
eps: float,
maximize: bool):
r"""Functional API that performs AdamW algorithm computation.
See :class:`~torch.optim.AdamW` for details.
"""
for i, param in enumerate(params):
grad = grads[i]
exp_avg = exp_avgs[i].recover()
exp_avg_sq = exp_avg_sqs[i].recover()
bias_correction1 = beta1 ** (step - 1)
bias_correction2 = beta2 ** (step - 1)
param.data, exp_avg, exp_avg_sq = torch_npu.npu_apply_adam_w(
bias_correction1,
bias_correction2,
lr,
weight_decay,
beta1,
beta2,
eps,
grad,
None,
amsgrad,
maximize,
out=(param.data, exp_avg, exp_avg_sq)
)
exp_avgs[i].update(step)
exp_avg_sqs[i].update(step)
def compress_optimizer_step_impl(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
if 'step' in group:
group['step'] += 1
else:
group['step'] = 1
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
if len(state) == 0:
state['exp_avg'] = CompressTensor(p)
state['exp_avg_sq'] = CompressTensor(p, 0.4)
if amsgrad:
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
compress_adamw_impl(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
group['step'],
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=group['maximize'])
return loss