* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file kernel_qbmm_cube.h
* \brief Quantized batch matmul cube kernel (A8W8 fixpipe, Tensor API)
*/
#pragma once
#include "kernel_universal.h"
#if ASC_DEVKIT_MAJOR >= 9
#include "kernel_basic_intf.h"
#else
#include "kernel_operator.h"
#include "kernel_operator_intf.h"
#endif
#include "blaze/gemm/utils/common_utils.h"
#include "blaze/gemm/block/block_scheduler_qbmm.h"
#include "tensor_api/tensor.h"
namespace Blaze {
namespace Gemm {
namespace Kernel {
#define QBMM_CUBE_KERNEL_CLASS_TEM_PARAMS \
template <class ProblemShape, class BlockMmad, class BlockEpilogue, class BlockScheduler>
#define QBMM_CUBE_KERNEL_TEM_PARAMS \
ProblemShape, BlockMmad, BlockEpilogue, BlockScheduler, \
AscendC::Std::enable_if_t< \
AscendC::Std::is_same_v<KernelMmadWithScaleFixpipeQuant, typename BlockMmad::DispatchPolicy::ScheduleType>>
#define QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS template <class ProblemShape, class BlockMmad, class BlockEpilogue, class BlockScheduler>
#define QBMM_CUBE_KERNEL_FUNC_TEMPLATE_PARAMS ProblemShape, BlockMmad, BlockEpilogue, BlockScheduler
QBMM_CUBE_KERNEL_CLASS_TEM_PARAMS
class GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS> {
public:
__aicore__ inline GemmUniversal()
{}
__aicore__ inline ~GemmUniversal()
{}
static constexpr bool weightNz = BlockMmad::weightNz;
static constexpr bool isAtomicAdd = BlockMmad::DispatchPolicy::isAtomicAdd;
static constexpr bool transA = BlockMmad::transA;
static constexpr bool transB = BlockMmad::transB;
using BlockMmadParams = typename BlockMmad::Params;
using AType = typename BlockMmad::AType;
using BType = typename BlockMmad::BType;
using CType = typename BlockMmad::CType;
using BiasType = typename BlockMmad::BiasType;
using LayoutA = typename BlockMmad::LayoutA;
using LayoutB = typename BlockMmad::LayoutB;
using LayoutC = typename BlockMmad::LayoutC;
using X2ScaleType = uint64_t;
using ScaleGmType = typename BlockMmad::X2ScaleType;
static constexpr int64_t C0_SIZE = AscendC::AuxGetC0Size<AType>();
static constexpr uint64_t DEQ_SCALE_MUL = 0xFFFFE000;
static constexpr uint32_t LEFT_SHIFT_16 = 16;
using BlockShape = AscendC::Te::Shape<int64_t, int64_t, int64_t, int64_t>;
using BlockCoord = AscendC::Te::Coord<int64_t, int64_t, int64_t, int64_t>;
using BlockSchedulerParams = typename BlockScheduler::Params;
using MakeLayoutA = AscendC::Te::FrameLayoutFormat<LayoutA, AscendC::Std::Int<C0_SIZE>>;
using MakeLayoutB = AscendC::Te::FrameLayoutFormat<LayoutB, AscendC::Std::Int<C0_SIZE>>;
using MakeLayoutC = AscendC::Te::FrameLayoutFormat<LayoutC, AscendC::Std::Int<AscendC::AuxGetC0Size<CType>()>>;
struct QBMMTiling {
uint32_t batchA1;
uint32_t batchA2;
uint32_t batchA3;
uint32_t batchA4;
uint32_t batchB1;
uint32_t batchB2;
uint32_t batchB3;
uint32_t batchB4;
uint32_t batchC1;
uint32_t batchC2;
uint32_t batchC3;
uint32_t batchC4;
uint32_t biasThreeDim;
uint32_t x1QuantMode;
uint32_t x2QuantMode;
uint32_t kAL1;
uint32_t kBL1;
uint32_t nBufferNum;
uint32_t baseM;
uint32_t baseN;
uint32_t baseK;
uint32_t isBias;
uint32_t dbL0C;
};
struct Params {
ProblemShape problemShape;
BlockMmadParams mmadParams;
BlockSchedulerParams schParams;
QBMMTiling qbmmParams;
};
public:
__aicore__ inline void Init(const Params& params);
__aicore__ inline void Run(const Params& params);
__aicore__ inline void operator()(const Params& params)
{
Run(params);
}
private:
__aicore__ inline void ResetGmAddr(const Params& params);
__aicore__ inline void AddBatchOffset(const Params& params);
__aicore__ inline void ProcessSingleBatch(
const Params& params, BlockScheduler& bs, uint64_t batchCnt, bool isTailRound);
__aicore__ inline void ProcessWithBatch(const Params& params, BlockScheduler& bs);
BlockMmad mmadOp_;
__gm__ AType* aGmBase_{nullptr};
__gm__ BType* bGmBase_{nullptr};
__gm__ CType* cGmBase_{nullptr};
__gm__ BiasType* biasGmBase_{nullptr};
__gm__ X2ScaleType* scaleGmBase_{nullptr};
bool isBias_{false};
bool isBiasThreeDim_{false};
uint64_t scaleScalar_{0};
uint64_t batchCOffset_{0};
uint64_t batchAOffset_{0};
uint64_t batchBOffset_{0};
bool needUpdateTail_{false};
};
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::Run(const Params& params)
{
if constexpr (isAtomicAdd) {
AscendC::SetAtomicAdd<float>();
}
Init(params);
BlockScheduler bs(params.problemShape, params.schParams);
BlockShape l0TileShape{
static_cast<int64_t>(params.qbmmParams.baseM), static_cast<int64_t>(params.qbmmParams.baseN),
static_cast<int64_t>(params.qbmmParams.baseK), 0};
bool enableL0CPingPong = (params.qbmmParams.dbL0C > 1);
mmadOp_.Init(
params.problemShape, l0TileShape, params.qbmmParams.kAL1, params.qbmmParams.kBL1, params.qbmmParams.nBufferNum,
static_cast<QuantMode>(params.qbmmParams.x2QuantMode), isBias_, enableL0CPingPong);
if (AscendC::Te::Get<MNK_B>(params.problemShape) == 1) {
AddBatchOffset(params);
ProcessSingleBatch(params, bs, 0, true);
if constexpr (isAtomicAdd) {
AscendC::SetAtomicNone();
}
return;
}
ProcessWithBatch(params, bs);
if constexpr (isAtomicAdd) {
AscendC::SetAtomicNone();
}
}
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::Init(const Params& params)
{
if ASCEND_IS_AIV {
return;
}
if (params.qbmmParams.isBias == 1) {
isBias_ = true;
biasGmBase_ = reinterpret_cast<__gm__ BiasType*>(params.mmadParams.biasGmAddr);
if (params.qbmmParams.biasThreeDim == 1) {
isBiasThreeDim_ = true;
}
}
aGmBase_ = reinterpret_cast<__gm__ AType*>(params.mmadParams.aGmAddr);
bGmBase_ = reinterpret_cast<__gm__ BType*>(params.mmadParams.bGmAddr);
cGmBase_ = reinterpret_cast<__gm__ CType*>(params.mmadParams.cGmAddr);
if (static_cast<QuantMode>(params.qbmmParams.x2QuantMode) == QuantMode::PERCHANNEL_MODE) {
scaleGmBase_ = reinterpret_cast<__gm__ uint64_t*>(params.mmadParams.scaleBGmAddr);
} else if (
static_cast<QuantMode>(params.qbmmParams.x1QuantMode) == QuantMode::PERTENSOR_MODE) {
auto pertokenScale = AscendC::GlobalTensor<float>();
auto scale = AscendC::GlobalTensor<float>();
pertokenScale.SetGlobalBuffer((__gm__ float*)params.mmadParams.scaleAGmAddr);
scale.SetGlobalBuffer((__gm__ float*)params.mmadParams.scaleBGmAddr);
float deqScale = pertokenScale.GetValue(0) * scale.GetValue(0);
uint32_t uint32Scale = *(reinterpret_cast<uint32_t*>(&deqScale));
scaleScalar_ = static_cast<uint64_t>(uint32Scale & DEQ_SCALE_MUL);
} else if (static_cast<QuantMode>(params.qbmmParams.x2QuantMode) == QuantMode::PERTENSOR_MODE) {
if constexpr (AscendC::IsSameType<ScaleGmType, uint64_t>::value ||
AscendC::IsSameType<ScaleGmType, int64_t>::value) {
auto scale = AscendC::GlobalTensor<uint64_t>();
scale.SetGlobalBuffer(reinterpret_cast<__gm__ uint64_t*>(params.mmadParams.scaleBGmAddr));
scaleScalar_ = scale.GetValue(0);
} else if constexpr (AscendC::IsSameType<ScaleGmType, bfloat16_t>::value) {
auto scale = AscendC::GlobalTensor<uint16_t>();
scale.SetGlobalBuffer((__gm__ uint16_t*)params.mmadParams.scaleBGmAddr);
uint16_t uint16Scale = scale.GetValue(0);
uint32_t uint32Scale = static_cast<uint32_t>(uint16Scale << LEFT_SHIFT_16);
scaleScalar_ = static_cast<uint64_t>(uint32Scale & DEQ_SCALE_MUL);
} else {
auto scale = AscendC::GlobalTensor<uint32_t>();
scale.SetGlobalBuffer((__gm__ uint32_t*)params.mmadParams.scaleBGmAddr);
uint32_t uint32Scale = scale.GetValue(0);
scaleScalar_ = static_cast<uint64_t>(uint32Scale & DEQ_SCALE_MUL);
}
}
}
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::ResetGmAddr(const Params& params)
{
if ASCEND_IS_AIV {
return;
}
aGmBase_ = reinterpret_cast<__gm__ AType*>(params.mmadParams.aGmAddr);
bGmBase_ = reinterpret_cast<__gm__ BType*>(params.mmadParams.bGmAddr);
cGmBase_ = reinterpret_cast<__gm__ CType*>(params.mmadParams.cGmAddr);
if (isBias_) {
biasGmBase_ = reinterpret_cast<__gm__ BiasType*>(params.mmadParams.biasGmAddr);
}
}
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::AddBatchOffset(const Params& params)
{
ResetGmAddr(params);
aGmBase_ += batchAOffset_ * AscendC::Te::Get<MNK_M>(params.problemShape) *
AscendC::Te::Get<MNK_K>(params.problemShape);
if constexpr (weightNz) {
if constexpr (transB) {
bGmBase_ += batchBOffset_ * Blaze::Gemm::CeilDiv(AscendC::Te::Get<MNK_K>(params.problemShape), C0_SIZE) *
Blaze::Gemm::CeilDiv(AscendC::Te::Get<MNK_N>(params.problemShape), static_cast<int64_t>(BLOCK_CUBE)) *
BLOCK_CUBE * C0_SIZE;
} else {
bGmBase_ += batchBOffset_ * Blaze::Gemm::CeilDiv(AscendC::Te::Get<MNK_N>(params.problemShape), C0_SIZE) *
Blaze::Gemm::CeilDiv(AscendC::Te::Get<MNK_K>(params.problemShape), static_cast<int64_t>(BLOCK_CUBE)) *
BLOCK_CUBE * C0_SIZE;
}
} else {
bGmBase_ += batchBOffset_ * AscendC::Te::Get<MNK_N>(params.problemShape) *
AscendC::Te::Get<MNK_K>(params.problemShape);
}
cGmBase_ += batchCOffset_ * AscendC::Te::Get<MNK_M>(params.problemShape) *
AscendC::Te::Get<MNK_N>(params.problemShape);
if (isBiasThreeDim_) {
biasGmBase_ += batchCOffset_ * AscendC::Te::Get<MNK_N>(params.problemShape);
}
}
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::ProcessWithBatch(
const Params& params, BlockScheduler& bs)
{
uint64_t batchC3C4 = static_cast<uint64_t>(params.qbmmParams.batchC3) * params.qbmmParams.batchC4;
uint64_t batchC2C3C4 = params.qbmmParams.batchC2 * batchC3C4;
uint64_t batchB3B4 = static_cast<uint64_t>(params.qbmmParams.batchB3) * params.qbmmParams.batchB4;
uint64_t batchB2B3B4 = params.qbmmParams.batchB2 * batchB3B4;
uint64_t batchA3A4 = static_cast<uint64_t>(params.qbmmParams.batchA3) * params.qbmmParams.batchA4;
uint64_t batchA2A3A4 = params.qbmmParams.batchA2 * batchA3A4;
uint32_t multiA1C1 = params.qbmmParams.batchA1 / params.qbmmParams.batchC1;
uint32_t multiA2C2 = params.qbmmParams.batchA2 / params.qbmmParams.batchC2;
uint32_t multiA3C3 = params.qbmmParams.batchA3 / params.qbmmParams.batchC3;
uint32_t multiA4C4 = params.qbmmParams.batchA4 / params.qbmmParams.batchC4;
uint32_t multiB1C1 = params.qbmmParams.batchB1 / params.qbmmParams.batchC1;
uint32_t multiB2C2 = params.qbmmParams.batchB2 / params.qbmmParams.batchC2;
uint32_t multiB3C3 = params.qbmmParams.batchB3 / params.qbmmParams.batchC3;
uint32_t multiB4C4 = params.qbmmParams.batchB4 / params.qbmmParams.batchC4;
uint64_t batchC1Offset = 0;
uint64_t batchA1Offset = 0;
uint64_t batchB1Offset = 0;
uint64_t curBatchC = 1UL;
uint64_t totalCnt = bs.GetTotalCnt() * AscendC::Te::Get<MNK_B>(params.problemShape);
uint64_t nonTailRoundCnt = (totalCnt / AscendC::GetBlockNum()) * AscendC::GetBlockNum();
for (uint64_t b1Index = 0; b1Index < params.qbmmParams.batchC1; ++b1Index) {
uint64_t batchC2Offset = batchC1Offset;
uint64_t batchA2Offset = batchA1Offset;
uint64_t batchB2Offset = batchB1Offset;
for (uint64_t b2Index = 0; b2Index < params.qbmmParams.batchC2; ++b2Index) {
uint64_t batchC3Offset = batchC2Offset;
uint64_t batchA3Offset = batchA2Offset;
uint64_t batchB3Offset = batchB2Offset;
for (uint64_t b3Index = 0; b3Index < params.qbmmParams.batchC3; ++b3Index) {
batchCOffset_ = batchC3Offset;
batchAOffset_ = batchA3Offset;
batchBOffset_ = batchB3Offset;
for (uint64_t b4Index = 0; b4Index < params.qbmmParams.batchC4; ++b4Index) {
bool isTailRound = curBatchC * bs.GetTotalCnt() > nonTailRoundCnt;
AddBatchOffset(params);
ProcessSingleBatch(
params, bs, (AscendC::Te::Get<MNK_B>(params.problemShape) - curBatchC), isTailRound);
curBatchC++;
batchCOffset_ += 1;
batchAOffset_ += multiA4C4;
batchBOffset_ += multiB4C4;
}
batchC3Offset += params.qbmmParams.batchC4;
batchA3Offset += params.qbmmParams.batchA4 * static_cast<uint64_t>(multiA3C3);
batchB3Offset += params.qbmmParams.batchB4 * static_cast<uint64_t>(multiB3C3);
}
batchC2Offset += batchC3C4;
batchA2Offset += batchA3A4 * multiA2C2;
batchB2Offset += batchB3B4 * multiB2C2;
}
batchC1Offset += batchC2C3C4;
batchA1Offset += batchA2A3A4 * multiA1C1;
batchB1Offset += batchB2B3B4 * multiB1C1;
}
}
QBMM_CUBE_KERNEL_CLASS_TEMPLATE_DEF_PARAMS
__aicore__ inline void GemmUniversal<QBMM_CUBE_KERNEL_TEM_PARAMS>::ProcessSingleBatch(
const Params& params, BlockScheduler& bs, uint64_t restBatch, bool isTailRound)
{
const int64_t m = AscendC::Te::Get<MNK_M>(params.problemShape);
const int64_t n = AscendC::Te::Get<MNK_N>(params.problemShape);
const int64_t k = AscendC::Te::Get<MNK_K>(params.problemShape);
constexpr int64_t kPos = 0;
auto layoutA = MakeLayoutA{}(m, k);
auto layoutB = MakeLayoutB{}(k, n);
auto layoutC = MakeLayoutC{}(m, n);
auto gmA = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(aGmBase_), layoutA);
auto gmB = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(bGmBase_), layoutB);
auto gmC = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(cGmBase_), layoutC);
auto layoutBias = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn>(1L, n);
__gm__ BiasType* biasPtr = isBias_ ? biasGmBase_ : reinterpret_cast<__gm__ BiasType*>(cGmBase_);
auto gmBias = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(biasPtr), layoutBias);
const bool isPerChannel =
static_cast<QuantMode>(params.qbmmParams.x2QuantMode) == QuantMode::PERCHANNEL_MODE;
BlockCoord blockIdx;
if (needUpdateTail_ || (isTailRound && ((bs.GetEndBlockIdx() + 1) + (restBatch * bs.GetTotalCnt())) *
params.schParams.mTailTile * params.schParams.nTailTile <=
AscendC::GetBlockNum())) {
needUpdateTail_ = true;
bs.UpdateTailTile(params.schParams.mTailTile, params.schParams.nTailTile);
}
int64_t mPos = 0L;
int64_t nPos = 0L;
auto layoutScale = AscendC::Te::MakeFrameLayout<AscendC::Te::NDExtLayoutPtn, AscendC::Te::LayoutTraitDefault<X2ScaleType>>(1, n);
auto gmScale = AscendC::Te::MakeTensor(AscendC::Te::MakeMemPtr<AscendC::Te::Location::GM>(scaleGmBase_), layoutScale);
while (bs.GetTileIdx(blockIdx)) {
BlockShape singleShape = bs.template GetBlockShape<QuantMode::DEFAULT, QuantMode::DEFAULT, weightNz>(blockIdx);
if (AscendC::Te::Get<IDX_M_TILEIDX>(singleShape) <= 0 || AscendC::Te::Get<IDX_N_TILEIDX>(singleShape) <= 0) {
return;
}
bs.GetTileCoord(blockIdx, mPos, nPos);
const int64_t curM = AscendC::Te::Get<IDX_M_TILEIDX>(singleShape);
const int64_t curN = AscendC::Te::Get<IDX_N_TILEIDX>(singleShape);
auto gmBlockA = gmA.Slice(AscendC::Te::MakeCoord(mPos, kPos), AscendC::Te::MakeShape(curM, k));
auto gmBlockB = gmB.Slice(AscendC::Te::MakeCoord(kPos, nPos), AscendC::Te::MakeShape(k, curN));
auto gmBlockC = gmC.Slice(AscendC::Te::MakeCoord(mPos, nPos), AscendC::Te::MakeShape(curM, curN));
if (isPerChannel) {
auto gmBlockScale = gmScale.Slice(AscendC::Te::MakeCoord(0UL, nPos), AscendC::Te::MakeShape(1, curN));
if (isBias_) {
auto gmBlockBias = gmBias.Slice(AscendC::Te::MakeCoord(0, nPos), AscendC::Te::MakeShape(1, curN));
mmadOp_(gmBlockA, gmBlockB, gmBlockScale, gmBlockBias, gmBlockC, singleShape);
} else {
auto gmBlockBias = gmBias.Slice(AscendC::Te::MakeCoord(0, 0), AscendC::Te::MakeShape(1, 1));
mmadOp_(gmBlockA, gmBlockB, gmBlockScale, gmBlockBias, gmBlockC, singleShape);
}
} else {
if (isBias_) {
auto gmBlockBias = gmBias.Slice(AscendC::Te::MakeCoord(0, nPos), AscendC::Te::MakeShape(1, curN));
mmadOp_(gmBlockA, gmBlockB, scaleScalar_, gmBlockBias, gmBlockC, singleShape);
} else {
auto gmBlockBias = gmBias.Slice(AscendC::Te::MakeCoord(0, 0), AscendC::Te::MakeShape(1, 1));
mmadOp_(gmBlockA, gmBlockB, scaleScalar_, gmBlockBias, gmBlockC, singleShape);
}
}
}
bs.UpdateNextBatchBlockRoundParams();
}
}
}
}