import unittest
import numpy as np
import torch
import torch.nn.functional as F
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import create_common_tensor
from torch_npu.contrib.module import Mish, SiLU
from torch_npu.contrib.function.fused_attention import _check_compatibility_once
from torch_npu.contrib.function.fused_attention import _is_format_matched
class TestActivations(TestCase):
def cpu_mish(self, input1):
"""
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
See additional documentation for mish class.
"""
input1.requires_grad = True
res = input1 * torch.tanh(F.softplus(input1))
res.backward(torch.ones_like(res))
return res.detach(), input1.grad
def npu_mish(self, input1):
input1.requires_grad = True
model = Mish()
res = model(input1)
res.backward(torch.ones_like(res))
return res.detach().cpu(), input1.grad.cpu()
def test_mish(self):
dtype_list = [np.float16, np.float32]
format_list = [-1, 0, 2]
shape_list = [
[4],
[2, 3],
[6, 5, 8, 10],
[1, 2, 3, 6, 6],
[2, 5, 6, 8, 9, 2],
[2, 5, 6, 8, 9, 2, 2],
]
shape_format = [
[i, j, k] for i in dtype_list for j in format_list for k in shape_list
]
for item in shape_format:
cpu_input, npu_input = create_common_tensor(item, 1, 10)
if cpu_input.dtype == torch.float16:
cpu_input = cpu_input.float()
cpu_output, cpu_inputgrad = self.cpu_mish(cpu_input)
cpu_output = cpu_output.half()
cpu_inputgrad = cpu_inputgrad.half()
else:
cpu_output, cpu_inputgrad = self.cpu_mish(cpu_input)
npu_output, npu_inputgrad = self.npu_mish(npu_input)
self.assertRtolEqual(cpu_output, npu_output)
self.assertRtolEqual(cpu_inputgrad, npu_inputgrad)
def cpu_silu(self, input1):
"""
Applies the mish function element-wise:
mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x)))
See additional documentation for mish class.
"""
input1.requires_grad = True
res = input1 * torch.sigmoid(input1)
output = res.sum()
output.backward()
return res.detach(), input1.grad
def npu_silu(self, input1):
input1.requires_grad = True
model = SiLU()
res = model(input1)
output = res.sum()
output.backward()
return res.detach().cpu(), input1.grad.cpu()
def test_silu(self):
dtype_list = [np.float32, np.float16]
format_list = [-1, 0, 2]
shape_list = [
[5],
[2, 3],
[6, 5, 2, 10],
[1, 2, 4, 6, 6],
[2, 5, 6, 2, 9, 2],
[2, 5, 6, 3, 9, 2, 2],
]
shape_format = [
[i, j, k] for i in dtype_list for j in format_list for k in shape_list
]
for item in shape_format:
cpu_input, npu_input = create_common_tensor(item, 1, 10)
if cpu_input.dtype == torch.float16:
cpu_input = cpu_input.float()
cpu_output, cpu_inputgrad = self.cpu_silu(cpu_input)
cpu_output = cpu_output.half()
cpu_inputgrad = cpu_inputgrad.half()
else:
cpu_output, cpu_inputgrad = self.cpu_silu(cpu_input)
npu_output, npu_inputgrad = self.npu_silu(npu_input)
self.assertRtolEqual(cpu_output, npu_output)
self.assertRtolEqual(cpu_inputgrad, npu_inputgrad)
def test_check_compatibility_once_invalid_hidden_states_shape(self):
hidden_states = torch_npu.npu_format_cast(torch.randn(30, 1024).npu(), 29)
attention_mask = torch_npu.npu_format_cast(torch.randn(2, 1, 8, 8).npu(), 29)
query_kernel = torch_npu.npu_format_cast(torch.randn(1024, 1024).npu(), 29)
key_kernel = torch_npu.npu_format_cast(torch.randn(1024, 1024).npu(), 29)
value_kernel = torch_npu.npu_format_cast(torch.randn(1024, 1024).npu(), 29)
query_bias = torch_npu.npu_format_cast(torch.randn(1024).npu(), 2)
key_bias = torch_npu.npu_format_cast(torch.randn(1024).npu(), 2)
value_bias = torch_npu.npu_format_cast(torch.randn(1024).npu(), 2)
with self.assertRaises(RuntimeError):
_check_compatibility_once(
hidden_states,
attention_mask,
query_kernel,
key_kernel,
value_kernel,
query_bias,
key_bias,
value_bias
)
def test_is_format_matched_invalid(self):
tensor1 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
tensor2 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
tensor3 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
tensor4 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
tensor5 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
tensor6 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 2)
tensor7 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 2)
tensor8 = torch_npu.npu_format_cast(torch.randn(4, 4).npu(), 29)
result = _is_format_matched([tensor1, tensor2, tensor3, tensor4, tensor5, tensor6, tensor7, tensor8])
self.assertFalse(result)
if __name__ == "__main__":
run_tests()