ba8dbd32创建于 2025年5月29日历史提交
import pytest
import torch
import torch_npu

from tests_extend.unit_tests.common import DistributedTest
from mindspeed.ops.fusion_attention_v2 import npu_fusion_attention

DEVICE_NAME = torch_npu.npu.get_device_name(0)[:10]


class TestNPUFusionAttention(DistributedTest):

    def supported_op_exec(self, query, key, value):
        return torch_npu.npu_fusion_attention(query, key, value, head_num=32, input_layout="BSH", scale=0.08838)

    def custom_op_exec(self, query, key, value):
        return npu_fusion_attention(query, key, value, 32, "BSH", scale=0.08838)

    def trans_BNSD2BSH(self, tensor: torch.Tensor):
        tensor = torch.transpose(tensor, 1, 2)
        tensor = torch.reshape(tensor, (tensor.shape[0], tensor.shape[1], -1))
        return tensor

    @pytest.mark.skipif(DEVICE_NAME != 'Ascend910B', reason='device type is not supported, skip this UT!')
    def test_npu_flash_attention(self):
        query = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        key = torch.randn(1, 32, 128, 128, dtype=torch.float16)
        value = torch.randn(1, 32, 128, 128, dtype=torch.float16)

        q_npu = self.trans_BNSD2BSH(query).npu()
        k_npu = self.trans_BNSD2BSH(key).npu()
        v_npu = self.trans_BNSD2BSH(value).npu()
        output = self.supported_op_exec(q_npu, k_npu, v_npu)
        attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels = self.custom_op_exec(q_npu, k_npu,
                                                                                                           v_npu)
        assert torch.allclose(output[0], attention_score, rtol=0.001, atol=0.001)