* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_check_vec.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/basic_api/utils/kernel_check_vec.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"basic_api/kernel_tensor.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__
#endif
#ifndef ASCENDC_MODULE_CHECK_VEC_H
#define ASCENDC_MODULE_CHECK_VEC_H
#if ASCENDC_CPU_DEBUG
#include "kernel_check_util.h"
#include "kernel_common.h"
#include "kernel_struct_unary.h"
#include "kernel_struct_mm.h"
namespace AscendC {
template <typename T, typename U>
bool CheckVectorPadding(
const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
const uint64_t mask[], const uint8_t repeatTime, const UnaryRepeatParams& repeatParams, const char* intriName)
{
check::VectorPaddingApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
repeatTime,
static_cast<uint16_t>(repeatParams.dstBlkStride),
static_cast<uint16_t>(repeatParams.srcBlkStride),
static_cast<uint16_t>(repeatParams.dstRepStride),
static_cast<uint16_t>(repeatParams.srcRepStride),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
padMode,
padSide};
return CheckVectorPaddingForMaskArray(chkParams, mask, intriName);
}
template <typename T, typename U>
bool CheckVectorPadding(
const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
const uint64_t mask, const uint8_t repeatTime, const UnaryRepeatParams& repeatParams, const char* intriName)
{
check::VectorPaddingApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
repeatTime,
static_cast<uint16_t>(repeatParams.dstBlkStride),
static_cast<uint16_t>(repeatParams.srcBlkStride),
static_cast<uint16_t>(repeatParams.dstRepStride),
static_cast<uint16_t>(repeatParams.srcRepStride),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
padMode,
padSide};
return CheckVectorPadding(chkParams, mask, intriName);
}
template <typename T, typename U>
bool CheckVectorPadding(
const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
const uint32_t count, const char* intriName)
{
check::VectorPaddingApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
count,
padMode,
padSide};
return CheckVectorPadding(chkParams, intriName);
}
template <typename T>
bool CheckFuncLoadDataTranspose(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LoadData2dTransposeParams& loadDataParams,
const char* intriName)
{
#if __NPU_ARCH__ == 3102
constexpr bool dtypeMatch = SupportType<PrimT<T>, uint8_t, int8_t, half>();
ASSERT(dtypeMatch && "LoadData2dTransposeParams without dtype of u8/s8/fp16 is not supported on current device");
return dtypeMatch;
#elif (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
constexpr bool dtypeMatch = SupportType<PrimT<T>, uint8_t, int8_t, half, bfloat16_t, float, int32_t, uint32_t>();
ASSERT(
dtypeMatch &&
"LoadData2dTransposeParams without dtype of u8/s8/fp16/bf16/f32/s32/u32 is not supported on current device");
return dtypeMatch;
#else
return true;
#endif
}
template <typename T>
bool CheckFuncLoadDataTranspose(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LoadData2dTransposeParamsV2& loadDataParams,
const char* intriName)
{
#if __NPU_ARCH__ == 3102 || (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
bool scopeMatch =
(GetPhyType(static_cast<TPosition>(dst.GetPosition())) == Hardware::L0B &&
GetPhyType(static_cast<TPosition>(src.GetPosition())) == Hardware::L1);
ASSERT(scopeMatch && "LoadDataWithTranspose without B1->B2 is not supported on current device");
#if __NPU_ARCH__ == 3102
constexpr bool dtypeMatch =
IsSameType<PrimT<T>, int4b_t>::value || sizeof(PrimT<T>) == sizeof(int8_t) || sizeof(PrimT<T>) == sizeof(half);
#elif (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
constexpr bool dtypeMatch =
sizeof(PrimT<T>) == sizeof(int8_t) || sizeof(PrimT<T>) == sizeof(half) || sizeof(PrimT<T>) == sizeof(float);
#endif
ASSERT(dtypeMatch && "LoadDataWithTranspose is not supported on current device");
return scopeMatch && dtypeMatch;
#else
ASSERT(false && "Current version don't support LoadDataWithTranspose using LoadData2dTransposeParamsV2");
return false;
#endif
}
template <typename T, typename U, typename S, typename V>
bool CheckMmadParams(
const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const LocalTensor<V>& bias,
const MmadParams& mmadParams, const char* intriName)
{
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113))
return true;
#else
check::MmadApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(fm.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(filter.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(bias.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint32_t>(sizeof(PrimT<S>)),
static_cast<uint32_t>(sizeof(PrimT<V>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(fm.GetSize() * sizeof(PrimT<U>)),
static_cast<uint64_t>(filter.GetSize() * sizeof(PrimT<S>)),
static_cast<uint64_t>(bias.GetSize() * sizeof(PrimT<V>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(fm.GetPosition()),
static_cast<uint8_t>(filter.GetPosition()),
static_cast<uint8_t>(bias.GetPosition()),
mmadParams.m,
mmadParams.n,
mmadParams.k,
mmadParams.isBias,
mmadParams.fmOffset,
mmadParams.enSsparse,
mmadParams.enWinogradA,
mmadParams.enWinogradB};
return CheckFuncMmadImpl(chkParams, intriName);
#endif
}
template <typename T, typename U, typename S>
bool CheckMmadParams(
const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const MmadParams& mmadParams,
const char* intriName)
{
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113))
return true;
#else
check::MmadApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(fm.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(filter.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint32_t>(sizeof(PrimT<S>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(fm.GetSize() * sizeof(PrimT<U>)),
static_cast<uint64_t>(filter.GetSize() * sizeof(PrimT<S>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(fm.GetPosition()),
static_cast<uint8_t>(filter.GetPosition()),
mmadParams.m,
mmadParams.n,
mmadParams.k,
mmadParams.isBias,
mmadParams.fmOffset,
mmadParams.enSsparse,
mmadParams.enWinogradA,
mmadParams.enWinogradB};
return CheckFuncMmadImpl(chkParams, intriName);
#endif
}
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102))
template <typename T, typename U, typename S, typename V>
bool CheckMmadParams(
const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const LocalTensor<V>& bias,
const uint64_t& mmadParams, const char* intriName)
{
return true;
}
template <typename T, typename U, typename S>
bool CheckMmadParams(
const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const uint64_t& mmadParams,
const char* intriName)
{
return true;
}
#endif
template <typename T, typename U>
bool CheckFuncBroadCastToMM(
const LocalTensor<T>& dst, const LocalTensor<U>& src, const int32_t blockCount, const uint8_t blockLen,
const uint8_t srcGap, const uint8_t dstGap, const char* intriName)
{
check::VecBroadCastToMMApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint32_t>(blockCount),
static_cast<uint8_t>(blockLen),
static_cast<uint8_t>(srcGap),
static_cast<uint8_t>(dstGap)};
return CheckFuncBroadCastToMMImpl(chkParams, intriName);
}
template <typename T, typename U = T>
check::VecReduceApiParams BuildVecReduceOtherParams(
const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t dstRepStride,
const int32_t srcBlkStride, const int32_t srcRepStride)
{
return check::VecReduceApiParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<U>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
static_cast<uint16_t>(dstRepStride),
static_cast<uint16_t>(srcBlkStride),
static_cast<uint16_t>(srcRepStride),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<U>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition())};
}
template <typename T, typename U = T>
bool CheckFunVecReduceOther(
const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t maskCount,
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceOtherParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride);
return CheckFunReduceOtherImpl(chkParams, maskCount, intriName);
}
template <typename T, typename U = T>
bool CheckFunVecReduceOther(
const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const uint64_t mask[],
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceOtherParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride);
return CheckFunReduceOtherImplForMaskArray(chkParams, mask, intriName);
}
template <typename T>
check::VecReduceWhlApiParams BuildVecReduceOtherWhlParams(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t dstRepStride,
const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order)
{
return check::VecReduceWhlApiParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
static_cast<uint16_t>(dstRepStride),
static_cast<uint16_t>(srcBlkStride),
static_cast<uint16_t>(srcRepStride),
order,
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition())};
}
template <typename T>
bool CheckFunVecReduceOtherWhl(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t maskCount,
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order,
const char* intriName)
{
auto chkParams =
BuildVecReduceOtherWhlParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride, order);
return CheckFunReduceOtherWhlImpl(chkParams, maskCount, intriName);
}
template <typename T>
bool CheckFunVecReduceOtherWhl(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const uint64_t mask[],
const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order,
const char* intriName)
{
auto chkParams =
BuildVecReduceOtherWhlParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride, order);
return CheckFunReduceOtherWhlImplForMaskArray(chkParams, mask, intriName);
}
template <typename T>
check::VecReduceApiParams BuildVecReduceWithCalIndexParams(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
bool calIndex, const int32_t srcRepStride)
{
return check::VecReduceApiParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
calIndex,
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(work.GetPosition()),
static_cast<uint16_t>(srcRepStride)};
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
const int32_t mask, bool calIndex, const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceWithCalIndexParams(dst, src, work, repeatTime, calIndex, srcRepStride);
return CheckFunReduceImpl(chkParams, mask, intriName);
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
const uint64_t mask[], bool calIndex, const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceWithCalIndexParams(dst, src, work, repeatTime, calIndex, srcRepStride);
return CheckFunReduceImplForMaskArray(chkParams, mask, intriName);
}
template <typename T>
check::VecReduceApiParams BuildVecReduceWithoutCalIndexParams(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
const int32_t srcRepStride)
{
return check::VecReduceApiParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(work.GetPosition()),
static_cast<uint16_t>(srcRepStride)};
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
const int32_t mask, const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceWithoutCalIndexParams(dst, src, work, repeatTime, srcRepStride);
return CheckFunReduceImpl(chkParams, mask, intriName);
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
const uint64_t mask[], const int32_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecReduceWithoutCalIndexParams(dst, src, work, repeatTime, srcRepStride);
return CheckFunReduceImplForMaskArray(chkParams, mask, intriName);
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, int32_t repeatTime,
const int32_t count, bool calIndex, const char* intriName)
{
check::VecReduceApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
static_cast<uint32_t>(count),
calIndex,
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(work.GetPosition())};
return CheckFunReduceImpl(chkParams, intriName);
}
template <typename T>
bool CheckFunVecReduce(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t count,
int32_t repeatTime, const char* intriName)
{
check::VecReduceApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
repeatTime,
static_cast<uint32_t>(count),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(work.GetPosition())};
return CheckFunReduceImpl(chkParams, intriName);
}
template <typename T>
bool CheckFunVecReduceMode2(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t count, const char* intriName)
{
check::VecReduceApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(count),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition())};
return CheckFunReduceImplMode2(chkParams, intriName);
}
template <typename T, typename U>
check::VecScatterApiParams BuildVecScatterParams(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
const uint8_t repeatTime, const uint16_t srcRepStride)
{
return check::VecScatterApiParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dstOffset.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
dstBaseAddr,
repeatTime,
static_cast<uint16_t>(srcRepStride),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(dstOffset.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(dstOffset.GetPosition())};
}
template <typename T, typename U>
bool CheckFunScatter(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
const uint64_t mask[], const uint8_t repeatTime, const uint16_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecScatterParams(dst, src, dstOffset, dstBaseAddr, repeatTime, srcRepStride);
return CheckFunScatterImplForMaskArray(chkParams, mask, intriName);
}
template <typename T, typename U>
bool CheckFunScatter(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
const uint64_t mask, const uint8_t repeatTime, const uint16_t srcRepStride, const char* intriName)
{
auto chkParams = BuildVecScatterParams(dst, src, dstOffset, dstBaseAddr, repeatTime, srcRepStride);
return CheckFunScatterImpl(chkParams, mask, intriName);
}
template <typename T, typename U>
bool CheckFunScatter(
const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
const uint32_t count, const char* intriName)
{
check::VecScatterApiParams chkParams{
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dstOffset.GetPhyAddr())),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<T>)),
static_cast<uint32_t>(sizeof(PrimT<U>)),
dstBaseAddr,
static_cast<uint32_t>(count),
static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
static_cast<uint64_t>(dstOffset.GetSize() * sizeof(PrimT<U>)),
static_cast<uint8_t>(dst.GetPosition()),
static_cast<uint8_t>(src.GetPosition()),
static_cast<uint8_t>(dstOffset.GetPosition())};
return CheckFunScatterImpl(chkParams, intriName);
}
}
#endif
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__
#endif