/**

* Copyright (c) 2025 Huawei Technologies Co., Ltd.

* This program is free software, you can redistribute it and/or modify it under the terms and conditions of

* CANN Open Software License Agreement Version 2.0 (the "License").

* Please refer to the License for details. You may not use this file except in compliance with the License.

* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,

* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.

* See LICENSE in the root of the software repository for the full text of the License.

*/



/*!

 * \file kernel_operator_mm_base_impl.h

 * \brief

 */

#ifndef ASCENDC_MODULE_OPERATOR_MM_BASE_IMPL_H

#define ASCENDC_MODULE_OPERATOR_MM_BASE_IMPL_H

#include "kernel_tensor.h"



#if __NPU_ARCH__ == 1001

#include "dav_c100/kernel_operator_mm_impl.h"

#elif __NPU_ARCH__ == 2002

#include "dav_m200/kernel_operator_mm_impl.h"

#elif __NPU_ARCH__ == 2201

#include "dav_c220/kernel_operator_mm_impl.h"

#elif __NPU_ARCH__ == 3002

#include "dav_m300/kernel_operator_mm_impl.h"

#elif __NPU_ARCH__ == 3102

#include "dav_m310/kernel_operator_mm_impl.h"

#elif __NPU_ARCH__ == 3101

#include "dav_c310/kernel_operator_mm_impl.h"

#elif (__NPU_ARCH__ == 5102)

#include "dav_m510/kernel_operator_mm_impl.h"

#elif (__NPU_ARCH__ == 3003)

#include "dav_l300/kernel_operator_mm_impl.h"

#elif (__NPU_ARCH__ == 3113)

#include "dav_l311/kernel_operator_mm_impl.h"

#endif

#include "kernel_operator_mm_check.h"

#include "kernel_operator_mm_load2d_impl.h"

#include "kernel_struct_mm.h"

namespace AscendC {

struct IsResetLoad3dConfig {

    __aicore__ constexpr IsResetLoad3dConfig(const bool isSetFMatrixIn, const bool isSetPaddingIn)

    {

        isSetFMatrix = isSetFMatrixIn;

        isSetPadding = isSetPaddingIn;

    }

    bool isSetFMatrix = true;

    bool isSetPadding = true;

};



constexpr IsResetLoad3dConfig IS_RESER_LOAD3D_DEFAULT_CONFIG = {true, true};



/* **************************************************************************************************

 * LoadData 3dv1                                             *

 * ************************************************************************************************* */

/*

 * @ingroup DataLoad

 * @brief Cube data loading

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] loadDataParams.padList padding list

 * @param [in] loadDataParams.l1H operand height

 * @param [in] loadDataParams.l1W operand width

 * @param [in] loadDataParams.c1Inde The starting point of the tensor C1 dimension

 * @param [in] loadDataParams.fetchFilterW The starting position of the w dimension on the convolution kernel

 * @param [in] loadDataParams.fetchFilterH The starting position of the H dimension on the convolution kernel

 * @param [in] loadDataParams.leftTopW Start point of the W dimension on the source operand

 * @param [in] loadDataParams.leftTopH Start point of the H dimension on the source operand

 * @param [in] loadDataParams.strideW W dimension stride

 * @param [in] loadDataParams.strideH H dimension stride

 * @param [in] loadDataParams.filterW Convolution kernel width

 * @param [in] loadDataParams.filterH Convolution kernel height

 * @param [in] loadDataParams.dilationFilterW Convolution kernel width expansion coefficient

 * @param [in] loadDataParams.dilationFilterH Convolution kernel height expansion coefficient

 * @param [in] loadDataParams.jumpStride repeat stride

 * @param [in] loadDataParams.repeatMode repeat mode

 * @param [in] loadDataParams.repeatTime repeat times

 * @param [in] loadDataParams.cSize judge whether to turn on optimization

 * @param [in] loadDataParams.padValue Value of Pad filling value

 */



template <typename T, const IsResetLoad3dConfig &defaultConfig = IS_RESER_LOAD3D_DEFAULT_CONFIG,

    typename U = PrimT<T>, typename std::enable_if<IsSameType<PrimT<T>, U>::value, bool>::type = true>

__aicore__ inline void LoadDataImpl(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const LoadData3DParamsV1<U>& loadDataParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckFuncLoadData3dv1(dst, src, loadDataParams, "LoadData with LoadData3DParamsV1")) {

        ASCENDC_REPORT_CHECK_ERROR("LoadData with LoadData3DParamsV1", KernelFuncType::NONE_MODE);

    }

#endif

    ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in "

