/**
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file kernel_check_vec.h
 * \brief
 */
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
    "impl/basic_api/utils/kernel_check_vec.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"basic_api/kernel_tensor.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__
#endif
#ifndef ASCENDC_MODULE_CHECK_VEC_H
#define ASCENDC_MODULE_CHECK_VEC_H

#if ASCENDC_CPU_DEBUG
#include "kernel_check_util.h"
#include "kernel_common.h"
#include "kernel_struct_unary.h"
#include "kernel_struct_mm.h"

namespace AscendC {
template <typename T, typename U>
bool CheckVectorPadding(
    const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
    const uint64_t mask[], const uint8_t repeatTime, const UnaryRepeatParams& repeatParams, const char* intriName)
{
    check::VectorPaddingApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        repeatTime,
        static_cast<uint16_t>(repeatParams.dstBlkStride),
        static_cast<uint16_t>(repeatParams.srcBlkStride),
        static_cast<uint16_t>(repeatParams.dstRepStride),
        static_cast<uint16_t>(repeatParams.srcRepStride),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        padMode,
        padSide};
    return CheckVectorPaddingForMaskArray(chkParams, mask, intriName);
}

template <typename T, typename U>
bool CheckVectorPadding(
    const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
    const uint64_t mask, const uint8_t repeatTime, const UnaryRepeatParams& repeatParams, const char* intriName)
{
    check::VectorPaddingApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        repeatTime,
        static_cast<uint16_t>(repeatParams.dstBlkStride),
        static_cast<uint16_t>(repeatParams.srcBlkStride),
        static_cast<uint16_t>(repeatParams.dstRepStride),
        static_cast<uint16_t>(repeatParams.srcRepStride),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        padMode,
        padSide};
    return CheckVectorPadding(chkParams, mask, intriName);
}

template <typename T, typename U>
bool CheckVectorPadding(
    const LocalTensor<T>& dst, const LocalTensor<U>& src, const uint8_t padMode, const bool padSide,
    const uint32_t count, const char* intriName)
{
    check::VectorPaddingApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        count,
        padMode,
        padSide};
    return CheckVectorPadding(chkParams, intriName);
}

template <typename T>
bool CheckFuncLoadDataTranspose(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LoadData2dTransposeParams& loadDataParams,
    const char* intriName)
{
#if __NPU_ARCH__ == 3102
    constexpr bool dtypeMatch = SupportType<PrimT<T>, uint8_t, int8_t, half>();
    ASSERT(dtypeMatch && "LoadData2dTransposeParams without dtype of u8/s8/fp16 is not supported on current device");
    return dtypeMatch;
#elif (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
    constexpr bool dtypeMatch = SupportType<PrimT<T>, uint8_t, int8_t, half, bfloat16_t, float, int32_t, uint32_t>();
    ASSERT(
        dtypeMatch &&
        "LoadData2dTransposeParams without dtype of u8/s8/fp16/bf16/f32/s32/u32 is not supported on current device");
    return dtypeMatch;
#else
    return true;
#endif
}

template <typename T>
bool CheckFuncLoadDataTranspose(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LoadData2dTransposeParamsV2& loadDataParams,
    const char* intriName)
{
#if __NPU_ARCH__ == 3102 || (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
    bool scopeMatch =
        (GetPhyType(static_cast<TPosition>(dst.GetPosition())) == Hardware::L0B &&
         GetPhyType(static_cast<TPosition>(src.GetPosition())) == Hardware::L1);
    ASSERT(scopeMatch && "LoadDataWithTranspose without B1->B2 is not supported on current device");
#if __NPU_ARCH__ == 3102
    constexpr bool dtypeMatch =
        IsSameType<PrimT<T>, int4b_t>::value || sizeof(PrimT<T>) == sizeof(int8_t) || sizeof(PrimT<T>) == sizeof(half);
#elif (__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102)
    constexpr bool dtypeMatch =
        sizeof(PrimT<T>) == sizeof(int8_t) || sizeof(PrimT<T>) == sizeof(half) || sizeof(PrimT<T>) == sizeof(float);
#endif
    ASSERT(dtypeMatch && "LoadDataWithTranspose is not supported on current device");
    return scopeMatch && dtypeMatch;
#else
    ASSERT(false && "Current version don't support LoadDataWithTranspose using LoadData2dTransposeParamsV2");
    return false;
#endif
}

template <typename T, typename U, typename S, typename V>
bool CheckMmadParams(
    const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const LocalTensor<V>& bias,
    const MmadParams& mmadParams, const char* intriName)
{
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113))
    return true;
