/**
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/

/*!
 * \file batch_matmul_impl.h
 * \brief
 */
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message("impl/adv_api/detail/matmul/batch_matmul_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__
#endif

#ifndef IMPL_MATMUL_BATCH_MATMUL_IMPL_H
#define IMPL_MATMUL_BATCH_MATMUL_IMPL_H

#include "matmul_impl_base.h"

namespace AscendC {

// Match Policy with CallBack parameter
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG, class MM_CB,
    MATMUL_POLICY_TEMPLATE_OF(MATMUL_POLICY)>
class MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY,
    enable_if_t<A_TYPE::layout != LayoutMode::NONE>>
    : public MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>
    , MATMUL_IMPORT_MODULE(BatchScheduler)
    , MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInParamsA)
    , MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInParamsB)
    , MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInA)
    , MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInB)
    , MATMUL_IMPORT_MODULE_PRIVATE(BatchLoop)
{
private:
    using SrcAT = typename A_TYPE::T;
    using SrcBT = typename B_TYPE::T;
    using DstT = typename C_TYPE::T;
    using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
public:
    MATMUL_ALLOW_USING(CopyCubeInA);
    MATMUL_ALLOW_USING(CopyCubeInB);
    MATMUL_ALLOW_USING(CopyCubeOut);
    MATMUL_ALLOW_USING(Scheduler);
    MATMUL_ALLOW_USING(BatchScheduler);
    MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInParamsA);
    MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInParamsB);
    MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInA);
    MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInB);
    MATMUL_ALLOW_USING_PRIVATE(BatchLoop);
    MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoA);
    MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoB);

    template <InputTypeTag TAG>
    using BatchCopyCubeInParams = typename AscendC::Conditional<TAG == InputTypeTag::A, BatchCopyCubeInParamsA, BatchCopyCubeInParamsB>::type;

    template <InputTypeTag TAG>
    using MatmulTensorInfo =
        typename AscendC::Conditional<TAG == InputTypeTag::A, MatmulTensorInfoA, MatmulTensorInfoB>::type;
private:
    MATMUL_USE_MODULE(CopyCubeInA);
    MATMUL_USE_MODULE(CopyCubeInB);
    MATMUL_USE_MODULE(Scheduler);
    MATMUL_USE_MODULE(BatchScheduler);
    MATMUL_USE_MODULE(BatchCopyCubeInA);
    MATMUL_USE_MODULE(BatchCopyCubeInB);
    MATMUL_USE_MODULE(BatchLoop);

    using ChosenCopyCubeInA = typename AscendC::Conditional<Impl::Detail::GetCopyCubeInType<MatmulInputAType<A_TYPE, typename A_TYPE::T>, MM_CFG>() != Impl::Detail::CopyCubeInType::BMM,
                                                            CopyCubeInA, BatchCopyCubeInA>::type;

    using ChosenCopyCubeInB = typename AscendC::Conditional<Impl::Detail::GetCopyCubeInType<MatmulInputBType<B_TYPE, typename B_TYPE::T>, MM_CFG>() != Impl::Detail::CopyCubeInType::BMM,
                                                            CopyCubeInB, BatchCopyCubeInB>::type;
    MATMUL_USE_MODULE(ChosenCopyCubeInA);
    MATMUL_USE_MODULE(ChosenCopyCubeInB);
public:
    using BASE_MODULE = MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
    __aicore__ inline MatmulImpl() {}

    __aicore__ inline void Init(const TCubeTiling* __restrict cubeTiling, TPipe* tpipe = nullptr)
    {
        auto tpipePtr = GetTPipePtr();
        MATMUL_MODULE(BatchScheduler)->Init(cubeTiling, tpipePtr);
    }

    __aicore__ inline void End()
    {
        MATMUL_MODULE(BatchScheduler)->End();
    }

    __aicore__ inline void SetTensorA(const GlobalTensor<SrcAT>& gm, bool isTransposeA = false)
    {
        MATMUL_MODULE(ChosenCopyCubeInA)->SetInput(gm, isTransposeA);
        MATMUL_MODULE(Scheduler)->Reset();
    }

    __aicore__ inline void SetTensorA(const LocalTensor<SrcAT>& leftMatrix, bool isTransposeA = false)
    {
        BASE_MODULE::SetTensorA(leftMatrix, isTransposeA);
    }

    __aicore__ inline void SetTensorB(const GlobalTensor<SrcBT>& gm, bool isTransposeB = false)
    {
        MATMUL_MODULE(ChosenCopyCubeInB)->SetInput(gm, isTransposeB);
        MATMUL_MODULE(Scheduler)->Reset();
    }

    __aicore__ inline void SetTensorB(const LocalTensor<SrcBT>& rightMatrix, bool isTransposeB = false)
    {
        BASE_MODULE::SetTensorB(rightMatrix, isTransposeB);
    }

    __aicore__ inline void SetTensorA(SrcAT aScalar)
    {
        BASE_MODULE::SetTensorA(aScalar);
    }

    __aicore__ inline void SetTensorB(SrcBT bScalar)
    {
        BASE_MODULE::SetTensorB(bScalar);
    }

    __aicore__ inline void SetBatchNum(int32_t batchA, int32_t batchB)
    {
        MATMUL_MODULE(BatchLoop)->SetBatchNum(batchA, batchB);
    }

    __aicore__ inline void SetNBatchOutNum(int32_t nBatchOutNumIn)
    {
        int32_t nBatchOutNum = 1;
        if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
            nBatchOutNum = nBatchOutNumIn;
        }
        MATMUL_MODULE(BatchScheduler)->SetNBatchOutNum(nBatchOutNum);
        MATMUL_MODULE(BatchLoop)->SetNBatchOutNum(nBatchOutNum);
    }

    __aicore__ inline void IterateBatch(const GlobalTensor<DstT>& gm,
        bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite, const uint32_t matrixStrideA = 0,
        const uint32_t matrixStrideB = 0, const uint32_t matrixStrideC = 0)
    {
        MATMUL_MODULE(BatchScheduler)->Schedule(gm, enPartialSum, enAtomic, enSequentialWrite, matrixStrideA, matrixStrideB,
            matrixStrideC);
    }

    __aicore__ inline void IterateBatch(const LocalTensor<DstT>& ubCmatrix,
        bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite, const uint32_t matrixStrideA = 0,
        const uint32_t matrixStrideB = 0, const uint32_t matrixStrideC = 0)
    {
        MATMUL_MODULE(BatchScheduler)->Schedule(ubCmatrix, enPartialSum, enAtomic, enSequentialWrite, matrixStrideA,
            matrixStrideB, matrixStrideC);
    }
};

} // namespace AscendC

#endif

#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__
#endif