        "LoadData with LoadData3DParamsV1, current api support dtype combination is src and dst both: uint8_t / int8_t "

        "/ half.");});



    if constexpr (defaultConfig.isSetFMatrix) {

        Load3DSetFMatrixCal(loadDataParams.l1H, loadDataParams.l1W, loadDataParams.padList);

    }

    if constexpr (defaultConfig.isSetPadding) {

        Load3DSetPaddingCal(loadDataParams.padValue);

    }



    CheckTensorPos<T>(src, Hardware::L1, "src", "A1 / B1", "LoadData with LoadData3DParamsV1");

    CheckTensorAlign<T>(src, ONE_BLK_SIZE, "src", "LoadData with LoadData3DParamsV1");

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

    if (dstScope == Hardware::L0A) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadData with LoadData3DParamsV1"); // 512B align

        LoadData3DV1L12L0ACal((__ca__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::L0B) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadData with LoadData3DParamsV1"); // 512B align

        LoadData3DV1L12L0BCal((__cb__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::UB) {

        CheckTensorAlign<T>(dst, ONE_BLK_SIZE, "dst", "LoadData with LoadData3DParamsV1");

        LoadData3DV1L12UBCal((__ubuf__ PrimT<T>*)dst.GetPhyAddr(),

                             (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else {

        ASCENDC_CHECK_TPOSITION((false), "dst", "A2 / B2 / UB", "LoadData with LoadData3DParamsV1",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}



/* **************************************************************************************************

 * LoadData 3dv2                                             *

 * enhanced from v1, suitable for aicore > 200                                             *

 * ************************************************************************************************* */

/*

 * @ingroup DataLoad

 * @brief Cube data loading

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] loadDataParams.padList padding list

 * @param [in] loadDataParams.l1H operand height

 * @param [in] loadDataParams.l1W operand width

 * @param [in] loadDataParams.channelSize number of channels

 * @param [in] loadDataParams.kExtension Transmission length of K dimension

 * @param [in] loadDataParams.mExtension Transmission length of M dimension

 * @param [in] loadDataParams.kStartPt Start point of K dimension

 * @param [in] loadDataParams.mStartPt Start point of M dimension

 * @param [in] loadDataParams.strideW W dimension stride

 * @param [in] loadDataParams.strideH H dimension stride

 * @param [in] loadDataParams.filterW Convolution kernel width

 * @param [in] loadDataParams.filterH Convolution kernel height

 * @param [in] loadDataParams.dilationFilterW Convolution kernel width expansion coefficient

 * @param [in] loadDataParams.dilationFilterH Convolution kernel height expansion coefficient

 * @param [in] loadDataParams.enTranspose judge whether to enable the transpose function

 * @param [in] loadDataParams.enSmallK Whether to enable the small k feature

 * @param [in] loadDataParams.padValue Value of Pad filling value

 */

template <typename T, const IsResetLoad3dConfig &defaultConfig = IS_RESER_LOAD3D_DEFAULT_CONFIG,

    typename U = PrimT<T>, typename std::enable_if<IsSameType<PrimT<T>, U>::value, bool>::type = true>

__aicore__ inline void LoadDataImpl(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const LoadData3DParamsV2<U>& loadDataParams)

{

    ASCENDC_ASSERT(CheckFuncLoadData3dv2(dst, src, loadDataParams, "LoadData with LoadData3DParamsV2"), {

        ASCENDC_REPORT_CHECK_ERROR("LoadData with LoadData3DParamsV2", KernelFuncType::NONE_MODE);

    });

    if constexpr (defaultConfig.isSetFMatrix) {

        Load3DSetFMatrixCal(loadDataParams.l1H, loadDataParams.l1W, loadDataParams.padList);

    }

    if constexpr (defaultConfig.isSetPadding) {

        Load3DSetPaddingCal(loadDataParams.padValue);

    }



    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

#if __NPU_ARCH__ == 2002

    ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, int4b_t>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in "

        "LoadData with LoadData3DParamsV2, current api support dtype combination is src and dst both: uint8_t / int8_t "

        "/ half / int4b_t.");});

#elif __NPU_ARCH__ == 2201

    if (dstScope == Hardware::L0A) {

        ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, bfloat16_t, float, uint32_t, int32_t, int4b_t>()),

            {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in LoadData with LoadData3DParamsV2 when dst position is "

            "A2, current api support dtype combination is src and dst both: uint8_t / int8_t / half / bfloat16_t / "

            "float / uint32_t / int32_t / int4b_t.");});

    } else if (dstScope == Hardware::L0B) {

        ASCENDC_ASSERT((SupportType<PrimT<T>, half, bfloat16_t, float, uint32_t, int32_t>()), {KERNEL_LOG(KERNEL_ERROR,

            "Failed to check dtype in LoadData with LoadData3DParamsV2 when dst position is B2, current api support "

            "dtype combination is src and dst both: half / bfloat16_t / float / uint32_t / int32_t.");});

    }

#elif __NPU_ARCH__ == 3101

    ASCENDC_ASSERT(loadDataParams.kExtension * sizeof(T) % ONE_BLK_SIZE == 0, {

        KERNEL_LOG(KERNEL_ERROR, "kExtension * sizeof(T) must be a multiple of 32");});

    ASCENDC_ASSERT(loadDataParams.mExtension % 16 == 0, {

        KERNEL_LOG(KERNEL_ERROR, "mExtension should be a multiple of 16");});

    ASCENDC_ASSERT(loadDataParams.kStartPt * sizeof(T) % ONE_BLK_SIZE == 0, {

        KERNEL_LOG(KERNEL_ERROR, "kStartPt * sizeof(T) must be a multiple of 32");});

    ASCENDC_ASSERT(loadDataParams.mStartPt % 16 == 0, {

        KERNEL_LOG(KERNEL_ERROR, "mStartPt should be a multiple of 16");});

#elif __NPU_ARCH__ == 3102

    if (dstScope == Hardware::L0A) {

        ASCENDC_ASSERT((SupportType<PrimT<T>, uint8_t, int8_t, half, uint16_t, int16_t, int4b_t>()),

            {KERNEL_LOG(KERNEL_ERROR,

            "Failed to check dtype in LoadData with LoadData3DParamsV2 when dst position is A2, current api support "

            "dtype combination is src and dst both: uint8_t / int8_t / half / uint16_t / int16_t / int4b_t.");});

    } else {

        ASCENDC_ASSERT((SupportType<PrimT<T>, half, int16_t, uint16_t>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype "

            "in LoadData with LoadData3DParamsV2 when dst position is B2, current api support dtype combination is src "

            "and dst both: half / int16_t / uint16_t.");});

    }

