* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mc2_matmul_block.h
* \brief
*/
#ifndef MC2_MATMUL_BLOCK_H
#define MC2_MATMUL_BLOCK_H
namespace AscendC {
constexpr uint32_t C0_SIZE = 16;
struct BaseBlockOffset {
uint64_t offsetA;
uint64_t offsetB;
uint64_t offsetC;
uint64_t offsetBias;
};
struct BaseBlockArguments
{
bool isRowOrder;
bool isAtomic;
bool isTransA;
bool isTransB;
uint32_t singleCoreM;
uint32_t singleCoreN;
uint32_t mBlockCnt;
uint32_t nBlockCnt;
uint32_t nBaseTail;
uint32_t mBaseTail;
uint32_t totalBlockCnt;
uint32_t blockCnt;
uint32_t blockStartIdx;
uint32_t blockCurrIdx;
uint32_t preCoreNum;
uint32_t preCoreStartIdx;
uint64_t mBlockOffset;
uint64_t nBlockOffset;
uint64_t mCWorkOffset;
};
class MatmulBaseBlockMC2 {
public:
__aicore__ inline MatmulBaseBlockMC2() {}
__aicore__ inline void Init(Mc2Tiling::RCSTiling& cfg, TCubeTiling& tiling, Mc2Tiling::TileL2Tiling &l2Tiling, uint32_t rankID=0);
__aicore__ inline void InitBlockIndex(uint32_t index=0);
__aicore__ inline void InitBlockWithoutIndex();
__aicore__ inline void UpdateBlockIndex(uint32_t currPos);
__aicore__ inline void UpdateBlockParams(int32_t mTileIndex=0, int32_t nTileIndex=0);
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void CalcGMOffset();
__aicore__ inline void GetBlockStartIdx(uint32_t startIdx, uint32_t endIdx);
public:
BaseBlockOffset offset_;
BaseBlockArguments args_;
TCubeTiling tiling_;
Mc2Tiling::RCSTiling cfg_;
};
__aicore__ inline void MatmulBaseBlockMC2::Init(Mc2Tiling::RCSTiling& cfg, TCubeTiling& tiling, Mc2Tiling::TileL2Tiling &l2Tiling, uint32_t rankID)
{
(void)l2Tiling;
(void)rankID;
tiling_ = tiling;
cfg_ = cfg;
args_.preCoreStartIdx = 0;
args_.mBlockCnt = DivCeil(tiling.M, tiling.baseM);
args_.nBlockCnt = DivCeil(tiling.N, tiling.baseN);
args_.nBaseTail = tiling.N - (args_.nBlockCnt - 1) * tiling.baseN;
args_.mBaseTail = tiling.M - (args_.mBlockCnt - 1) * tiling.baseM;
args_.totalBlockCnt = args_.mBlockCnt * args_.nBlockCnt;
args_.isRowOrder = true;
if (tiling_.N > 5 * tiling_.M) {
args_.isRowOrder = false;
}
args_.isTransA = cfg.isTransposeA > 0 ? true : false;
args_.isTransB = cfg.isTransposeB > 0 ? true : false;
args_.isAtomic = false;
if (args_.isTransA) {
args_.isAtomic = true;
}
}
__aicore__ inline void MatmulBaseBlockMC2::InitBlockIndex(uint32_t index)
{
args_.totalBlockCnt = args_.mBlockCnt * args_.nBlockCnt;
args_.blockCnt = args_.totalBlockCnt / tiling_.usedCoreNum;
args_.preCoreNum = args_.totalBlockCnt % tiling_.usedCoreNum;
auto startIdx = index * args_.preCoreNum % tiling_.usedCoreNum;
auto endIdx = (startIdx + args_.preCoreNum) % tiling_.usedCoreNum;
GetBlockStartIdx(startIdx, endIdx);
}
__aicore__ inline void MatmulBaseBlockMC2::InitBlockWithoutIndex()
{
args_.totalBlockCnt = args_.mBlockCnt * args_.nBlockCnt;
args_.blockCnt = args_.totalBlockCnt / tiling_.usedCoreNum;
args_.preCoreNum = args_.totalBlockCnt % tiling_.usedCoreNum;
auto startIdx = args_.preCoreStartIdx;
auto endIdx = (startIdx + args_.preCoreNum) % tiling_.usedCoreNum;
args_.preCoreStartIdx = endIdx;
GetBlockStartIdx(startIdx, endIdx);
}
__aicore__ inline void MatmulBaseBlockMC2::GetBlockStartIdx(uint32_t startIdx, uint32_t endIdx)
{
if (startIdx > endIdx) {
if (block_idx < endIdx) {
args_.blockCnt += 1;
args_.blockStartIdx = block_idx * args_.blockCnt;
} else if (block_idx >= startIdx) {
args_.blockCnt += 1;
args_.blockStartIdx = block_idx * args_.blockCnt - (tiling_.usedCoreNum - args_.preCoreNum);
} else {
args_.blockStartIdx = block_idx * args_.blockCnt + endIdx;
}
} else {
if (block_idx < startIdx) {
args_.blockStartIdx = block_idx * args_.blockCnt;
} else if (block_idx >= endIdx) {
args_.blockStartIdx = block_idx * args_.blockCnt + args_.preCoreNum;
} else {
args_.blockCnt += 1;
args_.blockStartIdx = block_idx * args_.blockCnt - startIdx;
}
}
if (!args_.isRowOrder) {
auto blockStart = args_.blockStartIdx;
args_.blockStartIdx = blockStart / args_.mBlockCnt + blockStart % args_.mBlockCnt * args_.nBlockCnt;
}
}
__aicore__ inline void MatmulBaseBlockMC2::UpdateBlockIndex(uint32_t currPos)
{
if (args_.isRowOrder) {
args_.blockCurrIdx = args_.blockStartIdx + currPos % args_.blockCnt;
return;
}
args_.blockCurrIdx = args_.blockStartIdx + (currPos % args_.blockCnt) * args_.nBlockCnt;
if (args_.blockCurrIdx >= args_.totalBlockCnt) {
args_.blockCurrIdx = args_.blockCurrIdx % args_.totalBlockCnt + args_.blockCurrIdx / args_.totalBlockCnt;
}
return;
}
__aicore__ inline void MatmulBaseBlockMC2::UpdateBlockParams(int32_t mTileIndex, int32_t nTileIndex)
{
(void)mTileIndex;
(void)nTileIndex;
if (args_.blockCurrIdx == (args_.totalBlockCnt - 1)) {
args_.singleCoreM = args_.mBaseTail;
args_.singleCoreN = args_.nBaseTail;
} else if (args_.blockCurrIdx >= (args_.mBlockCnt - 1) * args_.nBlockCnt) {
args_.singleCoreM = args_.mBaseTail;
args_.singleCoreN = tiling_.baseN;
} else if ((args_.blockCurrIdx + 1) % args_.nBlockCnt == 0) {
args_.singleCoreM = tiling_.baseM;
args_.singleCoreN = args_.nBaseTail;
} else {
args_.singleCoreM = tiling_.baseM;
args_.singleCoreN = tiling_.baseN;
}
args_.mBlockOffset = args_.blockCurrIdx / args_.nBlockCnt * tiling_.baseM;
args_.nBlockOffset = args_.blockCurrIdx % args_.nBlockCnt * tiling_.baseN;
args_.mCWorkOffset = args_.mBlockOffset;
}
template <class A_TYPE, class B_TYPE, class C_TYPE, class BIAS_TYPE>
__aicore__ inline void MatmulBaseBlockMC2::CalcGMOffset()
{
auto alignedKa = AlignUp(tiling_.Ka, C0_SIZE);
auto alignedKb = AlignUp(tiling_.Kb, C0_SIZE);
if constexpr (A_TYPE::format == CubeFormat::ND) {
if (args_.isTransA) {
offset_.offsetA = args_.mBlockOffset;
} else {
offset_.offsetA = args_.mBlockOffset * tiling_.Ka;
}
} else if constexpr (A_TYPE::format == CubeFormat::NZ) {
if (args_.isTransA) {
offset_.offsetA = args_.mBlockOffset * alignedKa;
} else {
offset_.offsetA = args_.mBlockOffset * C0_SIZE;
}
}
if constexpr (B_TYPE::format == CubeFormat::ND) {
if (args_.isTransB) {
offset_.offsetB = args_.nBlockOffset * tiling_.Kb;
} else {
offset_.offsetB = args_.nBlockOffset;
}
} else if constexpr (B_TYPE::format == CubeFormat::NZ) {
if (args_.isTransB) {
offset_.offsetB = args_.nBlockOffset * C0_SIZE;
} else {
offset_.offsetB = args_.nBlockOffset * alignedKb;
}
}
if constexpr (C_TYPE::format == CubeFormat::ND || C_TYPE::format == CubeFormat::ND_ALIGN) {
offset_.offsetC = args_.nBlockOffset + args_.mCWorkOffset * tiling_.N;
}
if constexpr (BIAS_TYPE::format == CubeFormat::ND) {
offset_.offsetBias = args_.nBlockOffset;
}
}
}
#endif