9583d925创建于 2024年10月23日历史提交
import os
import pytest
import torch
import torch_npu
import torchair as tng
import numpy as np

from torchair.configs.compiler_config import CompilerConfig
from mindspeed.ops.quant_gmm import npu_quant_gmm
from mindspeed.ops.quant_gmm import npu_quant_gmm_v2

os.environ["ENABLE_ACLNN"] = "true"
config = CompilerConfig()
npu_backend = tng.get_npu_backend(compiler_config=config)

DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]


class QuantGMMModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, *args):
        x, weight, scale, per_token_scale, group_list, group_list_type, output_dtype = args
        if group_list_type == 0:
            y = npu_quant_gmm(x, weight, scale=scale, per_token_scale=per_token_scale, bias=None,
                              group_list=group_list, output_dtype=output_dtype)
        elif group_list_type == 1:
            y = npu_quant_gmm_v2(x, weight, scale=scale, per_token_scale=per_token_scale, bias=None,
                                 group_list=group_list, output_dtype=output_dtype)
        return y


class TestNPUQuantGMM:

    def supported_op_exec(self, *args):
        x, weight, scale, per_token_scale, group_list, output_dtype = args
        final_out = []
        num_experts = len(group_list)
        x = x.to(torch.float)
        weight = weight.to(torch.float)
        x = list(x.split(group_list, dim=0))
        per_token_scale = list(per_token_scale.split(group_list, dim=0))
        for expert_idx in range(num_experts):
            h = x[expert_idx]
            h_out = h @ weight[expert_idx]
            h_out = (h_out * scale[expert_idx].reshape(1, -1) * per_token_scale[expert_idx].reshape(-1, 1)).to(output_dtype)
            final_out.append(h_out)

        return torch.cat([x for x in final_out], dim=0)

    @pytest.mark.skip(reason='Cann package need update')
    @pytest.mark.skipif(DEVICE_NAME != 'Ascend910B', reason='device type is not supported, skip this UT!')
    @pytest.mark.parametrize('group_list_type', [0, 1])
    @pytest.mark.parametrize('is_graph_mode', [True, False])
    def test_npu_quant_gmm(self, group_list_type, is_graph_mode):
        x = torch.randint(-128, 128, (42, 64), dtype=torch.int8).npu()
        weight = torch.randint(-128, 128, (8, 64, 32), dtype=torch.int8).npu()
        scale = torch.rand(8, 32, dtype=torch.float32).npu()
        per_token_scale = torch.rand(42, dtype=torch.float32).npu()

        group_list = [1, 3, 5, 7, 9, 6, 7, 4]
        output = self.supported_op_exec(x, weight, scale, per_token_scale, group_list, torch.float16)
        group_list_value = group_list if group_list_type == 1 else np.cumsum(group_list)
        group_list_tensor = torch.tensor(group_list_value).to(torch.int64).npu()
        model = QuantGMMModel().npu()
        if is_graph_mode:
            model = torch.compile(model, backend=npu_backend, dynamic=True)
        y = model(x, weight, scale, per_token_scale, group_list_tensor, group_list_type, torch.float16)

        assert torch.allclose(y, output, rtol=0.005, atol=0.005)