#endif



    CheckTensorPos<T>(src, Hardware::L1, "src", "A1 / B1", "LoadData with LoadData3DParamsV2");

    if (dstScope == Hardware::L0A) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadData with LoadData3DParamsV2");

        LoadData3DV2L12L0ACal((__ca__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::L0B) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadData with LoadData3DParamsV2");

        LoadData3DV2L12L0BCal((__cb__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::UB) {

        CheckTensorAlign<T>(dst, ONE_BLK_SIZE, "dst", "LoadData with LoadData3DParamsV2");

        LoadData3DV2L12UBCal((__ubuf__ PrimT<T>*)dst.GetPhyAddr(),

                             (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else {

        ASCENDC_CHECK_TPOSITION((false), "dst", "A2 / B2 / UB", "LoadData with LoadData3DParamsV2",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}



#if ((__NPU_ARCH__ == 2201) || (__NPU_ARCH__ == 3002) ||            \

     (__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))

// cce compiler process laod3d bfloat16_t using B8, so use the half dtype instead

template <const IsResetLoad3dConfig& defaultConfig>

[[deprecated("NOTICE: LoadData<IsResetLoad3dConfig> has been deprecated and will be removed in the next version."

             " Please do not use it!")]]

__aicore__ inline void LoadData(const LocalTensor<bfloat16_t>& dst, const LocalTensor<bfloat16_t>& src,

    const LoadData3DParamsV2<bfloat16_t>& loadDataParams)

{

    LoadDataImpl<bfloat16_t, defaultConfig>(dst, src, loadDataParams);

}

#endif



#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 5102)

template <typename T, typename U>

__aicore__ inline __inout_pipe__(MTE2) void LoadDataImpl(const LocalTensor<T>& dst, const GlobalTensor<U>& src,

    const LoadData2DParamsV2& loadDataParams, const Nd2NzParamsV2& nd2nzParams)

{

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

    if (dstScope == Hardware::L1) {

        LoadData2DGM2L1Cal((__cbuf__ T *)dst.GetPhyAddr(), (__gm__ U *)src.GetPhyAddr(), loadDataParams, nd2nzParams);

    } else {

        ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "dst only support A1/B1"); });

    }

}

#endif



#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))

template <TPosition Dst, TPosition Src, typename T>

__aicore__ inline void LoadDataImpl(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const Load3DBitModeParam& loadDataParams)

{

    CheckTensorAlign<T>(src, ONE_BLK_SIZE, "src", "LoadData with LoadData3DParams");

    CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadData with LoadData3DParams");



    if constexpr (Src != TPosition::A1 && Src != TPosition::A2) {

        ASCENDC_CHECK_TPOSITION(false, "src", "A1 / B1",

            "LoadData with LoadDataBitModeParams",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(src.GetPosition())));

    };

    if constexpr (Dst == TPosition::A2) {

        LoadData3DV2L12L0ACal((__ca__ PrimT<T>*)dst.GetPhyAddr(),

                            (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if constexpr (Dst == TPosition::B2) {

        LoadData3DV2L12L0BCal((__cb__ PrimT<T>*)dst.GetPhyAddr(),

                            (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else {

        ASCENDC_CHECK_TPOSITION(false, "dst", "A2 / B2",

            "LoadData with LoadData3DParams",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}

#endif

/* **************************************************************************************************

 * LoadData 3dv2Pro                                             *

 * enhanced from v1, suitable for aicore > 200                                             *

 * ************************************************************************************************* */

/*

 * @ingroup DataLoad

 * @brief Cube data loading

 * @param [out] dst output LocalTensor

 * @param [in] src input LocalTensor

 * @param [in] loadDataParams.channelSize number of channels

 * @param [in] loadDataParams.GetKExtension() Transmission length of K dimension

 * @param [in] loadDataParams.GetMExtension() Transmission length of M dimension

 * @param [in] loadDataParams.GetKStartPt() Start point of K dimension

 * @param [in] loadDataParams.GetMStartPt() Start point of M dimension

 * @param [in] loadDataParams.GetStrideW() W dimension stride

 * @param [in] loadDataParams.GetStrideH() H dimension stride

 * @param [in] loadDataParams.GetFilterW() Convolution kernel width

 * @param [in] loadDataParams.GetFilterH() Convolution kernel height

 * @param [in] loadDataParams.GetDilationFilterW() Convolution kernel width expansion coefficient

 * @param [in] loadDataParams.GetDilationFilterH() Convolution kernel height expansion coefficient

 * @param [in] loadDataParams.enTranspose judge whether to enable the transpose function

 * @param [in] loadDataParams.enSmallK Whether to enable the small k feature

 */

template <typename T>

__aicore__ inline void LoadDataImpl(const LocalTensor<T>& dst, const LocalTensor<T>& src,

    const LoadData3DParamsV2Pro& loadDataParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckFuncLoadData3dv2Pro(dst, src, loadDataParams, "LoadData with LoadData3DParamsV2Pro")) {

        ASCENDC_REPORT_CHECK_ERROR("LoadData with LoadData3DParamsV2Pro", KernelFuncType::NONE_MODE);

    }

#endif

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

    if (dstScope == Hardware::L0A) {

        LoadData3DV2L12L0ACal((__ca__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::L0B) {

        LoadData3DV2L12L0BCal((__cb__ PrimT<T>*)dst.GetPhyAddr(),

                              (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::UB) {

        LoadData3DV2L12UBCal((__ubuf__ PrimT<T>*)dst.GetPhyAddr(),

                             (__cbuf__ PrimT<T>*)src.GetPhyAddr(), loadDataParams);

    } else {

        ASCENDC_CHECK_TPOSITION((false), "dst", "A1 / A2 / UB", "LoadData with LoadData3DParamsV2Pro",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}



#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3003))

// cce compiler process laod3d bfloat16_t using B8, so use the half dtype instead

template <>

__aicore__ inline void LoadDataImpl(const LocalTensor<bfloat16_t>& dst, const LocalTensor<bfloat16_t>& src,

    const LoadData3DParamsV2Pro& loadDataParams)

{

#if ASCENDC_CPU_DEBUG

    ASCENDC_ASSERT(CheckFuncLoadData3dv2Pro(dst, src, loadDataParams, "loaddata3dv2Pro"), {

        KERNEL_LOG(KERNEL_ERROR, "check loaddata3dv2Pro instr failed");

    });

#endif



    const Hardware dstScope = GetPhyType((QuePosition)dst.GetPosition());

    // compiler process bfloat16_t load3dv2 is using B8 type, so cast to half which is using B16 type

    if (dstScope == Hardware::L0A) {

        LoadData3DV2L12L0ACal((__ca__ half*)dst.GetPhyAddr(),

            (__cbuf__ half*)src.GetPhyAddr(), loadDataParams);

    } else if (dstScope == Hardware::L0B) {

        LoadData3DV2L12L0BCal((__cb__ half*)dst.GetPhyAddr(),

            (__cbuf__ half*)src.GetPhyAddr(), loadDataParams);

    } else {

        ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "dst only support A2/B2"); });

    }

}

#endif



/* **************************************************************************************************

 * Mmad                                             *

 * ************************************************************************************************* */

/*

 * @ingroup Mmad

 * @brief Matrix multiplication and addition

 * @param [out] dst output LocalTensor

 * @param [in] fm input LocalTensor

 * @param [in] filter input LocalTensor

 * @param [in] mmadParams.m Left matrix row number

 * @param [in] mmadParams.n right matrix column number

 * @param [in] mmadParams.k Left matrix column number m

 * @param [in] mmadParams.unitFlag whether enable unit flag

 * @param [in] mmadParams.kDirectionAlign is the indicator for alignment in L0A/L0B in the K direction

 * @param [in] mmadParams.cmatrixSource indicates the C matrix source, 1: the C matrix is in bias table buffer, 0: the C

 * matrix is in L0C

 * @param [in] mmadParams.cmatrixInitVal indicates the initial matrix, 1: the number in C matrix is 0, 0:use the real

 * number in C matrix

 */



template <typename T, typename U, typename S>

__aicore__ inline void MmadImpl(const LocalTensor<T>& dst, const LocalTensor<U>& fm,

    const LocalTensor<S>& filter, const MmadParams& mmadParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckMmadParams(dst, fm, filter, mmadParams, "Mmad")) {

        ASCENDC_REPORT_CHECK_ERROR("Mmad", KernelFuncType::NONE_MODE);

    }

    CheckMmadAlign(dst, fm, filter);

#endif

    MmadCal((__cc__ PrimT<T>*)dst.GetPhyAddr(), (__ca__ PrimT<U>*)fm.GetPhyAddr(),

        (__cb__ PrimT<S>*)filter.GetPhyAddr(), mmadParams);

}



template <typename T, typename U, typename S, typename V>

__aicore__ inline void MmadImpl(const LocalTensor<T>& dst, const LocalTensor<U>& fm,

    const LocalTensor<S>& filter, const LocalTensor<V>& bias, const MmadParams& mmadParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckMmadParams(dst, fm, filter, bias, mmadParams, "Mmad with bias")) {

        ASCENDC_REPORT_CHECK_ERROR("Mmad with bias", KernelFuncType::NONE_MODE);

    }

    CheckMmadAlign(dst, fm, filter);

    CheckTensorAlign<V>(bias, 128, "bias", "Mmad");

#if __NPU_ARCH__ == 2201

    ASCENDC_ASSERT((SupportType<Tuple<PrimT<T>, PrimT<U>, PrimT<S>, PrimT<V>>,

        Tuple<int32_t, int8_t, int8_t, int32_t>,

        Tuple<float, half, half, float>, Tuple<float, float, float, float>,

        Tuple<float, bfloat16_t, bfloat16_t, float>>()),

        {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Mmad, current api support dtype combination is "

        "Dst: int32_t, src0: int8_t, src1: int8_t, Bias: int32_t; Dst: float, src0: half, src1: half, Bias: float; "

        "Dst: float, src0: float, src1: float, Bias: float; "

        "Dst: float, src0: bfloat16_t, src1: bfloat16_t, Bias: float");});

#endif

#endif

    const Hardware biasScope = GetPhyType((TPosition)bias.GetPosition());

    bool cmatrixSource = false;

    if (biasScope == Hardware::BIAS) {

        cmatrixSource = true;

    } else if (biasScope == Hardware::L0C) {

        cmatrixSource = false;

    } else {

        ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR,

            "Failed to check bias tensor position in Mmad, supported positions are CO1 or C2"); });

    }

    MmadCal((__cc__ PrimT<T>*)dst.GetPhyAddr(), (__ca__ PrimT<U>*)fm.GetPhyAddr(),

        (__cb__ PrimT<S>*)filter.GetPhyAddr(), (uint64_t)bias.GetPhyAddr(), mmadParams, cmatrixSource);

}



#if defined(__NPU_ARCH__) && (__NPU_ARCH__ == 3101)

template <typename T, typename U, typename S>

__aicore__ inline void MmadImpl(const LocalTensor<T>& dst, const LocalTensor<U>& fm,

    const LocalTensor<S>& filter, const MmadBitModeParams& mmadParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckMmadParams(dst, fm, filter, mmadParams.GetConfig0(), "Mmad")) {

        ASCENDC_REPORT_CHECK_ERROR("Mmad", KernelFuncType::NONE_MODE);

    }

    CheckMmadAlign(dst, fm, filter);

#endif

    MmadCal((__cc__ PrimT<T>*)dst.GetPhyAddr(), (__ca__ PrimT<U>*)fm.GetPhyAddr(),

        (__cb__ PrimT<S>*)filter.GetPhyAddr(), mmadParams);

}



template <typename T, typename U, typename S, typename V>

__aicore__ inline void MmadImpl(const LocalTensor<T>& dst, const LocalTensor<U>& fm,

    const LocalTensor<S>& filter, const LocalTensor<V>& bias, const MmadBitModeParams& mmadParams)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckMmadParams(dst, fm, filter, bias, mmadParams.GetConfig0(), "Mmad with bias")) {

        ASCENDC_REPORT_CHECK_ERROR("Mmad with bias", KernelFuncType::NONE_MODE);

    }

    CheckMmadAlign(dst, fm, filter);

    CheckTensorAlign<V>(bias, 128, "bias", "Mmad");

#endif

    const Hardware biasScope = GetPhyType((TPosition)bias.GetPosition());

    bool cmatrixSource = false;

    if (biasScope == Hardware::BIAS) {

        cmatrixSource = true;

    } else if (biasScope == Hardware::L0C) {

        cmatrixSource = false;

    } else {

        ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR,

            "Failed to check bias tensor position in Mmad, supported positions are CO1 or C2"); });

    }

    MmadCal((__cc__ PrimT<T>*)dst.GetPhyAddr(), (__ca__ PrimT<U>*)fm.GetPhyAddr(),

        (__cb__ PrimT<S>*)filter.GetPhyAddr(), (uint64_t)bias.GetPhyAddr(), mmadParams);

}

