* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_operator_mm_check.h
* \brief
*/
#ifndef ASCENDC_MODULE_OPERATOR_MM_CHECK_H
#define ASCENDC_MODULE_OPERATOR_MM_CHECK_H
#include "kernel_check.h"
namespace AscendC {
template <typename T>
__aicore__ static inline bool ChannelSizeRemainder(const uint16_t channelSize, uint16_t remainder[], uint16_t size)
{
uint16_t oneBlkNum = ONE_BLK_SIZE / sizeof(T);
if constexpr (IsSameType<T, int4b_t>::value) {
oneBlkNum = 64;
}
for (uint16_t i = 0; i < size; i++) {
if (channelSize % oneBlkNum == remainder[i]) {
return true;
}
}
return false;
}
template <typename T, typename U, typename S>
__aicore__ static inline void CheckMmadAlign(const LocalTensor<T>& dst, const LocalTensor<U>& fm,
const LocalTensor<S>& filter) {
constexpr uint64_t align1024B = 1024;
if constexpr ((IsSameType<PrimT<U>, half>::value) && (IsSameType<PrimT<S>, half>::value) &&
(IsSameType<PrimT<T>, half>::value)) {
CheckTensorAlign<T>(dst, VALUE_512, "dst", "Mmad");
} else {
CheckTensorAlign<T>(dst, align1024B, "dst", "Mmad");
}
CheckTensorAlign<U>(fm, VALUE_512, "fm", "Mmad");
CheckTensorAlign<S>(filter, VALUE_512, "filter", "Mmad");
}
template <typename T>
__aicore__ static inline void CheckLoadData2dDatatype()
{
#if __NPU_ARCH__ == 2002
ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, uint16_t, int16_t, half, int4b_t>()),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check dtype in LoadData with LoadData2DParams, current api support dtype combination is src and dst both: "
"uint8_t / int8_t / uint16_t / int16_t / half / int4b_t.");});
#elif __NPU_ARCH__ == 2201
ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, uint16_t, int16_t, half, bfloat16_t, uint32_t, int32_t,
float, int4b_t>()),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in LoadData with LoadData2DParams, current api "
"support dtype combination is src and dst both uint8_t / int8_t / uint16_t / int16_t / half / bfloat16_t / "
"uint32_t / int32_t / float / int4b_t.");});
#elif __NPU_ARCH__ == 3102
ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, uint16_t, int16_t, int4b_t>()),
{KERNEL_LOG(KERNEL_ERROR,
"Failed to check dtype in LoadData with LoadData2DParamsV2, current api support dtype combination is src and "
"dst both: uint8_t / int8_t / half / uint16_t / int16_t / int4b_t.");});
#endif
}
__aicore__ static inline void CheckLoadData3dParams(const uint16_t srcHeight, const uint16_t srcWeight,
const uint8_t srcWStride, const uint8_t srcHStride)
{
ASCENDC_CHECK_VALUE_RANGE(srcHeight, MIN_LOAD3D_L1, MAX_LOAD3D_L1, "l1H", "LoadData with LoadData3DParams");
ASCENDC_CHECK_VALUE_RANGE(srcWeight, MIN_LOAD3D_L1, MAX_LOAD3D_L1, "l1W", "LoadData with LoadData3DParams");
ASCENDC_CHECK_VALUE_RANGE(srcWStride, MIN_LOAD3D_STRIDE, MAX_LOAD3D_STRIDE, "strideW",
"LoadData with LoadData3DParams");
ASCENDC_CHECK_VALUE_RANGE(srcHStride, MIN_LOAD3D_STRIDE, MAX_LOAD3D_STRIDE, "strideH",
"LoadData with LoadData3DParams");
}
template <typename T>
__aicore__ static inline void CheckLoadData3dv2ChannelSize(const uint16_t channelSize)
{
#if __NPU_ARCH__ == 2002
if constexpr (IsSameType<PrimT<T>, half>::value) {
uint16_t remainderList[] = {4, 8};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 2) || channelSize == 16),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with LoadData3DParamsV2 "
"with dtype half, it should be: 16 or channelSize % 16 = 4 / 8, current value is %u", channelSize);});
} else if constexpr(SupportType<PrimT<T>, int8_t, uint8_t>()) {
uint16_t remainderList[] = {4, 8, 16};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3) || channelSize == 32),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with LoadData3DParamsV2 "
"with dtype int8_t / uint8_t, it should be: 32 or channelSize % 32 = 4 / 8 / 16, current value is %u",
channelSize);});
} else if constexpr (IsSameType<PrimT<T>, int4b_t>::value) {
uint16_t remainderList[] = {8, 16, 32};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3) || channelSize == 64),
{KERNEL_LOG(KERNEL_ERROR, "Failed to check param channelSize value in LoadData with LoadData3DParamsV2 "
"with dtype int4b_t, it should be: 64 or channelSize % 64 = 8 / 16 / 32, current value is %u",
channelSize);});
}
#elif defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) || \
(__NPU_ARCH__ == 3102) || (__NPU_ARCH__ == 5102) || (__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113) || \
(__NPU_ARCH__ == 3101))
#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3102 || (__NPU_ARCH__ == 3003) || \
(__NPU_ARCH__ == 3113))
if constexpr (IsSameType<PrimT<T>, half>::value) {
uint16_t remainderList[] = {0, 4, 8};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3)),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check param channelSize value in LoadData with LoadData3DParamsV2 with dtype half, it should be: "
"channelSize % 16 = 0 / 4 / 8, current value is %u", channelSize);});
}
#else
if constexpr (SupportType<PrimT<T>, half, bfloat16_t>()) {
uint16_t remainderList[] = {0, 4, 8};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 3)),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check param channelSize value in LoadData with LoadData3DParamsV2 with dtype half / bfloat16_t, it should "
"be: channelSize % 16 = 0 / 4 / 8, current value is %u", channelSize);});
}
#endif
if constexpr (SupportType<PrimT<T>, float, int32_t, uint32_t>()) {
uint16_t remainderList[] = {0, 4};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 2)),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check param channelSize value in LoadData with LoadData3DParamsV2 with dtype float / int32_t / uint32_t, "
"it should be: channelSize % 8 = 0 / 4, current value is %u", channelSize);});
} else if constexpr (SupportType<PrimT<T>, int8_t, uint8_t>()) {
uint16_t remainderList[] = {0, 4, 8, 16};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 4)),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check param channelSize value in LoadData with LoadData3DParamsV2 with dtype int8_t / uint8_t, it should "
"be: channelSize % 32 = 0 / 4 / 8 / 16, current value is %u", channelSize);});
} else if constexpr (IsSameType<PrimT<T>, int4b_t>::value) {
uint16_t remainderList[] = {0, 8, 16, 32};
ASCENDC_ASSERT((ChannelSizeRemainder<PrimT<T>>(channelSize, remainderList, 4)),
{KERNEL_LOG(KERNEL_ERROR, "Failed to "
"check param channelSize value in LoadData with LoadData3DParamsV2 with dtype int4b_t, it should be: "
"channelSize % 64 = 0 / 8 / 16 / 32, current value is %u", channelSize);});
}
#endif
}
template <typename T>
__aicore__ static inline void CheckLoadData3dv2MatrixParams(const uint16_t kExtension, const uint16_t mExtension,
const uint16_t kStartPt, const uint16_t mStartPt) {
constexpr uint16_t base16 = 16;
if constexpr (SupportType<PrimT<T>, half, int8_t, int4b_t>()) {
ASCENDC_ASSERT((mExtension % base16 == 0), { KERNEL_LOG(KERNEL_ERROR, "Failed to check mExtension value in "
"LoadData with LoadData3DParamsV2 when dtype is half / int8_t / int4b_t, it should be divisible by 16, "
"current value is %u", mExtension);});
}
uint16_t kExtBase = (SupportType<PrimT<T>, int4b_t>()) ? 64 : ONE_BLK_SIZE / sizeof(PrimT<T>);
if constexpr (SupportType<PrimT<T>, half, int8_t, int4b_t, int32_t, uint32_t, float>()) {
ASCENDC_ASSERT((kExtension % kExtBase == 0), { KERNEL_LOG(KERNEL_ERROR, "Failed to check kExtension value in "
"LoadData with LoadData3DParamsV2 when dtype is half / int8_t / int4b_t / int32_t / uint32_t / float, it "
"should be divisible by %u, current value is %u", kExtBase, kExtension);});
ASCENDC_ASSERT((kStartPt % kExtBase == 0), { KERNEL_LOG(KERNEL_ERROR, "Failed to check kStartPt value in "
"LoadData with LoadData3DParamsV2 when dtype is half / int8_t / int4b_t / int32_t / uint32_t / float, it "
"should be divisible by %u, current value is %u", kExtBase, kStartPt);});
}
#if __NPU_ARCH__ == 2002
if constexpr (SupportType<PrimT<T>, half, int8_t, int4b_t>()) {
ASCENDC_ASSERT((mStartPt % base16 == 0), { KERNEL_LOG(KERNEL_ERROR, "Failed to check mStartPt value in "
"LoadData with LoadData3DParamsV2 when dtype is half / int8_t / int4b_t, it should be divisible by 16, "
"current value is %u", mStartPt);});
}
#elif __NPU_ARCH__ == 2201
ASCENDC_CHECK_VALUE_RANGE(mStartPt, 0, UINT15_MAX, "mStartPt", "LoadData with LoadData3DParamsV2");
#endif
}
}
#endif