* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file batch_scheduler.h
* \brief
*/
#if !defined(__ASCENDC_INCLUDE_INTERNAL_HEADERS__)
#pragma message("impl/adv_api/detail/matmul/scheduler/batch/batch_scheduler.h is an internal header file and must not be used directly. Functions or variables defined in this file may be removed in the future. Please use \"#include \"adv_api/matmul/matmul.h\"\" and use public functions or variables defined in interface headers files.")
#define __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#define __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_SCHEDULER_BATCH_BATCH_SCHEDULER_H__
#endif
#ifndef IMPL_MATMUL_SCHEDULER_BATCH_BATCH_SCHEDULER_H
#define IMPL_MATMUL_SCHEDULER_BATCH_BATCH_SCHEDULER_H
#include "batch_scheduler_intf.h"
#include "batch_scheduler_base.h"
namespace AscendC {
namespace Impl {
namespace Detail {
BatchScheduler is considered entirely experimental.
We retain the freedom to make incompatible changes, but do not guarantee the stability.
BatchScheduler is only for internal usage, does not support extension or customized specialization!
*/
template <typename IMPL, class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE, const auto& MM_CFG>
class BatchScheduler<IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG,
enable_if_t<!MatmulFeatureTrait<MM_CFG>::IsNeedUB() && DoMatmulNorm(MM_CFG) &&
(A_TYPE::layout != LayoutMode::NONE && (A_TYPE::layout != LayoutMode::NORMAL ||
ToMatmulConfig(MM_CFG).batchMode != BatchMode::SINGLE_LARGE_THAN_L1))>>
: public BatchSchedulerBase<IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>
{
MATMUL_USE_MODULE(BatchLoop);
MATMUL_USE_MODULE(BatchCopyCubeInA);
MATMUL_USE_MODULE(BatchCopyCubeInB);
MATMUL_USE_MODULE(NLoop);
MATMUL_USE_MODULE(MLoop);
MATMUL_USE_MODULE(KLoop);
MATMUL_USE_MODULE(MatmulShapeTiling);
MATMUL_USE_MODULE(MatmulShapeInfo);
MATMUL_USE_MODULE(CubeOutBuffer);
MATMUL_USE_MODULE(BiasScheduler);
MATMUL_USE_MODULE(TBufPoolL0);
MATMUL_USE_MODULE(MmadCompute);
MATMUL_USE_MODULE(LoadToA2);
MATMUL_USE_MODULE(LoadToB2);
MATMUL_USE_MODULE(MatmulUnitFlag);
using TransAT = typename A_TYPE::T;
using TransBT = typename decltype(GetTransBDataType<A_TYPE, B_TYPE, MM_CFG>())::T;
using BiasT = typename BIAS_TYPE::T;
using DstT = typename C_TYPE::T;
public:
using BASE_MODULE = AscendC::Impl::Detail::BatchSchedulerBase<IMPL, A_TYPE, B_TYPE, C_TYPE, BIAS_TYPE, MM_CFG>;
__aicore__ inline BatchScheduler() = default;
__aicore__ inline ~BatchScheduler() = default;
template<class T>
__aicore__ inline void Schedule(const T& dst, bool enPartialSum, uint8_t enAtomic, bool enSequentialWrite,
const uint32_t matrixStrideA, const uint32_t matrixStrideB, const uint32_t matrixStrideC)
{
MATMUL_MODULE(BiasScheduler)->SetBias(MATMUL_MODULE(BiasScheduler)->IsBias() && !enPartialSum);
auto batchOffsetInfo = PrepareOffset();
auto ctx = BASE_MODULE::PrepareContext();
const auto batchLoop = MATMUL_MODULE(BatchLoop);
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
batchLoop->SetBatchOutCacheNum(0);
batchLoop->SetBatchOutOffsetNum(0);
}
LocalTensor<BiasT> bias;
if constexpr (!ToMatmulConfig(MM_CFG).isBiasBatch) {
bias = MATMUL_MODULE(BiasScheduler)->CopyIn(MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN());
}
for (batchLoop->OuterStart(); !batchLoop->OuterEnd(); batchLoop->OuterNext()) {
if constexpr (ToMatmulConfig(MM_CFG).isBiasBatch) {
bias = MATMUL_MODULE(BiasScheduler)->CopyIn(
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetSingleCoreN(),
batchLoop->GetBatchNum(), batchLoop->GetBiasBatchSrcOffset());
}
if constexpr (ToMatmulConfig(MM_CFG).batchMode != BatchMode::BATCH_LESS_THAN_L1) {
auto a1 = MATMUL_MODULE(BatchCopyCubeInA)->AllocTensor();
auto b1 = MATMUL_MODULE(BatchCopyCubeInB)->AllocTensor();
event_t eventIDMte2ToMte1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1));
event_t eventIDMToMte1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1));
auto batchLoop = MATMUL_MODULE(BatchLoop);
for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) {
uint32_t outerIdxA;
uint32_t outerIdxB;
batchLoop->CalcBatchOuterIdx(outerIdxA, outerIdxB);
MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, outerIdxA,
batchLoop->GetSplitIndex(), batchLoop->GetSplitSize());
MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, outerIdxB,
batchLoop->GetSplitIndex(), batchLoop->GetSplitSize());
SetFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
WaitFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
auto batchLoop = MATMUL_MODULE(BatchLoop);
for (batchLoop->InnerStart(); !batchLoop->InnerEnd(); batchLoop->InnerNext()) {
BASE_MODULE::isFirstIter_ = true;
if (batchOffsetInfo.setBiasFlag && (batchLoop->GetBatchIndex() % batchOffsetInfo.divisorBias == 1)) {
MATMUL_MODULE(BiasScheduler)->StopBias(bias);
}
UpdateInnerOffset(batchOffsetInfo, ctx);
while (BASE_MODULE::MoveNext()) {
MATMUL_MODULE(CubeOutBuffer)->AllocTensor();
ComputeBatch(a1, b1, bias, enPartialSum, ctx);
BatchScheduler::GetBatchResultImpl(dst, ctx, enAtomic, enSequentialWrite);
SetFlag<HardEvent::M_MTE1>(eventIDMToMte1);
WaitFlag<HardEvent::M_MTE1>(eventIDMToMte1);
}
EndIterate();
}
BASE_MODULE::End();
}
MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1);
MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1);
} else {
auto batchLoop = MATMUL_MODULE(BatchLoop);
for (batchLoop->SplitStart(); !batchLoop->SplitEnd(); batchLoop->SplitNext()) {
LocalTensor<TransAT> a1;
LocalTensor<TransBT> b1;
event_t eventIDMte2ToMte1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::MTE2_MTE1));
event_t eventIDMToMte1 = static_cast<event_t>(GetTPipePtr()->FetchEventID(HardEvent::M_MTE1));
auto splitIdxA = batchLoop->template GetSplitIndex<InputTypeTag::A>();
auto splitIdxB = batchLoop->template GetSplitIndex<InputTypeTag::B>();
MATMUL_MODULE(BatchCopyCubeInA)->BatchLoad(a1, matrixStrideA, batchLoop->GetOuterIndex(),
splitIdxA, batchLoop->GetSplitSize());
MATMUL_MODULE(BatchCopyCubeInB)->BatchLoad(b1, matrixStrideB, batchLoop->GetOuterIndex(),
splitIdxB, batchLoop->GetSplitSize());
SetFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
WaitFlag<HardEvent::MTE2_MTE1>(eventIDMte2ToMte1);
auto batchLoop = MATMUL_MODULE(BatchLoop);
for (batchLoop->InnerStart(); !batchLoop->InnerEnd(); batchLoop->InnerNext()) {
BASE_MODULE::isFirstIter_ = true;
if (batchOffsetInfo.setBiasFlag && (batchLoop->GetBatchIndex() % batchOffsetInfo.divisorBias == 1)) {
MATMUL_MODULE(BiasScheduler)->StopBias(bias);
}
UpdateInnerOffset(batchOffsetInfo, ctx);
while (BASE_MODULE::MoveNext()) {
MATMUL_MODULE(CubeOutBuffer)->AllocTensor();
ComputeBatch(a1, b1, bias, enPartialSum, ctx);
BatchScheduler::GetBatchResultImpl(dst, ctx, enAtomic, enSequentialWrite);
SetFlag<HardEvent::M_MTE1>(eventIDMToMte1);
WaitFlag<HardEvent::M_MTE1>(eventIDMToMte1);
}
EndIterate();
}
MATMUL_MODULE(BatchCopyCubeInA)->BatchDestroy(a1);
MATMUL_MODULE(BatchCopyCubeInB)->BatchDestroy(b1);
MATMUL_MODULE(BiasScheduler)->End();
MATMUL_MODULE(CubeOutBuffer)->Destroy();
}
MATMUL_MODULE(BatchCopyCubeInA)->Reset();
MATMUL_MODULE(BatchCopyCubeInB)->Reset();
}
if constexpr (ToMatmulConfig(MM_CFG).isBiasBatch) {
MATMUL_MODULE(BiasScheduler)->Destroy(bias);
}
}
if constexpr (!ToMatmulConfig(MM_CFG).isBiasBatch) {
MATMUL_MODULE(BiasScheduler)->Destroy(bias);
}
}
private:
__aicore__ inline BatchOffsetInfo PrepareOffset()
{
BatchOffsetInfo batchOffsetInfo;
BASE_MODULE::CalcBatchIterateAOffsetInfo(batchOffsetInfo);
BASE_MODULE::CalcBatchIterateBOffsetInfo(batchOffsetInfo);
CalcBatchIterateBiasOffsetInfo(batchOffsetInfo);
return batchOffsetInfo;
}
__aicore__ inline void UpdateInnerOffset(BatchOffsetInfo& batchOffsetInfo, BatchSchedulerContext& ctx)
{
const auto& bL = MATMUL_MODULE(BatchLoop);
if constexpr (ToMatmulConfig(MM_CFG).batchMode != BatchMode::BATCH_LESS_THAN_L1) {
auto batchIndex = bL->GetBatchIndex();
ctx.offsetA = UpdateOffset<A_TYPE>(batchOffsetInfo.alignA, batchIndex, batchOffsetInfo.modA, batchOffsetInfo.divisorA);
ctx.offsetB = UpdateOffset<B_TYPE>(batchOffsetInfo.alignB, batchIndex, batchOffsetInfo.modB, batchOffsetInfo.divisorB);
ctx.offsetBias = batchOffsetInfo.alignBias *
(batchIndex % batchOffsetInfo.modBias + batchIndex / batchOffsetInfo.divisorBias);
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
bL->SetBatchOutCacheNum(bL->GetBatchOutCacheNum() + 1);
}
} else {
auto biasIndex = bL->GetBatchIndex();
if (bL->GetBatchA() == bL->GetBatchB()) {
auto batchIndex = bL->GetBatchIndex() % bL->GetSplitBatchNum();
ctx.offsetA = UpdateOffset<A_TYPE>(batchOffsetInfo.alignA, batchIndex, batchOffsetInfo.modA, batchOffsetInfo.divisorA);
ctx.offsetB = UpdateOffset<B_TYPE>(batchOffsetInfo.alignB, batchIndex, batchOffsetInfo.modB, batchOffsetInfo.divisorB);
} else {
auto batchAIndex = bL->GetBatchA() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex()
: bL->GetBatchIndex() % bL->GetSplitBatchNum();
auto batchBIndex = bL->GetBatchB() <= bL->GetSplitBatchNum() ? bL->GetBatchIndex()
: bL->GetBatchIndex() % bL->GetSplitBatchNum();
ctx.offsetA = UpdateOffset<A_TYPE>(batchOffsetInfo.alignA, batchAIndex, batchOffsetInfo.modA, batchOffsetInfo.divisorA);
ctx.offsetB = UpdateOffset<B_TYPE>(batchOffsetInfo.alignB, batchBIndex, batchOffsetInfo.modB, batchOffsetInfo.divisorB);
}
ctx.offsetBias = batchOffsetInfo.alignBias *
(biasIndex % batchOffsetInfo.modBias + biasIndex / batchOffsetInfo.divisorBias);
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
bL->SetBatchOutCacheNum(bL->GetBatchOutCacheNum() + 1);
}
}
}
template <typename OFF_TYPE>
__aicore__ inline int32_t UpdateOffset(int32_t alignAB, int32_t batchIndex, int32_t modAB, int32_t divisorAB) {
if constexpr (OFF_TYPE::layout == LayoutMode::NORMAL) {
return alignAB * (batchIndex / divisorAB);
} else {
return alignAB * (batchIndex % modAB + batchIndex / divisorAB);
}
}
__aicore__ inline void ComputeBatch(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
bool enPartialSum, BatchSchedulerContext& ctx)
{
if constexpr (ToMatmulConfig(MM_CFG).scheduleType == ScheduleType::OUTER_PRODUCT){
ComputeL0DB(a1, b1, bias, enPartialSum, ctx);
} else if constexpr (IsBasic(MM_CFG)) {
ComputeOneIter(a1, b1, bias, enPartialSum, ctx);
} else {
ComputeMultiIter(a1, b1, bias, enPartialSum, ctx);
}
}
__aicore__ inline void ComputeMultiIter(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
bool enPartialSum, BatchSchedulerContext& ctx)
{
BASE_MODULE::InitSplitAParams(ctx.aL0Params);
BASE_MODULE::InitSplitBParams(ctx.bL0Params);
MATMUL_MODULE(KLoop)->OuterStart();
do {
int32_t sL0CInit;
int32_t sL0CLast;
BASE_MODULE::UpdateSplitParams(enPartialSum, ctx.aL0Params, ctx.bL0Params, sL0CInit, sL0CLast);
MATMUL_MODULE(BiasScheduler)->SplitLoad(bias, ctx.bL0Params.axisL0Len,
ctx.offsetBias + MATMUL_MODULE(NLoop)->GetOuterIdx() *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN());
BASE_MODULE::MacroCompute(a1, b1, ctx, sL0CInit, sL0CLast);
} while (MATMUL_MODULE(KLoop)->OuterNext());
}
__aicore__ inline void ComputeOneIter(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
bool enPartialSum, BatchSchedulerContext& ctx)
{
BASE_MODULE::InitSplitAParams(ctx.aL0Params);
BASE_MODULE::InitSplitBParams(ctx.bL0Params);
MATMUL_MODULE(KLoop)->OuterStart();
BASE_MODULE::UpdateSplitParams(ctx.aL0Params, ctx.bL0Params);
int32_t sL0CInit = enPartialSum ? 0 : 1;
MATMUL_MODULE(BiasScheduler)->SplitLoad(bias, ctx.bL0Params.axisL0Len,
ctx.offsetBias + MATMUL_MODULE(NLoop)->GetOuterIdx() * MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN());
BASE_MODULE::MacroCompute(a1, b1, ctx, sL0CInit, 1);
}
__aicore__ inline void ComputeL0DB(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
bool enPartialSum, BatchSchedulerContext& ctx)
{
BASE_MODULE::InitSplitAParams(ctx.aL0Params);
BASE_MODULE::InitSplitBParams(ctx.bL0Params);
MATMUL_MODULE(KLoop)->OuterStart();
MATMUL_MODULE(KLoop)->InnerStart();
do {
int32_t sL0CInit;
int32_t sL0CLast;
BASE_MODULE::UpdateSplitParams(enPartialSum, ctx.aL0Params, ctx.bL0Params, sL0CInit, sL0CLast);
if constexpr (ToMatmulConfig(MM_CFG).iterateOrder == IterateOrder::ORDER_N) {
MATMUL_MODULE(BiasScheduler)->SplitLoad(bias, ctx.bL0Params.axisL0Len,
ctx.offsetBias + MATMUL_MODULE(NLoop)->GetOuterIdx() *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN());
}
MATMUL_MODULE(LoadToA2)->Prepare(MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA(),
ctx.aL0Params.kAxisL1Len, ctx.aL0Params.axisL1Len);
MATMUL_MODULE(LoadToB2)->Prepare(MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB(),
ctx.bL0Params.kAxisL1Len);
if constexpr (ToMatmulConfig(MM_CFG).iterateOrder == IterateOrder::ORDER_M) {
ComputeMDb(a1, b1, bias, ctx, sL0CInit, sL0CLast, enPartialSum);
} else {
ComputeNDb(a1, b1, bias, ctx, sL0CInit, sL0CLast, enPartialSum);
MATMUL_MODULE(BiasScheduler)->Free();
}
} while(MATMUL_MODULE(KLoop)->OuterNext());
}
__aicore__ inline void ComputeNDb(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
BatchSchedulerContext& ctx, int32_t sL0CInit, int32_t sL0CLast, bool enPartialSum)
{
int32_t aKL1Offset = ctx.aL0Params.kAxisL1Offset;
int32_t bKL1Offset = ctx.bL0Params.kAxisL1Offset;
int32_t computeK = MATMUL_MODULE(KLoop)->GetBaseShape();
int32_t nL0DBLoop = MATMUL_MODULE(MLoop)->GetL0DBLoopNum();
LocalTensor<TransBT> b2 = MATMUL_MODULE(TBufPoolL0)->template GetBuffer<TPosition::B2, TransBT>();
MATMUL_MODULE(LoadToB2)->Load(b2, b1[ctx.offsetB], ctx.bL0Params.axisL1Len,
ctx.bL0Params.kAxisL1Len, ctx.bL0Params.axisL0Len,
computeK, ctx.bL0Params.axisL1Offset,
bKL1Offset, MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB());
int32_t axisL1DbOffset = ctx.aL0Params.axisL1Offset;
for (int32_t idx = 0; idx < nL0DBLoop; ++idx) {
auto& bufferPool = MATMUL_MODULE(TBufPoolL0)->Allocate();
LocalTensor<TransAT> a2 = bufferPool.template GetBuffer<TPosition::A2, TransAT>();
MATMUL_MODULE(LoadToA2)->Load(a2, a1[ctx.offsetA],
ctx.aL0Params.axisL1Len, ctx.aL0Params.kAxisL1Len,
ctx.aL0Params.axisL0Len, computeK, axisL1DbOffset, aKL1Offset,
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA());
bufferPool.EnQue();
bufferPool.DeQue();
bool cmatrixInitVal;
bool cmatrixSource;
BASE_MODULE::UpdateMmadComputeParams(sL0CInit, cmatrixSource, cmatrixInitVal);
MATMUL_MODULE(MmadCompute)->Compute(MATMUL_MODULE(CubeOutBuffer)->GetTensor()[idx *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM() *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN()],
a2, b2, ctx.aL0Params.axisL0Len, computeK, ctx.bL0Params.axisL0Len,
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA(),
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB(),
MATMUL_MODULE(MatmulUnitFlag)->GetUnitFlag(sL0CLast),
cmatrixSource,
cmatrixInitVal,
false);
bufferPool.Free();
axisL1DbOffset += MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM();
}
}
__aicore__ inline void ComputeMDb(LocalTensor<TransAT>& a1, LocalTensor<TransBT>& b1, LocalTensor<BiasT>& bias,
BatchSchedulerContext& ctx, int32_t sL0CInit, int32_t sL0CLast, bool enPartialSum)
{
int32_t aKL1Offset = ctx.aL0Params.kAxisL1Offset;
int32_t bKL1Offset = ctx.bL0Params.kAxisL1Offset;
int32_t computeK = MATMUL_MODULE(KLoop)->GetBaseShape();
int32_t mL0DbLoop = MATMUL_MODULE(NLoop)->GetL0DBLoopNum();
LocalTensor<TransAT> a2 = MATMUL_MODULE(TBufPoolL0)->template GetBuffer<TPosition::A2, TransAT>();
MATMUL_MODULE(LoadToA2)->Load(a2, a1[ctx.offsetA],
ctx.aL0Params.axisL1Len, ctx.aL0Params.kAxisL1Len, ctx.aL0Params.axisL0Len,
computeK, ctx.aL0Params.axisL1Offset, aKL1Offset,
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA());
int32_t axisL1DbOffset = ctx.bL0Params.axisL1Offset;
for (int32_t idx = 0; idx < mL0DbLoop; ++idx) {
auto& bufferPool = MATMUL_MODULE(TBufPoolL0)->Allocate();
LocalTensor<TransBT> b2 = bufferPool.template GetBuffer<TPosition::B2, TransBT>();
MATMUL_MODULE(LoadToB2)->Load(b2, b1[ctx.offsetB],
ctx.bL0Params.axisL1Len, ctx.bL0Params.kAxisL1Len, ctx.bL0Params.axisL0Len,
computeK, axisL1DbOffset, bKL1Offset,
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB());
bufferPool.EnQue();
bufferPool.DeQue();
MATMUL_MODULE(BiasScheduler)->SplitLoad(bias, ctx.bL0Params.axisL0Len,
ctx.offsetBias + (MATMUL_MODULE(NLoop)->GetOuterIdx() + idx) *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN());
bool cmatrixSource;
bool cmatrixInitVal;
BASE_MODULE::UpdateMmadComputeParams(sL0CInit, cmatrixSource, cmatrixInitVal);
MATMUL_MODULE(MmadCompute)->Compute(MATMUL_MODULE(CubeOutBuffer)->GetTensor()[idx *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseM() *
MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN()],
a2, b2,
ctx.aL0Params.axisL0Len, computeK, ctx.bL0Params.axisL0Len,
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeA(),
MATMUL_MODULE(MatmulShapeInfo)->IsTransposeB(),
MATMUL_MODULE(MatmulUnitFlag)->GetUnitFlag(sL0CLast),
cmatrixSource,
cmatrixInitVal,
false);
bufferPool.Free();
axisL1DbOffset += MATMUL_MODULE(MatmulShapeTiling)->GetTiling().GetBaseN();
MATMUL_MODULE(BiasScheduler)->Free();
}
}
__aicore__ inline void CalcBatchIterateBiasOffsetInfo(BatchOffsetInfo& batchOffsetInfo)
{
const auto tiling = MATMUL_MODULE(MatmulShapeTiling)->GetTiling();
if (tiling.GetCLayoutInfoG() == 1 && (tiling.GetALayoutInfoG() != 1 || tiling.GetBLayoutInfoG() != 1)) {
ASSERT((tiling.GetALayoutInfoG() > 0 && tiling.GetBLayoutInfoG() > 0) &&
(tiling.GetCLayoutInfoN() != 1 || (tiling.GetALayoutInfoN() == 1 && tiling.GetBLayoutInfoN() == 1)) &&
(tiling.GetCLayoutInfoB() != 1 || (tiling.GetALayoutInfoB() == 1 && tiling.GetBLayoutInfoB() == 1)));
auto gExtend = tiling.GetALayoutInfoG() != 1 ? tiling.GetALayoutInfoG() : tiling.GetBLayoutInfoG();
batchOffsetInfo.setBiasFlag = true;
batchOffsetInfo.modBias = 1;
batchOffsetInfo.divisorBias = gExtend;
} else if (tiling.GetCLayoutInfoN() == 1 && (tiling.GetALayoutInfoN() != 1 || tiling.GetBLayoutInfoN() != 1)) {
ASSERT((tiling.GetALayoutInfoN() > 0 && tiling.GetBLayoutInfoN() > 0) &&
(tiling.GetCLayoutInfoB() != 1 || (tiling.GetALayoutInfoB() == 1 && tiling.GetBLayoutInfoB() == 1)) &&
(tiling.GetCLayoutInfoG() != 1 || (tiling.GetALayoutInfoG() == 1 && tiling.GetBLayoutInfoG() == 1)));
auto gExtend = tiling.GetALayoutInfoG() != 1 ? tiling.GetALayoutInfoG() : tiling.GetBLayoutInfoG();
auto nExtend = tiling.GetALayoutInfoN() != 1 ? tiling.GetALayoutInfoN() : tiling.GetBLayoutInfoN();
batchOffsetInfo.modBias = gExtend;
batchOffsetInfo.divisorBias = gExtend * nExtend;
} else if (tiling.GetCLayoutInfoB() == 1 && (tiling.GetALayoutInfoB() != 1 || tiling.GetBLayoutInfoB() != 1)) {
ASSERT((tiling.GetALayoutInfoB() > 0 && tiling.GetBLayoutInfoB() > 0) &&
(tiling.GetCLayoutInfoN() != 1 || (tiling.GetALayoutInfoN() == 1 && tiling.GetBLayoutInfoN() == 1)) &&
(tiling.GetCLayoutInfoG() != 1 || (tiling.GetALayoutInfoG() == 1 && tiling.GetBLayoutInfoG() == 1)));
auto gExtend = tiling.GetALayoutInfoG() != 1 ? tiling.GetALayoutInfoG() : tiling.GetBLayoutInfoG();
auto nExtend = tiling.GetALayoutInfoN() != 1 ? tiling.GetALayoutInfoN() : tiling.GetBLayoutInfoN();
auto bExtend = tiling.GetALayoutInfoB() != 1 ? tiling.GetALayoutInfoB() : tiling.GetBLayoutInfoB();
batchOffsetInfo.modBias = gExtend * nExtend;
batchOffsetInfo.divisorBias = gExtend * nExtend * bExtend;
} else {
batchOffsetInfo.modBias = 1;
batchOffsetInfo.divisorBias = 1;
}
if constexpr (!ToMatmulConfig(MM_CFG).isBiasBatch) {
batchOffsetInfo.alignBias = 0;
} else {
batchOffsetInfo.alignBias = CeilAlign(MATMUL_MODULE(MatmulShapeInfo)->
template GetSingleCoreN<false, IsBasic(MM_CFG)>(), AscendCUtils::GetC0Count(sizeof(BiasT)));
}
}
template<class T>
__aicore__ inline void GetBatchResultImpl(const T &dst, const BatchSchedulerContext& ctx,
uint8_t enAtomic, bool enSequentialWrite)
{
const auto batchLoop = MATMUL_MODULE(BatchLoop);
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
if (batchLoop->IsNeedNBatchOut()) {
BASE_MODULE::GetBatchResult(dst, ctx, enAtomic, enSequentialWrite);
} else {
auto co1Local = MATMUL_MODULE(CubeOutBuffer)->GetTensor();
MATMUL_MODULE(CubeOutBuffer)->FreeTensor(co1Local);
}
} else {
BASE_MODULE::GetBatchResult(dst[batchLoop->GetDstOffset()], ctx, enAtomic, enSequentialWrite);
}
}
__aicore__ inline void EndIterate()
{
if constexpr (ToMatmulConfig(MM_CFG).bmmOutMode != BatchOutMode::SINGLE_BATCH) {
const auto batchLoop = MATMUL_MODULE(BatchLoop);
if (batchLoop->IsNeedNBatchOut()) {
batchLoop->SetBatchOutOffsetNum(batchLoop->GetBatchOutOffsetNum() + batchLoop->GetBatchOutCacheNum());
batchLoop->SetBatchOutCacheNum(0);
}
}
}
};
}
}
}
#endif
#if defined(__UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_SCHEDULER_BATCH_BATCH_SCHEDULER_H__)
#undef __ASCENDC_INCLUDE_INTERNAL_HEADERS__
#undef __UNDEF_ASCENDC_INCLUDE_INTERNAL_HEADERS_DETAIL_MATMUL_SCHEDULER_BATCH_BATCH_SCHEDULER_H__
#endif