* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file grouped_matmul_utils.h
* \brief
*/
#ifndef ASCENDC_GROUPED_MATMUL_UTILS_H
#define ASCENDC_GROUPED_MATMUL_UTILS_H
#if ((defined(__CCE_AICORE__) && (__CCE_AICORE__ == 310)) && !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3113))
#if ASC_DEVKIT_MAJOR >= 9
#include "kernel_basic_intf.h"
#else
#include "kernel_operator.h"
#endif
#else
#include "kernel_operator.h"
#endif
#include "kernel_tiling/kernel_tiling.h"
#include "lib/matmul_intf.h"
#if ((defined(__CCE_AICORE__) && (__CCE_AICORE__ == 310)) && !(defined(__NPU_ARCH__) && __NPU_ARCH__ == 3113))
#if defined(ORIG_DTYPE_X) && defined(DT_INT8) && ORIG_DTYPE_X == DT_INT8
#define DTYPE_L0C_LOCAL int32_t
#else
#define DTYPE_L0C_LOCAL float
#endif
#if defined(ORIG_DTYPE_X) && defined(ORIG_DTYPE_WEIGHT) && defined(DT_FLOAT8_E5M2) && defined(DT_FLOAT8_E4M3FN) && \
defined(DT_HIFLOAT8) && defined(DT_INT8) && defined(DT_FLOAT4_E2M1) && defined(DT_INT4) && \
((ORIG_DTYPE_X == DT_INT8 && ORIG_DTYPE_WEIGHT == DT_INT8) || \
(ORIG_DTYPE_X == DT_HIFLOAT8 && ORIG_DTYPE_WEIGHT == DT_HIFLOAT8) || \
((ORIG_DTYPE_X == DT_FLOAT8_E5M2 || ORIG_DTYPE_X == DT_FLOAT8_E4M3FN) && \
(ORIG_DTYPE_WEIGHT == DT_FLOAT8_E5M2 || ORIG_DTYPE_WEIGHT == DT_FLOAT8_E4M3FN)) || \
((ORIG_DTYPE_X == DT_FLOAT4_E2M1 || ORIG_DTYPE_X == DT_FLOAT4_E1M2) && \
(ORIG_DTYPE_WEIGHT == DT_FLOAT4_E2M1 || ORIG_DTYPE_WEIGHT == DT_FLOAT4_E1M2)) || \
(ORIG_DTYPE_X == DT_INT4 && ORIG_DTYPE_WEIGHT == DT_INT4))
#define V310_GMM_QUANT
#if defined(ORIG_DTYPE_SCALE) && defined(DT_FLOAT8_E8M0) && ORIG_DTYPE_SCALE == DT_FLOAT8_E8M0
#define V310_GMM_QUANT_MX
#elif defined(ORIG_DTYPE_SCALE) && defined(DT_UINT64) && defined(DT_INT64) && \
(ORIG_DTYPE_SCALE != DT_UINT64 && ORIG_DTYPE_SCALE != DT_INT64)
#define V310_GMM_QUANT_MIX
#define V310_GMM_QUANT_PERTENSOR_CUBE
#if (ORIG_DTYPE_X != DT_INT8 && ORIG_DTYPE_SCALE == DT_FLOAT)
#define V310_GMM_QUANT_PERTILE
#endif
#else
#define V310_GMM_QUANT_CUBE
#endif
#endif
#if defined(ORIG_DTYPE_X) && defined(ORIG_DTYPE_WEIGHT) && ORIG_DTYPE_X != ORIG_DTYPE_WEIGHT
#if ((ORIG_DTYPE_X == DT_FLOAT16 || ORIG_DTYPE_X == DT_BF16) && \
(ORIG_DTYPE_WEIGHT == DT_FLOAT8_E5M2 || ORIG_DTYPE_WEIGHT == DT_FLOAT8_E4M3FN || \
ORIG_DTYPE_WEIGHT == DT_HIFLOAT8 || ORIG_DTYPE_WEIGHT == DT_INT8 || ORIG_DTYPE_WEIGHT == DT_FLOAT4_E2M1 || \
ORIG_DTYPE_WEIGHT == DT_FLOAT || ORIG_DTYPE_WEIGHT == DT_INT32 || \
ORIG_DTYPE_WEIGHT == DT_INT4)) || \
(ORIG_DTYPE_X == DT_INT8 && (ORIG_DTYPE_WEIGHT == DT_INT4 || ORIG_DTYPE_WEIGHT == DT_INT32)) || \
(ORIG_DTYPE_X == DT_FLOAT8_E4M3FN && (ORIG_DTYPE_WEIGHT == DT_FLOAT4_E2M1 || ORIG_DTYPE_WEIGHT == DT_FLOAT))
#define V310_GMM_ANTI_QUANT
#endif
#endif
#endif
#if defined(ORIG_DTYPE_X) && defined(ORIG_DTYPE_WEIGHT) && defined(ORIG_DTYPE_Y) && defined(DT_INT8) && \
defined(DT_BF16) && defined(DT_INT4)
#if ORIG_DTYPE_X == ORIG_DTYPE_WEIGHT
#if ORIG_DTYPE_X == DT_INT8
#if ORIG_DTYPE_Y == DT_BF16
#define GMM_QUANT_BF16
#define MM_DTYPE_Y int32_t
#elif ORIG_DTYPE_Y == DT_FLOAT16
#define GMM_QUANT_FLOAT16
#define MM_DTYPE_Y int32_t
#elif ORIG_DTYPE_Y == DT_INT32
#define GMM_QUANT_INT32
#else
#define GMM_QUANT_INT8
#endif
#elif ORIG_DTYPE_X == DT_INT4
#define GMM_A4W4
#define MM_DTYPE_Y half
#if ORIG_DTYPE_Y == DT_BF16
#define GMM_A4W4_BF16
#elif ORIG_DTYPE_Y == DT_FLOAT16
#define GMM_A4W4_FP16
#endif
#else
#define GMM_FLOAT
#endif
#else
#define GMM_ANTI_QUANT
#if ORIG_DTYPE_X == DT_INT8 && ORIG_DTYPE_WEIGHT == DT_INT4
#define GMM_ANTI_QUANT_A8W4_MSD
#define GMM_ANTI_QUANT_A8W4
#if ORIG_DTYPE_Y == DT_BF16
#define GMM_ANTI_QUANT_A8W4_MSD_OUT_BF16
#else
#define GMM_ANTI_QUANT_A8W4_MSD_OUT_FP16
#endif
#define MM_DTYPE_Y int32_t
#else
#define GMM_ANTI_QUANT
#endif
#endif
#endif
#if defined(DTYPE_Y) && !defined(MM_DTYPE_Y)
#define MM_DTYPE_Y DTYPE_Y
#endif
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113))
#ifdef GMM_ANTI_QUANT_A8W4_MSD_OUT_BF16
#undef DTYPE_SCALE
#define DTYPE_SCALE bfloat16_t
#elif defined(GMM_ANTI_QUANT_A8W4_MSD_OUT_FP16)
#undef DTYPE_SCALE
#define DTYPE_SCALE float
#endif
#endif
#if defined(CONST_TILING)
#define TILING_TYPE const int32_t
#else
#define TILING_TYPE __gm__ int32_t
#endif
#if defined(CONST_TILING)
#if defined(V310_GMM_ANTI_QUANT)
#define GET_TILING_DATA_MEMBER_ADDR(tilingType, member, var, tiling) \
GET_TILING_DATA_MEMBER(GMMWeightQuantTilingData, member, obj, tiling); \
const int32_t* (var) = (const int32_t*)((const uint8_t*)&obj);
#else
#define GET_TILING_DATA_MEMBER_ADDR(tilingType, member, var, tiling) \
GET_TILING_DATA_MEMBER(tilingType, member, obj, tiling); \
const int32_t* (var) = (const int32_t*)((const uint8_t*)&obj);
#endif
#else
#define GET_TILING_DATA_MEMBER_ADDR(tilingType, member, var, tiling) \
size_t offset##var = (size_t)(&((tilingType*)0)->member); \
__gm__ int32_t* (var) = (__gm__ int32_t*)((tiling) + (offset##var));
#endif
namespace GROUPED_MATMUL {
using namespace AscendC;
constexpr uint32_t INT8_BITS = 8;
constexpr int32_t MKN_LIST_LEN = 128;
constexpr uint32_t UB_BLOCK_UNIT_SIZE = 32;
constexpr uint32_t UB_BLOCK_DOUBLE_UNIT_SIZE = 64;
constexpr uint32_t HALF_UB_BLOCK_UNIT_SIZE = UB_BLOCK_UNIT_SIZE / 2;
#if ((defined(__CCE_AICORE__) && __CCE_AICORE__ == 220) || (defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3003 || __NPU_ARCH__ == 3113))) && \
defined(ORIG_DTYPE_X) && defined(ORIG_DTYPE_WEIGHT) && \
ORIG_DTYPE_X == DT_INT8 && ORIG_DTYPE_WEIGHT == DT_INT8
constexpr MatmulConfig NZ_CFG_MDL =
GetMDLConfig(false, false, 0, true, false, false, true, true, true, false, false, true);
constexpr MatmulConfig matmulCFGUnitFlag{.doMultiDataLoad = true, .enUnitFlag = true, .enableKdimReorderLoad = true};
#else
constexpr MatmulConfig NZ_CFG_MDL = GetMDLConfig(false, false, 0, true, false, false, true);
constexpr MatmulConfig matmulCFGUnitFlag{false, false, true, 0, 0, 0, false, false, false, false, false, 0, 0, 0,
0, 0, 0, 0, true};
#endif
constexpr uint64_t SYNC_AIV_AIC_FLAG = 3;
constexpr uint64_t SYNC_AIC_AIV_FLAG = 5;
constexpr uint64_t SYNC_MODE2 = 2;
constexpr uint32_t BASIC_BLOCK_SIZE_128 = 128;
constexpr uint32_t BASIC_BLOCK_SIZE_256 = 256;
constexpr int32_t STATIC_TILING_DEPTH_A1_B1 = 8;
constexpr int32_t STATIC_TILING_STEP_KA_KB = 4;
constexpr uint64_t DOUBLE_BUFFER_L0A_L0B = 2;
constexpr uint32_t STATIC_TILING_MAX_K = 8192;
constexpr uint32_t STATIC_TILING_MAX_SINGLE_N = 1024;
constexpr uint32_t GROUP_LIST_TYPE_SPARSE = 2;
constexpr uint32_t SPARSE_GROUP_LIST_SPLIT_VALUE_OFFSET = 1;
template<class AT_, class BT_, class CT_, class BiasT_, const auto& MM_CFG = CFG_MDL>
struct MMType {
using AT = AT_;
using BT = BT_;
using CT = CT_;
using BiasT = BiasT_;
using MT = matmul::Matmul<AT, BT, CT, BiasT, MM_CFG>;
};
template<class AT_, class BT_, class CT_, class BiasT_, const auto& MM_CFG = CFG_MDL>
struct MMImplType {
using AT = AT_;
using BT = BT_;
using CT = CT_;
using BiasT = BiasT_;
using MT = matmul::MatmulImpl<AT, BT, CT, BiasT, MM_CFG>;
};
enum class ActiveType : std::uint8_t {
INVALID_TYPE = 0,
RELU,
GELU_TANH,
GELU_ERR_FUNC,
FASTGELU,
SILU
};
template <typename T>
__aicore__ inline T GreatestCommonDivisor(T a, T b) {
T c = a;
if (a < b) {
a = b;
b = c;
}
while (b != 0) {
c = a;
a = b;
b = c % b;
}
return a;
}
template <typename T>
__aicore__ inline T LeastCommonMultiple(T a, T b) {
return a * b / GreatestCommonDivisor(a, b);
}
template <typename T>
__aicore__ inline T Max(T a, T b) {
return a > b ? a : b;
}
template <typename T>
__aicore__ inline T Min(T a, T b) {
return a > b ? b : a;
}
template <uint32_t base, typename T = uint32_t>
__aicore__ inline T AlignUp(T a) {
return (a + base - 1) / base * base;
}
template <typename T>
__aicore__ inline T AlignUp(T a, T base) {
return (a + base - 1) / base * base;
}
template <typename T>
__aicore__ inline T AlignDown(T a, T base) {
if (unlikely(base == 0)) {
return a;
}
return a / base * base;
}
template <>
__aicore__ inline uint32_t AlignUp<4, uint32_t>(uint32_t a) {
return (a + 3) & ~3;
}
template <>
__aicore__ inline uint32_t AlignUp<8, uint32_t>(uint32_t a) {
return (a + 7) & ~7;
}
template <>
__aicore__ inline uint32_t AlignUp<16, uint32_t>(uint32_t a) {
return (a + 15) & ~15;
}
template <>
__aicore__ inline uint32_t AlignUp<32, uint32_t>(uint32_t a) {
return (a + 31) & ~31;
}
template <typename T>
__aicore__ inline __gm__ T* GetTensorAddr(uint16_t index, GM_ADDR tensorPtr) {
__gm__ uint64_t* dataAddr = reinterpret_cast<__gm__ uint64_t*>(tensorPtr);
uint64_t tensorPtrOffset = *dataAddr;
__gm__ uint64_t* retPtr = dataAddr + (tensorPtrOffset >> 3);
return reinterpret_cast<__gm__ T*>(*(retPtr + index));
}
#if (defined(__CCE_AICORE__) && __CCE_AICORE__ != 310) || (defined(__NPU_ARCH__) && __NPU_ARCH__ == 3113)
__aicore__ inline int32_t GetSplitValueFromGroupList(uint32_t groupIdx, int32_t &preOffset,
const GMMBaseParams* __restrict &gmmBaseParams,
const GlobalTensor<int64_t> &groupListGm) {
int32_t splitValue = 0;
if (likely(gmmBaseParams->groupType != -1)) {
if (gmmBaseParams->groupListType == 0) {
int32_t offset = static_cast<int32_t>(groupListGm.GetValue(groupIdx));
splitValue = offset - preOffset;
preOffset = offset;
} else if (gmmBaseParams->groupListType == GROUP_LIST_TYPE_SPARSE) {
splitValue = static_cast<int32_t>(groupListGm.GetValue(groupIdx + SPARSE_GROUP_LIST_SPLIT_VALUE_OFFSET));
} else {
splitValue = static_cast<int32_t>(groupListGm.GetValue(groupIdx));
}
}
return splitValue;
}
#endif
template <typename T>
__aicore__ inline constexpr uint32_t GetTypeBits() {
if constexpr (IsSameType<T, int4b_t>::value) {
return 4;
}
return sizeof(T) * INT8_BITS;
}
__aicore__ static constexpr MatmulConfig GenGmmConf(bool isND2NZ) {
return {
.doNorm = false,
.doBasicBlock = false,
.doMultiDataLoad = true,
.basicM = BASIC_BLOCK_SIZE_128,
.basicN = BASIC_BLOCK_SIZE_256,
.basicK = BASIC_BLOCK_SIZE_128,
.intrinsicsCheck = false,
.isNBatch = false,
.enVecND2NZ = isND2NZ,
.doSpecialBasicBlock = false,
.doMTE2Preload = 0,
.singleCoreM = BASIC_BLOCK_SIZE_128,
.singleCoreN = STATIC_TILING_MAX_SINGLE_N,
.singleCoreK = STATIC_TILING_MAX_K,
.stepM = 0,
.stepN = 0,
.baseMN = 0,
.singleCoreMN = 0,
.enUnitFlag = true,
.isPerTensor = false,
.hasAntiQuantOffset = false,
.doIBShareNorm = false,
.doSpecialMDL = false,
.enableInit = false,
.batchMode = BatchMode::NONE,
.enableEnd = true,
.enableGetTensorC = true,
.enableSetOrgShape = true,
.enableSetBias = false,
.enableSetTail = true,
.enableQuantVector = false,
.enableSetDefineData = false,
.iterateMode = IterateMode::ITERATE_MODE_DEFAULT,
.enableReuse = true,
.enableUBReuse = true,
.enableL1CacheUB = false,
.intraBlockPartSum = false,
.iterateOrder = IterateOrder::UNDEF,
.scheduleType = ScheduleType::INNER_PRODUCT,
.enableDoubleCache = false,
.isBiasBatch = true,
.enableStaticPadZeros = false,
.isA2B2Shared = false,
.enableKdimReorderLoad = false,
.isCO1Shared = false,
};
}
}
#endif