import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
class TestJitTrace(TestCase):
def test_script_npu_max(self):
class NpuModel(torch.nn.Module):
def __init__(self):
super(NpuModel, self).__init__()
def forward(self, x):
x = torch_npu.npu_max(x, dim=1)
return x
example_input = torch.rand(2, 8).npu()
model = NpuModel().to("npu")
output1 = model(example_input)
script_model = torch.jit.script(model)
output2 = script_model(example_input)
self.assertRtolEqual(output1, output2)
def test_script_npu_bert_apply_adam_out(self):
class NpuModel(torch.nn.Module):
def __init__(self):
super(NpuModel, self).__init__()
def forward(self, grad, var_in, m_in, v_in):
max_grad_norm = -1.
beta1 = 0.9
beta2 = 0.99
weight_decay = 0.
lr = 0.
epsilon = 1e-06
global_grad_norm = 0.
var_out, m_out, v_out = torch_npu.npu_bert_apply_adam(
lr, beta1, beta2, epsilon, grad, max_grad_norm, global_grad_norm, weight_decay,
out=(var_in, m_in, v_in))
return var_out, m_out, v_out
seed = 3
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
var_in = torch.rand(321538).uniform_(-32., 21.).npu()
m_in = torch.zeros(321538).npu()
v_in = torch.zeros(321538).npu()
grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
model = NpuModel().to("npu")
output1 = model(grad, var_in, m_in, v_in)
script_model = torch.jit.script(model)
seed = 3
torch.manual_seed(seed)
torch.npu.manual_seed(seed)
torch.npu.manual_seed_all(seed)
var_in = torch.rand(321538).uniform_(-32., 21.).npu()
m_in = torch.zeros(321538).npu()
v_in = torch.zeros(321538).npu()
grad = torch.rand(321538).uniform_(-0.05, 0.03).npu()
output2 = script_model(grad, var_in, m_in, v_in)
self.assertRtolEqual(output1, output2)
def test_script_npu_rotary_mul(self):
class NpuModel(torch.nn.Module):
def __init__(self):
super(NpuModel, self).__init__()
def forward(self, x, r1, r2):
x = torch_npu.npu_rotary_mul(x, r1, r2)
return x
x = torch.rand([8192, 2, 5, 128], dtype=torch.float32).npu()
r1 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu()
r2 = torch.rand([8192, 1, 1, 128], dtype=torch.float32).npu()
model = NpuModel().to("npu")
output1 = model(x, r1, r2)
script_model = torch.jit.script(model)
output2 = script_model(x, r1, r2)
self.assertRtolEqual(output1, output2)
if __name__ == '__main__':
run_tests()