#else
    check::MmadApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(fm.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(filter.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(bias.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint32_t>(sizeof(PrimT<S>)),
        static_cast<uint32_t>(sizeof(PrimT<V>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(fm.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint64_t>(filter.GetSize() * sizeof(PrimT<S>)),
        static_cast<uint64_t>(bias.GetSize() * sizeof(PrimT<V>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(fm.GetPosition()),
        static_cast<uint8_t>(filter.GetPosition()),
        static_cast<uint8_t>(bias.GetPosition()),
        mmadParams.m,
        mmadParams.n,
        mmadParams.k,
        mmadParams.isBias,
        mmadParams.fmOffset,
        mmadParams.enSsparse,
        mmadParams.enWinogradA,
        mmadParams.enWinogradB};
    return CheckFuncMmadImpl(chkParams, intriName);
#endif
}
template <typename T, typename U, typename S>
bool CheckMmadParams(
    const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const MmadParams& mmadParams,
    const char* intriName)
{
#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003) || (__NPU_ARCH__ == 3113))
    return true;
#else
    check::MmadApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(fm.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(filter.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint32_t>(sizeof(PrimT<S>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(fm.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint64_t>(filter.GetSize() * sizeof(PrimT<S>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(fm.GetPosition()),
        static_cast<uint8_t>(filter.GetPosition()),
        mmadParams.m,
        mmadParams.n,
        mmadParams.k,
        mmadParams.isBias,
        mmadParams.fmOffset,
        mmadParams.enSsparse,
        mmadParams.enWinogradA,
        mmadParams.enWinogradB};
    return CheckFuncMmadImpl(chkParams, intriName);
#endif
}

#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3510) || (__NPU_ARCH__ == 5102))
template <typename T, typename U, typename S, typename V>
bool CheckMmadParams(
    const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const LocalTensor<V>& bias,
    const uint64_t& mmadParams, const char* intriName)
{
    return true;
}

template <typename T, typename U, typename S>
bool CheckMmadParams(
    const LocalTensor<T>& dst, const LocalTensor<U>& fm, const LocalTensor<S>& filter, const uint64_t& mmadParams,
    const char* intriName)
{
    return true;
}
#endif

template <typename T, typename U>
bool CheckFuncBroadCastToMM(
    const LocalTensor<T>& dst, const LocalTensor<U>& src, const int32_t blockCount, const uint8_t blockLen,
    const uint8_t srcGap, const uint8_t dstGap, const char* intriName)
{
    check::VecBroadCastToMMApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint32_t>(blockCount),
        static_cast<uint8_t>(blockLen),
        static_cast<uint8_t>(srcGap),
        static_cast<uint8_t>(dstGap)};
    return CheckFuncBroadCastToMMImpl(chkParams, intriName);
}

template <typename T, typename U = T>
check::VecReduceApiParams BuildVecReduceOtherParams(
    const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t dstRepStride,
    const int32_t srcBlkStride, const int32_t srcRepStride)
{
    return check::VecReduceApiParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        static_cast<uint16_t>(dstRepStride),
        static_cast<uint16_t>(srcBlkStride),
        static_cast<uint16_t>(srcRepStride),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition())};
}

template <typename T, typename U = T>
bool CheckFunVecReduceOther(
    const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t maskCount,
    const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceOtherParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride);
    return CheckFunReduceOtherImpl(chkParams, maskCount, intriName);
}

template <typename T, typename U = T>
bool CheckFunVecReduceOther(
    const LocalTensor<U>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const uint64_t mask[],
    const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceOtherParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride);
    return CheckFunReduceOtherImplForMaskArray(chkParams, mask, intriName);
}

template <typename T>
check::VecReduceWhlApiParams BuildVecReduceOtherWhlParams(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t dstRepStride,
    const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order)
{
    return check::VecReduceWhlApiParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        static_cast<uint16_t>(dstRepStride),
        static_cast<uint16_t>(srcBlkStride),
        static_cast<uint16_t>(srcRepStride),
        order,
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition())};
}

template <typename T>
bool CheckFunVecReduceOtherWhl(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const int32_t maskCount,
    const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order,
    const char* intriName)
{
    auto chkParams =
        BuildVecReduceOtherWhlParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride, order);
    return CheckFunReduceOtherWhlImpl(chkParams, maskCount, intriName);
}

template <typename T>
bool CheckFunVecReduceOtherWhl(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t repeatTime, const uint64_t mask[],
    const int32_t dstRepStride, const int32_t srcBlkStride, const int32_t srcRepStride, ReduceOrder order,
    const char* intriName)
{
    auto chkParams =
        BuildVecReduceOtherWhlParams(dst, src, repeatTime, dstRepStride, srcBlkStride, srcRepStride, order);
    return CheckFunReduceOtherWhlImplForMaskArray(chkParams, mask, intriName);
}

template <typename T>
check::VecReduceApiParams BuildVecReduceWithCalIndexParams(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    bool calIndex, const int32_t srcRepStride)
{
    return check::VecReduceApiParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        calIndex,
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(work.GetPosition()),
        static_cast<uint16_t>(srcRepStride)};
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    const int32_t mask, bool calIndex, const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceWithCalIndexParams(dst, src, work, repeatTime, calIndex, srcRepStride);
    return CheckFunReduceImpl(chkParams, mask, intriName);
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    const uint64_t mask[], bool calIndex, const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceWithCalIndexParams(dst, src, work, repeatTime, calIndex, srcRepStride);
    return CheckFunReduceImplForMaskArray(chkParams, mask, intriName);
}

template <typename T>
check::VecReduceApiParams BuildVecReduceWithoutCalIndexParams(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    const int32_t srcRepStride)
{
    return check::VecReduceApiParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(work.GetPosition()),
        static_cast<uint16_t>(srcRepStride)};
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    const int32_t mask, const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceWithoutCalIndexParams(dst, src, work, repeatTime, srcRepStride);
    return CheckFunReduceImpl(chkParams, mask, intriName);
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t repeatTime,
    const uint64_t mask[], const int32_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecReduceWithoutCalIndexParams(dst, src, work, repeatTime, srcRepStride);
    return CheckFunReduceImplForMaskArray(chkParams, mask, intriName);
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, int32_t repeatTime,
    const int32_t count, bool calIndex, const char* intriName)
{
    // max or min level2
    check::VecReduceApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        static_cast<uint32_t>(count),
        calIndex,
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(work.GetPosition())};
    return CheckFunReduceImpl(chkParams, intriName);
}

template <typename T>
bool CheckFunVecReduce(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<T>& work, const int32_t count,
    int32_t repeatTime, const char* intriName)
{
    // sum level 2
    check::VecReduceApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(work.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        repeatTime,
        static_cast<uint32_t>(count),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(work.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(work.GetPosition())};
    return CheckFunReduceImpl(chkParams, intriName);
}

template <typename T>
bool CheckFunVecReduceMode2(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const int32_t count, const char* intriName)
{
    check::VecReduceApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(count),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition())};
    return CheckFunReduceImplMode2(chkParams, intriName);
}