#endif



#if __NPU_ARCH__ == 2201

template <typename T = int32_t, typename U = int8_t,

    typename std::enable_if<IsSameType<PrimT<T>, int32_t>::value, bool>::type = true,

    typename std::enable_if<IsSameType<PrimT<U>, int8_t>::value, bool>::type = true>

__aicore__ inline void MmadSpImpl(const LocalTensor<T>& dst, const LocalTensor<U>& fm,

    const LocalTensor<U>& filter, const MmadParams& mmadParams)

{

    CheckTensorPos<T>(dst, Hardware::L0C, "dst", "CO1", "MmadWithSparse");

    CheckTensorPos<U>(fm, Hardware::L0A, "fm", "A2", "MmadWithSparse");

    CheckTensorPos<U>(filter, Hardware::L0B, "filter", "B2", "MmadWithSparse");

    CheckTensorAlign<T>(dst, 1024, "dst", "MmadWithSparse");             // 1024B aligned

    CheckTensorAlign<U>(fm, VALUE_512, "fm", "MmadWithSparse");           // 512B aligned

    CheckTensorAlign<U>(filter, VALUE_512, "filter", "MmadWithSparse");   // 512B aligned

    ASCENDC_CHECK_VALUE_RANGE(mmadParams.m, 0, UINT12_MAX, "m", "MmadWithSparse");

    ASCENDC_CHECK_VALUE_RANGE(mmadParams.n, 0, UINT12_MAX, "n", "MmadWithSparse");

    ASCENDC_CHECK_VALUE_RANGE(mmadParams.k, 0, UINT12_MAX, "k", "MmadWithSparse");

    MmadSpCal((__cc__ int32_t*)dst.GetPhyAddr(), (__ca__ int8_t*)fm.GetPhyAddr(),

        (__cb__ int8_t*)filter.GetPhyAddr(), mmadParams);

}



