* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include "kernel_operator.h"
#include "include/adv_api/matmul/tiling.h"
#include "impl/adv_api/detail/matmul/utils/matmul_param.h"
#include "impl/adv_api/detail/matmul/policy/matmul_policy.h"
#include "impl/adv_api/detail/matmul/policy/matmul_private_modules.h"
#include "impl/adv_api/detail/matmul/utils/matmul_call_back.h"
#include "impl/adv_api/detail/matmul/utils/matmul_utils.h"
using namespace std;
using namespace AscendC;
class TestMatmulUtilsMx : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
private:
using AS_TYPE_GM = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using INPUT_AS_TYPE_GM = MatmulInputScaleAType<AS_TYPE_GM, fp8_e8m0_t>;
using AS_TYPE_UB = MatmulTypeWithScale<TPosition::VECOUT, TPosition::VECOUT, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using INPUT_AS_TYPE_UB = MatmulInputScaleAType<AS_TYPE_UB, fp8_e8m0_t>;
using AS_TYPE_L1 = MatmulTypeWithScale<TPosition::TSCM, TPosition::TSCM, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using INPUT_AS_TYPE_L1 = MatmulInputScaleAType<AS_TYPE_L1, fp8_e8m0_t>;
using A_TYPE_GM = MatmulType<TPosition::GM, CubeFormat::ND, half, false>;
using INPUT_A_TYPE_GM = MatmulInputAType<A_TYPE_GM, typename A_TYPE_GM::T>;
using A_TYPE_UB = MatmulType<TPosition::VECOUT, CubeFormat::ND, half, false>;
using INPUT_A_TYPE_UB = MatmulInputAType<A_TYPE_UB, typename A_TYPE_UB::T>;
using A_TYPE_L1 = MatmulType<TPosition::TSCM, CubeFormat::ND, half, false>;
using INPUT_A_TYPE_L1 = MatmulInputAType<A_TYPE_L1, typename A_TYPE_L1::T>;
using AS_TYPE_GM_L1 = MatmulTypeWithScale<TPosition::GM, TPosition::TSCM, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using AS_TYPE_GM_TP = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp8_e4m3fn_t, true, TPosition::GM, CubeFormat::ND, true, TPosition::GM>;
using BS_TYPE_GM = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using INPUT_BS_TYPE_GM = MatmulInputScaleBType<BS_TYPE_GM, fp8_e8m0_t>;
using BS_TYPE_GM_TP = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp8_e4m3fn_t, true, TPosition::GM, CubeFormat::ND, true, TPosition::GM>;
using BS_TYPE_GM_F4 = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp4x2_e1m2_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using B_TYPE_GM = MatmulType<TPosition::GM, CubeFormat::ND, half>;
using INPUT_B_TYPE_GM = MatmulInputBType<B_TYPE_GM, typename B_TYPE_GM::T>;
};
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_AuxGetC0Size) {
EXPECT_EQ(AuxGetC0Size<float>(), 8);
EXPECT_EQ(AuxGetC0Size<int32_t>(), 8);
EXPECT_EQ(AuxGetC0Size<uint32_t>(), 8);
EXPECT_EQ(AuxGetC0Size<half>(), 16);
EXPECT_EQ(AuxGetC0Size<bfloat16_t>(), 16);
EXPECT_EQ(AuxGetC0Size<int16_t>(), 16);
EXPECT_EQ(AuxGetC0Size<uint16_t>(), 16);
EXPECT_EQ(AuxGetC0Size<int8_t>(), 32);
EXPECT_EQ(AuxGetC0Size<uint8_t>(), 32);
EXPECT_EQ(AuxGetC0Size<hifloat8_t>(), 32);
EXPECT_EQ(AuxGetC0Size<fp8_e4m3fn_t>(), 32);
EXPECT_EQ(AuxGetC0Size<fp8_e5m2_t>(), 32);
EXPECT_EQ(AuxGetC0Size<fp8_e8m0_t>(), 32);
EXPECT_EQ(AuxGetC0Size<int4b_t>(), 64);
EXPECT_EQ(AuxGetC0Size<fp4x2_e1m2_t>(), 64);
EXPECT_EQ(AuxGetC0Size<fp4x2_e2m1_t>(), 64);
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsSupportB8) {
EXPECT_FALSE(IsSupportB8<float>());
EXPECT_FALSE(IsSupportB8<int32_t>());
EXPECT_FALSE(IsSupportB8<uint32_t>());
EXPECT_FALSE(IsSupportB8<half>());
EXPECT_FALSE(IsSupportB8<bfloat16_t>());
EXPECT_FALSE(IsSupportB8<int16_t>());
EXPECT_FALSE(IsSupportB8<uint16_t>());
EXPECT_TRUE(IsSupportB8<int8_t>());
EXPECT_FALSE(IsSupportB8<uint8_t>());
EXPECT_TRUE(IsSupportB8<hifloat8_t>());
EXPECT_TRUE(IsSupportB8<fp8_e4m3fn_t>());
EXPECT_TRUE(IsSupportB8<fp8_e5m2_t>());
EXPECT_FALSE(IsSupportB8<fp8_e8m0_t>());
EXPECT_FALSE(IsSupportB8<int4b_t>());
EXPECT_FALSE(IsSupportB8<fp4x2_e1m2_t>());
EXPECT_FALSE(IsSupportB8<fp4x2_e2m1_t>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsSupportB4) {
EXPECT_FALSE(IsSupportB4<float>());
EXPECT_FALSE(IsSupportB4<int32_t>());
EXPECT_FALSE(IsSupportB4<uint32_t>());
EXPECT_FALSE(IsSupportB4<half>());
EXPECT_FALSE(IsSupportB4<bfloat16_t>());
EXPECT_FALSE(IsSupportB4<int16_t>());
EXPECT_FALSE(IsSupportB4<uint16_t>());
EXPECT_FALSE(IsSupportB4<int8_t>());
EXPECT_FALSE(IsSupportB4<uint8_t>());
EXPECT_FALSE(IsSupportB4<hifloat8_t>());
EXPECT_FALSE(IsSupportB4<fp8_e4m3fn_t>());
EXPECT_FALSE(IsSupportB4<fp8_e5m2_t>());
EXPECT_FALSE(IsSupportB4<fp8_e8m0_t>());
EXPECT_TRUE(IsSupportB4<int4b_t>());
EXPECT_TRUE(IsSupportB4<fp4x2_e1m2_t>());
EXPECT_TRUE(IsSupportB4<fp4x2_e2m1_t>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_PhyMxScalePosIsL1) {
EXPECT_FALSE(PhyMxScalePosIsL1<INPUT_AS_TYPE_GM>());
EXPECT_FALSE(PhyMxScalePosIsL1<INPUT_AS_TYPE_UB>());
EXPECT_TRUE(PhyMxScalePosIsL1<INPUT_AS_TYPE_L1>());
EXPECT_FALSE(PhyMxScalePosIsL1<INPUT_A_TYPE_GM>());
EXPECT_FALSE(PhyMxScalePosIsL1<INPUT_A_TYPE_UB>());
EXPECT_FALSE(PhyMxScalePosIsL1<INPUT_A_TYPE_L1>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_PhyMxScalePosIsUB) {
EXPECT_FALSE(PhyMxScalePosIsUB<INPUT_AS_TYPE_GM>());
EXPECT_TRUE(PhyMxScalePosIsUB<INPUT_AS_TYPE_UB>());
EXPECT_FALSE(PhyMxScalePosIsUB<INPUT_AS_TYPE_L1>());
EXPECT_FALSE(PhyMxScalePosIsUB<INPUT_A_TYPE_GM>());
EXPECT_FALSE(PhyMxScalePosIsUB<INPUT_A_TYPE_UB>());
EXPECT_FALSE(PhyMxScalePosIsUB<INPUT_A_TYPE_L1>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_PhyMxScalePosIsGM) {
EXPECT_TRUE(PhyMxScalePosIsGM<INPUT_AS_TYPE_GM>());
EXPECT_FALSE(PhyMxScalePosIsGM<INPUT_AS_TYPE_UB>());
EXPECT_FALSE(PhyMxScalePosIsGM<INPUT_AS_TYPE_L1>());
EXPECT_FALSE(PhyMxScalePosIsGM<INPUT_A_TYPE_GM>());
EXPECT_FALSE(PhyMxScalePosIsGM<INPUT_A_TYPE_UB>());
EXPECT_FALSE(PhyMxScalePosIsGM<INPUT_A_TYPE_L1>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_GetBitSize) {
EXPECT_EQ(GetBitSize<float>(), 32);
EXPECT_EQ(GetBitSize<int32_t>(), 32);
EXPECT_EQ(GetBitSize<uint32_t>(), 32);
EXPECT_EQ(GetBitSize<half>(), 16);
EXPECT_EQ(GetBitSize<bfloat16_t>(), 16);
EXPECT_EQ(GetBitSize<int16_t>(), 16);
EXPECT_EQ(GetBitSize<uint16_t>(), 16);
EXPECT_EQ(GetBitSize<int8_t>(), 8);
EXPECT_EQ(GetBitSize<uint8_t>(), 8);
EXPECT_EQ(GetBitSize<hifloat8_t>(), 8);
EXPECT_EQ(GetBitSize<fp8_e4m3fn_t>(), 8);
EXPECT_EQ(GetBitSize<fp8_e5m2_t>(), 8);
EXPECT_EQ(GetBitSize<fp8_e8m0_t>(), 8);
EXPECT_EQ(GetBitSize<int4b_t>(), 4);
EXPECT_EQ(GetBitSize<fp4x2_e1m2_t>(), 4);
EXPECT_EQ(GetBitSize<fp4x2_e2m1_t>(), 4);
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsScaleTransWithInlv) {
EXPECT_FALSE(IsScaleTransWithInlv<AS_TYPE_GM>);
EXPECT_FALSE(IsScaleTransWithInlv<AS_TYPE_UB>);
EXPECT_FALSE(IsScaleTransWithInlv<AS_TYPE_L1>);
EXPECT_TRUE(IsScaleTransWithInlv<AS_TYPE_GM_L1>);
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsL1BNeedTrans) {
EXPECT_FALSE((IsL1BNeedTrans<AS_TYPE_GM, BS_TYPE_GM, CFG_NORM>()));
EXPECT_TRUE((IsL1BNeedTrans<AS_TYPE_GM, BS_TYPE_GM_F4, CFG_NORM>()));
EXPECT_FALSE((IsL1BNeedTrans<AS_TYPE_GM, BS_TYPE_GM, CFG_MDL>()));
EXPECT_TRUE((IsL1BNeedTrans<AS_TYPE_GM, BS_TYPE_GM_F4, CFG_MDL>()));
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_GetTransBDataType) {
GetTransBDataType<AS_TYPE_GM, BS_TYPE_GM, CFG_NORM>();
GetTransBDataType<AS_TYPE_GM, BS_TYPE_GM_F4, CFG_NORM>();
GetTransBDataType<AS_TYPE_GM, BS_TYPE_GM, CFG_MDL>();
GetTransBDataType<AS_TYPE_GM, BS_TYPE_GM_F4, CFG_MDL>();
GetTransBDataType<A_TYPE_GM, B_TYPE_GM, CFG_NORM>();
GetTransBDataType<A_TYPE_GM, BS_TYPE_GM_F4, CFG_NORM>();
GetTransBDataType<A_TYPE_GM, B_TYPE_GM, CFG_MDL>();
GetTransBDataType<A_TYPE_GM, BS_TYPE_GM_F4, CFG_MDL>();
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsScaleTag) {
EXPECT_TRUE(IsScaleTag<INPUT_AS_TYPE_GM>());
EXPECT_FALSE(IsScaleTag<INPUT_A_TYPE_GM>());
EXPECT_TRUE(IsScaleTag<INPUT_BS_TYPE_GM>());
EXPECT_FALSE(IsScaleTag<INPUT_B_TYPE_GM>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_InputPhyPosIsL1) {
EXPECT_FALSE(InputPhyPosIsL1<INPUT_AS_TYPE_GM>());
EXPECT_FALSE(InputPhyPosIsL1<INPUT_AS_TYPE_UB>());
EXPECT_TRUE(InputPhyPosIsL1<INPUT_AS_TYPE_L1>());
EXPECT_FALSE(InputPhyPosIsL1<INPUT_A_TYPE_GM>());
EXPECT_FALSE(InputPhyPosIsL1<INPUT_A_TYPE_UB>());
EXPECT_TRUE(InputPhyPosIsL1<INPUT_A_TYPE_L1>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_InputPhyPosIsUB) {
EXPECT_FALSE(InputPhyPosIsUB<INPUT_AS_TYPE_GM>());
EXPECT_TRUE(InputPhyPosIsUB<INPUT_AS_TYPE_UB>());
EXPECT_FALSE(InputPhyPosIsUB<INPUT_AS_TYPE_L1>());
EXPECT_FALSE(InputPhyPosIsUB<INPUT_A_TYPE_GM>());
EXPECT_TRUE(InputPhyPosIsUB<INPUT_A_TYPE_UB>());
EXPECT_FALSE(InputPhyPosIsUB<INPUT_A_TYPE_L1>());
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_SupportMXFP8) {
EXPECT_EQ(SupportMXFP8<int8_t>, 0);
EXPECT_EQ(SupportMXFP8<uint8_t>, 0);
EXPECT_EQ(SupportMXFP8<hifloat8_t>, 0);
EXPECT_EQ(SupportMXFP8<fp8_e4m3fn_t>, 1);
EXPECT_EQ(SupportMXFP8<fp8_e5m2_t>, 1);
EXPECT_EQ(SupportMXFP8<fp8_e8m0_t>, 0);
}
TEST_F(TestMatmulUtilsMx, test_mx_matmul_utils_IsMxEnableUnitFlag) {
EXPECT_TRUE((IsMxDisableUnitFlag<AS_TYPE_GM, BS_TYPE_GM, CFG_NORM>));
EXPECT_TRUE((IsMxDisableUnitFlag<AS_TYPE_GM_TP, BS_TYPE_GM, CFG_NORM>));
EXPECT_FALSE((IsMxDisableUnitFlag<AS_TYPE_GM, BS_TYPE_GM_TP, CFG_NORM>));
EXPECT_TRUE((IsMxDisableUnitFlag<AS_TYPE_GM_TP, BS_TYPE_GM_TP, CFG_NORM>));
}