template <typename T, typename U>
check::VecScatterApiParams BuildVecScatterParams(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
    const uint8_t repeatTime, const uint16_t srcRepStride)
{
    return check::VecScatterApiParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dstOffset.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        dstBaseAddr,
        repeatTime,
        static_cast<uint16_t>(srcRepStride),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(dstOffset.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(dstOffset.GetPosition())};
}

template <typename T, typename U>
bool CheckFunScatter(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
    const uint64_t mask[], const uint8_t repeatTime, const uint16_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecScatterParams(dst, src, dstOffset, dstBaseAddr, repeatTime, srcRepStride);
    return CheckFunScatterImplForMaskArray(chkParams, mask, intriName);
}

template <typename T, typename U>
bool CheckFunScatter(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
    const uint64_t mask, const uint8_t repeatTime, const uint16_t srcRepStride, const char* intriName)
{
    auto chkParams = BuildVecScatterParams(dst, src, dstOffset, dstBaseAddr, repeatTime, srcRepStride);
    return CheckFunScatterImpl(chkParams, mask, intriName);
}

template <typename T, typename U>
bool CheckFunScatter(
    const LocalTensor<T>& dst, const LocalTensor<T>& src, const LocalTensor<U>& dstOffset, const uint32_t dstBaseAddr,
    const uint32_t count, const char* intriName)
{
    check::VecScatterApiParams chkParams{
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dst.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(src.GetPhyAddr())),
        static_cast<uint64_t>(reinterpret_cast<uintptr_t>(dstOffset.GetPhyAddr())),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<T>)),
        static_cast<uint32_t>(sizeof(PrimT<U>)),
        dstBaseAddr,
        static_cast<uint32_t>(count),
        static_cast<uint64_t>(dst.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(src.GetSize() * sizeof(PrimT<T>)),
        static_cast<uint64_t>(dstOffset.GetSize() * sizeof(PrimT<U>)),
        static_cast<uint8_t>(dst.GetPosition()),
        static_cast<uint8_t>(src.GetPosition()),
        static_cast<uint8_t>(dstOffset.GetPosition())};
    return CheckFunScatterImpl(chkParams, intriName);
}
} // namespace AscendC
#endif

#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_KERNEL_CHECK_VEC_H__
#endif