template <typename T = int8_t, typename U = uint8_t,

    typename std::enable_if<IsSameType<PrimT<T>, int8_t>::value, bool>::type = true,

    typename std::enable_if<IsSameType<PrimT<U>, uint8_t>::value, bool>::type = true>

__aicore__ inline void LoadDataWithSparseImpl(const LocalTensor<T> &dst, const LocalTensor<T> &src,

    const LocalTensor<U> &idx, const LoadData2dParams &loadDataParam)

{

    CheckTensorPos<T>(dst, Hardware::L0B, "dst", "B2", "LoadDataWithSparse");

    CheckTensorPos<T>(src, Hardware::L1, "src", "B1", "LoadDataWithSparse");

    CheckTensorPos<U>(idx, Hardware::L1, "idx", "B1", "LoadDataWithSparse");

    CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadDataWithSparse");        // 512B align

    CheckTensorAlign<T>(src, ONE_BLK_SIZE, "src", "LoadDataWithSparse");     // 32B align

    CheckTensorAlign<U>(idx, ONE_BLK_SIZE, "idx", "LoadDataWithSparse");    // 32B align

    LoadDataWithSparseCal(dst, src, idx, loadDataParam);

}

#endif



#if __NPU_ARCH__ == 2002

template <typename T = int8_t, typename std::enable_if<IsSameType<PrimT<T>, int8_t>::value, bool>::type = true>

