* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_operator_mm_check.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message("impl/basic_api/kernel_operator_mm_check.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"basic_api/kernel_operator_mm_intf.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_OPERATOR_MM_CHECK_H__
#endif
#ifndef ASCENDC_MODULE_OPERATOR_MM_CHECK_H
#define ASCENDC_MODULE_OPERATOR_MM_CHECK_H
#include "kernel_check.h"
#include "kernel_npu_debug.h"
#include "kernel_log.h"
#include "kernel_struct_mm.h"
#include "kernel_struct_fixpipe.h"
namespace AscendC {
template <typename T>
__aicore__ static inline bool ChannelSizeRemainder(const uint16_t channelSize, uint16_t remainder[], uint16_t size)
{
uint16_t oneBlkNum = ONE_BLK_SIZE / sizeof(T);
if constexpr (IsSameType<T, int4b_t>::value) {
oneBlkNum = 64;
}
for (uint16_t i = 0; i < size; i++) {
if (channelSize % oneBlkNum == remainder[i]) {
return true;
}
}
return false;
}
template <typename T, typename U, typename S>
__aicore__ static inline void CheckMmadAlign(const LocalTensor<T>& dst, const LocalTensor<U>& fm,
const LocalTensor<S>& filter) {
constexpr uint64_t align1024B = 1024;
if constexpr ((IsSameType<PrimT<U>, half>::value) && (IsSameType<PrimT<S>, half>::value) &&
(IsSameType<PrimT<T>, half>::value)) {
CheckTensorAlign<T>(dst, VALUE_512, "dst", "Mmad");
} else {
CheckTensorAlign<T>(dst, align1024B, "dst", "Mmad");
}
CheckTensorAlign<U>(fm, VALUE_512, "fm", "Mmad");
CheckTensorAlign<S>(filter, VALUE_512, "filter", "Mmad");
}
__aicore__ inline void CheckMmadParamsCommon(const MmadParams& mmadParams, const __gm__ char* apiName)
{
CheckValueRange<uint16_t>(mmadParams.m, 0, UINT12_MAX, "m", apiName);
CheckValueRange<uint16_t>(mmadParams.n, 0, UINT12_MAX, "n", apiName);
CheckValueRange<uint16_t>(mmadParams.k, 0, UINT12_MAX, "k", apiName);
ASCENDC_DEBUG_ASSERT((mmadParams.unitFlag == 0 || mmadParams.unitFlag == 2 || mmadParams.unitFlag == 3),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check unitFlag value in %s, supported values are 0, 2, and 3.\n",
apiName));
ReportNopWarning<uint16_t>(mmadParams.m, "mmadParams.m", apiName);
ReportNopWarning<uint16_t>(mmadParams.n, "mmadParams.n", apiName);
ReportNopWarning<uint16_t>(mmadParams.k, "mmadParams.k", apiName);
}
template <typename T, typename U, typename S>
__aicore__ inline void CheckMmadTensorCommon(const LocalTensor<T>& dst, const LocalTensor<U>& fm,
const LocalTensor<S>& filter, const MmadParams& mmadParams, const __gm__ char* apiName)
{
constexpr uint32_t align1024B = 1024;
CheckMmadParamsCommon(mmadParams, apiName);
CheckTensorPhyPosition<Hardware::L0C>(dst, "dstLocal", "CO1", apiName);
CheckTensorPhyPosition<Hardware::L0A>(fm, "fmLocal", "A2", apiName);
CheckTensorPhyPosition<Hardware::L0B>(filter, "filterLocal", "B2", apiName);
CheckTensorAlignment(dst, align1024B, "dst", apiName);
CheckTensorAlignment(fm, VALUE_512, "fm", apiName);
CheckTensorAlignment(filter, VALUE_512, "filter", apiName);
}
template <typename T, typename U, typename S, typename V>
__aicore__ inline void CheckMmadTensorCommon(const LocalTensor<T>& dst, const LocalTensor<U>& fm,
const LocalTensor<S>& filter, const LocalTensor<V>& bias, const MmadParams& mmadParams, const __gm__ char* apiName)
{
CheckMmadTensorCommon(dst, fm, filter, mmadParams, apiName);
CheckTensorAlignment(bias, 128, "bias", apiName);
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 1001) || (__NPU_ARCH__ == 2002))
CheckTensorPhyPosition<Hardware::L0C>(bias, "bias", "CO1", apiName);
#else
CheckTensorPhyPosition<Hardware::L0C, Hardware::BIAS>(bias, "bias", "CO1 / C2", apiName);
#endif
}
__aicore__ inline void CheckFixpipeQuantPreWithWorkspaceCommon(const QuantMode_t quantPre, const __gm__ char* apiName)
{
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::VDEQF16 || quantPre == QuantMode_t::VQF322B8_PRE ||
quantPre == QuantMode_t::VREQ8), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check quantPre value in %s, "
"when cbufWorkspace is given, supported values are VDEQF16 / VQF322B8_PRE / VREQ8.\n", apiName));
}
__aicore__ inline void CheckFixpipeQuantPreValid(const QuantMode_t quantPre, const __gm__ char* apiName)
{
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::NoQuant || quantPre == QuantMode_t::F322F16 ||
quantPre == QuantMode_t::F322BF16 || quantPre == QuantMode_t::DEQF16 ||
quantPre == QuantMode_t::VDEQF16 || quantPre == QuantMode_t::QF322B8_PRE ||
quantPre == QuantMode_t::VQF322B8_PRE || quantPre == QuantMode_t::REQ8 ||
quantPre == QuantMode_t::VREQ8), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, supported values are NoQuant / F322F16 / F322BF16 / DEQF16 / "
"VDEQF16 / QF322B8_PRE / VQF322B8_PRE / REQ8 / VREQ8.\n", apiName));
}
template <typename T, typename U>
__aicore__ inline void CheckFixpipeQuantPreCommon(const QuantMode_t quantPre, const __gm__ char* apiName)
{
CheckFixpipeQuantPreValid(quantPre, apiName);
if constexpr (IsSameType<PrimT<U>, float>::value && SupportType<PrimT<T>, int8_t, uint8_t>()) {
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::QF322B8_PRE ||
quantPre == QuantMode_t::VQF322B8_PRE), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, when src is float and dst is int8_t / uint8_t, supported values "
"are QF322B8_PRE and VQF322B8_PRE.\n", apiName));
} else if constexpr (IsSameType<PrimT<U>, float>::value && IsSameType<PrimT<T>, half>::value) {
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::F322F16), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, when src is float and dst is half, supported value is F322F16.\n",
apiName));
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102))
} else if constexpr (IsSameType<PrimT<U>, float>::value && IsSameType<PrimT<T>, bfloat16_t>::value) {
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::F322BF16), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, when src is float and dst is bfloat16_t, supported value is "
"F322BF16.\n", apiName));
#endif
} else if constexpr (IsSameType<PrimT<U>, int32_t>::value && SupportType<PrimT<T>, int8_t, uint8_t>()) {
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::REQ8 ||
quantPre == QuantMode_t::VREQ8), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, when src is int32_t and dst is int8_t / uint8_t, supported values "
"are REQ8 and VREQ8.\n", apiName));
} else if constexpr (IsSameType<PrimT<U>, int32_t>::value && IsSameType<PrimT<T>, half>::value) {
ASCENDC_DEBUG_ASSERT((quantPre == QuantMode_t::DEQF16 ||
quantPre == QuantMode_t::VDEQF16), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check quantPre value in %s, when src is int32_t and dst is half, supported values are DEQF16 "
"and VDEQF16.\n", apiName));
}
}
template <typename T, typename U, const FixpipeConfig& config>
__aicore__ inline void CheckFixpipeParamsV220Common(const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
if (intriParams.isChannelSplit) {
ASCENDC_DEBUG_ASSERT((intriParams.nSize >= 0 && intriParams.nSize <= UINT12_MAX &&
intriParams.nSize % 8 == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check nSize value in %s, "
"when isChannelSplit is true, its valid range is 0 ~ 4095 and must be divisible by 8, current value "
"is %u.\n", apiName, intriParams.nSize));
ASCENDC_DEBUG_ASSERT((IsSameType<PrimT<T>, float>::value && IsSameType<PrimT<U>, float>::value),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check isChannelSplit value in %s, isChannelSplit can be "
"enabled only when src and dst are float.\n", apiName));
ASCENDC_DEBUG_ASSERT((config.format != CO2Layout::ROW_MAJOR), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check isChannelSplit value in %s, isChannelSplit and NZ2ND cannot be enabled at the same "
"time.\n", apiName));
} else if constexpr (config.format == CO2Layout::ROW_MAJOR) {
CheckValueRange<uint16_t>(intriParams.nSize, 0, UINT12_MAX, "nSize", apiName);
} else {
ASCENDC_DEBUG_ASSERT((intriParams.nSize >= 0 && intriParams.nSize <= UINT12_MAX &&
intriParams.nSize % 16 == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check nSize value in %s, "
"when isChannelSplit is false and format is NZ, its valid range is 0 ~ 4095 and must be divisible by 16, "
"current value is %u.\n", apiName, intriParams.nSize));
}
ReportNopWarning<uint16_t>(intriParams.nSize, "intriParams.nSize", apiName);
constexpr uint16_t maxMSize = config.format == CO2Layout::ROW_MAJOR ? 8192 : UINT16_MAX;
CheckValueRange<uint16_t>(intriParams.mSize, 0, maxMSize, "mSize", apiName);
ReportNopWarning<uint16_t>(intriParams.mSize, "intriParams.mSize", apiName);
ASCENDC_DEBUG_ASSERT((intriParams.dstStride != 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check dstStride value in %s, its valid range is 1 ~ 4294967295, current value is %u.\n", apiName,
intriParams.dstStride));
ReportNopWarning<uint16_t>(intriParams.ndNum, "intriParams.ndNum", apiName);
if (intriParams.ndNum > 1) {
CheckValueRange<uint16_t>(intriParams.srcNdStride, 1, VALUE_512, "srcNdStride", apiName);
CheckValueRange<uint16_t>(intriParams.dstNdStride, 1, UINT16_MAX, "dstNdStride", apiName);
}
ASCENDC_DEBUG_ASSERT((intriParams.unitFlag == 0 || intriParams.unitFlag == 2 || intriParams.unitFlag == 3),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check unitFlag value in %s, supported values are 0, 2, and 3.\n",
apiName));
CheckFixpipeQuantPreCommon<T, U>(intriParams.quantPre, apiName);
}
template <typename T>
__aicore__ inline void CheckFixpipeWorkspace(const LocalTensor<T>& cbufWorkspace,
const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
ASCENDC_DEBUG_ASSERT((SupportType<PrimT<T>, uint64_t>()), KERNEL_LOG_INTERNAL(KERNEL_ERROR,
"Failed to check cbufWorkspace dtype in %s, supported dtype is uint64_t.\n", apiName));
CheckTensorPhyPosition<Hardware::L1>(cbufWorkspace, "cbufWorkspace", "A1", apiName);
ASCENDC_DEBUG_ASSERT((intriParams.quantPre == QuantMode_t::VDEQF16 ||
intriParams.quantPre == QuantMode_t::VQF322B8_PRE || intriParams.quantPre == QuantMode_t::VREQ8),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check quantPre value in %s, "
"when cbufWorkspace is given, supported values are VDEQF16 / VQF322B8_PRE / VREQ8.\n", apiName));
}
template <typename T, typename U, const FixpipeConfig& config>
__aicore__ inline void CheckFixpipeTensor(const LocalTensor<T>& dst, const LocalTensor<U>& src,
const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
CheckFixpipeParamsV220Common<T, U, config>(intriParams, apiName);
CheckTensorPhyPosition<Hardware::L0C>(src, "src", "CO1", apiName);
const uint32_t L0C_SRC_ALIGN = 16 * sizeof(float);
CheckTensorAlignment(src, L0C_SRC_ALIGN, "src", apiName);
CheckTensorAlignment(dst, ONE_BLK_SIZE, "dst", apiName);
CheckTensorPhyPosition<Hardware::L1, Hardware::UB>(dst, "dst", "A1", apiName);
}
template <typename T, typename U, const FixpipeConfig& config, typename S>
__aicore__ inline void CheckFixpipeTensor(const LocalTensor<T>& dst, const LocalTensor<U>& src,
const LocalTensor<S>& cbufWorkspace, const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
CheckFixpipeTensor<T, U, config>(dst, src, intriParams, apiName);
CheckFixpipeWorkspace(cbufWorkspace, intriParams, apiName);
}
template <typename T, typename U, const FixpipeConfig& config>
__aicore__ inline void CheckFixpipeTensor(const GlobalTensor<T>& dst, const LocalTensor<U>& src,
const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
(void)dst;
CheckFixpipeParamsV220Common<T, U, config>(intriParams, apiName);
CheckTensorPhyPosition<Hardware::L0C>(src, "src", "CO1", apiName);
}
template <typename T, typename U, const FixpipeConfig& config, typename S>
__aicore__ inline void CheckFixpipeTensor(const GlobalTensor<T>& dst, const LocalTensor<U>& src,
const LocalTensor<S>& cbufWorkspace, const FixpipeParamsV220& intriParams, const __gm__ char* apiName)
{
CheckFixpipeTensor<T, U, config>(dst, src, intriParams, apiName);
CheckFixpipeWorkspace(cbufWorkspace, intriParams, apiName);
}
template <typename T>
__aicore__ static inline void CheckLoadData2dDatatype()
{
#if __NPU_ARCH__ == 2002
ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, uint16_t, int16_t, half, int4b_t>()),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in LoadData with LoadData2DParams, current api support dtype "
"combination is src and dst both: uint8_t / int8_t / uint16_t / int16_t / half / int4b_t.");});
#elif __NPU_ARCH__ == 2201
ASCENDC_DEBUG_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, uint16_t, int16_t, half, bfloat16_t, uint32_t, int32_t,
float, int4b_t>()), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check dtype in LoadData with LoadData2DParams,"
" current api support dtype combination is src and dst both: uint8_t / int8_t / uint16_t / int16_t / half / "
"bfloat16_t / uint32_t / int32_t / float / int4b_t.\n"));
#elif __NPU_ARCH__ == 3102
ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, uint16_t, int16_t, int4b_t>()),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in LoadData with LoadData2DParamsV2, current api support "
"dtype combination is src and dst both: uint8_t / int8_t / half / uint16_t / int16_t / int4b_t.");});
#endif
}
template <typename T>
__aicore__ static inline void CheckLoadData2dLocal2Local(const LocalTensor<T>& dst, const LocalTensor<T>& src,
const __gm__ char* apiName)
{
CheckTensorPhyPosition<Hardware::L1>(src, "src", "A1 / B1", apiName);
CheckTensorPhyPosition<Hardware::L0A, Hardware::L0B>(dst, "dst", "A2 / B2", apiName);
CheckTensorAlignment(src, ONE_BLK_SIZE, "src", apiName);
CheckTensorAlignment(dst, VALUE_512, "dst", apiName);
}
template <typename T>
__aicore__ static inline void CheckLoadData2dGlobal2Local(const LocalTensor<T>& dst, const __gm__ char* apiName)
{
#if __NPU_ARCH__ == 3510
CheckTensorPhyPosition<Hardware::L1>(dst, "dst", "A1 / B1", apiName);
CheckTensorAlignment(dst, ONE_BLK_SIZE, "dst", apiName);
#else
CheckTensorPhyPosition<Hardware::L1, Hardware::L0A, Hardware::L0B>(dst, "dst", "A1 / B1 / A2 / B2", apiName);
const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());
if (dstScope == Hardware::L0A || dstScope == Hardware::L0B) {
CheckTensorAlignment(dst, VALUE_512, "dst", apiName);
} else {
CheckTensorAlignment(dst, ONE_BLK_SIZE, "dst", apiName);
}
#endif
}
template <typename T>
__aicore__ static inline void CheckLoadData2dParams(const LoadData2DParams& loadDataParams, bool checkTranspose)
{
#if defined(ASCENDC_DEBUG) || defined(ASCENDC_CPU_DEBUG)
CheckValueRange<uint8_t>(loadDataParams.sid, 0, MAX_LOAD2D_SID, "loadDataParams.sid",
"LoadData with LoadData2DParams");
#endif
ReportNopWarning<uint8_t>(loadDataParams.repeatTimes, "loadDataParams.repeatTimes",
"LoadData with LoadData2DParams");
if (loadDataParams.ifTranspose) {
if (checkTranspose) {
ASCENDC_DEBUG_WARNING((sizeof(T) == 2), KERNEL_LOG_INTERNAL(KERNEL_WARN, "ifTranspose in LoadData2DParams "
"should be enabled only when dtype is uint16_t / int16_t / half / bfloat16_t.\n"));
} else {
ASCENDC_DEBUG_WARNING((false), KERNEL_LOG_INTERNAL(KERNEL_WARN, "ifTranspose is effective in LoadData with "
"LoadData2DParams when src->dst is A1->A2 / B1->B2.\n"));
}
}
}
template <typename T>
__aicore__ static inline void CheckLoadDataWithTransposeDtype(const __gm__ char* apiName, bool tPosIsA1)
{
#if __NPU_ARCH__ == 2201
if (tPosIsA1) {
ASCENDC_DEBUG_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, bfloat16_t, uint32_t, int32_t, float>()),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check dtype in %s, current api support dtype combination is "
"src and dst both: uint8_t / int8_t / half / bfloat16_t / uint32_t / int32_t / float when dst TPosition "
"is A2.\n", apiName));
} else {
ASCENDC_DEBUG_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, bfloat16_t, uint32_t, int32_t, float,
int4b_t>()), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check dtype in %s, current api support dtype "
"combination is src and dst both: uint8_t / int8_t / half / bfloat16_t / uint32_t / int32_t / float / "
"int4b_t when dst TPosition is B2.\n", apiName));
}
#endif
}
template <typename T>
__aicore__ static inline void CheckLoadDataWithTranspose(const LocalTensor<T>& dst, const LocalTensor<T>& src,
const __gm__ char* apiName)
{
#if __NPU_ARCH__ != 3510 && __NPU_ARCH__ != 5102 && __NPU_ARCH__ != 3102
CheckTensorPhyPosition<Hardware::L1>(src, "src", "A1 / B1", apiName);
CheckTensorPhyPosition<Hardware::L0A, Hardware::L0B>(dst, "dst", "A2 / B2", apiName);
if ((TPosition)dst.GetPosition() == TPosition::A2) {
CheckLoadDataWithTransposeDtype<T>(apiName, true);
} else {
CheckLoadDataWithTransposeDtype<T>(apiName, false);
}
#endif
CheckTensorAlignment(src, ONE_BLK_SIZE, "src", apiName);
CheckTensorAlignment(dst, VALUE_512, "dst", apiName);
}
__aicore__ static inline void CheckLoadData3dParams(const uint16_t srcHeight, const uint16_t srcWidth,
const uint8_t srcWStride, const uint8_t srcHStride)
{
#if defined(ASCENDC_DEBUG) || defined(ASCENDC_CPU_DEBUG)
CheckValueRange<uint16_t>(srcHeight, 0, MAX_LOAD3D_L1, "l1H", "LoadData with LoadData3DParams");
CheckValueRange<uint16_t>(srcWidth, 0, MAX_LOAD3D_L1, "l1W", "LoadData with LoadData3DParams");
CheckValueRange<uint8_t>(srcWStride, static_cast<uint8_t>(MIN_LOAD3D_STRIDE),
static_cast<uint8_t>(MAX_LOAD3D_STRIDE), "strideW", "LoadData with LoadData3DParams");
CheckValueRange<uint8_t>(srcHStride, static_cast<uint8_t>(MIN_LOAD3D_STRIDE),
static_cast<uint8_t>(MAX_LOAD3D_STRIDE), "strideH", "LoadData with LoadData3DParams");
#endif
}
#if defined(ASCENDC_DEBUG) || defined(ASCENDC_CPU_DEBUG)
template <typename U>
__aicore__ inline void CheckLoadData3dv1Params(const LoadData3DParamsV1<U>& loadDataParams)
{
CheckValueRange<uint16_t>(loadDataParams.c1Index, static_cast<uint16_t>(MIN_LOAD3D_C1_IDX),
static_cast<uint16_t>(MAX_LOAD3D_C1_IDX), "c1Index", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.fetchFilterW, static_cast<uint8_t>(MIN_LOAD3D_FETCH_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FETCH_FILTER), "fetchFilterW", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.fetchFilterH, static_cast<uint8_t>(MIN_LOAD3D_FETCH_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FETCH_FILTER), "fetchFilterH", "LoadData with LoadData3DParamsV1");
CheckValueRange<int16_t>(loadDataParams.leftTopW, static_cast<int16_t>(MIN_LOAD3D_LEFT_TOP),
static_cast<int16_t>(MAX_LOAD3D_LEFT_TOP), "leftTopW", "LoadData with LoadData3DParamsV1");
CheckValueRange<int16_t>(loadDataParams.leftTopH, static_cast<int16_t>(MIN_LOAD3D_LEFT_TOP),
static_cast<int16_t>(MAX_LOAD3D_LEFT_TOP), "leftTopH", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.filterW, static_cast<uint8_t>(MIN_LOAD3D_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "filterW", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.filterH, static_cast<uint8_t>(MIN_LOAD3D_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "filterH", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.dilationFilterW, static_cast<uint8_t>(MIN_LOAD3D_DILATION_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "dilationFilterW", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.dilationFilterH, static_cast<uint8_t>(MIN_LOAD3D_DILATION_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "dilationFilterH", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.jumpStride, static_cast<uint8_t>(MIN_LOAD3D_JUMP_STRIDE),
static_cast<uint8_t>(MAX_LOAD3D_JUMP_STRIDE), "jumpStride", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.repeatMode, 0, 1, "repeatMode", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.cSize, 0, 1, "cSize", "LoadData with LoadData3DParamsV1");
CheckValueRange<uint8_t>(loadDataParams.repeatTime, static_cast<uint8_t>(MIN_LOAD3D_REPEAT_TIMES),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "repeatTime", "LoadData with LoadData3DParamsV1");
}
template <typename U>
__aicore__ inline void CheckLoadData3dv2Params(const LoadData3DParamsV2<U>& loadDataParams)
{
CheckValueRange<uint8_t>(loadDataParams.dilationFilterW, static_cast<uint8_t>(MIN_LOAD3D_DILATION_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "dilationFilterW", "LoadData with LoadData3DParamsV2");
CheckValueRange<uint8_t>(loadDataParams.dilationFilterH, static_cast<uint8_t>(MIN_LOAD3D_DILATION_FILTER),
static_cast<uint8_t>(MAX_LOAD3D_FILTER), "dilationFilterH", "LoadData with LoadData3DParamsV2");
ReportNopWarning<uint16_t>(loadDataParams.kExtension, "loadDataParams.kExtension",
"LoadData with LoadData3DParamsV2");
ReportNopWarning<uint16_t>(loadDataParams.mExtension, "loadDataParams.mExtension",
"LoadData with LoadData3DParamsV2");
ReportNopWarning<uint16_t>(loadDataParams.channelSize, "loadDataParams.channelSize",
"LoadData with LoadData3DParamsV2");
ReportNopWarning<uint16_t>(loadDataParams.l1H, "loadDataParams.l1H", "LoadData with LoadData3DParamsV2");
ReportNopWarning<uint16_t>(loadDataParams.l1W, "loadDataParams.l1W", "LoadData with LoadData3DParamsV2");
ASCENDC_DEBUG_WARNING((!(loadDataParams.filterW == 0 && loadDataParams.filterSizeW == false)),
KERNEL_LOG_INTERNAL(KERNEL_WARN, "In LoadData with LoadData3DParamsV2, loadDataParams.filterW = 0 and "
"loadDataParams.filterSizeW == false, which makes LoadData with LoadData3DParamsV2 equivalent to a NOP.\n"));
ASCENDC_DEBUG_WARNING((!(loadDataParams.filterH == 0 && loadDataParams.filterSizeH == false)),
KERNEL_LOG_INTERNAL(KERNEL_WARN, "In LoadData with LoadData3DParamsV2, loadDataParams.filterH = 0 and "
"loadDataParams.filterSizeH == false, which makes LoadData with LoadData3DParamsV2 equivalent to a NOP.\n"));
}
#endif
template <typename T>
__aicore__ static inline void CheckLoadData3dv2ChannelSize(const uint16_t channelSize)
{
#if __NPU_ARCH__ == 2002
if constexpr (IsSameType<PrimT<T>, half>::value) {
uint16_t remainderList[] = {4, 8};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 2) || channelSize == 16),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype half, allowed value is 16, or allowed remainders when divided by 16 "
"are 4 and 8, current value is %u.\n", channelSize));
} else if constexpr(SupportType<PrimT<T>, int8_t, uint8_t>()) {
uint16_t remainderList[] = {4, 8, 16};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3) || channelSize == 32),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype int8_t / uint8_t, allowed value is 32, or allowed remainders when "
"divided by 32 are 4, 8, and 16, current value is %u.\n", channelSize));
} else if constexpr (IsSameType<PrimT<T>, int4b_t>::value) {
uint16_t remainderList[] = {8, 16, 32};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3) || channelSize == 64),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype int4b_t, allowed value is 64, or allowed remainders when divided by "
"64 are 8, 16, and 32, current value is %u.\n", channelSize));
}
#elif defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) || \
(__NPU_ARCH__ == 3102) || (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || \
(__NPU_ARCH__ == 3510))
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3102 || (__NPU_ARCH__ == 3003) || \
(__NPU_ARCH__ == 3113))
if constexpr (IsSameType<PrimT<T>, half>::value) {
uint16_t remainderList[] = {0, 4, 8};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3)),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype half, allowed remainders when divided by 16 are 0, 4, and 8, current "
"value is %u.\n", channelSize));
}
#else
if constexpr (SupportType<PrimT<T>, half, bfloat16_t>()) {
uint16_t remainderList[] = {0, 4, 8};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3)),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype half / bfloat16_t, allowed remainders when divided by 16 are 0, 4, "
"and 8, current value is %u.\n", channelSize));
}
#endif
if constexpr (SupportType<PrimT<T>, float, int32_t, uint32_t>()) {
uint16_t remainderList[] = {0, 4};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 2)),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype float / int32_t / uint32_t, allowed remainders when divided by 8 are "
"0 and 4, current value is %u.\n", channelSize));
} else if constexpr (SupportType<PrimT<T>, int8_t, uint8_t>()) {
uint16_t remainderList[] = {0, 4, 8, 16};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 4)),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype int8_t / uint8_t, allowed remainders when divided by 32 are 0, 4, 8, "
"and 16, current value is %u.\n", channelSize));
} else if constexpr (IsSameType<PrimT<T>, int4b_t>::value) {
uint16_t remainderList[] = {0, 8, 16, 32};
ASCENDC_DEBUG_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 4)),
KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with "
"LoadData3DParamsV2 with dtype int4b_t, allowed remainders when divided by 64 are 0, 8, 16, and 32, "
"current value is %u.\n", channelSize));
}
#endif
}
template <typename T>
__aicore__ static inline void CheckLoadData3dv2MatrixParams(const uint16_t kExtension, const uint16_t mExtension,
const uint16_t kStartPt, const uint16_t mStartPt) {
constexpr uint16_t base16 = 16;
if constexpr (SupportType<PrimT<T>, half, int8_t, uint8_t, int4b_t>()) {
ASCENDC_DEBUG_ASSERT((mExtension % base16 == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check "
"mExtension value in LoadData with LoadData3DParamsV2 when dtype is half / int8_t / uint8_t / int4b_t, "
"it should be divisible by 16, current value is %u.\n", mExtension));
}
uint16_t kExtBase = (SupportType<PrimT<T>, int4b_t>()) ? 64 : ONE_BLK_SIZE / sizeof(PrimT<T>);
if constexpr (SupportType<PrimT<T>, half, int8_t, uint8_t, int4b_t, int32_t, uint32_t, float>()) {
ASCENDC_DEBUG_ASSERT((kExtension % kExtBase == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check "
"kExtension value in LoadData with LoadData3DParamsV2 when dtype is half / int8_t / uint8_t / int4b_t / "
"int32_t / uint32_t / float, it should be divisible by %u, current value is %u.\n", kExtBase,
kExtension));
ASCENDC_DEBUG_ASSERT((kStartPt % kExtBase == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check "
"kStartPt value in LoadData with LoadData3DParamsV2 when dtype is half / int8_t / uint8_t / int4b_t / "
"int32_t / uint32_t / float, it should be divisible by %u, current value is %u.\n", kExtBase,
kStartPt));
}
#if __NPU_ARCH__ == 2002
if constexpr (SupportType<PrimT<T>, half, int8_t, uint8_t, int4b_t>()) {
ASCENDC_DEBUG_ASSERT((mStartPt % base16 == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check "
"mStartPt value in LoadData with LoadData3DParamsV2 when dtype is half / int8_t / uint8_t / int4b_t, it "
"should be divisible by 16, current value is %u.\n", mStartPt));
}
#elif __NPU_ARCH__ == 2201
CheckValueRange<uint16_t>(mStartPt, static_cast<uint16_t>(0), static_cast<uint16_t>(UINT15_MAX), "mStartPt",
"LoadData with LoadData3DParamsV2");
if constexpr (SupportType<PrimT<T>, half, int8_t, uint8_t>()) {
ASCENDC_DEBUG_ASSERT((mStartPt % base16 == 0), KERNEL_LOG_INTERNAL(KERNEL_ERROR, "Failed to check "
"mStartPt value in LoadData with LoadData3DParamsV2 when dtype is half / int8_t / uint8_t, it should be "
"divisible by 16, current value is %u.\n", mStartPt));
}
#endif
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_OPERATOR_MM_CHECK_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_OPERATOR_MM_CHECK_H__
#endif