import math
import unittest
import numpy as np
import torch
import torch_npu
from torch_npu.testing.testcase import TestCase, run_tests
from torch_npu.testing.common_utils import SupportedDevices
def silu(x):
return x * (1 / (1 + torch.exp(-x)))
class TestFFN(TestCase):
def deq_scale_generate(self, deq_scale_shape):
fp32_deq_scale = np.random.uniform(low=0.01, high=0.05, size=deq_scale_shape).astype(np.float32)
uint32_deq_scale = np.frombuffer(fp32_deq_scale, np.uint32).reshape(deq_scale_shape)
uint32_deq_scale &= 0XFFFFE000
fp32_deq_scale = np.frombuffer(uint32_deq_scale, np.float32)
uint64_deq_scale = np.zeros(deq_scale_shape, np.uint64)
uint64_deq_scale |= np.uint64(uint32_deq_scale)
return fp32_deq_scale, uint64_deq_scale
def calc_ffn(self, x, weight1, weight2, activation, *, antiquant_scale1=None, antiquant_scale2=None,
antiquant_offset1=None, antiquant_offset2=None, scale=None, offset=None, deq_scale1=None,
deq_scale2=None):
if antiquant_scale1 is not None:
x = x.to(torch.float32)
weight1 = torch.from_numpy(weight1).to(torch.float16)
antiquant_offset1 = torch.broadcast_to(antiquant_offset1, weight1.size())
antiquant_scale1 = torch.broadcast_to(antiquant_scale1, weight1.size())
weight1 = (weight1.npu() + antiquant_offset1) * antiquant_scale1
weight1 = weight1.to(torch.float32)
weight2 = torch.from_numpy(weight2).to(torch.float16)
antiquant_offset2 = torch.broadcast_to(antiquant_offset2, weight2.size())
antiquant_scale2 = torch.broadcast_to(antiquant_scale2, weight2.size())
weight2 = (weight2.npu() + antiquant_offset2) * antiquant_scale2
weight2 = weight2.to(torch.float32)
elif scale is not None:
x = x.to(torch.int32)
weight1 = torch.from_numpy(weight1).to(torch.int32)
weight2 = torch.from_numpy(weight2).to(torch.int32)
mm1_res = x @ weight1
if scale is not None:
deq_scale1 = deq_scale1.reshape(1, -1)[:, :mm1_res.shape[-1]]
deq_scale1 = torch.from_numpy(deq_scale1)
mm1_res = (mm1_res * deq_scale1).to(torch.float16)
if activation == "relu":
activation_res = torch.nn.functional.relu(mm1_res)
elif activation == "gelu":
activation_res = torch.nn.functional.gelu(mm1_res)
elif activation == "silu":
activation_res = silu(mm1_res)
if scale is not None:
scale = torch.from_numpy(scale).to(torch.float16)
offset = torch.from_numpy(offset).to(torch.float16)
activation_res = activation_res * scale + offset
activation_res = torch.round(activation_res, decimals=0)
activation_res = activation_res.clamp(-128, 127).to(torch.int32)
mm2_res = torch.matmul(activation_res, weight2).to(torch.float16)
if scale is not None:
deq_scale2 = deq_scale2.reshape(1, -1)[:, :mm2_res.shape[-1]]
deq_scale2 = torch.from_numpy(deq_scale2)
mm2_res = (mm2_res * deq_scale2).to(torch.float16)
return mm2_res
def supported_op_exec(self, x, weight1, weight2, activation, *, expert_tokens=None, **kwargs):
if expert_tokens:
x = x.split(expert_tokens, dim=0)
y = []
for idx, x_i in enumerate(x):
y_i = self.calc_ffn(x_i, weight1[idx], weight2[idx], activation, **kwargs)
y.append(y_i)
y = torch.cat(y)
else:
y = self.calc_ffn(x, weight1, weight2, activation, **kwargs)
return y
def custom_op_exec(self, x, weight1, weight2, activation, *, expert_tokens=None, expert_tokens_index=None,
antiquant_scale1=None, antiquant_scale2=None, antiquant_offset1=None, antiquant_offset2=None,
scale=None, offset=None, deq_scale1=None, deq_scale2=None, output_dtype=None):
if antiquant_scale1 is not None:
return torch_npu.npu_ffn(x, weight1, weight2, activation, expert_tokens=expert_tokens,
expert_tokens_index=expert_tokens_index, antiquant_scale1=antiquant_scale1,
antiquant_scale2=antiquant_scale2, antiquant_offset1=antiquant_offset1,
antiquant_offset2=antiquant_offset2, inner_precise=1,
output_dtype=output_dtype)
elif scale is not None:
x = x.npu()
weight1 = weight1.npu()
weight2 = weight2.npu()
scale = scale.npu()
offset = offset.npu()
deq_scale1 = torch.from_numpy(deq_scale1.astype(np.int64)).npu()
deq_scale2 = torch.from_numpy(deq_scale2.astype(np.int64)).npu()
return torch_npu.npu_ffn(x, weight1, weight2, activation, expert_tokens=expert_tokens,
expert_tokens_index=expert_tokens_index, scale=scale, offset=offset,
deq_scale1=deq_scale1, deq_scale2=deq_scale2, inner_precise=1,
output_dtype=output_dtype)
else:
return torch_npu.npu_ffn(x, weight1, weight2, activation, expert_tokens=expert_tokens,
expert_tokens_index=expert_tokens_index, inner_precise=1)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn(self, device="npu"):
torch.manual_seed(0)
x = torch.randn(8192, 320, dtype=torch.float16).npu()
weight1 = (torch.randn(320, 2560, dtype=torch.float16) * 0.01).npu()
weight2 = (torch.randn(2560, 320, dtype=torch.float16) * 0.01).npu()
x_clone = x.clone()
weight1_clone = weight1.clone()
weight2_clone = weight2.clone()
activation = "silu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_expert_tokens(self, device="npu"):
torch.manual_seed(0)
BS, N, H, E = 340, 2560, 5120, 16
x = torch.randn(BS, H, dtype=torch.float16).npu()
weight1 = (torch.randn(E, H, N, dtype=torch.float16) * 0.01).npu()
weight2 = (torch.randn(E, N, H, dtype=torch.float16) * 0.01).npu()
expert_tokens = [50, 15, 4, 15, 20, 21, 36, 15, 25, 21, 15, 10, 19, 30, 24, 20]
x_clone = x.clone()
weight1_clone = weight1.clone()
weight2_clone = weight2.clone()
activation = "gelu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation, expert_tokens=expert_tokens)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation,
expert_tokens=expert_tokens)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_expert_tokens_index(self, device="npu"):
torch.manual_seed(0)
BS, N, H, E = 340, 2560, 5120, 16
x = torch.randn(BS, H, dtype=torch.float16).npu()
weight1 = (torch.randn(E, H, N, dtype=torch.float16) * 0.01).npu()
weight2 = (torch.randn(E, N, H, dtype=torch.float16) * 0.01).npu()
expert_tokens = [50, 15, 4, 15, 20, 21, 36, 15, 25, 21, 15, 10, 19, 30, 24, 20]
expert_tokens_index = np.cumsum(expert_tokens).tolist()
x_clone = x.clone()
weight1_clone = weight1.clone()
weight2_clone = weight2.clone()
activation = "gelu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation, expert_tokens=expert_tokens)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation,
expert_tokens_index=expert_tokens_index)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_antiquant(self, device="npu"):
torch.manual_seed(0)
np.random.seed(0)
x = torch.normal(mean=0., std=0.02, size=(8192, 320), dtype=torch.float16).npu()
weight1 = np.random.randint(-9, 9, size=(320, 2560), dtype=np.int8)
weight2 = np.random.randint(-9, 9, size=(2560, 320), dtype=np.int8)
antiquant_scale1 = torch.normal(mean=0., std=0.02, size=(2560,), dtype=torch.float16).npu()
antiquant_scale2 = torch.normal(mean=0., std=0.02, size=(320,), dtype=torch.float16).npu()
antiquant_offset1 = torch.normal(mean=0., std=0.02, size=(2560,), dtype=torch.float16).npu()
antiquant_offset2 = torch.normal(mean=0., std=0.02, size=(320,), dtype=torch.float16).npu()
x_clone = x.clone()
weight1_clone = torch.from_numpy(weight1).npu()
weight2_clone = torch.from_numpy(weight2).npu()
antiquant_scale1_clone = antiquant_scale1.clone()
antiquant_scale2_clone = antiquant_scale2.clone()
antiquant_offset1_clone = antiquant_offset1.clone()
antiquant_offset2_clone = antiquant_offset2.clone()
activation = "relu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation,
antiquant_scale1=antiquant_scale1,
antiquant_scale2=antiquant_scale2,
antiquant_offset1=antiquant_offset1,
antiquant_offset2=antiquant_offset2)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation,
antiquant_scale1=antiquant_scale1_clone,
antiquant_scale2=antiquant_scale2_clone,
antiquant_offset1=antiquant_offset1_clone,
antiquant_offset2=antiquant_offset2_clone)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_quant(self, device="npu"):
torch.manual_seed(0)
np.random.seed(0)
x = torch.from_numpy(np.random.randint(-128, 127, size=(8192, 320), dtype=np.int8))
weight1 = np.random.randint(-128, 127, size=(320, 2560), dtype=np.int8)
weight2 = np.random.randint(-128, 127, size=(2560, 320), dtype=np.int8)
scale = np.ones(1, dtype=np.float32) * 0.001
offset = np.zeros(1, dtype=np.float32)
deq_scale1_fp32, deq_scale1_uint64 = self.deq_scale_generate((2560,))
deq_scale2_fp32, deq_scale2_uint64 = self.deq_scale_generate((320,))
x_clone = x.clone().npu()
weight1_clone = torch.from_numpy(weight1)
weight2_clone = torch.from_numpy(weight2)
scale_clone = torch.from_numpy(scale)
offset_clone = torch.from_numpy(offset)
activation = "gelu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation, scale=scale, offset=offset,
deq_scale1=deq_scale1_fp32, deq_scale2=deq_scale2_fp32)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation, scale=scale_clone,
offset=offset_clone, deq_scale1=deq_scale1_uint64,
deq_scale2=deq_scale2_uint64)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
@SupportedDevices(['Ascend910B'])
def test_npu_ffn_quant_with_outputdtype(self, device="npu"):
torch.manual_seed(0)
np.random.seed(0)
x = torch.from_numpy(np.random.randint(-128, 127, size=(8192, 320), dtype=np.int8))
weight1 = np.random.randint(-128, 127, size=(320, 2560), dtype=np.int8)
weight2 = np.random.randint(-128, 127, size=(2560, 320), dtype=np.int8)
scale = np.ones(1, dtype=np.float32) * 0.001
offset = np.zeros(1, dtype=np.float32)
deq_scale1_fp32, deq_scale1_uint64 = self.deq_scale_generate((2560,))
deq_scale2_fp32, deq_scale2_uint64 = self.deq_scale_generate((320,))
x_clone = x.clone().npu()
weight1_clone = torch.from_numpy(weight1)
weight2_clone = torch.from_numpy(weight2)
scale_clone = torch.from_numpy(scale)
offset_clone = torch.from_numpy(offset)
activation = "gelu"
supported_output = self.supported_op_exec(x, weight1, weight2, activation, scale=scale, offset=offset,
deq_scale1=deq_scale1_fp32, deq_scale2=deq_scale2_fp32)
custom_output = self.custom_op_exec(x_clone, weight1_clone, weight2_clone, activation, scale=scale_clone,
offset=offset_clone, deq_scale1=deq_scale1_uint64,
deq_scale2=deq_scale2_uint64, output_dtype=torch.float16)
self.assertRtolEqual(x, x_clone, 0.001)
self.assertRtolEqual(supported_output, custom_output, 0.001)
if __name__ == "__main__":
run_tests()