* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file block_scheduler_qbmm.h
* \brief
*/
#pragma once
#include "blaze/gemm/utils/layout_utils.h"
#include "blaze/gemm/utils/common_utils.h"
#include "tensor_api/tensor.h"
namespace Blaze {
namespace Gemm {
namespace Block {
template <class ProblemShape_, uint64_t FullLoadMode_, class LayoutA_, class LayoutB_, class AType_>
class BlockSchedulerQuantBatchMatmulV3 {
public:
int64_t baseM_{0};
int64_t baseN_{0};
int64_t mCnt_{0};
int64_t nCnt_{0};
int64_t totalCnt_{0};
int64_t mBaseNormCnt_{0};
int64_t nBaseNormCnt_{0};
int64_t mBaseTailMain_{0};
int64_t nBaseTailMain_{0};
int64_t mBaseTailLast_{0};
int64_t nBaseTailLast_{0};
int64_t mCoreNum_{0};
int64_t mTailCoreNum_{0};
int64_t blockIdx_{AscendC::GetBlockIdx() / AscendC::GetTaskRation()};
int64_t blockNum_{AscendC::GetBlockNum()};
int64_t startBlockIdx_{0};
int64_t endBlockIdx_{0};
int64_t roundIdx_{0};
int64_t round_{0};
int64_t mTailTile_{1};
int64_t nTailTile_{1};
int64_t totalTailTile_{1};
int64_t mainRow_{0};
int64_t mSplitAddrOffset_{0};
int64_t nSplitAddrOffset_{0};
using BlockShape = AscendC::Te::Shape<int64_t, int64_t, int64_t, int64_t>;
using BlockCoord = AscendC::Te::Coord<int64_t, int64_t, int64_t, int64_t>;
using ProblemShape = ProblemShape_;
using AType = AType_;
static constexpr bool transA = IsTrans<LayoutA_>::value;
static constexpr bool transB = IsTrans<LayoutB_>::value;
static constexpr int64_t WINDOW_LEN = 4;
struct Params {
int64_t baseM;
int64_t baseN;
int64_t mTailTile;
int64_t nTailTile;
int64_t mBaseTailSplitCnt;
int64_t nBaseTailSplitCnt;
int64_t mTailMain;
int64_t nTailMain;
};
public:
__aicore__ inline BlockSchedulerQuantBatchMatmulV3(const ProblemShape& shape, const Params& params)
{
const int64_t m = AscendC::Te::Get<MNK_M>(shape);
const int64_t n = AscendC::Te::Get<MNK_N>(shape);
baseM_ = static_cast<int64_t>(params.baseM);
baseN_ = static_cast<int64_t>(params.baseN);
mCnt_ = Blaze::Gemm::CeilDiv(m, baseM_);
nCnt_ = Blaze::Gemm::CeilDiv(n, baseN_);
totalCnt_ = mCnt_ * nCnt_;
mCoreNum_ = Blaze::Gemm::Min(WINDOW_LEN, mCnt_);
if (mCoreNum_ != 0) {
mainRow_ = mCnt_ / mCoreNum_ - 1;
}
mTailCoreNum_ = mCnt_ - mCoreNum_ * mainRow_;
endBlockIdx_ = (totalCnt_ - 1) % blockNum_;
round_ = Blaze::Gemm::CeilDiv(totalCnt_, blockNum_);
if (blockIdx_ > endBlockIdx_) {
round_ -= 1;
}
if constexpr (!transA) {
mBaseNormCnt_ = mCnt_ - params.mBaseTailSplitCnt;
int64_t mMergeSize = m - mBaseNormCnt_ * baseM_;
mBaseTailMain_ = params.mBaseTailSplitCnt == 1 ? mMergeSize : params.mTailMain;
mBaseTailLast_ = mMergeSize - (params.mBaseTailSplitCnt - 1) * mBaseTailMain_;
} else {
mBaseTailMain_ = m - (mCnt_ - 1) * baseM_;
}
if constexpr (transB) {
nBaseNormCnt_ = nCnt_ - params.nBaseTailSplitCnt;
int64_t nMergeSize = n - nBaseNormCnt_ * baseN_;
nBaseTailMain_ = params.nBaseTailSplitCnt == 1 ? nMergeSize : params.nTailMain;
nBaseTailLast_ = nMergeSize - (params.nBaseTailSplitCnt - 1) * nBaseTailMain_;
} else {
nBaseTailMain_ = n - (nCnt_ - 1) * baseN_;
}
}
__aicore__ inline void UpdateTailTile(uint32_t mTailTile, uint32_t nTailTile)
{
mTailTile_ = mTailTile;
nTailTile_ = nTailTile;
totalTailTile_ = mTailTile * nTailTile;
uint64_t tailOriCnt = AscendC::Std::min(totalCnt_, endBlockIdx_ + 1);
int64_t newEndBlockIdx = endBlockIdx_ + tailOriCnt * (totalTailTile_ - 1);
if (blockIdx_ > endBlockIdx_ && blockIdx_ <= newEndBlockIdx) {
round_ += 1;
}
if (blockIdx_ > newEndBlockIdx) {
mTailTile_ = 1;
nTailTile_ = 1;
totalTailTile_ = 1;
}
endBlockIdx_ = newEndBlockIdx;
}
__aicore__ inline int64_t GetTotalCnt()
{
return totalCnt_;
}
__aicore__ inline int64_t GetEndBlockIdx()
{
return endBlockIdx_;
}
* @brief Round the input value up to the smallest power of two.
*
* Returns the smallest power of two greater than or equal to the input value.
* This implementation uses a bit-smearing technique and assumes
* the input value is in the range [1, 256].
*
* @param inputValue Input value to be rounded up.
*/
__aicore__ inline int64_t CeilPowerOfTwo(int64_t inputValue)
{
inputValue--;
inputValue |= inputValue >> 1;
inputValue |= inputValue >> 2;
inputValue |= inputValue >> 4;
inputValue++;
return inputValue;
}
__aicore__ inline void CalSingleCoreShapeByCoord(
int64_t& singleCoreM, int64_t& singleCoreN, BlockCoord blockCoord)
{
const int64_t mIdx = AscendC::Te::Get<MNK_M>(blockCoord);
const int64_t nIdx = AscendC::Te::Get<MNK_N>(blockCoord);
if constexpr (!transA) {
if (mIdx >= mBaseNormCnt_) {
singleCoreM = mIdx < mCnt_ - 1 ? mBaseTailMain_ : mBaseTailLast_;
}
} else {
if (mIdx == mCnt_ - 1) {
singleCoreM = mBaseTailMain_;
}
}
if constexpr (transB) {
if (nIdx >= nBaseNormCnt_) {
singleCoreN = nIdx < nCnt_ - 1 ? nBaseTailMain_ : nBaseTailLast_;
}
} else {
if (nIdx == nCnt_ - 1) {
singleCoreN = nBaseTailMain_;
}
}
}
template <QuantMode aQuantMode, QuantMode bQuantMode, bool weightNz = false>
__aicore__ inline BlockShape GetBlockShape(BlockCoord blockCoord)
{
int64_t singleCoreM = baseM_;
int64_t singleCoreN = baseN_;
CalSingleCoreShapeByCoord(singleCoreM, singleCoreN, blockCoord);
if (totalTailTile_ == 1 || roundIdx_ < round_) {
return {singleCoreM, singleCoreN, 0, 0};
}
int64_t singleCoreMSplit = Blaze::Gemm::CeilDiv(singleCoreM, mTailTile_);
int64_t singleCoreNSplit = Blaze::Gemm::CeilDiv(singleCoreN, nTailTile_);
if constexpr (IsFp4<AType>() && transA) {
singleCoreMSplit = (singleCoreMSplit + 1) & ~1;
}
if constexpr (IsFp4<AType>() && !transB) {
singleCoreNSplit = (singleCoreNSplit + 1) & ~1;
}
if constexpr ((aQuantMode == QuantMode::PERGROUP_MODE || aQuantMode == QuantMode::PERBLOCK_MODE) && transA) {
singleCoreMSplit = PER_BLOCK_SIZE << (singleCoreMSplit > PER_BLOCK_SIZE);
} else if constexpr (aQuantMode == QuantMode::PERBLOCK_MODE) {
singleCoreMSplit = CeilPowerOfTwo(singleCoreMSplit);
}
if constexpr (bQuantMode == QuantMode::PERBLOCK_MODE) {
if constexpr (!transB) {
singleCoreNSplit = PER_BLOCK_SIZE << (singleCoreNSplit > PER_BLOCK_SIZE);
} else {
singleCoreNSplit = CeilPowerOfTwo(singleCoreNSplit);
}
}
if constexpr (weightNz) {
if constexpr (!transB) {
if constexpr (IsFp4<AType>()) {
singleCoreNSplit = Align64(singleCoreNSplit);
} else {
singleCoreNSplit = Align32(singleCoreNSplit);
}
} else {
singleCoreNSplit = Align16(singleCoreNSplit);
}
}
const int64_t tailSplitIdx = blockIdx_ % totalTailTile_;
int64_t mSplitIdx = tailSplitIdx % mTailTile_;
int64_t nSplitIdx = 0;
if constexpr (FullLoadMode_ == A_FULL_LOAD_MODE) {
nSplitIdx = blockIdx_ / mCnt_ % nTailTile_;
} else {
nSplitIdx = tailSplitIdx / mTailTile_;
}
mSplitAddrOffset_ = mSplitIdx * singleCoreMSplit;
nSplitAddrOffset_ = nSplitIdx * singleCoreNSplit;
if (mSplitAddrOffset_ >= singleCoreM || nSplitAddrOffset_ >= singleCoreN) {
return {0, 0, 0, 0};
}
singleCoreM = Blaze::Gemm::Min(singleCoreM - mSplitAddrOffset_, singleCoreMSplit);
singleCoreN = Blaze::Gemm::Min(singleCoreN - nSplitAddrOffset_, singleCoreNSplit);
return {singleCoreM, singleCoreN, mSplitAddrOffset_, nSplitAddrOffset_};
}
__aicore__ inline void UpdateNextBatchBlockRoundParams()
{
startBlockIdx_ = endBlockIdx_ + 1 == blockNum_ ? 0 : (endBlockIdx_ + 1);
endBlockIdx_ = (totalCnt_ + startBlockIdx_ - 1) % blockNum_;
roundIdx_ = 0;
round_ = Blaze::Gemm::CeilDiv(totalCnt_, blockNum_);
if (startBlockIdx_ > endBlockIdx_ && (blockIdx_ > endBlockIdx_ && blockIdx_ < startBlockIdx_)) {
round_ -= 1;
} else if (startBlockIdx_ <= endBlockIdx_ && (blockIdx_ > endBlockIdx_ || blockIdx_ < startBlockIdx_)) {
round_ -= 1;
}
}
__aicore__ inline bool GetTileIdx(BlockCoord& blockCoord)
{
if (roundIdx_ >= round_) {
return false;
}
int64_t blockCoordM = 0;
int64_t blockCoordN = 0;
int64_t newBlockIdx = (roundIdx_ == round_ - 1) ? blockIdx_ / totalTailTile_ : blockIdx_;
int64_t tileIdx = newBlockIdx + roundIdx_ * blockNum_;
if constexpr (FullLoadMode_ == A_FULL_LOAD_MODE) {
blockCoordM = blockIdx_ % mCnt_;
int64_t curNTailTile = (roundIdx_ == round_ - 1) ? nTailTile_ : 1;
blockCoordN = roundIdx_ * blockNum_ / mCnt_ % nCnt_ + blockIdx_ / mCnt_ / curNTailTile;
roundIdx_++;
blockCoord = BlockCoord{blockCoordM, blockCoordN, 0, 0};
return true;
}
if (blockIdx_ < startBlockIdx_) {
tileIdx += blockNum_ - startBlockIdx_;
} else if (endBlockIdx_ + 1 >= totalTailTile_ * totalCnt_) {
tileIdx -= startBlockIdx_ / totalTailTile_;
} else {
tileIdx -= startBlockIdx_;
}
int64_t rowIdx = tileIdx / nCnt_ / mCoreNum_;
int64_t nIdx = 0;
if (rowIdx < mainRow_) {
blockCoordM = rowIdx * mCoreNum_ + tileIdx % mCoreNum_;
nIdx = (tileIdx / mCoreNum_) % nCnt_;
} else {
rowIdx = mainRow_;
int64_t tailIdx = tileIdx - mainRow_ * mCoreNum_ * nCnt_;
blockCoordM = mainRow_ * mCoreNum_ + tailIdx % mTailCoreNum_;
nIdx = (tailIdx / mTailCoreNum_) % nCnt_;
}
if (rowIdx & 1) {
nIdx = nCnt_ - 1 - nIdx;
}
blockCoordN = nIdx;
roundIdx_++;
blockCoord = BlockCoord{blockCoordM, blockCoordN, 0, 0};
return true;
}
__aicore__ inline void GetTileCoord(BlockCoord blockCoord, int64_t& mPos, int64_t& nPos)
{
auto mTileIdx = AscendC::Te::Get<MNK_M>(blockCoord);
auto nTileIdx = AscendC::Te::Get<MNK_N>(blockCoord);
mPos = mTileIdx * baseM_ + mSplitAddrOffset_;
nPos = nTileIdx * baseN_ + nSplitAddrOffset_;
if constexpr (!transA) {
if (mTileIdx > mBaseNormCnt_) {
mPos -= (mTileIdx - mBaseNormCnt_) * (baseM_ - mBaseTailMain_);
}
}
if constexpr (transB) {
if (nTileIdx > nBaseNormCnt_) {
nPos -= (nTileIdx - nBaseNormCnt_) * (baseN_ - nBaseTailMain_);
}
}
}
};
}
}
}