__aicore__ inline void LoadUnzipIndexImpl(const GlobalTensor<T>& src, uint32_t numOfIndexTabEntry)

{

    LoadUnzipIndexCal(src, numOfIndexTabEntry);

}

#endif



/* **************************************************************************************************

 * BroadCastVecToMM                                             *

 * ************************************************************************************************* */

template <typename T, typename U>

__aicore__ inline __inout_pipe__(V) void BroadCastVecToMMImpl(const LocalTensor<T> &dst,

    const LocalTensor<U> &src, const int32_t blockCount, const uint8_t blockLen, const uint8_t srcGap,

    const uint8_t dstGap)

{

#if ASCENDC_CPU_DEBUG

    if (!CheckFuncBroadCastToMM(dst, src, blockCount, blockLen, srcGap, dstGap, "BroadCastVecToMM")) {

        ASCENDC_REPORT_CHECK_ERROR("BroadCastVecToMM", KernelFuncType::NONE_MODE);

    }

#endif

    BroadCastVecToMMCal((__cc__ PrimT<T>*)dst.GetPhyAddr(), (__ubuf__ PrimT<U>*)src.GetPhyAddr(),

        blockCount, blockLen, srcGap, dstGap);

}



/* **************************************************************************************************

 * SetLoadDataPaddingValue                                             *                                            *

 * ************************************************************************************************* */

