* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file batch_matmul_impl.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message( \
"impl/adv_api/detail/matmul/batch_matmul_impl.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__
#endif
#ifndef IMPL_MATMUL_BATCH_MATMUL_IMPL_H
#define IMPL_MATMUL_BATCH_MATMUL_IMPL_H
#include "matmul_impl_base.h"
namespace AscendC {
template <
class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG, class MM_CB,
MATMUL_POLICY_TEMPLATE_OF(MATMUL_POLICY)>
class MatmulImpl<
A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY, enable_if_t<A_TYPE::layout != LayoutMode::NONE>>
: public MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>,
MATMUL_IMPORT_MODULE(BatchScheduler),
MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInParamsA),
MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInParamsB),
MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInA),
MATMUL_IMPORT_MODULE_PRIVATE(BatchCopyCubeInB),
MATMUL_IMPORT_MODULE_PRIVATE(BatchLoop) {
private:
using SrcAT = typename A_TYPE::T;
using SrcBT = typename B_TYPE::T;
using DstT = typename C_TYPE::T;
using IMPL = MatmulImpl<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
public:
MATMUL_ALLOW_USING(CopyCubeInA);
MATMUL_ALLOW_USING(CopyCubeInB);
MATMUL_ALLOW_USING(CopyCubeOut);
MATMUL_ALLOW_USING(Scheduler);
MATMUL_ALLOW_USING(BatchScheduler);
MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInParamsA);
MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInParamsB);
MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInA);
MATMUL_ALLOW_USING_PRIVATE(BatchCopyCubeInB);
MATMUL_ALLOW_USING_PRIVATE(BatchLoop);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoA);
MATMUL_ALLOW_USING_PRIVATE(MatmulTensorInfoB);
template <InputTypeTag TAG>
using BatchCopyCubeInParams =
typename AscendC::Conditional<TAG == InputTypeTag::A, BatchCopyCubeInParamsA, BatchCopyCubeInParamsB>::type;
template <InputTypeTag TAG>
using MatmulTensorInfo =
typename AscendC::Conditional<TAG == InputTypeTag::A, MatmulTensorInfoA, MatmulTensorInfoB>::type;
private:
MATMUL_USE_MODULE(CopyCubeInA);
MATMUL_USE_MODULE(CopyCubeInB);
MATMUL_USE_MODULE(Scheduler);
MATMUL_USE_MODULE(BatchScheduler);
MATMUL_USE_MODULE(BatchCopyCubeInA);
MATMUL_USE_MODULE(BatchCopyCubeInB);
MATMUL_USE_MODULE(BatchLoop);
using ChosenCopyCubeInA = typename AscendC::Conditional<
Impl::Detail::GetCopyCubeInType<MatmulInputAType<A_TYPE, typename A_TYPE::T>, MM_CFG>() !=
Impl::Detail::CopyCubeInType::BMM,
CopyCubeInA, BatchCopyCubeInA>::type;
using ChosenCopyCubeInB = typename AscendC::Conditional<
Impl::Detail::GetCopyCubeInType<MatmulInputBType<B_TYPE, typename B_TYPE::T>, MM_CFG>() !=
Impl::Detail::CopyCubeInType::BMM,
CopyCubeInB, BatchCopyCubeInB>::type;
MATMUL_USE_MODULE(ChosenCopyCubeInA);
MATMUL_USE_MODULE(ChosenCopyCubeInB);
public:
using BASE_MODULE = MatmulImplBase<A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG, MM_CB, MATMUL_POLICY>;
__aicore__ inline MatmulImpl() {}
__aicore__ inline void Init(const TCubeTiling* __restrict cubeTiling, TPipe* tpipe = nullptr)
{
auto tpipePtr = GetTPipePtr();
MATMUL_MODULE(BatchScheduler)->Init(cubeTiling, tpipePtr);
}
__aicore__ inline void End() { MATMUL_MODULE(BatchScheduler)->End(); }
__aicore__ inline void SetTensorA(const GlobalTensor<SrcAT>& gm, bool isTransposeA = false)
{
MATMUL_MODULE(ChosenCopyCubeInA)->SetInput(gm, isTransposeA);
MATMUL_MODULE(Scheduler)->Reset();
}
__aicore__ inline void SetTensorA(const LocalTensor<SrcAT>& leftMatrix, bool isTransposeA = false)
{
BASE_MODULE::SetTensorA(leftMatrix, isTransposeA);
}
__aicore__ inline void SetTensorB(const GlobalTensor<SrcBT>& gm, bool isTransposeB = false)
{
MATMUL_MODULE(ChosenCopyCubeInB)->SetInput(gm, isTransposeB);
MATMUL_MODULE(Scheduler)->Reset();
}
__aicore__ inline void SetTensorB(const LocalTensor<SrcBT>& rightMatrix, bool isTransposeB = false)
{
BASE_MODULE::SetTensorB(rightMatrix, isTransposeB);
}
__aicore__ inline void SetTensorA(SrcAT aScalar) { BASE_MODULE::SetTensorA(aScalar); }
__aicore__ inline void SetTensorB(SrcBT bScalar) { BASE_MODULE::SetTensorB(bScalar); }
__aicore__ inline void SetBatchNum(int32_t batchA, int32_t batchB)
{
MATMUL_MODULE(BatchLoop)->SetBatchNum(batchA, batchB);
}
__aicore__ inline void SetNBatchOutNum(int32_t nBatchOutNumIn)
{
int32_t nBatchOutNum = 1;
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
nBatchOutNum = nBatchOutNumIn;
}
MATMUL_MODULE(BatchScheduler)->SetNBatchOutNum(nBatchOutNum);
MATMUL_MODULE(BatchLoop)->SetNBatchOutNum(nBatchOutNum);
}
__aicore__ inline void IterateBatch(
const GlobalTensor<DstT>& gm, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite,
const uint32_t matrixStrideA = 0, const uint32_t matrixStrideB = 0, const uint32_t matrixStrideC = 0)
{
MATMUL_MODULE(BatchScheduler)
->Schedule(gm, enPartialSum, enAtomic, enSequentialWrite, matrixStrideA, matrixStrideB, matrixStrideC);
}
__aicore__ inline void IterateBatch(
const LocalTensor<DstT>& ubCmatrix, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite,
const uint32_t matrixStrideA = 0, const uint32_t matrixStrideB = 0, const uint32_t matrixStrideC = 0)
{
MATMUL_MODULE(BatchScheduler)
->Schedule(
ubCmatrix, enPartialSum, enAtomic, enSequentialWrite, matrixStrideA, matrixStrideB, matrixStrideC);
}
};
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_BATCH_MATMUL_IMPL_H__
#endif