* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file grouped_matmul.cpp
* \brief
*/
#include "grouped_matmul_utils.h"
#include "grouped_matmul_antiquant.h"
#include "grouped_matmul_vector.h"
#include "grouped_matmul_tiling_key.h"
#include "grouped_matmul.h"
#include "kernel_operator.h"
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113))
#include "grouped_matmul_antiquant_a16w8_msd.h"
#include "grouped_matmul_antiquant_a8w4_msd_pre.h"
#include "grouped_matmul_antiquant_a8w4_msd.h"
#include "grouped_matmul_antiquant_a8w4_pre.h"
#include "grouped_matmul_antiquant_a8w4.h"
#include "grouped_matmul_antiquant_a8w4_msd_new.h"
#include "grouped_matmul_quant_mixcore.h"
#include "grouped_matmul_pre_tiling.h"
#include "grouped_matmul_a4w4.h"
#include "grouped_matmul_autotiling_a8w4.h"
#include "a16w4_msd/grouped_matmul_weight_quant_a16w4_msd_controller.h"
#ifndef __CCE_KT_TEST__
#include "grouped_matmul_fixaxismove_interface.cpp"
#include "grouped_matmul_a4w4_interface.cpp"
#endif
#endif
using namespace AscendC;
using namespace matmul;
using namespace GROUPED_MATMUL;
#ifndef FORMAT_FRACTAL_NZ
#define FORMAT_FRACTAL_NZ
#endif
namespace {
#if defined(FORMAT_WEIGHT) && FORMAT_WEIGHT == FORMAT_FRACTAL_NZ
constexpr CubeFormat wFormat = CubeFormat::NZ;
constexpr MatmulConfig matmulCFG = NZ_CFG_MDL;
#else
constexpr CubeFormat wFormat = CubeFormat::ND;
constexpr MatmulConfig matmulCFG = CFG_MDL;
#endif
#if defined(GMM_ANTI_QUANT_A8W4_MSD)
constexpr MatmulConfig A8W4_GMM_CFG_MDL = GetNormalConfig();
constexpr auto GetMmCFG() {
auto CFG = CFG_MDL;
CFG.isPartialOutput = true;
return CFG;
}
constexpr MatmulConfig A8W4_GMM_CFG_MDL_NEW = GetMmCFG();
#endif
}
template <bool trans = false>
using xType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, DTYPE_X, trans>;
template <bool trans = false>
using xTypeMSD = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, DTYPE_WEIGHT, trans>;
template <bool trans = false>
using weightType = MatmulType<AscendC::TPosition::GM, wFormat, DTYPE_X, trans>;
template <bool trans = false>
using weightTypeMSD = MatmulType<AscendC::TPosition::GM, wFormat, DTYPE_WEIGHT, trans>;
using yType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, MM_DTYPE_Y>;
using yTypeMSD = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, int32_t>;
using biasType = MatmulType<AscendC::TPosition::GM, CubeFormat::ND, DTYPE_BIAS>;
namespace {
__aicore__ inline static constexpr MatmulApiStaticTiling GetGmmMatmulApiTiling(bool isND2NZ, bool transB) {
MatmulConfig conf = GenGmmConf(isND2NZ);
MatmulApiStaticTiling staticTilingTmp;
if (transB) {
staticTilingTmp = GetMatmulApiTiling<xType<false>, weightType<true>, yType, biasType>(conf);
} else {
staticTilingTmp = GetMatmulApiTiling<xType<false>, weightType<false>, yType, biasType>(conf);
}
staticTilingTmp.depthA1 = STATIC_TILING_DEPTH_A1_B1;
staticTilingTmp.depthB1 = STATIC_TILING_DEPTH_A1_B1;
staticTilingTmp.stepM = 1;
staticTilingTmp.stepN = 1;
staticTilingTmp.stepKa = STATIC_TILING_STEP_KA_KB;
staticTilingTmp.stepKb = STATIC_TILING_STEP_KA_KB;
staticTilingTmp.dbL0A = DOUBLE_BUFFER_L0A_L0B;
staticTilingTmp.dbL0B = DOUBLE_BUFFER_L0A_L0B;
staticTilingTmp.dbL0C = 1;
return staticTilingTmp;
}
#if defined(FORMAT_WEIGHT) && FORMAT_WEIGHT == FORMAT_FRACTAL_NZ
constexpr bool isWeightNZ = true;
#else
constexpr bool isWeightNZ = false;
#endif
constexpr static auto staticCFG = GetGmmMatmulApiTiling(isWeightNZ, false);
constexpr static auto staticCFGtransB = GetGmmMatmulApiTiling(isWeightNZ, true);
}
#define GMM_IMP(computeClass, processClass, transA, transB, sync, cfg) \
do { \
using matmulType = MMType<xType<transA>, weightType<transB>, yType, biasType, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
REGIST_MATMUL_OBJ(&tPipe, GetSysWorkSpacePtr(), mm, &mmTilingData_); \
computeClass<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, &mmTilingData_, gmmArrayAddr_, groupList, tiling); \
op.Process(); \
} while (0)
#define GMM_CUBE_STATIC_TILING_IMP(processClass, transA, transB, sync, cfg) \
do { \
if ASCEND_IS_AIV { \
return; \
} \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
using matmulType = MMImplType<xType<transA>, weightType<transB>, yType, biasType, cfg>; \
matmulType::MT mm; \
mm.SetSubBlockIdx(0); \
mm.Init((TCubeTiling*)nullptr, &tPipe); \
GMMCompute<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, nullptr, &tPipe); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, nullptr, 0, groupList, tiling); \
op.InitStaticTiling((cfg).baseM, (cfg).baseN); \
op.Process(); \
} while (0)
#define GMM_CV_SPLIT_STATIC_TILING_IMP(computeClass, processClass, transA, transB, sync, cfg, aType, bType, cType) \
do { \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, biasType, cfg>; \
matmulType::MT mm; \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init((TCubeTiling*)nullptr, &tPipe); \
} \
computeClass<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, nullptr, &tPipe); \
computeOp.InitStaticTiling(&gmmBaseParams_, user1, (cfg).baseM, (cfg).baseN); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, nullptr, 0, groupList, tiling); \
op.InitStaticTiling((cfg).baseM, (cfg).baseN); \
op.Process(); \
} while (0)
#define GMM_CUBE_IMP(processClass, transA, transB, sync, cfg) \
do { \
if ASCEND_IS_AIV { \
return; \
} \
using matmulType = MMImplType<xType<transA>, weightType<transB>, yType, biasType, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
GMMCompute<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, &mmTilingData_, gmmArrayAddr_, groupList, tiling); \
op.Process(); \
} while (0)
#if defined(CONST_TILING)
#define GMM_CV_SPLIT_IMP(computeClass, processClass, transA, transB, sync, cfg, aType, bType, cType) \
do { \
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, biasType, cfg>; \
matmulType::MT mm; \
GMMTilingData gmmTilingData; \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, &mmTilingData_, gmmArrayAddr_, groupList, tiling); \
op.Process(); \
} while (0)
#else
#define GMM_CV_SPLIT_IMP(computeClass, processClass, transA, transB, sync, cfg, aType, bType, cType) \
do { \
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, biasType, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
GMMPreTilingProcess preTiling; \
preTiling.Init(groupList, gmmBaseParams_, mmTilingData_, &tPipe); \
preTiling.Process(gmmBaseParams_, mmTilingData_); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType, sync> computeOp(mm); \
computeOp.Init(x, weight, bias, scale, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
processClass<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, &mmTilingData_, gmmArrayAddr_, groupList, tiling); \
op.Process(); \
} while (0)
#endif
#define GMM_A4W4_IMP(computeClass, transA, transB, cfg, aType, bType, cType) \
do { \
using matmulType = MMImplType<aType<transA>, bType<transB>, cType, biasType, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType> computeOp(mm); \
computeOp.Init(x, weight, scale, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
computeOp.Process(); \
} while (0)
#define GMM_CV_SPLIT_IMP_A8W4_MSD(computeClass, cfg) \
do { \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
if ASCEND_IS_AIV { \
GMMA8W4PreProcess op1; \
op1.Init(x, x, groupList, user1, gmmBaseParams_, &tPipe); \
op1.Process(); \
tPipe.Reset(); \
tPipe.Destroy(); \
tPipe.Init(); \
} \
using aT = MatmulType<TPosition::GM, CubeFormat::ND, DTYPE_X_DEV_A8W4MSD, false>; \
using bT = MatmulType<TPosition::GM, wFormat, DTYPE_WEIGHT_DEV_A8W4MSD, false>; \
using biasT = MatmulType<TPosition::GM, CubeFormat::ND, int32_t, false>; \
using cT = MatmulType<TPosition::GM, CubeFormat::ND, half, false>; \
using matmulType = MMImplType<aT, bT, cT, biasT, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType> op(mm); \
op.Init(x, weight, bias, groupList, scale, perTokenScale, offset, nullptr, nullptr, nullptr, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
op.Process(); \
} while (0)
#define GMM_CV_SPLIT_IMP_A8W4(computeClass, cfg) \
do { \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
if ASCEND_IS_AIV { \
GMMA8W4FakeQuantPreProcess<wFormat> op1; \
op1.Init(weight, y, groupList, user1, gmmBaseParams_, &tPipe); \
op1.Process(); \
tPipe.Reset(); \
tPipe.Destroy(); \
tPipe.Init(); \
} \
SyncAll<false>(); \
using aT = MatmulType<TPosition::GM, CubeFormat::ND, int8_t, false>; \
using bT = MatmulType<TPosition::GM, wFormat, int8_t, false>; \
using biasT = MatmulType<TPosition::GM, CubeFormat::ND, int32_t, false>; \
using cT = MatmulType<TPosition::GM, CubeFormat::ND, half, false>; \
using matmulType = MMImplType<aT, bT, cT, biasT, cfg>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
computeClass<matmulType> op(mm); \
op.Init(x, weight, bias, groupList, scale, perTokenScale, offset, nullptr, nullptr, nullptr, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
op.Process(); \
} while (0)
#define GMM_CV_SPLIT_IMP_A8W4_FAKEA8W8(computeClass, cfg) \
do { \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
if ASCEND_IS_AIV { \
GMMA8W4FakeQuantPreProcess<wFormat> op1; \
op1.Init(weight, y, scale, user1, gmmBaseParams_, &tPipe); \
op1.Process(); \
tPipe.Reset(); \
tPipe.Destroy(); \
tPipe.Init(); \
} \
SyncAll<false>(); \
GlobalTensor<int8_t> yGm; \
yGm.SetGlobalBuffer((__gm__ int8_t *)workspace); \
using aT = MatmulType<TPosition::GM, CubeFormat::ND, int8_t, false>; \
using bT = MatmulType<TPosition::GM, wFormat, int8_t, false>; \
using biasT = MatmulType<TPosition::GM, CubeFormat::ND, int32_t, false>; \
using cT = MatmulType<TPosition::GM, CubeFormat::ND, int32_t, false>; \
using matmulType = MMImplType<aT, bT, cT, biasT, matmulCFG>; \
matmulType::MT mm; \
GET_TILING_DATA_MEMBER(GMMTilingData, mmTilingData, mmTilingData_, tiling); \
GET_TILING_DATA_MEMBER_ADDR(GMMTilingData, gmmArray, gmmArrayAddr_, tiling); \
if ASCEND_IS_AIC { \
mm.SetSubBlockIdx(0); \
mm.Init(&mmTilingData_, &tPipe); \
} \
GMMQuantMixCoreCompute<matmulType, false> computeOp(mm); \
computeOp.isA8W4FakeQuant = true; \
computeOp.Init(x, user1, bias, user1, offset, antiquantScale, antiquantOffset, groupList, perTokenScale, \
y, user1, &gmmBaseParams_, &mmTilingData_, &tPipe); \
GMMProcess<decltype(computeOp)> op(computeOp); \
op.Init(&gmmBaseParams_, &mmTilingData_, gmmArrayAddr_, groupList, tiling); \
op.Process(); \
} while (0)
#define GMM_CV_SPLIT_IMP_A16W4_MSD(computeClass, ...) \
do { \
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling); \
computeClass<DTYPE_X, DTYPE_WEIGHT, DTYPE_BIAS, GROUP_LIST_TYPE> op; \
op.Init(x, weight, antiquantScale, bias, groupList, y, &gmmBaseParams_); \
op.Process(workspace, &tPipe); \
} while (0)
template <int D_T_A, int D_T_B, int D_T_Y, int TRANS_A, int TRANS_B, int GROUP_LIST_TYPE,
int IS_STATIC_TILING_API, int A8W4_KERNEL_TEMPLATE, int A16W8_KERNEL_TEMPLATE, int AIV_AIC_RATIO, bool IS_ENABLE_FIXED_AXIS>
__global__ __aicore__ void grouped_matmul(GM_ADDR x, GM_ADDR weight, GM_ADDR bias, GM_ADDR scale,
GM_ADDR offset, GM_ADDR antiquantScale, GM_ADDR antiquantOffset,
GM_ADDR groupList, GM_ADDR perTokenScale, GM_ADDR y,
GM_ADDR workspace, GM_ADDR tiling)
{
TPipe tPipe;
AscendCUtils::SetOverflow(1);
KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_AIC_ONLY);
GM_ADDR user1 = GetUserWorkspace(workspace);
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113))
#if defined(GMM_ANTI_QUANT_A8W4_MSD)
if constexpr (D_T_A == GMM_TPL_INT8 && D_T_B == GMM_TPL_INT4) {
if constexpr (A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_MSD_API_DEQUANT) {
GMM_CV_SPLIT_IMP_A8W4_MSD(GMMA8W4MSDCompute, A8W4_GMM_CFG_MDL);
} else if constexpr (A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_MSD_VECTOR_DEQUANT) {
GMM_CV_SPLIT_IMP_A8W4_MSD(GMMA8W4MSDComputeNew, A8W4_GMM_CFG_MDL_NEW);
} else if constexpr (A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_PERCHANNEL_ANTIQUANT) {
GMM_CV_SPLIT_IMP_A8W4_FAKEA8W8(GMMA8W4Compute, A8W4_GMM_CFG_MDL);
} else if constexpr (A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_PERGROUP_ANTIQUANT) {
GMM_CV_SPLIT_IMP_A8W4(GMMA8W4Compute, A8W4_GMM_CFG_MDL);
} else if constexpr (A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_AUTOTILING) {
GET_TILING_DATA_MEMBER(GMMTilingData, hpTilingData, tilingData, tiling);
GM_ADDR A = x;
GM_ADDR B = weight;
GM_ADDR C = y;
GM_ADDR groupListOptional = groupList;
GM_ADDR bias_ = bias;
GM_ADDR offset_ = offset;
GM_ADDR sa = perTokenScale;
GM_ADDR sw = scale;
GM_ADDR workspaceDevice = user1;
GMMA4W8AutotilingCompute op(A, B, C, groupListOptional, bias_, offset_, sa, sw, workspaceDevice,
const_cast<A8W4HPTiling *>(&tilingData), &tPipe);
op.Init();
op.Process();
}
}
#elif defined(GMM_ANTI_QUANT)
if constexpr ((D_T_A == GMM_TPL_BF16) &&
A16W8_KERNEL_TEMPLATE == GROUPED_MATMUL_A16W4_KERNEL_TEMPLATE_MSD_ANTIQUANT_GS32) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP_A16W4_MSD(A16W4Msd::GMMWeightQuantA16W4MsdControllerMSparse, false);
} else {
GMM_CV_SPLIT_IMP_A16W4_MSD(A16W4Msd::GMMWeightQuantA16W4MsdController, false);
}
} else if constexpr ((D_T_A == GMM_TPL_FLOAT16 || D_T_A == GMM_TPL_BF16) &&
A16W8_KERNEL_TEMPLATE != GROUPED_MATMUL_A16W8_KERNEL_TEMPLATE_MSD) {
if constexpr (TRANS_B == 0 && AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_1) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantSparseProcess, false, false, false, matmulCFG);
} else {
GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantProcess, false, false, false, matmulCFG);
}
} else if constexpr (TRANS_B == 1 && AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_1) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantSparseProcess, false, true, false, matmulCFG);
} else {
GMM_IMP(GMMAntiquantComputeNorm, GMMAntiquantProcess, false, true, false, matmulCFG);
}
} else if constexpr (TRANS_B == 0 && AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_2) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_IMP(GMMAntiquantComputePerformance, GMMAntiquantSparseProcess, false, false, false, matmulCFG);
} else {
GMM_IMP(GMMAntiquantComputePerformance, GMMAntiquantProcess, false, false, false, matmulCFG);
}
}
}
#if defined(ORIG_DTYPE_WEIGHT) && defined(DT_INT8) && ORIG_DTYPE_WEIGHT == DT_INT8
if constexpr ((D_T_A == GMM_TPL_FLOAT16 || D_T_A == GMM_TPL_BF16) && D_T_B == GMM_TPL_INT8 &&
A16W8_KERNEL_TEMPLATE == GROUPED_MATMUL_A16W8_KERNEL_TEMPLATE_MSD) {
if constexpr (TRANS_B == 0) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDMSparseProcess, false, false, false,
matmulCFG, xTypeMSD, weightTypeMSD, yTypeMSD);
} else {
GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDProcess, false, false, false,
matmulCFG, xTypeMSD, weightTypeMSD, yTypeMSD);
}
} else if constexpr (TRANS_B == 1) {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDMSparseProcess, false, true, false,
matmulCFG, xTypeMSD, weightTypeMSD, yTypeMSD);
} else {
GMM_CV_SPLIT_IMP(GMMA16W8MSDCompute, GMMA16W8MSDProcess, false, true, false,
matmulCFG, xTypeMSD, weightTypeMSD, yTypeMSD);
}
}
}
#endif
#elif defined(GMM_QUANT_BF16) || defined(GMM_QUANT_FLOAT16)
if constexpr (D_T_A == GMM_TPL_INT8 && D_T_B == GMM_TPL_INT8 && (D_T_Y == GMM_TPL_BF16 || D_T_Y == GMM_TPL_FLOAT16) &&
TRANS_A == 0 && A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_NONE) {
if constexpr (IS_STATIC_TILING_API == 0) {
if constexpr (AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_1) {
if constexpr(IS_ENABLE_FIXED_AXIS == 0) {
if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, false, false, matmulCFG, xType, weightType, yType);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, true, false, matmulCFG, xType, weightType, yType);
} else if constexpr(TRANS_B == 0 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMGroupMSparseProcess, false, false, false, matmulCFG, xType,
weightType, yType);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMGroupMSparseProcess, false, true, false, matmulCFG, xType,
weightType, yType);
}
} else if constexpr(IS_ENABLE_FIXED_AXIS == 1 && TRANS_B == 0 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_CUMSUM) {
tPipe.Destroy();
AscendC::SetMMLayoutTransform(true);
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling)
using XDType = int8_t;
using WeightDType = int8_t;
using CDType = int32_t;
using ScaleDType = float;
using GrouplistDType = int64_t;
using PerTokenScaleDType = float;
using YDType = half;
#ifndef __CCE_KT_TEST__
Catlass::grouped_matmul_fixaxismove<XDType, WeightDType, CDType, ScaleDType, GrouplistDType, PerTokenScaleDType, YDType>(
gmmBaseParams_.m, gmmBaseParams_.k, gmmBaseParams_.n, gmmBaseParams_.groupNum,
x, weight, scale, groupList, perTokenScale, y, user1, gmmBaseParams_.coreNum);
#endif
}
} else if constexpr (AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_2) {
if constexpr (TRANS_B == 0) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, false, false, matmulCFG, xType, weightType, yType);
} else if constexpr (TRANS_B == 1) {
GMM_CV_SPLIT_IMP(GMMQuantMixCoreCompute, GMMProcess, false, true, false, matmulCFG, xType, weightType, yType);
}
}
} else if (IS_STATIC_TILING_API == 1) {
if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_STATIC_TILING_IMP(GMMQuantMixCoreCompute, GMMProcess,
false, false, false, staticCFG, xType, weightType, yType);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_STATIC_TILING_IMP(GMMQuantMixCoreCompute, GMMProcess,
false, true, false, staticCFGtransB, xType, weightType, yType);
} else if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_STATIC_TILING_IMP(GMMQuantMixCoreCompute, GMMGroupMSparseProcess,
false, false, false, staticCFG, xType, weightType, yType);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CV_SPLIT_STATIC_TILING_IMP(GMMQuantMixCoreCompute, GMMGroupMSparseProcess,
false, true, false, staticCFGtransB, xType, weightType, yType);
}
}
}
#elif defined(GMM_A4W4)
if constexpr (D_T_A == GMM_TPL_INT4 && D_T_B == GMM_TPL_INT4) {
if constexpr (TRANS_B == 0) {
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling);
if (gmmBaseParams_.isA4W4Optimize) {
tPipe.Destroy();
AscendC::SetMMLayoutTransform(true);
#ifndef __CCE_KT_TEST__
Catlass::grouped_matmul_a4w4_catlass(
gmmBaseParams_.m, gmmBaseParams_.k, gmmBaseParams_.n, gmmBaseParams_.groupNum, gmmBaseParams_.quantGroupNum,
x, weight, scale, groupList, perTokenScale, y, user1, gmmBaseParams_.coreNum);
#endif
AscendC::SetMMLayoutTransform(false);
} else {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_A4W4_IMP(GMMA4W4SparseCompute, false, false, matmulCFG, xType, weightType, yType);
} else {
GMM_A4W4_IMP(GMMA4W4Compute, false, false, matmulCFG, xType, weightType, yType);
}
}
} else {
if constexpr (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_A4W4_IMP(GMMA4W4SparseCompute, false, true, matmulCFG, xType, weightType, yType);
} else {
GMM_A4W4_IMP(GMMA4W4Compute, false, true, matmulCFG, xType, weightType, yType);
}
}
}
#elif defined(GMM_QUANT_INT8) || defined(GMM_QUANT_INT32)
if constexpr (D_T_A == GMM_TPL_INT8 && D_T_B == GMM_TPL_INT8 && (D_T_Y == GMM_TPL_INT8 || D_T_Y == GMM_TPL_INT32) &&
TRANS_A == 0 && A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_NONE &&
AIV_AIC_RATIO == GROUPED_MATMUL_CUBE_ONLY) {
if constexpr (IS_STATIC_TILING_API == 0) {
if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_IMP(GMMProcess, false, false, false, matmulCFGUnitFlag);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_IMP(GMMProcess, false, true, false, matmulCFGUnitFlag);
} else if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, false, false, matmulCFGUnitFlag);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, true, false, matmulCFGUnitFlag);
}
} else if constexpr (IS_STATIC_TILING_API == 1){
if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_STATIC_TILING_IMP(GMMProcess, false, false, false, staticCFG);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE != GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_STATIC_TILING_IMP(GMMProcess, false, true, false, staticCFGtransB);
} else if constexpr (TRANS_B == 0 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_STATIC_TILING_IMP(GMMGroupMSparseProcess, false, false, false, staticCFG);
} else if constexpr (TRANS_B == 1 && GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM) {
GMM_CUBE_STATIC_TILING_IMP(GMMGroupMSparseProcess, false, true, false, staticCFGtransB);
}
}
}
#elif defined(GMM_FLOAT)
if (IS_STATIC_TILING_API == 0 && A8W4_KERNEL_TEMPLATE == GROUPED_MATMUL_A8W4_KERNEL_TEMPLATE_NONE) {
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling);
if constexpr (TRANS_A == 0 && TRANS_B == 0 && AIV_AIC_RATIO == GROUPED_MATMUL_CUBE_ONLY) {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, false, false, matmulCFGUnitFlag);
} else {
GMM_CUBE_IMP(GMMProcess, false, false, false, matmulCFGUnitFlag);
}
} else if constexpr (TRANS_A == 0 && TRANS_B == 1 && AIV_AIC_RATIO == GROUPED_MATMUL_CUBE_ONLY) {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, true, false, matmulCFGUnitFlag);
} else {
GMM_CUBE_IMP(GMMProcess, false, true, false, matmulCFGUnitFlag);
}
} else if constexpr (TRANS_A == 1 && AIV_AIC_RATIO == GROUPED_MATMUL_AIV_AIC_RATIO_1) {
if ASCEND_IS_AIV {
GET_TILING_DATA(tilingData, tiling);
EmptyTensorCompute<DTYPE_Y>(groupList, y, &tilingData);
}
if ASCEND_IS_AIC {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, true, false, false, matmulCFG);
} else {
GMM_CUBE_IMP(GMMProcess, true, false, false, matmulCFG);
}
}
}
}
#endif
#endif
#if defined(__CCE_AICORE__) && __CCE_AICORE__ == 200
#if defined(GMM_FLOAT)
GET_TILING_DATA_MEMBER(GMMTilingData, gmmBaseParams, gmmBaseParams_, tiling);
if constexpr (TRANS_A == 0 && TRANS_B == 0) {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, false, false, matmulCFG);
} else {
GMM_CUBE_IMP(GMMProcess, false, false, false, matmulCFG);
}
} else if constexpr (TRANS_A == 0 && TRANS_B == 1) {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, false, true, false, matmulCFG);
} else {
GMM_CUBE_IMP(GMMProcess, false, true, false, matmulCFG);
}
} else if constexpr (TRANS_A == 1 && TRANS_B == 0) {
if ASCEND_IS_AIV {
GET_TILING_DATA(tilingData, tiling);
EmptyTensorCompute<DTYPE_Y>(groupList, y, &tilingData);
}
if ASCEND_IS_AIC {
if (GROUP_LIST_TYPE == GROUPED_MATMUL_GROUP_LIST_TYPE_SPARSEM && gmmBaseParams_.groupType == 0) {
GMM_CUBE_IMP(GMMGroupMSparseProcess, true, false, false, matmulCFG);
} else {
GMM_CUBE_IMP(GMMProcess, true, false, false, matmulCFG);
}
}
}
#endif
#endif
}