/*

 * @ingroup SetLoadDataPaddingValue

 * @brief setting loadData pad value

 * @param [in]padValue padding value

 */

template <typename T>

__aicore__ inline void Load3DSetPaddingImpl(const T padValue)

{

    Load3DSetPaddingCal(padValue);

}



/* **************************************************************************************************

 * Fill                                             *

 * ************************************************************************************************* */

/*

 * @ingroup Fill

 * @brief L0A/L0B value initializing

 * @param [out] dst output LocalTensor

 * @param [in] InitConstValueParams.repeatTimes repeat times

 * @param [in] InitConstValueParams.repeatTimes blockNum block number

 * @param [in] InitConstValueParams.dstGap interval between the previous tail and the next block head

 * @param [in] InitConstValueParams.initValue initialize Value

 */

template <typename T, typename U = PrimT<T>,

    typename std::enable_if<IsSameType<PrimT<T>, U>::value, bool>::type = true>

__aicore__ inline void FillImpl(const LocalTensor<T> &dst,

    const InitConstValueParams<U> &initConstValueParams)

{

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

    if (dstScope == Hardware::L0A) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "Fill when TPosition is A2");

        InitL0ANzMatrixCal((__ca__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else if (dstScope == Hardware::L0B) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "Fill when TPosition is B2");

        InitL0BNzMatrixCal((__cb__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else if (dstScope == Hardware::L1) {

        CheckTensorAlign<T>(dst, ONE_BLK_SIZE, "dst", "Fill when TPosition is A1 / B1");

        InitL1BufferCal((__cbuf__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else {

        ASCENDC_CHECK_TPOSITION(false, "dst", "A1 / B1 / A2 / B2", "Fill",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}



/* **************************************************************************************************

 * InitConstValue                                             *

 * ************************************************************************************************* */

/*

 * @ingroup InitConstValue

 * @brief L0A/L0B value initializing

 * @param [out] dst output LocalTensor

 * @param [in] InitConstValueParams.repeatTimes repeat times

 * @param [in] InitConstValueParams.repeatTimes blockNum block number

 * @param [in] InitConstValueParams.dstGap interval between the previous tail and the next block head

 * @param [in] InitConstValueParams.initValue initialize Value

 */

template <typename T, typename U = PrimT<T>,

    typename std::enable_if<IsSameType<PrimT<T>, U>::value, bool>::type = true>

__aicore__ inline void InitConstValueImpl(const LocalTensor<T> &dst,

    const InitConstValueParams<U> &initConstValueParams)

{

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

    if (dstScope == Hardware::L0A) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "InitConstValue when TPosition is A2");

        InitL0ANzMatrixCal((__ca__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else if (dstScope == Hardware::L0B) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "InitConstValue when TPosition is B2");

        InitL0BNzMatrixCal((__cb__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else if (dstScope == Hardware::L1) {

        CheckTensorAlign<T>(dst, ONE_BLK_SIZE, "dst", "InitConstValue when TPosition is A1 / B1");

        InitL1BufferCal((__cbuf__ PrimT<T>*)dst.GetPhyAddr(), initConstValueParams);

    } else {

        ASCENDC_CHECK_TPOSITION(false, "dst", "A1 / B1 / A2 / B2", "InitConstValue",

            ConstDefiner::Instance().logicNameMap.at(static_cast<uint8_t>(dst.GetPosition())));

    }

}



/* **************************************************************************************************

 * SetFmatrix                                             *

 * ************************************************************************************************* */

/*

 * @ingroup SetFmatrix

 * @brief setting fmatrix

 * @param [in]l1H operand height

 * @param [in]l1W operand width

 * @param [in]padList padding list

 * @param [in]fmatrixMode set fmatrix_a or fmatrix_b

 */

__aicore__ inline void SetFmatrixImpl(uint16_t l1H, uint16_t l1W, const uint8_t padList[4],

    const FmatrixMode &fmatrixMode)

{

    if (fmatrixMode == FmatrixMode::FMATRIX_LEFT) {

        Load3DSetFMatrixCal(l1H, l1W, padList);

    } else if (fmatrixMode == FmatrixMode::FMATRIX_RIGHT) {

        Load3DSetFMatrixBCal(l1H, l1W, padList);

    }

}



#if defined(__NPU_ARCH__) && ((__NPU_ARCH__ == 3101) || (__NPU_ARCH__ == 5102))

__aicore__ inline void SetFmatrixImpl(const SetFMatrixBitModeParams &param, const FmatrixMode &fmatrixMode)

{

    if (fmatrixMode == FmatrixMode::FMATRIX_LEFT) {

        Load3DSetFMatrixCal(param.GetConfig0());

    } else if (fmatrixMode == FmatrixMode::FMATRIX_RIGHT) {

        Load3DSetFMatrixBCal(param.GetConfig0());

    }

}

#endif



/* **************************************************************************************************

 * SetLoadDataBoundary                                             *

 * ************************************************************************************************* */

/*

 * @ingroup SetFmatrix

 * @brief setting loaddata boundary

 * @param [in]boundaryValue

 */

__aicore__ inline void SetLoadDataBoundaryImpl(uint32_t boundaryValue)

{

    SetLoadDataBoundaryCal(boundaryValue);

}



/* **************************************************************************************************

 * SetLoadDataRepeat                                             *

 * ************************************************************************************************* */

__aicore__ inline void SetLoadDataRepeatImpl(const LoadDataRepeatParam& repeatParams)

{

    SetLoadDataRepeatCal(repeatParams);

}



/* **************************************************************************************************

 * LoadDataUnzipImpl                                             *

 * ************************************************************************************************* */

/*

 * @ingroup LoadDataUnzip

 * @brief loadData and unzip

 * @param [out] dst output LocalTensor

 * @param [in] src input GlobalTensor

 */

template <typename T>

__aicore__ inline void LoadDataUnzipImpl(const LocalTensor<T>& dst, const GlobalTensor<T>& src)

{

    const Hardware dstScope = GetPhyType((TPosition)dst.GetPosition());

#if ASCENDC_CPU_DEBUG

    if (dstScope == Hardware::L1) {

        CheckTensorAlign<T>(dst,  ONE_BLK_SIZE, "dst", "LoadDataUnzip in A1 / B1"); // 32B align

    } else if (dstScope == Hardware::L0A || dstScope == Hardware::L0B) {

        CheckTensorAlign<T>(dst, VALUE_512, "dst", "LoadDataUnzip in B2");               // 512B align

    }

    if constexpr(!SupportType<PrimT<T>, int8_t>()) {

        ASCENDC_ASSERT(false, {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in LoadDataUnzip, current api support "

            "dtype combination is dst: int8_t.");});

    }

#endif

    if (dstScope == Hardware::L1) {

        LoadDataUnzipToL1Cal((__cbuf__ PrimT<T>*)dst.GetPhyAddr(), (__gm__ PrimT<T>*)src.GetPhyAddr());

    } else if (dstScope == Hardware::L0A) {

        LoadDataUnzipToL0ACal((__ca__ PrimT<T>*)dst.GetPhyAddr(), (__gm__ PrimT<T>*)src.GetPhyAddr());

    } else if (dstScope == Hardware::L0B) {

        LoadDataUnzipToL0BCal((__cb__ PrimT<T>*)dst.GetPhyAddr(), (__gm__ PrimT<T>*)src.GetPhyAddr());

    } else {

        ASCENDC_ASSERT((false), { KERNEL_LOG(KERNEL_ERROR, "Failed to check dst tensor position in LoadDataUnzip, "

            "supported positions are A1 / B1 / B2"); });

    }

}



} // namespace AscendC

#endif // ASCENDC_MODULE_OPERATOR_MM_BASE_IMPL_H