import torch
import numpy as np
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
torch.npu.set_compile_mode(jit_compile=False)
class TestNPUDropoutGenMask(TestCase):
def cpu_op_exec_p0(self, x):
BYTE_LEN = 8
DATA_ALIGN = 128
size = x.shape
numels = np.prod(size)
length = (numels + DATA_ALIGN - 1) // DATA_ALIGN * DATA_ALIGN // BYTE_LEN
output = np.zeros(length, dtype=np.uint8)
res_len = (numels + BYTE_LEN - 1) // BYTE_LEN
for i in range(res_len):
output[i] = 255
return output
def npu_op_exec(self, x, p, seed=1, offset=0, parallel=False):
x1 = x.to("npu")
size = x.shape
output = torch_npu._npu_dropout_gen_mask(x1, size, p, seed, offset, parallel=parallel)
output = output.to("cpu")
output = output.numpy()
return output
def test_dropout_gen_full_mask(self):
h, w = 32, 16
x = torch.randn(h, w, dtype=torch.float16)
prob = 0.0
res = self.npu_op_exec(x, prob)
res_cpu = self.cpu_op_exec_p0(x)
self.assertRtolEqual(res, res_cpu)
def test_dropout_gen_mask(self):
h, w = 17, 19
x = torch.randn(h, w, dtype=torch.float32)
prob = 0.4
res = self.npu_op_exec(x, prob)
def test_gen_mask_enable_parallel(self):
h, w = 17, 19
x = torch.randn(h, w, dtype=torch.float32)
prob = 0.7
res = self.npu_op_exec(x, prob, 2, 100, True)
res1 = self.npu_op_exec(x, prob, 2, 100, False)
self.assertRtolEqual(res, res1)
if __name__ == "__main__":
run_tests()