* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <gtest/gtest.h>
#include "kernel_operator.h"
#include "include/adv_api/matmul/matmul.h"
#include "copy_cube_in/base_tiling_struct.h"
using namespace std;
using namespace AscendC;
namespace {
class CustomMatmulScheduler {
public:
__aicore__ inline void Init(const TCubeTiling *__restrict cubeTiling, TPipe *tpipe) {}
__aicore__ inline void GetResult(const LocalTensor<float>& co2Local, uint8_t enAtomic = 0, bool enSequentialWrite = false) {}
__aicore__ inline bool ScheduleOnce(bool enPartialSum) {
return false;
}
__aicore__ inline void Reset() {}
__aicore__ inline void End() {}
};
template <const auto& MM_CFG, typename IMPL, typename A_TYPE, typename B_TYPE, typename C_TYPE, typename BIAS_TYPE>
class CustomMatmulPolicy : public Impl::Detail::MatmulWithScalePolicy<MM_CFG, IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE>
{
public:
using Scheduler = CustomMatmulScheduler;
};
}
class TestMxMatmulImpl : public testing::Test {
protected:
void SetUp() {}
void TearDown() {}
private:
using AS_TYPE_GM = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp4x2_e1m2_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using AS_TYPE_UB = MatmulTypeWithScale<TPosition::GM, TPosition::VECOUT, CubeFormat::ND,
fp8_e4m3fn_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using BS_TYPE_GM = MatmulTypeWithScale<TPosition::GM, TPosition::GM, CubeFormat::ND,
fp4x2_e2m1_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using BS_TYPE_UB = MatmulTypeWithScale<TPosition::GM, TPosition::VECOUT, CubeFormat::ND,
fp8_e5m2_t, false, TPosition::GM, CubeFormat::ND, false, TPosition::GM>;
using C_TYPE_GM = MatmulType<TPosition::GM, CubeFormat::ND, float>;
using C_TYPE_UB = MatmulType<TPosition::VECIN, CubeFormat::NZ, float>;
using BIAS_TYPE = MatmulType<TPosition::GM, CubeFormat::ND, float>;
MatmulImpl<AS_TYPE_GM, BS_TYPE_GM, C_TYPE_GM, BIAS_TYPE, CFG_NORM,
MatmulCallBackFunc<nullptr, nullptr, nullptr>, Impl::Detail::MatmulWithScalePolicy> mm1;
MatmulImpl<AS_TYPE_UB, BS_TYPE_UB, C_TYPE_UB, BIAS_TYPE, CFG_NORM,
MatmulCallBackFunc<nullptr, nullptr, nullptr>, CustomMatmulPolicy> mm2;
};
TEST_F(TestMxMatmulImpl, test_mx_matmul_impl_scaleUb_OutputUb) {
TPipe pipe;
TCubeTiling tiling;
TilingParamsMx tilingParamsMx = {1, 64, 64, 64, 64, 64, 64, 32, 32, 32, 2, 2, 1, 1, 2, 2, 1, 1, 0};
tilingParamsMx.GetTiling(tiling);
bool isTransposeA = false;
bool isTransposeAS = false;
bool isTransposeB = false;
bool isTransposeBS = false;
using A_T = typename AS_TYPE_UB::T;
using B_T = typename BS_TYPE_UB::T;
using C_T = typename C_TYPE_UB::T;
typedef fp8_e8m0_t MX_T;
const int32_t left_data_size = tiling.M * tiling.Ka;
const int32_t left_scale_data_size = tiling.M * ((tiling.Ka + 31) / 32);
const int32_t right_data_size = tiling.Kb * tiling.N;
const int32_t right_scale_data_size = ((tiling.Kb + 31) / 32) * tiling.N;
const int32_t output_data_size = tiling.M * tiling.N;
uint8_t left_global[left_data_size * sizeof(A_T)] = {0};
uint8_t left_scale_global[left_scale_data_size * sizeof(MX_T)] = {0};
uint8_t right_global[right_data_size * sizeof(B_T)] = {0};
uint8_t right_scale_global[right_scale_data_size * sizeof(MX_T)] = {0};
uint8_t output_global[output_data_size * sizeof(C_T)] = {0};
GM_ADDR aGM = left_global;
GM_ADDR bGM = right_global;
TQue<TPosition::VECOUT, 1> leftMatrixScale;
TQue<TPosition::VECOUT, 1> rightMatrixScale;
TQue<TPosition::VECIN, 1> resultCMatrix;
GlobalTensor<A_T> aGlobal;
GlobalTensor<B_T> bGlobal;
LocalTensor<MX_T> bufferMxLeft;
LocalTensor<MX_T> bufferMxRight;
LocalTensor<C_T> bufferC;
aGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ A_T*>(aGM), left_data_size);
bGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ B_T*>(bGM), right_data_size);
pipe.InitBuffer(leftMatrixScale, 1, left_scale_data_size * sizeof(MX_T));
pipe.InitBuffer(rightMatrixScale, 1, right_scale_data_size * sizeof(MX_T));
pipe.InitBuffer(resultCMatrix, 1, output_data_size * sizeof(C_T));
auto gmA = aGlobal[0];
auto gmB = bGlobal[0];
bufferMxLeft = leftMatrixScale.AllocTensor<MX_T>();
bufferMxRight = rightMatrixScale.AllocTensor<MX_T>();
bufferC = resultCMatrix.AllocTensor<C_T>();
mm2.Init(&tiling, &pipe);
mm2.SetTensorA(gmA, isTransposeA);
mm2.SetTensorScaleA(bufferMxLeft, isTransposeAS);
mm2.SetTensorB(gmB, isTransposeB);
mm2.SetTensorScaleB(bufferMxRight, isTransposeBS);
mm2.IterateAll(bufferC);
KfcSetIntraScaleAId(mm2, 0);
KfcSetIntraScaleBId(mm2, 0);
mm2.End();
}