import unittest
from copy import deepcopy
import torch
from torch.optim import SGD, Adam, AdamW, Adadelta, RMSprop
from torch_npu.npu.amp import GradScaler, autocast
from torch_npu.optim import (
NpuFusedSGD, NpuFusedAdam, NpuFusedAdamW, NpuFusedAdamP,
NpuFusedLamb, NpuFusedAdadelta, NpuFusedBertAdam,
NpuFusedRMSprop, NpuFusedRMSpropTF
)
from torch_npu.testing.testcase import TestCase, run_tests
class TestFusedOptim(TestCase):
def __init__(self, method_name='runTest'):
super(TestCase, self).__init__(method_name)
self.optim_cases = [
(SGD, NpuFusedSGD, dict(lr=0.01, momentum=0.9, weight_decay=0.001)),
(Adam, NpuFusedAdam, dict(eps=1e-8, betas=(0.9, 0.999), lr=2e-3, weight_decay=0.05)),
(AdamW, NpuFusedAdamW, dict(eps=1e-8, betas=(0.9, 0.999), lr=2e-3, weight_decay=0.05)),
(Adadelta, NpuFusedAdadelta, dict(lr=1.0, rho=0.9, eps=1e-6, weight_decay=0.05)),
(RMSprop, NpuFusedRMSprop, dict(eps=0.001, lr=0.01, weight_decay=1e-5)),
(None, NpuFusedAdamP, dict(eps=1e-5, betas=(0.9, 0.999), lr=2e-3, weight_decay=0.05)),
(None, NpuFusedLamb, dict(lr=0.01, eps=1e-5)),
(None, NpuFusedBertAdam, dict(lr=0.01, warmup=0.1, t_total=20, max_grad_norm=-1)),
(None, NpuFusedRMSpropTF, dict(eps=0.001, lr=0.01, weight_decay=1e-5)),
]
self.base_cases = [SGD, Adam, AdamW, Adadelta, RMSprop]
self.third_optim_baseline = dict()
self.third_optim_baseline[NpuFusedAdamP] = [14.885, 65.714, 14.882, 65.75, 104.615, 152.75]
self.third_optim_baseline[NpuFusedBertAdam] = [12.982, 61.537, 13.023, 61.5625, 99.305, 146.125]
self.third_optim_baseline[NpuFusedLamb] = [13.407, 62.683, 13.414, 62.625, 101.258, 149.0]
self.third_optim_baseline[NpuFusedRMSpropTF] = [14.9797, 65.911, 15.0, 66.0, 104.8588, 153.0]
def _create_optimizer_cases(self, all_cases=False):
optim_cases = self.optim_cases
if not all_cases:
optim_cases = list(filter(lambda x: x[0] in self.base_cases, optim_cases))
return optim_cases
def _create_simple_model(self):
model = torch.nn.Sequential(
torch.nn.Conv2d(1, 8, kernel_size=3),
torch.nn.BatchNorm2d(8, momentum=0.3),
torch.nn.ReLU(),
torch.nn.Linear(22, 12),
)
model.to("npu:0")
return model
def _create_simple_params_and_grads(self):
params = [
torch.arange(6).reshape(2, 3).float().npu(),
torch.arange(12).reshape(4, 3).float().npu(),
torch.arange(6).reshape(2, 3).half().npu(),
torch.arange(12).reshape(4, 3).half().npu(),
torch.arange(15).reshape(5, 3).float().npu(),
torch.arange(18).reshape(6, 3).half().npu(),
torch.arange(6).reshape(2, 3).float().npu(),
]
for i, p in enumerate(params):
if i < len(params) - 1:
p.requires_grad = True
p.grad = p.clone().detach() / 100.
return params
def _create_params_clone(self, params):
params_clone = []
for p in params:
p_clone = p.clone().detach()
if p.requires_grad:
p_clone.requires_grad = True
p_clone.grad = p.grad.clone().detach()
params_clone.append(p_clone)
return params_clone
def test_zero_grad(self):
optim_cases = self._create_optimizer_cases()
for opt_obj, fused_opt_obj, opt_kwargs in optim_cases:
params = self._create_simple_params_and_grads()
params_clone = self._create_params_clone(params)
with torch.no_grad():
opt = opt_obj(params, **opt_kwargs)
opt.zero_grad()
fused_opt = fused_opt_obj(params_clone, **opt_kwargs)
fused_opt.zero_grad()
for p, p_clone in zip(params, params_clone):
if p.grad is not None:
self.assertEqual(p.grad, p_clone.grad)
self.assertEqual(p.grad, torch.zeros_like(p.grad))
def test_step(self):
optim_cases = self._create_optimizer_cases(all_cases=True)
num_iters = 10
for opt_obj, fused_opt_obj, opt_kwargs in optim_cases:
if opt_obj is None:
continue
params = self._create_simple_params_and_grads()
params_clone = self._create_params_clone(params)
opt = opt_obj(params, **opt_kwargs)
fused_opt = fused_opt_obj(params_clone, **opt_kwargs)
with torch.no_grad():
for _ in range(num_iters):
opt.step()
fused_opt.step()
for p, p_clone in zip(params, params_clone):
if p.grad is not None:
self.assertRtolEqual(p, p_clone, prec=1e-3)
def test_step_3rd_optims(self):
optim_cases = self._create_optimizer_cases(all_cases=True)
num_iters = 10
for _, fused_opt_obj, opt_kwargs in optim_cases:
if fused_opt_obj not in self.third_optim_baseline:
continue
params = self._create_simple_params_and_grads()
fused_opt = fused_opt_obj(params, **opt_kwargs)
with torch.no_grad():
for _ in range(num_iters):
fused_opt.step()
for i, p in enumerate(params):
if p.grad is not None:
self.assertRtolEqual(p.sum().item(), self.third_optim_baseline[fused_opt_obj][i])
def test_unscale(self):
model = self._create_simple_model()
input_tensor = torch.rand(3, 1, 24, 24).to("npu:0")
optim_cases = self._create_optimizer_cases()
for _, fused_opt_obj, opt_kwargs in optim_cases:
m = deepcopy(model)
optimizer = fused_opt_obj(m.parameters(), **opt_kwargs)
t = input_tensor.detach().clone()
scaler = GradScaler(init_scale=128.0)
with autocast():
output = m(t)
loss = output.mean()
scaler.scale(loss).backward()
grads_before_unscale = dict()
for p in m.parameters():
if p.grad is not None:
grads_before_unscale[p] = p.grad.clone().detach()
scaler.unscale_(optimizer)
for p in m.parameters():
if p.grad is not None:
self.assertEqual(grads_before_unscale[p] / 128, p.grad)
def test_simple_model_train_dynamic(self):
model = self._create_simple_model()
optim_cases = self._create_optimizer_cases()
num_iters = 10
for opt_obj, fused_opt_obj, opt_kwargs in optim_cases:
m = deepcopy(model)
opt = opt_obj(m.parameters(), **opt_kwargs)
scaler = GradScaler()
m_clone = deepcopy(model)
opt_fused = fused_opt_obj(m_clone.parameters(), **opt_kwargs)
scaler_fused = GradScaler()
for _ in range(num_iters):
input_tensor = torch.rand(3, 1, 24, 24).to("npu:0")
with autocast():
output = m(input_tensor)
loss = output.mean()
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
with autocast():
output_fused = m_clone(input_tensor)
loss_fused = output_fused.mean()
scaler_fused.scale(loss_fused).backward()
scaler_fused.step(opt_fused)
scaler_fused.update()
self.assertRtolEqual(loss, loss_fused)
def test_simple_model_train_static(self):
model = self._create_simple_model()
optim_cases = self._create_optimizer_cases()
num_iters = 10
for opt_obj, fused_opt_obj, opt_kwargs in optim_cases:
m = deepcopy(model)
opt = opt_obj(m.parameters(), **opt_kwargs)
scaler = GradScaler(dynamic=False, init_scale=128)
m_clone = deepcopy(model)
opt_fused = fused_opt_obj(m_clone.parameters(), **opt_kwargs)
scaler_fused = GradScaler(dynamic=False, init_scale=128)
for _ in range(num_iters):
input_tensor = torch.rand(3, 1, 24, 24).to("npu:0")
with autocast():
output = m(input_tensor)
loss = output.float().mean()
scaler.scale(loss).backward()
scaler.step(opt)
scaler.update()
with autocast():
output_fused = m_clone(input_tensor)
loss_fused = output_fused.float().mean()
scaler_fused.scale(loss_fused).backward()
scaler_fused.step(opt_fused)
scaler_fused.update()
self.assertRtolEqual(loss, loss_fused, prec=1e-3)
def test_clip_grad_norm_fused(self):
optim_cases = self._create_optimizer_cases()
for _, fused_opt_obj, opt_kwargs in optim_cases:
params = self._create_simple_params_and_grads()
params_clone = self._create_params_clone(params)
fused_opt = fused_opt_obj(params_clone, **opt_kwargs)
grad_norm = torch.nn.utils.clip_grad_norm_(params, 5.0)
grad_norm_fused = fused_opt.clip_grad_norm_fused_(5.0)
for p, p_clone in zip(params, params_clone):
if p.grad is not None:
self.assertRtolEqual(p.grad, p_clone.grad, prec=1e-3)
self.assertRtolEqual(grad_norm.float(), grad_norm_fused, prec=1e-3)
if __name__ == "__main__":
run_tests()