import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
def _output_m_compute(m, beta1_broad, grad):
"""
_output_m_compute
"""
input_dtype = m.dtype
sneg_one = torch.ones((1), dtype=input_dtype) * -1
vsub_beta1_1 = torch.add(beta1_broad, sneg_one)
vsub_m_grad = torch.sub(m, grad)
vmul_m = torch.mul(vsub_beta1_1, vsub_m_grad)
m_t = torch.add(m, vmul_m)
return m_t
def _output_v_compute(v, beta2, grad):
"""_output_v_compute
do compute v_t = v + (1 - beta2)*(grad*grad -v)
"""
input_dtype = v.dtype
shape_m_grad = v.shape
sneg_one = torch.ones((1), dtype=input_dtype) * -1
beta2_tensor = torch.tensor(beta2, dtype=input_dtype)
beta2_broad = beta2_tensor.expand_as(v)
vsub_beta2_1 = torch.add(beta2_broad, sneg_one)
vmul_grad_grad = torch.mul(grad, grad)
vsub_v_grad = torch.sub(v, vmul_grad_grad)
vmul_grad = torch.mul(vsub_beta2_1, vsub_v_grad)
v_t = torch.add(v, vmul_grad)
return v_t
def _inner_eps_add_sqrt_vt_compute(epsilon, v_t):
"""
(epsilon + sqrt(v_t) )
"""
sqrt_vt = torch.sqrt(v_t)
compute_shape = v_t.shape
input_dtype = v_t.dtype
epsilon_tensor = torch.tensor(epsilon, dtype=input_dtype)
epsilon_broad = epsilon_tensor.expand_as(v_t)
v_add_sqrt_v = torch.add(sqrt_vt, epsilon_broad)
return v_add_sqrt_v
def _inner_lr_compute(lr, beta2_power, beta1_power, compute_shape_tensor):
"""
_inner_lr_compute
#lr_t = learning_rate * (sqrt(1-beta2_power)) / (1 - beta1_power)
"""
input_dtype = compute_shape_tensor.dtype
s_one = torch.ones((1), dtype=input_dtype)
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
v_neg_beta2_power = torch.mul(beta2_power, s_neg_one)
v_add_beta2_power = torch.add(v_neg_beta2_power, s_one)
v_sqrt_beta2_power = torch.sqrt(v_add_beta2_power)
v_neg_beta1_power = torch.mul(beta1_power, s_neg_one)
v_add_beta1_power = torch.add(v_neg_beta1_power, s_one)
res = torch.mul(lr, v_sqrt_beta2_power)
res = torch.div(res, v_add_beta1_power)
return res.expand_as(compute_shape_tensor)
def _output_var_t_compute_use_nesterov(var, lr_t, m_t, beta1_broad, grad, epsilon, v_t):
"""
_output_var_t_compute_use_nesterov
# var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))
# var_t = var - lr_t * (m_t * beta1 + (1 - beta1) * grad) / (epsilon + sqrt(v_t))
"""
input_dtype = var.dtype
compute_shape = var.shape
s_one = torch.ones((1), dtype=input_dtype)
s_neg_one = torch.ones((1), dtype=input_dtype) * -1
v_muls_mt_beta1 = torch.mul(m_t, beta1_broad)
v_neg_beta1 = torch.mul(beta1_broad, s_neg_one)
vsub_1_beta1 = torch.add(v_neg_beta1, s_one)
v_mul_grad = torch.mul(vsub_1_beta1, grad)
v_div_left = torch.add(v_muls_mt_beta1, v_mul_grad)
lrt_broad = lr_t.expand_as(var)
v_mul_left = torch.mul(lrt_broad, v_div_left)
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
v_t = torch.sub(var, v_div_res)
return v_t
def _output_var_t_compute(var, lr_t, m_t, epsilon, v_t):
"""
_output_var_t_compute
`var_t = var - lr_t * m_t / (epsilon + sqrt(v_t))`
"""
v_mul_left = torch.mul(lr_t, m_t)
v_add_sqrt_v = _inner_eps_add_sqrt_vt_compute(epsilon, v_t)
v_div_res = torch.div(v_mul_left, v_add_sqrt_v)
v_t = torch.sub(var, v_div_res)
return v_t
def apply_adam_d(beta1_power, beta2_power, lr, beta1, beta2, epsilon, grad, use_locking, use_nesterov, var, m, v):
shape_m_grad = m.shape
input_dtype = m.dtype
beta1_tensor = torch.tensor(beta1, dtype=input_dtype)
beta1_broad = beta1_tensor.expand_as(m)
m_t = _output_m_compute(m, beta1_broad, grad)
v_t = _output_v_compute(v, beta2, grad)
compute_shape = m.shape
lr_r = _inner_lr_compute(lr, beta2_power, beta1_power, m)
if use_nesterov is True:
var_t = _output_var_t_compute_use_nesterov(var, lr_r, m_t, beta1_broad, grad, epsilon, v_t)
else:
var_t = _output_var_t_compute(var, lr_r, m_t, epsilon, v_t)
res = [var_t, m_t, v_t]
return res
class TestApplyAdam(TestCase):
def test_apply_adam(self):
var1 = torch.randn(2, 2, 2, 2, dtype=torch.float32)
m1 = torch.randn(2, 2, 2, 2, dtype=torch.float32)
v1 = torch.randn(2, 2, 2, 2, dtype=torch.float32)
grad1 = torch.randn(2, 2, 2, 2, dtype=torch.float32)
var2 = var1.to(torch.half)
m2 = m1.to(torch.half)
v2 = v1.to(torch.half)
grad2 = grad1.to(torch.half)
_, _, v1_c = apply_adam_d(1, 1, 0.2, 0.2, 0.2, 0.2, grad1, False, False, var1, m1, v1)
_, _, v2_c = apply_adam_d(1, 1, 0.2, 0.2, 0.2, 0.2, grad2, False, False, var2, m2, v2)
_, _, v1_o = torch_npu.npu_apply_adam(1, 1, 0.2, 0.2, 0.2, 0.2, grad1.to("npu"), False, False, out=(var1.to("npu"), m1.to("npu"), v1.to("npu")))
_, _, v2_o = torch_npu.npu_apply_adam(1, 1, 0.2, 0.2, 0.2, 0.2, grad2.to("npu"), False, False, out=(var2.to("npu"), m2.to("npu"), v2.to("npu")))
self.assertRtolEqual(v1_c, v1_o.cpu())
self.assertRtolEqual(v2_c, v2_o.cpu())
def test_apply_adam_out_fp32(self):
var = torch.randn(2, 2, 2, 2, dtype=torch.float32)
m = torch.randn(2, 2, 2, 2, dtype=torch.float32)
v = torch.randn(2, 2, 2, 2, dtype=torch.float32)
grad = torch.randn(2, 2, 2, 2, dtype=torch.float32)
bt1p = 0.9
bt2p = 0.9
lr = 0.2
bt1 = 0.2
bt2 = 0.2
ep = 0.2
ul = False
un = False
var_c, m_c, v_c = apply_adam_d(bt1p, bt2p, lr, bt1, bt2, ep, grad, ul, un, var, m, v)
var_o, m_o, v_o = torch_npu.npu_apply_adam(bt1p, bt2p, lr, bt1, bt2, ep, grad.to("npu"), ul, un, out=(var.to("npu"), m.to("npu"), v.to("npu")))
self.assertRtolEqual(m_c, m_o.cpu())
self.assertRtolEqual(v_c, v_o.cpu())
if __name__ == "__main__":
run_tests()